# Learning Without Forgetting

In [161]:
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Normalize

## Data Preprocessing

In [162]:
class FashionMNISTtask1(datasets.FashionMNIST):
  def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
    super(FashionMNISTtask1, self).__init__(root, train=train, transform=transform, target_transform=target_transform,download=download)
    self.classes = self.classes[:6]

  def __getitem__(self, index):
    img, target = super(FashionMNISTtask1, self).__getitem__(index)
    if target < 6:
        return img, target
    else:
        return img, -1


In [163]:
class FashionMNISTtask2(datasets.FashionMNIST):
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        super(FashionMNISTtask2, self).__init__(root, train=train, transform=transform, target_transform=target_transform,download=download)
        self.classes = self.classes[6:]

    def __getitem__(self, index):
        img, target = super(FashionMNISTtask2, self).__getitem__(index)
        if target >= 6:
            return img, target
        else:
            return img, -1

In [164]:
transform = torchvision.transforms.Compose([ToTensor(), Normalize((0.5), (0.5))])

train_dataset_task1 = FashionMNISTtask1(root='./data1', train=True, transform=transform, download=True)
test_dataset_task1 = FashionMNISTtask1(root='./data1', train=False, transform=transform, download=True)

In [165]:
train_dataset_task2 = FashionMNISTtask2(root='./data2', train=True, transform=transform, download=True)
test_dataset_task2 = FashionMNISTtask2(root='./data2', train=False, transform=transform, download=True)

In [166]:
train_dataset_filtered_task1 = [data for data in train_dataset_task1 if data[1] != -1]
test_dataset_filtered_task1 = [data for data in test_dataset_task1 if data[1] != -1]

train_dataset_filtered_task2 = [data for data in train_dataset_task2 if data[1] != -1]
test_dataset_filtered_task2 = [data for data in test_dataset_task2 if data[1] != -1]

In [167]:
train_dataloader_task1 = DataLoader(train_dataset_filtered_task1, batch_size=64, shuffle=True)
test_dataloader_task1 = DataLoader(test_dataset_filtered_task1, batch_size=256, shuffle=False)

for X, y in train_dataloader_task1:
  print(f"Shape of X [N, C, H, W]: {X.shape}")
  print(f"SHape of y: {y.shape}, dtype: {y.dtype}")
  break

Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
SHape of y: torch.Size([64]), dtype: torch.int64


In [168]:
len(train_dataloader_task1)

563

In [169]:
train_dataloader_task2 = DataLoader(train_dataset_filtered_task2, batch_size=64, shuffle=True)
test_dataloader_task2 = DataLoader(test_dataset_filtered_task2, batch_size=256, shuffle=False)

In [170]:
device = ('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using {device} device")

Using cpu device


# NN Architecture

In [171]:
def kaiming_normal_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
    elif isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, nonlinearity='sigmoid')

In [172]:
class NeuralNetwork(nn.Module):
    def __init__(self, num_classes=10, hidden_size=512):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()

        self.fc1 = nn.Linear(28*28, hidden_size)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.relu2 = nn.ReLU()
        self.classifier = nn.Linear(hidden_size, num_classes)

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='sigmoid')
            elif isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')

    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        logits = self.classifier(x)

        return logits

In [173]:
class CNN(nn.Module):
    def __init__(self, num_classes=10, hidden_size=512):
        super(CNN, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(16, 32, 3, 1, 1)
        self.relu2 = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(32*14*14, hidden_size)
        self.relu3 = nn.ReLU()
        self.classifier = nn.Linear(hidden_size, num_classes)

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='sigmoid')

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu3(x)
        logits = self.classifier(x)

        return logits

    def forward_features(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu3(x)

        return x


## Training

In [174]:
def train(dataloader, model, loss_fn, optimizer):
  size = len(dataloader.dataset)
  model.train()
  for batch, (X, y) in enumerate(dataloader):
    X, y = X.to(device), y.to(device)
    pred = model(X)
    loss = loss_fn(pred, y)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if batch % 100 == 0:
      loss, current = loss.item(), (batch+1) * len(X)
      print(f"Loss: {loss:>7f}, {current:>5d}/{size:>5d}")

In [175]:
def test(dataloader, model, loss_fn):
  size = len(dataloader.dataset)
  num_batches = len(dataloader)
  model.eval()

  test_loss, correct = 0, 0
  with torch.no_grad():
    for X, y in dataloader:
      X, y = X.to(device), y.to(device)
      pred = model(X)
      test_loss += loss_fn(pred, y).item()
      correct += (pred.argmax(1) == y).type(torch.float).sum().item()
  test_loss /= num_batches
  correct /= size
  print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}, Avg Loss: {test_loss:>8f}\n")


