In [1]:
from tqdm import tqdm
import random
import json
import torch
from torch import nn
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.transforms import v2, ToTensor
from torchvision.models import resnet18

In [2]:
transforms = v2.Compose([
            v2.RandomHorizontalFlip(),
            v2.RandomRotation(10),
            v2.RandomAffine(0, shear=10, scale=(0.8, 1.2)),
            v2.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
            v2.ToTensor()
])



In [3]:
train_dataset = datasets.FashionMNIST(
    root="./nodes/data",
    download=False,
    train=True,
    transform=transforms
)

test_dataset = datasets.FashionMNIST(
    root="./nodes/data",
    download=False,
    train=False, 
    transform=ToTensor()
)

In [4]:
batch_size = 32
epochs = 100
num_classes = 10
learning_rate = 1e-3
weight_decay = 1e-4

In [5]:
val_size = len(test_dataset) // 2
val_dataset = torch.utils.data.Subset(test_dataset, range(val_size))
test_dataset = torch.utils.data.Subset(test_dataset, range(val_size, len(test_dataset)))


train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
val_dataloader = DataLoader(val_dataset, batch_size=val_size)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)

In [6]:
class LeNet5(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1)
        self.ap = nn.AvgPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.tanh(self.conv1(x))
        x = self.tanh(self.ap(x))
        x = self.tanh(self.conv2(x))
        x = self.tanh(self.ap(x))
        x = torch.flatten(x, 1)
        x = self.tanh(self.fc1(x))
        x = self.tanh(self.fc2(x))
        return self.fc3(x)

In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = LeNet5(num_classes=num_classes).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate, weight_decay=weight_decay)
loss_fn = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=3)

In [8]:
def train():
    model.train()
    losses = []
    model.train()
    for b, (X, y) in tqdm(enumerate(train_dataloader)):
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        losses.append(loss.item())
    return sum(losses)/len(losses)

In [9]:
def validate():
    model.eval()
    with torch.no_grad():
        X, y = next(iter(val_dataloader))
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)
    return loss.item()

In [10]:
data = []
for e in range(epochs):
    loss = train()
    val_loss = validate()
    scheduler.step(val_loss)
    data.append({
        "epoch" : e + 1,
        "train_loss" : loss,
        "val_loss" : val_loss
    })

1875it [00:41, 45.02it/s]
1875it [00:39, 47.14it/s]
1875it [00:40, 46.74it/s]
1875it [00:39, 46.92it/s]
1875it [00:44, 42.50it/s]
1875it [00:43, 43.14it/s]
1875it [00:44, 41.83it/s]
1875it [00:43, 43.31it/s]
1875it [00:42, 43.62it/s]
1875it [00:43, 42.66it/s]
1875it [00:42, 44.57it/s]
1875it [00:45, 41.34it/s]
1875it [00:40, 46.67it/s]
1875it [00:42, 43.94it/s]
1875it [00:43, 43.03it/s]
1875it [00:40, 46.52it/s]
1875it [00:45, 40.94it/s]
1875it [00:42, 44.49it/s]
1875it [00:41, 44.76it/s]
1875it [00:40, 46.39it/s]
1875it [00:44, 41.67it/s]
1875it [00:45, 41.56it/s]
1875it [00:44, 41.85it/s]
1875it [00:39, 47.09it/s]
1875it [00:42, 44.23it/s]
1875it [00:46, 40.67it/s]
1875it [00:45, 41.30it/s]
1875it [00:42, 44.48it/s]
1875it [00:39, 47.52it/s]
1875it [00:37, 49.92it/s]
1875it [00:43, 43.37it/s]
1875it [00:42, 44.30it/s]
1875it [00:40, 46.56it/s]
1875it [00:38, 49.07it/s]
1875it [00:39, 46.88it/s]
1875it [00:39, 46.91it/s]
1875it [00:39, 47.23it/s]
1875it [00:39, 47.38it/s]
1875it [00:4

In [11]:
with open("./results/fmnist.json", "w") as f:
    f.write(json.dumps(data, indent=2))