In [None]:
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Normalize

In [None]:
class FashionMNISTtask1(datasets.FashionMNIST):
  def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
    super(FashionMNISTtask1, self).__init__(root, train=train, transform=transform, target_transform=target_transform,download=download)
    self.classes = self.classes[:6]

  def __getitem__(self, index):
    img, target = super(FashionMNISTtask1, self).__getitem__(index)
    if target < 6:
        return img, target
    else:
        return img, -1


In [None]:
class FashionMNISTtask2(datasets.FashionMNIST):
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        super(FashionMNISTtask2, self).__init__(root, train=train, transform=transform, target_transform=target_transform,download=download)
        self.classes = self.classes[6:]

    def __getitem__(self, index):
        img, target = super(FashionMNISTtask2, self).__getitem__(index)
        if target >= 6:
            return img, target
        else:
            return img, -1

In [None]:
transform = torchvision.transforms.Compose([ToTensor(), Normalize((0.5), (0.5))])

train_dataset_1 = FashionMNISTtask1(root='./data1', train=True, transform=transform, download=True)
test_dataset_1 = FashionMNISTtask1(root='./data1', train=False, transform=transform, download=True)

In [None]:
train_dataset_2 = FashionMNISTtask2(root='./data2', train=True, transform=transform, download=True)
test_dataset_2 = datasets.FashionMNIST(
    root='data2',
    train=False,
    download=True,
    transform=transform
)

In [None]:
print(dir(datasets.FashionMNIST))

['__add__', '__annotations__', '__class__', '__class_getitem__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__orig_bases__', '__parameters__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__slots__', '__str__', '__subclasshook__', '__weakref__', '_check_exists', '_check_legacy_exist', '_format_transform_repr', '_is_protocol', '_load_data', '_load_legacy_data', '_repr_indent', 'class_to_idx', 'classes', 'download', 'extra_repr', 'mirrors', 'processed_folder', 'raw_folder', 'resources', 'test_data', 'test_file', 'test_labels', 'train_data', 'train_labels', 'training_file']


In [None]:
train_dataset_filtered_old = [data for data in train_dataset_1 if data[1] != -1]
test_dataset_filtered_old = [data for data in test_dataset_1 if data[1] != -1]

train_dataset_filtered_new = [data for data in train_dataset_2 if data[1] != -1]

In [None]:
train_dataloader_old = DataLoader(train_dataset_filtered_old, batch_size=64, shuffle=True)
test_dataloader_old = DataLoader(test_dataset_filtered_old, batch_size=64, shuffle=False)

for X, y in train_dataloader_old:
  print(f"Shape of X [N, C, H, W]: {X.shape}")
  print(f"SHape of y: {y.shape}, dtype: {y.dtype}")
  break

Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
SHape of y: torch.Size([64]), dtype: torch.int64


In [None]:
len(train_dataloader_old)

563

In [None]:
train_dataloader_new = DataLoader(train_dataset_filtered_new, batch_size=64, shuffle=True)
test_dataloader_new = DataLoader(test_dataset_2, batch_size=64, shuffle=False)

In [None]:
device = ('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using {device} device")

Using cpu device


In [None]:
def kaiming_normal_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
    elif isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, nonlinearity='sigmoid')

In [94]:
class NeuralNetwork(nn.Module):
    def __init__(self, num_classes=10, hidden_size=512):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()

        self.fc1 = nn.Linear(28*28, hidden_size)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.relu2 = nn.ReLU()
        self.classifier = nn.Linear(hidden_size, num_classes)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='sigmoid')
            elif isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')

    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        logits = self.fc3(x)

        return logits

In [95]:
def train(dataloader, model, loss_fn, optimizer):
  size = len(dataloader.dataset)
  model.train()
  for batch, (X, y) in enumerate(dataloader):
    X, y = X.to(device), y.to(device)
    pred = model(X)
    loss = loss_fn(pred, y)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if batch % 100 == 0:
      loss, current = loss.item(), (batch+1) * len(X)
      print(f"Loss: {loss:>7f}, {current:>5d}/{size:>5d}")

In [96]:
def test(dataloader, model, loss_fn):
  size = len(dataloader.dataset)
  num_batches = len(dataloader)
  model.eval()

  test_loss, correct = 0, 0
  with torch.no_grad():
    for X, y in dataloader:
      X, y = X.to(device), y.to(device)
      pred = model(X)
      test_loss += loss_fn(pred, y).item()
      correct += (pred.argmax(1) == y).type(torch.float).sum().item()
  test_loss /= num_batches
  correct /= size
  print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}, Avg Loss: {test_loss:>8f}\n")


