In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from torchinfo import summary
if torch.cuda.is_available():
    print(torch.cuda.get_device_name())
else:
    print("CPU")

NVIDIA GeForce RTX 3050 6GB Laptop GPU


In [2]:
class TaskConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, task_ids, 
                 stride=1, padding=0, dilation=1, groups=1, bias=True):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.task_ids = task_ids
        
        self.weight = nn.Parameter(torch.randn(out_channels, in_channels, *self.kernel_size) * 0.01)
        if bias:
            self.bias = nn.Parameter(torch.randn(out_channels) * 0.01)
        else:
            self.register_parameter('bias', None)
            
        self.register_buffer("channel_task_ids", torch.tensor(task_ids).long())
    
    def forward(self, x, task_id):
        out = F.conv2d(x, self.weight, self.bias, self.stride, 
                      self.padding, self.dilation, self.groups)
        
        mask = (self.channel_task_ids == task_id).float()
        self._mask = self.channel_task_ids == task_id
        mask = mask.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
        
        return out * mask
    
    def apply_gradient_mask(self):
        if self.weight.grad is not None:
            inactive = ~self._mask
            self.weight.grad[inactive] = 0
            if self.bias is not None:
                self.bias.grad[inactive] = 0

In [3]:
class TaskLinear(nn.Module):
    def __init__(self, in_features, out_features, task_ids):
        super().__init__()
        self.in_features = in_features 
        self.out_features = out_features
        self.task_ids = task_ids
        self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.01)
        self.bias = nn.Parameter(torch.randn(out_features) * 0.01)
        self.register_buffer("neuron_ids", torch.tensor(task_ids).long())

    def forward(self, x, task_id):
        out = F.linear(x, self.weight, self.bias)
        mask = (self.neuron_ids == task_id).float().unsqueeze(0)
        self._mask = self.neuron_ids == task_id
        return out * mask

    def apply_gradient_mask(self):
        if self.weight.grad is not None:
            inactive = ~self._mask
            self.weight.grad[inactive] = 0
            self.bias.grad[inactive] = 0

