In [None]:
from tqdm import tqdm
import random
import json
import torch
from torch import nn
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.transforms import v2, ToTensor
from torchvision.models import resnet18

In [None]:
transforms = v2.Compose([
            v2.RandomHorizontalFlip(),
            v2.RandomRotation(10),
            v2.RandomAffine(0, shear=10, scale=(0.8, 1.2)),
            v2.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
            v2.ToTensor()
])

In [None]:
train_dataset = datasets.MNIST(
    root="./nodes/data",
    download=False,
    train=True,
    transform=transforms
)

test_dataset = datasets.MNIST(
    root="./nodes/data",
    download=False,
    train=False, 
    transform=ToTensor()
)

In [None]:
batch_size = 32
epochs = 100
num_classes = 10
learning_rate = 1e-3
weight_decay = 1e-4

In [None]:
val_size = len(test_dataset) // 2
val_dataset = torch.utils.data.Subset(test_dataset, range(val_size))
test_dataset = torch.utils.data.Subset(test_dataset, range(val_size, len(test_dataset)))


train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
val_dataloader = DataLoader(val_dataset, batch_size=val_size)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)

In [None]:
class LeNet5(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1)
        self.ap = nn.AvgPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.tanh(self.conv1(x))
        x = self.tanh(self.ap(x))
        x = self.tanh(self.conv2(x))
        x = self.tanh(self.ap(x))
        x = torch.flatten(x, 1)
        x = self.tanh(self.fc1(x))
        x = self.tanh(self.fc2(x))
        return self.fc3(x)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = LeNet5(num_classes=num_classes).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate, weight_decay=weight_decay)
loss_fn = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=3)

In [None]:
def train():
    model.train()
    losses = []
    model.train()
    for b, (X, y) in tqdm(enumerate(train_dataloader)):
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        losses.append(loss.item())
    return sum(losses)/len(losses)

In [None]:
def validate():
    model.eval()
    with torch.no_grad():
        X, y = next(iter(val_dataloader))
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)
    return loss.item()

In [None]:
data = []
for e in range(epochs):
    loss = train()
    val_loss = validate()
    scheduler.step(val_loss)
    data.append({
        "epoch" : e + 1,
        "train_loss" : loss,
        "val_loss" : val_loss
    })

In [None]:
with open("./results/mnist.json", "w") as f:
    f.write(json.dumps(data, indent=2))