In [None]:
import os
import sys
from pathlib import Path
import matplotlib.pyplot as plot
import torch

BASE_DIR = Path(os.path.abspath("")).parent
TEST_SIZE = 0.2

sys.path.insert(0, str(BASE_DIR / "src"))
plot.style.use("dark_background")

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
from collections.abc import Callable
from typing import Any
from torch import Tensor
from torch.utils.data import Subset
from torchvision import transforms
from torchvision.datasets.vision import VisionDataset
from PIL import Image
from PIL.Image import Image as PILImage
from sklearn.model_selection import train_test_split

class KnotCrossingCountDataset(VisionDataset):
    def __init__(
        self,
        root: str | Path,
        transforms: Callable[[PILImage], Any] | None = None,
        transform: Callable[[PILImage], Any] | None = None,
        target_transform: Callable[[PILImage], Any] | None = None,
    ) -> None:
        super().__init__(
            root,
            transforms,
            transform,
            target_transform,
        )

        classes, class_to_index = self.find_classes(self.root)
        self.classes = classes
        self.class_to_index = class_to_index

        samples = self.find_samples(self.root, self.class_to_index)
        self.samples = samples
        self.targets = [target for _, target in samples]

    def __getitem__(self, index: int) -> Any:
        image_path, class_index = self.samples[index]
        image = self.load_image(image_path)

        if self.transforms is None:
            return image, class_index

        return self.transforms(image, class_index)

    def __len__(self) -> int:
        return len(self.samples)

    def find_classes(self, root: str | Path) -> tuple[list[str], dict[str, int]]:
        classes = sorted(
            set(
                self.get_class_from_dir_name(entry.name)
                for entry in os.scandir(root)
                if entry.is_dir()
            )
        )

        class_to_index = {class_: i for i, class_ in enumerate(classes)}

        return classes, class_to_index

    def find_samples(
        self,
        root: str | Path,
        class_to_index: dict[str, int],
    ) -> list[tuple[Path, int]]:
        samples: list[tuple[Path, int]] = []

        for dir in os.scandir(root):
            if not dir.is_dir():
                continue
            
            class_ = self.get_class_from_dir_name(dir.name)
            class_index = class_to_index[class_]

            for parent_dir, _, file_names in os.walk(dir.path):
                for file_name in file_names:
                    file_path = Path(parent_dir) / file_name

                    if self.is_valid_file(file_path):
                        samples.append((file_path, class_index))

        return samples

    def is_valid_file(self, path: Path) -> bool:
        transforms_bitmask = int(path.stem.split("_")[-2], 2)

        return (transforms_bitmask & (1 << 3)) == 0 # No elastic transform

    def load_image(self, path: Path) -> PILImage:
        with open(path, "rb") as image_file:
            return Image.open(image_file).convert("RGB")

    def get_class_from_dir_name(self, name: str) -> str:
        return name.split("_")[0]


dataset = KnotCrossingCountDataset(
    BASE_DIR / "data" / "augmented" / "transformed_knots",
    transform=transforms.Compose([
        transforms.Grayscale(),
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
    ]),
)

split_indexes: list[list[int]] = train_test_split(
    list(range(len(dataset))),
    test_size=TEST_SIZE,
    stratify=dataset.targets,
)

train_indexes, test_indexes = split_indexes

train_set = Subset[tuple[Tensor, int]](dataset, train_indexes)
test_set = Subset[tuple[Tensor, int]](dataset, test_indexes)

In [None]:
import random

random_sample = train_set[random.randrange(len(train_set))]

plot.imshow(random_sample[0].squeeze(0), cmap="grey")
plot.title(dataset.classes[random_sample[1]])
plot.show()

In [None]:
from torch import nn

from modules import (
    Conv2d,
    Pad2dPropsWithSame,
    EncoderLayer,
    AttentionNoChannelsProps,
    Conv2dNoChannelsFixedPaddingProps,
)


