In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
import copy
import matplotlib.pyplot as plt

# Basic MLP
class MLP(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=256, output_dim=10):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.net(x.view(x.size(0), -1))

# Fisher Estimator
def estimate_fisher(model, dataset, criterion, sample_size=1024):
    model.eval()
    fisher = {n: torch.zeros_like(p) for n, p in model.named_parameters()}
    loader = DataLoader(dataset, batch_size=128, shuffle=True)
    for i, (x, y) in enumerate(loader):
        if i * 128 >= sample_size: break
        x, y = x.cuda(), y.cuda()
        model.zero_grad()
        output = model(x)
        loss = criterion(output, y)
        loss.backward()
        for n, p in model.named_parameters():
            fisher[n] += (p.grad ** 2) / sample_size
    return fisher

# Online EWC penalty
def ewc_loss(model, fisher, prev_params, lambda_):
    loss = 0.0
    for (n, p), prev_p in zip(model.named_parameters(), prev_params):
        loss += (fisher[n] * (p - prev_p).pow(2)).sum()
    return lambda_ * loss

# Task data with pixel permutations
def get_permuted_mnist(task_id, seed=0):
    torch.manual_seed(seed + task_id)
    perm = torch.randperm(28*28)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.view(-1)[perm].view(1, 28, 28))
    ])
    train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    return train, test

# Training loop for each task
def train_task(model, trainset, testsets, ewc_data, use_ewc=False, lambda_=0.4, gamma=1.0, epochs=3):
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    loader = DataLoader(trainset, batch_size=128, shuffle=True)

    for epoch in range(epochs):
        for x, y in loader:
            x, y = x.cuda(), y.cuda()
            optimizer.zero_grad()
            output = model(x)
            loss = criterion(output, y)

            if use_ewc and ewc_data:
                loss += ewc_loss(model, ewc_data['fisher'], ewc_data['params'], lambda_)

            loss.backward()
            optimizer.step()

    # Save current parameters
    new_params = [p.detach().clone() for p in model.parameters()]
    fisher = estimate_fisher(model, trainset, criterion)

    if ewc_data:
        # Incremental EWC update
        for n in fisher:
            fisher[n] = fisher[n] + gamma * ewc_data['fisher'][n]
    return {'params': new_params, 'fisher': fisher}

# Evaluate on all tasks
def evaluate(model, testsets):
    model.eval()
    accs = []
    for testset in testsets:
        correct = 0
        total = 0
        loader = DataLoader(testset, batch_size=256)
        with torch.no_grad():
            for x, y in loader:
                x, y = x.cuda(), y.cuda()
                pred = model(x).argmax(1)
                correct += (pred == y).sum().item()
                total += y.size(0)
        accs.append(correct / total)
    return accs


ModuleNotFoundError: No module named 'torch'

In [None]:
# Setup
torch.manual_seed(0)
num_tasks = 5
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Two models: baseline and with EWC
model_base = MLP().to(device)
model_ewc = MLP().to(device)
ewc_data = None

results_base = []
results_ewc = []
testsets = []

for task_id in range(num_tasks):
    print(f"\n=== Task {task_id+1} ===")
    trainset, testset = get_permuted_mnist(task_id)
    testsets.append(testset)

    # Train baseline
    train_task(model_base, trainset, testsets, None, use_ewc=False)
    acc_base = evaluate(model_base, testsets)
    results_base.append(acc_base)

    # Train EWC model
    ewc_data = train_task(model_ewc, trainset, testsets, ewc_data, use_ewc=True)
    acc_ewc = evaluate(model_ewc, testsets)
    results_ewc.append(acc_ewc)


In [None]:
results_base = np.array(results_base)
results_ewc = np.array(results_ewc)

plt.figure(figsize=(10, 5))
for t in range(num_tasks):
    plt.plot(range(t+1, num_tasks+1), results_base[t:, t], label=f'Base: Task {t+1}')
    plt.plot(range(t+1, num_tasks+1), results_ewc[t:, t], '--', label=f'EWC: Task {t+1}')
plt.xlabel('Task index (after)')
plt.ylabel('Accuracy on Task')
plt.title('Continual Learning: Baseline vs EWC')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()