In [176]:
model_task1 = CNN(num_classes=6, hidden_size=512).to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model_task1.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

In [177]:
model_task1.classifier.out_features

6

In [178]:
epochs = 3
for t in range(epochs):
  print(f"Epoch {t+1}\n---------------------------")
  train(train_dataloader_task1, model_task1, loss_fn, optimizer)
  test(test_dataloader_task1, model_task1, loss_fn)
print("Done!")

torch.save(model_task1.state_dict(), "model_old.pth")

Epoch 1
---------------------------
Loss: 2.021631,    64/36000
Loss: 0.401100,  6464/36000
Loss: 0.176412, 12864/36000
Loss: 0.144074, 19264/36000
Loss: 0.289997, 25664/36000
Loss: 0.307687, 32064/36000
Test Error: 
 Accuracy: 92.0, Avg Loss: 0.221119

Epoch 2
---------------------------
Loss: 0.107748,    64/36000
Loss: 0.194388,  6464/36000
Loss: 0.190120, 12864/36000
Loss: 0.136756, 19264/36000
Loss: 0.106814, 25664/36000
Loss: 0.177587, 32064/36000
Test Error: 
 Accuracy: 93.0, Avg Loss: 0.195162

Epoch 3
---------------------------
Loss: 0.105263,    64/36000
Loss: 0.102043,  6464/36000
Loss: 0.215910, 12864/36000
Loss: 0.161934, 19264/36000
Loss: 0.221108, 25664/36000
Loss: 0.174191, 32064/36000
Test Error: 
 Accuracy: 94.1, Avg Loss: 0.163600

Done!


ModifiedCNN has an additional linear hidden layer and new classifier with all 10 classes

In [179]:
class ModifiedCNN(CNN):
    def __init__(self, num_classes, hidden_size, new_hidden_size):
        super(ModifiedCNN, self).__init__(num_classes=num_classes, hidden_size=hidden_size)
        self.new_fc = nn.Linear(hidden_size, new_hidden_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.5)
        self.classifier = nn.Linear(new_hidden_size, num_classes)

    def forward(self, x):
        x = super().forward_features(x)
        x = self.new_fc(x)
        x = self.relu(x)
        x = self.classifier(x)
        return x

In [180]:
# Model initialization
model = ModifiedCNN(num_classes=10, hidden_size=512, new_hidden_size=512)

In [181]:
# Check the architecture
model

ModifiedCNN(
  (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu1): ReLU()
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu2): ReLU()
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc1): Linear(in_features=6272, out_features=512, bias=True)
  (relu3): ReLU()
  (classifier): Linear(in_features=512, out_features=10, bias=True)
  (new_fc): Linear(in_features=512, out_features=512, bias=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.5, inplace=False)
)

# LWF

In [182]:
# Initialize models with old NN architecture
model_task2 = CNN(num_classes=6, hidden_size=512).to(device)
model_task1 = CNN(num_classes=6, hidden_size=512).to(device)

# Load the wights from trained on task 1 model
model_task2.load_state_dict(torch.load("model_old.pth"))
model_task1.load_state_dict(torch.load("model_old.pth"))

<All keys matched successfully>

In [183]:
for n, p in model_task1.named_parameters():
    print(f"{n}, {p}")