In [4]:
class MultiTaskConvSynapticNet(nn.Module):
    def __init__(self, input_channels=3, initial_conv_channels=64, 
                 hidden_size=512, output_size=10, num_conv_layers=4):
        super().__init__()
        self.input_channels = input_channels
        self.conv_channels = initial_conv_channels
        self.hidden_size = hidden_size
        self.num_conv_layers = num_conv_layers
        
        # Convolutional layers
        self.conv_layers = nn.ModuleList()
        
        # First conv layer (handles different input channels for different datasets)
        task_ids = [1] * initial_conv_channels
        self.conv_layers.append(
            TaskConv2d(input_channels, initial_conv_channels, 3, task_ids, padding=1)
        )
        
        # Additional conv layers
        for i in range(1, num_conv_layers):
            task_ids = [1] * initial_conv_channels
            self.conv_layers.append(
                TaskConv2d(initial_conv_channels, initial_conv_channels, 3, task_ids, padding=1)
            )
        
        # MaxPool and dropout
        self.pool = nn.MaxPool2d(2, 2)
        self.adaptive_pool = nn.AdaptiveAvgPool2d((4, 4))  # Ensures consistent output size
        
        # Calculate flattened size after adaptive pooling
        flattened_size = initial_conv_channels * 4 * 4
        
        # Fully connected layers
        self.fc1 = TaskLinear(flattened_size, hidden_size, [1] * hidden_size)
        self.fc2 = TaskLinear(hidden_size, output_size, [1] * output_size)
        
        # Activation functions
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
    
    def grow(self, grow_conv_channels=64, grow_hidden=512, grow_output=10, task_id=2):
        """Grow the network for a new task"""
        new_conv_channels = self.conv_channels + grow_conv_channels
        new_hidden_size = self.hidden_size + grow_hidden
        new_output_size = self.fc2.out_features + grow_output
        
        new_conv_task_ids = [task_id] * grow_conv_channels
        new_hidden_task_ids = [task_id] * grow_hidden
        new_output_task_ids = [task_id] * grow_output
        
        # Grow convolutional layers
        new_conv_layers = nn.ModuleList()
        
        for i, old_conv in enumerate(self.conv_layers):
            old_task_ids = old_conv.task_ids
            combined_task_ids = old_task_ids + new_conv_task_ids
            
            if i == 0:
                new_conv = TaskConv2d(
                    self.input_channels, new_conv_channels, 3, 
                    combined_task_ids, padding=1
                )
            else:
                new_conv = TaskConv2d(
                    new_conv_channels, new_conv_channels, 3,
                    combined_task_ids, padding=1
                )
            
            # Copy old weights
            with torch.no_grad():
                old_out_channels = old_conv.out_channels
                if i == 0:
                    new_conv.weight[:old_out_channels].copy_(old_conv.weight)
                else:
                    new_conv.weight[:old_out_channels, :self.conv_channels].copy_(old_conv.weight)
                new_conv.bias[:old_out_channels].copy_(old_conv.bias)
            
            new_conv_layers.append(new_conv)
        
        self.conv_layers = new_conv_layers
        
        # Update flattened size calculation
        new_flattened_size = new_conv_channels * 4 * 4
        
        # Grow FC1 layer
        old_fc1_task_ids = self.fc1.task_ids
        combined_fc1_task_ids = old_fc1_task_ids + new_hidden_task_ids
        
        new_fc1 = TaskLinear(new_flattened_size, new_hidden_size, combined_fc1_task_ids)
        with torch.no_grad():
            old_features = self.fc1.in_features
            old_hidden = self.fc1.out_features
            new_fc1.weight[:old_hidden, :old_features].copy_(self.fc1.weight)
            new_fc1.bias[:old_hidden].copy_(self.fc1.bias)
        
        self.fc1 = new_fc1
        
        # Grow FC2 layer
        old_fc2_task_ids = self.fc2.task_ids
        combined_fc2_task_ids = old_fc2_task_ids + new_output_task_ids
        
        new_fc2 = TaskLinear(new_hidden_size, new_output_size, combined_fc2_task_ids)
        with torch.no_grad():
            old_output = self.fc2.out_features
            old_hidden = self.fc2.in_features
            new_fc2.weight[:old_output, :old_hidden].copy_(self.fc2.weight)
            new_fc2.bias[:old_output].copy_(self.fc2.bias)
        
        self.fc2 = new_fc2
        
        # Update dimensions
        self.conv_channels = new_conv_channels
        self.hidden_size = new_hidden_size
    
    def forward(self, x, task_id):
        # Convolutional layers with pooling
        for i, conv_layer in enumerate(self.conv_layers):
            x = self.relu(conv_layer(x, task_id))
            if i < len(self.conv_layers) - 1:  # Don't pool after last conv layer
                x = self.pool(x)
        
        # Adaptive pooling to handle different input sizes
        x = self.adaptive_pool(x)
        
        # Flatten
        x = x.view(x.size(0), -1)
        
        # Fully connected layers
        x = self.dropout(self.relu(self.fc1(x, task_id)))
        x = self.fc2(x, task_id)
        
        return x
    
    def apply_task_gradient_mask(self):
        for conv_layer in self.conv_layers:
            conv_layer.apply_gradient_mask()
        self.fc1.apply_gradient_mask()
        self.fc2.apply_gradient_mask()


In [5]:
class TaskDataset(Dataset):
    def __init__(self, dataset, task_id, label_offset=0, input_size=(32, 32)):
        self.dataset = dataset
        self.task_id = task_id
        self.label_offset = label_offset
        self.input_size = input_size

    def __getitem__(self, idx):
        x, y = self.dataset[idx]
        
        # Handle different input sizes (MNIST vs CIFAR)
        if x.shape[1:] != self.input_size:
            # Resize if needed (for MNIST -> 32x32)
            resize_transform = transforms.Resize(self.input_size)
            x = resize_transform(x)
        
        # Handle grayscale vs RGB (MNIST vs CIFAR)
        if x.shape[0] == 1 and self.input_size == (32, 32):
            # Convert grayscale to RGB for consistency
            x = x.repeat(3, 1, 1)
        
        return x, y + self.label_offset, self.task_id

    def __len__(self):
        return len(self.dataset)

