In [1]:
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm
from torchvision import datasets
from torchvision.transforms import v2
import torch.nn as nn
import torch.optim as optim

In [2]:
batch_size = 32
# MNIST Dataset
mnist_train_dataset = datasets.MNIST(root='../datasets/mnist_data/', train=True, transform=v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]), download=True)
mnist_val_dataset = datasets.MNIST(root='../datasets/mnist_data/', train=False, transform=v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]), download=True)

# Data Loader (Input Pipeline)
mnist_train_loader = torch.utils.data.DataLoader(dataset=mnist_train_dataset, batch_size=batch_size, shuffle=True)
mnist_val_loader = torch.utils.data.DataLoader(dataset=mnist_val_dataset, batch_size=batch_size, shuffle=False)

In [8]:
from my_model import MyModel

model = MyModel()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
loaders = {"train": mnist_train_loader, "valid": mnist_val_loader}
max_epochs = 10

In [None]:
def train_and_val(model, criterion, optimizer, loaders, max_epochs=10):
    accuracy = {"train": [], "valid": []}
    for epoch in range(max_epochs):
        for k, dataloader in loaders.items():
            epoch_correct = 0
            epoch_all = 0
            for x_batch, y_batch in dataloader:
                if k == "train":
                    model.train()
                    optimizer.zero_grad()
                    outp = model(x_batch)
                    loss = criterion(outp, y_batch)
                    loss.backward()
                    optimizer.step()
                else:
                    model.eval()
                    with torch.no_grad():
                        outp = model(x_batch)
                preds = outp.argmax(-1)
                correct = (preds == y_batch).sum()
                all = len(y_batch)
                epoch_correct += correct.item()
                epoch_all += all
            if k == "train":
                print(f"Epoch: {epoch+1}")
            print(f"Loader: {k}. Accuracy: {epoch_correct/epoch_all}")
            accuracy[k].append(epoch_correct/epoch_all)
    return accuracy

In [None]:
accuracy = train_and_val(model, criterion, optimizer, loaders, max_epochs)

plt.figure(figsize=(16, 10))
plt.title("Accuracy")
plt.plot(range(max_epochs), accuracy['train'], label="train", linewidth=2)
plt.plot(range(max_epochs), accuracy['valid'], label="validation", linewidth=2)
plt.legend()
plt.xlabel("Epoch")
plt.show()

In [None]:
torch.save(model.state_dict(), '../../models/baseline')