## Second Opinion model practice on MNIST

Hello and welcome to my TED talk. This notebook is about an idea for architecture I got when learning DL. I am oblivious to how effective it is or even if I'm original. The idea sounds similar to MoE (Mixture of Experts). While MoE seems to be about seperating unique tasks between a couple models, my idea is to take MNIST and make 10 models that each are only responsible for their own number and nothing else. The idea is that they would be easier to tweak individually and theoretically improve accuracy. \
I also realised that this can be an Evangelion reference if you squint at it.

# Dataset

But first we need to initialize the victim of many amateur machine learning students - MNIST, a dataset of tens of thousands of pictures of handwritten digits that we will use to teach our """experts""" how to recognize numbers

In [None]:
import torch
from torchvision import datasets
import torchvision.transforms as transforms

DATA_WORKERS = 0
BATCH_SIZE = 64

#We will use target_transform soon
def get_loaders(target_transform=None):
    #No data augmentation necessary. It's literally just 28x28 pixels
    transform = transforms.ToTensor()

    train_data = datasets.MNIST(root='data', 
                                train=True,
                                download=True, 
                                transform=transform,
                                target_transform=target_transform)
    #Data loader
    train_loader = torch.utils.data.DataLoader(train_data, 
                                            batch_size=BATCH_SIZE,
                                            num_workers=DATA_WORKERS,
                                            shuffle=True)

    val_data = datasets.MNIST(root='data', 
                                train=False,
                                download=True, 
                                transform=transform,
                                target_transform=target_transform)
    #Data loader
    val_loader = torch.utils.data.DataLoader(val_data, 
                                            batch_size=BATCH_SIZE,
                                            num_workers=DATA_WORKERS)
    
    return train_loader, val_loader

