In [30]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import datasets, transforms
import numpy as np
from elastic_weight_consolidation import EWC

In [31]:
# PermutedMNIST dataset class
class PermutedMNIST(Dataset):
    def __init__(self, data, permutation=None):
        self.data = data
        if permutation is not None:
            self.permutation = permutation
        else:
            self.permutation = np.random.permutation(28 * 28)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img, label = self.data[idx]

        img_flattened = np.array(img).flatten()
        permuted_image = img_flattened[self.permutation]
        permuted_image = permuted_image.reshape(28, 28)

        permuted_image_tensor = torch.tensor(permuted_image, dtype=torch.float32).unsqueeze(0)

        return permuted_image_tensor, label

# Generate permuted datasets
def permute_dataset():
    train_data = datasets.MNIST(root='./data', train=True, download=True)

    task_datasets = []
    num_tasks = 2  # Task A and Task B
    for i in range(num_tasks):
        random_permutation = np.random.permutation(28 * 28)
        permuted_dataset = PermutedMNIST(train_data, random_permutation)
        task_datasets.append(permuted_dataset)

    return task_datasets

In [32]:
# Define a simple neural network
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Helper function to train the model
def train_model(model, dataloader, criterion, optimizer, ewc=None, lambda_ewc=0.0, epochs=5, device='cpu'):
    model.train()
    for epoch in range(epochs):
        total_loss = 0.0
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)

            # Zero gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)
            task_loss = criterion(outputs, targets)

            # Add EWC loss if applicable
            ewc_loss = ewc.compute_ewc_loss(lambda_ewc) if ewc else 0.0
            loss = task_loss + ewc_loss

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            total_loss += task_loss.item()

        print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(dataloader):.4f}")

# Helper function to evaluate the model
def evaluate_model(model, dataloader, device='cpu'):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()

    return correct / total


In [33]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load permuted datasets
task_datasets = permute_dataset()
loader_a = DataLoader(task_datasets[0], batch_size=64, shuffle=True)
loader_b = DataLoader(task_datasets[1], batch_size=64, shuffle=True)

# Initialize the model, criterion, and optimizer
model_sgd = SimpleNN().to(device)
model_ewc = SimpleNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer_sgd = optim.SGD(model_sgd.parameters(), lr=0.01)
optimizer_ewc = optim.SGD(model_ewc.parameters(), lr=0.01)

In [28]:
# Initialize the model, criterion, and optimizer
model_sgd = SimpleNN().to(device)
model_ewc = SimpleNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer_sgd = optim.SGD(model_sgd.parameters(), lr=0.01)
optimizer_ewc = optim.SGD(model_ewc.parameters(), lr=0.01)

# Train and evaluate with SGD
print("Training with SGD...")
train_model(model_sgd, loader_a, criterion, optimizer_sgd, epochs=20)  # Train on Task A
acc_a_sgd = evaluate_model(model_sgd, loader_a)
train_model(model_sgd, loader_b, criterion, optimizer_sgd)  # Train on Task B
acc_b_sgd = evaluate_model(model_sgd, loader_b)
acc_ba_sgd = evaluate_model(model_sgd, loader_a)

print(f"SGD - Accuracy on Task A: {acc_a_sgd:.2f}, Task B: {acc_b_sgd:.2f}, Accuracy on A after B: {acc_ba_sgd:.2f}")

Training with SGD...
Epoch 1/20, Loss: 3.3376
Epoch 2/20, Loss: 1.2570
Epoch 3/20, Loss: 1.0638
Epoch 4/20, Loss: 1.0091
Epoch 5/20, Loss: 0.9465
Epoch 6/20, Loss: 0.7951
Epoch 7/20, Loss: 0.8601
Epoch 8/20, Loss: 0.7825
Epoch 9/20, Loss: 0.8297
Epoch 10/20, Loss: 0.7013
Epoch 11/20, Loss: 0.6354
Epoch 12/20, Loss: 0.5890
Epoch 13/20, Loss: 0.5915
Epoch 14/20, Loss: 0.5382
Epoch 15/20, Loss: 0.5069
Epoch 16/20, Loss: 0.5949
Epoch 17/20, Loss: 0.5166
Epoch 18/20, Loss: 0.5894
Epoch 19/20, Loss: 0.7752
Epoch 20/20, Loss: 0.8007
Epoch 1/5, Loss: 1.9405
Epoch 2/5, Loss: 1.5016
Epoch 3/5, Loss: 1.1362
Epoch 4/5, Loss: 1.0629
Epoch 5/5, Loss: 1.1564
SGD - Accuracy on Task A: 0.80, Task B: 0.48, Accuracy on A after B: 0.53


In [29]:
# Train and evaluate with EWC
print("Training with EWC...")
ewc = EWC(model_ewc, loader_a, device=device)
train_model(model_ewc, loader_a, criterion, optimizer_ewc, epochs=20)  # Train on Task A
ewc.compute_fisher()
ewc.store_prev_params()
acc_a_ewc = evaluate_model(model_ewc, loader_a)

train_model(model_ewc, loader_b, criterion, optimizer_ewc, ewc=ewc, lambda_ewc=100.0)  # Train on Task B
acc_b_ewc = evaluate_model(model_ewc, loader_b)
acc_ba_ewc = evaluate_model(model_ewc, loader_a)

print(f"EWC - Accuracy on Task A: {acc_a_ewc:.2f}, Task B: {acc_b_ewc:.2f}, Accuracy on A after B: {acc_ba_ewc:.2f}")

Training with EWC...
Epoch 1/20, Loss: 7.1965
Epoch 2/20, Loss: 0.9511
Epoch 3/20, Loss: 0.9432
Epoch 4/20, Loss: 0.9446
Epoch 5/20, Loss: 0.8819
Epoch 6/20, Loss: 0.9131
Epoch 7/20, Loss: 0.8584
Epoch 8/20, Loss: 0.8794
Epoch 9/20, Loss: 0.8803
Epoch 10/20, Loss: 0.8303
Epoch 11/20, Loss: 0.7707
Epoch 12/20, Loss: 0.8610
Epoch 13/20, Loss: 0.8667
Epoch 14/20, Loss: 0.7741
Epoch 15/20, Loss: 0.7384
Epoch 16/20, Loss: 0.7271
Epoch 17/20, Loss: 0.7591
Epoch 18/20, Loss: 0.6638
Epoch 19/20, Loss: 0.6595
Epoch 20/20, Loss: 0.6242
Epoch 1/5, Loss: 1.7872
Epoch 2/5, Loss: 1.2489
Epoch 3/5, Loss: 1.1441
Epoch 4/5, Loss: 1.0308
Epoch 5/5, Loss: 1.0393
EWC - Accuracy on Task A: 0.83, Task B: 0.64, Accuracy on A after B: 0.84
