In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision

torch.set_float32_matmul_precision('high')
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# device = torch.device('cpu') # for easier debugging
RANDOM_SEED = 42
RANDOM_GENERATOR = torch.Generator().manual_seed(RANDOM_SEED)

In [None]:
from torch.utils.data import DataLoader, Dataset
from torchvision.io.image import ImageReadMode
from torchvision.transforms import v2

BLACK_PIECES = "将士象車馬砲卒"
RED_PIECES = "帅仕相俥傌炮兵"
VOCAB = BLACK_PIECES + RED_PIECES
NUM_CLASSES = len(VOCAB)

PADDING_VALUE = NUM_CLASSES
# IGNORE_INDEX = -100 # Specify to ignore certain targets during loss computation

MAX_LABEL_LEN = 1


def get_image_paths(folder):# -> list:
    return [os.path.join(folder, img) for img in os.listdir(folder)]


def get_filenames(paths):
    return [os.path.basename(path).split('.')[0] for path in paths]


def label_to_integers(label):
    assert MAX_LABEL_LEN is not None
    integers = [VOCAB.index(c) for c in label]
    assert len(integers) <= MAX_LABEL_LEN
    integers.extend([PADDING_VALUE] * (MAX_LABEL_LEN - len(integers))) # Add padding
    return integers


def paths_to_labels(paths):
    filenames = get_filenames(paths)
    labels = []
    for filename in filenames:
        label = filename[:-2] # remove the "-0" suffix
        label = label_to_integers(label)
        labels.append(label)

    return torch.tensor(labels, dtype=torch.long)

In [None]:
from functools import cache


class ImageDataset(Dataset):
    def __init__(self, image_paths, transform=None, is_train=True):
        self.image_paths = image_paths
        self.labels = paths_to_labels(self.image_paths)
        self.transform = transform
        self.is_train = is_train

    def __len__(self):
        return len(self.image_paths)

    @cache # Don't cache if the transforms are not deterministic
    def __getitem__(self, idx):
        label = self.labels[idx].to(device)

        img_path = self.image_paths[idx]
        image = torchvision.io.read_image(img_path, mode=ImageReadMode.GRAY)
        # image = torchvision.io.read_image("train/lpbf-0.png", mode=ImageReadMode.GRAY)
        image = image.to(device).float()

        # Invert image, making background pixels 0-intensity
        image = 255.0 - image

        # Apply normalization
        image /= 255.0 # Scale from [0,255] to [0,1]

        if self.transform:
            image = self.transform(image)

        return image, label

In [None]:
# Dataset for loading CAPTCHA images
import random
from typing import Optional

# Define the mean and standard deviation used for normalization
RGB_MEAN = [0.485, 0.456, 0.406]
RGB_STD = [0.229, 0.224, 0.225]
GRAYSCALE_MEAN = [0.5]
GRAYSCALE_STD = [0.5]
NORM_MEAN = GRAYSCALE_MEAN # Modify this
NORM_STD = GRAYSCALE_STD # Modify this

TRAIN_VAL_SPLIT = [0.9, 0.1]
BATCH_SIZE = 8

# Resize all images to this size
IMG_HEIGHT = 32
IMG_WIDTH = 32

# Image transformations
transform = v2.Compose([
    v2.Resize((IMG_HEIGHT, IMG_WIDTH)),
    # v2.ColorJitter((1.0,1.0), (2.0,2.0)),
    # v2.RandomAffine(degrees=30, translate=(0.1, 0.1), scale=(0.9, 1.1)),
    # v2.RandomRotation(45),
    # v2.Grayscale(),
    v2.Normalize(mean=NORM_MEAN, std=NORM_STD),
])

image_folders = ["train/"]
image_paths = []
for image_folder in image_folders:
    image_paths.extend(get_image_paths(image_folder))
random.seed(RANDOM_SEED)
random.shuffle(image_paths)

N_train = round(TRAIN_VAL_SPLIT[0] * len(image_paths))
train_paths, val_paths = image_paths[:N_train], image_paths[N_train:]

train_dataset = ImageDataset(train_paths, transform=transform, is_train=True)
val_dataset = ImageDataset(val_paths, transform=transform, is_train=False)
test_dataset = ImageDataset(get_image_paths("test/"), transform=transform, is_train=False)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)

# Un-normalize the image
def unnormalize(image, mean, std):
    mean = torch.tensor(mean).view(-1, 1, 1)
    std = torch.tensor(std).view(-1, 1, 1)
    # image = image * std + mean  # Reverse normalization
    return image * 255.0


def display(
    image: torch.Tensor,
    actual_label: Optional[torch.Tensor] = None,
    predicted_label: Optional[torch.Tensor] = None
):
    """
    Display a sample image
    """
    image = unnormalize(image, NORM_MEAN, NORM_STD)
    plt.figure(figsize=(12,24), facecolor='black')

    if actual_label is not None:
        actual_label = ''.join(VOCAB[i] for i in list(actual_label.int()) if i != PADDING_VALUE)
    if predicted_label is not None:
        predicted_label = ''.join(VOCAB[i] for i in list(predicted_label.int()) if i != PADDING_VALUE)

    print(f"Predicted: {predicted_label}, Actual: {actual_label}")

    for i in range(len(image)):
        if torch.all(image[i] == 0):
            continue
        plt.subplot(1, 15, i+1)
        plt.axis('off')
        plt.imshow(image[i].squeeze(0).cpu().numpy().astype(np.uint8), cmap='gray')

    plt.tight_layout()
    plt.show()


