<a href="https://colab.research.google.com/github/LShahmiri/Continual-Learning/blob/main/ewc_cifar10_fmnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Data Loaders
transform_cifar = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

transform_fmnist = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.Grayscale(num_output_channels=3),  # match CIFAR input
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

cifar_train = datasets.CIFAR10("../data", train=True, download=True, transform=transform_cifar)
cifar_test = datasets.CIFAR10("../data", train=False, download=True, transform=transform_cifar)
train_loader = DataLoader(cifar_train, batch_size=100, shuffle=True)
test_loader = DataLoader(cifar_test, batch_size=100, shuffle=False)

fmnist_train = datasets.FashionMNIST("../data", train=True, download=True, transform=transform_fmnist)
fmnist_test = datasets.FashionMNIST("../data", train=False, download=True, transform=transform_fmnist)
f_train_loader = DataLoader(fmnist_train, batch_size=100, shuffle=True)
f_test_loader = DataLoader(fmnist_test, batch_size=100, shuffle=False)

# Model
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),  # input 3x32x32
            nn.ReLU(),
            nn.MaxPool2d(2),  # 32x16x16
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)   # 64x8x8
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 8 * 8, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.conv(x)
        return self.fc(x)

# Accuracy
def get_accuracy(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for input, target in dataloader:
            output = model(input.to(device))
            pred = output.argmax(dim=1)
            correct += (pred == target.to(device)).sum().item()
            total += target.size(0)
    return correct / total

# EWC
def ewc_loss(model, weight, fisher, means):
    loss = 0
    for name, param in model.named_parameters():
        if param.requires_grad:
            loss += (fisher[name] * (param - means[name])**2).sum()
    return (weight / 2) * loss

def estimate_ewc_params(model, dataset, num_batches=300):
    fisher = {}
    means = {}
    model.eval()

    for name, param in model.named_parameters():
        fisher[name] = torch.zeros_like(param)
        means[name] = param.data.clone()

    dataloader = DataLoader(dataset, batch_size=100, shuffle=True)

    for i, (input, target) in enumerate(dataloader):
        if i > num_batches:
            break
        model.zero_grad()
        output = model(input.to(device))
        label = output.max(1)[1]
        loss = F.nll_loss(F.log_softmax(output, dim=1), label)
        loss.backward()
        for name, param in model.named_parameters():
            if param.grad is not None:
                fisher[name] += param.grad.data.clone().pow(2) / len(dataloader)

    return means, fisher

# Train CIFAR-10 (Task A)
EPOCHS = 5
model = SimpleCNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

for _ in range(EPOCHS):
    model.train()
    for input, target in tqdm(train_loader, desc="Training Task A"):
        output = model(input.to(device))
        loss = criterion(output, target.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

acc1 = get_accuracy(model, test_loader)
print(f"\nAccuracy on CIFAR-10 after Task A: {acc1:.4f}")

# Estimate Fisher and Means
means, fisher = estimate_ewc_params(model, cifar_train)

# Train Fashion-MNIST (Task B) with EWC
for _ in range(EPOCHS):
    model.train()
    for input, target in tqdm(f_train_loader, desc="Training Task B"):
        output = model(input.to(device))
        loss = criterion(output, target.to(device)) + ewc_loss(model, weight=5000, fisher=fisher, means=means)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

acc2 = get_accuracy(model, test_loader)
acc3 = get_accuracy(model, f_test_loader)
print(f"\nAccuracy on CIFAR-10 after Task B with EWC: {acc2:.4f}")
print(f"Accuracy on Fashion-MNIST after Task B: {acc3:.4f}")


100%|██████████| 170M/170M [00:02<00:00, 67.0MB/s]
100%|██████████| 26.4M/26.4M [00:01<00:00, 17.4MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 302kB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 5.52MB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 13.5MB/s]
Training Task A: 100%|██████████| 500/500 [00:58<00:00,  8.50it/s]
Training Task A: 100%|██████████| 500/500 [00:59<00:00,  8.44it/s]
Training Task A: 100%|██████████| 500/500 [00:56<00:00,  8.87it/s]
Training Task A: 100%|██████████| 500/500 [00:56<00:00,  8.93it/s]
Training Task A: 100%|██████████| 500/500 [00:56<00:00,  8.89it/s]



Accuracy on CIFAR-10 after Task A: 0.7180


Training Task B: 100%|██████████| 600/600 [01:13<00:00,  8.11it/s]
Training Task B: 100%|██████████| 600/600 [01:13<00:00,  8.16it/s]
Training Task B: 100%|██████████| 600/600 [01:17<00:00,  7.77it/s]
Training Task B: 100%|██████████| 600/600 [01:16<00:00,  7.88it/s]
Training Task B: 100%|██████████| 600/600 [01:16<00:00,  7.86it/s]



Accuracy on CIFAR-10 after Task B with EWC: 0.5901
Accuracy on Fashion-MNIST after Task B: 0.8915
