In [1]:
import torch
import torch.nn.functional as F  # noqa: N812
from omegaconf import OmegaConf
from torch import Tensor, nn

from datasets import load_cifar10
from models import BaseModel, FFCBlock
from train import train

In [2]:
class CustomModel(BaseModel):
    train_keys = ("loss",)
    val_keys = ("loss", "acc")

    def __init__(self):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.layer = nn.Sequential(
            FFCBlock(
                16,
                16,
                16,
                ratio_in=(1.0, 0.0),
                ratio_out=(0.5, 0.5),
                stride=2,
                enable_lfu=True,
            ),
            FFCBlock(
                16,
                16,
                16,
                ratio_in=(0.5, 0.5),
                ratio_out=(0.5, 0.5),
                stride=2,
                enable_lfu=True,
            ),
            FFCBlock(
                16,
                16,
                16,
                ratio_in=(0.5, 0.5),
                ratio_out=(1.0, 0.0),
                stride=2,
                enable_lfu=True,
            ),
        )
        self.layer_final = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64, 10),
        )
        return

    def get_output(self, x: Tensor):
        x = self.conv1(x)
        x, _ = self.layer(x)  # (-1 x 256 x 4 x 4)
        x = self.layer_final(x)
        return x

    def forward(self, x: Tensor, target: Tensor):
        output = self.get_output(x)
        loss = F.cross_entropy(output, target)
        return dict(loss=loss)

    @torch.inference_mode()
    def validate_batch(self, x: Tensor, target: Tensor):
        output = self.get_output(x)
        loss = F.cross_entropy(output, target).item()

        pred = output.argmax(1)
        acc = target.eq(pred).float()
        acc = acc.mean().item()

        return dict(loss=loss, acc=acc)

In [3]:
train_set = load_cifar10(root="./data", kind="train")
val_set = load_cifar10(root="./data", kind="val")

In [4]:
config = OmegaConf.load("./configs/simple_ffc.yaml")
model = CustomModel().to(config.device)

In [None]:
train(model, config, train_set, val_set)

Epoch  0: 100%|██████████| 313/313 [00:13<00:00, 23.32it/s, train/loss=2.296, val/loss=2.213, val/acc=0.169]
Epoch  1: 100%|██████████| 313/313 [00:13<00:00, 23.43it/s, train/loss=2.119, val/loss=2.084, val/acc=0.215]
Epoch  2: 100%|██████████| 313/313 [00:13<00:00, 22.97it/s, train/loss=2.010, val/loss=1.989, val/acc=0.256]
Epoch  3: 100%|██████████| 313/313 [00:12<00:00, 24.10it/s, train/loss=1.923, val/loss=1.909, val/acc=0.288]
Epoch  4: 100%|██████████| 313/313 [00:12<00:00, 25.60it/s, train/loss=1.851, val/loss=1.845, val/acc=0.309]
Epoch  5: 100%|██████████| 313/313 [00:12<00:00, 25.34it/s, train/loss=1.799, val/loss=1.802, val/acc=0.327]
Epoch  6: 100%|██████████| 313/313 [00:12<00:00, 24.75it/s, train/loss=1.758, val/loss=1.762, val/acc=0.344]
Epoch  7: 100%|██████████| 313/313 [00:12<00:00, 24.30it/s, train/loss=1.721, val/loss=1.724, val/acc=0.353]
Epoch  8: 100%|██████████| 313/313 [00:12<00:00, 25.57it/s, train/loss=1.686, val/loss=1.691, val/acc=0.364]
Epoch  9: 100%|████