# $\text{Import libraries}$

## $\text{Based libraries}$

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from varname import argname
import polars as pl
from pathlib import Path
import subprocess
import shutil
import platform
import random

## $\text{Pytorch libraries}$

In [None]:
import torch
from torch import nn
from torch import tensor
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from torchinfo import summary
import torchmetrics.classification as tmc

## $\text{Lightning libraries}$

In [None]:
import lightning.pytorch as ptl
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import LearningRateFinder
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint, Callback
from lightning.pytorch.tuner.tuning import Tuner

# $\text{Utils functions}$

In [None]:
def get_metric_result_parameter(predictions, targets, task):
    pass

# $\text{Config class}$

In [None]:
class Config:
    def __init__(self):
        self.train_split = 0.7
        self.val_split = 0.2
        self.test_split = 0.1
        self.lr = 0.0001
        self.batch_size = 256

In [None]:
CONFIG = Config()

In [None]:
torch.manual_seed(6742)
random.seed(6742)

# $\text{Read and prepare data}$

# $\text{Make data config}$

In [None]:
class Data_Dataset(Dataset):
    def __init__(self, *args, **kwargs):
        pass
    def __len__(self, *args, **kwargs):
        pass
    def __getitem__(self, idx, *args, **kwargs):
        data = None
        label = None
        return data, label

In [None]:
class Dataset_config:
    def __init__(self,config *args, **kwargs):
        self.train_dataset = Data_Dataset()
        self.val_dataset = Data_Dataset()
        self.test_dataset = Data_Dataset()
        

        self.train_dataloader = DataLoader(
            self.train_dataset,
            shuffle=True,
            # num_workers=2,
            batch_size=config.batch_size,
        )
        self.val_dataloader = DataLoader(
            self.val_dataset,
            shuffle=False,
            # num_workers=2,
            batch_size=config.batch_size,
        )
        self.test_dataloader = DataLoader(
            self.test_dataset,
            shuffle=False,
            # num_workers=2,
            batch_size=config.batch_size,
        )

In [None]:
DATA_CONFIG = Dataset_config(config=CONFIG)

In [None]:
DATA_CONFIG.train_dataset.__len__(), \
DATA_CONFIG.val_dataset.__len__(), \
DATA_CONFIG.test_dataset.__len__()

# $\text{Make model}$

## $\text{Base Pytorch model}$

In [None]:
class Base_Pytorch_Model(nn.Module):
    def __init__(self):
        super().__init__()
        pass

    def forward(self, data, *args, **kwargs):
        output = None
        return output

In [None]:
base_model = Base_Pytorch_Model()

In [None]:
data_iter = iter(DATA_CONFIG.train_dataloader)

In [None]:
summary(base_model, next(data_iter)[0].shape, depth=5)

## $\text{Lightning model}$

In [None]:
class Lightning_Wrapper(ptl.LightningModule):
    def __init__(self, model, config):
        super().__init__()
        self.model = model
        self.learning_rate = config.lr
        self.loss_fn = config.loss_fn
        self.optimizer = config.optimizer

    def configure_optimizers(self):
        return self.optimizer(self.model.parameters(), lr=self.learning_rate)

    def forward(self, data):
        return self.model(data)

    # __________________TRAIN_____________________________________________________

    def training_step(self, batch, batch_idx):
        try:
            data, label, _ = batch
            result = self(data)

            predicted_class = None

            accuracy = predicted_class == label

            loss = self.loss_fn(result, label)

            logs = {
                "train_acc": accuracy.count_nonzero().item() / len(accuracy) * 100,
            }
            self.log_dict(logs)
            self.log("train_loss", loss, prog_bar=True)
            return loss
        except:
            pass

    # __________________VALIDATION_____________________________________________________

    def validation_step(self, batch, batch_idx):
        try:
            data, label, _ = batch
            result = self(data)

            predicted_class = None

            accuracy = predicted_class == label

            loss = self.loss_fn(result, label)

            logs = {
                "val_acc": accuracy.count_nonzero().item() / len(accuracy) * 100,
            }
            self.log_dict(logs)
            self.log("val_loss", loss, prog_bar=True)
            return loss
        except:
            pass

    # __________________TEST_____________________________________________________

    def test_step(self, batch, batch_idx):
        try:
            data, label, _ = batch
            result = self(data)

            predicted_class = None

            accuracy = predicted_class == label

            loss = self.loss_fn(result, label)

            logs = {
                "test_acc": accuracy.count_nonzero().item() / len(accuracy) * 100,
            }
            self.log_dict(logs)
            self.log("test_loss", loss, prog_bar=True)
            return loss
        except:
            pass

    # _____________________PREDICT_____________________________________________________

    def predict_step(self, batch, batch_idx):
        try:
            data, label = batch
            result = self(data)

            return (result, label)
        except:
            pass

In [None]:
# base_model.cpu()
# lightning_model.cpu()
# del trainer, base_model, lightning_model
# del tuner

import gc
gc.collect()

# $\text{Logging}$

In [None]:
shutil.rmtree("lightning_logs", ignore_errors=True)
Path("lightning_logs").mkdir(exist_ok=True)

In [None]:
try:
    if platform.system() == "Windows":
        subprocess.run(["kill", "-name" "tensorboard"])
    else:
        subprocess.run(["killall", "tensorboard"])
except:
    pass

In [None]:
%reload_ext tensorboard
%tensorboard --logdir=lightning_logs --port=6006