In [97]:
pre_model = NeuralNetwork(num_classes=6, hidden_size=512).to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(pre_model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

In [102]:
print(dir(nn.Linear(2, 2)))

['T_destination', '__annotations__', '__call__', '__class__', '__constants__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_apply', '_backward_hooks', '_backward_pre_hooks', '_buffers', '_call_impl', '_compiled_call_impl', '_forward_hooks', '_forward_hooks_always_called', '_forward_hooks_with_kwargs', '_forward_pre_hooks', '_forward_pre_hooks_with_kwargs', '_get_backward_hooks', '_get_backward_pre_hooks', '_get_name', '_is_full_backward_hook', '_load_from_state_dict', '_load_state_dict_post_hooks', '_load_state_dict_pre_hooks', '_maybe_warn_non_full_backward_hook', '_modules', '_named_members', '_non_persistent_buffers_set', '_parameters', '_register_load_s

In [98]:
epochs = 10
for t in range(epochs):
  print(f"Epoch {t+1}\n---------------------------")
  train(train_dataloader_old, pre_model, loss_fn, optimizer)
  test(test_dataloader_old, pre_model, loss_fn)
print("Done!")

torch.save(pre_model.state_dict(), "model_old.pth")

Epoch 1
---------------------------
Loss: 1.799676,    64/36000
Loss: 0.452803,  6464/36000
Loss: 0.300028, 12864/36000
Loss: 0.317498, 19264/36000
Loss: 0.255805, 25664/36000
Loss: 0.399239, 32064/36000
Test Error: 
 Accuracy: 89.8, Avg Loss: 0.276300

Epoch 2
---------------------------
Loss: 0.247731,    64/36000
Loss: 0.219728,  6464/36000
Loss: 0.160258, 12864/36000
Loss: 0.529205, 19264/36000
Loss: 0.131535, 25664/36000
Loss: 0.126601, 32064/36000
Test Error: 
 Accuracy: 90.2, Avg Loss: 0.265142

Epoch 3
---------------------------
Loss: 0.245242,    64/36000
Loss: 0.178187,  6464/36000
Loss: 0.174892, 12864/36000
Loss: 0.142251, 19264/36000
Loss: 0.211528, 25664/36000
Loss: 0.164661, 32064/36000
Test Error: 
 Accuracy: 91.0, Avg Loss: 0.244063

Epoch 4
---------------------------
Loss: 0.151808,    64/36000
Loss: 0.161958,  6464/36000
Loss: 0.178810, 12864/36000
Loss: 0.196824, 19264/36000
Loss: 0.193289, 25664/36000
Loss: 0.259371, 32064/36000
Test Error: 
 Accuracy: 91.1, Avg 

RuntimeError: ignored

In [99]:
torch.save(pre_model.state_dict(), "model_old.pth")

# LWF

In [101]:
net_new = NeuralNetwork(num_classes=6, hidden_size=512).to(device)
net_old = NeuralNetwork(num_classes=6, hidden_size=512).to(device)

net_new.load_state_dict(torch.load("model_old.pth"))
net_old.load_state_dict(torch.load("model_old.pth"))

in_features = net_old.classifier.in_features
out_features = net_old.classifier.out_features

weight = net_old.classifier.weight.data
bias = net_old.classifier.bias.data

new_out_features = 6 + 4

new_fc = nn.Linear(in_features, new_out_features)
kaiming_normal_init(new_fc.weight)

new_fc.weight.data[:out_features] = weight
new_fc.bias.data[:out_features] = bias

net_new.classifier = new_fc
net_new = net_new.to(device)
print("New head numbers: ", net_new.classifier.out_fetures)

for param in net_old.parameters():
  param.requires_grad = False

AttributeError: ignored

Changes in training and testing

In [None]:
def train(alpha, T):
  size = len(train_dataloader_2.dataset)
  net_new.train()
  for batch, (X, y) in enumerate(train_dataloader_2):
    X, y = X.to(device), y.to(device)

    outputs = net_new(X)
    soft_y = net_old(X)

    loss1 = loss_fn(outputs, y)

    outputs_S = F.softmax(outputs[:, :out_fetures] / T, dim=1)
    outputs_T = F.softmax(soft_y[:. :out_fetures] / T, dim=1)

    loss2 = outputs_T.mul(-1 * torch.log(outputs_S))
    loss2 = loss2.sum(1)
    loss2 = loss2.mean() * T * T

    loss = loss1 + alpha * loss2


    loss.backward(retain_graph=True)
    optimizer.step()
    optimizer.zero_grad()

    if batch % 100 == 0:
      loss, current = loss.item(), (batch+1) * len(X)
      print(f"Loss: {loss:>7f}, {current:>5d}/{size:>5d}")


In [None]:
def test(alpha, T):
  size = len(dataloader.dataset)
  num_batches = len(dataloader)
  net_new.eval()

  test_loss, correct = 0, 0
  with torch.no_grad():
    for X, y in dataloader:
      X, y = X.to(device), y.to(device)

      outputs = net_new(X)
      soft_y = net_old(X)

      loss1 = loss_fn(outputs, y)

      outputs_S = F.softmax(outputs[:, :out_fetures] / T, dim=1)
      outputs_T = F.softmax(soft_y[:. :out_fetures] / T, dim=1)

      loss2 = outputs_T.mul(-1 * torch.log(outputs_S))
      loss2 = loss2.sum(1)
      loss2 = loss2.mean() * T * T

      loss = loss1 * alpha + loss2 * (1 - alpha)

      test_loss += loss.item()
      _, predicted = outputs.max(1)
      correct += predicted.eq(y).sum().item()
      # correct += (pred.argmax(1) == y).type(torch.float).sum().item()
  test_loss /= num_batches
  correct /= size
  print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}, Avg Loss: {test_loss:>8f}\n")


In [None]:
T = 2
alpha = 0.5
loss_fn = nn.CrossEntropyLoss()

optimizer = optim.SGD(filter(lambda p: p.requires_grad, net_new.parametrs()), lr=0.01, momentum=0.9, weight_decay=5e-4)

for epoch in range(10):
  train(alpha, T)
  test(alpha, T)

torch.save(net_new.state_dict(), "model.pth")