In [113]:
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Normalize

In [114]:
class FashionMNISTtask1(datasets.FashionMNIST):
  def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
    super(FashionMNISTtask1, self).__init__(root, train=train, transform=transform,
                                            target_transform=target_transform,download=download)
    self.classes = self.classes[:6]

  def __getitem__(self, index):
    img, target = super(FashionMNISTtask1, self).__getitem__(index)
    if target < 6:
        return img, target
    else:
        return img, -1

In [115]:
class FashionMNISTtask2(datasets.FashionMNIST):
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        super(FashionMNISTtask2, self).__init__(root, train=train, transform=transform,
                                                target_transform=target_transform,download=download)
        self.classes = self.classes[6:]

    def __getitem__(self, index):
        img, target = super(FashionMNISTtask2, self).__getitem__(index)
        if target >= 6:
            return img, target
        else:
            return img, -1

In [116]:
transform = torchvision.transforms.Compose([ToTensor(), Normalize((0.5), (0.5))])

train_dataset_1 = FashionMNISTtask1(root='./data1', train=True, transform=transform, download=True)
test_dataset_1 = FashionMNISTtask1(root='./data1', train=False, transform=transform, download=True)

In [117]:
train_dataset_2 = FashionMNISTtask2(root='./data2', train=True, transform=transform, download=True)
test_dataset_2 = FashionMNISTtask2(root='./data2', train=False, transform=transform, download=True)
test_dataset_3 = datasets.FashionMNIST(
    root='data3',
    train=False,
    download=True,
    transform=transform
)

In [118]:
train_dataset_filtered_first = [data for data in train_dataset_1 if data[1] != -1]
test_dataset_filtered_first = [data for data in test_dataset_1 if data[1] != -1]

train_dataset_filtered_second = [data for data in train_dataset_2 if data[1] != -1]
test_dataset_filtered_second = [data for data in test_dataset_2 if data[1] != -1]

In [119]:
unique_labels_1 = set()
unique_labels_2 = set()

for _, target in train_dataset_filtered_first:
    unique_labels_1.add(target)

for _, target in train_dataset_filtered_second:
    unique_labels_2.add(target)

print(f"First: {unique_labels_1}")
print(f"Second: {unique_labels_2}")

First: {0, 1, 2, 3, 4, 5}
Second: {8, 9, 6, 7}


In [120]:
train_dataloader_first = DataLoader(train_dataset_filtered_first, batch_size=64, shuffle=True)
test_dataloader_first = DataLoader(test_dataset_filtered_first, batch_size=256, shuffle=False)

for X, y in train_dataloader_first:
  print(f"Shape of X [N, C, H, W]: {X.shape}")
  print(f"SHape of y: {y.shape}, dtype: {y.dtype}")
  break

Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
SHape of y: torch.Size([64]), dtype: torch.int64


In [121]:
train_dataloader_second = DataLoader(train_dataset_filtered_second, batch_size=64, shuffle=True)
test_dataloader_second = DataLoader(test_dataset_filtered_second, batch_size=256, shuffle=False)

In [122]:
eval_dataloader = DataLoader(test_dataset_3, batch_size=256, shuffle=True)

