In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import numpy as np

# 1. Define the Neural Network Model
class SimpleNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

# 2. Implement Elastic Weight Consolidation (EWC)
class EWC:
    def __init__(self, model, dataset, lambda_ewc):
        self.model = model
        self.dataset = dataset
        self.lambda_ewc = lambda_ewc
        self.params = {n: p for n, p in self.model.named_parameters() if p.requires_grad}
        self._means = {}
        self._fisher = {}
        self._compute_fisher_and_means()

    def _compute_fisher_and_means(self):
        self.model.eval()
        for n, p in self.params.items():
            self._means[n] = p.clone().detach()
            self._fisher[n] = torch.zeros_like(p)

        data_loader = DataLoader(self.dataset, batch_size=1, shuffle=True)
        for x, y in data_loader:
            self.model.zero_grad()
            output = self.model(x.view(x.size(0), -1))
            loss = nn.CrossEntropyLoss()(output, y)
            loss.backward()
            for n, p in self.params.items():
                self._fisher[n] += p.grad.data ** 2 / len(data_loader)

        for n, p in self._fisher.items():
            self._fisher[n] = p / len(self.dataset)

    def penalty(self):
        loss = 0
        for n, p in self.params.items():
            _loss = self._fisher[n] * (p - self._means[n]) ** 2
            loss += _loss.sum()
        return self.lambda_ewc * loss

# 3. Training Loop with EWC and Learning Rate Scheduler
def train_ewc(model, datasets, epochs, learning_rate, lambda_ewc):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)  # Example scheduler

    for i, dataset in enumerate(datasets):
        data_loader = DataLoader(dataset, batch_size=64, shuffle=True)
        ewc = EWC(model, dataset, lambda_ewc)

        for epoch in range(epochs):
            model.train()
            total_loss = 0
            for inputs, labels in data_loader:
                optimizer.zero_grad()
                outputs = model(inputs.view(inputs.size(0), -1))
                loss = criterion(outputs, labels)
                if i > 0:
                    loss += ewc.penalty()
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            # Step the learning rate scheduler
            scheduler.step()

            print(f"Task {i+1}, Epoch {epoch+1}, Loss: {total_loss / len(data_loader)}, LR: {scheduler.get_last_lr()}")

# 4. Evaluation Metrics
def evaluate_task(model, dataset):
    data_loader = DataLoader(dataset, batch_size=64, shuffle=False)
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in data_loader:
            outputs = model(inputs.view(inputs.size(0), -1))
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

def evaluate_all_tasks(model, datasets):
    accuracies = []
    for i, dataset in enumerate(datasets):
        accuracy = evaluate_task(model, dataset)
        accuracies.append(accuracy)
        print(f"Accuracy on Task {i+1}: {accuracy:.2f}%")
    average_accuracy = sum(accuracies) / len(accuracies)
    print(f"Average Accuracy: {average_accuracy:.2f}%")
    return accuracies, average_accuracy

def calculate_forgetting(initial_accuracies, final_accuracies):
    forgetting = []
    for i in range(len(initial_accuracies)):
        forgetting.append(initial_accuracies[i] - final_accuracies[i])
    avg_forgetting = sum(forgetting) / len(forgetting)
    print(f"Average Forgetting: {avg_forgetting:.2f}%")
    return forgetting, avg_forgetting

def calculate_forward_transfer(model, datasets):
    forward_transfer = []
    for i in range(1, len(datasets)):
        fresh_model = SimpleNN(input_size=784, hidden_size=256, output_size=2)
        train_ewc(fresh_model, [datasets[i]], epochs=10, learning_rate=0.01, lambda_ewc=0)
        fresh_accuracy = evaluate_task(fresh_model, datasets[i])

        current_accuracy = evaluate_task(model, datasets[i])
        transfer = current_accuracy - fresh_accuracy
        forward_transfer.append(transfer)
        print(f"Forward Transfer for Task {i+1}: {transfer:.2f}%")
    avg_forward_transfer = sum(forward_transfer) / len(forward_transfer)
    print(f"Average Forward Transfer: {avg_forward_transfer:.2f}%")
    return forward_transfer, avg_forward_transfer

