In [1]:
from torchcfm.models.mnist_classifier.lenet import LeNet5
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets.mnist import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np

# load model

In [3]:
def load_mnist_classifier(device: torch.device | None = None):
    if device is None: device = torch.get_default_device()  
    net = LeNet5().eval()
    net.load_state_dict(torch.load('../../torchcfm/models/mnist_classifier/weights/lenet_epoch=12_test_acc=0.991.pth'))
    net = net.to(device)
    return net


net = load_mnist_classifier()

# load mnist test set

In [4]:
data_root = '../data/'
data_test = MNIST(
    data_root,
    train=False,
    download=True,
    transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor()
    ])
)
data_loader = torch.utils.data.DataLoader(data_test, 
                                          batch_size=10000,
                                          shuffle=False)
(ims, labs) = next(iter(data_loader)) # get whole dataset

In [11]:
ims[0].min()

tensor(0.)

# calculate test acc

In [5]:
@torch.inference_mode()
def mnist_acc(net, ims, labs):
    preds = net(ims)
    class_preds = torch.argmax(preds, dim=1)
    return torch.sum(class_preds == labs).float() / len(labs)


@torch.inference_mode()
def mnist_acc_per_class(net, ims, labs, num_classes=10):
    """
    Compute per-class accuracy for MNIST-style classification.

    Args:
        net: classifier network
        ims: input images [N,C,H,W]
        labs: ground truth labels [N]
        num_classes: number of classes (default=10 for MNIST)
    Returns:
        dict {class_idx: accuracy (0..1)}
    """
    preds = net(ims)
    class_preds = torch.argmax(preds, dim=1)

    accs = {}
    for c in range(num_classes):
        mask = (labs == c)
        if mask.sum() > 0:
            correct = (class_preds[mask] == labs[mask]).sum().float()
            accs[c] = (correct / mask.sum()).item()
        else:
            accs[c] = float("nan")  # no samples of this class
    return accs

    

print(f"Total acc: {mnist_acc(net, ims, labs):.4f}")
for c, acc in mnist_acc_per_class(net, ims, labs).items():
    print(f"Class {c} acc: {acc:.4f}")


Total acc: 0.9907
Class 0 acc: 0.9969
Class 1 acc: 0.9956
Class 2 acc: 0.9893
Class 3 acc: 0.9931
Class 4 acc: 0.9888
Class 5 acc: 0.9865
Class 6 acc: 0.9896
Class 7 acc: 0.9922
Class 8 acc: 0.9846
Class 9 acc: 0.9891


# Generate digits

In [None]:
from torchdiffeq import odeint


@torch.inference_mode()
def generate_mnist_testset(model, n: int = 1000):
    generated_class_list = torch.arange(10).repeat(n)
    return odeint(
        lambda t, x: model.forward(t, x, generated_class_list),
            torch.randn(n, 1, 28, 28),
            torch.linspace(0, 1, 2),
            atol=1e-4,
            rtol=1e-4,
            method="dopri5",
    )[-1]

# Summary

In [None]:
def load_mnist_classifier(device: torch.device | None = None):
    if device is None: device = torch.get_default_device()  
    net = LeNet5().eval()
    net.load_state_dict(torch.load('../../torchcfm/models/mnist_classifier/weights/lenet_epoch=12_test_acc=0.991.pth'))
    net = net.to(device)
    return net


@torch.inference_mode()
def mnist_acc(net, ims, labs):
    preds = net(ims)
    class_preds = torch.argmax(preds, dim=1)
    return torch.sum(class_preds == labs).float() / len(labs)


@torch.inference_mode()
def mnist_acc_per_class(net, ims, labs, num_classes=10):
    """
    Compute per-class accuracy for MNIST-style classification.

    Args:
        net: classifier network
        ims: input images [N,C,H,W]
        labs: ground truth labels [N]
        num_classes: number of classes (default=10 for MNIST)
    Returns:
        dict {class_idx: accuracy (0..1)}
    """
    preds = net(ims)
    class_preds = torch.argmax(preds, dim=1)

    accs = {}
    for c in range(num_classes):
        mask = (labs == c)
        if mask.sum() > 0:
            correct = (class_preds[mask] == labs[mask]).sum().float()
            accs[c] = (correct / mask.sum()).item()
        else:
            accs[c] = float("nan")  # no samples of this class
    return accs


@torch.inference_mode()
def generate_mnist_testset(model, n: int = 1000):
    generated_class_list = torch.arange(10).repeat(n)
    return odeint(
        lambda t, x: model.forward(t, x, generated_class_list),
            torch.randn(n, 1, 28, 28),
            torch.linspace(0, 1, 2),
            atol=1e-4,
            rtol=1e-4,
            method="dopri5",
    )[-1]