# Домашнее задание 2 — FashionMNIST на PyTorch Lightning


In [1]:
!pip install numpy torch torchvision lightning

Collecting lightning
  Downloading lightning-2.6.0-py3-none-any.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.9/44.9 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities<2.0,>=0.10.0 (from lightning)
  Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB)
Collecting torchmetrics<3.0,>0.7.0 (from lightning)
  Downloading torchmetrics-1.8.2-py3-none-any.whl.metadata (22 kB)
Collecting pytorch-lightning (from lightning)
  Downloading pytorch_lightning-2.6.0-py3-none-any.whl.metadata (21 kB)
Downloading lightning-2.6.0-py3-none-any.whl (845 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m846.0/846.0 kB[0m [31m25.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.15.2-py3-none-any.whl (29 kB)
Downloading torchmetrics-1.8.2-py3-none-any.whl (983 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m57.4 MB/s[0m eta [36m0:00:00[0

Импорты, версии и воспроизводимость

In [2]:
import os
import math
import random
from dataclasses import dataclass
from typing import Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split

import torchvision
from torchvision import transforms


import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger

import torchmetrics

print("torch:", torch.__version__)
print("torchvision:", torchvision.__version__)
print("lightning:", pl.__version__)
print("torchmetrics:", torchmetrics.__version__)

SEED = 42
pl.seed_everything(SEED, workers=True)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", DEVICE)


INFO: Seed set to 42
INFO:lightning.fabric.utilities.seed:Seed set to 42


torch: 2.9.0+cu126
torchvision: 0.24.0+cu126
lightning: 2.6.0
torchmetrics: 1.8.2
device: cuda


## `FashionMNISTDataModule`

Требования:
- загрузка данных
- предобработка: ToTensor + Normalize (+ опционально аугментации)
- разбиение на train/val/test
- dataloader'ы

**Нормализация.** Для FashionMNIST обычно используют mean/std, оцененные по train. В литературе часто встречается `mean≈0.286`, `std≈0.353` для FashionMNIST. Здесь зададим эти значения явно.

**Аугментации.**  Самый минимум: `RandomHorizontalFlip` (для одежды допустимо), можно добавить `RandomAffine` с малыми углами.


In [3]:
class FashionMNISTDataModule(pl.LightningDataModule):
    def __init__(
        self,
        data_dir: str = "./data",
        batch_size: int = 128,
        num_workers: int = 4,
        val_size: int = 5000,
        seed: int = 42,
        pin_memory: bool = True,
    ):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.val_size = val_size
        self.seed = seed
        self.pin_memory = pin_memory
        self.mean = (0.2860,)
        self.std = (0.3530,)

        self.train_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.Normalize(self.mean, self.std),
        ])

        self.test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(self.mean, self.std),
        ])

        self.ds_train = None
        self.ds_val = None
        self.ds_test = None

    def prepare_data(self) -> None:
        torchvision.datasets.FashionMNIST(self.data_dir, train=True, download=True)
        torchvision.datasets.FashionMNIST(self.data_dir, train=False, download=True)

    def setup(self, stage: Optional[str] = None) -> None:
        if stage in (None, "fit"):
            full_train = torchvision.datasets.FashionMNIST(
                self.data_dir, train=True, transform=self.train_transform, download=False
            )

            train_size = len(full_train) - self.val_size
            if train_size <= 0:
                raise ValueError(f"val_size={self.val_size} слишком большой для train={len(full_train)}")

            g = torch.Generator().manual_seed(self.seed)
            self.ds_train, self.ds_val = random_split(full_train, [train_size, self.val_size], generator=g)

        if stage in (None, "test"):
            self.ds_test = torchvision.datasets.FashionMNIST(
                self.data_dir, train=False, transform=self.test_transform, download=False
            )

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.ds_train,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory and torch.cuda.is_available(),
            persistent_workers=self.num_workers > 0,
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.ds_val,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory and torch.cuda.is_available(),
            persistent_workers=self.num_workers > 0,
        )

    def test_dataloader(self) -> DataLoader:
        return DataLoader(
            self.ds_test,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory and torch.cuda.is_available(),
            persistent_workers=self.num_workers > 0,
        )


Проверяем, что все собирается и размеры корректные.


In [4]:
dm = FashionMNISTDataModule(batch_size=128, num_workers=4, val_size=5000, seed=SEED)
dm.prepare_data()
dm.setup("fit")

x, y = next(iter(dm.train_dataloader()))
print("batch x:", x.shape, x.dtype)
print("batch y:", y.shape, y.dtype, "classes:", y.min().item(), "..", y.max().item())


100%|██████████| 26.4M/26.4M [00:02<00:00, 8.94MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 208kB/s]
100%|██████████| 4.42M/4.42M [00:01<00:00, 3.58MB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 25.0MB/s]


batch x: torch.Size([128, 1, 28, 28]) torch.float32
batch y: torch.Size([128]) torch.int64 classes: 0 .. 9


## `FashionMNIST` модель (`LightningModule`)

Требования:
- `training_step`, `validation_step`, `test_step`
- метрики TorchMetrics: **F1**, **ROC AUC** на val/test
- логирование loss/метрик по эпохам
- подбор optimizer + lr-scheduler

### Почему такая архитектура
FashionMNIST - 28×28 (один канал), 10 классов, относительно простой датасет.
- **CNN**: локальные признаки (края, формы) важны, и CNN учится на порядок эффективнее.
- Архитектура: 3 блока Conv-BN-ReLU-Pool, потом FC.
- Dropout чтобы чуть уменьшить переобучение.

### Optimizer + Scheduler
- **AdamW**: дефолт для CNN, weight decay адекватная регуляризация.
- **ReduceLROnPlateau**: простая стратегия с EarlyStopping.


In [5]:
from torchmetrics.classification import MulticlassF1Score, MulticlassAUROC

class FashionMNIST(pl.LightningModule):
    def __init__(
        self,
        num_classes: int = 10,
        lr: float = 1e-3,
        weight_decay: float = 1e-4,
        lr_factor: float = 0.5,
        lr_patience: int = 2,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1, 1)),
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(p=0.2),
            nn.Linear(128, num_classes),
        )

        self.criterion = nn.CrossEntropyLoss()

        self.val_f1 = MulticlassF1Score(num_classes=num_classes, average="macro")
        self.val_auroc = MulticlassAUROC(num_classes=num_classes, average="macro")

        self.test_f1 = MulticlassF1Score(num_classes=num_classes, average="macro")
        self.test_auroc = MulticlassAUROC(num_classes=num_classes, average="macro")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = self.classifier(x)
        return x

    def _shared_step(self, batch: Tuple[torch.Tensor, torch.Tensor]):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        probs = F.softmax(logits, dim=1)
        preds = torch.argmax(logits, dim=1)
        return loss, probs, preds, y

    def training_step(self, batch, batch_idx):
        loss, probs, preds, y = self._shared_step(batch)
        self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, probs, preds, y = self._shared_step(batch)

        self.val_f1.update(preds, y)
        self.val_auroc.update(probs, y)

        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val_f1", self.val_f1, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val_auroc", self.val_auroc, on_step=False, on_epoch=True, prog_bar=False)

    def test_step(self, batch, batch_idx):
        loss, probs, preds, y = self._shared_step(batch)

        self.test_f1.update(preds, y)
        self.test_auroc.update(probs, y)

        self.log("test_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("test_f1", self.test_f1, on_step=False, on_epoch=True, prog_bar=True)
        self.log("test_auroc", self.test_auroc, on_step=False, on_epoch=True, prog_bar=False)

    def configure_optimizers(self):
        opt = torch.optim.AdamW(
            self.parameters(),
            lr=self.hparams.lr,
            weight_decay=self.hparams.weight_decay,
        )
        sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
            opt,
            mode="min",
            factor=self.hparams.lr_factor,
            patience=self.hparams.lr_patience,
            min_lr=1e-6,
        )
        return {
            "optimizer": opt,
            "lr_scheduler": {
                "scheduler": sch,
                "monitor": "val_loss",
                "interval": "epoch",
                "frequency": 1,
            },
        }


