In [1]:
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Normalize

In [2]:
class FashionMNISTtask1(datasets.FashionMNIST):
  def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
    super(FashionMNISTtask1, self).__init__(root, train=train, transform=transform, target_transform=target_transform,download=download)
    self.classes = self.classes[:6]

  def __getitem__(self, index):
    img, target = super(FashionMNISTtask1, self).__getitem__(index)
    if target < 6:
        return img, target
    else:
        return img, -1

In [3]:
class FashionMNISTtask2(datasets.FashionMNIST):
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        super(FashionMNISTtask2, self).__init__(root, train=train, transform=transform, target_transform=target_transform,download=download)
        self.classes = self.classes[6:]

    def __getitem__(self, index):
        img, target = super(FashionMNISTtask2, self).__getitem__(index)
        if target >= 6:
            return img, target
        else:
            return img, -1

In [4]:
transform = torchvision.transforms.Compose([ToTensor(), Normalize((0.5), (0.5))])

train_dataset_task1 = FashionMNISTtask1(root='./data1', train=True, transform=transform, download=True)
test_dataset_task1 = FashionMNISTtask1(root='./data1', train=False, transform=transform, download=True)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data1/FashionMNISTtask1/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:03<00:00, 7624414.92it/s] 


Extracting ./data1/FashionMNISTtask1/raw/train-images-idx3-ubyte.gz to ./data1/FashionMNISTtask1/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data1/FashionMNISTtask1/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 136220.66it/s]


Extracting ./data1/FashionMNISTtask1/raw/train-labels-idx1-ubyte.gz to ./data1/FashionMNISTtask1/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data1/FashionMNISTtask1/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:01<00:00, 2541652.80it/s]


Extracting ./data1/FashionMNISTtask1/raw/t10k-images-idx3-ubyte.gz to ./data1/FashionMNISTtask1/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data1/FashionMNISTtask1/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 8244473.84it/s]

Extracting ./data1/FashionMNISTtask1/raw/t10k-labels-idx1-ubyte.gz to ./data1/FashionMNISTtask1/raw






In [5]:
train_dataset_task2 = FashionMNISTtask2(root='./data2', train=True, transform=transform, download=True)
test_dataset_task2 = FashionMNISTtask2(root='./data2', train=False, transform=transform, download=True)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data2/FashionMNISTtask2/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:03<00:00, 7714507.84it/s] 


Extracting ./data2/FashionMNISTtask2/raw/train-images-idx3-ubyte.gz to ./data2/FashionMNISTtask2/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data2/FashionMNISTtask2/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 140804.33it/s]


Extracting ./data2/FashionMNISTtask2/raw/train-labels-idx1-ubyte.gz to ./data2/FashionMNISTtask2/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data2/FashionMNISTtask2/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:01<00:00, 2609766.26it/s]


Extracting ./data2/FashionMNISTtask2/raw/t10k-images-idx3-ubyte.gz to ./data2/FashionMNISTtask2/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data2/FashionMNISTtask2/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 4113598.21it/s]

Extracting ./data2/FashionMNISTtask2/raw/t10k-labels-idx1-ubyte.gz to ./data2/FashionMNISTtask2/raw






In [6]:
train_dataset_filtered_task1 = [data for data in train_dataset_task1 if data[1] != -1]
test_dataset_filtered_task1 = [data for data in test_dataset_task1 if data[1] != -1]

train_dataset_filtered_task2 = [data for data in train_dataset_task2 if data[1] != -1]
test_dataset_filtered_task2 = [data for data in test_dataset_task2 if data[1] != -1]

In [7]:
train_dataloader_task1 = DataLoader(train_dataset_filtered_task1, batch_size=64, shuffle=True)
test_dataloader_task1 = DataLoader(test_dataset_filtered_task1, batch_size=256, shuffle=False)

for X, y in train_dataloader_task1:
  print(f"Shape of X [N, C, H, W]: {X.shape}")
  print(f"SHape of y: {y.shape}, dtype: {y.dtype}")
  break

Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
SHape of y: torch.Size([64]), dtype: torch.int64


In [8]:
train_dataloader_task2 = DataLoader(train_dataset_filtered_task2, batch_size=64, shuffle=True)
test_dataloader_task2 = DataLoader(test_dataset_filtered_task2, batch_size=256, shuffle=False)

In [9]:
device = ('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using {device} device.")

Using cpu device.


In [16]:
class NeuralNetwork(nn.Module):
    def __init__(self, num_classes=10, hidden_size=512):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28*28, hidden_size)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.relu2 = nn.ReLU()
        self.classifier = nn.Linear(hidden_size, num_classes)

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
            elif isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')

    def forward(self, x):
        x = self.flatten(x)
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        logits = self.classifier(x)

        return logits

In [17]:
model = NeuralNetwork()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

model.to(device)

NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc1): Linear(in_features=784, out_features=512, bias=True)
  (relu1): ReLU()
  (fc2): Linear(in_features=512, out_features=512, bias=True)
  (relu2): ReLU()
  (classifier): Linear(in_features=512, out_features=10, bias=True)
)

In [18]:
W = {n: torch.zeros_like(p, requires_grad=False) for n, p in model.named_parameters() if p.requires_grad}
p_old = {n: p.clone().detach() for n, p in model.named_parameters() if p.requires_grad}

In [19]:
# Surrogate loss function for Synaptic Intelligence
def surrogate_loss(W, p_old):
    loss = 0
    for n, p in model.named_parameters():
        if p.requires_grad:
            loss += (W[n] * (p - p_old[n]) ** 2).sum()
    return loss

In [20]:
# Update omega after completing a task
def update_omega(W, param_importance, epsilon=0.1):
    for n, p in model.named_parameters():
        if p.requires_grad:
            delta = p.detach() - p_old[n]
            W[n] += param_importance[n] / (delta ** 2 + epsilon)

In [25]:
# Synaptic Intelligence regularization coefficient
c = 0.1

# Training loop
for task_data in [train_dataloader_task1, train_dataloader_task2]:  # Assuming these are DataLoader instances
    param_importance = {n: torch.zeros_like(p, requires_grad=False) for n, p in model.named_parameters() if p.requires_grad}

    for epoch in range(3):
        for X, y in task_data:
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            output = model(X)
            loss = criterion(output, y)

            # Compute gradients for the current task
            loss.backward(retain_graph=True)  # Retain graph if needed for subsequent operations

            # Update parameter importance dynamically during training
            for n, p in model.named_parameters():
                if p.requires_grad:
                    if p.grad is not None:  # Ensure gradients exist
                        param_importance[n] += p.grad.abs() * (p.detach() - p_old[n]).abs()

            # Apply Synaptic Intelligence regularization
            si_loss = surrogate_loss(W, p_old)
            total_loss = loss + c * si_loss
            optimizer.zero_grad()  # Clear gradients to avoid accumulation
            total_loss.backward()  # Backward pass on total loss

            optimizer.step()

    # Update omega (W) for the next task after training is complete
    update_omega(W, param_importance)

    # Update old parameters (p_old) for the next task
    p_old = {n: p.clone().detach() for n, p in model.named_parameters() if p.requires_grad}


In [28]:
def evaluate_model(model, task_loader):
    model.eval()
    correct = 0
    total = 0
    for inputs, labels in task_loader:
        inputs, labels = inputs.to(device), labels.to(device)  # Ensure inputs and labels are on the correct device
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    accuracy = correct / total  # Calculate accuracy for the entire task
    print(f"Task Accuracy: {accuracy*100}")
    return accuracy

In [29]:
acc = evaluate_model(model, test_dataloader_task1)
acc

Task Accuracy: 13.0


0.13