class Encoder(nn.Sequential):
    def __init__(self) -> None:
        super().__init__(
            Conv2d(
                in_channels=1,
                out_channels=64,
                kernel_size=3,
                padding_props=Pad2dPropsWithSame(
                    padding="same",
                ),
            ),
            EncoderLayer(
                embedding_size=64,
                attn_props=AttentionNoChannelsProps(
                    heads=6,
                    key_size=96,
                    key_conv_props=Conv2dNoChannelsFixedPaddingProps(
                        kernel_size=3,
                    ),
                    value_size=96,
                    value_conv_props=Conv2dNoChannelsFixedPaddingProps(
                        kernel_size=3,
                    ),
                    attn_conv_props=Conv2dNoChannelsFixedPaddingProps(
                        kernel_size=7,
                    ),
                    out_conv_props=Conv2dNoChannelsFixedPaddingProps(
                        kernel_size=3,
                    ),
                ),
                feedforward_conv_props=Conv2dNoChannelsFixedPaddingProps(
                    kernel_size=3,
                ),
            ),
            nn.MaxPool2d(
                kernel_size=2,
            ),
            EncoderLayer(
                embedding_size=64,
                attn_props=AttentionNoChannelsProps(
                    heads=6,
                    key_size=96,
                    key_conv_props=Conv2dNoChannelsFixedPaddingProps(
                        kernel_size=3,
                    ),
                    value_size=96,
                    value_conv_props=Conv2dNoChannelsFixedPaddingProps(
                        kernel_size=3,
                    ),
                    attn_conv_props=Conv2dNoChannelsFixedPaddingProps(
                        kernel_size=7,
                    ),
                    out_conv_props=Conv2dNoChannelsFixedPaddingProps(
                        kernel_size=3,
                    ),
                ),
                feedforward_conv_props=Conv2dNoChannelsFixedPaddingProps(
                    kernel_size=3,
                ),
            ),
            nn.MaxPool2d(
                kernel_size=2,
            ),
            EncoderLayer(
                embedding_size=64,
                attn_props=AttentionNoChannelsProps(
                    heads=6,
                    key_size=96,
                    key_conv_props=Conv2dNoChannelsFixedPaddingProps(
                        kernel_size=3,
                    ),
                    value_size=96,
                    value_conv_props=Conv2dNoChannelsFixedPaddingProps(
                        kernel_size=3,
                    ),
                    attn_conv_props=Conv2dNoChannelsFixedPaddingProps(
                        kernel_size=7,
                    ),
                    out_conv_props=Conv2dNoChannelsFixedPaddingProps(
                        kernel_size=3,
                    ),
                ),
                feedforward_conv_props=Conv2dNoChannelsFixedPaddingProps(
                    kernel_size=3,
                ),
            ),
            nn.MaxPool2d(
                kernel_size=2,
            ),
            EncoderLayer(
                embedding_size=64,
                attn_props=AttentionNoChannelsProps(
                    heads=6,
                    key_size=96,
                    key_conv_props=Conv2dNoChannelsFixedPaddingProps(
                        kernel_size=3,
                    ),
                    value_size=96,
                    value_conv_props=Conv2dNoChannelsFixedPaddingProps(
                        kernel_size=3,
                    ),
                    attn_conv_props=Conv2dNoChannelsFixedPaddingProps(
                        kernel_size=7,
                    ),
                    out_conv_props=Conv2dNoChannelsFixedPaddingProps(
                        kernel_size=3,
                    ),
                ),
                feedforward_conv_props=Conv2dNoChannelsFixedPaddingProps(
                    kernel_size=3,
                ),
            ),
            nn.MaxPool2d(
                kernel_size=2,
            ),
            nn.Flatten(),
            nn.Linear(
                in_features=64 * 4 * 4,
                out_features=len(dataset.classes),
            ),
        )

In [None]:
from torch.utils.data import DataLoader

BATCH_SIZE = 1

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
model = Encoder()

In [None]:
from torch import optim

loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [None]:
import typing
from tqdm.auto import tqdm

def train_step(
    model: nn.Module,
    loss_func: nn.Module,
    optimizer: optim.Optimizer,
    device: str = "cpu",
) -> tuple[float, float]:
    model.train()
    total_correct_count = 0
    total_loss = 0.0

    for (x, y) in tqdm(train_loader):
        x = x.to(device)
        y = y.to(device)
        y_logits: Tensor = model(x)
        loss: Tensor = loss_func(y_logits, y)

        total_loss += typing.cast(float, loss.item() * x.shape[0])
        total_correct_count += (y_logits.argmax(dim=-1).long() == y).int().sum().item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    avg_loss = total_loss / len(train_set)
    accuracy = total_correct_count / len(train_set)

    return avg_loss, accuracy

def test_step(
    model: nn.Module,
    loss_func: nn.Module,
    device: str = "cpu",
) -> tuple[float, float]:
    model.eval()
    total_correct_count = 0
    total_loss = 0.0

    with torch.inference_mode():
        for (x, y) in tqdm(test_loader):
            x = x.to(device)
            y = y.to(device)
            y_logits: Tensor = model(x)
            loss: Tensor = loss_func(y_logits, y)

            total_loss += typing.cast(float, loss.item() * x.shape[0])
            total_correct_count += (y_logits.argmax(dim=-1).long() == y).int().sum().item()

    avg_loss = total_loss / len(test_set)
    accuracy = total_correct_count / len(test_set)

    return avg_loss, accuracy

In [None]:
model.to(device)
epochs = 10

for epoch in range(epochs):
    train_loss, train_acc = train_step(model, loss_func, optimizer)
    test_loss, test_acc = test_step(model, loss_func)

    print(f"epoch={epoch} train_loss={train_loss} train_acc={train_acc} test_loss={test_loss} test_acc={test_acc}")