In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from torchinfo import summary
if torch.cuda.is_available():
    print(torch.cuda.get_device_name())
else:
    print("CPU")

NVIDIA GeForce RTX 3050 6GB Laptop GPU


In [2]:
class TaskLinear(nn.Module):
    def __init__(self, in_features, out_features, task_ids):
        super().__init__()
        self.in_features = in_features 
        self.out_features = out_features
        self.task_ids = task_ids
        self.weight = nn.Parameter(torch.randn(out_features,in_features)*0.01)
        self.bias = nn.Parameter(torch.randn(out_features))
        self.register_buffer("neuron_ids", torch.tensor(task_ids).long())

    def forward(self, x, task_id):
        out = F.linear(x,self.weight , self.bias)
        mask = (self.neuron_ids == task_id).float().unsqueeze(0)
        self._mask = self.neuron_ids == task_id

        return out*mask

    def apply_gradient_mask(self):
        if self.weight.grad is not None:
            inactive = ~self._mask
            self.weight.grad[inactive] = 0
            self.bias.grad[inactive]=0

In [3]:
class SynapticNet(nn.Module):
    def __init__(self, input_size = 28*28 , hidden_size = 512 , output_szie =10 , hidden_layers = 10):
        super().__init__()
        self.hidden_layers = hidden_layers
        self.hidden_size = hidden_size
        self.layers = nn.ModuleList()
        for _ in range(self.hidden_layers):
            task_ids = [1]*512
            self.layers.append(TaskLinear(hidden_size, hidden_size,task_ids))
        self.input_layer = TaskLinear(input_size,hidden_size,[1]*512)
        self.output_layer = TaskLinear(hidden_size, output_szie , [1]*10)
        self.relu = nn.ReLU()

    def grow(self, grow_size = 512 , output_grow_size = 10, task_id = [2]):
        new_hidden_size = self.hidden_size+grow_size
        new_output_size = self.output_layer.out_features + output_grow_size
        new_task_ids = task_id*grow_size
        old_task_ids = self.input_layer.task_ids
        combined_task_ids = old_task_ids + new_task_ids
        old_input_w = self.input_layer.weight.data.clone()
        old_input_b = self.input_layer.bias.data.clone()
        new_input = TaskLinear(self.input_layer.in_features , new_hidden_size , combined_task_ids)
        with torch.no_grad():
            new_input.weight[:self.hidden_size, :].copy_(old_input_w)
            new_input.bias[:self.hidden_size].copy_(old_input_b)
        self.input_layer = new_input

        new_layers = nn.ModuleList()
        for i in range(self.hidden_layers):
            old_layer = self.layers[i]
            old_w = old_layer.weight.data.clone()
            old_b =  old_layer.bias.data.clone()
            new_layer = TaskLinear(new_hidden_size , new_hidden_size , combined_task_ids)
            with torch.no_grad():
                new_layer.weight[:self.hidden_size, :self.hidden_size] = old_w
                new_layer.bias[:self.hidden_size] = old_b
            new_layers.append(new_layer)
        self.layers = new_layers

        old_out_w = self.output_layer.weight.data.clone()
        old_out_b = self.output_layer.bias.data.clone()
        old_out_task_ids = self.output_layer.task_ids
        new_out_task_ids = task_id*output_grow_size
        combined_out_task_ids = old_out_task_ids + new_out_task_ids
        new_output = TaskLinear(new_hidden_size,new_output_size,combined_out_task_ids)
        with torch.no_grad():
            new_output.weight[:self.output_layer.out_features , :self.hidden_size].copy_(old_out_w)
            new_output.bias[:self.output_layer.out_features].copy_(old_out_b)
        self.output_layer = new_output

        self.hidden_size = new_hidden_size
    
    def forward(self,x,task_id):
        x = self.relu(self.input_layer(x,task_id))
        for layer in self.layers:
            x = self.relu(layer(x,task_id))
        x = self.output_layer(x,task_id)
        return x
        
    def apply_task_gradient_mask(self):
        self.input_layer.apply_gradient_mask()
        for layer in self.layers:
            layer.apply_gradient_mask()
        self.output_layer.apply_gradient_mask()

In [4]:
class TaskDataset(Dataset):
    def __init__(self, dataset, task_id, label_offset=0):
        self.dataset = dataset
        self.task_id = task_id
        self.label_offset = label_offset

    def __getitem__(self, idx):
        x, y = self.dataset[idx]
        return x.view(-1), y + self.label_offset, self.task_id

    def __len__(self):
        return len(self.dataset)

