## Importing Libraries

In [None]:
from pathlib import Path
import requests
import pickle
import gzip
from matplotlib import pyplot as plt
import tqdm

# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms

## Device

In [None]:
# Device setup
device = (
    "cuda:0" if torch.cuda.is_available() else # Nvidia GPU
    "mps" if torch.backends.mps.is_available() else # Apple Silicon GPU
    "cpu"
)
print(f"Device = {device}")

## Hyperparameters

In [None]:
batch_size = 64
val_batch_size = 1000
validation_size = 0.2

## Dataset

In [None]:
# Dataset path
DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"
PATH.mkdir(parents=True, exist_ok=True)

In [None]:
# Dataset download
URL = "https://github.com/pytorch/tutorials/raw/main/_static/"
FILENAME = "mnist.pkl.gz"

if not (PATH / FILENAME).exists():
    content = requests.get(URL + FILENAME).content
    (PATH / FILENAME).open("wb").write(content)

with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
    ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")

In [None]:
# Custom Dataset
class MNIST_Dataset(Dataset):
    def __init__(self, x, y, transform=None):
        self.x = x
        self.y = y
        self.transform = transform
        
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        x = self.x[idx].reshape(1, 28, 28)
        y = self.y[idx]
        if self.transform:
            x = self.transform(x)
        return x, y
    
transfrom = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

In [None]:
# Full Dataset
train_ds = MNIST_Dataset(x_train, y_train, transform=transfrom)
test_ds = MNIST_Dataset(x_valid, y_valid, transform=transfrom)

train_size = int(len(train_ds) * (1 - validation_size))
valid_size = len(train_ds) - train_size
train_ds, valid_ds = random_split(train_ds, [train_size, valid_size])

train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
valid_dl = DataLoader(valid_ds, batch_size=val_batch_size, shuffle=False)
test_dl = DataLoader(test_ds, batch_size=val_batch_size, shuffle=False)

In [None]:
# 0&1 Dataset
x_train_01 = x_train[y_train <= 1]
y_train_01 = y_train[y_train <= 1]
x_valid_01 = x_valid[y_valid <= 1]
y_valid_01 = y_valid[y_valid <= 1]

print(x_train_01.shape)
print(y_train_01.shape)
print(x_valid_01.shape)
print(y_valid_01.shape)

In [None]:
train_ds_01 = MNIST_Dataset(x_train_01, y_train_01, transform=transfrom)
test_ds_01 = MNIST_Dataset(x_valid_01, y_valid_01, transform=transfrom)

train_size_01 = int(len(train_ds_01) * (1 - validation_size))
valid_size_01 = len(train_ds_01) - train_size_01
train_ds_01, valid_ds_01 = random_split(train_ds_01, [train_size_01, valid_size_01])

train_dl_01 = DataLoader(train_ds_01, batch_size=batch_size, shuffle=True)
valid_dl_01 = DataLoader(valid_ds_01, batch_size=val_batch_size, shuffle=False)
test_dl_01 = DataLoader(test_ds_01, batch_size=val_batch_size, shuffle=False)

## Model

In [None]:
class MNIST_CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(16, 10, kernel_size=3, stride=2, padding=1)

    def forward(self, xb):
        xb = xb.view(-1, 1, 28, 28)
        xb = F.relu(self.conv1(xb))
        xb = F.relu(self.conv2(xb))
        xb = F.relu(self.conv3(xb))
        xb = F.avg_pool2d(xb, 4)
        return xb.view(-1, xb.size(1))

In [None]:
# Load model
base_model = MNIST_CNN().to(device)
base_model.load_state_dict(torch.load("mnist_cnn.pth"))
model_01 = MNIST_CNN().to(device)
model_01.load_state_dict(torch.load("mnist_cnn_01.pth"))

## Evaluation

In [None]:
correct = [0] * 10
total = [0] * 10

In [None]:
# Accuracy on each label
def evaluate_labels(model, test_dl):
    model.eval()
    with torch.no_grad():
        for xb, yb in test_dl:
            xb, yb = xb.to(device), yb.to(device)
            output = model(xb)
            _, predicted = torch.max(output, 1)
            for i in range(len(yb)):
                label = yb[i]
                total[label] += 1
                if predicted[i] == label:
                    correct[label] += 1
    return [correct[i] / total[i] for i in range(10)]

In [None]:
for i, acc in enumerate(evaluate_labels(base_model, test_dl)):
    print(f"Label {i}, Accuracy: {acc:.4f}")
    
plt.bar(range(10), evaluate_labels(base_model, test_dl))
plt.xlabel("label")
plt.ylabel("accuracy")
plt.show()

## Arithmetic Operations

In [None]:
# Subtract model weights
def subtract(model1, model2):
    for p1, p2 in zip(model1.parameters(), model2.parameters()):
        p1.data -= p2.data

subtract(base_model, model_01)        

## Re-Evaluation

In [None]:
## Re-Evaluation&Inference
for i, acc in enumerate(evaluate_labels(base_model, test_dl)):
    print(f"Label {i}, Accuracy: {acc:.4f}")

plt.bar(range(10), evaluate_labels(base_model, test_dl))
plt.xlabel("label")
plt.ylabel("accuracy")
plt.show()