In [1]:
import torch
from torch import nn
import torchvision
from torchvision import transforms

In [2]:
from torcheval.metrics import MulticlassAccuracy, Mean
metric = MulticlassAccuracy()




In [3]:
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

In [4]:
from torch.utils.data import DataLoader, default_collate

In [5]:
def dataloader(train_dataset, test_dataset, batch_size):
    return (
        DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, collate_fn=default_collate), 
        DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, collate_fn=default_collate),
        train_dataset.classes
    )

In [6]:
BATCH_SIZE = 64
train_dataloader, test_dataloader, class_names = dataloader(train_dataset, test_dataset, BATCH_SIZE)

In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [8]:
from conv import *
cnn = nn.Sequential(
    conv(1, 4),
    conv(4, 8),
    conv(8, 16),
    conv(16, 16),
    conv(16, 10, act=False),
    nn.Flatten()
)

In [9]:
lr = 0.4

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(params=cnn.parameters(), lr=lr)

In [10]:
from torcheval.metrics import MulticlassAccuracy, Mean

In [11]:
accuracy_metric = MulticlassAccuracy(device=device)
loss_metric = Mean(device=device)

In [12]:
# insert in train 

def fit(epochs, model, train_dataloader, test_dataloader, loss_fn, optimizer, device):
    result = dict.fromkeys(["mode", "epoch", "loss", "accuracy"])
    
    for epoch in range(epochs):
        model.to(device)
        model.train()
        for X, y in train_dataloader:
            X, y = X.to(device), y.to(device)
            loss = loss_fn(model(X), y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        model.eval()
        with torch.inference_mode():
            tot_loss, tot_acc, count = 0.,0.,0
            for X, y in test_dataloader:
                X, y = X.to(device), y.to(device)
                pred = model(X)
                loss_metric.update(loss_fn(model(X), y))
                accuracy_metric.update(pred, y)
            result["mode"] = "test"
            result["epoch"] = epoch
            result["loss"] = loss_metric.compute().item()
            result["accuracy"] = accuracy_metric.compute().item()
            loss_metric.reset()
            accuracy_metric.reset()
        
        print(result)

In [13]:
fit(epochs=3, model=cnn, loss_fn=loss_fn, optimizer=optimizer, train_dataloader=train_dataloader, test_dataloader=test_dataloader, device=device)

{'mode': 'test', 'epoch': 0, 'loss': 0.6504706662551613, 'accuracy': 0.7534000277519226}
{'mode': 'test', 'epoch': 1, 'loss': 0.5538710277930946, 'accuracy': 0.7993000149726868}
{'mode': 'test', 'epoch': 2, 'loss': 0.4363217795160925, 'accuracy': 0.8348000049591064}
