# Exercise 3.2

In [9]:
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import pickle
from collections import OrderedDict
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix, precision_score, recall_score, accuracy_score
from tqdm import tqdm
from sklearn.model_selection import train_test_split

In [None]:


# 3. Define model for case 'a'
case = 'a'
num_classes = 10 
num_epochs = 500

if case == 'a':
    inputs, n_hidden0, n_hidden1, out = 784*3, 64, 16, 10
    ckpt_pth = 'best_model_NN.pth'
    model = nn.Sequential(
        nn.Linear(inputs, n_hidden0, bias=True), 
        nn.Tanh(),
        nn.Linear(n_hidden0, n_hidden1, bias=True),
        nn.Tanh(),
        nn.Linear(n_hidden1, out, bias=True),
        nn.Softmax(dim=1)
    ).to('cuda')
elif case == 'b':
    ckpt_pth = 'best_model_CNN.pth'
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    model = torch.hub.load('pytorch/vision:v0.10.0', 'alexnet', pretrained=True)
    model.classifier[6] = nn.Linear(4096, num_classes)
    model = model.to('cuda')
else:
    raise ValueError('Case choice is invalid')

model.train()

#load data
dev_path = './data/0_development_data.pkl'
test_path = './data/0_test_data.pkl'

with open(dev_path, 'rb') as f:
    devel_data = pickle.load(f)

with open(test_path, 'rb') as f:
    test_data = pickle.load(f)

#combine
combined_imgs = devel_data[0] + test_data[0]
combined_labels = [int(i[0].split('/')[-2]) for i in combined_imgs]

#split the train and test
train_imgs, temp_imgs, train_labels, temp_labels = train_test_split(
    combined_imgs, combined_labels, test_size=0.25, stratify=combined_labels, random_state=42)

#split the train and validation
val_imgs, test_imgs, val_labels, test_labels = train_test_split(
    temp_imgs, temp_labels, test_size=0.4, stratify=temp_labels, random_state=42)

print(f"Train: {len(train_imgs)}, Validation: {len(val_imgs)}, Test: {len(test_imgs)}")


In [34]:

class CustomDataset(Dataset):
    def __init__(self, image_list, labels, transform=None):
        self.image_list = image_list
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.image_list)
    
    def __getitem__(self, idx):
        if case == 'a':
            image = self.image_list[idx].astype(float)
            image /= 255.0
            image -= np.sum(np.sum(image, 0), 0) / (image.shape[0] * image.shape[1])
        elif case == 'b':
            img_tmp = self.image_list[idx]
            image = preprocess(Image.fromarray(img_tmp))
        label = self.labels[idx]
        return image, label


criterion = nn.CrossEntropyLoss()
if case == 'a':
    optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
else:
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    
train_array_list = [i[1] for i in train_imgs]
val_array_list = [i[1] for i in val_imgs]
test_array_list = [i[1] for i in test_imgs]

dataset_train = CustomDataset(train_array_list, train_labels, transform=None)
dataset_val = CustomDataset(val_array_list, val_labels, transform=None)
dataset_test = CustomDataset(test_array_list, test_labels, transform=None)

batch_size = 32
dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=batch_size, shuffle=True)
dataloader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=False)

In [38]:
# Training loop
early_stopping_patience = 10

best_val_loss = 10000.0
for epoch in range(num_epochs):
    running_loss, running_val_loss = 0.0, 0.0
    model.train()
    for inputs_, labels_ in tqdm(dataloader_train):
        
        if case == 'a': inputs_ = torch.reshape(inputs_, (inputs_.shape[0], -1))
        inputs_, labels_ = inputs_.to(torch.float).to('cuda'), labels_.to('cuda')
        optimizer.zero_grad()
        outputs = model(inputs_)
        loss = criterion(outputs, labels_)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    
    model.eval()
    with torch.no_grad():
        for inputs_val, labels_val in tqdm(dataloader_val):
            if case == 'a': inputs_val = torch.reshape(inputs_val, (inputs_val.shape[0], -1))
            inputs_val, labels_val = inputs_val.to(torch.float).to('cuda'), labels_val.to('cuda')
            outputs_val = model(inputs_val)
            val_loss = criterion(outputs_val, labels_val)
            running_val_loss += val_loss.item()
            
    epoch_val_loss = running_val_loss/len(dataloader_val)
    if epoch_val_loss < best_val_loss:
        early_stopping_counter = 0
        best_val_loss = float(epoch_val_loss)
        torch.save(model.state_dict(), ckpt_pth)
    else:
        early_stopping_counter += 1
        if early_stopping_counter==early_stopping_patience:
            print('-------- Early Stopping ------------')
            print(f'Epoch {epoch+1}, Train loss: {running_loss/len(dataloader_train)}, Val loss: {running_val_loss/len(dataloader_val)}')
            break
        
    print(f'Epoch {epoch+1}, Train loss: {running_loss/len(dataloader_train)}, Val loss: {epoch_val_loss}')

