In [1]:
import torch
from models import ClassifierMNIST
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

In [12]:
def train_mnist_classifier(model: nn.Module, 
                           batch_size: int = 64, 
                           epochs: int = 5, 
                           learning_rate: float = 0.001):
    # Transform the data
    transform = transforms.Compose([transforms.ToTensor()])

    # Load the datasets
    train_dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST('../data', train=False, transform=transform)

    # Prepare data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # Initialize the loss function and optimizer
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Training loop
    for epoch in range(epochs):
        model.train()
        with tqdm(train_loader) as tepoch:
            for data, target in tepoch:
                tepoch.set_description(f"Epoch {epoch}")
                optimizer.zero_grad()
                output = model(data)
                loss = loss_function(output, target)
                loss.backward()
                optimizer.step()
                
                tepoch.set_postfix(loss=loss.item())

    # Testing loop
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            test_loss += loss_function(output, target).item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} '
          f'({100. * correct / len(test_loader.dataset):.0f}%)')

In [13]:
classifier = ClassifierMNIST()

In [14]:
# Train the classifer
train_mnist_classifier(classifier)

  0%|          | 0/938 [00:00<?, ?it/s]

  0%|          | 0/938 [00:00<?, ?it/s]

  0%|          | 0/938 [00:00<?, ?it/s]

  0%|          | 0/938 [00:00<?, ?it/s]

  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0006, Accuracy: 9878/10000 (99%)


In [15]:
# Save model weights
torch.save(classifier.state_dict(),f='models/mnist_classifier.pt')