<a href="https://colab.research.google.com/github/LShahmiri/Continual-Learning/blob/main/ewc_mnist_fashionmnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import autograd
import numpy as np
from torch.utils.data import DataLoader

from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm

def get_accuracy(model, dataloader):
    model = model.eval()
    acc = 0
    for input, target in dataloader:
        o = model(input.to(device))
        acc += (o.argmax(dim=1).long() == target.to(device)).float().mean()
    return acc / len(dataloader)

class LinearLayer(nn.Module):
    # from https://github.com/shivamsaboo17/Overcoming-Catastrophic-forgetting-in-Neural-Networks/blob/master/elastic_weight_consolidation.py
    def __init__(self, input_dim, output_dim, act='relu', use_bn=False):
        super(LinearLayer, self).__init__()
        self.use_bn = use_bn
        self.lin = nn.Linear(input_dim, output_dim)
        self.act = nn.ReLU() if act == 'relu' else act
        if use_bn:
            self.bn = nn.BatchNorm1d(output_dim)
    def forward(self, x):
        if self.use_bn:
            return self.bn(self.act(self.lin(x)))
        return self.act(self.lin(x))

class Flatten(nn.Module):

    def forward(self, x):
        return x.view(x.shape[0], -1)

class Model(nn.Module):

    def __init__(self, num_inputs, num_hidden, num_outputs):
        super(Model, self).__init__()
        self.f1 = Flatten()
        self.lin1 = LinearLayer(num_inputs, num_hidden, use_bn=True)
        self.lin2 = LinearLayer(num_hidden, num_hidden, use_bn=True)
        self.lin3 = nn.Linear(num_hidden, num_outputs)

    def forward(self, x):
        return self.lin3(self.lin2(self.lin1(self.f1(x))))

# Load MNIST dataset, representint task A
mnist_train = datasets.MNIST("../data", train=True, download=True, transform=transforms.ToTensor())
mnist_test = datasets.MNIST("../data", train=False, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(mnist_train, batch_size = 100, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size = 100, shuffle=False)

# FashiomMNIST is task B
f_mnist_train = datasets.FashionMNIST("../data", train=True, download=True, transform=transforms.ToTensor())
f_mnist_test = datasets.FashionMNIST("../data", train=False, download=True, transform=transforms.ToTensor())
f_train_loader = DataLoader(f_mnist_train, batch_size = 100, shuffle=True)
f_test_loader = DataLoader(f_mnist_test, batch_size = 100, shuffle=False)

100%|██████████| 9.91M/9.91M [00:00<00:00, 20.7MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 497kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.61MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 13.4MB/s]
100%|██████████| 26.4M/26.4M [00:02<00:00, 12.9MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 200kB/s]
100%|██████████| 4.42M/4.42M [00:01<00:00, 3.75MB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 27.8MB/s]


In [11]:
# parameters
EPOCHS = 4
lr=0.001
weight=100000
accuracies = {}

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

criterion = nn.CrossEntropyLoss()

# train model on task A
model = Model(28 * 28, 100, 10).to(device)
optimizer = optim.Adam(model.parameters(), lr)

for _ in range(EPOCHS):
    for input, target in tqdm(train_loader):
        output = model(input.to(device))
        loss = criterion(output, target.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

accuracies['mnist_initial'] = get_accuracy(model, test_loader)

100%|██████████| 600/600 [00:07<00:00, 76.71it/s]
100%|██████████| 600/600 [00:07<00:00, 79.74it/s]
100%|██████████| 600/600 [00:06<00:00, 88.11it/s]
100%|██████████| 600/600 [00:07<00:00, 80.79it/s]


In [12]:
accuracies

{'mnist_initial': tensor(0.9765, device='cuda:0')}

In [13]:
def ewc_loss(model, weight, estimated_fishers, estimated_means):
    losses = []
    for param_name, param in model.named_parameters():
        estimated_mean = estimated_means[param_name]
        estimated_fisher = estimated_fishers[param_name]
        losses.append((estimated_fisher * (param - estimated_mean) ** 2).sum())

    return (weight / 2) * sum(losses)

def estimate_ewc_params(model, train_ds, batch_size=100, num_batch=300, estimate_type='true'):
    estimated_mean = {}

    for param_name, param in model.named_parameters():
        estimated_mean[param_name] = param.data.clone()

    estimated_fisher = {}
    dl = DataLoader(train_ds, batch_size, shuffle=True)

    for n, p in model.named_parameters():
        estimated_fisher[n] = torch.zeros_like(p)

    model.eval()
    for i, (input, target) in enumerate(dl):
        if i > num_batch:
            break
        model.zero_grad()

        output = model(input.to(device))
        # https://www.inference.vc/on-empirical-fisher-information/ - more on this here
        if ESTIMATE_TYPE == 'empirical':
            # empirical
            label = target.to(device)
        else:
            # true estimate
            label = output.max(1)[1]

        loss = F.nll_loss(F.log_softmax(output, dim=1), label)
        loss.backward()

        # accumulate all the gradients
        for n, p in model.named_parameters():
            estimated_fisher[n].data += p.grad.data ** 2 / len(dl)

    estimated_fisher = {n: p for n, p in estimated_fisher.items()}
    return estimated_mean, estimated_fisher

In [14]:
# compute fisher and mean parameters for EWC loss
ESTIMATE_TYPE = 'true'
estimated_mean, estimated_fisher = estimate_ewc_params(model, mnist_train)

# Train task B fashion mnist
for _ in range(EPOCHS):
    for input, target in tqdm(f_train_loader):
        output = model(input.to(device))
        loss = ewc_loss(model, weight, estimated_fisher, estimated_mean) + criterion(output, target.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

accuracies['mnist_EWC'] = get_accuracy(model, test_loader)
accuracies['f_mnist_EWC'] = get_accuracy(model, f_test_loader)

100%|██████████| 600/600 [00:07<00:00, 75.76it/s]
100%|██████████| 600/600 [00:08<00:00, 69.60it/s]
100%|██████████| 600/600 [00:08<00:00, 69.88it/s]
100%|██████████| 600/600 [00:08<00:00, 72.16it/s]


In [15]:
accuracies

{'mnist_initial': tensor(0.9765, device='cuda:0'),
 'mnist_EWC': tensor(0.9690, device='cuda:0'),
 'f_mnist_EWC': tensor(0.8366, device='cuda:0')}