In [1]:
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Normalize

In [2]:
class FashionMNISTtask1(datasets.FashionMNIST):
  def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
    super(FashionMNISTtask1, self).__init__(root, train=train, transform=transform, target_transform=target_transform,download=download)
    self.classes = self.classes[:6]

  def __getitem__(self, index):
    img, target = super(FashionMNISTtask1, self).__getitem__(index)
    if target < 6:
        return img, target
    else:
        return img, -1


In [3]:
class FashionMNISTtask2(datasets.FashionMNIST):
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        super(FashionMNISTtask2, self).__init__(root, train=train, transform=transform, target_transform=target_transform,download=download)
        self.classes = self.classes[6:]

    def __getitem__(self, index):
        img, target = super(FashionMNISTtask2, self).__getitem__(index)
        if target >= 6:
            return img, target
        else:
            return img, -1

In [4]:
transform = torchvision.transforms.Compose([ToTensor(), Normalize((0.5), (0.5))])

train_dataset_task1 = FashionMNISTtask1(root='./data1', train=True, transform=transform, download=True)
test_dataset_task1 = FashionMNISTtask1(root='./data1', train=False, transform=transform, download=True)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data1/FashionMNISTtask1/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:03<00:00, 7051358.33it/s]


Extracting ./data1/FashionMNISTtask1/raw/train-images-idx3-ubyte.gz to ./data1/FashionMNISTtask1/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data1/FashionMNISTtask1/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 305205.18it/s]


Extracting ./data1/FashionMNISTtask1/raw/train-labels-idx1-ubyte.gz to ./data1/FashionMNISTtask1/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data1/FashionMNISTtask1/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:00<00:00, 5478483.10it/s]


Extracting ./data1/FashionMNISTtask1/raw/t10k-images-idx3-ubyte.gz to ./data1/FashionMNISTtask1/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data1/FashionMNISTtask1/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 17626348.56it/s]

Extracting ./data1/FashionMNISTtask1/raw/t10k-labels-idx1-ubyte.gz to ./data1/FashionMNISTtask1/raw






In [5]:
train_dataset_task2 = FashionMNISTtask2(root='./data2', train=True, transform=transform, download=True)
test_dataset_task2 = FashionMNISTtask2(root='./data2', train=False, transform=transform, download=True)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data2/FashionMNISTtask2/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:01<00:00, 16439850.96it/s]


Extracting ./data2/FashionMNISTtask2/raw/train-images-idx3-ubyte.gz to ./data2/FashionMNISTtask2/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data2/FashionMNISTtask2/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 302069.14it/s]


Extracting ./data2/FashionMNISTtask2/raw/train-labels-idx1-ubyte.gz to ./data2/FashionMNISTtask2/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data2/FashionMNISTtask2/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:00<00:00, 5544086.06it/s]


Extracting ./data2/FashionMNISTtask2/raw/t10k-images-idx3-ubyte.gz to ./data2/FashionMNISTtask2/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data2/FashionMNISTtask2/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 5007485.39it/s]

Extracting ./data2/FashionMNISTtask2/raw/t10k-labels-idx1-ubyte.gz to ./data2/FashionMNISTtask2/raw






In [6]:
train_dataset_filtered_task1 = [data for data in train_dataset_task1 if data[1] != -1]
test_dataset_filtered_task1 = [data for data in test_dataset_task1 if data[1] != -1]

train_dataset_filtered_task2 = [data for data in train_dataset_task2 if data[1] != -1]
test_dataset_filtered_task2 = [data for data in test_dataset_task2 if data[1] != -1]

In [7]:
train_dataloader_task1 = DataLoader(train_dataset_filtered_task1, batch_size=64, shuffle=True)
test_dataloader_task1 = DataLoader(test_dataset_filtered_task1, batch_size=256, shuffle=False)

for X, y in train_dataloader_task1:
  print(f"Shape of X [N, C, H, W]: {X.shape}")
  print(f"SHape of y: {y.shape}, dtype: {y.dtype}")
  break

Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
SHape of y: torch.Size([64]), dtype: torch.int64


In [9]:
len(train_dataloader_task1)

563

In [10]:
train_dataloader_task2 = DataLoader(train_dataset_filtered_task2, batch_size=64, shuffle=True)
test_dataloader_task2 = DataLoader(test_dataset_filtered_task2, batch_size=256, shuffle=False)

In [11]:
device = ('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using {device} device")

Using cpu device


In [12]:
def kaiming_normal_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
    elif isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, nonlinearity='sigmoid')

