In [1]:
import matplotlib.pyplot as plt
import numpy as np

import torch
import torchvision
import torchvision.transforms as transforms

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import DataLoader
from tqdm import tqdm

In [2]:
device = torch.device("mps" if torch.backends.mps.is_available() else
                      "cuda" if torch.cuda.is_available() else
                      "cpu")

In [3]:
classes = ('T-shirt_top', 'Trouser', 'Pullover', 'Dress', 'Coat',
        'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle_Boot')

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5), std=(0.5))
])

train_data = torchvision.datasets.FashionMNIST(
    "data",
    download=True,
    train=True,
    transform=transform)

test_data = torchvision.datasets.FashionMNIST(
    "data",
    download=True,
    train=False,
    transform=transform)

train_loader = DataLoader(train_data, batch_size=64, shuffle=True, num_workers=2)

test_loader = DataLoader(test_data, batch_size=64, shuffle=False, num_workers=2)

In [4]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        self.bn1 = nn.BatchNorm2d(32)
        self.bn2 = nn.BatchNorm2d(64)

        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5)

        self.fc1 = nn.Linear(64*4*4, 400)
        self.fc2 = nn.Linear(400, 100)
        self.fc3 = nn.Linear(100, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.bn1(self.conv1(x))), kernel_size=2, stride=2)
        x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))), kernel_size=2, stride=2)
        x = x.view(-1, 64*4*4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = Net().to(device)
criterion = nn.CrossEntropyLoss()
all_criterion = nn.CrossEntropyLoss(reduction="none")
optimizer1 = optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
optimizer2 = optim.Adam(model.parameters(), lr=1e-3)

In [None]:
from torch.utils.tensorboard import SummaryWriter

SummaryWriter()

In [6]:
num_epochs = 10

In [7]:
writer_sgd = SummaryWriter("runs/sgd_fashion")

model.train()
for epoch in range(1, num_epochs + 1):
    # ---- Train ----
    model.train()
    running_loss, correct, seen = 0.0, 0, 0
    for imgs, labels in tqdm(train_loader):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer1.zero_grad(set_to_none=True)

        logits = model(imgs)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer1.step()

        running_loss += loss.item() * imgs.size(0)
        correct += (logits.argmax(1) == labels).sum().item()
        seen += imgs.size(0)

    train_loss = running_loss / seen
    train_acc = correct / seen
    writer_sgd.add_scalar('Loss/train', train_loss, epoch)
    writer_sgd.add_scalar('Accuracy/train', train_acc, epoch)

    # ---- Validation ----
    model.eval()
    v_running, v_correct, v_seen = 0.0, 0, 0
    all_labels_list, all_predictions_list, all_losses_list = [], [], []

    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            logits = model(imgs)
            loss = criterion(logits, labels)

            v_running += loss.item() * imgs.size(0)
            v_correct += (logits.argmax(1) == labels).sum().item()
            v_seen += imgs.size(0)

            all_losses_list.append(all_criterion(logits, labels).cpu())
            all_predictions_list.append(logits.argmax(1).cpu())
            all_labels_list.append(labels.cpu())

    val_loss = v_running / v_seen
    val_acc = v_correct / v_seen
    writer_sgd.add_scalar('Loss/val', val_loss, epoch)
    writer_sgd.add_scalar('Accuracy/val', val_acc, epoch)

    print(f"Epoch {epoch}: train_loss={train_loss:.4f} val_loss={val_loss:.4f} "
          f"train_acc={train_acc:.3f} val_acc={val_acc:.3f}")

writer_sgd.flush()
writer_sgd.close()

100%|██████████| 938/938 [00:05<00:00, 163.18it/s]


Epoch 1: train_loss=0.4670 val_loss=0.3383 train_acc=0.833 val_acc=0.878


100%|██████████| 938/938 [00:05<00:00, 180.14it/s]


Epoch 2: train_loss=0.2906 val_loss=0.2916 train_acc=0.892 val_acc=0.893


100%|██████████| 938/938 [00:05<00:00, 176.11it/s]


Epoch 3: train_loss=0.2468 val_loss=0.2609 train_acc=0.909 val_acc=0.906


100%|██████████| 938/938 [00:05<00:00, 175.92it/s]


Epoch 4: train_loss=0.2196 val_loss=0.2822 train_acc=0.917 val_acc=0.898


100%|██████████| 938/938 [00:05<00:00, 175.91it/s]


Epoch 5: train_loss=0.1963 val_loss=0.2792 train_acc=0.928 val_acc=0.899


100%|██████████| 938/938 [00:05<00:00, 179.75it/s]


Epoch 6: train_loss=0.1784 val_loss=0.2460 train_acc=0.934 val_acc=0.911


100%|██████████| 938/938 [00:05<00:00, 177.96it/s]


Epoch 7: train_loss=0.1603 val_loss=0.2742 train_acc=0.939 val_acc=0.903


100%|██████████| 938/938 [00:05<00:00, 177.47it/s]


Epoch 8: train_loss=0.1450 val_loss=0.2542 train_acc=0.947 val_acc=0.914


100%|██████████| 938/938 [00:05<00:00, 177.50it/s]


Epoch 9: train_loss=0.1299 val_loss=0.2632 train_acc=0.951 val_acc=0.914


100%|██████████| 938/938 [00:05<00:00, 177.58it/s]


Epoch 10: train_loss=0.1166 val_loss=0.2717 train_acc=0.957 val_acc=0.914


In [8]:
writer_adam = SummaryWriter("runs/adam_fashion")

model.train()
for epoch in range(1, num_epochs + 1):
    # ---- Train ----
    model.train()
    running_loss, correct, seen = 0.0, 0, 0
    for imgs, labels in tqdm(train_loader):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer2.zero_grad(set_to_none=True)

        logits = model(imgs)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer2.step()

        running_loss += loss.item() * imgs.size(0)
        correct += (logits.argmax(1) == labels).sum().item()
        seen += imgs.size(0)

    train_loss = running_loss / seen
    train_acc = correct / seen
    writer_adam.add_scalar('Loss/train', train_loss, epoch)
    writer_adam.add_scalar('Accuracy/train', train_acc, epoch)

    # ---- Validation ----
    model.eval()
    v_running, v_correct, v_seen = 0.0, 0, 0
    all_labels_list, all_predictions_list, all_losses_list = [], [], []

    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            logits = model(imgs)
            loss = criterion(logits, labels)

            v_running += loss.item() * imgs.size(0)
            v_correct += (logits.argmax(1) == labels).sum().item()
            v_seen += imgs.size(0)

            all_losses_list.append(all_criterion(logits, labels).cpu())
            all_predictions_list.append(logits.argmax(1).cpu())
            all_labels_list.append(labels.cpu())

    val_loss = v_running / v_seen
    val_acc = v_correct / v_seen
    writer_adam.add_scalar('Loss/val', val_loss, epoch)
    writer_adam.add_scalar('Accuracy/val', val_acc, epoch)

    print(f"Epoch {epoch}: train_loss={train_loss:.4f} val_loss={val_loss:.4f} "
          f"train_acc={train_acc:.3f} val_acc={val_acc:.3f}")

writer_adam.flush()
writer_adam.close()

100%|██████████| 938/938 [00:06<00:00, 145.32it/s]


Epoch 1: train_loss=0.1912 val_loss=0.3298 train_acc=0.929 val_acc=0.896


100%|██████████| 938/938 [00:06<00:00, 154.67it/s]


Epoch 2: train_loss=0.1616 val_loss=0.2978 train_acc=0.940 val_acc=0.899


100%|██████████| 938/938 [00:06<00:00, 156.13it/s]


Epoch 3: train_loss=0.1401 val_loss=0.2730 train_acc=0.947 val_acc=0.909


100%|██████████| 938/938 [00:05<00:00, 156.46it/s]


Epoch 4: train_loss=0.1245 val_loss=0.2675 train_acc=0.953 val_acc=0.910


100%|██████████| 938/938 [00:06<00:00, 155.68it/s]


Epoch 5: train_loss=0.1086 val_loss=0.2937 train_acc=0.959 val_acc=0.913


100%|██████████| 938/938 [00:06<00:00, 152.74it/s]


Epoch 6: train_loss=0.1002 val_loss=0.3060 train_acc=0.963 val_acc=0.912


100%|██████████| 938/938 [00:05<00:00, 157.56it/s]


Epoch 7: train_loss=0.0897 val_loss=0.3326 train_acc=0.966 val_acc=0.910


100%|██████████| 938/938 [00:06<00:00, 156.20it/s]


Epoch 8: train_loss=0.0797 val_loss=0.3130 train_acc=0.970 val_acc=0.908


100%|██████████| 938/938 [00:05<00:00, 156.75it/s]


Epoch 9: train_loss=0.0702 val_loss=0.3388 train_acc=0.974 val_acc=0.910


100%|██████████| 938/938 [00:05<00:00, 156.42it/s]


Epoch 10: train_loss=0.0630 val_loss=0.3874 train_acc=0.977 val_acc=0.909
