# Лабораторная работа №3 "Сверточные нейронные сети"

In [1]:
from dataclasses import dataclass
from functools import cache
from itertools import islice
from typing import IO, Any, Callable, Mapping, Optional, Sequence, Sized, Tuple, cast
from zipfile import ZipFile

import torch.nn as nn
from matplotlib import pyplot as plt
from PIL import Image as image
from PIL.Image import Image
from torch import Tensor, cuda, device, mps, no_grad, optim
from torch import max as torch_max
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.datasets import VisionDataset


def RPSClassifier() -> nn.Sequential:
    return nn.Sequential(
        nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Flatten(),
        nn.Linear(256 * 18 * 18, 512),
        nn.ReLU(),
        nn.Linear(512, 3),
        nn.Sigmoid(),
    )


class ZipImageDataset[TSample, TTarget](VisionDataset):
    _zip: ZipFile
    _loader: Callable[[IO[bytes]], Image]
    _classes: Sequence[str]
    _target_mapper: Mapping[str, int]
    _dataset: Sequence[Tuple[str, int]]

    def __init__(
        self,
        zip_filename: str,
        loader: Optional[Callable[[IO[bytes]], Image]] = None,
        transform: Optional[Callable[[Image], TSample]] = None,
        target_transform: Optional[Callable[[int], TTarget]] = None,
    ) -> None:
        super().__init__(
            zip_filename, transform=transform, target_transform=target_transform
        )
        self._zip = ZipFile(zip_filename, "r")
        self._loader = loader or ZipImageDataset.default_loader

        self._classes = self._build_classes()
        self._target_mapper = self._build_target_mapper()
        self._dataset = self._build_dataset()

    @property
    def classes(self) -> Sequence[str]:
        return self._classes

    def close(self) -> None:
        self._zip.close()

    def _build_classes(self) -> Sequence[str]:
        directories = [e for e in self._zip.infolist() if e.is_dir()]
        max_deep = max(e.filename.count("/") for e in directories)

        return sorted(
            ZipImageDataset._get_target(e.filename)
            for e in directories
            if e.filename.count("/") == max_deep
        )

    def _build_target_mapper(self) -> Mapping[str, int]:
        return {e: i for i, e in enumerate(self._classes)}

    def _build_dataset(self) -> Sequence[Tuple[str, int]]:
        return [
            (e.filename, self._target_mapper[ZipImageDataset._get_target(e.filename)])
            for e in self._zip.infolist()
            if not e.is_dir()
        ]

    def _unpack_file(self, filename: str) -> IO[bytes]:
        return self._zip.open(filename)

    def __len__(self) -> int:
        return len(self._dataset)

    def __getitem__(self, index: int) -> Tuple[TSample, TTarget]:
        path, target = self._dataset[index]

        with self._unpack_file(path) as stream:
            sample = self._loader(stream)

        if self.transform is not None:
            sample = self.transform(sample)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return cast(TSample, sample), cast(TTarget, target)

    def __enter__(self) -> "ZipImageDataset":
        return self

    def __exit__(self, *_) -> None:
        self.close()

    @staticmethod
    def _get_target(path: str) -> str:
        return path.split("/")[-2]

    @staticmethod
    def default_loader(stream: IO[bytes]) -> Any:
        img = image.open(stream)
        return img.convert("RGB")


class CachedDataset[TSample, TTarget](Dataset):
    def __init__(self, dataset: Dataset) -> None:
        self._dataset = dataset

    @cache
    def __len__(self) -> int:
        if isinstance(self._dataset, Sized):
            return len(self._dataset)

        raise NotImplementedError()

    @cache
    def __getitem__(self, index: int) -> Tuple[TSample, TTarget]:
        return self._dataset[index]

    def __getattr__(self, attr: str) -> Any:
        return getattr(self._dataset, attr)


@dataclass(frozen=True, slots=True)
class EpochResults:
    correct: int
    total: int
    loss: float

    @property
    def accuracy(self) -> float:
        return self.correct / self.total


def show_samples(samples: Sequence[Tuple[Tensor, str]]) -> None:
    size = len(samples)
    plt.subplots(1, size, figsize=(12, 12))

    for i, (sample, target) in enumerate(samples):
        plt.subplot(1, size, i + 1)
        plt.title(target.title())

        # Normalize the image tensor from range [-1, 1] to [0, 1]
        sample = sample / 2 + 0.5

        # Convert the tensor to a NumPy array and transpose the dimensions from (channels, height, width) to (height, width, channels)
        image = sample.numpy().transpose(1, 2, 0).squeeze()
        plt.imshow(image)

    plt.tight_layout()
    plt.show()


def get_device() -> device:
    if mps.is_available():
        return device('mps')

    if cuda.is_available():
        return device('cuda')

    return device('cpu')


def do_epoch(
    model: nn.Module,
    loader: DataLoader,
    device: device,
    criterion: nn.Module,
    optimizer: optim.Optimizer | None = None,
) -> EpochResults:
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)

        if optimizer is not None:
            optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)

        if optimizer is not None:
            loss.backward()
            optimizer.step()

        running_loss += loss.item()
        _, predicted = torch_max(outputs, 1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    return EpochResults(correct, total, running_loss / len(loader))

## Загрузка данных

In [2]:
BATCH_SIZE = 64
transform = transforms.Compose(
    [
        transforms.Resize(150),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5) * 3, std=(0.5) * 3),
    ]
)

train = CachedDataset(ZipImageDataset("data/rps.zip", transform=transform))
train_loader = DataLoader(train, batch_size=BATCH_SIZE, shuffle=True)

test = CachedDataset(ZipImageDataset("data/rps-test-set.zip", transform=transform))
test_loader = DataLoader(test, batch_size=BATCH_SIZE, shuffle=False)

## Проверка предобработанных данных

In [None]:
images, labels = next(iter(train_loader))
samples = ((s, train.classes[t]) for s, t in zip(images, labels))

show_samples([*islice(samples, 5)])

## Обучение модели

In [None]:
DEVICE = get_device()
EPOCHS = 10
LEARNING_RATE = 0.001

model = RPSClassifier().to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

train_results = []
model.train()
for epoch in range(EPOCHS):
    epoch_results = do_epoch(model, train_loader, DEVICE, criterion, optimizer)
    train_results.append(epoch_results)

    print(
        f"Epoch {epoch + 1} / {EPOCHS}, Loss: {epoch_results.loss:.4f}, Accuracy: {epoch_results.accuracy * 100:.2f}%"
    )

In [None]:
epochs = [*range(1, EPOCHS + 1)]
accuracies = [e.accuracy for e in train_results]
losses = [e.loss for e in train_results]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

ax1.plot(epochs, accuracies, color="b", label="Accuracy")
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Accuracy")
ax1.set_title("Accuracy Over Epochs")
ax1.grid(True)

ax2.plot(epochs, losses, color="r", label="Loss")
ax2.set_xlabel("Epoch")
ax2.set_ylabel("Loss")
ax2.set_title("Loss Over Epochs")
ax2.grid(True)

plt.tight_layout()
plt.show()

## Проверка модели на тестовой выборке

In [None]:
model.eval()
with no_grad():
    results = do_epoch(model, test_loader, DEVICE, criterion, None)

    print(
        f"Test dataset, Loss: {results.loss:.4f}, Accuracy: {results.accuracy * 100:.2f}%"
    )