images, labels = next(iter(train_dataloader))
display(images[0], actual_label=labels[0]) # extract first sample from batch

input_channels = images.shape[2] # Get the no. of channels in the image
print("Input channels:", input_channels)

In [None]:
images.shape

In [None]:
class ImageClassificationModel(nn.Module):
    def __init__(self, num_classes, input_channels, seq_length):  # 26 letters + 10 digits = 36 classes
        super().__init__()
        self.seq_length = seq_length

        # CNN layers
        first_layer_channels = 32
        n_cnn_layers = 4
        in_channels = input_channels
        out_channels = None
        cnn_layers = [
            # nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding='same'),
            # nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding='same'),
            # nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding='same'),
        ]
        for i in range(1, n_cnn_layers+1):
            out_channels = first_layer_channels if i == 1 else in_channels * 2
            cnn_layers.append(nn.Conv2d(in_channels, out_channels, kernel_size= 3 if i == 1 else 5, stride=1, padding='same')),
            cnn_layers.append(nn.MaxPool2d(kernel_size=2, stride=2)),
            cnn_layers.append(nn.ReLU()),
            cnn_layers.append(nn.BatchNorm2d(out_channels)),
            cnn_layers.append(nn.Dropout2d(p=0.30)),
            in_channels = out_channels

        self.cnn = nn.Sequential(
            *cnn_layers,
            # nn.Dropout2d(p=0.25),
        )

        num_features = 4096
        self.fc = nn.Sequential(
            # nn.Dropout(p=0.25),
            nn.Linear(num_features, num_features//2),
            nn.Linear(num_features//2, num_features//4),
            nn.Linear(num_features//4, num_classes)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.cnn(x)
        N, C, h, w = x.size()
        x = x.view(N, -1) # flatten the matrix
        x = self.fc(x)
        return x


def compute_loss(criterion, logits: torch.Tensor, labels: torch.Tensor):
    return criterion(logits, labels)

In [None]:
def train(model, train_dataloader, criterion, optimizer, val_dataloader=None, epochs=10):
    best_val_loss = float('inf')
    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        num_samples = 0
        for images, labels in train_dataloader:
            images = images.to(device)
            labels = labels.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            logits: torch.Tensor = model(images)

            # Compute loss
            loss = compute_loss(criterion, logits, labels)

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            # the criterion computes the mean across the batch by default
            # the last batch might not have the same batch size as the other batches
            # this way of loss calculation is more accurate
            total_loss += loss.item() * images.size(0)
            num_samples += images.size(0) # no. of samples in current batch

        train_loss = total_loss / num_samples

        val_loss = val_acc = None
        if val_dataloader is not None:
            val_loss, val_acc = test(model, val_dataloader, criterion=criterion)

        print(f"epoch [{epoch+1}/{epochs}], train_loss: {train_loss:.2f}, val_loss: {val_loss:.2f}, val_acc: {val_acc:.2f}")
        if val_loss is not None and val_loss <= best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), "model.pth")
            print("Saved model state dict to model.pth")


def test(model, dataloader, criterion=None) -> float:
    model.eval()
    total_loss = 0.0
    num_correct = 0
    num_samples = 0

    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)
            # Forward pass
            logits: torch.Tensor = model(images)

            # Compute loss
            if criterion:
                loss = compute_loss(criterion, logits, labels)
                total_loss += loss.item() * images.size(0)

            y_pred = logits.argmax(dim=1)

            # Compare predictions with true labels
            correct_predictions: torch.Tensor = (y_pred == labels)
            num_correct += correct_predictions.sum()

            num_samples += images.size(0)

    if criterion:
        return total_loss / num_samples, num_correct / num_samples

    return num_correct / num_samples

In [None]:
# Training block
model = ImageClassificationModel(num_classes=NUM_CLASSES, input_channels=input_channels, seq_length=MAX_LABEL_LEN).to(device)

# Resume from checkpoint. Comment out to start training from scratch
# state_dict = torch.load("model.pth", map_location=device)
# model.load_state_dict(state_dict)
# # print(model)

# Loss and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=PADDING_VALUE).to(device)
optimizer = optim.AdamW(model.parameters(), lr=0.01, amsgrad=True)

# Train the model
train(model, train_dataloader, criterion, optimizer, val_dataloader=val_dataloader, epochs=150)

In [None]:
import gc

gc.collect()
torch.cuda.empty_cache()

In [None]:
# Testing block
model = ImageClassificationModel(num_classes=NUM_CLASSES, input_channels=input_channels, seq_length=MAX_LABEL_LEN)
model = model.to(device)
state_dict = torch.load("model.pth", map_location=device)
model.load_state_dict(state_dict)
train_acc = test(model, train_dataloader)
print(f'train_acc: {train_acc * 100:.2f}%')
val_acc = test(model, val_dataloader)
print(f'val_acc: {val_acc * 100:.2f}%')
test_acc = test(model, test_dataloader)
print(f'test_acc: {test_acc * 100:.2f}%')

In [None]:
# Model playground
model.eval()
for images, labels in test_dataloader:
    # Forward pass
    logits: torch.Tensor = model(images)

    y_pred = logits.argmax(dim=1)
    display(images[0], actual_label=labels[0], predicted_label=y_pred[0])
    break

In [None]:
import torch

state_dict = torch.load("model.pth", map_location='cuda')
print(state_dict)