In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import datasets, transforms
import numpy as np
from elastic_weight_consolidation import EWC

  _torch_pytree._register_pytree_node(


In [2]:
# PermutedMNIST dataset class
class PermutedMNIST(Dataset):
    def __init__(self, data, permutation=None):
        self.data = data
        if permutation is not None:
            self.permutation = permutation
        else:
            self.permutation = np.random.permutation(28 * 28)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img, label = self.data[idx]

        img_flattened = np.array(img).flatten()
        permuted_image = img_flattened[self.permutation]
        permuted_image = permuted_image.reshape(28, 28)

        permuted_image_tensor = torch.tensor(permuted_image, dtype=torch.float32).unsqueeze(0)

        return permuted_image_tensor, label

# Generate permuted datasets
def permute_dataset():
    train_data = datasets.MNIST(root='./data', train=True, download=True)

    task_datasets = []
    num_tasks = 3  # Task A and Task B
    for i in range(num_tasks):
        random_permutation = np.random.permutation(28 * 28)
        permuted_dataset = PermutedMNIST(train_data, random_permutation)
        task_datasets.append(permuted_dataset)

    return task_datasets

In [3]:
# Define a simple neural network
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Helper function to train the model
def train_model(model, dataloader, criterion, optimizer, ewc=None, lambda_ewc=0.0, epochs=1, device='cpu'):
    model.train()
    for epoch in range(epochs):
        total_loss = 0.0
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)

            # Zero gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)
            task_loss = criterion(outputs, targets)

            # Add EWC loss if applicable
            ewc_loss = ewc.compute_ewc_loss(lambda_ewc) if ewc else 0.0
            loss = task_loss + ewc_loss

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            total_loss += task_loss.item()

        print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(dataloader):.4f}")

# Helper function to evaluate the model
def evaluate_model(model, dataloader, device='cuda'):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()

    return correct / total


In [4]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load permuted datasets
task_datasets = permute_dataset()
loader_a = DataLoader(task_datasets[0], batch_size=64, shuffle=True)
loader_b = DataLoader(task_datasets[1], batch_size=64, shuffle=True)
loader_c = DataLoader(task_datasets[2], batch_size=64, shuffle=True)

# Initialize the model, criterion, and optimizer
model_sgd = SimpleNN().to(device)
model_ewc = SimpleNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer_sgd = optim.SGD(model_sgd.parameters(), lr=0.01)
optimizer_ewc = optim.SGD(model_ewc.parameters(), lr=0.01)

In [5]:
# # Train and evaluate with SGD
# print("Training with SGD...")
# train_model(model_sgd, loader_a, criterion, optimizer_sgd, epochs=20)  # Train on Task A
# acc_a_sgd = evaluate_model(model_sgd, loader_a, device='cpu')
# train_model(model_sgd, loader_b, criterion, optimizer_sgd)  # Train on Task B
# acc_b_sgd = evaluate_model(model_sgd, loader_b, device='cpu')
# acc_ba_sgd = evaluate_model(model_sgd, loader_a, device='cpu')
# train_model(model_sgd, loader_c, criterion, optimizer_sgd)  # Train on Task A again
# acc_c_sgd = evaluate_model(model_sgd, loader_c, device='cpu')
# acc_cb_sgd = evaluate_model(model_sgd, loader_b, device='cpu')
# acc_ca_sgd = evaluate_model(model_sgd, loader_a, device='cpu')


# print(f"SGD - Accuracy on Task A: {acc_a_sgd:.2f}, Task B: {acc_b_sgd:.2f}, Accuracy on A after B: {acc_ba_sgd:.2f}")
# print(f"SGD - Accuracy on Task C: {acc_c_sgd:.2f}, Task B after C: {acc_cb_sgd:.2f}, Task A after C: {acc_ca_sgd:.2f}")

In [6]:
# Train and evaluate with EWC
print("Training with EWC...")
ewc = EWC(model_ewc, device=device)
train_model(model_ewc, loader_a, criterion, optimizer_ewc, epochs=1, device='cpu')  # Train on Task A
acc_a_ewc = evaluate_model(model_ewc, loader_a, device='cpu')

Training with EWC...
Epoch 1/1, Loss: 6.8240


In [7]:
ewc.compute_fisher(loader_a)
ewc.update_params(model_ewc)

train_model(model_ewc, loader_b, criterion, optimizer_ewc, ewc=ewc, epochs=1, lambda_ewc=100.0)  # Train on Task B
acc_b_ewc = evaluate_model(model_ewc, loader_b, device='cpu')
acc_ba_ewc = evaluate_model(model_ewc, loader_a, device='cpu')

ewc.compute_fisher(loader_b)
ewc.update_params(model_ewc)
train_model(model_ewc, loader_c, criterion, optimizer_ewc, ewc=ewc, epochs=1, lambda_ewc=100.0)  # Train on Task C
acc_c_ewc = evaluate_model(model_ewc, loader_c, device='cpu')
acc_cb_ewc = evaluate_model(model_ewc, loader_b, device='cpu')
acc_ca_ewc = evaluate_model(model_ewc, loader_a, device='cpu')
print(f"EWC - Accuracy on Task A: {acc_a_ewc:.2f}, Task B: {acc_b_ewc:.2f}, Accuracy on A after B: {acc_ba_ewc:.2f}")
print(f"EWC - Accuracy on Task C: {acc_c_ewc:.2f}, Task B: {acc_cb_ewc:.2f}, Accuracy on A after C: {acc_ca_ewc:.2f}")

Epoch 1/1, Loss: 1.6068
Epoch 1/1, Loss: 1.7459
EWC - Accuracy on Task A: 0.64, Task B: 0.62, Accuracy on A after B: 0.48
EWC - Accuracy on Task C: 0.38, Task B: 0.52, Accuracy on A after C: 0.45
