In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms

In [None]:
import torchvision
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms

In [None]:
batch_size = 64
train_dataset = datasets.MNIST(
    root="dataset/",
    train=True,
    transform=transforms.ToTensor(),
    download=True,
)

test_dataset = datasets.MNIST(
    root="dataset/",
    train=False,
    transform=transforms.ToTensor(),
    download=True,
)

In [None]:
batch_size = 64

train_indices = train_dataset.targets.clone().detach() == 0
train_indices |= train_dataset.targets.clone().detach() == 1
train_dataset_01 = torch.utils.data.Subset(train_dataset, torch.where(train_indices)[0])

train_loader_01 = DataLoader(dataset=train_dataset_01, batch_size=batch_size, shuffle=True)

test_indices = test_dataset.targets.clone().detach() == 0
test_indices |= test_dataset.targets.clone().detach() == 1
test_dataset_01 = torch.utils.data.Subset(test_dataset, torch.where(test_indices)[0])

test_loader_01 = DataLoader(dataset=test_dataset_01, batch_size=batch_size, shuffle=True)


In [None]:
class NN(nn.Module):
    def __init__(self, input_size, num_classes):
        super(NN, self).__init__()
        self.fc1 = nn.Linear(input_size, 400)
        self.fc2 = nn.Linear(400, 400)
        self.fc3 = nn.Linear(400, num_classes)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_size = 784
num_classes = 2
learning_rate = 0.001
num_epochs = 3

In [None]:
model = NN(input_size=input_size, num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
for epoch in range(num_epochs):
    print(f"Epoch: {epoch}")
    for batch_idx, (data, targets) in enumerate(train_loader_01):
        data = data.to(device=device)
        targets = targets.to(device=device)

        data = data.reshape(data.shape[0], -1)

        scores = model(data)
        loss = criterion(scores, targets)

        optimizer.zero_grad()
        loss.backward()

        optimizer.step()

Epoch: 0
Epoch: 1
Epoch: 2


In [None]:
def check_accuracy(loader, model):
    num_correct = 0
    num_samples = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device)
            y = y.to(device=device)
            x = x.reshape(x.shape[0], -1)

            scores = model(x)
            _, predictions = scores.max(1)
            num_correct += (predictions == y).sum().item()
            num_samples += predictions.size(0)

        print(
            f"Got {num_correct} / {num_samples} with accuracy"
            f" {float(num_correct) / float(num_samples) * 100:.2f}"
        )

    model.train()


check_accuracy(train_loader_01, model)
check_accuracy(test_loader_01, model)

Got 12662 / 12665 with accuracy 99.98
Got 2113 / 2115 with accuracy 99.91


In [None]:
import copy
num_classes = 4

model2 = copy.deepcopy(model)
model2.fc3 = nn.Linear(400, num_classes).to(device)

In [None]:
batch_size = 64

train_indices = train_dataset.targets.clone().detach() == 0
train_indices |= train_dataset.targets.clone().detach() == 1
train_indices |= train_dataset.targets.clone().detach() == 2
train_indices |= train_dataset.targets.clone().detach() == 3
train_dataset_02 = torch.utils.data.Subset(train_dataset, torch.where(train_indices)[0])

train_loader_02 = DataLoader(dataset=train_dataset_02, batch_size=batch_size, shuffle=True)

test_indices = test_dataset.targets.clone().detach() == 0
test_indices |= test_dataset.targets.clone().detach() == 1
test_indices |= test_dataset.targets.clone().detach() == 2
test_indices |= test_dataset.targets.clone().detach() == 3
test_dataset_02 = torch.utils.data.Subset(test_dataset, torch.where(test_indices)[0])

test_loader_02 = DataLoader(dataset=test_dataset_02, batch_size=batch_size, shuffle=True)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_size = 784
num_classes = 4
learning_rate = 0.01
num_epochs = 2

In [None]:
optimizer = optim.Adam(model2.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
model2.fc1.weight.requires_grad = False
model2.fc1.bias.requires_grad = False

model2.fc2.weight.requires_grad = False
model2.fc2.bias.requires_grad = False

model2.fc3.weight.requires_grad = True
model2.fc3.bias.requires_grad = True

In [None]:
# Training loop for transfer learning
for epoch in range(num_epochs):
    print(f"Epoch: {epoch}")
    for batch_idx, (data, targets) in enumerate(train_loader_02):
        data = data.to(device=device)
        targets = targets.to(device=device)

        data = data.reshape(data.shape[0], -1)

        scores = model2(data)

        loss = criterion(scores, targets)

        optimizer.zero_grad()
        loss.backward()

        optimizer.step()


Epoch: 0
Epoch: 1


In [None]:
check_accuracy(train_loader_02, model2)
check_accuracy(test_loader_02, model2)

Got 23216 / 24754 with accuracy 93.79
Got 3923 / 4157 with accuracy 94.37


In [None]:
are_parameters_equal = torch.allclose(model.fc1.weight, model2.fc1.weight) and torch.allclose(model.fc1.bias, model2.fc1.bias)
are_parameters_equal |= torch.allclose(model.fc2.weight, model2.fc2.weight) and torch.allclose(model.fc2.bias, model2.fc2.bias)

if are_parameters_equal:
    print("Parameters of fc1 and fc2 are equal.")
else:
    print("Parameters of fc1 and fc2 are not equal.")

Parameters of fc1 and fc2 are equal.
