In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm

# Device configuration (GPU if available)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Hyperparameters and settings
num_epochs = 30
batch_size = 128
learning_rate = 1e-3
snapshot_interval = 100  # save embedding every 100 iterations

In [None]:
from sklearn.metrics import f1_score, roc_auc_score
from sklearn.preprocessing import label_binarize
import numpy as np
import torch.nn.functional as F

def evaluate(model, loader):
    """
    Return tuple:
        (avg_loss, accuracy, macro_f1, macro_auroc)
    Works for multi-class (10 classes). Uses OVR AUROC.
    """
    model.eval()
    tot, correct, loss_sum = 0, 0, 0.0
    all_labels, all_preds, all_probs = [], [], []

    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss_sum += criterion(logits, y).item() * y.size(0)

            probs = F.softmax(logits, dim=1) # (B, 10)
            preds = probs.argmax(1)

            tot     += y.size(0)
            correct += (preds == y).sum().item()

            all_labels.append(y.cpu())
            all_preds.append(preds.cpu())
            all_probs.append(probs.cpu())

    # concat everything
    labels = torch.cat(all_labels).numpy() # (N,)
    preds  = torch.cat(all_preds).numpy() # (N,)
    probs  = torch.cat(all_probs).numpy() # (N, 10)

    avg_loss = loss_sum / tot
    acc = correct / tot
    macro_f1 = f1_score(labels, preds, average='macro')

    # one-hot labels for AUROC
    labels_bin = label_binarize(labels, classes=np.arange(10))
    macro_auc  = roc_auc_score(labels_bin, probs, multi_class='ovr', average='macro')

    return avg_loss, acc, macro_f1, macro_auc

In [3]:
# MNIST data loaders
transform = transforms.Compose([transforms.ToTensor()])
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_data  = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(test_data, batch_size=batch_size, shuffle=False)

In [None]:
# Define the LeNet-style CNN with 2D embedding
class LeNet2D(nn.Module):
    def __init__(self):
        super(LeNet2D, self).__init__()
        # Convolutional feature extractor
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5)   # 28x28 -> 24x24 (6 feature maps)
        self.pool  = nn.MaxPool2d(2)                 # 24x24 -> 12x12 (downsample)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)  # 12x12 -> 8x8 (16 feature maps)
        # After second pool: 8x8 -> 4x4
        # Flattened feature vector length = 16*4*4 = 256
        self.fc1 = nn.Linear(16*4*4, 64)   # a hidden fully-connected layer
        self.fc2 = nn.Linear(64, 2)        # 2D embedding layer
        self.fc3 = nn.Linear(2, 10)        # output layer (10 classes)
    def forward(self, x, return_embedding=False):
        x = torch.relu(self.conv1(x)) # apply Conv1 + ReLU
        x = self.pool(x)              # downsample
        x = torch.relu(self.conv2(x)) # Conv2 + ReLU
        x = self.pool(x)              # downsample again
        x = x.view(x.size(0), -1)     # flatten (B, 256)
        x = torch.relu(self.fc1(x))   # dense + ReLU
        embed = self.fc2(x)           # 2D embedding (no activation to allow negative coords)
        out = self.fc3(embed)         # class scores
        if return_embedding:
            return embed, out
        else:
            return out

In [None]:
class FCNN2D(nn.Module):
    """
    • 28×28 MNIST flattened to 784.
    • Hidden dims chosen so total params ≈19 k (same as LeNet2D).
    """
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 24)   # 784 -> 24
        self.ln1 = nn.LayerNorm(24)       # normalise
        self.fc2 = nn.Linear(24, 16)      # 24 -> 16
        self.ln2 = nn.LayerNorm(16)
        self.fc_embed = nn.Linear(16, 2)  # 16 -> 2  (latent)
        self.fc_out   = nn.Linear(2, 10)  # 2  -> 10 (logits)

    def forward(self, x, return_embedding=False):
        x = x.view(x.size(0), -1) # flatten to (B,784)
        x = torch.relu(self.ln1(self.fc1(x))) 
        x = torch.relu(self.ln2(self.fc2(x)))
        embed = self.fc_embed(x) # (B,2)
        out   = self.fc_out(embed) # (B,10)
        if return_embedding:
            return embed, out
        return out

