In [45]:
import torch
import torch.nn as nn
from typing import Tuple, Callable
from torch.utils.data import DataLoader
import torchvision
from torchsummary import summary
from tqdm import tqdm
from datetime import datetime
import os
from enum import Enum
import shutil

In [46]:
class Mode(Enum):
    TRAIN = 0
    LOAD = 1

class Classes(Enum):
    PAPER = 0
    OTHER = 1
    ROCK = 2
    SCISSORS = 3

MODE = Mode.LOAD

In [47]:
def get_device() -> torch.device:
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")
    x = torch.ones(1, device=device)

    return device

In [48]:
def get_data(
    batch_size: int, resolution: Tuple[int, int]
) -> Tuple[DataLoader, DataLoader]:
    transform = torchvision.transforms.Compose(
        [
            torchvision.transforms.Resize(resolution),  # 3:2 aspect ratio
            torchvision.transforms.Grayscale(num_output_channels=1),
            torchvision.transforms.ToTensor(),
        ]
    )

    data_dir = "data"
    dataset = torchvision.datasets.ImageFolder(data_dir, transform)
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size]
    )
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader

In [49]:
def get_predictions_and_labels(
    model: nn.Module,
    loader: DataLoader,
    device: torch.device,
    batch_lim: int = 0,
    silent: bool = False,
) -> torch.Tensor:

    model.eval()

    for i, (data, labels) in tqdm(enumerate(iter(loader)), disable=silent):

        data, labels = data.to(device), labels.to(device)

        with torch.no_grad():
            output = model(data)

        _, predictions = torch.max(output, 1)
        # Stack predictions and labels for each sample in the batch
        batch_results = torch.stack((predictions, labels), dim=1).detach().cpu()
        if i == 0:
            results = batch_results
        else:
            results = torch.cat((results, batch_results), dim=0)
        if batch_lim > 0 and i >= batch_lim:
            break

    return results


def get_correct_percent(results: torch.Tensor) -> float:
    # results is a tensor of shape (N, 2): [prediction, label]
    correct = (results[:, 0] == results[:, 1]).sum().item()
    return correct / len(results) * 100


def save_miscategorized_images(
    results: torch.Tensor,
    val_loader: DataLoader,
    path: str,
    max_img_count: int = 20,
):
    # Create directory for misclassified images
    # results is a tensor of shape (N, 2): [prediction, label]
    misclassified_mask = results[:, 0] != results[:, 1]
    misclassified_indices = misclassified_mask.nonzero(as_tuple=True)[0][:max_img_count]

    misclassified_dir = os.path.join(path, "misclassified")
    os.makedirs(misclassified_dir, exist_ok=True)

    # Get the dataset from the DataLoader
    dataset = val_loader.dataset

    # If val_loader is a Subset, get the original indices
    if hasattr(dataset, "indices"):
        indices_map = dataset.indices
    else:
        indices_map = range(len(dataset))

    for idx, mis_idx in enumerate(misclassified_indices):
        orig_idx = indices_map[mis_idx.item()]
        img_path, label = dataset.dataset.samples[orig_idx]
        pred = int(results[mis_idx, 0].item())
        actual = int(results[mis_idx, 1].item())
        dst = os.path.join(
            misclassified_dir,
            f"misclassified_{idx}_pred{pred}_actual{actual}_{os.path.basename(img_path)}"
        )
        shutil.copy(img_path, dst)

In [50]:
def train(
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    train_loader: DataLoader,
    val_loader: DataLoader,
    device: torch.device,
    epochs: int,
) -> None:
    for epoch in range(epochs):
        progress_bar = tqdm(
            enumerate(iter(train_loader)),
            total=len(train_loader),
            desc=f"Epoch {epoch+1}/{epochs}",
        )
        model.train()
        for batchIdx, (data, labels) in progress_bar:
            data, labels = data.to(device), labels.to(device)
            optimizer.zero_grad()
            predictions = model(data)
            batch_loss = loss_fn(predictions, labels)
            batch_loss.backward()
            optimizer.step()
            if batchIdx == 0:
                batch_losses = batch_loss.detach().unsqueeze(0)
            else:
                batch_losses = torch.cat(
                    (batch_losses, batch_loss.detach().unsqueeze(0))
                )
            progress_bar.set_postfix(
                {"Mean batch loss": torch.mean(batch_losses).item()}
            )
        val_acc = get_correct_percent(
            get_predictions_and_labels(model, val_loader, device, batch_lim=1, silent=True)
        )
        train_acc = get_correct_percent(
            get_predictions_and_labels(model, train_loader, device, batch_lim=1, silent=True)
        )
        print(f"Val Acc: {val_acc:.2f}% Train Acc: {train_acc:.2f}%")

