In [1]:
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Normalize

In [2]:
class FashionMNISTtask1(datasets.FashionMNIST):
  def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
    super(FashionMNISTtask1, self).__init__(root, train=train, transform=transform, target_transform=target_transform,download=download)
    self.classes = self.classes[:6]

  def __getitem__(self, index):
    img, target = super(FashionMNISTtask1, self).__getitem__(index)
    if target < 6:
        return img, target
    else:
        return img, -1

In [3]:
class FashionMNISTtask2(datasets.FashionMNIST):
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        super(FashionMNISTtask2, self).__init__(root, train=train, transform=transform, target_transform=target_transform,download=download)
        self.classes = self.classes[6:]

    def __getitem__(self, index):
        img, target = super(FashionMNISTtask2, self).__getitem__(index)
        if target >= 6:
            return img, target
        else:
            return img, -1

In [None]:
transform = torchvision.transforms.Compose([ToTensor(), Normalize((0.5), (0.5))])

train_dataset_1 = FashionMNISTtask1(root='./data1', train=True, transform=transform, download=True)
test_dataset_1 = FashionMNISTtask1(root='./data1', train=False, transform=transform, download=True)

In [None]:
train_dataset_2 = FashionMNISTtask2(root='./data2', train=True, transform=transform, download=True)
test_dataset_2 = FashionMNISTtask2(root='./data2', train=False, transform=transform, download=True)
test_dataset_3 = datasets.FashionMNIST(
    root='data3',
    train=False,
    download=True,
    transform=transform
)

In [6]:
train_dataset_filtered_old = [data for data in train_dataset_1 if data[1] != -1]
test_dataset_filtered_old = [data for data in test_dataset_1 if data[1] != -1]

train_dataset_filtered_new = [data for data in train_dataset_2 if data[1] != -1]
test_dataset_filtered_new = [data for data in test_dataset_2 if data[1] != -1]

In [7]:
train_dataloader_old = DataLoader(train_dataset_filtered_old, batch_size=64, shuffle=True)
test_dataloader_old = DataLoader(test_dataset_filtered_old, batch_size=256, shuffle=False)

for X, y in train_dataloader_old:
  print(f"Shape of X [N, C, H, W]: {X.shape}")
  print(f"SHape of y: {y.shape}, dtype: {y.dtype}")
  break

Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
SHape of y: torch.Size([64]), dtype: torch.int64


In [8]:
train_dataloader_new = DataLoader(train_dataset_filtered_new, batch_size=64, shuffle=True)
test_dataloader_new = DataLoader(test_dataset_filtered_new, batch_size=256, shuffle=False)

In [9]:
eval_dataloader = DataLoader(test_dataset_3, batch_size=256, shuffle=True)

In [10]:
device = ('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using {device} device.")

Using cpu device.


In [11]:
class CNN1(nn.Module):
    def __init__(self, num_classes=10, hidden_size=512):
        super(CNN1, self).__init__()

        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28*28, hidden_size)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, 2*hidden_size)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(2*hidden_size, hidden_size)
        self.relu3 = nn.ReLU()
        self.classifier = nn.Linear(hidden_size, 28*28)

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
            elif isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')

    def forward(self, x):
        x = self.flatten(x)
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        x = self.relu3(self.fc3(x))
        logits = self.classifier(x)

        return logits

In [12]:
class CNN2(nn.Module):
    def __init__(self, num_classes=10, hidden_size=512):
        super(CNN2, self).__init__()

        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28*28, hidden_size)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, 2*hidden_size)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(2*hidden_size, hidden_size)
        self.relu3 = nn.ReLU()
        self.classifier = nn.Linear(hidden_size, num_classes)

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
            elif isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')

    def forward(self, x):
        x = self.flatten(x)
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        x = self.relu3(self.fc3(x))
        logits = self.classifier(x)

        return logits

In [13]:
def train(model, dataloader, loss_fn, optimizer):
    model.train()
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        pred = model(X)
        loss = loss_fn(pred, y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch+1) * len(X)
            print(f"Loss: {loss:>7f}, {current:>5d}/{size:>5d}")


In [14]:
def test(model, dataloader, loss_fn):
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)

    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)

            pred = model(X)
            test_loss += loss_fn(pred, y).item()

            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
        test_loss /= num_batches
        correct /= size
        print(f"Test loss:\nAccuracy: {100*correct:>0.1f}, Avg Loss: {test_loss:>8f}\n")

In [2]:
def val(model, epoch):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for batch, (X, y) in enumerate(eval_dataloader):
            X, y = X.to(device), y.to(device)
            outputs = model(X)
            _, predicted_old = outputs.max(1)
            total += len(y)
            correct += predicted_old.eq(y).sum().item()
        print(f"Validation Acc: {100. * correct / total}\n")

In [None]:
%time
from tqdm import tqdm

total_runs = 3
num_epochs = 1
learning_rate = 1e-3

for runs in range(total_runs):
    torch.manual_seed(runs)

    # Your CNN model definition
    model_theta = CNN1()
    model_theta = model_theta.to(device)
    optimizer_theta = torch.optim.Adam(model_theta.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()

    for dataloader in [train_dataloader_old, train_dataloader_new]:

        dataloader = tqdm(dataloader, total=len(dataloader))

        model_W = CNN2()
        model_W.to(device)
        optimizer_W = torch.optim.Adam(model_W.parameters(), lr=learning_rate)
        # Alternate training between d_old and d_new in each epoch
        for epoch in range(num_epochs):
            for X, y in dataloader:
                X, y = X.to(device), y.to(device)

                # Training the model for classification (model_W)
                optimizer_W.zero_grad()
                pred = model_W(model_theta(X))
                loss = criterion(pred, y)
                loss.backward(create_graph=True)
                optimizer_W.step()

            # Learning the representation after each dataloader
            for X, y in dataloader:
                X, y = X.to(device), y.to(device)

                optimizer_theta.zero_grad()
                pred = model_W(model_theta(X))
                loss = criterion(pred, y)
                loss.backward(create_graph=True)
                optimizer_theta.step()

    print(f"One run finished")

In [1]:
model_theta.eval()
correct, total = 0, 0
with torch.no_grad():
    for batch, (X, y) in enumerate(eval_dataloader):
        X, y = X.to(device), y.to(device)
        outputs = model_W(model_theta(X))
        _, predicted_old = outputs.max(1)
        total += len(y)
        correct += predicted_old.eq(y).sum().item()
    print(f"Validation Acc: {100. * correct / total}\n")

NameError: ignored