target_transform = transforms.Lambda(lambda y: torch.zeros(
    10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
train_loader, val_loader = get_loaders(target_transform)

Now let's take a look at what we are dealing with

In [None]:
%matplotlib inline

import matplotlib.pyplot as plt


data_iter = iter(train_loader)
image_batch, labels = next(data_iter)
image_batch = image_batch.numpy()

fig, axes = plt.subplots(figsize=(7,7), nrows=3, ncols=3, sharey=True, sharex=True)
for ax, img in zip(axes.flatten(), image_batch):
    ax.xaxis.set_visible(False)
    ax.yaxis.set_visible(False)
    im = ax.imshow(img.reshape((28,28)), cmap='Greys_r')

If you spent any amount of time trying to do image classification these numbers better be burned in your mind

# The Fun Stuff™

Now we can get to architecture. For this particular experiment I will be going back to the good ol' days of dense(fully connected) layers. I'm not trying to get state of the art performance here so it's nice to not have to overthink things

# The Baseline 

We will begin with creating a regular fully connected classifier for MNIST and see how it performs. We will use this as the baseline on which to judge the second opinion models

In [22]:
from torch import nn
import torch.nn.functional as F

class Solo_Expert(nn.Module):
    def __init__(self, hidden_dim):
        super(Solo_Expert, self).__init__()
        
        # define hidden linear layers
        self.fc1 = nn.Linear(28*28, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim*2)
        self.fc3 = nn.Linear(hidden_dim*2, hidden_dim*4)
        
        # final fully-connected layer
        self.fc4 = nn.Linear(hidden_dim*4, 10)
        
        # dropout layer 
        self.dropout = nn.Dropout(0.3)
        self.flatten = nn.Flatten()

    def forward(self, x):
        x = self.flatten(x)
        
        # all hidden layers
        x = F.leaky_relu(self.fc1(x), 0.2) # (input, negative_slope=0.2)
        x = self.dropout(x)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = self.dropout(x)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = self.dropout(x)
        # final layer with tanh applied
        out = self.fc4(x)

        return out
    
#Check what device to use
use_cuda = torch.cuda.is_available()
use_mps = torch.backends.mps.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
device = torch.device("mps" if use_mps else "cpu")
print(f"Device is {device}")

HIDDEN_DIM = 32
solo_model = Solo_Expert(HIDDEN_DIM).to(device)

Device is cpu


In [None]:
import torch.optim as optim

from tqdm import tqdm 
import datetime

EPOCHS = 30
LEARNING_RATE = 0.0001
target = "full"

optimizer = optim.SGD(solo_model.parameters(), lr=LEARNING_RATE, momentum=0.9)
loss_fn = nn.CrossEntropyLoss() 

def training_loop(model, target, train_loader, val_loader):
    beginning = datetime.datetime.now()
    for epoch in range(1, EPOCHS + 1):
        total_loss = 0.0
        total_val_loss = 0.0
        best_loss = 9999
        
        #Train
        for (imgs, labels) in tqdm(train_loader, desc="Training"):
            imgs = imgs.to(device)
            labels = labels.to(device)

            model.eval()

            out = solo_model(imgs)
            loss = loss_fn(out, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        #Validate
        for (val_imgs, val_labels) in tqdm(val_loader, desc="Validation"):
            val_imgs = val_imgs.to(device)
            val_labels = val_labels.to(device)

            model.train(True)

            val_out = model(val_imgs)
            val_loss = loss_fn(val_out, val_labels)

            total_val_loss += val_loss.item()

        epoch_val_loss = total_val_loss / len(val_loader)
        epoch_loss = total_loss / len(train_loader)
            
        if epoch_val_loss < best_loss:
            best_loss = epoch_val_loss
            torch.save(model.state_dict(), "data/" + f"MNIST_[{beginning}].pth")

        # if epoch == 1 or epoch % 10 == 0:
        now = datetime.datetime.now()
        print(f"{now}\nEpoch {epoch}\ntr_loss {epoch_loss:.5}\nval_loss {epoch_val_loss:.5}\n")

training_loop(solo_model, target, train_loader, val_loader)

And now to test the accuracy

In [26]:
solo_model.load_state_dict(torch.load(f"data/MNIST_[2023-05-19 20:01:21.259617].pth"))

def validate(model, val_loader):
    model.eval()
    for name, loader in [("val", val_loader)]:
        correct = 0
        total = 0

        with torch.no_grad():
            for imgs, labels in loader:
                imgs = imgs.to(device)
                labels = labels.to(device)
                outputs = model(imgs)
                print(outputs)
                _, predicted = torch.max(outputs, dim=1)
                total += labels.shape[0]
                correct += int((predicted == labels).sum())

        print("Accuracy {}: {:.2f}".format(name , correct / total))

validate(solo_model, val_loader)

tensor([[ 9.8376e-01, -4.9722e+00, -2.9358e+00, -2.4496e-01,  1.7727e-01,
          1.9353e+00, -5.2895e+00,  6.5521e+00, -1.9683e-01,  3.7434e+00],
        [ 2.2175e+00, -3.1020e+00,  6.6008e+00,  3.3791e+00, -4.9611e+00,
          3.2751e+00,  3.3950e+00, -7.5959e+00,  2.3853e+00, -6.2375e+00],
        [-4.9105e+00,  5.3028e+00,  1.5899e+00,  1.3726e+00, -1.2044e+00,
         -3.9618e-01, -6.3773e-01, -4.0844e-01,  1.4526e+00, -3.8401e-01],
        [ 8.8345e+00, -1.5461e+01,  1.8741e+00,  1.0435e-01, -1.9261e+00,
          5.6794e+00,  1.6438e-01, -1.3470e+00,  4.6945e-01, -2.0585e+00],
        [-7.9220e-01, -5.5725e+00, -7.6428e-01, -3.0193e+00,  3.8786e+00,
          4.1368e-01,  7.3784e-01,  6.5906e-01,  9.5891e-01,  2.8883e+00],
        [-5.9699e+00,  6.3069e+00,  1.3965e+00,  1.9979e+00, -1.6289e+00,
         -3.5793e-01, -1.6468e+00,  1.7519e-01,  1.8942e+00, -5.7006e-02],
        [-2.5279e+00, -3.6300e+00, -2.6497e+00, -1.6065e+00,  3.5247e+00,
          8.1324e-01, -1.6441e+0

RuntimeError: The size of tensor a (64) must match the size of tensor b (10) at non-singleton dimension 1

An 86-89% accuracy might not be the knife's edge in terms of classification but it's a fair start (for a dense network)

# Second Opinion

Now that we have seen the performance of the baseline, we can compare it to a small horde of single-minded models. A layer has been removed from these models because trying to classify whether something is 7 or not should not require as many parameters as distinguishing between all 10 numbers. We will start by creating our own transform for labels.

In [None]:
#SLOW ?
class Relabel:
    def __init__(self, target):
        self.target = target
    
    def __call__(self, label):
        if label == target:
            label = torch.int64(0)
        else:
            label = torch.int64(1)

And we will make aslightly lobotomized model for the simpler task (overfitting be damned)

In [None]:
class Expert(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        
        self.fc1 = nn.Linear(28*28, hidden_dim*2)
        self.fc2 = nn.Linear(hidden_dim*2, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, 2)
        
        #Overfitting be damned
        self.dropout = nn.Dropout(0.4)

    def forward(self, x):
        x = self.flatten(x)
        
        # all hidden layers
        x = F.leaky_relu(self.fc1(x), 0.2) # (input, negative_slope=0.2)
        x = self.dropout(x)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = self.dropout(x)
        # final layer with tanh applied
        out = F.tanh(self.fc3(x))

        return out

Now let's modify the training for this hive mind

In [None]:
for target in range(0, 10):
    # print(target)
    target_transform = Relabel(target)
    train_loader, val_loader = get_loaders(target_transform)
    hiveling = Expert(HIDDEN_DIM)

    for _, y in train_loader:
        print(y)

    training_loop(model=hiveling,
                target=target,
                train_loader=train_loader,
                val_loader=val_loader)