In [5]:
def get_dataloaders(batch_size=64):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    mnist = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    fmnist = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)

    mnist = TaskDataset(mnist, task_id=1, label_offset=0)
    fmnist = TaskDataset(fmnist, task_id=2, label_offset=10)

    return DataLoader(mnist, batch_size=batch_size, shuffle=True), DataLoader(fmnist, batch_size=batch_size, shuffle=True)

In [6]:
def train_task(model, task_id, dataloader, epochs=3, lr=1e-3):
    model.train()
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = lr)
    criterion = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        total_loss = 0
        correct = 0
        total = 0
        for x,y, tid in dataloader:
            x,y = x.to(device), y.to(device)
            out = model(x, task_id)
            loss = criterion(out,y)
            optimizer.zero_grad()
            loss.backward()
            model.apply_task_gradient_mask()
            optimizer.step()
            total_loss += loss.item()
            preds = torch.argmax(out, dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
        acc = 100 * correct / total
        print(f"[Task {task_id}] Epoch {epoch+1} | Loss: {total_loss:.4f} | Accuracy: {acc:.2f}%")

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SynapticNet().to(device)
mnist_loader, fmnist_loader = get_dataloaders()

In [9]:
print("\n Training on MNIST (Task 1)")
train_task(model, task_id=1, dataloader=mnist_loader , epochs = 10)


 Training on MNIST (Task 1)
[Task 1] Epoch 1 | Loss: 368.7942 | Accuracy: 89.34%
[Task 1] Epoch 2 | Loss: 329.8951 | Accuracy: 90.33%
[Task 1] Epoch 3 | Loss: 283.3505 | Accuracy: 91.89%
[Task 1] Epoch 4 | Loss: 244.8758 | Accuracy: 92.88%
[Task 1] Epoch 5 | Loss: 222.1942 | Accuracy: 93.61%
[Task 1] Epoch 6 | Loss: 208.6383 | Accuracy: 93.93%
[Task 1] Epoch 7 | Loss: 196.4845 | Accuracy: 94.24%
[Task 1] Epoch 8 | Loss: 184.1229 | Accuracy: 94.59%
[Task 1] Epoch 9 | Loss: 192.5516 | Accuracy: 94.25%
[Task 1] Epoch 10 | Loss: 170.7890 | Accuracy: 95.05%


In [11]:
model.grow(grow_size=512 , output_grow_size=10, task_id=[2])
model.to(device)

SynapticNet(
  (layers): ModuleList(
    (0-9): 10 x TaskLinear()
  )
  (input_layer): TaskLinear()
  (output_layer): TaskLinear()
  (relu): ReLU()
)

In [12]:
print("\n Training on FMNIST (Task 2)")
train_task(model, task_id=2, dataloader=fmnist_loader , epochs = 10)


 Training on FMNIST (Task 2)
[Task 2] Epoch 1 | Loss: 997.7255 | Accuracy: 57.73%
[Task 2] Epoch 2 | Loss: 666.7542 | Accuracy: 74.37%
[Task 2] Epoch 3 | Loss: 584.0280 | Accuracy: 78.60%
[Task 2] Epoch 4 | Loss: 514.6503 | Accuracy: 81.29%
[Task 2] Epoch 5 | Loss: 467.1802 | Accuracy: 83.34%
[Task 2] Epoch 6 | Loss: 435.0229 | Accuracy: 84.80%
[Task 2] Epoch 7 | Loss: 404.4991 | Accuracy: 85.64%
[Task 2] Epoch 8 | Loss: 407.8861 | Accuracy: 85.45%
[Task 2] Epoch 9 | Loss: 385.7003 | Accuracy: 86.48%
[Task 2] Epoch 10 | Loss: 384.5262 | Accuracy: 86.69%


In [13]:
def evaluate_task(model, task_id, dataloader):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for x, y, tid in dataloader:
            x, y = x.to(device), y.to(device)
            out = model(x, task_id)
            preds = torch.argmax(out, dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)

    acc = 100 * correct / total
    print(f"[Task {task_id}] Evaluation Accuracy: {acc:.2f}%")
    return acc

In [15]:
print("\n testing on MNIST(Task 1)\n")
evaluate_task(model , task_id = 1, dataloader=mnist_loader)
print("\n testing on FMNIST(Task 1)\n")
evaluate_task(model , task_id = 2, dataloader=fmnist_loader)


 testing on MNIST(Task 1)

[Task 1] Evaluation Accuracy: 89.59%

 testing on FMNIST(Task 1)

[Task 2] Evaluation Accuracy: 86.85%


86.84666666666666