In [None]:
# Instantiate models
cnn_model = LeNet2D().to(device)
fcnn_model = FCNN2D().to(device)

# Optimizers and loss function
criterion = nn.CrossEntropyLoss()
optimizer_cnn = torch.optim.Adam(cnn_model.parameters(), lr=learning_rate)
optimizer_fc = torch.optim.Adam(fcnn_model.parameters(), lr=learning_rate)

In [None]:
# Helper function to compute embeddings for all data in a loader
def compute_embeddings(model, data_loader):
    """
    Helper that extracts the 2D latent codes for all samples in data_loader (train or test).
    Returns two NumPy arrays: embeds shape (𝑁,2)(N,2) and labels shape (𝑁,)(N,).
    """
    model.eval()
    embeds_list, labels_list = [], []
    with torch.no_grad(): # Context manager turn off gradient tracking
        for images, labels in data_loader: # per mini-batch
            images = images.to(device) 
            emb, _ = model(images, return_embedding=True) # forward pass to get 2D latent emb
            embeds_list.append(emb.cpu()) # back to cpu to free GPU RAM
            labels_list.append(labels)
    # concat the tensors 
    embeds = torch.cat(embeds_list, dim=0).numpy()   # shape (N, 2)
    labels = torch.cat(labels_list, dim=0).numpy()   # shape (N,)
    return embeds, labels 

In [None]:
# Training loop for CNN
cnn_snapshots = []
global_step = 0
cnn_metrics = []                
model_name = "CNN"
for epoch in range(num_epochs):
    for images, labels in tqdm(train_loader, desc=f"CNN Epoch {epoch+1}", leave=False):
        images, labels = images.to(device), labels.to(device)
        # Forward + backward + optimize
        outputs = cnn_model(images)
        loss = criterion(outputs, labels)
        optimizer_cnn.zero_grad()
        loss.backward()
        optimizer_cnn.step()
        global_step += 1
        # Snapshot embedding every n iterations
        if global_step % snapshot_interval == 0:
            embeds, lbls = compute_embeddings(cnn_model, test_loader)  # using test set for visualization
            cnn_snapshots.append((embeds, lbls))
    # compute epoch-end snapshot as well
    embeds, lbls = compute_embeddings(cnn_model, test_loader)
    cnn_snapshots.append((embeds, lbls))
    tr_loss, tr_acc, tr_f1, tr_auc = evaluate(cnn_model, train_loader)
    te_loss, te_acc, te_f1, te_auc = evaluate(cnn_model, test_loader)

    print(f"[{model_name}] epoch {epoch+1:02d} | "
          f"train loss {tr_loss:.3f} acc {tr_acc*100:.1f}%  ||  "
          f"test loss {te_loss:.3f} acc {te_acc*100:.1f}% "
          f"F1 {te_f1:.3f}  AUROC {te_auc:.3f}")

    cnn_metrics.append((epoch+1, tr_loss, te_loss, tr_acc,  te_acc, te_f1,   te_auc))


                                                               

[CNN] epoch 01 | train loss 0.826 acc 68.5%  ||  test loss 0.803 acc 69.3% F1 0.661  AUROC 0.968


                                                               

[CNN] epoch 02 | train loss 0.498 acc 88.1%  ||  test loss 0.476 acc 88.5% F1 0.885  AUROC 0.989


                                                               

[CNN] epoch 03 | train loss 0.329 acc 92.3%  ||  test loss 0.317 acc 92.9% F1 0.929  AUROC 0.994


                                                               

[CNN] epoch 04 | train loss 0.264 acc 94.0%  ||  test loss 0.259 acc 94.4% F1 0.944  AUROC 0.996


                                                               

[CNN] epoch 05 | train loss 0.229 acc 94.7%  ||  test loss 0.232 acc 94.9% F1 0.949  AUROC 0.996


                                                               

[CNN] epoch 06 | train loss 0.205 acc 95.2%  ||  test loss 0.220 acc 95.0% F1 0.950  AUROC 0.997


                                                               

[CNN] epoch 07 | train loss 0.195 acc 95.2%  ||  test loss 0.211 acc 95.1% F1 0.951  AUROC 0.997


                                                               