conv1.weight, Parameter containing:
tensor([[[[-0.1057,  0.7276, -0.1905],
          [-0.1395, -0.1583, -0.6713],
          [ 0.6999, -0.0707,  0.0365]]],


        [[[-0.1791, -0.0330,  0.0873],
          [ 0.0945, -0.5090, -0.1908],
          [-0.0037,  0.7521, -1.1215]]],


        [[[-0.3484,  0.2708, -0.2843],
          [-0.3917, -0.2529,  0.6720],
          [-0.4581,  0.4667,  0.3419]]],


        [[[ 0.5713,  0.7453, -0.1456],
          [-1.1476, -0.1132,  1.0575],
          [ 0.5034, -0.4015,  0.7360]]],


        [[[-0.0569, -0.4772, -0.5208],
          [ 0.2953, -0.1534, -0.3075],
          [ 1.3770, -0.2166,  0.1164]]],


        [[[-0.6752,  0.2039, -0.1308],
          [ 0.4807,  0.4267, -0.2956],
          [-0.5032, -0.4513,  0.0596]]],


        [[[-0.0339,  0.3412, -0.4523],
          [ 0.4974, -0.2907,  0.3372],
          [-0.2468,  0.0887, -0.4683]]],


        [[[-0.0616, -0.2475, -0.3191],
          [-0.3377,  0.5390,  0.4910],
          [ 0.2113,  0.8815, -0.1580]]]

In [184]:
model_task2.classifier.weight

Parameter containing:
tensor([[ 0.0359,  0.0103, -0.0372,  ...,  0.0387,  0.0418, -0.0350],
        [ 0.0887, -0.0033, -0.0498,  ..., -0.0822, -0.0123,  0.0434],
        [ 0.0436,  0.0154, -0.0520,  ...,  0.0721, -0.0191, -0.0233],
        [ 0.0813, -0.0383,  0.1036,  ..., -0.0864, -0.0380,  0.0220],
        [-0.0208, -0.0674, -0.0162,  ..., -0.0632, -0.0303, -0.0616],
        [ 0.0258, -0.0105,  0.0341,  ...,  0.0470,  0.0022, -0.0353]],
       requires_grad=True)

In [185]:
model_task2 = ModifiedCNN(num_classes=10, hidden_size=512, new_hidden_size=512)

In [186]:
for n, p in model_task2.named_parameters():
    print(f"{n}, {p}")

conv1.weight, Parameter containing:
tensor([[[[-0.0424, -0.7092,  0.8925],
          [-0.2007, -0.3538,  0.4573],
          [-0.0191,  0.3872,  0.1740]]],


        [[[ 0.2868,  0.4410, -0.1939],
          [ 0.8882, -0.0712,  0.0279],
          [ 0.5257,  0.3444, -0.4839]]],


        [[[-0.9819,  0.4272, -0.4246],
          [ 0.4682,  0.3654, -0.6178],
          [-0.0928, -0.2722,  0.1935]]],


        [[[-1.0172, -0.8041, -0.3511],
          [-0.5816,  0.3598,  0.4628],
          [ 0.3011,  0.3005, -0.0158]]],


        [[[-0.2432, -0.6164,  0.0258],
          [-1.2331, -0.1153, -0.6541],
          [ 0.5195,  1.0962,  0.7474]]],


        [[[ 0.2384, -0.3966,  0.9313],
          [-0.2341, -0.2060, -0.0087],
          [-0.6331,  0.4961, -0.5190]]],


        [[[ 0.1002,  0.5174,  0.2105],
          [-0.3143, -0.0340,  0.0095],
          [ 0.6088,  0.6862,  0.5214]]],


        [[[ 0.1544, -0.2725, -0.6826],
          [-0.2867,  0.2775, -0.7737],
          [-0.0959, -0.9020, -0.0610]]]

In [187]:
model_task2.classifier.weight

Parameter containing:
tensor([[-0.0291,  0.0266,  0.0046,  ...,  0.0420, -0.0244,  0.0118],
        [-0.0081,  0.0377,  0.0040,  ...,  0.0394, -0.0178,  0.0197],
        [ 0.0378,  0.0112,  0.0073,  ...,  0.0038,  0.0103,  0.0161],
        ...,
        [-0.0027, -0.0182,  0.0116,  ...,  0.0419, -0.0274, -0.0309],
        [-0.0212, -0.0006,  0.0388,  ..., -0.0239,  0.0288,  0.0331],
        [ 0.0326,  0.0351, -0.0070,  ...,  0.0417,  0.0278, -0.0340]],
       requires_grad=True)

In [188]:
model_task2

ModifiedCNN(
  (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu1): ReLU()
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu2): ReLU()
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc1): Linear(in_features=6272, out_features=512, bias=True)
  (relu3): ReLU()
  (classifier): Linear(in_features=512, out_features=10, bias=True)
  (new_fc): Linear(in_features=512, out_features=512, bias=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.5, inplace=False)
)

In [189]:
conv1_weight = model_task1.conv1.weight.data
conv1_bias = model_task1.conv1.bias.data
conv2_weight = model_task1.conv2.weight.data
conv2_bias = model_task1.conv2.bias.data
fc1_weight = model_task1.fc1.weight.data
fc1_bias = model_task1.fc1.bias.data

model_task2.conv1.weight = nn.Parameter(conv1_weight)
model_task2.conv1.bias = nn.Parameter(conv1_bias)
model_task2.conv2.weight = nn.Parameter(conv2_weight)
model_task2.conv2.bias = nn.Parameter(conv2_bias)
model_task2.fc1.weight = nn.Parameter(fc1_weight)
model_task2.fc1.bias = nn.Parameter(fc1_bias)

In [190]:
in_features = model_task1.classifier.in_features
out_features = model_task1.classifier.out_features

new_weights = torch.zeros_like(model_task2.classifier.weight)
new_biases = torch.zeros_like(model_task2.classifier.bias)

weight = model_task1.classifier.weight.data
bias = model_task1.classifier.bias.data

new_weights[:out_features, :] = weight
new_biases[:out_features] = bias

In [191]:
model_task2.classifier.weight = nn.Parameter(new_weights)
model_task2.classifier.bias = nn.Parameter(new_biases)

In [192]:
model_task2.classifier.weight

Parameter containing:
tensor([[ 0.0359,  0.0103, -0.0372,  ...,  0.0387,  0.0418, -0.0350],
        [ 0.0887, -0.0033, -0.0498,  ..., -0.0822, -0.0123,  0.0434],
        [ 0.0436,  0.0154, -0.0520,  ...,  0.0721, -0.0191, -0.0233],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
       requires_grad=True)

In [193]:
model_task2 = model_task2.to(device)
print("New head numbers: ", model_task2.classifier.out_features)

New head numbers:  10


In [194]:
for param in model_task1.parameters():
  param.requires_grad = False

In [195]:
# Initially freeze all the layers except for classifier
for name, param in model_task2.named_parameters():
    if name not in ['classifier.weight', 'classifier.bias', 'new_fc.weight', 'new_fc.bias']:
        param.requires_grad = False

In [196]:
for name, _ in model_task2.named_parameters():
    print(name)

conv1.weight
conv1.bias
conv2.weight
conv2.bias
fc1.weight
fc1.bias
classifier.weight
classifier.bias
new_fc.weight
new_fc.bias


Changes in training and testing

In [197]:
def train(alpha, T):
    size = len(train_dataloader_task2.dataset)
    # We set net_new to evaluation mode to prevent it from being updated
    # while computing the distillation loss from the old model

    model_task2.train()
    for batch, (X, y) in enumerate(train_dataloader_task2):
        X, y = X.to(device), y.to(device)

        outputs = model_task2(X)
        soft_y = model_task1(X)

        loss1 = loss_fn(outputs, y)

        outputs_S = nn.functional.softmax(outputs[:, :out_features] / T, dim=1)
        outputs_T = nn.functional.softmax(soft_y[:, :out_features] / T, dim=1)

        loss2 = outputs_T.mul(-1 * torch.log(outputs_S))
        loss2 = loss2.sum(1)
        loss2 = loss2.mean() * T * T

        loss = loss1 + alpha * loss2

        loss.backward(retain_graph=True)
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch+1) * len(X)
            print(f"Loss: {loss:>7f}, {current:>5d}/{size:>5d}")


In [198]:
def test(alpha, T):
  size = len(test_dataloader_task2.dataset)
  num_batches = len(test_dataloader_task2)
  model_task2.eval()

  test_loss, correct = 0, 0
  with torch.no_grad():
    for X, y in test_dataloader_task2:
      X, y = X.to(device), y.to(device)

      outputs = model_task2(X)
      soft_y = model_task1(X)

      loss1 = loss_fn(outputs, y)

      outputs_S = nn.functional.softmax(outputs[:, :out_features] / T, dim=1)
      outputs_T = nn.functional.softmax(soft_y[:, :out_features] / T, dim=1)

      loss2 = outputs_T.mul(-1 * torch.log(outputs_S))
      loss2 = loss2.sum(1)
      loss2 = loss2.mean() * T * T

      loss = loss1 * alpha + loss2 * (1 - alpha)

      test_loss += loss.item()
      _, predicted = outputs.max(1)
      correct += predicted.eq(y).sum().item()
      # correct += (pred.argmax(1) == y).type(torch.float).sum().item()
  test_loss /= num_batches
  correct /= size
  print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}, Avg Loss: {test_loss:>8f}")


In [199]:
def val(epoch):
    model_task2.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for batch, (X, y) in enumerate(test_dataloader_task1):
            X, y = X.to(device), y.to(device)
            outputs = model_task2(X)
            _, predicted_old = outputs.max(1)
            total += len(y)
            correct += predicted_old.eq(y).sum().item()
        print(f"Validation Acc: {100. * correct / total}\n")

In [200]:
T = 2
alpha = 0.9
loss_fn = nn.CrossEntropyLoss()

optimizer = torch.optim.SGD(model_task2.parameters(), lr=0.0001, momentum=0.9, weight_decay=5e-4)

for epoch in range(5):
    print(f"Epoch {epoch+1}: ----------------------")
    if epoch > 2:
        for param in model_task2.parameters():
            param.requires_grad = True
        optimizer = torch.optim.SGD(model_task2.parameters(), lr=0.0001, momentum=0.9)
    train(alpha, T)
    test(alpha, T)
    val(epoch)

Epoch 1: ----------------------
Loss: 9.026002,    64/24000
Loss: 6.970753,  6464/24000
Loss: 5.930604, 12864/24000
Loss: 5.666660, 19264/24000
Test Error: 
 Accuracy: 18.5, Avg Loss: 1.763167
Validation Acc: 38.483333333333334

Epoch 2: ----------------------
Loss: 4.964440,    64/24000
Loss: 4.754537,  6464/24000
Loss: 5.003913, 12864/24000
Loss: 4.465080, 19264/24000
Test Error: 
 Accuracy: 67.5, Avg Loss: 1.214632
Validation Acc: 4.266666666666667

Epoch 3: ----------------------
Loss: 4.393607,    64/24000
Loss: 4.129972,  6464/24000
Loss: 4.054672, 12864/24000
Loss: 4.213043, 19264/24000
Test Error: 
 Accuracy: 80.8, Avg Loss: 1.017046
Validation Acc: 0.016666666666666666

Epoch 4: ----------------------
Loss: 3.811729,    64/24000
Loss: 3.227595,  6464/24000
Loss: 2.818718, 12864/24000
Loss: 2.401695, 19264/24000
Test Error: 
 Accuracy: 91.2, Avg Loss: 0.506622
Validation Acc: 0.0

Epoch 5: ----------------------
Loss: 2.101202,    64/24000
Loss: 1.723558,  6464/24000
Loss: 1.64