100%|██████████| 329/329 [00:00<00:00, 482.42it/s]
100%|██████████| 329/329 [00:00<00:00, 751.32it/s]


Epoch 1, Train loss: 1.4833602984987857, Val loss: 1.4919023401469083


100%|██████████| 329/329 [00:00<00:00, 473.16it/s]
100%|██████████| 329/329 [00:00<00:00, 754.34it/s]


Epoch 2, Train loss: 1.4803156834605256, Val loss: 1.4933049356321433


100%|██████████| 329/329 [00:00<00:00, 483.84it/s]
100%|██████████| 329/329 [00:00<00:00, 757.27it/s]


Epoch 3, Train loss: 1.4766990638431445, Val loss: 1.4865819132436735


100%|██████████| 329/329 [00:00<00:00, 482.79it/s]
100%|██████████| 329/329 [00:00<00:00, 743.08it/s]


Epoch 4, Train loss: 1.4747128562724336, Val loss: 1.4862191970587502


100%|██████████| 329/329 [00:00<00:00, 484.50it/s]
100%|██████████| 329/329 [00:00<00:00, 755.96it/s]


Epoch 5, Train loss: 1.4724069246767502, Val loss: 1.4849313870751748


100%|██████████| 329/329 [00:00<00:00, 483.54it/s]
100%|██████████| 329/329 [00:00<00:00, 750.26it/s]


Epoch 6, Train loss: 1.4713787235387552, Val loss: 1.489074838922379


100%|██████████| 329/329 [00:00<00:00, 479.98it/s]
100%|██████████| 329/329 [00:01<00:00, 183.58it/s]


Epoch 7, Train loss: 1.471195677493481, Val loss: 1.4852214692936117


100%|██████████| 329/329 [00:00<00:00, 361.67it/s]
100%|██████████| 329/329 [00:00<00:00, 738.57it/s]


Epoch 8, Train loss: 1.4711558373141072, Val loss: 1.4820141285023791


100%|██████████| 329/329 [00:00<00:00, 462.10it/s]
100%|██████████| 329/329 [00:00<00:00, 737.27it/s]


Epoch 9, Train loss: 1.4698161256349558, Val loss: 1.4883903370077487


100%|██████████| 329/329 [00:00<00:00, 475.86it/s]
100%|██████████| 329/329 [00:00<00:00, 737.32it/s]


Epoch 10, Train loss: 1.468535166380978, Val loss: 1.4815568315221908


100%|██████████| 329/329 [00:00<00:00, 475.28it/s]
100%|██████████| 329/329 [00:00<00:00, 738.54it/s]


Epoch 11, Train loss: 1.4678967390379283, Val loss: 1.480916760612766


100%|██████████| 329/329 [00:00<00:00, 475.79it/s]
100%|██████████| 329/329 [00:00<00:00, 739.73it/s]


Epoch 12, Train loss: 1.4692315033503942, Val loss: 1.4820133545478427


100%|██████████| 329/329 [00:00<00:00, 475.46it/s]
100%|██████████| 329/329 [00:00<00:00, 735.67it/s]


Epoch 13, Train loss: 1.4676182208452544, Val loss: 1.484737014335702


100%|██████████| 329/329 [00:00<00:00, 475.97it/s]
100%|██████████| 329/329 [00:00<00:00, 738.57it/s]


Epoch 14, Train loss: 1.4685658455619697, Val loss: 1.4811377162991324


100%|██████████| 329/329 [00:00<00:00, 475.80it/s]
100%|██████████| 329/329 [00:00<00:00, 739.44it/s]


Epoch 15, Train loss: 1.4676379806002584, Val loss: 1.4818147022311086


100%|██████████| 329/329 [00:00<00:00, 475.93it/s]
100%|██████████| 329/329 [00:00<00:00, 739.40it/s]


Epoch 16, Train loss: 1.4673447148778156, Val loss: 1.4821851941952227


100%|██████████| 329/329 [00:00<00:00, 475.38it/s]
100%|██████████| 329/329 [00:00<00:00, 738.87it/s]


Epoch 17, Train loss: 1.4664753389213585, Val loss: 1.4803088291075455


100%|██████████| 329/329 [00:00<00:00, 475.89it/s]
100%|██████████| 329/329 [00:00<00:00, 730.12it/s]


Epoch 18, Train loss: 1.4669296060289656, Val loss: 1.4800791791144838


