In [1]:
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Normalize

In [2]:
class FashionMNISTtask1(datasets.FashionMNIST):
  def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
    super(FashionMNISTtask1, self).__init__(root, train=train, transform=transform, target_transform=target_transform,download=download)
    self.classes = self.classes[:6]

  def __getitem__(self, index):
    img, target = super(FashionMNISTtask1, self).__getitem__(index)
    if target < 6:
        return img, target
    else:
        return img, -1

In [3]:
class FashionMNISTtask2(datasets.FashionMNIST):
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        super(FashionMNISTtask2, self).__init__(root, train=train, transform=transform, target_transform=target_transform,download=download)
        self.classes = self.classes[6:]

    def __getitem__(self, index):
        img, target = super(FashionMNISTtask2, self).__getitem__(index)
        if target >= 6:
            return img, target
        else:
            return img, -1

In [4]:
transform = torchvision.transforms.Compose([ToTensor(), Normalize((0.5), (0.5))])

train_dataset_1 = FashionMNISTtask1(root='./data1', train=True, transform=transform, download=True)
test_dataset_1 = FashionMNISTtask1(root='./data1', train=False, transform=transform, download=True)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data1/FashionMNISTtask1/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:02<00:00, 10486401.96it/s]


Extracting ./data1/FashionMNISTtask1/raw/train-images-idx3-ubyte.gz to ./data1/FashionMNISTtask1/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data1/FashionMNISTtask1/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 172930.49it/s]


Extracting ./data1/FashionMNISTtask1/raw/train-labels-idx1-ubyte.gz to ./data1/FashionMNISTtask1/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data1/FashionMNISTtask1/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:01<00:00, 3182112.04it/s]


Extracting ./data1/FashionMNISTtask1/raw/t10k-images-idx3-ubyte.gz to ./data1/FashionMNISTtask1/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data1/FashionMNISTtask1/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 6861225.61it/s]

Extracting ./data1/FashionMNISTtask1/raw/t10k-labels-idx1-ubyte.gz to ./data1/FashionMNISTtask1/raw






In [5]:
train_dataset_2 = FashionMNISTtask2(root='./data2', train=True, transform=transform, download=True)
test_dataset_2 = FashionMNISTtask2(root='./data2', train=False, transform=transform, download=True)
test_dataset_3 = datasets.FashionMNIST(
    root='data3',
    train=False,
    download=True,
    transform=transform
)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data2/FashionMNISTtask2/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:02<00:00, 9715041.91it/s] 


Extracting ./data2/FashionMNISTtask2/raw/train-images-idx3-ubyte.gz to ./data2/FashionMNISTtask2/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data2/FashionMNISTtask2/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 178104.19it/s]


Extracting ./data2/FashionMNISTtask2/raw/train-labels-idx1-ubyte.gz to ./data2/FashionMNISTtask2/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data2/FashionMNISTtask2/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:01<00:00, 3153656.35it/s]


Extracting ./data2/FashionMNISTtask2/raw/t10k-images-idx3-ubyte.gz to ./data2/FashionMNISTtask2/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data2/FashionMNISTtask2/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 20781787.29it/s]


Extracting ./data2/FashionMNISTtask2/raw/t10k-labels-idx1-ubyte.gz to ./data2/FashionMNISTtask2/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data3/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:02<00:00, 10869285.39it/s]


Extracting data3/FashionMNIST/raw/train-images-idx3-ubyte.gz to data3/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data3/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 177107.40it/s]


Extracting data3/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data3/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data3/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:01<00:00, 3167696.13it/s]


Extracting data3/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data3/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data3/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 19665097.44it/s]

Extracting data3/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data3/FashionMNIST/raw






In [6]:
train_dataset_filtered_old = [data for data in train_dataset_1 if data[1] != -1]
test_dataset_filtered_old = [data for data in test_dataset_1 if data[1] != -1]

train_dataset_filtered_new = [data for data in train_dataset_2 if data[1] != -1]
test_dataset_filtered_new = [data for data in test_dataset_2 if data[1] != -1]

In [7]:
train_dataloader_old = DataLoader(train_dataset_filtered_old, batch_size=64, shuffle=True)
test_dataloader_old = DataLoader(test_dataset_filtered_old, batch_size=256, shuffle=False)