[CNN] epoch 08 | train loss 0.165 acc 96.3%  ||  test loss 0.186 acc 96.2% F1 0.962  AUROC 0.998


                                                               

[CNN] epoch 09 | train loss 0.172 acc 96.0%  ||  test loss 0.198 acc 95.5% F1 0.955  AUROC 0.997


                                                                

[CNN] epoch 10 | train loss 0.142 acc 96.5%  ||  test loss 0.177 acc 96.3% F1 0.962  AUROC 0.998


                                                                

[CNN] epoch 11 | train loss 0.133 acc 96.7%  ||  test loss 0.168 acc 96.5% F1 0.964  AUROC 0.998


                                                                

[CNN] epoch 12 | train loss 0.126 acc 96.9%  ||  test loss 0.165 acc 96.6% F1 0.966  AUROC 0.998


                                                                

[CNN] epoch 13 | train loss 0.155 acc 96.6%  ||  test loss 0.206 acc 96.2% F1 0.962  AUROC 0.997


                                                                

[CNN] epoch 14 | train loss 0.111 acc 97.1%  ||  test loss 0.160 acc 96.4% F1 0.964  AUROC 0.998


                                                                

[CNN] epoch 15 | train loss 0.109 acc 97.2%  ||  test loss 0.161 acc 96.5% F1 0.965  AUROC 0.998


                                                                

[CNN] epoch 16 | train loss 0.098 acc 97.4%  ||  test loss 0.156 acc 97.0% F1 0.970  AUROC 0.998


                                                                

[CNN] epoch 17 | train loss 0.089 acc 97.7%  ||  test loss 0.146 acc 96.9% F1 0.969  AUROC 0.998


                                                                

[CNN] epoch 18 | train loss 0.094 acc 97.6%  ||  test loss 0.153 acc 97.2% F1 0.972  AUROC 0.998


                                                                

[CNN] epoch 19 | train loss 0.086 acc 97.7%  ||  test loss 0.133 acc 97.2% F1 0.972  AUROC 0.999


                                                                

[CNN] epoch 20 | train loss 0.080 acc 97.8%  ||  test loss 0.140 acc 97.2% F1 0.971  AUROC 0.999


                                                                

[CNN] epoch 21 | train loss 0.068 acc 98.2%  ||  test loss 0.143 acc 97.5% F1 0.975  AUROC 0.999


                                                                

[CNN] epoch 22 | train loss 0.070 acc 98.1%  ||  test loss 0.150 acc 97.3% F1 0.973  AUROC 0.999


                                                                

[CNN] epoch 23 | train loss 0.080 acc 97.7%  ||  test loss 0.128 acc 97.2% F1 0.971  AUROC 0.999


                                                                

[CNN] epoch 24 | train loss 0.068 acc 98.2%  ||  test loss 0.143 acc 97.5% F1 0.975  AUROC 0.999


                                                                

[CNN] epoch 25 | train loss 0.068 acc 98.1%  ||  test loss 0.139 acc 97.4% F1 0.974  AUROC 0.999


                                                                

[CNN] epoch 26 | train loss 0.063 acc 98.2%  ||  test loss 0.140 acc 97.5% F1 0.975  AUROC 0.999


                                                                

[CNN] epoch 27 | train loss 0.072 acc 98.0%  ||  test loss 0.166 acc 97.1% F1 0.971  AUROC 0.998


                                                                

[CNN] epoch 28 | train loss 0.047 acc 98.7%  ||  test loss 0.132 acc 97.7% F1 0.977  AUROC 0.999


                                                                

[CNN] epoch 29 | train loss 0.054 acc 98.5%  ||  test loss 0.141 acc 97.6% F1 0.976  AUROC 0.999


                                                                

[CNN] epoch 30 | train loss 0.044 acc 98.8%  ||  test loss 0.145 acc 97.7% F1 0.977  AUROC 0.999


In [None]:
# Training loop for FCNN
fcnn_snapshots = []
fcnn_metrics   = []
model_name     = "FCNN2D"
global_step    = 0

