# Imports + general

In [None]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle

import numpy as np
import pandas as pd
import random
import os

from tqdm import tqdm

import torchvision
import torchvision.models as models
from torchvision import transforms, datasets
import torch.utils.data as data
import torchvision.datasets
from torch.utils.data.sampler import WeightedRandomSampler

from sklearn.model_selection import train_test_split
from sklearn.metrics import average_precision_score, precision_recall_fscore_support, confusion_matrix, classification_report, accuracy_score

from torchvision.models.resnet import conv3x3, _resnet, ResNet18_Weights

import matplotlib.pyplot as plt
from PIL import ImageOps, Image

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.backends.cudnn.deterministic = True

# device = 'cpu' 
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Paths to data
data_folder = "/mnt/d/Bachelor_work/data_for_model/kkanji2_known_unknown"

# PreactBasicBlock

In [None]:
class PreactBasicBlock(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(PreactBasicBlock, self).__init__()

        if norm_layer is None:
            norm_layer = nn.BatchNorm2d

        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')

        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")

        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.bn1 = nn.BatchNorm2d(inplanes)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = conv3x3(inplanes, planes, stride)

        self.bn2 = nn.BatchNorm2d(planes)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)

        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.bn1(x)
        out = self.relu1(out)
        out = self.conv1(out)

        out = self.bn2(out)
        out = self.relu2(out)
        out = self.conv2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity

        return out

    # Check what is model params and conv3x3. Also normlayer is not used.

# Cycle func

In [None]:
def cycle(model, threshold, loader, criterion, train=False, optimizer=None):
    model.eval()
    overall_loss = []
    overall_accuracy = []
    dataloader_iterator = iter(loader)

    if train:
        threshold.requires_grad = True
    else:
        threshold.requires_grad = False

    for batch_id in tqdm(range(len(loader))):   
        try:
            X_batch, y_batch = next(dataloader_iterator)

        except:
            dataloader_iterator = iter(loader)
            X_batch, y_batch = next(dataloader_iterator)
    
        X_batch = X_batch.to(device)
        y_batch = y_batch.to(device)

        if train:
            optimizer.zero_grad()

            preds = model(X_batch) 
            confidences = torch.max(F.softmax(preds, dim=1), 1)[0]

            binary_predictions = confidences > threshold
            loss = criterion(binary_predictions, y_batch)

            loss.backward()
            optimizer.step()

        else:
            with torch.no_grad():
                preds = model(X_batch) 
                confidences= torch.max(F.softmax(preds, dim=1), 1)[0]

                binary_predictions = confidences > threshold
                loss = criterion(binary_predictions, y_batch)

        overall_loss.append(loss.item().mean())
        accuracy = (binary_predictions == y_batch).float().mean().item()
        overall_accuracy.append(accuracy)

    mean_loss = sum(overall_loss) / len(loader)
    mean_accuracy = sum(overall_accuracy) / len(loader)
    return mean_loss.item(), mean_accuracy.item()



        

# Dataloaders

In [None]:
# Dataset initialization
def get_dataloaders(batch_size: int = 4096, train_test_indices_path: str = None):
    # Define transformations
    my_transform = transforms.Compose([
                transforms.Grayscale(num_output_channels=3),
                transforms.Resize(64),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # Normalize each channel to (-1, 1)  # Normalize to (-1, 1)
            ])
        
    full_dataset = datasets.ImageFolder(root=data_folder, transform=my_transform)

    train_indices = []
    test_indices = []
    # Split dataset into training and testing sets

    if train_test_indices_path is not None and os.path.exists(train_test_indices_path):
        with open(train_test_indices_path, "rb") as f:
            indices = pickle.load(f)
            train_indices = indices["train_indices"]
            test_indices = indices["test_indices"]
                
    else:
        train_indices, test_indices = train_test_split(
            list(range(len(full_dataset))),
            test_size=0.3,
            stratify=[label for _, label in full_dataset.samples]
        )
        
        # Save indices
        with open(train_test_indices_path, "wb") as f:
            pickle.dump({"train_indices": train_indices, "test_indices": test_indices}, f)

    train_dataset = data.Subset(full_dataset, train_indices)
    test_dataset = data.Subset(full_dataset, test_indices)

    # Create DataLoaders
    trainloader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    testloader = data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    return trainloader, testloader, full_dataset

# Getting stuff for training