In [13]:
class NeuralNetwork(nn.Module):
    def __init__(self, num_classes=10, hidden_size=512):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()

        self.fc1 = nn.Linear(28*28, hidden_size)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.relu2 = nn.ReLU()
        self.classifier = nn.Linear(hidden_size, num_classes)

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='sigmoid')
            elif isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')

    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        logits = self.classifier(x)

        return logits

In [14]:
class CNN(nn.Module):
    def __init__(self, num_classes=10, hidden_size=512):
        super(CNN, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(16, 32, 3, 1, 1)
        self.relu2 = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(32*14*14, hidden_size)
        self.relu3 = nn.ReLU()
        self.classifier = nn.Linear(hidden_size, num_classes)

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='sigmoid')

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu3(x)
        logits = self.classifier(x)

        return logits


In [15]:
def train(dataloader, model, loss_fn, optimizer):
  size = len(dataloader.dataset)
  model.train()
  for batch, (X, y) in enumerate(dataloader):
    X, y = X.to(device), y.to(device)
    pred = model(X)
    loss = loss_fn(pred, y)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if batch % 100 == 0:
      loss, current = loss.item(), (batch+1) * len(X)
      print(f"Loss: {loss:>7f}, {current:>5d}/{size:>5d}")

In [16]:
def test(dataloader, model, loss_fn):
  size = len(dataloader.dataset)
  num_batches = len(dataloader)
  model.eval()

  test_loss, correct = 0, 0
  with torch.no_grad():
    for X, y in dataloader:
      X, y = X.to(device), y.to(device)
      pred = model(X)
      test_loss += loss_fn(pred, y).item()
      correct += (pred.argmax(1) == y).type(torch.float).sum().item()
  test_loss /= num_batches
  correct /= size
  print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}, Avg Loss: {test_loss:>8f}\n")


In [17]:
model_task1 = CNN(num_classes=6, hidden_size=512).to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model_task1.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

In [18]:
model_task1.classifier.out_features

6

In [19]:
epochs = 5
for t in range(epochs):
  print(f"Epoch {t+1}\n---------------------------")
  train(train_dataloader_task1, model_task1, loss_fn, optimizer)
  test(test_dataloader_task1, model_task1, loss_fn)
print("Done!")

torch.save(model_task1.state_dict(), "model_old.pth")

Epoch 1
---------------------------
Loss: 1.898291,    64/36000
Loss: 0.276725,  6464/36000
Loss: 0.217151, 12864/36000
Loss: 0.194844, 19264/36000
Loss: 0.277112, 25664/36000
Loss: 0.085803, 32064/36000
Test Error: 
 Accuracy: 92.3, Avg Loss: 0.211253

Epoch 2
---------------------------
Loss: 0.115818,    64/36000
Loss: 0.222595,  6464/36000
Loss: 0.104931, 12864/36000
Loss: 0.178367, 19264/36000
Loss: 0.129727, 25664/36000
Loss: 0.049222, 32064/36000
Test Error: 
 Accuracy: 93.7, Avg Loss: 0.175326

Epoch 3
---------------------------
Loss: 0.067484,    64/36000
Loss: 0.249414,  6464/36000
Loss: 0.071581, 12864/36000
Loss: 0.159191, 19264/36000
Loss: 0.160739, 25664/36000
Loss: 0.180911, 32064/36000
Test Error: 
 Accuracy: 93.7, Avg Loss: 0.169220

Epoch 4
---------------------------
Loss: 0.052266,    64/36000
Loss: 0.080109,  6464/36000
Loss: 0.140392, 12864/36000
Loss: 0.038281, 19264/36000
Loss: 0.052388, 25664/36000
Loss: 0.089193, 32064/36000
Test Error: 
 Accuracy: 94.4, Avg 

# LWF

In [20]:
model_task2 = CNN(num_classes=6, hidden_size=512).to(device)
model_task1 = CNN(num_classes=6, hidden_size=512).to(device)

model_task2.load_state_dict(torch.load("model_old.pth"))
model_task1.load_state_dict(torch.load("model_old.pth"))

in_features = model_task1.classifier.in_features
out_features = model_task1.classifier.out_features

weight = model_task1.classifier.weight.data
bias = model_task1.classifier.bias.data

new_out_features = 10

new_fc = nn.Linear(in_features, new_out_features)
kaiming_normal_init(new_fc.weight)

new_fc.weight.data[:out_features] = weight
new_fc.bias.data[:out_features] = bias

model_task2.classifier = new_fc
model_task2 = model_task2.to(device)
print("New head numbers: ", model_task2.classifier.out_features)

for param in model_task1.parameters():
  param.requires_grad = False

New head numbers:  10


Changes in training and testing