for epoch in range(num_epochs):
    for images, labels in tqdm(train_loader,desc=f"{model_name} Epoch {epoch+1}",leave=False):
        images, labels = images.to(device), labels.to(device)

        # forward / backward / update
        outputs = fcnn_model(images)
        loss = criterion(outputs, labels)
        optimizer_fc.zero_grad()
        loss.backward()
        optimizer_fc.step()
        global_step += 1

        # snapshot every N iterations
        if global_step % snapshot_interval == 0:
            embeds, lbls = compute_embeddings(fcnn_model, test_loader)
            fcnn_snapshots.append((embeds, lbls))

    # epoch-end snapshot (optional but nice)
    embeds, lbls = compute_embeddings(fcnn_model, test_loader)
    fcnn_snapshots.append((embeds, lbls))

    # evaluate & log
    tr_loss, tr_acc, tr_f1, tr_auc = evaluate(fcnn_model, train_loader)
    te_loss, te_acc, te_f1, te_auc = evaluate(fcnn_model, test_loader)

    print(f"[{model_name}] epoch {epoch+1:02d} | "
          f"train loss {tr_loss:.3f} acc {tr_acc*100:.1f}%  ||  "
          f"test  loss {te_loss:.3f} acc {te_acc*100:.1f}% "
          f"F1 {te_f1:.3f}  AUROC {te_auc:.3f}")

    fcnn_metrics.append((epoch+1, tr_loss, te_loss,
                         tr_acc,  te_acc,
                         te_f1,   te_auc))

                                                                  

[FCNN2D] epoch 01 | train loss 0.830 acc 77.1%  ||  test  loss 0.817 acc 77.1% F1 0.729  AUROC 0.981


                                                                  

[FCNN2D] epoch 02 | train loss 0.529 acc 89.3%  ||  test  loss 0.532 acc 89.4% F1 0.892  AUROC 0.988


                                                                  

[FCNN2D] epoch 03 | train loss 0.438 acc 90.7%  ||  test  loss 0.466 acc 90.2% F1 0.901  AUROC 0.989


                                                                  

[FCNN2D] epoch 04 | train loss 0.354 acc 92.6%  ||  test  loss 0.414 acc 91.8% F1 0.917  AUROC 0.990


                                                                  

[FCNN2D] epoch 05 | train loss 0.303 acc 93.3%  ||  test  loss 0.353 acc 92.8% F1 0.927  AUROC 0.992


                                                                  

[FCNN2D] epoch 06 | train loss 0.259 acc 94.2%  ||  test  loss 0.333 acc 93.1% F1 0.930  AUROC 0.993


                                                                 

[FCNN2D] epoch 07 | train loss 0.243 acc 94.4%  ||  test  loss 0.314 acc 93.6% F1 0.935  AUROC 0.993


                                                                  

[FCNN2D] epoch 08 | train loss 0.239 acc 94.5%  ||  test  loss 0.330 acc 93.5% F1 0.935  AUROC 0.993


                                                                 

[FCNN2D] epoch 09 | train loss 0.220 acc 94.7%  ||  test  loss 0.317 acc 93.6% F1 0.935  AUROC 0.993


                                                                   

[FCNN2D] epoch 10 | train loss 0.219 acc 94.7%  ||  test  loss 0.320 acc 93.5% F1 0.934  AUROC 0.993


                                                                   

[FCNN2D] epoch 11 | train loss 0.205 acc 94.9%  ||  test  loss 0.310 acc 93.5% F1 0.935  AUROC 0.993


                                                                   

[FCNN2D] epoch 12 | train loss 0.208 acc 94.8%  ||  test  loss 0.324 acc 93.5% F1 0.935  AUROC 0.993


                                                                   

[FCNN2D] epoch 13 | train loss 0.223 acc 94.5%  ||  test  loss 0.337 acc 93.3% F1 0.933  AUROC 0.993


                                                                   

[FCNN2D] epoch 14 | train loss 0.213 acc 94.5%  ||  test  loss 0.328 acc 92.9% F1 0.929  AUROC 0.993


                                                                   

[FCNN2D] epoch 15 | train loss 0.172 acc 95.7%  ||  test  loss 0.319 acc 93.8% F1 0.938  AUROC 0.994


                                                                   

[FCNN2D] epoch 16 | train loss 0.165 acc 95.8%  ||  test  loss 0.304 acc 94.0% F1 0.940  AUROC 0.994


                                                                   

