In [5]:
import torch
import torch.nn as nn
import torchmetrics
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms.v2 as T

In [6]:
toTensor = T.Compose([T.ToImage(), T.ToDtype(torch.float32, scale=True)])

In [7]:
from torch.serialization import validate_cuda_device
train_and_valid_data = torchvision.datasets.FashionMNIST(
    root="datasets", train=True, download=True, transform=toTensor
)
test_data = torchvision.datasets.FashionMNIST(
    root="datasets", train=False, download=True, transform=toTensor
)

torch.manual_seed(42)
train_data, valid_data = torch.utils.data.random_split(
    train_and_valid_data, [55000, 5000])

100%|██████████| 26.4M/26.4M [00:02<00:00, 12.5MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 210kB/s]
100%|██████████| 4.42M/4.42M [00:01<00:00, 3.93MB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 24.5MB/s]


In [8]:
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=32)
test_loader = DataLoader(test_data, batch_size=32)

In [9]:
X_sample, y_sample = train_data[0]

In [10]:
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

device

'cuda'

In [11]:
class ImageCLassifier(nn.Module):
  def __init__(self, n_inputs, n_hidden1, n_hidden2, n_classes):
    super().__init__()
    self.mlp = nn.Sequential(
        nn.Flatten(),
        nn.Linear(n_inputs, n_hidden1),
        nn.ReLU(),
        nn.Linear(n_hidden1, n_hidden2),
        nn.ReLU(),
        nn.Linear(n_hidden2, n_classes)
    )

  def forward(self, X):
    return self.mlp(X)

torch.manual_seed(42)
model = ImageCLassifier(n_inputs =28*28, n_hidden1=300, n_hidden2=100, n_classes=10)
xentropy = nn.CrossEntropyLoss()
n_epochs = 20

In [15]:
def train2(model, optimizer, criterion, metric, train_loader, valid_loader,
               n_epochs):
    history = {"train_losses": [], "train_metrics": [], "valid_metrics": []}
    for epoch in range(n_epochs):
        total_loss = 0.
        metric.reset()
        for X_batch, y_batch in train_loader:
            model.train()
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            y_pred = model(X_batch)
            loss = criterion(y_pred, y_batch)
            total_loss += loss.item()
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            metric.update(y_pred, y_batch)
        mean_loss = total_loss / len(train_loader)
        history["train_losses"].append(mean_loss)
        history["train_metrics"].append(metric.compute().item())
        history["valid_metrics"].append(
            evaluate_tm(model, valid_loader, metric).item())
        print(f"Epoch {epoch + 1}/{n_epochs}, "
              f"train loss: {history['train_losses'][-1]:.4f}, "
              f"train metric: {history['train_metrics'][-1]:.4f}, "
              f"valid metric: {history['valid_metrics'][-1]:.4f}")
    return history

In [16]:
def evaluate_tm(model, data_loader, metric):
    model.eval()
    metric.reset()  # reset the metric at the beginning
    with torch.no_grad():
        for X_batch, y_batch in data_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            y_pred = model(X_batch)
            metric.update(y_pred, y_batch)  # update it at each iteration
    return metric.compute()  # compute the final result at the end

    evaluate_tm(model, valid_loader, rmse)

In [17]:
model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(device)

z = train2(model, optimizer, xentropy, accuracy, train_loader, valid_loader,
           n_epochs)

Epoch 1/20, train loss: 0.4074, train metric: 0.8483, valid metric: 0.8518
Epoch 2/20, train loss: 0.3639, train metric: 0.8653, valid metric: 0.8576
Epoch 3/20, train loss: 0.3360, train metric: 0.8751, valid metric: 0.8562
Epoch 4/20, train loss: 0.3162, train metric: 0.8829, valid metric: 0.8710
Epoch 5/20, train loss: 0.2996, train metric: 0.8878, valid metric: 0.8730
Epoch 6/20, train loss: 0.2851, train metric: 0.8942, valid metric: 0.8734
Epoch 7/20, train loss: 0.2730, train metric: 0.8980, valid metric: 0.8604
Epoch 8/20, train loss: 0.2620, train metric: 0.9015, valid metric: 0.8786
Epoch 9/20, train loss: 0.2527, train metric: 0.9043, valid metric: 0.8728
Epoch 10/20, train loss: 0.2435, train metric: 0.9075, valid metric: 0.8700
Epoch 11/20, train loss: 0.2348, train metric: 0.9110, valid metric: 0.8726
Epoch 12/20, train loss: 0.2282, train metric: 0.9120, valid metric: 0.8870
Epoch 13/20, train loss: 0.2222, train metric: 0.9144, valid metric: 0.8848
Epoch 14/20, train lo