In [None]:
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Normalize

In [None]:
class FashionMNISTtask1(datasets.FashionMNIST):
  def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
    super(FashionMNISTtask1, self).__init__(root, train=train, transform=transform, target_transform=target_transform,download=download)
    self.classes = self.classes[:6]

  def __getitem__(self, index):
    img, target = super(FashionMNISTtask1, self).__getitem__(index)
    if target < 6:
        return img, target
    else:
        return img, -1

In [None]:
class FashionMNISTtask2(datasets.FashionMNIST):
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        super(FashionMNISTtask2, self).__init__(root, train=train, transform=transform, target_transform=target_transform,download=download)
        self.classes = self.classes[6:]

    def __getitem__(self, index):
        img, target = super(FashionMNISTtask2, self).__getitem__(index)
        if target >= 6:
            return img, target
        else:
            return img, -1

In [None]:
transform = torchvision.transforms.Compose([ToTensor(), Normalize((0.5), (0.5))])

train_dataset_1 = FashionMNISTtask1(root='./data1', train=True, transform=transform, download=True)
test_dataset_1 = FashionMNISTtask1(root='./data1', train=False, transform=transform, download=True)

In [None]:
train_dataset_2 = FashionMNISTtask2(root='./data2', train=True, transform=transform, download=True)
test_dataset_2 = FashionMNISTtask2(root='./data2', train=False, transform=transform, download=True)
test_dataset_3 = datasets.FashionMNIST(
    root='data3',
    train=False,
    download=True,
    transform=transform
)

In [None]:
train_dataset_filtered_old = [data for data in train_dataset_1 if data[1] != -1]
test_dataset_filtered_old = [data for data in test_dataset_1 if data[1] != -1]

train_dataset_filtered_new = [data for data in train_dataset_2 if data[1] != -1]
test_dataset_filtered_new = [data for data in test_dataset_2 if data[1] != -1]

In [None]:
train_dataloader_old = DataLoader(train_dataset_filtered_old, batch_size=64, shuffle=True)
test_dataloader_old = DataLoader(test_dataset_filtered_old, batch_size=256, shuffle=False)

for X, y in train_dataloader_old:
  print(f"Shape of X [N, C, H, W]: {X.shape}")
  print(f"SHape of y: {y.shape}, dtype: {y.dtype}")
  break

In [None]:
train_dataloader_new = DataLoader(train_dataset_filtered_new, batch_size=64, shuffle=True)
test_dataloader_new = DataLoader(test_dataset_filtered_new, batch_size=256, shuffle=False)

In [None]:
eval_dataloader = DataLoader(test_dataset_3, batch_size=256, shuffle=True)

In [None]:
device = ('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using {device} device.")

In [None]:
class NeuralNetwork(nn.Module):
    def __init__(self, num_classes=10, hidden_size=512):
        super(CNN, self).__init__()

        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28*28, hidden_size)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.relu2 = nn.ReLU()
        self.classifier = nn.Linear(hidden_size, num_classes)

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
            elif isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')

    def forward(self, x):
        x = self.flatten(x)
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        logits = self.classifier(x)

        return logits

In [None]:
def update_omega(W: dict[str, float], epsilon: float):
    for n, p in model.named_parameters():
        if p.requires_grad:
            p_prev = old_params[n]
            p_current = p.detach().clone()
            p_change = p_current - p_prev

            omega_add = W[n] / (p_change ** 2 + epsilon)
            omega =
            omega_new = omega + omega_add

            old_params[n] = p_current
            omega

In [None]:
model = NeuralNetwork()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
model.to(device)
# Initialize importance estimates for parameters
param_importance = {name: torch.ones_like(param, requires_grad=False)
                    for name, param in model.named_parameters()}
# I am not sure which one is correct
old_params = {n: p for n, m in model.named_parameters() if p.requires_grad}

continual_learning_tasks = [train_dataloader_old, train_dataloader_new]
num_epochs_per_task = 3

# Continual learning loop
for task_data in continual_learning_tasks:
    for epoch in range(num_epochs_per_task):
        for X, y in task_data:
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            pred = model(X)
            loss = criterion(pred, y)

            # Compute gradients
            loss.backward()

            # Update importance estimates
            for name, param in model.named_parameters():
                param_importance[name] += torch.abs(param.grad)

            # Synaptic intelligence regularization
            si_penalty = 0
            for name, param in model.named_parameters():
                si_penalty += torch.sum(param_importance[name] * (param ** 2))

            # Add SI regularization to the loss
            loss += c * si_penalty

            optimizer.step()