In [14]:
import torch
import torch.nn as nn
from typing import Tuple, Callable
from torch.utils.data import DataLoader
import torchvision
import numpy as np
from torchsummary import summary
from tqdm import tqdm
from datetime import datetime
import os
from enum import Enum

In [15]:
class Mode(Enum):
    TRAIN = 0
    LOAD = 1


MODE = Mode.LOAD

In [16]:
def get_device() -> torch.device:
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")
    x = torch.ones(1, device=device)

    return device

In [17]:
def get_data(
    batch_size: int, resolution: Tuple[int, int]
) -> Tuple[DataLoader, DataLoader]:
    transform = torchvision.transforms.Compose(
        [
            torchvision.transforms.Resize(resolution),  # 3:2 aspect ratio
            torchvision.transforms.Grayscale(num_output_channels=1),
            torchvision.transforms.ToTensor(),
        ]
    )

    data_dir = "data"
    dataset = torchvision.datasets.ImageFolder(data_dir, transform)
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size]
    )
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader

In [18]:
def validate_accuracy(
    model: nn.Module,
    loader: DataLoader,
    device: torch.device,
    batch_lim: int = 0,
    silent: bool = False,
) -> float:

    model.eval()

    correct_count = 0
    total_count = 0

    for i, (data, labels) in tqdm(enumerate(iter(loader)), disable=silent):

        data, labels = data.to(device), labels.to(device)

        with torch.no_grad():

            output = model(data)

        _, predictions = torch.max(output, 1)
        correct_count += torch.sum(predictions == labels).item()
        total_count += labels.size(0)
        if batch_lim > 0 and i >= batch_lim:
            break

    return correct_count / total_count

In [19]:
def train(
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    train_loader: DataLoader,
    val_loader: DataLoader,
    device: torch.device,
    epochs: int,
) -> None:
    for epoch in range(epochs):
        progress_bar = tqdm(
            enumerate(iter(train_loader)),
            total=len(train_loader),
            desc=f"Epoch {epoch+1}/{epochs}",
        )
        model.train()
        for batchIdx, (data, labels) in progress_bar:
            data, labels = data.to(device), labels.to(device)
            optimizer.zero_grad()
            predictions = model(data)
            batch_loss = loss_fn(predictions, labels)
            batch_loss.backward()
            optimizer.step()
            if batchIdx == 0:
                batch_losses = batch_loss.detach().unsqueeze(0)
            else:
                batch_losses = torch.cat(
                    (batch_losses, batch_loss.detach().unsqueeze(0))
                )
            progress_bar.set_postfix(
                {"Mean batch loss": torch.mean(batch_losses).item()}
            )
        val_acc = validate_accuracy(model, val_loader, device, batch_lim=1, silent=True)
        train_acc = validate_accuracy(
            model, train_loader, device, batch_lim=1, silent=True
        )
        print(f"Val Acc: {val_acc * 100:.2f}% Train Acc: {train_acc * 100:.2f}%")

In [20]:
def test(net, loader, device) -> float:
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = net(data)
            _, prediction = torch.max(output, 1)
            correct += torch.sum(prediction == target).item()
            total += target.size(0)
    return 100 * correct / total

In [21]:
device = get_device()
print(f"Device: {device}")


resolution = (150, 100)
batch_size = 256


train_loader, val_loader = get_data(batch_size, resolution)

Device: cuda


In [None]:
MODEL_FILE_NAME = "rps2.pt"

if MODE == Mode.LOAD:
    model = torch.load(MODEL_FILE_NAME, weights_only=False).to(device)
    summary(model, (1, *resolution), batch_size, device.type)
elif MODE == Mode.TRAIN:
    model = nn.Sequential(
        nn.Conv2d(1, 6, 5),
        nn.ReLU(),
        nn.MaxPool2d(2, 2),
        nn.Conv2d(6, 16, 5),
        nn.ReLU(),
        nn.MaxPool2d(2, 2),
        nn.Flatten(),
        nn.Linear(11968, 128),
        nn.ReLU(),
        nn.Linear(128, 84),
        nn.ReLU(),
        nn.Linear(84, 4),
    ).to(device)
    summary(model, (1, *resolution), batch_size, device.type)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), weight_decay=0.03)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [256, 6, 146, 96]             156
              ReLU-2          [256, 6, 146, 96]               0
         MaxPool2d-3           [256, 6, 73, 48]               0
            Conv2d-4          [256, 16, 69, 44]           2,416
              ReLU-5          [256, 16, 69, 44]               0
         MaxPool2d-6          [256, 16, 34, 22]               0
           Flatten-7               [256, 11968]               0
            Linear-8                 [256, 128]       1,532,032
              ReLU-9                 [256, 128]               0
           Linear-10                  [256, 84]          10,836
             ReLU-11                  [256, 84]               0
           Linear-12                   [256, 4]             340
Total params: 1,545,780
Trainable params: 1,545,780
Non-trainable params: 0
---------------------------

In [23]:
summary(model, (1, *resolution), batch_size, device.type)


if MODE == Mode.TRAIN:

    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), weight_decay=0.03)
    train(model, optimizer, loss_fn, train_loader, val_loader, device, epochs=20)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [256, 6, 146, 96]             156
              ReLU-2          [256, 6, 146, 96]               0
         MaxPool2d-3           [256, 6, 73, 48]               0
            Conv2d-4          [256, 16, 69, 44]           2,416
              ReLU-5          [256, 16, 69, 44]               0
         MaxPool2d-6          [256, 16, 34, 22]               0
           Flatten-7               [256, 11968]               0
            Linear-8                 [256, 128]       1,532,032
              ReLU-9                 [256, 128]               0
           Linear-10                  [256, 84]          10,836
             ReLU-11                  [256, 84]               0
           Linear-12                   [256, 4]             340
Total params: 1,545,780
Trainable params: 1,545,780
Non-trainable params: 0
---------------------------

In [24]:
acc = test(model, val_loader, device)
print(f"Val Acc: {acc:.2f}")

Val Acc: 97.38


In [25]:
val_acc = validate_accuracy(model, val_loader, device)
train_acc = validate_accuracy(model, train_loader, device)
print(f"Train Acc: {train_acc * 100:.2f}%, Val Acc: {val_acc * 100:.2f}%")

5it [00:05,  1.13s/it]
18it [00:21,  1.20s/it]

Train Acc: 96.71%, Val Acc: 97.38%





In [26]:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
dir = os.path.join("run_artifacts", timestamp)
os.makedirs(dir, exist_ok=True)
path = os.path.join(dir, MODEL_FILE_NAME)
torch.save(model, path)