This script can be used to train models on FashionMNIST in a sort of continual learning way

### Import modules

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.utils.data import DataLoader

### Import and process data

In [None]:
dataset_full = torchvision.datasets.FashionMNIST(
    "data", train=True, download=True, transform=transforms.ToTensor()
)

# Choosing the classes we want
idx_tshirt_trouser = [0, 1]
idx_pullover_dress = [2, 3]
idx_tshirt_trouser_pullover_dress = [0, 1, 2, 3]

# Creating the datasets
dataset_tshirt_trouser = [
    data for data in dataset_full if data[1] in idx_tshirt_trouser
]  # 01
dataset_pullover_dress = [
    data for data in dataset_full if data[1] in idx_pullover_dress
]  # 23
dataset_tshirt_trouser_pullover_dress = [
    data for data in dataset_full if data[1] in idx_tshirt_trouser_pullover_dress
]

# Splits the datasets between train and test
train_dataset_all, test_dataset_all = torch.utils.data.dataset.random_split(
    dataset_full, [50000, 10000]
)
train_dataset_01, test_dataset_01 = torch.utils.data.dataset.random_split(
    dataset_tshirt_trouser, [10000, 2000]
)
train_dataset_23, test_dataset_23 = torch.utils.data.dataset.random_split(
    dataset_pullover_dress, [10000, 2000]
)
train_dataset_0123, test_dataset_0123 = torch.utils.data.dataset.random_split(
    dataset_tshirt_trouser_pullover_dress, [20000, 4000]
)

# Creating PyTorch DataLoader with a batch_size of 8
batch_size = 8
train_loader_all = DataLoader(train_dataset_all, batch_size=batch_size)
test_loader_all = DataLoader(test_dataset_all, batch_size=batch_size)

train_loader_01 = DataLoader(train_dataset_01, batch_size=batch_size)
test_loader_01 = DataLoader(test_dataset_01, batch_size=batch_size)

train_loader_23 = DataLoader(train_dataset_23, batch_size=batch_size)
test_loader_23 = DataLoader(test_dataset_23, batch_size=batch_size)

train_loader_0123 = DataLoader(train_dataset_0123, batch_size=batch_size)
test_loader_0123 = DataLoader(test_dataset_0123, batch_size=batch_size)

### Architecture of the model

In [3]:
# Simple class that creates a CNN network with 
# two convolutional layers and a FCN

class FashionCNN(nn.Module):
    def __init__(self):
        super(FashionCNN, self).__init__()

        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )

        self.fc1 = nn.Linear(in_features=64 * 6 * 6, out_features=600)
        self.fc2 = nn.Linear(in_features=600, out_features=120)
        self.fc3 = nn.Linear(in_features=120, out_features=4)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.view(out.size(0), -1)
        out = self.fc1(out)
        out = self.fc2(out)
        out = self.fc3(out)

        return out

We want to do "continual learning" so we first train on two classes before training on the four classes

### Parameters


In [None]:
train_first_two = test_loader_01 # first two classes to train on
train_full = train_loader_0123 # all four classes
not_freeze_list = ["fc3.weight", "fc3.bias", "fc2.weight", "fc2.bias"] # layers not to freeze
name = "model_01_0123" # generic name for the model
n_iter = 50
accuracy_01 = []
accuracy_23 = []

### Training loop

In [None]:
def one_training(not_freeze_list, name_mdl, train_full, train__first_two):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    model = FashionCNN()
    model.to(device)

    error = nn.CrossEntropyLoss()
    learning_rate = 0.001
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    num_epochs = 1
    count = 0

    for epoch in range(num_epochs):
        for images, labels in train_first_two:
            images, labels = images.to(device), labels.to(device)

            train = Variable(images.view(batch_size, 1, 28, 28))
            labels = Variable(labels)

            outputs = model(train)
            loss = error(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            count += 1

    for name, param in model.named_parameters():
        if name not in not_freeze_list:
            param.requires_grad = False

    num_epochs = 1
    count = 0

    for epoch in range(num_epochs):
        for images, labels in train_full:
            images, labels = images.to(device), labels.to(device)

            train = Variable(images.view(batch_size, 1, 28, 28))
            labels = Variable(labels)

            outputs = model(train)
            loss = error(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            count += 1

            if count % 10 == 0:
                print(count)

            if count in range(100):
                total = 0
                correct = 0

                for images, labels in test_loader_01:
                    images, labels = images.to(device), labels.to(device)

                    test = Variable(images.view(batch_size, 1, 28, 28))

                    outputs = model(test)

                    predictions = torch.max(outputs, 1)[1].to(device)
                    correct += (predictions == labels).sum()

                    total += len(labels)

                accuracy_01 = correct * 100 / total

                total = 0
                correct = 0

                for images, labels in test_loader_23:
                    images, labels = images.to(device), labels.to(device)

                    test = Variable(images.view(batch_size, 1, 28, 28))

                    outputs = model(test)

                    predictions = torch.max(outputs, 1)[1].to(device)
                    correct += (predictions == labels).sum()

                    total += len(labels)

                accuracy_23 = correct * 100 / total

                if (accuracy_01 + accuracy_23) > 160:
                    path = "models/" + name_mdl + ".pth"
                    torch.save(
                        {
                            "model_state_dict": model.state_dict(),
                            "metrics": [accuracy_01, accuracy_23],
                        },
                        path,
                    )
                    return accuracy_01, accuracy_23

### Training

In [None]:
for iter in range(n_iter):
    print(f"Iteration {iter+1}/{n_iter}")
    name_iter = f"{name}_ex_{iter}"
    acc_01, acc_23 = one_training(not_freeze_list, name_iter)
    accuracy_01.append(acc_01)
    accuracy_23.append(acc_23)

print("Average test for classes 01", np.mean([x.cpu() for x in accuracy_01]))
print("Average test for classes 23", np.mean([x.cpu() for x in accuracy_23]))