# Elastic Weight Consolidation

In [2]:
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Normalize
from torch.utils.data import Subset

## Data Preparation

In [3]:
transform = torchvision.transforms.Compose([ToTensor(), Normalize((0.5), (0.5))])

train_dataset = datasets.FashionMNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST('./data', train=False, transform=transform)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:01<00:00, 16215374.24it/s]


Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 273964.31it/s]


Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:00<00:00, 4978216.55it/s]


Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 5935205.33it/s]

Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw






In [4]:
# Splitting datasets for two tasks
train_dataset_task1 = Subset(train_dataset, [i for i in range(len(train_dataset)) if train_dataset.targets[i] <= 5])
train_dataset_task2 = Subset(train_dataset, [i for i in range(len(train_dataset)) if train_dataset.targets[i] > 5])
test_dataset_task1 = Subset(test_dataset, [i for i in range(len(test_dataset)) if test_dataset.targets[i] <= 5])
test_dataset_task2 = Subset(test_dataset, [i for i in range(len(test_dataset)) if test_dataset.targets[i] > 5])

In [5]:
unique_labels_1 = set()
unique_labels_2 = set()

for _, target in train_dataset_task1:
    unique_labels_1.add(target)

for _, target in train_dataset_task2:
    unique_labels_2.add(target)

print(f"First: {unique_labels_1}")
print(f"Second: {unique_labels_2}")

First: {0, 1, 2, 3, 4, 5}
Second: {8, 9, 6, 7}


In [6]:
train_dataloader_task1 = DataLoader(train_dataset_task1, batch_size=64, shuffle=True)
test_dataloader_task1 = DataLoader(test_dataset_task1, batch_size=256, shuffle=True)

for X, y in train_dataloader_task1:
  print(f"Shape of X [N, C, H, W]: {X.shape}")
  print(f"SHape of y: {y.shape}, dtype: {y.dtype}")
  break

Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
SHape of y: torch.Size([64]), dtype: torch.int64


In [7]:
train_dataloader_task2 = DataLoader(train_dataset_task2, batch_size=64, shuffle=True)
test_dataloader_task2 = DataLoader(test_dataset_task2, batch_size=256, shuffle=True)

In [8]:
device = ('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using {device} device")

Using cuda device


## NN Architecture

In [9]:
class NeuralNetwork(nn.Module):
    def __init__(self, num_classes=10, hidden_size=512):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()

        self.fc1 = nn.Linear(28*28, hidden_size)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.relu2 = nn.ReLU()
        self.classifier = nn.Linear(hidden_size, num_classes)

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='sigmoid')
            elif isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')

    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        logits = self.classifier(x)

        return logits

In [10]:
def train(dataloader, model, loss_fn, optimizer):
  size = len(dataloader.dataset)
  model.train()
  for batch, (X, y) in enumerate(dataloader):
    X, y = X.to(device), y.to(device)
    pred = model(X)
    loss = loss_fn(pred, y)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if batch % 100 == 0:
      loss, current = loss.item(), (batch+1) * len(X)
      print(f"Loss: {loss:>7f}, {current:>5d}/{size:>5d}")

## Training on Task 1

In [11]:
def test(dataloader, model, loss_fn):
  size = len(dataloader.dataset)
  num_batches = len(dataloader)
  model.eval()

  test_loss, correct = 0, 0
  with torch.no_grad():
    for X, y in dataloader:
      X, y = X.to(device), y.to(device)
      pred = model(X)
      test_loss += loss_fn(pred, y).item()
      correct += (pred.argmax(1) == y).type(torch.float).sum().item()
  test_loss /= num_batches
  correct /= size
  print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}, Avg Loss: {test_loss:>8f}\n")

In [12]:
model_task1 = NeuralNetwork(num_classes=10, hidden_size=512).to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model_task1.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

In [13]:
epochs = 5
for t in range(epochs):
  print(f"Epoch {t+1}\n---------------------------")
  train(train_dataloader_task1, model_task1, loss_fn, optimizer)
  test(test_dataloader_task1, model_task1, loss_fn)
print("Done!")

Epoch 1
---------------------------
Loss: 2.251544,    64/36000
Loss: 0.337350,  6464/36000
Loss: 0.355125, 12864/36000
Loss: 0.299372, 19264/36000
Loss: 0.337214, 25664/36000
Loss: 0.191454, 32064/36000
Test Error: 
 Accuracy: 88.8, Avg Loss: 0.291615

Epoch 2
---------------------------
Loss: 0.264742,    64/36000
Loss: 0.221406,  6464/36000
Loss: 0.263498, 12864/36000
Loss: 0.315568, 19264/36000
Loss: 0.139890, 25664/36000
Loss: 0.131490, 32064/36000
Test Error: 
 Accuracy: 90.9, Avg Loss: 0.261338

Epoch 3
---------------------------
Loss: 0.154787,    64/36000
Loss: 0.326340,  6464/36000
Loss: 0.217241, 12864/36000
Loss: 0.223733, 19264/36000
Loss: 0.183263, 25664/36000
Loss: 0.182134, 32064/36000
Test Error: 
 Accuracy: 91.0, Avg Loss: 0.246638