100%|██████████| 329/329 [00:00<00:00, 472.28it/s]
100%|██████████| 329/329 [00:00<00:00, 743.65it/s]


Epoch 19, Train loss: 1.4664943877686845, Val loss: 1.4799080565345324


100%|██████████| 329/329 [00:00<00:00, 479.52it/s]
100%|██████████| 329/329 [00:00<00:00, 744.38it/s]


Epoch 20, Train loss: 1.4659537689301743, Val loss: 1.4808744347928868


100%|██████████| 329/329 [00:00<00:00, 478.61it/s]
100%|██████████| 329/329 [00:00<00:00, 744.60it/s]


Epoch 21, Train loss: 1.466612387573103, Val loss: 1.4794182505651086


100%|██████████| 329/329 [00:00<00:00, 477.16it/s]
100%|██████████| 329/329 [00:00<00:00, 743.20it/s]


Epoch 22, Train loss: 1.4666648175578711, Val loss: 1.4783940177558041


100%|██████████| 329/329 [00:00<00:00, 476.84it/s]
100%|██████████| 329/329 [00:00<00:00, 743.23it/s]


Epoch 23, Train loss: 1.4657494394974868, Val loss: 1.478183537268711


100%|██████████| 329/329 [00:00<00:00, 478.36it/s]
100%|██████████| 329/329 [00:00<00:00, 743.06it/s]


Epoch 24, Train loss: 1.465164792573923, Val loss: 1.4788499871285854


100%|██████████| 329/329 [00:00<00:00, 477.42it/s]
100%|██████████| 329/329 [00:00<00:00, 742.09it/s]


Epoch 25, Train loss: 1.4652407937499166, Val loss: 1.4783583470028585


100%|██████████| 329/329 [00:00<00:00, 478.38it/s]
100%|██████████| 329/329 [00:00<00:00, 743.23it/s]


Epoch 26, Train loss: 1.464950050264144, Val loss: 1.4777205439686414


100%|██████████| 329/329 [00:00<00:00, 478.45it/s]
100%|██████████| 329/329 [00:00<00:00, 740.59it/s]


Epoch 27, Train loss: 1.4648097188277085, Val loss: 1.4790284662623536


100%|██████████| 329/329 [00:00<00:00, 478.63it/s]
100%|██████████| 329/329 [00:00<00:00, 743.40it/s]


Epoch 28, Train loss: 1.4663040029241683, Val loss: 1.4809010485385328


100%|██████████| 329/329 [00:00<00:00, 477.96it/s]
100%|██████████| 329/329 [00:00<00:00, 741.88it/s]


Epoch 29, Train loss: 1.4663963147572108, Val loss: 1.4803774936583267


100%|██████████| 329/329 [00:00<00:00, 478.90it/s]
100%|██████████| 329/329 [00:00<00:00, 735.61it/s]


Epoch 30, Train loss: 1.4662104486332113, Val loss: 1.4788206214238082


100%|██████████| 329/329 [00:00<00:00, 479.49it/s]
100%|██████████| 329/329 [00:00<00:00, 742.37it/s]


Epoch 31, Train loss: 1.4655148185857523, Val loss: 1.4802744678450934


100%|██████████| 329/329 [00:00<00:00, 479.97it/s]
100%|██████████| 329/329 [00:00<00:00, 746.14it/s]


Epoch 32, Train loss: 1.465321468002528, Val loss: 1.4796846412960156


100%|██████████| 329/329 [00:01<00:00, 184.40it/s]
100%|██████████| 329/329 [00:00<00:00, 742.30it/s]


Epoch 33, Train loss: 1.4656198292880074, Val loss: 1.4791077676152748


100%|██████████| 329/329 [00:00<00:00, 477.90it/s]
100%|██████████| 329/329 [00:00<00:00, 744.70it/s]


Epoch 34, Train loss: 1.4656779592160396, Val loss: 1.4836738605992048


100%|██████████| 329/329 [00:00<00:00, 477.20it/s]
100%|██████████| 329/329 [00:00<00:00, 742.88it/s]


Epoch 35, Train loss: 1.4660280073305032, Val loss: 1.480588198070468


100%|██████████| 329/329 [00:00<00:00, 478.46it/s]
100%|██████████| 329/329 [00:00<00:00, 739.78it/s]

-------- Early Stopping ------------
Epoch 36, Train loss: 1.464792055561912, Val loss: 1.4787404895915812





In [39]:
# Make predictions on the train data
model.load_state_dict(torch.load(ckpt_pth, weights_only=True))
model.eval()

Sequential(
  (0): Linear(in_features=2352, out_features=64, bias=True)
  (1): Tanh()
  (2): Linear(in_features=64, out_features=16, bias=True)
  (3): Tanh()
  (4): Linear(in_features=16, out_features=10, bias=True)
  (5): Softmax(dim=None)
)