[FCNN2D] epoch 17 | train loss 0.175 acc 95.4%  ||  test  loss 0.308 acc 93.7% F1 0.936  AUROC 0.994


                                                                   

[FCNN2D] epoch 18 | train loss 0.163 acc 95.7%  ||  test  loss 0.296 acc 93.9% F1 0.938  AUROC 0.994


                                                                   

[FCNN2D] epoch 19 | train loss 0.154 acc 95.9%  ||  test  loss 0.323 acc 93.8% F1 0.938  AUROC 0.994


                                                                   

[FCNN2D] epoch 20 | train loss 0.167 acc 95.6%  ||  test  loss 0.343 acc 93.5% F1 0.935  AUROC 0.993


                                                                   

[FCNN2D] epoch 21 | train loss 0.144 acc 96.2%  ||  test  loss 0.326 acc 93.9% F1 0.939  AUROC 0.993


                                                                   

[FCNN2D] epoch 22 | train loss 0.168 acc 95.3%  ||  test  loss 0.342 acc 93.2% F1 0.932  AUROC 0.993


                                                                   

[FCNN2D] epoch 23 | train loss 0.137 acc 96.4%  ||  test  loss 0.305 acc 94.3% F1 0.942  AUROC 0.994


                                                                   

[FCNN2D] epoch 24 | train loss 0.141 acc 96.3%  ||  test  loss 0.349 acc 93.9% F1 0.938  AUROC 0.993


                                                                   

[FCNN2D] epoch 25 | train loss 0.130 acc 96.6%  ||  test  loss 0.325 acc 94.3% F1 0.943  AUROC 0.994


                                                                   

[FCNN2D] epoch 26 | train loss 0.141 acc 96.3%  ||  test  loss 0.360 acc 94.0% F1 0.939  AUROC 0.993


                                                                   

[FCNN2D] epoch 27 | train loss 0.137 acc 96.4%  ||  test  loss 0.343 acc 93.8% F1 0.937  AUROC 0.993


                                                                   

[FCNN2D] epoch 28 | train loss 0.135 acc 96.2%  ||  test  loss 0.331 acc 93.9% F1 0.938  AUROC 0.994


                                                                   

[FCNN2D] epoch 29 | train loss 0.122 acc 96.7%  ||  test  loss 0.320 acc 94.2% F1 0.941  AUROC 0.994


                                                                   

[FCNN2D] epoch 30 | train loss 0.121 acc 96.7%  ||  test  loss 0.325 acc 94.0% F1 0.939  AUROC 0.994


In [14]:
import matplotlib.pyplot as plt
import numpy as np
import imageio

def create_embedding_video(snapshot_data, output_filename):
    """snapshot_data is a list of (embeds, labels) tuples."""
    images = []
    # Define colors for 10 classes
    cmap = plt.get_cmap('tab10')  # tab10 has 10 distinct colors
    for i, (embeds, labels) in enumerate(snapshot_data):
        plt.figure(figsize=(16,9))
        plt.xlim(-200, 200)
        plt.ylim(-200, 200)
        plt.gca().set_facecolor('black')
        plt.title(f"Iteration {i * snapshot_interval}")  # simple title to indicate training step
        # Plot each class in its color
        for digit in range(10):
            pts = embeds[labels == digit]
            if len(pts) > 0:
                plt.scatter(pts[:,0], pts[:,1], s=5, color=cmap(digit), label=str(digit))
        plt.legend(loc='upper right', fontsize='x-small')
        # Render plot to array and append
        plt.tight_layout()
        canvas = plt.gca().figure.canvas
        canvas.draw()
        frame = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
        frame = frame.reshape(canvas.get_width_height()[::-1] + (3,))
        images.append(frame)
        plt.close()
    # Save frames as video
    imageio.mimsave(output_filename, images, fps=10, quality=8)
# Create videos for CNN 
create_embedding_video(cnn_snapshots, "cnn_embedding_evolution.mp4")

  frame = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import imageio