In [None]:
version = None
log_step = None
logger = TensorBoardLogger("lightning_logs",
                           name="",
                           version=version
                           )
checkpoint_callback_train = ModelCheckpoint(
    save_top_k=1,
    monitor="train_loss",
    dirpath=f"lightning_logs/version_{version}/checkpoints/train/",
    filename="train_model-{epoch}-{train_loss}-{val_loss}",
    verbose=True,
    save_on_train_epoch_end=True
)
checkpoint_callback_val = ModelCheckpoint(
    save_top_k=1,
    monitor="val_loss",
    dirpath="lightning_logs/version_0/checkpoints/val/",
    filename="val_model-{epoch}-{train_loss}-{val_loss}",
    verbose=True,
    save_on_train_epoch_end=True
)

trainer = ptl.Trainer(
    #     fast_dev_run = True,
    max_epochs=100,
    callbacks=[
        checkpoint_callback_train,
        checkpoint_callback_val
    ],
    accelerator="auto",
    #     accumulate_grad_batches=64,
    benchmark=True,
    logger=logger,
    log_every_n_steps=log_step,
)

base_model = Base_Pytorch_Model()
lightning_model = Lightning_Wrapper(base_model, CONFIG)

# $\text{Trainning step}$

In [None]:
tuner = Tuner(trainer)
tuner.lr_find(lightning_model,
              train_dataloaders=DATA_CONFIG.train_dataloader,
              val_dataloaders=DATA_CONFIG.val_dataloader,
              )

trainer.fit(lightning_model,
            train_dataloaders=DATA_CONFIG.train_dataloader,
            val_dataloaders=DATA_CONFIG.val_dataloader,
            )

# $\text{Testing step}$

## $\text{Train model best}$

In [None]:
train_pth = list(Path("lightning_logs").rglob("train_model*.ckpt"))[0]
print(train_pth)

In [None]:
train_result = trainer.test(
    lightning_model, dataloaders=DATA_CONFIG.test_dataloader, ckpt_path=train_pth
)

## $\text{Val model best}$

In [None]:
val_pth = list(Path("lightning_logs").rglob("val_model*.ckpt"))[0]
print(val_pth)

In [None]:
val_result = trainer.test(
    lightning_model, dataloaders=DATA_CONFIG.test_dataloader, ckpt_path=val_pth
)

In [None]:
if train_result[0]["test_loss"] < val_result[0]["test_loss"]:
    best_model_path = train_pth
    print("Train model is better")
    print(f"train_model loss = {train_result[0]['test_loss']}")
    print(f"train_model accuracy = {train_result[0]['test_acc']}")
else:
    best_model_path = val_pth
    print("Val model is better")
    print(f"val_model loss = {val_result[0]['test_loss']}")
    print(f"val_model accuracy = {val_result[0]['test_acc']}")

# $\text{Predicting step}$

In [None]:
result = trainer.predict(
    lightning_model, dataloaders=DATA_CONFIG.test_dataloader, ckpt_path=best_model_path
)

In [None]:
predict_output = (torch.cat([x[0] for x in result]))
predict_truth = torch.cat([x[1] for x in result])

# $\text{Metrics calculation}$

## $\text{Parameters}$

### $\text{Classification}$

In [None]:
num_class, predict_class = get_metric_result_parameter(predict_output, predict_truth, "classification")
name_dict = {}

## $\text{Accuracy}$

In [None]:
accuracy = tmc.MulticlassAccuracy(
    num_classes=num_class,
    average="none"
)
accuracy(predict_output, predict_truth)

## $\text{Precision}$

In [None]:
precision = tmc.MulticlassPrecision(
    num_classes=num_class,
    average="none",
)

precision(predict_output, predict_truth)

## $\text{Recall}$

In [None]:
recall = tmc.MulticlassRecall(
    num_classes=num_class,
    average="none",
)

recall(predict_output, predict_truth)

## $\text{F1 score}$

In [None]:
F1_score = tmc.MulticlassF1Score(
    num_classes=num_class,
    average='none',
)

F1_score(predict_output, predict_truth)

## $\text{ROC}$

In [None]:
ROC = tmc.MulticlassROC(
    num_classes=num_class,
)

ROC.update(predict_output, predict_truth)

fig_ROC, ax_ROC = ROC.plot(score=True)

handles_ROC, labels_ROC = ax_ROC.get_legend_handles_labels()
labels_ROC = [f"{name_dict[int(x.split()[0])]} {x.split()[1]}" for x in labels_ROC]
ax_ROC.legend(handles_ROC, labels_ROC)

plt.show()

## $\text{Precision Recall curve}$

In [None]:
PRC = tmc.MulticlassPrecisionRecallCurve(
    num_classes=num_class,
)

PRC.update(predict_output, predict_truth)

fig_PRC, ax_PRC = PRC.plot(score=True)
handles_PRC, labels_PRC = ax_PRC.get_legend_handles_labels()

labels_PRC = [
    f"{name_dict[int(x.split()[0])]} {x.split()[1]}" for x in labels_PRC
]
ax_PRC.legend(handles_PRC, labels_PRC)

plt.show()

## $\text{Confusion matrix}$

In [None]:
confusion_matrix = tmc.MulticlassConfusionMatrix(
    num_classes=num_class,
    normalize='true',
)

confusion_matrix.update(predict_class, predict_truth)

fig_confusion_matrix, ax_confusion_matrix = confusion_matrix.plot(
    labels=name_dict.values())
plt.show()