In [1]:
import torch.nn as nn
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from loguru import logger

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
training_data = datasets.FashionMNIST(
    root="data", train=True, download=True, transform=ToTensor()
)
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

In [3]:
batch_size = 64
train_dataloader = DataLoader(training_data, batch_size, shuffle=True, drop_last=True)
test_dataloader = DataLoader(test_data, batch_size, shuffle=True, drop_last=True)


In [4]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

In [5]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

In [6]:
model = NeuralNetwork()
model = model.to(device)

In [7]:
loss_func = nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), 2e-4)

In [8]:
EPOCH = 5
for i in range(EPOCH):
    logger.info(f"Epoch {i+1}")
    for idx, (x,y) in enumerate(train_dataloader):
        x = x.to(device)
        y = y.to(device)
        y_pred = model(x)
        loss = loss_func(y_pred, y)
        loss.backward()
        opt.step()
        opt.zero_grad()
        if idx % 100 == 0:
            logger.info(f"loss{loss.item()}")

[32m2024-02-05 00:07:13.551[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m3[0m - [1mEpoch 1[0m
[32m2024-02-05 00:07:16.029[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m13[0m - [1mloss2.306499481201172[0m
[32m2024-02-05 00:07:16.739[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m13[0m - [1mloss0.9275075197219849[0m
[32m2024-02-05 00:07:17.426[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m13[0m - [1mloss0.6783807277679443[0m
[32m2024-02-05 00:07:18.113[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m13[0m - [1mloss0.515682578086853[0m
[32m2024-02-05 00:07:18.944[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m13[0m - [1mloss0.4708956480026245[0m
[32m2024-02-05 00:07:19.744[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m13[0m - [1mloss0.49175596237182617[0m
[32m2024-02-05 00:07:20.466[0m | [1mINFO    [0m | [36m__main__[0m:

In [11]:
test_loss = 0
correct = 0
with torch.no_grad():
    for x, y in test_dataloader:
        x, y = x.to(device), y.to(device)
        pred = model(x)
        test_loss += loss_func(pred, y).item()
        correct += (pred.argmax(dim=1) == y).type(torch.float).sum().item()
test_loss = test_loss / len(test_dataloader)
correct = correct / len(test_dataloader.dataset)
logger.info(f"Avg loss is {test_loss}, correct is {correct}")

[32m2024-02-05 11:52:44.921[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m11[0m - [1mAvg loss is 0.3726440172355909, correct is 0.8652[0m
