In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt

In [2]:
class SimpleNN(nn.Module):
    def __init__(self, input_size=28 * 28, hidden_size=256, output_size=10):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten input
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [4]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [7]:
class EWC:
    def __init__(self, model, dataloader, importance=1000):
        self.model = model
        self.dataloader = dataloader
        self.importance = importance
        self.params = {n: p for n, p in model.named_parameters()
                       if p.requires_grad}
        self.means = {}
        self.fisher_matrix = {}

        # Compute Fisher Information Matrix
        self._compute_fisher_matrix()

    def _compute_fisher_matrix(self):
        self.model.eval()
        fisher = {n: torch.zeros_like(p) for n, p in self.params.items()}

        for inputs, labels in self.dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            self.model.zero_grad()
            outputs = self.model(inputs)
            loss = nn.CrossEntropyLoss()(outputs, labels)
            loss.backward()

            for n, p in self.params.items():
                fisher[n] += p.grad ** 2 / len(self.dataloader)

        self.fisher_matrix = fisher

        # Save parameter means
        for n, p in self.params.items():
            self.means[n] = p.clone()

    def penalty(self, model):
        penalty = 0
        for n, p in model.named_parameters():
            if n in self.fisher_matrix:
                fisher = self.fisher_matrix[n]
                mean = self.means[n]
                penalty += (fisher * (p - mean) ** 2).sum()
        return self.importance * penalty

# Training function


def train(model, dataloader, optimizer, criterion, ewc=None):
    model.train()
    total_loss = 0
    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        if ewc:
            loss += ewc.penalty(model)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

# Testing function


def test(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

In [9]:
# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(
    root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(
    root='./data', train=False, transform=transform, download=True)

# Divide into class incremental tasks
tasks = [list(range(i, i + 2))
         for i in range(0, 10, 2)]  # [0-1], [2-3], ..., [8-9]

In [None]:

# Incremental Learning Loop
model = SimpleNN().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
ewc = None

for task_idx, task_classes in enumerate(tasks):
    print(f"Task {task_idx + 1}: Classes {task_classes}")

    # Create dataloaders for the current task
    task_train_dataset = [d for d in train_dataset if d[1] in task_classes]
    task_test_dataset = [d for d in test_dataset if d[1] in task_classes]

    task_train_loader = DataLoader(
        task_train_dataset, batch_size=64, shuffle=True)
    task_test_loader = DataLoader(
        task_test_dataset, batch_size=64, shuffle=False)

    # Train on the current task
    for epoch in range(5):  # 5 epochs per task
        train_loss = train(model, task_train_loader, optimizer, criterion, ewc)
        print(f"Epoch {epoch+1}, Loss: {train_loss:.4f}")

    # Evaluate on all tasks seen so far
    for past_task_idx in range(task_idx + 1):
        past_task_classes = tasks[past_task_idx]
        past_task_test_dataset = [
            d for d in test_dataset if d[1] in past_task_classes]
        past_task_test_loader = DataLoader(
            past_task_test_dataset, batch_size=64, shuffle=False)
        accuracy = test(model, past_task_test_loader)
        print(
            f"Accuracy on Task {past_task_idx + 1} ({past_task_classes}): {accuracy:.2f}%")

    # Update EWC
    ewc = EWC(model, task_train_loader,5000)

In [None]:

# Incremental Learning Loop
model = SimpleNN().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
ewc = None

for task_idx, task_classes in enumerate(tasks):
    print(f"Task {task_idx + 1}: Classes {task_classes}")

    # Create dataloaders for the current task
    task_train_dataset = [d for d in train_dataset if d[1] in task_classes]
    task_test_dataset = [d for d in test_dataset if d[1] in task_classes]

    task_train_loader = DataLoader(
        task_train_dataset, batch_size=64, shuffle=True)
    task_test_loader = DataLoader(
        task_test_dataset, batch_size=64, shuffle=False)

    # Train on the current task
    for epoch in range(5):  # 5 epochs per task
        train_loss = train(model, task_train_loader, optimizer, criterion, ewc)
        print(f"Epoch {epoch+1}, Loss: {train_loss:.4f}")

    # Evaluate on all tasks seen so far
    for past_task_idx in range(task_idx + 1):
        past_task_classes = tasks[past_task_idx]
        past_task_test_dataset = [
            d for d in test_dataset if d[1] in past_task_classes]
        past_task_test_loader = DataLoader(
            past_task_test_dataset, batch_size=64, shuffle=False)
        accuracy = test(model, past_task_test_loader)
        print(
            f"Accuracy on Task {past_task_idx + 1} ({past_task_classes}): {accuracy:.2f}%")

    # Update EWC
    ewc = EWC(model, task_train_loader,5000)