def create_embedding_video(snapshot_data, output_filename):
    """snapshot_data is a list of (embeds, labels) tuples."""
    images = []
    # Define colors for 10 classes
    cmap = plt.get_cmap('tab10')  # tab10 has 10 distinct colors
    for i, (embeds, labels) in enumerate(snapshot_data):
        plt.figure(figsize=(16,9))
        plt.xlim(-50, 50)
        plt.ylim(-50, 50)
        plt.gca().set_facecolor('black')
        plt.title(f"Iteration {i * snapshot_interval}")  # simple title to indicate training step
        # Plot each class in its color
        for digit in range(10):
            pts = embeds[labels == digit]
            if len(pts) > 0:
                plt.scatter(pts[:,0], pts[:,1], s=5, color=cmap(digit), label=str(digit))
        plt.legend(loc='upper right', fontsize='x-small')
        # Render plot to array and append
        plt.tight_layout()
        canvas = plt.gca().figure.canvas
        canvas.draw()
        frame = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
        frame = frame.reshape(canvas.get_width_height()[::-1] + (3,))
        images.append(frame)
        plt.close()
    # Save frames as video
    imageio.mimsave(output_filename, images, fps=10, quality=8)

In [41]:
create_embedding_video(fcnn_snapshots , "FCNN_embedding_evolution1.mp4")

  frame = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')


In [None]:
import pandas as pd
df_cnn  = pd.DataFrame(cnn_metrics,
                       columns=["epoch","train_loss","test_loss",
                                "train_acc","test_acc",
                                "test_F1","test_AUROC"])
# df_trans = pd.DataFrame(trans_metrics, columns=df_cnn.columns)

# quick look
print(f"CNN: {df_cnn.shape} rows")
print(df_cnn.tail())
# print(f"Transformer: {df_trans.shape} rows")
# print(df_trans.tail())
df_fcnn = pd.DataFrame(fcnn_metrics, columns=df_cnn.columns)
# quick look
print(f"FCNN: {df_fcnn.shape} rows")
print(df_fcnn.tail())

CNN: (30, 7) rows
    epoch  train_loss  test_loss  train_acc  test_acc   test_F1  test_AUROC
25     26    0.063267   0.140310   0.982217    0.9752  0.975038    0.998679
26     27    0.072480   0.166481   0.980283    0.9714  0.971310    0.998467
27     28    0.046527   0.132425   0.986600    0.9771  0.977060    0.998821
28     29    0.053943   0.140672   0.984817    0.9758  0.975781    0.998789
29     30    0.044175   0.145351   0.987983    0.9773  0.977121    0.998788
Transformer: (30, 7) rows
    epoch  train_loss  test_loss  train_acc  test_acc   test_F1  test_AUROC
25     26    0.191636   0.292215   0.954000    0.9433  0.943458    0.994094
26     27    0.207650   0.286001   0.947250    0.9411  0.941530    0.994038
27     28    0.164413   0.264923   0.958867    0.9456  0.945535    0.994730
28     29    0.157626   0.252403   0.961367    0.9514  0.951172    0.995220
29     30    0.171211   0.268318   0.957267    0.9453  0.945462    0.994534
FCNN: (30, 7) rows
    epoch  train_loss  te

In [None]:
import pandas as pd, numpy as np, matplotlib.pyplot as plt, seaborn as sns
from sklearn.metrics import roc_curve, auc, confusion_matrix
from matplotlib import gridspec

sns.set(style="white", font_scale=1.0) # clean aesthetics
cmap10 = plt.get_cmap('tab10', 10)

# pull last row (epoch, train_loss, test_loss, train_acc, test_acc, F1, AUROC)
cnn_last   = cnn_metrics[-1]

summary_df = pd.DataFrame([
    ["CNN"        , *cnn_last[1:]],      # skip epoch col
    ["Transformer", *trans_last[1:]]
], columns=["model", "train_loss", "test_loss",
            "train_acc", "test_acc",
            "test_F1", "test_AUROC"]).round(3)

print("\n=== FINAL EPOCH METRICS ===")
display(summary_df)


=== FINAL EPOCH METRICS ===


Unnamed: 0,model,train_loss,test_loss,train_acc,test_acc,test_F1,test_AUROC
0,CNN,0.044,0.145,0.988,0.977,0.977,0.999
1,Transformer,0.171,0.268,0.957,0.945,0.945,0.995


  plt.tight_layout()