for X, y in train_dataloader_old:
  print(f"Shape of X [N, C, H, W]: {X.shape}")
  print(f"SHape of y: {y.shape}, dtype: {y.dtype}")
  break

Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
SHape of y: torch.Size([64]), dtype: torch.int64


In [8]:
train_dataloader_new = DataLoader(train_dataset_filtered_new, batch_size=64, shuffle=True)
test_dataloader_new = DataLoader(test_dataset_filtered_new, batch_size=256, shuffle=False)

In [9]:
eval_dataloader = DataLoader(test_dataset_3, batch_size=256, shuffle=True)

In [21]:
device = ('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using {device} device")

Using cuda device


In [11]:
class NeuralNetwork(nn.Module):
    def __init__(self, num_classes=10, hidden_size=512):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()

        self.fc1 = nn.Linear(28*28, hidden_size)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.relu2 = nn.ReLU()
        self.classifier = nn.Linear(hidden_size, num_classes)

    #     self._initialize_weights()

    # def _initialize_weights(self):
    #     for m in self.modules():
    #         if isinstance(m, nn.Linear):
    #             nn.init.kaiming_normal_(m.weight, nonlinearity='sigmoid')
    #         elif isinstance(m, nn.Conv2d):
    #             nn.init.kaiming_normal_(m.weight, nonlinearity='relu')

    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        logits = self.classifier(x)

        return logits

In [12]:
def train(dataloader, model, loss_fn, optimizer):
  size = len(dataloader.dataset)
  model.train()
  for batch, (X, y) in enumerate(dataloader):
    X, y = X.to(device), y.to(device)
    pred = model(X)
    loss = loss_fn(pred, y)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if batch % 100 == 0:
      loss, current = loss.item(), (batch+1) * len(X)
      print(f"Loss: {loss:>7f}, {current:>5d}/{size:>5d}")

In [13]:
def test(dataloader, model, loss_fn):
  size = len(dataloader.dataset)
  num_batches = len(dataloader)
  model.eval()

  test_loss, correct = 0, 0
  with torch.no_grad():
    for X, y in dataloader:
      X, y = X.to(device), y.to(device)
      pred = model(X)
      test_loss += loss_fn(pred, y).item()
      correct += (pred.argmax(1) == y).type(torch.float).sum().item()
  test_loss /= num_batches
  correct /= size
  print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}, Avg Loss: {test_loss:>8f}\n")

In [14]:
pre_model = NeuralNetwork(num_classes=10, hidden_size=512).to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(pre_model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

In [16]:
epochs = 10
for t in range(epochs):
  print(f"Epoch {t+1}\n---------------------------")
  train(train_dataloader_old, pre_model, loss_fn, optimizer)
  test(test_dataloader_old, pre_model, loss_fn)
print("Done!")

torch.save(pre_model.state_dict(), "model_old.pth")

Epoch 1
---------------------------
Loss: 2.283751,    64/36000
Loss: 0.243989,  6464/36000
Loss: 0.184512, 12864/36000
Loss: 0.344231, 19264/36000
Loss: 0.393142, 25664/36000
Loss: 0.231987, 32064/36000
Test Error: 
 Accuracy: 89.7, Avg Loss: 0.282104

Epoch 2
---------------------------
Loss: 0.229488,    64/36000
Loss: 0.310935,  6464/36000
Loss: 0.317108, 12864/36000
Loss: 0.150361, 19264/36000
Loss: 0.290324, 25664/36000
Loss: 0.325517, 32064/36000
Test Error: 
 Accuracy: 90.0, Avg Loss: 0.268644

Epoch 3
---------------------------
Loss: 0.183788,    64/36000
Loss: 0.179374,  6464/36000
Loss: 0.152611, 12864/36000
Loss: 0.284833, 19264/36000
Loss: 0.278665, 25664/36000
Loss: 0.163754, 32064/36000
Test Error: 
 Accuracy: 89.4, Avg Loss: 0.280450

Epoch 4
---------------------------
Loss: 0.199625,    64/36000
Loss: 0.414821,  6464/36000
Loss: 0.212038, 12864/36000
Loss: 0.176117, 19264/36000
Loss: 0.157161, 25664/36000
Loss: 0.191457, 32064/36000
Test Error: 
 Accuracy: 89.4, Avg 

________

# I must try replacing the last layer (classifier) with a new one with 10 classes. Maybe that will work

In [22]:
model = NeuralNetwork()
model.load_state_dict(torch.load("model_old.pth"))

<All keys matched successfully>

In [None]:
print(list(model.named_parameters()))

In [42]:
from copy import deepcopy

def get_fisher_diag(model, dataset, params, empirical=True):
  fisher = {}
  params_dict = dict(params)
  for n, p in deepcopy(params_dict).items():
    p.data.zero_()
    fisher[n] = p.data.clone().detach().requires_grad_()

  model.eval()

  for input, gt_label in dataset:
    input, gt_label = input.to(device), gt_label.to(device)

    model.zero_grad()
    output = model(input)
    # print(output.shape)
    # output = output.view(1, -1)

    if empirical:
      label = gt_label
    else:
      label = output.max(1)[1].view(-1)
    # label = gt_label.repeat(output.size(0))
    # print(label.shape)
    negloglikelihood = torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(output, dim=1), label)
    negloglikelihood.backward()

    for n, p in model.named_parameters():
      fisher[n].data += p.grad.data ** 2 / len(dataset)

  fisher = {n: p for n, p in fisher.items()}
  return fisher