In [51]:
def test(net, loader, device) -> float:
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = net(data)
            _, prediction = torch.max(output, 1)
            correct += torch.sum(prediction == target).item()
            total += target.size(0)
    return 100 * correct / total

In [52]:
device = get_device()
print(f"Device: {device}")


resolution = (150, 100)
batch_size = 256


train_loader, val_loader = get_data(batch_size, resolution)

Device: cuda


In [None]:
MODEL_FILE_NAME = "rps2.pt"

if MODE == Mode.LOAD:
    model = torch.load(MODEL_FILE_NAME, weights_only=False).to(device)
    summary(model, (1, *resolution), batch_size, device.type)
elif MODE == Mode.TRAIN:
    model = nn.Sequential(
        nn.Conv2d(1, 6, 5),
        nn.ReLU(),
        nn.MaxPool2d(2, 2),
        nn.Conv2d(6, 16, 5),
        nn.ReLU(),
        nn.MaxPool2d(2, 2),
        nn.Flatten(),
        nn.Linear(11968, 128),
        nn.ReLU(),
        nn.Linear(128, 84),
        nn.ReLU(),
        nn.Linear(84, 4),
    ).to(device)
    summary(model, (1, *resolution), batch_size, device.type)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), weight_decay=0.03)

RuntimeError: Error(s) in loading state_dict for Net:
	Missing key(s) in state_dict: "conv_layers.12.weight", "conv_layers.12.bias", "conv_layers.15.weight", "conv_layers.15.bias", "fc_layers.5.weight", "fc_layers.5.bias". 
	size mismatch for fc_layers.2.weight: copying a param with shape torch.Size([4, 128]) from checkpoint, the shape in current model is torch.Size([64, 512]).
	size mismatch for fc_layers.2.bias: copying a param with shape torch.Size([4]) from checkpoint, the shape in current model is torch.Size([64]).

In [None]:
summary(model, (1, *resolution), batch_size, device.type)


if MODE == Mode.TRAIN:

    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), weight_decay=0.03)
    train(model, optimizer, loss_fn, train_loader, val_loader, device, epochs=50)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [256, 6, 146, 96]             156
              ReLU-2          [256, 6, 146, 96]               0
         MaxPool2d-3           [256, 6, 73, 48]               0
            Conv2d-4          [256, 16, 69, 44]           2,416
              ReLU-5          [256, 16, 69, 44]               0
         MaxPool2d-6          [256, 16, 34, 22]               0
           Flatten-7               [256, 11968]               0
            Linear-8                 [256, 128]       1,532,032
              ReLU-9                 [256, 128]               0
           Linear-10                  [256, 84]          10,836
             ReLU-11                  [256, 84]               0
           Linear-12                   [256, 4]             340
Total params: 1,545,780
Trainable params: 1,545,780
Non-trainable params: 0
---------------------------

In [None]:
results = get_predictions_and_labels(model, val_loader, device)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
run_artifacts_dir = os.path.join("run_artifacts", timestamp)
os.makedirs(run_artifacts_dir, exist_ok=True)
save_miscategorized_images(results, val_loader, run_artifacts_dir, 100)
val_acc = get_correct_percent(results)
train_acc = get_correct_percent(get_predictions_and_labels(model, train_loader, device))
print(f"Train Acc: {train_acc:.2f}%, Val Acc: {val_acc:.2f}%")

2it [00:04,  2.29s/it]
5it [00:16,  3.21s/it]

Train Acc: 85.33%, Val Acc: 83.22%





In [None]:
path = os.path.join(run_artifacts_dir, MODEL_FILE_NAME)
torch.save(model, path)