# PyTorch Yatt

Yatt is yet another PyTorch trainer.

## Imports

In [6]:
import os
from typing import Literal

import numpy as np
import torch
from torch import Size, Tensor, nn
from torch.utils import data
from torchvision import datasets as vdata
from torchvision import transforms as vtransforms
from torchvision import utils as vutils

from yatt import DataLoaderConfig, OptimizerConfig, HParams, Trainer

## Setup Architecture

In [7]:
class Residual(nn.Module):
    def __init__(
        self,
        child: nn.Module,
    ) -> None:
        super().__init__()
        self._child = child

    def forward(self, x: Tensor) -> Tensor:
        return x + self._child.forward(x)


class AutoEncoder(nn.Module):

    def __init__(
        self,
        in_shape: tuple[int, int, int],
        hidden_dims: list[int],
        latent_dim: int,
    ) -> None:
        super().__init__()

        channels = [in_shape[0], *hidden_dims]
        self.encoder = nn.Sequential(
            *[
                nn.Sequential(
                    nn.Conv2d(ch1, ch2, kernel_size=3, stride=2, padding=1),
                    nn.SELU(),
                    Residual(
                        nn.Conv2d(ch2, ch2, kernel_size=3, stride=1, padding=1),
                    ),
                    nn.SELU(),
                ) for ch1,ch2 in zip(channels[:-1], channels[1:])
            ]
        )

        channels = list(reversed(channels))
        self.decoder = nn.Sequential(
            *[
                nn.Sequential(
                    nn.ConvTranspose2d(ch1, ch2, kernel_size=3, stride=2, padding=1, output_padding=1),
                    nn.SELU(),
                    Residual(
                        nn.Conv2d(ch2, ch2, kernel_size=3, stride=1, padding=1),
                    ),
                    nn.SELU(),
                ) for ch1,ch2 in zip(channels[:-1], channels[1:])
            ],
            nn.Tanh(),
        )

    def forward(self, x: Tensor) -> tuple[Tensor, Tensor]:
        z = self.encoder(x)
        xhat = self.decoder(z)
        return xhat

## Setup Trainer

In [8]:
class PooParams(HParams):
    dataset: Literal["cifar10", "celeba", "fgvc"]
    img_shape: tuple[int, int, int]
    hidden_dims: list[int]
    latent_dim: int
    learning_rate: float
    batch_size: int
    num_workers: int = (os.cpu_count() or 0) // 2