In [123]:
device = ('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using {device} device")

Using cuda device


In [124]:
class NeuralNetwork(nn.Module):
    def __init__(self, num_classes=10, hidden_size=512):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()

        self.fc1 = nn.Linear(28*28, hidden_size)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.relu2 = nn.ReLU()
        self.classifier = nn.Linear(hidden_size, num_classes)

    #     self._initialize_weights()

    # def _initialize_weights(self):
    #     for m in self.modules():
    #         if isinstance(m, nn.Linear):
    #             nn.init.kaiming_normal_(m.weight, nonlinearity='sigmoid')
    #         elif isinstance(m, nn.Conv2d):
    #             nn.init.kaiming_normal_(m.weight, nonlinearity='relu')

    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        logits = self.classifier(x)

        return logits

In [125]:
def train(dataloader, model, loss_fn, optimizer):
  size = len(dataloader.dataset)
  model.train()
  for batch, (X, y) in enumerate(dataloader):
    X, y = X.to(device), y.to(device)
    pred = model(X)
    loss = loss_fn(pred, y)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if batch % 100 == 0:
      loss, current = loss.item(), (batch+1) * len(X)
      print(f"Loss: {loss:>7f}, {current:>5d}/{size:>5d}")

In [126]:
def test(dataloader, model, loss_fn):
  size = len(dataloader.dataset)
  num_batches = len(dataloader)
  model.eval()

  test_loss, correct = 0, 0
  with torch.no_grad():
    for X, y in dataloader:
      X, y = X.to(device), y.to(device)
      pred = model(X)
      test_loss += loss_fn(pred, y).item()
      correct += (pred.argmax(1) == y).type(torch.float).sum().item()
  test_loss /= num_batches
  correct /= size
  print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}, Avg Loss: {test_loss:>8f}\n")

In [127]:
model = NeuralNetwork(num_classes=10, hidden_size=512).to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.05, momentum=0.9, weight_decay=5e-4)

In [128]:
epochs = 10
for t in range(epochs):
  print(f"Epoch {t+1}\n---------------------------")
  train(train_dataloader_first, model, loss_fn, optimizer)
  test(test_dataloader_first, model, loss_fn)
print("Done!")

torch.save(model.state_dict(), "model_old.pth")

Epoch 1
---------------------------
Loss: 2.298927,    64/36000
Loss: 0.396251,  6464/36000
Loss: 0.468095, 12864/36000
Loss: 0.442391, 19264/36000
Loss: 0.337982, 25664/36000
Loss: 0.302926, 32064/36000
Test Error: 
 Accuracy: 89.2, Avg Loss: 0.288838

Epoch 2
---------------------------
Loss: 0.237493,    64/36000
Loss: 0.282463,  6464/36000
Loss: 0.448286, 12864/36000
Loss: 0.237903, 19264/36000
Loss: 0.211039, 25664/36000
Loss: 0.153253, 32064/36000
Test Error: 
 Accuracy: 89.5, Avg Loss: 0.282752

Epoch 3
---------------------------
Loss: 0.253466,    64/36000
Loss: 0.154371,  6464/36000
Loss: 0.113784, 12864/36000
Loss: 0.242314, 19264/36000
Loss: 0.179281, 25664/36000
Loss: 0.199503, 32064/36000
Test Error: 
 Accuracy: 90.4, Avg Loss: 0.262568

Epoch 4
---------------------------
Loss: 0.184641,    64/36000
Loss: 0.401508,  6464/36000
Loss: 0.286990, 12864/36000
Loss: 0.275291, 19264/36000
Loss: 0.305218, 25664/36000
Loss: 0.141361, 32064/36000
Test Error: 
 Accuracy: 91.1, Avg 

In [129]:
def val(epoch):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for batch, (X, y) in enumerate(eval_dataloader):
            X, y = X.to(device), y.to(device)
            outputs = model(X)
            _, predicted_old = outputs.max(1)
            total += len(y)
            correct += predicted_old.eq(y).sum().item()
        print(f"Validation Acc: {100. * correct / total}\n")

In [130]:
val(1)

Validation Acc: 54.07



________

In [131]:
# model = NeuralNetwork()
# model.load_state_dict(torch.load("model_old.pth"))

In [132]:
from copy import deepcopy

def get_fisher_diag(model, dataloader, params, empirical=False):
    fisher = {}
    params_dict = dict(params)
    for n, p in deepcopy(params_dict).items():
        p.data.zero_()
        fisher[n] = p.data.clone().detach().requires_grad_()

    model.eval()

    for input, gt_label in dataloader:
        input, gt_label = input.to(device), gt_label.to(device)
        model.zero_grad()
        output = model(input)

        if empirical:
            label = output.max(1)[1]
        else:
            label = gt_label

        # label = gt_label.repeat(output.size(0))
        negloglikelihood = torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(output, dim=1), label)
        negloglikelihood.backward()

        for n, p in model.named_parameters():
            fisher[n].data += p.grad.data ** 2 / len(dataloader.dataset)

    fisher = {n: p for n, p in fisher.items()}
    return fisher


