# Training
This files is meant to train a model to be used as a test model to my metric

This code was heavily inspired on [Official PyTorch MNIST CNN example](https://github.com/pytorch/examples/tree/main/mnist)

By running this file, you should get a file called `mnist_cnn.pt`, which is a torch.jit model, that contains the entire model to be loaded and tested with the metric.\
This file has to be copied into `./tests/data/mnist_cnn.pt` in order to be used by the project

In [1]:
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np

MODEL_PATH = "mnist_cnn.pt"

def exit():
    class StopExecution(Exception):
        def _render_traceback_(self):
            return []
    raise StopExecution

## Model Architecture - Simple CNN model

In [2]:
from model import Model
from model import get_mnist_dataset_loaders, load_cnn_model

In [3]:
def calculate_metrics(model: Model, test_loader: DataLoader) -> tuple[np.ndarray, float]:
    """
    :return confusion_matrix, accuracy:
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    confusion_matrix = np.zeros((10, 10), dtype=int)
    total = 0
    with torch.no_grad():
        for images_batch, labels_batch in test_loader:
            images_batch, labels_batch = images_batch.to(device), labels_batch.to(device)
            outputs = model(images_batch)
            _, predicted = torch.max(outputs.data, 1)
            for idx in range(labels_batch.size(0)):
                true_label = int(labels_batch[idx].item())
                pred_label = int(predicted[idx].item())
                confusion_matrix[true_label][pred_label] += 1
            total += labels_batch.size(0)
            
    correct = np.trace(confusion_matrix)    
    accuracy = 100 * correct / total
    return confusion_matrix, accuracy

In [4]:
def train_new_model(train_loader, test_loader, num_epochs=30) -> Model:
    model = Model()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    try:
        for epoch in range(1, num_epochs + 1):
            epoch_loss = 0
            for images_batch, labels_batch in tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}"):
                images_batch, labels_batch = images_batch.to(device), labels_batch.to(device)
                
                pred = model(images_batch)
                loss = loss_fn(pred, labels_batch)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()

            print("Epoch loss:", epoch_loss)
            _, accuracy = calculate_metrics(model, test_loader)
            print(f"Accuracy after epoch {epoch}: {accuracy:.2f}%")
        return model
    except KeyboardInterrupt:
        print("Training interrupted, returning the current model state")
        return model


In [5]:
train_loader, test_loader = get_mnist_dataset_loaders()
model: Model | None = load_cnn_model()
if model is not None:
    print("Model already trained on the default path")
    _, accuracy = calculate_metrics(model, test_loader)
    print(f"Previous model accuracy: {accuracy:.2f}%")
    print("Train new model? (y/n): ", end="")
    choice = input().strip().lower()
    if choice == 'n':
        exit()
else:
    print("No model found")

print("Proceeding to training a new model. To interrupt training and potentially save the model weights, press Ctrl+C.")
model = train_new_model(train_loader, test_loader, num_epochs=15)

# Calculate accuracy on test dataset
_, accuracy = calculate_metrics(model, test_loader)
print(f"\nModel accuracy on test dataset: {accuracy:.2f}%")
print("Save model? (y/n): ", end="")
choice = input().strip().lower()
if choice == 'y':
    model.to(torch.device('cpu'))
    model_scripted = torch.jit.script(model)
    torch.jit.save(model_scripted, MODEL_PATH)
    print(f"Model saved to {MODEL_PATH}")

Error loading model from mnist_cnn.pt: The provided filename mnist_cnn.pt does not exist
No model found
Proceeding to training a new model. To interrupt training and potentially save the model weights, press Ctrl+C.


Epoch 1/15: 100%|██████████| 938/938 [00:08<00:00, 105.99it/s]


Epoch loss: 182.07243093289435
Accuracy after epoch 1: 98.43%


Epoch 2/15: 100%|██████████| 938/938 [00:08<00:00, 109.07it/s]


Epoch loss: 34.91340115336061
Accuracy after epoch 2: 98.99%


Epoch 3/15: 100%|██████████| 938/938 [00:08<00:00, 110.54it/s]


Epoch loss: 21.02453301843343
Accuracy after epoch 3: 98.84%


Epoch 4/15: 100%|██████████| 938/938 [00:08<00:00, 110.58it/s]


Epoch loss: 16.039115211915487
Accuracy after epoch 4: 98.94%


Epoch 5/15: 100%|██████████| 938/938 [00:08<00:00, 110.95it/s]


Epoch loss: 11.270184685272397
Accuracy after epoch 5: 98.80%


Epoch 6/15: 100%|██████████| 938/938 [00:08<00:00, 110.03it/s]


Epoch loss: 9.827236655492015
Accuracy after epoch 6: 98.98%


Epoch 7/15: 100%|██████████| 938/938 [00:08<00:00, 109.92it/s]


Epoch loss: 7.261012126436981
Accuracy after epoch 7: 99.03%


Epoch 8/15: 100%|██████████| 938/938 [00:08<00:00, 110.37it/s]


Epoch loss: 5.351004158766045
Accuracy after epoch 8: 98.94%


Epoch 9/15: 100%|██████████| 938/938 [00:08<00:00, 110.62it/s]


Epoch loss: 6.326982337201798
Accuracy after epoch 9: 98.98%


Epoch 10/15: 100%|██████████| 938/938 [00:08<00:00, 111.12it/s]


Epoch loss: 4.98749559967554
Accuracy after epoch 10: 98.96%


Epoch 11/15: 100%|██████████| 938/938 [00:08<00:00, 110.50it/s]


Epoch loss: 4.527220163001328
Accuracy after epoch 11: 98.95%


Epoch 12/15: 100%|██████████| 938/938 [00:08<00:00, 110.27it/s]


Epoch loss: 3.199637951901776
Accuracy after epoch 12: 98.92%


Epoch 13/15: 100%|██████████| 938/938 [00:08<00:00, 110.38it/s]


Epoch loss: 4.473641462640632
Accuracy after epoch 13: 98.86%


Epoch 14/15: 100%|██████████| 938/938 [00:08<00:00, 109.57it/s]


Epoch loss: 4.265924488226407
Accuracy after epoch 14: 99.11%


Epoch 15/15: 100%|██████████| 938/938 [00:08<00:00, 110.94it/s]


Epoch loss: 3.4115158143936277
Accuracy after epoch 15: 98.95%

Model accuracy on test dataset: 98.95%
Save model? (y/n): Model saved to mnist_cnn.pt


In [6]:
# Caclulate Lipschitz metric
from Lipschitz import measure
if model:
    model.to(torch.device('cpu'))
    model.eval()
    images, labels = next(iter(test_loader))
    images, labels = images.cpu(), labels.cpu()
    scores = measure(model, images, labels)
    print("Local Lipschitz Estimates:", scores)
else:
    print("Model is not loaded, please train a model or make sure the mnist_cnn.pt file exists")

  from .autonotebook import tqdm as notebook_tqdm


Local Lipschitz Estimates: [2.1467957496643066]
Local Lipschitz Estimates: [{'name': 'local_lipschitz_estimate', 'score': 2.1467957496643066, 'time': datetime.datetime(2025, 11, 25, 23, 55, 9, 594359)}]