In [6]:
def get_all_dataloaders(batch_size=64):
    # Different transforms for grayscale vs color datasets
    transform_gray = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    
    transform_color = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # Load datasets
    mnist = datasets.MNIST(root='./data', train=True, download=True, transform=transform_gray)
    fmnist = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform_gray)
    cifar10 = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_color)
    cifar100 = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_color)

    # Create task datasets with appropriate label offsets
    mnist_task = TaskDataset(mnist, task_id=1, label_offset=0, input_size=(32, 32))
    fmnist_task = TaskDataset(fmnist, task_id=2, label_offset=10, input_size=(32, 32))
    cifar10_task = TaskDataset(cifar10, task_id=3, label_offset=20, input_size=(32, 32))
    cifar100_task = TaskDataset(cifar100, task_id=4, label_offset=30, input_size=(32, 32))

    # Create data loaders
    mnist_loader = DataLoader(mnist_task, batch_size=batch_size, shuffle=True)
    fmnist_loader = DataLoader(fmnist_task, batch_size=batch_size, shuffle=True)
    cifar10_loader = DataLoader(cifar10_task, batch_size=batch_size, shuffle=True)
    cifar100_loader = DataLoader(cifar100_task, batch_size=batch_size, shuffle=True)

    return mnist_loader, fmnist_loader, cifar10_loader, cifar100_loader