class MyTrainer(Trainer[PooParams, AutoEncoder]):

    @classmethod
    def configure_model(cls, hp: PooParams) -> nn.Module:
        model = AutoEncoder(
            in_shape=hp.img_shape,
            hidden_dims=hp.hidden_dims,
            latent_dim=hp.latent_dim,
        )

        return model

    @classmethod
    def configure_optimizer(cls, hp: PooParams, model: AutoEncoder) -> OptimizerConfig:
        optimizer = torch.optim.Adam(model.parameters())
        lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer)
        return OptimizerConfig(optimizer, lr_scheduler)

    @classmethod
    def configure_data_loaders(cls, hp: PooParams) -> DataLoaderConfig:
        transform = vtransforms.Compose([
            vtransforms.Resize(hp.img_shape[-2:]),
            vtransforms.CenterCrop(hp.img_shape[-2:]),
            vtransforms.ToTensor(),
            vtransforms.Normalize(0.5, 0.5),
        ])
        match hp.dataset:
            case "cifar10":
                train_ds = vdata.CIFAR10("../data", train=True, transform=transform, download=True)
                train_ds, val_ds = data.random_split(train_ds, [0.9, 0.1])
                test_ds = vdata.CIFAR10("../data", train=False, transform=transform, download=True)
            case "celeba":
                train_ds = vdata.CelebA("../data", split="train", transform=transform, download=True)
                val_ds = vdata.CelebA("../data", split="valid", transform=transform, download=True)
                test_ds = vdata.CelebA("../data", split="test", transform=transform, download=True)
            case "fgvc":
                train_ds = vdata.FGVCAircraft("../data", "train", transform=transform, download=True)
                val_ds = vdata.FGVCAircraft("../data", "val", transform=transform, download=True)
                test_ds = vdata.FGVCAircraft("../data", "test", transform=transform, download=True)
            case _:
                raise ValueError

        train_dl = data.DataLoader(train_ds,
                                   shuffle=True,
                                   batch_size=hp.batch_size,
                                   pin_memory=True,
                                   num_workers=hp.num_workers,
                                   persistent_workers=hp.num_workers > 0)
        val_dl = data.DataLoader(val_ds,
                                 batch_size=hp.batch_size,
                                 pin_memory=True,
                                 num_workers=hp.num_workers,
                                 persistent_workers=hp.num_workers > 0)
        test_dl = data.DataLoader(test_ds,
                                 pin_memory=True,
                                 num_workers=hp.num_workers,
                                 persistent_workers=hp.num_workers > 0)
        return DataLoaderConfig(
            train=train_dl,
            val=val_dl,
            test=test_dl,
        )

    def get_loss(self, x: Tensor) -> Tensor:
        xhat = self.model(x)
        loss = torch.nn.functional.mse_loss(x, xhat)
        return loss

    def train_step(self, batch: list[Tensor], batch_idx: int) -> Tensor:
        return self.get_loss(batch[0])

    def val_step(self, batch: list[Tensor], batch_idx: int) -> Tensor:
        return self.get_loss(batch[0])

    def test_step(self, batch: list[Tensor], batch_idx: int) -> Tensor:
        return self.get_loss(batch[0])


    def train_epoch_begin(self) -> None:
        pass
    def train_epoch_end(self) -> None:
        pass

    def val_epoch_begin(self) -> None:
        pass
    def val_epoch_end(self) -> None:
        if self.data_loaders.val == None:
            return
        x = next(iter(self.data_loaders.val))[0][:8].to(self.device)
        y = self.model(x)
        grid = vutils.make_grid(torch.cat([x, y]), normalize=True)
        self.log_image("val/sample", grid, self.epoch)
        self.log_graph(x)


In [9]:
%%html
<!-- fix widget style -->

<style>
    html .widget-html {
        color: white !important;
        mix-blend-mode: difference;
    }

    html .cell-output-ipywidget-background {
        background: transparent !important;
    }
</style>

## Execution

In [10]:
hp = PooParams(
    dataset="celeba",
    img_shape=(3,64,64),
    hidden_dims=[16, 32, 64, 128],
    latent_dim=512,
    learning_rate=1e-3,
    batch_size=512,
)

trainer = MyTrainer(
    f"auto_encoder.{hp.dataset}.img_shape={hp.img_shape}.latent_dim={hp.latent_dim}",
    save_best_count=5,
    max_epochs=1000,
    log_interval=200,
    device=torch.device("cuda"),
)

# trainer.configure(hp)
trainer.configure_checkpoint("runs/auto_encoder.celeba.img_shape=(3, 64, 64).latent_dim=512/2023-03-14@14:33:31/checkpoints/latest.loss=0.018214423209428787.epoch=1.ckpt")

trainer.train()

┌─────────────────────────────────────┐
│             AutoEncoder             │
└─────────────────────────────────────┘
┌─────────────────────────────────────┐
│               HParams               │
├─────────────────┬───────────────────┤
│ dataset         │ celeba            │
│ img_shape       │ (3, 64, 64)       │
│ hidden_dims     │ [16, 32, 64, 128] │
│ latent_dim      │ 512               │
│ learning_rate   │ 0.001             │
│ batch_size      │ 512               │
│ num_workers     │ 8                 │
└─────────────────┴───────────────────┘
┌─────────────────────────────────────┐
│                Stats                │
├─────────────────┬───────────────────┤
│ Parameter Count │ 32                │
│ Parameter Size  │ 1.7MiB            │
│ Buffer Count    │ 0                 │
│ Buffer Size     │ 0.0B              │
│ Total Size      │ 1.7MiB            │
└─────────────────┴───────────────────┘


Train 2:   0%|          | 0/318 [00:00<?, ?it/s]

KeyboardInterrupt: 