In [16]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Subset
import numpy as np
import matplotlib.pyplot as plt

In [17]:
def get_split_mnist(task_num=2):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    mnist_test = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    
    if task_num == 1:
        train_indices = [i for i, target in enumerate(mnist_train.targets) if target in [0, 1, 2, 3, 4]]
        test_indices = [i for i, target in enumerate(mnist_test.targets) if target in [0, 1, 2, 3, 4]]
    elif task_num == 2:
        train_indices = [i for i, target in enumerate(mnist_train.targets) if target in [5, 6, 7, 8, 9]]
        test_indices = [i for i, target in enumerate(mnist_test.targets) if target in [5, 6, 7, 8, 9]]

    train_dataset = Subset(mnist_train, train_indices)
    test_dataset = Subset(mnist_test, test_indices)

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

    return train_loader, test_loader

In [18]:
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28*28, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [1]:
class EWC:
    def __init__(self, model, dataset):
        self.model = model
        self.dataset = dataset
        self.params = {n: p.clone().detach() for n, p in self.model.named_parameters() if p.requires_grad}
        self._precision_matrices = self._calculate_fisher()

    def _calculate_fisher(self):
        precision_matrices = {n: torch.zeros(p.size()) for n, p in self.model.named_parameters() if p.requires_grad}
        self.model.eval()
        
        for data in self.dataset:
            inputs, labels = data
            self.model.zero_grad()
            outputs = self.model(inputs)
            loss = torch.nn.functional.cross_entropy(outputs, labels)
            loss.backward()
            
            for n, p in self.model.named_parameters():
                if p.requires_grad:
                    precision_matrices[n].data += p.grad.data ** 2 / len(self.dataset)
        
        return precision_matrices

    def penalty(self, model):
        loss = 0
        for n, p in model.named_parameters():
            if p.requires_grad:
                _loss = self._precision_matrices[n] * (p - self.params[n]) ** 2
                loss += _loss.sum()
        return loss

In [20]:
def train_model(model, optimizer, criterion, dataloader, ewc=None, ewc_lambda=0):
    model.train()
    total_loss = 0
    for data in dataloader:
        inputs, labels = data

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        if ewc:
            loss += ewc_lambda * ewc.penalty(model)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    return total_loss / len(dataloader)

def evaluate_model(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in dataloader:
            inputs, labels = data
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

In [21]:
# Create the model
model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Train on Task 1
train_loader_1, test_loader_1 = get_split_mnist(task_num=1)
train_model(model, optimizer, criterion, train_loader_1)

# Evaluate on Task 1
accuracy_1 = evaluate_model(model, test_loader_1)
print(f'Test Accuracy on Task 1: {accuracy_1 * 100:.2f}%')

# Apply EWC for Task 2
ewc = EWC(model, train_loader_1)

# Train on Task 2
train_loader_2, test_loader_2 = get_split_mnist(task_num=2)
ewc_lambda = 0.4
train_model(model, optimizer, criterion, train_loader_2, ewc, ewc_lambda)

# Evaluate on Task 2
accuracy_2 = evaluate_model(model, test_loader_2)
print(f'Test Accuracy on Task 2: {accuracy_2 * 100:.2f}%')

# Evaluate on Task 1 again
accuracy_1_revisited = evaluate_model(model, test_loader_1)
print(f'Test Accuracy on Task 1 (revisited): {accuracy_1_revisited * 100:.2f}%')

Test Accuracy on Task 1: 95.97%
Test Accuracy on Task 2: 91.73%
Test Accuracy on Task 1 (revisited): 0.00%
