# Imports


In [None]:
from os.path import join as pjoin
import datetime
import os
import re

import pandas as pd
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.tensorboard import SummaryWriter

from torchmetrics.classification import MulticlassAccuracy, MulticlassPrecision, MulticlassRecall, MulticlassStatScores
from torchmetrics.functional.classification.accuracy import _accuracy_reduce

import matplotlib.pyplot as plt
import matplotlib

import torchvision
from torchvision import transforms
from torchvision.io import read_image

from tqdm import tqdm
import torchinfo

# Константы 

In [None]:
WEIGHT_SAVER = "last" # "all" / "nothing" / "last"
epochs = 25
NUM_WORKERS = 0

batch_size = 128
CLASSES = ['abraham_grampa_simpson', 'agnes_skinner', 'apu_nahasapeemapetilon', 'barney_gumble', 'bart_simpson',
    'carl_carlson', 'charles_montgomery_burns', 'chief_wiggum', 'cletus_spuckler', 'comic_book_guy', 'disco_stu',
    'edna_krabappel', 'fat_tony', 'gil', 'groundskeeper_willie', 'homer_simpson', 'kent_brockman', 'krusty_the_clown', 'lenny_leonard',
    'lionel_hutz', 'lisa_simpson', 'maggie_simpson', 'marge_simpson', 'martin_prince', 'mayor_quimby', 'milhouse_van_houten', 'miss_hoover',
    'moe_szyslak', 'ned_flanders', 'nelson_muntz', 'otto_mann', 'patty_bouvier', 'principal_skinner', 'professor_john_frink', 'rainier_wolfcastle',
    'ralph_wiggum', 'selma_bouvier', 'sideshow_bob', 'sideshow_mel', 'snake_jailbird', 'troy_mcclure', 'waylon_smithers']

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
NUM_CLASSES = len(CLASSES)
IMAGE_RESIZE = (64, 64)

NORMALIZE_MEAN = (0.5, 0.5, 0.5)
NORMALIZE_DEVIATIONS = (0.5, 0.5, 0.5)

writer = SummaryWriter("runs/simpsons")

# Дополнительные функции для интеграции с TensorBoard

In [None]:
def matplotlib_imshow(img, title=None, plt_ax=plt):
    img = img.numpy().transpose((1, 2, 0))
    mean = np.array(NORMALIZE_MEAN)
    std = np.array(NORMALIZE_DEVIATIONS)
    img = std * img + mean
    plt_ax.imshow(img)
    if title is not None:
        plt_ax.set_title(title)
    plt_ax.grid(False)

def images_to_probs(net, images):
    output = net(images)
    # convert output probabilities to predicted class
    preds_tensor = torch.max(output, 1)[1].cpu()
    preds = np.squeeze(preds_tensor.numpy())
    return preds, [nn.functional.softmax(el, dim=0)[i].item() for i, el in zip(preds, output)]


def plot_classes_preds(net, images, labels):
    preds, probs = images_to_probs(net, images)
    fig = plt.figure(figsize=(12, 8))
    for idx in np.arange(4):
        ax = fig.add_subplot(1, 4, idx+1, xticks=[], yticks=[])
        matplotlib_imshow(images[idx].cpu(), one_channel=True)
        ax.set_title("{0}, {1:.1f}%\n(label: {2})".format(
            CLASSES[preds[idx]],
            probs[idx] * 100.0,
            CLASSES[labels[idx]]),
                    color=("green" if preds[idx]==labels[idx].item() else "red"))
    return fig

# Transformers 

In [None]:
transform = torchvision.transforms.Compose([
    # torchvision.transforms.ToPILImage(),
    transforms.Resize(IMAGE_RESIZE),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(NORMALIZE_MEAN, NORMALIZE_DEVIATIONS),
])

class LabelTransformer():
    def __init__(self, initial_list = CLASSES):
        self.initial_list = initial_list
    
    def __call__(self, val):
        if type(val) == str:
            return self.toInt(val)
        elif type(val) == int or type(val) == np.int64:
            return self.toStr(val)
        return None

    def toInt(self, label : str) -> int:
        return np.where(self.initial_list == label)[0][0]
    
    def toStr(self, ind : int) -> str:
        return self.initial_list[ind]

# Датасет

In [None]:
full_dataset = torchvision.datasets.ImageFolder(root="dataset/characters/", transform=transform)
generator=torch.Generator().manual_seed(42)
train_dataset, valid_dataset = torch.utils.data.random_split(full_dataset, [0.8, 0.2], generator=generator)
if CLASSES != full_dataset.classes:
    print("Информация о классах обновлена в соответствии с датасетом")
    CLASSES = full_dataset.classes
    NUM_CLASSES = len(CLASSES)

# DataLoaders

In [None]:
train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=NUM_WORKERS,
)
valid_dataloader = DataLoader(
    dataset=valid_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=NUM_WORKERS,
)

images, labels = next(iter(train_dataloader))
img_grid = torchvision.utils.make_grid(images)
writer.add_image('train samples', img_grid)

# Модель

