In [79]:
import torch
import torch.nn as nn
from typing import Tuple, Callable
from torch.utils.data import DataLoader
import torchvision
import numpy as np

In [80]:
def get_device() -> torch.device:
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")
    x = torch.ones(1, device=device)

    return device

In [81]:
def get_data(batch_size: int) -> Tuple[DataLoader, DataLoader]:
    transform = torchvision.transforms.Compose(
        [
            torchvision.transforms.Resize((120, 80)),  # 3:2 aspect ratio
            torchvision.transforms.Grayscale(num_output_channels=1),
            torchvision.transforms.ToTensor(),
        ]
    )

    data_dir = "data"
    dataset = torchvision.datasets.ImageFolder(data_dir, transform)
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=len(val_dataset), shuffle=False)
    return train_loader, val_loader

In [82]:
def train(
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    loader: DataLoader,
    device: torch.device,
    epochs: int
) -> None:
    model.train()
    for epoch in range(epochs):
        for batchIdx, (data, labels) in enumerate(iter(loader)):
            data, labels = data.to(device), labels.to(device)
            predictions = model(data)
            batch_loss = loss_fn(predictions, labels)
            optimizer.step()
            optimizer.zero_grad()


In [83]:
def validate_accuracy(model: nn.Module, loader: DataLoader, device: torch.device) -> float:
    model.eval()
    is_correct = torch.empty(0)
    for i, (data, labels) in enumerate(iter(loader)):
        data, labels = data.to(device), labels.to(device)
        with torch.no_grad():
            prediction = model(data)
        max_values, argmaxes = prediction.max(-1)
        is_correct = torch.cat((is_correct, argmaxes == labels))
    return np.mean(is_correct.cpu().numpy())

In [84]:
device = get_device()
train_loader, val_loader = get_data(256)
model = nn.Sequential(
    # 1x120x80
    nn.Conv2d(1, 4, 3),
    # 4x118x78
    nn.ReLU(),
    nn.Conv2d(4, 8, 3),
    # 8x116x76
    nn.ReLU(),
    nn.MaxPool2d(3),
    # 8x38x25
    nn.Flatten(),
    # 7600
    nn.Linear(7600, 128),
    nn.ReLU(),
    nn.Linear(128, 4),
    nn.Softmax()
).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters())
train(model, optimizer, loss_fn, train_loader, device, 1)

In [85]:
acc = validate_accuracy(model, val_loader, device)
print(f"Acc: {acc * 100:.2f}%")

Acc: 26.53%