In [None]:
def getting_stuff_for_training(pathes: dict = None, lr: float = 1e-3, scheduler_step_size: int = 10, scheduler_gamma: float = 0.1):    

    stuff = {}
    last_epoch = 0

    if os.path.exists(pathes["last_epoch"]):
        last_epoch = np.load(pathes["last_epoch"])
    stuff["last_epoch"] = last_epoch

    model = _resnet(PreactBasicBlock, [2, 2, 2, 2], None, progress=False) # 'resnet18'
    model.fc = nn.Linear(model.fc.in_features, 300)
    
    if os.path.exists(pathes["model"]):
        model.load_state_dict(torch.load(pathes["model"]))
    model = model.to(device)
    stuff["model"] = model

    threshold = nn.Parameter(torch.tensor(0.5, requires_grad=True))
    if os.path.exists(pathes["threshold"]):
        threshold = torch.load(pathes["threshold"])
    stuff["threshold"] = threshold

    optimizer = torch.optim.Adam([threshold], lr=lr) # , weight_decay=0.)
    if os.path.exists(pathes["optimizer"]):
        optimizer.load_state_dict(torch.load(pathes["optimizer"]))
    stuff["optimizer"] = optimizer

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step_size, gamma=scheduler_gamma)
    if os.path.exists(pathes["scheduler"]):
        scheduler.load_state_dict(torch.load(pathes["scheduler"]))
    stuff["scheduler"] = scheduler

    train_loss_history = []
    if os.path.exists(pathes["train_loss_history"]):
        train_loss_history = list(np.load(pathes["train_loss_history"]))
    stuff["train_loss_history"] = train_loss_history

    train_accuracy_history = []
    if os.path.exists(pathes["train_accuracy_history"]):
        train_accuracy_history = list(np.load(pathes["train_accuracy_history"]))
    stuff["train_accuracy_history"] = train_accuracy_history

    val_loss_history = []
    if os.path.exists(pathes["val_loss_history"]):
        val_loss_history = list(np.load(pathes["val_loss_history"]))
    stuff["val_loss_history"] = val_loss_history

    val_accuracy_history = []
    if os.path.exists(pathes["val_accuracy_history"]):
        val_accuracy_history = list(np.load(pathes["val_accuracy_history"]))
    stuff["val_accuracy_history"] = val_accuracy_history

    max_val_accuracy = 0
    if os.path.exists(pathes["max_val_accuracy"]):
        max_val_accuracy = np.load(pathes["max_val_accuracy"])
    stuff["max_val_accuracy"] = max_val_accuracy

    early_stopping = 5
    if os.path.exists(pathes["early_stopping"]):
        early_stopping = np.load(pathes["early_stopping"])
    stuff["early_stopping"] = early_stopping

    return stuff

# Plot func

In [None]:
def custom_plot(train_data : list = None, val_data : list = None, title: str = None,  ylabel: str = None):
    plt.plot(train_data, label="train")
    plt.plot(val_data, label="validation")
    plt.xlabel('Epochs')
    plt.ylabel(ylabel)
    plt.legend()
    plt.title(title)
    return

# Create pathes vocabulary

In [None]:
def create_pathes_vocabulary(path : str = None):
    pathes = {"threshold": f"{path}/threshold.pth","last_epoch": f"{path}/last_epoch.npy", "model": f"{path}/model.pth", "optimizer": f"{path}/optimizer.pth", 
    "scheduler": f"{path}/scheduler.pth", "train_loss_history": f"{path}/train_loss_history.npy", "train_accuracy_history": f"{path}/train_accuracy_history.npy", 
    "val_loss_history": f"{path}/val_loss_history.npy", "val_accuracy_history": f"{path}/val_accuracy_history.npy", 
    "max_val_accuracy": f"{path}/max_val_accuracy.npy", "early_stopping": f"{path}/early_stopping.npy"}

    return pathes
    

# Train loop

In [None]:
def train_loop(trainloader, testloader, stuff : dict = None, criterion=nn.CrossEntropyLoss(), save_path : str = None, epochs : int = 100, model_name : str = None):

    last_epoch = stuff["last_epoch"]
    epochs += last_epoch 

    model = stuff["model"]
    threshold = stuff["threshold"]
    optimizer = stuff["optimizer"]
    scheduler = stuff["scheduler"]

    train_loss_history = stuff["train_loss_history"]
    train_accuracy_history = stuff["train_accuracy_history"]

    val_loss_history = stuff["val_loss_history"]
    val_accuracy_history = stuff["val_accuracy_history"]

    max_val_accuracy = stuff["max_val_accuracy"]
    early_stopping = stuff["early_stopping"]

    for epoch in tqdm(range(last_epoch, epochs)):

        train_loss, train_accuracy = cycle(model, threshold, trainloader, criterion, train=True, optimizer=optimizer) #train(model, trainloader, criterion, optimizer, batch_size)
        train_loss_history.append(train_loss)
        train_accuracy_history.append(train_accuracy)

        val_loss, val_accuracy = cycle(model, threshold, testloader, criterion) #validate(model, testloader, criterion, batch_size)
        val_loss_history.append(val_loss)
        val_accuracy_history.append(val_accuracy)

        scheduler.step()

        print('Epoch:', epoch+1)
        print('Train: loss', train_loss, 'accuracy', train_accuracy)
        print('Validation: loss', val_loss, 'accuracy', val_accuracy)

        if val_accuracy > max_val_accuracy:
            max_val_accuracy = val_accuracy
            early_stopping = 5

            # Save the model
            torch.save(model.state_dict(), f'{save_path}/model.pth')

            # Save the threshold
            torch.save(threshold, f'{save_path}/threshold.pth')

            # Save the optimizer
            torch.save(optimizer.state_dict(), f'{save_path}/optimizer.pth')

            # Save the scheduler
            torch.save(scheduler.state_dict(), f'{save_path}/scheduler.pth')

            # Save the loss history
            np.save(f'{save_path}/train_loss_history.npy', train_loss_history)
            np.save(f'{save_path}/val_loss_history.npy', val_loss_history)

            # Save the accuracy history
            np.save(f'{save_path}/train_accuracy_history.npy', train_accuracy_history)
            np.save(f'{save_path}/val_accuracy_history.npy', val_accuracy_history)

            # Save the last epoch
            np.save(f'{save_path}/last_epoch.npy', epoch+1)

            # Save max validation accuracy
            np.save(f'{save_path}/max_val_accuracy.npy', max_val_accuracy)

            # Save early stopping
            np.save(f'{save_path}/early_stopping.npy', early_stopping)

        else:
            early_stopping -= 1

        if early_stopping == 0:
            break


    print(f'Best validation accuracy for {model_name}:', max(val_accuracy_history), 'Epoch', val_accuracy_history.index(max(val_accuracy_history))+1)
    return train_accuracy_history, val_accuracy_history, train_loss_history, val_loss_history