In [21]:
def train(alpha, T):
  size = len(train_dataloader_task2.dataset)
  # We set net_new to evaluation mode to prevent it from being updated
  # while computing the distillation loss from the old model
  model_task2.train()
  for batch, (X, y) in enumerate(train_dataloader_task2):
    X, y = X.to(device), y.to(device)

    outputs = model_task2(X)
    soft_y = model_task1(X)

    loss1 = loss_fn(outputs, y)

    outputs_S = nn.functional.softmax(outputs[:, :out_features] / T, dim=1)
    outputs_T = nn.functional.softmax(soft_y[:, :out_features] / T, dim=1)

    loss2 = outputs_T.mul(-1 * torch.log(outputs_S))
    loss2 = loss2.sum(1)
    loss2 = loss2.mean() * T * T

    loss = loss1 + alpha * loss2

    loss.backward(retain_graph=True)
    optimizer.step()
    optimizer.zero_grad()

    if batch % 100 == 0:
      loss, current = loss.item(), (batch+1) * len(X)
      print(f"Loss: {loss:>7f}, {current:>5d}/{size:>5d}")


In [22]:
def test(alpha, T):
  size = len(test_dataloader_task2.dataset)
  num_batches = len(test_dataloader_task2)
  model_task2.eval()

  test_loss, correct = 0, 0
  with torch.no_grad():
    for X, y in test_dataloader_task2:
      X, y = X.to(device), y.to(device)

      outputs = model_task2(X)
      soft_y = model_task1(X)

      loss1 = loss_fn(outputs, y)

      outputs_S = nn.functional.softmax(outputs[:, :out_features] / T, dim=1)
      outputs_T = nn.functional.softmax(soft_y[:, :out_features] / T, dim=1)

      loss2 = outputs_T.mul(-1 * torch.log(outputs_S))
      loss2 = loss2.sum(1)
      loss2 = loss2.mean() * T * T

      loss = loss1 * alpha + loss2 * (1 - alpha)

      test_loss += loss.item()
      _, predicted = outputs.max(1)
      correct += predicted.eq(y).sum().item()
      # correct += (pred.argmax(1) == y).type(torch.float).sum().item()
  test_loss /= num_batches
  correct /= size
  print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}, Avg Loss: {test_loss:>8f}")


In [23]:
def val(epoch):
    model_task2.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for batch, (X, y) in enumerate(test_dataloader_task1):
            X, y = X.to(device), y.to(device)
            outputs = model_task2(X)
            _, predicted_old = outputs.max(1)
            total += len(y)
            correct += predicted_old.eq(y).sum().item()
        print(f"Validation Acc: {100. * correct / total}\n")

In [24]:
T = 2
alpha = 0.5
loss_fn = nn.CrossEntropyLoss()

optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model_task2.parameters()), lr=0.001, momentum=0.9, weight_decay=5e-4)

# warmup_epochs = 5
# initial_lr = 0.0001
# final_lr = 0.01

# warmup_scheduler = torch.optim.lr_scheduler.LambdaLR(
#     optimizer,
#     lr_lambda= lambda epoch: (epoch+1)/warmup_epochs if epoch < warmup_epochs else final_lr/initial_lr
# )

for epoch in range(5):
    print(f"Epoch {epoch+1}: ----------------------")
    train(alpha, T)
    test(alpha, T)
    val(epoch)

Epoch 1: ----------------------
Loss: 9.306654,    64/24000
Loss: 1.464222,  6464/24000
Loss: 1.442121, 12864/24000
Loss: 1.512976, 19264/24000
Test Error: 
 Accuracy: 96.2, Avg Loss: 1.176230
Validation Acc: 14.6

Epoch 2: ----------------------
Loss: 1.303878,    64/24000
Loss: 1.424517,  6464/24000
Loss: 1.333093, 12864/24000
Loss: 1.102066, 19264/24000
Test Error: 
 Accuracy: 97.0, Avg Loss: 1.160700
Validation Acc: 13.616666666666667

Epoch 3: ----------------------
Loss: 1.244305,    64/24000
Loss: 0.891064,  6464/24000
Loss: 1.163645, 12864/24000
Loss: 1.110944, 19264/24000
Test Error: 
 Accuracy: 97.2, Avg Loss: 1.154813
Validation Acc: 13.483333333333333

Epoch 4: ----------------------
Loss: 1.109710,    64/24000
Loss: 1.190669,  6464/24000
Loss: 1.268371, 12864/24000
Loss: 1.263525, 19264/24000
Test Error: 
 Accuracy: 97.4, Avg Loss: 1.151523
Validation Acc: 13.633333333333333

Epoch 5: ----------------------
Loss: 1.015058,    64/24000
Loss: 1.222240,  6464/24000
Loss: 1.16