In [40]:
# Inference function
def get_predictions(input_batch, model):

    # move the input and model to GPU for speed if available
    if torch.cuda.is_available():
        input_batch = input_batch.to('cuda')
        model.to('cuda')

    with torch.no_grad():
        output = model(input_batch)

    probabilities = torch.nn.functional.softmax(output, dim=1)
    return probabilities

In [44]:
# Run inference
preds_list = []
batch_size = 64
with torch.no_grad():
    for inputs_test, _ in tqdm(dataloader_test):
        if case == 'a': inputs_test = torch.reshape(inputs_test, (inputs_test.shape[0], -1))
        inputs_test = inputs_test.to(torch.float).to('cuda')
        preds_list.append(get_predictions(inputs_test, model).cpu().numpy())
final_preds = np.argmax(np.reshape(np.vstack(preds_list), (-1,10)),1)

100%|██████████| 657/657 [00:00<00:00, 695.05it/s]


In [45]:
# Generate all interesting metrics
def multiclass_metrics(y_true, y_pred, labels):
    """
    Compute per-class accuracy, sensitivity (recall), specificity, and precision.
    
    y_true, y_pred : array-like of shape (n_samples,)
    labels         : list of class labels, e.g. [0,1,...,9]
    """
    # Compute the full confusion matrix once
    cm = confusion_matrix(y_true, y_pred, labels=labels)
    # cm[i, j] is count of true class i predicted as class j
    
    # Prepare containers
    metrics = {
        "class": [],
        "accuracy": [],
        "sensitivity (recall)": [],
        "specificity": [],
        "precision": []
    }
    
    # Total samples
    total = cm.sum()
    
    for idx, cls in enumerate(labels):
        TP = cm[idx, idx]
        FN = cm[idx, :].sum() - TP
        FP = cm[:, idx].sum() - TP
        TN = total - TP - FP - FN
        
        # Per-class metrics
        acc = (TP + TN) / total
        sens = TP / (TP + FN) if (TP + FN) > 0 else 0.0
        spec = TN / (TN + FP) if (TN + FP) > 0 else 0.0
        prec = TP / (TP + FP) if (TP + FP) > 0 else 0.0
        
        metrics["class"].append(cls)
        metrics["accuracy"].append(acc)
        metrics["sensitivity (recall)"].append(sens)
        metrics["specificity"].append(spec)
        metrics["precision"].append(prec)
    
    return pd.DataFrame(metrics)


In [46]:
#performance
report_df = multiclass_metrics(test_labels, final_preds, np.arange(10).tolist()).set_index('class')
report_df

Unnamed: 0_level_0,accuracy,sensitivity (recall),specificity,precision
class,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,0.901062,0.0,0.999366,0.0
1,0.998048,0.988471,0.99925,0.993989
2,0.994287,0.972236,0.996722,0.970377
3,0.995477,0.976562,0.997663,0.979714
4,0.996286,0.994106,0.99652,0.968421
5,0.901633,0.989989,0.892855,0.478604
6,0.996334,0.990817,0.996937,0.972486
7,0.995096,0.970014,0.998032,0.982965
8,0.993334,0.959646,0.996943,0.971116
9,0.995144,0.976122,0.99725,0.975191


In [None]:


#how many train, validation, test coutns there are
print("Train:\n", pd.Series(train_labels).value_counts().sort_index())
print("Validation:\n", pd.Series(val_labels).value_counts().sort_index())
print("Test:\n", pd.Series(test_labels).value_counts().sort_index())

In [None]:
#plotting

train_counts = pd.Series(train_labels).value_counts().sort_index()
val_counts = pd.Series(val_labels).value_counts().sort_index()
test_counts = pd.Series(test_labels).value_counts().sort_index()

#make a dataframe
df = pd.DataFrame({
    'Class': train_counts.index,
    'Train': train_counts.values,
    'Validation': val_counts.values,
    'Test': test_counts.values
})

# Plot
x = np.arange(len(df['Class']))
width = 0.25

plt.figure(figsize=(10, 6))
plt.bar(x - width, df['Train'], width, label='Train')
plt.bar(x, df['Validation'], width, label='Validation')
plt.bar(x + width, df['Test'], width, label='Test')

plt.xlabel('Class')
plt.ylabel('Number of Samples')
plt.title('Class Distribution Across Train, Validation, and Test Sets')
plt.xticks(x, df['Class'])
plt.legend()
plt.grid(axis='y', linestyle='--', alpha=0.5)
plt.tight_layout()

# Save and download
plt.savefig('class_distribution.png', dpi=300)
plt.show()

# Optional: download image from Colab
files.download('class_distribution.png')