def get_ewc_loss(model, fisher, p_old):
  loss = 0
  for n, p in model.named_parameters():
    _loss = fisher[n] * (p - p_old[n]) ** 2
    loss += _loss.sum()
  return loss

In [43]:
model.to(device)

ewc_lambda = 0.1

fisher_matrix = get_fisher_diag(model, train_dataloader_new, model.named_parameters())
prev_params = {n: p.data.clone() for n, p in model.named_parameters()}

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(pre_model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

In [52]:
def train(dataloader, model, loss_fn, optimizer, fisher_matrix, prev_params):
    model.train()
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        pred = model(X)

        # Original loss
        ce_loss = loss_fn(pred, y)

        # EWC loss
        ewc_loss = get_ewc_loss(model, fisher_matrix, prev_params)

        loss = ce_loss + ewc_lambda * ewc_loss

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch+1)*len(X)
            print(f"Loss: {loss:>7f}, {current:>5d}/{size:>5d}")

    # Update fisher matrix and previous parameters after each epoch
    fisher_matrix = get_fisher_diag(model, dataloader, model.named_parameters())
    prev_params = {n: p.data.clone() for n, p in model.named_parameters()}

In [45]:
def val(epoch):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for batch, (X, y) in enumerate(eval_dataloader):
            X, y = X.to(device), y.to(device)
            outputs = model(X)
            _, predicted_old = outputs.max(1)
            total += len(y)
            correct += predicted_old.eq(y).sum().item()
        print(f"Validation Acc: {100. * correct / total}\n")

In [53]:
for epoch in range(10):
    print(f"Epoch {epoch+1}: ----------------------")
    train(train_dataloader_new, model, loss_fn, optimizer, fisher_matrix, prev_params)
    test(test_dataloader_new, model, loss_fn)
    val(epoch)
print("Done!")

Epoch 1: ----------------------
Loss: 12.452131,    64/24000
Loss: 11.685606,  6464/24000
Loss: 12.589679, 12864/24000
Loss: 12.411456, 19264/24000
Test Error: 
 Accuracy: 0.0, Avg Loss: 12.390607

Validation Acc: 54.94

Epoch 2: ----------------------
Loss: 12.516727,    64/24000
Loss: 12.418256,  6464/24000
Loss: 12.863902, 12864/24000
Loss: 12.051398, 19264/24000
Test Error: 
 Accuracy: 0.0, Avg Loss: 12.390607

Validation Acc: 54.94

Epoch 3: ----------------------
Loss: 12.676291,    64/24000
Loss: 11.888682,  6464/24000
Loss: 11.509149, 12864/24000
Loss: 12.062311, 19264/24000
Test Error: 
 Accuracy: 0.0, Avg Loss: 12.390607

Validation Acc: 54.94

Epoch 4: ----------------------
Loss: 12.501888,    64/24000
Loss: 12.030292,  6464/24000
Loss: 12.016770, 12864/24000
Loss: 11.271650, 19264/24000
Test Error: 
 Accuracy: 0.0, Avg Loss: 12.390607

Validation Acc: 54.94

Epoch 5: ----------------------
Loss: 12.083328,    64/24000
Loss: 12.295427,  6464/24000
Loss: 12.509932, 12864/240