# 5. Example Usage with Fashion-MNIST Dataset

# Load the Fashion-MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
fashion_mnist = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)

# Split Fashion-MNIST into two tasks:
# Task 1: Classifying T-shirts/Tops (label 0) vs Trousers (label 1)
# Task 2: Classifying Pullovers (label 2) vs Dresses (label 3)

indices_task1 = np.where((fashion_mnist.targets == 0) | (fashion_mnist.targets == 1))[0]
indices_task2 = np.where((fashion_mnist.targets == 2) | (fashion_mnist.targets == 3))[0]

dataset_task1 = Subset(fashion_mnist, indices_task1)
dataset_task2 = Subset(fashion_mnist, indices_task2)

# Re-label the tasks to have labels 0 and 1
for i in range(len(dataset_task1)):
    dataset_task1.dataset.targets[dataset_task1.indices[i]] = int(dataset_task1.dataset.targets[dataset_task1.indices[i]] == 1)

for i in range(len(dataset_task2)):
    dataset_task2.dataset.targets[dataset_task2.indices[i]] = int(dataset_task2.dataset.targets[dataset_task2.indices[i]] == 3)

datasets = [dataset_task1, dataset_task2]

# Initialize the model and train
model = SimpleNN(input_size=784, hidden_size=256, output_size=2)
train_ewc(model, datasets, epochs=10, learning_rate=0.01, lambda_ewc=0.4)

# Evaluate performance after learning all tasks
initial_accuracies = evaluate_all_tasks(model, datasets)[0]
final_accuracies = evaluate_all_tasks(model, datasets)[0]
forgetting, avg_forgetting = calculate_forgetting(initial_accuracies, final_accuracies)
forward_transfer, avg_forward_transfer = calculate_forward_transfer(model, datasets)

# Summary of results
print(f"Average Accuracy: {sum(final_accuracies)/len(final_accuracies):.2f}%")
print(f"Average Forgetting: {avg_forgetting:.2f}%")
print(f"Average Forward Transfer: {avg_forward_transfer:.2f}%")


Task 1, Epoch 1, Loss: 0.16201461815929158, LR: [0.01]
Task 1, Epoch 2, Loss: 0.0644669945331964, LR: [0.01]
Task 1, Epoch 3, Loss: 0.0517908925032045, LR: [0.01]
Task 1, Epoch 4, Loss: 0.04648446631023383, LR: [0.01]
Task 1, Epoch 5, Loss: 0.042838892379300074, LR: [0.001]
Task 1, Epoch 6, Loss: 0.041561374313227754, LR: [0.001]
Task 1, Epoch 7, Loss: 0.04076181133921714, LR: [0.001]
Task 1, Epoch 8, Loss: 0.04047835495402204, LR: [0.001]
Task 1, Epoch 9, Loss: 0.04026413850109786, LR: [0.001]
Task 1, Epoch 10, Loss: 0.04003787631827823, LR: [0.0001]
Task 2, Epoch 1, Loss: 0.7996512993853143, LR: [0.0001]
Task 2, Epoch 2, Loss: 0.5752830529149543, LR: [0.0001]
Task 2, Epoch 3, Loss: 0.4648738855377157, LR: [0.0001]
Task 2, Epoch 4, Loss: 0.3955476814286506, LR: [0.0001]
Task 2, Epoch 5, Loss: 0.34809246453198983, LR: [1e-05]
Task 2, Epoch 6, Loss: 0.3267561170331975, LR: [1e-05]
Task 2, Epoch 7, Loss: 0.32395265655631716, LR: [1e-05]
Task 2, Epoch 8, Loss: 0.3212822937267892, LR: [1e-