def get_ewc_loss(model, fisher, p_old):
    loss = 0
    for n, p in model.named_parameters():
        _loss = fisher[n] * (p - p_old[n]) ** 2
        loss += _loss.sum()
    return loss

In [133]:
model.to(device)

ewc_lambda = 500_000

fisher_matrix = get_fisher_diag(model, train_dataloader_first, model.named_parameters())
prev_params = {n: p.data.clone() for n, p in model.named_parameters()}

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

In [134]:
print(fisher_matrix)

{'fc1.weight': tensor([[3.7443e-09, 3.7439e-09, 3.7436e-09,  ..., 2.5215e-09, 3.4699e-09,
         3.7169e-09],
        [1.6326e-11, 1.6238e-11, 1.6049e-11,  ..., 1.4991e-11, 1.6045e-11,
         1.6240e-11],
        [2.9330e-08, 2.9330e-08, 2.9311e-08,  ..., 2.8306e-08, 2.9205e-08,
         2.9335e-08],
        ...,
        [6.1033e-08, 6.1013e-08, 6.0929e-08,  ..., 5.7351e-08, 5.9659e-08,
         6.0806e-08],
        [7.6741e-13, 7.6732e-13, 7.6720e-13,  ..., 7.3480e-13, 7.3585e-13,
         7.6046e-13],
        [1.2944e-09, 1.2944e-09, 1.2947e-09,  ..., 1.2968e-09, 1.2910e-09,
         1.2945e-09]], device='cuda:0', requires_grad=True), 'fc1.bias': tensor([3.7443e-09, 1.6326e-11, 2.9330e-08, 5.4086e-08, 1.8312e-10, 1.0067e-09,
        1.3593e-08, 7.5410e-13, 8.1404e-08, 5.7398e-08, 4.0001e-09, 2.2646e-12,
        3.4448e-09, 2.9741e-12, 4.5721e-17, 2.7898e-08, 4.6300e-13, 1.0759e-08,
        5.7502e-08, 2.4081e-08, 7.3653e-09, 7.6308e-08, 3.3324e-11, 2.4336e-10,
        2.0743e-14,

In [135]:
def train(dataloader, model, loss_fn, optimizer, fisher_matrix, prev_params):
    model.train()
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        pred = model(X)

        # Original loss
        ce_loss = loss_fn(pred, y)

        # EWC loss
        ewc_loss = get_ewc_loss(model, fisher_matrix, prev_params)

        loss = ce_loss + ewc_lambda * ewc_loss

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch+1)*len(X)
            print(f"Loss: {loss:>7f}, {current:>5d}/{size:>5d}")

In [136]:
def val(epoch):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for batch, (X, y) in enumerate(eval_dataloader):
            X, y = X.to(device), y.to(device)
            outputs = model(X)
            _, predicted_old = outputs.max(1)
            print(torch.unique(predicted_old))
            total += len(y)
            correct += predicted_old.eq(y).sum().item()
        print(f"Validation Acc: {100. * correct / total}\n")

In [137]:
for epoch in range(3):
    print(f"Epoch {epoch+1}: ----------------------")
    train(train_dataloader_second, model, loss_fn, optimizer, fisher_matrix, prev_params)
    test(test_dataloader_second, model, loss_fn)
print("Done!")

Epoch 1: ----------------------
Loss: 14.331015,    64/24000
Loss: 0.281682,  6464/24000
Loss: 0.269779, 12864/24000
Loss: 0.096627, 19264/24000
Test Error: 
 Accuracy: 95.5, Avg Loss: 0.125358

Epoch 2: ----------------------
Loss: 0.157882,    64/24000
Loss: 0.140260,  6464/24000
Loss: 0.069510, 12864/24000
Loss: 0.083681, 19264/24000
Test Error: 
 Accuracy: 96.0, Avg Loss: 0.110670

Epoch 3: ----------------------
Loss: 0.114369,    64/24000
Loss: 0.064815,  6464/24000
Loss: 0.142015, 12864/24000
Loss: 0.082133, 19264/24000
Test Error: 
 Accuracy: 96.7, Avg Loss: 0.097747

Done!


In [138]:
val(1)

tensor([1, 3, 6, 7, 8, 9], device='cuda:0')
tensor([1, 3, 6, 7, 8, 9], device='cuda:0')
tensor([1, 3, 6, 7, 8, 9], device='cuda:0')
tensor([1, 3, 6, 7, 8, 9], device='cuda:0')
tensor([0, 1, 6, 7, 8, 9], device='cuda:0')
tensor([1, 3, 6, 7, 8, 9], device='cuda:0')
tensor([1, 3, 6, 7, 8, 9], device='cuda:0')
tensor([1, 3, 6, 7, 8, 9], device='cuda:0')
tensor([1, 3, 6, 7, 8, 9], device='cuda:0')
tensor([1, 3, 6, 7, 8, 9], device='cuda:0')
tensor([1, 3, 6, 7, 8, 9], device='cuda:0')
tensor([1, 3, 6, 7, 8, 9], device='cuda:0')
tensor([1, 3, 6, 7, 8, 9], device='cuda:0')
tensor([1, 3, 6, 7, 8, 9], device='cuda:0')
tensor([1, 3, 6, 7, 8, 9], device='cuda:0')
tensor([1, 3, 6, 7, 8, 9], device='cuda:0')
tensor([1, 3, 6, 7, 8, 9], device='cuda:0')
tensor([1, 3, 6, 7, 8, 9], device='cuda:0')
tensor([1, 3, 6, 7, 8, 9], device='cuda:0')
tensor([1, 3, 6, 7, 8, 9], device='cuda:0')
tensor([1, 3, 6, 7, 8, 9], device='cuda:0')
tensor([1, 3, 6, 7, 8, 9], device='cuda:0')
tensor([1, 3, 6, 7, 8, 9], devic

In [139]:
fisher_matrix = get_fisher_diag(model, train_dataloader_first, model.named_parameters())

In [140]:
print(fisher_matrix)

{'fc1.weight': tensor([[2.7461e-07, 2.7452e-07, 2.7371e-07,  ..., 1.4784e-07, 2.4882e-07,
         2.7414e-07],
        [2.0734e-07, 2.0731e-07, 2.0710e-07,  ..., 1.9095e-07, 2.0407e-07,
         2.0714e-07],
        [3.0424e-06, 3.0421e-06, 3.0385e-06,  ..., 2.6609e-06, 2.9924e-06,
         3.0401e-06],
        ...,
        [4.4956e-06, 4.4947e-06, 4.4922e-06,  ..., 4.3381e-06, 4.5258e-06,
         4.4957e-06],
        [3.5764e-07, 3.5759e-07, 3.5729e-07,  ..., 3.2840e-07, 3.5391e-07,
         3.5759e-07],
        [4.5066e-09, 4.5063e-09, 4.5060e-09,  ..., 4.2264e-09, 4.4487e-09,
         4.4994e-09]], device='cuda:0', requires_grad=True), 'fc1.bias': tensor([2.7461e-07, 2.0734e-07, 3.0424e-06, 1.1354e-05, 6.0132e-08, 2.2335e-08,
        2.1391e-07, 1.6250e-11, 3.5841e-06, 4.4533e-07, 9.3733e-07, 4.0175e-08,
        1.0547e-11, 4.9968e-09, 5.6397e-12, 4.9934e-06, 3.7685e-12, 6.9628e-06,
        7.3301e-06, 1.4020e-06, 8.9563e-07, 1.1601e-05, 4.6595e-13, 7.9518e-07,
        1.5050e-11,