In [None]:
#  The final 2D snapshots 
# assume last element of snapshots lists
emb_cnn, lbl_cnn = cnn_snapshots[-1]
emb_fcnn, lbl_fcnn = fcnn_snapshots[-1]

fig, axes = plt.subplots(1,2, figsize=(8,4), dpi=200)
for ax, (emb,lbl,title) in zip(axes,
        [(emb_cnn,lbl_cnn,"CNN final"), (emb_fcnn,lbl_fcnn,"FCNN final")]):
    ax.scatter(emb[:,0], emb[:,1], c=lbl, cmap='tab10', s=4, alpha=0.7)
    ax.set_xlim(-50,50); ax.set_ylim(-50,50)
    ax.set_facecolor('black'); ax.set_title(title)
    ax.set_xticks([]); ax.set_yticks([])
fig.suptitle("Final 2-D embeddings (test set)")
plt.tight_layout()
fig.savefig("final_embeddings_compare1.png", dpi=300, facecolor='black')
plt.close(fig)

print("Saved:")
print("  summary_metrics.csv")
print("  cm_CNN.png  cm_FCNN1.png")
print("  final_embeddings_compare1.png")


Saved:
  summary_metrics.csv
  cm_CNN.png  cm_FCNN1.png
  final_embeddings_compare1.png


In [None]:
def get_out_weights_bias(model):
    """
    Return (W, b) from the linear layer that maps 2D -> 10 logits.
    Tries common attribute names, otherwise grabs the last nn.Linear.
    """
    for cand in ["fc_out", "fc3", "output", "classifier"]:
        layer = getattr(model, cand, None) 
        if isinstance(layer, nn.Linear) and layer.out_features == 10:
            return layer.weight.detach().cpu().numpy(), layer.bias.detach().cpu().numpy()

    # fallback: scan modules in reverse for a suitable Linear
    for m in reversed(list(model.modules())):
        if isinstance(m, nn.Linear) and m.out_features == 10:
            return m.weight.detach().cpu().numpy(), m.bias.detach().cpu().numpy()
    raise ValueError("Couldn't locate 2→10 output linear layer")

def plot_embedding_with_regions(model, embeds, labels,
                                title="latent space", save="emb_regions.png",
                                xy_lim=30, res=200):
    W, b = get_out_weights_bias(model) 

    lin = np.linspace(-xy_lim, xy_lim, res) 
    xx, yy = np.meshgrid(lin, lin)
    grid   = np.stack([xx.ravel(), yy.ravel()], axis=1)
    logits = grid @ W.T + b
    pred   = logits.argmax(1).reshape(res, res)

    cmap_bg = plt.cm.get_cmap("tab10", 10)
    fig, ax = plt.subplots(figsize=(5,5), dpi=200)
    ax.imshow(pred, extent=[-xy_lim, xy_lim, -xy_lim, xy_lim],
              origin='lower', cmap=cmap_bg, alpha=0.15)
    sc = ax.scatter(embeds[:,0], embeds[:,1],
                    c=labels, cmap='tab10', s=6, edgecolors='k', linewidths=0.15)
    ax.set_xlim(-xy_lim, xy_lim); ax.set_ylim(-xy_lim, xy_lim)
    ax.set_facecolor('black'); ax.set_title(title)
    ax.set_xticks([]); ax.set_yticks([])
    fig.tight_layout(); fig.savefig(save, dpi=300, facecolor='black')
    plt.close(fig); print("saved", save)


In [44]:
emb_fcnn, lbl_fcnn = fcnn_snapshots[-1]
plot_embedding_with_regions(fcnn_model, emb_fcnn, lbl_fcnn,
                            title="FCNN – decision regions",
                            save="fcnn_regions_zoomed.png")

  cmap_bg = plt.cm.get_cmap("tab10", 10)


saved fcnn_regions_zoomed.png


In [23]:
# final snapshot data:
emb_cnn , lbl_cnn  = cnn_snapshots[-1]

plot_embedding_with_regions(cnn_model , emb_cnn , lbl_cnn ,
                            title="CNN – decision regions",
                            save="cnn_regions.png")

  cmap_bg = plt.cm.get_cmap("tab10", 10)


saved cnn_regions.png