In [None]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True)
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.MaxPool2d(2)
        )

        # self.shuffle1_v1 = nn.Sequential(
        #     nn.Sequential(
        #     nn.Conv2d(128, 128, kernel_size=3, padding=1),
        #     nn.BatchNorm2d(128),
        #     nn.ReLU(True)
        # ), nn.Sequential(
        #     nn.Conv2d(128, 128, kernel_size=3, padding=1),
        #     nn.BatchNorm2d(128),
        #     nn.ReLU(True))
        # )

        self.shuffle1_v2 = nn.Sequential(
            nn.Conv2d(128, 512, kernel_size=1),
            nn.Conv2d(512, 512, kernel_size=3, groups=512, padding=1),
            nn.Conv2d(512, 128, kernel_size=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True)
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.MaxPool2d(2)
        )

        self.conv4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.MaxPool2d(2)
        )

        # self.shuffle2_v1 = nn.Sequential(nn.Sequential(
        #     nn.Conv2d(512, 512, kernel_size=3, padding=1),
        #     nn.BatchNorm2d(512),
        #     nn.ReLU(True)
        # ), nn.Sequential(
        #     nn.Conv2d(512, 512, kernel_size=3, padding=1),
        #     nn.BatchNorm2d(512),
        #     nn.ReLU(True))
        # )

        self.shuffle2_v2 = nn.Sequential(
            nn.Conv2d(512, 2048, 1),
            nn.Conv2d(2048, 2048, kernel_size=3, groups=2048, padding=1),
            nn.Conv2d(2048, 512, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(True)
        )

        self.conv5 = nn.Sequential(
            nn.Conv2d(512, 1024, kernel_size=3, padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU(True),
            nn.MaxPool2d(2)
        )

        self.res = nn.Sequential(
            nn.MaxPool2d(4),
            nn.Flatten(),
            nn.Dropout(0.1),
            nn.Linear(1024, 512),
            nn.Linear(512, 128),
            nn.Linear(128, len(CLASSES)),
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.shuffle1_v2(x) + x
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.shuffle2_v2(x) + x
        x = self.conv5(x)
        x = self.res(x)
        return x


# Оптимизатор, критерий, сеть

In [None]:
net = Model()

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(
    params=net.parameters(),
    lr=0.01,
    momentum=0.9,
)

writer.add_graph(net, images)

net.to(device)
print(model_sum := torchinfo.summary(net, input_size=(batch_size, 3, *IMAGE_RESIZE), row_settings=["var_names"], verbose=0, col_names=[
      "input_size", "output_size", "num_params", "params_percent", "kernel_size", "mult_adds", "trainable"]))


In [None]:
logLearnFolder = f"{os.path.abspath('.')}/log [{epochs}] {datetime.datetime.now().replace(microsecond=0)}"
os.makedirs(logLearnFolder)
logLearnFileName = f"{logLearnFolder}/log.txt"
learnLogFile = open(logLearnFileName, 'w', encoding='utf8')
learnLogFile.write(str(model_sum))
learnLogFile.write('\n')

# Шаги обучения и валидации

In [None]:
def train_step(net, dataloader, epoch : int):
    net.train()
    running_loss = 0.
    for images, labels in dataloader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        output = net(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss

    with torch.no_grad():
        train_loss = running_loss / len(train_dataloader)
    return train_loss.item()

def valid_step(net, dataloader, epoch : int):
    net.eval()
    acc = MulticlassAccuracy(num_classes=len(CLASSES), average="micro")
    recall = MulticlassRecall(num_classes=len(CLASSES), average="macro")
    precision = MulticlassPrecision(num_classes=len(CLASSES), average="macro")
    acc.to(device)
    recall.to(device)
    precision.to(device)

    running_loss = 0.

    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)

            output = net(images)

            acc.update(output, labels)
            recall.update(output, labels)
            precision.update(output, labels)

            loss = criterion(output, labels)
            running_loss += loss

        valid_loss = running_loss / len(valid_dataloader)

        return valid_loss.item(), acc.compute().item(), recall.compute().item(), precision.compute().item()

# Цикл обучения

In [None]:
best_loss = 10000

for epoch in (pbar := tqdm(range(epochs))):
    train_loss = train_step(epoch)
    valid_loss, valid_acc, valid_recall, valid_precision = valid_step(epoch)
    # scheduler.step(valid_loss)

    if WEIGHT_SAVER != "nothing" and valid_loss < best_loss and epoch > 3:
        best_loss = valid_loss

        print(f"Saved weights with acc/rec/prec: {valid_acc:.2f}/{valid_recall:.2f}/{valid_precision:.2f} | loss: {valid_loss:.4f}")
        if learnLogFile is not None:
            learnLogFile.write(
                f"Saved weights with acc/rec/prec: {valid_acc:.2f}/{valid_recall:.2f}/{valid_precision:.2f}")
            learnLogFile.write(f"| loss: {valid_loss:.4f}\n")

        if WEIGHT_SAVER == "all":
            torch.save(net.state_dict(),
                       f"{logLearnFolder}/weights_{epoch}.pth")
        elif WEIGHT_SAVER == "last":
            torch.save(net.state_dict(),
                       f"{logLearnFolder}/weights_last.pth")

    writer.add_scalar('valid loss', valid_loss, epoch)
    writer.add_scalar('train loss', train_loss, epoch)
    
    writer.add_scalar('train loss', valid_acc, epoch)
    writer.add_scalar('train loss', valid_recall, epoch)
    writer.add_scalar('train loss', valid_precision, epoch)

    for i, param_group in enumerate(optimizer.param_groups):
        lrs.append(float(param_group['lr']))
    train_losses.append(train_loss)
    valid_losses.append(valid_loss)
    valid_accs.append(valid_acc)
    valid_recs.append(valid_recall)
    valid_precs.append(valid_precision)

    if learnLogFile is not None:
        learnLogFile.write(
            f"[{epoch}] acc/rec/prec: {valid_acc:.2f}/{valid_recall:.2f}/{valid_precision:.2f} ")
        learnLogFile.write(f"| train/valid loss: {train_loss:.4f}/{valid_loss:.4f}\n")
    pbar.set_description(
        f'acc/rec/prec: {valid_acc:.2f}/{valid_recall:.2f}/{valid_precision:.2f}  | train/valid loss: {train_loss:.4f}/{valid_loss:.4f}')