Проверяем, что модель делает forward и выдает правильный shape.


In [6]:
model = FashionMNIST()
model.eval()
with torch.no_grad():
    out = model(x[:8])
print("logits:", out.shape)


logits: torch.Size([8, 10])


## Обучение через `Trainer`

Требования:
- `EarlyStopping`
- TensorBoard логирование
- интерпретация по графикам
- тест

Добавим также `ModelCheckpoint`, чтобы сохранить лучший чекпоинт (по `val_f1`).


In [7]:
logger = TensorBoardLogger(save_dir="tb_logs", name="fashionmnist_hw2")

callbacks = [
    EarlyStopping(monitor="val_loss", mode="min", patience=5),
    ModelCheckpoint(monitor="val_f1", mode="max", save_top_k=1, filename="best-{epoch:02d}-{val_f1:.4f}"),
    LearningRateMonitor(logging_interval="epoch"),
]

trainer = pl.Trainer(
    max_epochs=30,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    devices=1,
    logger=logger,
    callbacks=callbacks,
    deterministic=True,
    log_every_n_steps=50,
    enable_checkpointing=True,
)

dm = FashionMNISTDataModule(batch_size=128, num_workers=4, val_size=5000, seed=SEED)
model = FashionMNIST(lr=1e-3, weight_decay=1e-4)

trainer.fit(model, datamodule=dm)


INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

INFO: `Trainer.fit` stopped: `max_epochs=30` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=30` reached.


## Тестирование лучшего чекпоинта

Lightning автоматически хранит путь до лучшего чекпоинта в `trainer.checkpoint_callback.best_model_path`.


In [8]:
best_ckpt = trainer.checkpoint_callback.best_model_path
best_score = trainer.checkpoint_callback.best_model_score
print("Best checkpoint:", best_ckpt)
print("Best val_f1:", best_score)

test_results = trainer.test(model=None, datamodule=dm, ckpt_path="best")
test_results


Best checkpoint: tb_logs/fashionmnist_hw2/version_0/checkpoints/best-epoch=25-val_f1=0.9204.ckpt
Best val_f1: tensor(0.9204, device='cuda:0')


INFO: Restoring states from the checkpoint path at tb_logs/fashionmnist_hw2/version_0/checkpoints/best-epoch=25-val_f1=0.9204.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Restoring states from the checkpoint path at tb_logs/fashionmnist_hw2/version_0/checkpoints/best-epoch=25-val_f1=0.9204.ckpt
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: Loaded model weights from the checkpoint at tb_logs/fashionmnist_hw2/version_0/checkpoints/best-epoch=25-val_f1=0.9204.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Loaded model weights from the checkpoint at tb_logs/fashionmnist_hw2/version_0/checkpoints/best-epoch=25-val_f1=0.9204.ckpt


Output()

[{'test_loss': 0.25107261538505554,
  'test_f1': 0.9108181595802307,
  'test_auroc': 0.9942920804023743}]