In [7]:
# Training function
def train_task(model, task_id, dataloader, epochs=5, lr=1e-3):
    model.train()
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(epochs):
        total_loss = 0
        correct = 0
        total = 0
        
        for batch_idx, (x, y, tid) in enumerate(dataloader):
            x, y = x.to(device), y.to(device)
            
            optimizer.zero_grad()
            out = model(x, task_id)
            loss = criterion(out, y)
            loss.backward()
            
            model.apply_task_gradient_mask()
            optimizer.step()
            
            total_loss += loss.item()
            preds = torch.argmax(out, dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
            
            # Print progress every 200 batches
            if batch_idx % 200 == 0:
                batch_acc = 100 * (preds == y).sum().item() / y.size(0)
                print(f"[Task {task_id}] Epoch {epoch+1}, Batch {batch_idx} | Loss: {loss.item():.4f} | Batch Acc: {batch_acc:.2f}%")
        
        acc = 100 * correct / total
        print(f"[Task {task_id}] Epoch {epoch+1} | Loss: {total_loss:.4f} | Accuracy: {acc:.2f}%")


In [8]:
# Evaluation function
def evaluate_task(model, task_id, dataloader, task_name=""):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for x, y, tid in dataloader:
            x, y = x.to(device), y.to(device)
            out = model(x, task_id)
            preds = torch.argmax(out, dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)

    acc = 100 * correct / total
    print(f"[Task {task_id}] {task_name} Evaluation Accuracy: {acc:.2f}%")
    return acc

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [10]:
model = MultiTaskConvSynapticNet(
        input_channels=3,  # RGB for all datasets (MNIST will be converted)
        initial_conv_channels=64, 
        hidden_size=512, 
        output_size=10,  # Start with 10 classes for MNIST
        num_conv_layers=4
).to(device)

In [11]:
# Get all data loaders
mnist_loader, fmnist_loader, cifar10_loader, cifar100_loader = get_all_dataloaders()

In [12]:
# Task 1: MNIST
print("=" * 50)
print("Training on MNIST (Task 1)")
print("=" * 50)
train_task(model, task_id=1, dataloader=mnist_loader, epochs=3)

Training on MNIST (Task 1)
[Task 1] Epoch 1, Batch 0 | Loss: 2.3016 | Batch Acc: 15.62%
[Task 1] Epoch 1, Batch 200 | Loss: 0.3233 | Batch Acc: 89.06%
[Task 1] Epoch 1, Batch 400 | Loss: 0.1036 | Batch Acc: 96.88%
[Task 1] Epoch 1, Batch 600 | Loss: 0.0399 | Batch Acc: 98.44%
[Task 1] Epoch 1, Batch 800 | Loss: 0.1133 | Batch Acc: 93.75%
[Task 1] Epoch 1 | Loss: 262.2521 | Accuracy: 90.34%
[Task 1] Epoch 2, Batch 0 | Loss: 0.0543 | Batch Acc: 96.88%
[Task 1] Epoch 2, Batch 200 | Loss: 0.0943 | Batch Acc: 98.44%
[Task 1] Epoch 2, Batch 400 | Loss: 0.0244 | Batch Acc: 100.00%
[Task 1] Epoch 2, Batch 600 | Loss: 0.0215 | Batch Acc: 100.00%
[Task 1] Epoch 2, Batch 800 | Loss: 0.0235 | Batch Acc: 100.00%
[Task 1] Epoch 2 | Loss: 54.3520 | Accuracy: 98.24%
[Task 1] Epoch 3, Batch 0 | Loss: 0.0165 | Batch Acc: 100.00%
[Task 1] Epoch 3, Batch 200 | Loss: 0.0178 | Batch Acc: 100.00%
[Task 1] Epoch 3, Batch 400 | Loss: 0.0275 | Batch Acc: 98.44%
[Task 1] Epoch 3, Batch 600 | Loss: 0.0576 | Batch

In [13]:
# Grow for Task 2: Fashion-MNIST
print("\nGrowing network for Fashion-MNIST (Task 2)...")
model.grow(grow_conv_channels=64, grow_hidden=512, grow_output=10, task_id=2)
model = model.to(device)


Growing network for Fashion-MNIST (Task 2)...


In [21]:
print("=" * 50)
print("Training on Fashion-MNIST (Task 2)")
print("=" * 50)
train_task(model, task_id=2, dataloader=fmnist_loader, epochs=10)

Training on Fashion-MNIST (Task 2)
[Task 2] Epoch 1, Batch 0 | Loss: 0.3362 | Batch Acc: 82.81%
[Task 2] Epoch 1, Batch 200 | Loss: 0.3196 | Batch Acc: 84.38%
[Task 2] Epoch 1, Batch 400 | Loss: 0.2849 | Batch Acc: 90.62%
[Task 2] Epoch 1, Batch 600 | Loss: 0.3722 | Batch Acc: 89.06%
[Task 2] Epoch 1, Batch 800 | Loss: 0.4145 | Batch Acc: 85.94%
[Task 2] Epoch 1 | Loss: 274.1069 | Accuracy: 89.24%
[Task 2] Epoch 2, Batch 0 | Loss: 0.3152 | Batch Acc: 89.06%
[Task 2] Epoch 2, Batch 200 | Loss: 0.2286 | Batch Acc: 90.62%
[Task 2] Epoch 2, Batch 400 | Loss: 0.3027 | Batch Acc: 85.94%
[Task 2] Epoch 2, Batch 600 | Loss: 0.1682 | Batch Acc: 93.75%
[Task 2] Epoch 2, Batch 800 | Loss: 0.2939 | Batch Acc: 85.94%
[Task 2] Epoch 2 | Loss: 245.1158 | Accuracy: 90.38%
[Task 2] Epoch 3, Batch 0 | Loss: 0.2448 | Batch Acc: 87.50%
[Task 2] Epoch 3, Batch 200 | Loss: 0.1968 | Batch Acc: 92.19%
[Task 2] Epoch 3, Batch 400 | Loss: 0.2354 | Batch Acc: 92.19%
[Task 2] Epoch 3, Batch 600 | Loss: 0.1724 | B

In [15]:
# Grow for Task 3: CIFAR-10
print("\nGrowing network for CIFAR-10 (Task 3)...")
model.grow(grow_conv_channels=64, grow_hidden=512, grow_output=10, task_id=3)
model = model.to(device)


Growing network for CIFAR-10 (Task 3)...


In [23]:
print("=" * 50)
print("Training on CIFAR-10 (Task 3)")
print("=" * 50)
train_task(model, task_id=3, dataloader=cifar10_loader, epochs=10)

Training on CIFAR-10 (Task 3)
[Task 3] Epoch 1, Batch 0 | Loss: 0.2766 | Batch Acc: 90.62%
[Task 3] Epoch 1, Batch 200 | Loss: 0.5149 | Batch Acc: 85.94%
[Task 3] Epoch 1, Batch 400 | Loss: 0.2926 | Batch Acc: 93.75%
[Task 3] Epoch 1, Batch 600 | Loss: 0.4333 | Batch Acc: 82.81%
[Task 3] Epoch 1 | Loss: 377.0778 | Accuracy: 82.99%
[Task 3] Epoch 2, Batch 0 | Loss: 0.4307 | Batch Acc: 79.69%
[Task 3] Epoch 2, Batch 200 | Loss: 0.5021 | Batch Acc: 85.94%
[Task 3] Epoch 2, Batch 400 | Loss: 0.5158 | Batch Acc: 84.38%
[Task 3] Epoch 2, Batch 600 | Loss: 0.5037 | Batch Acc: 78.12%
[Task 3] Epoch 2 | Loss: 344.2313 | Accuracy: 84.42%
[Task 3] Epoch 3, Batch 0 | Loss: 0.2526 | Batch Acc: 89.06%
[Task 3] Epoch 3, Batch 200 | Loss: 0.6546 | Batch Acc: 79.69%
[Task 3] Epoch 3, Batch 400 | Loss: 0.3946 | Batch Acc: 87.50%
[Task 3] Epoch 3, Batch 600 | Loss: 0.2156 | Batch Acc: 93.75%
[Task 3] Epoch 3 | Loss: 321.6060 | Accuracy: 85.40%
[Task 3] Epoch 4, Batch 0 | Loss: 0.3311 | Batch Acc: 85.94%


In [17]:
# Grow for Task 4: CIFAR-100
print("\nGrowing network for CIFAR-100 (Task 4)...")
model.grow(grow_conv_channels=128, grow_hidden=1024, grow_output=100, task_id=4)  # Larger growth for 100 classes
model = model.to(device)


Growing network for CIFAR-100 (Task 4)...


In [25]:
print("=" * 50)
print("Training on CIFAR-100 (Task 4)")
print("=" * 50)
train_task(model, task_id=4, dataloader=cifar100_loader, epochs=10)

Training on CIFAR-100 (Task 4)
[Task 4] Epoch 1, Batch 0 | Loss: 0.5195 | Batch Acc: 84.38%
[Task 4] Epoch 1, Batch 200 | Loss: 0.7242 | Batch Acc: 73.44%
[Task 4] Epoch 1, Batch 400 | Loss: 0.6208 | Batch Acc: 84.38%
[Task 4] Epoch 1, Batch 600 | Loss: 0.9654 | Batch Acc: 67.19%
[Task 4] Epoch 1 | Loss: 666.2282 | Accuracy: 73.52%
[Task 4] Epoch 2, Batch 0 | Loss: 0.7491 | Batch Acc: 70.31%
[Task 4] Epoch 2, Batch 200 | Loss: 0.9515 | Batch Acc: 73.44%
[Task 4] Epoch 2, Batch 400 | Loss: 0.7985 | Batch Acc: 75.00%
[Task 4] Epoch 2, Batch 600 | Loss: 0.9242 | Batch Acc: 71.88%
[Task 4] Epoch 2 | Loss: 613.4397 | Accuracy: 75.30%
[Task 4] Epoch 3, Batch 0 | Loss: 0.6444 | Batch Acc: 79.69%
[Task 4] Epoch 3, Batch 200 | Loss: 0.6371 | Batch Acc: 81.25%
[Task 4] Epoch 3, Batch 400 | Loss: 0.8191 | Batch Acc: 71.88%
[Task 4] Epoch 3, Batch 600 | Loss: 0.7537 | Batch Acc: 75.00%
[Task 4] Epoch 3 | Loss: 591.1565 | Accuracy: 76.12%
[Task 4] Epoch 4, Batch 0 | Loss: 0.5244 | Batch Acc: 84.38%

In [26]:
# Evaluate all tasks
print("\n" + "=" * 50)
print("FINAL EVALUATION ON ALL TASKS")
print("=" * 50)
    
evaluate_task(model, task_id=1, dataloader=mnist_loader, task_name="MNIST")
evaluate_task(model, task_id=2, dataloader=fmnist_loader, task_name="Fashion-MNIST")
evaluate_task(model, task_id=3, dataloader=cifar10_loader, task_name="CIFAR-10")
evaluate_task(model, task_id=4, dataloader=cifar100_loader, task_name="CIFAR-100")


FINAL EVALUATION ON ALL TASKS
[Task 1] MNIST Evaluation Accuracy: 98.89%
[Task 2] Fashion-MNIST Evaluation Accuracy: 95.62%
[Task 3] CIFAR-10 Evaluation Accuracy: 93.62%
[Task 4] CIFAR-100 Evaluation Accuracy: 93.89%


93.888

In [20]:
# Print final network size
total_params = sum(p.numel() for p in model.parameters())
print(f"\nFinal network parameters: {total_params:,}")
print(f"Final conv channels: {model.conv_channels}")
print(f"Final hidden size: {model.hidden_size}")
print(f"Final output size: {model.fc2.out_features}")


Final network parameters: 16,217,410
Final conv channels: 320
Final hidden size: 2560
Final output size: 130