Epoch 4
---------------------------
Loss: 0.447313,    64/36000
Loss: 0.151534,  6464/36000
Loss: 0.176930, 12864/36000
Loss: 0.279525, 19264/36000
Loss: 0.197010, 25664/36000
Loss: 0.183053, 32064/36000
Test Error: 
 Accuracy: 91.7, Avg 

In [14]:
def val(dataloader, model):
    size = len(dataloader.dataset)
    model.eval()

    correct = 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}\n")

In [15]:
val(test_dataloader_task1, model_task1)

Test Error: 
 Accuracy: 90.3



________

## Fisher Matrix Calculation

In [16]:
from copy import deepcopy

def get_fisher_diag(model, dataloader, params, empirical=False):
    fisher = {}
    params_dict = dict(params)
    for n, p in deepcopy(params_dict).items():
        p.data.zero_()
        fisher[n] = p.data.clone().detach().requires_grad_()

    model.eval()

    for input, gt_label in dataloader:
        input, gt_label = input.to(device), gt_label.to(device)
        model.zero_grad()
        output = model(input)

        if empirical:
            label = output.max(1)[1]
        else:
            label = gt_label

        # label = gt_label.repeat(output.size(0))
        negloglikelihood = torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(output, dim=1), label)
        negloglikelihood.backward()

        for n, p in model.named_parameters():
            fisher[n].data += p.grad.data ** 2 / len(dataloader.dataset)

    fisher = {n: p for n, p in fisher.items()}
    return fisher


def get_ewc_loss(model, fisher, p_old):
    loss = 0
    for n, p in model.named_parameters():
        _loss = fisher[n] * (p - p_old[n]) ** 2
        loss += _loss.sum()
    return loss

## Training on Task 2

In [17]:
model_task2 = model_task1

In [18]:
model_task2.to(device)

ewc_lambda = 1_000_000

fisher_matrix = get_fisher_diag(model_task2, train_dataloader_task1, model_task2.named_parameters())
prev_params = {n: p.data.clone() for n, p in model_task2.named_parameters()}

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model_task2.parameters(), lr=0.001, momentum=0.9)

In [19]:
def train(dataloader, model, loss_fn, optimizer, fisher_matrix, prev_params):
    model.train()
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        pred = model(X)

        # Original loss
        ce_loss = loss_fn(pred, y)

        # EWC loss
        ewc_loss = get_ewc_loss(model, fisher_matrix, prev_params)

        loss = ce_loss + ewc_lambda * ewc_loss

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch+1)*len(X)
            print(f"Loss: {loss:>7f}, {current:>5d}/{size:>5d}")

In [20]:
def val(dataloader, model):
    size = len(dataloader.dataset)
    model.eval()

    correct = 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            _, predicted_old = pred.max(1)
            print(torch.unique(predicted_old))
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    correct /= size
    print(f"Validation Accuracy: {(100*correct):>0.1f}\n")

In [21]:
for epoch in range(1):
    print(f"Epoch {epoch+1}: ----------------------")
    train(train_dataloader_task2, model_task2, loss_fn, optimizer, fisher_matrix, prev_params)
    test(test_dataloader_task2, model_task2, loss_fn)
print("Done!")

Epoch 1: ----------------------
Loss: 12.328086,    64/24000
Loss: 0.617575,  6464/24000
Loss: 0.392870, 12864/24000
Loss: 0.258788, 19264/24000
Test Error: 
 Accuracy: 93.9, Avg Loss: 0.185084

Done!


In [22]:
val(test_dataloader_task1, model_task2)

tensor([0, 1, 3, 6, 7, 8, 9], device='cuda:0')
tensor([0, 1, 3, 6, 7, 8, 9], device='cuda:0')
tensor([0, 1, 4, 6, 7, 8, 9], device='cuda:0')
tensor([0, 1, 6, 7, 8, 9], device='cuda:0')
tensor([1, 6, 7, 8, 9], device='cuda:0')
tensor([0, 1, 6, 7, 8, 9], device='cuda:0')
tensor([0, 1, 3, 6, 7, 8, 9], device='cuda:0')
tensor([1, 6, 7, 8, 9], device='cuda:0')
tensor([0, 1, 3, 6, 7, 8, 9], device='cuda:0')
tensor([1, 6, 7, 8, 9], device='cuda:0')
tensor([0, 1, 3, 4, 6, 7, 8, 9], device='cuda:0')
tensor([0, 1, 6, 7, 8, 9], device='cuda:0')
tensor([0, 1, 2, 6, 7, 8, 9], device='cuda:0')
tensor([1, 3, 6, 7, 8, 9], device='cuda:0')
tensor([0, 1, 4, 6, 7, 8, 9], device='cuda:0')
tensor([0, 1, 6, 7, 8, 9], device='cuda:0')
tensor([0, 1, 3, 6, 7, 8, 9], device='cuda:0')
tensor([0, 1, 6, 7, 8, 9], device='cuda:0')
tensor([0, 1, 6, 7, 8, 9], device='cuda:0')
tensor([0, 1, 6, 7, 8, 9], device='cuda:0')
tensor([0, 1, 3, 6, 7, 8, 9], device='cuda:0')
tensor([0, 1, 6, 7, 8, 9], device='cuda:0')
tensor([