# Mount drive and append path to PYTONPATH


In [None]:
import os
import sys

from google.colab import drive

drive.mount("/content/drive")
sys.path.append("/content/drive/MyDrive/DeepLCMS/train_google_colab")

# Import and install libraries

In [None]:
%%capture
!pip install lightning
!pip install timm
!pip install torchinfo
!pip install scikit-posthocs
!pip install optuna

In [None]:
import gc
from typing import Optional, Tuple
from pathlib import Path

import colab_functions
import colab_utils
import pandas as pd
import prepare_data
import pytorch_lightning as pl
import timm
import torch
import torch.nn.functional as F
import torchinfo
import train_NN
from google.colab import drive
from lightning.pytorch.loggers import CSVLogger
from pytorch_lightning import LightningModule
from pytorch_lightning.callbacks import Callback, EarlyStopping
from pytorch_lightning.trainer.trainer import Trainer
from timm import create_model
from torchmetrics.classification import (
    BinaryAUROC,
    BinaryF1Score,
    BinaryPrecision,
    BinaryRecall,
)

import optuna
from torch import nn
from torch.optim import Adam, SGD, RMSprop
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR
from pytorch_lightning.callbacks import EarlyStopping
from torchmetrics.classification import BinaryF1Score, BinaryPrecision, BinaryRecall


from optuna.visualization import plot_optimization_history
from optuna.visualization import plot_parallel_coordinate
from optuna.visualization import plot_param_importances
from optuna.visualization import plot_contour

In [None]:
# Set the CUDA_VISIBLE_DEVICES environment variable
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Unzip data

In [None]:
%%script echo skipping
!unzip -q "*.zip"

# Check if GPU is used

In [None]:
device = colab_functions.get_device()

# Getting a tunable model

In [None]:
class Resnet_model(pl.LightningModule):
    def __init__(self, hyperparameters):
        super().__init__()
        self.hyperparameters = hyperparameters
        self.model = create_model("resnet50d.a3_in1k", pretrained=True, num_classes=1)

        # Freeze all layers except for the last one
        for param in self.model.parameters():
            param.requires_grad = False

        self.model.fc = nn.Sequential(
            nn.Linear(in_features=2048, out_features=512, bias=True),
            nn.ReLU(),
            nn.Dropout(p=self.hyperparameters["dropout"]),
            nn.Linear(in_features=512, out_features=256, bias=True),
            nn.ReLU(),
            nn.Linear(in_features=256, out_features=1, bias=True),
        )

    def forward(self, x):
        x = self.model(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch

        loss_fn = nn.BCELoss()

        y_pred_logits = self(x).squeeze()
        y_pred = torch.sigmoid(y_pred_logits)
        loss = loss_fn(y_pred, y.float())

        self.log(
            "train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True
        )

        # Calculate metrics

        # Calculate Accuracy
        y_pred_class = torch.round(y_pred)
        acc = (y_pred_class == y).sum().item() / len(y_pred)
        self.log(
            "train_acc", acc, on_step=False, on_epoch=True, prog_bar=False, logger=True
        )
        # Calculate F1
        metric_f1 = BinaryF1Score().to(y.device)
        f1 = metric_f1(y_pred_class, y)
        self.log(
            "train_f1", f1, on_step=False, on_epoch=True, prog_bar=False, logger=True
        )
        # Calculate Precision
        metric_precision = BinaryPrecision().to(y.device)
        precision = metric_precision(y_pred_class, y)
        self.log(
            "train_precision",
            precision,
            on_step=False,
            on_epoch=True,
            prog_bar=False,
            logger=True,
        )
        # Calculate Recall
        metric_f1 = BinaryRecall().to(y.device)
        recall = metric_f1(y_pred_class, y)
        self.log(
            "train_recall",
            recall,
            on_step=False,
            on_epoch=True,
            prog_bar=False,
            logger=True,
        )

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch

        loss_fn = nn.BCELoss()

        y_pred_logits = self(x).squeeze()
        y_pred = torch.sigmoid(y_pred_logits)
        loss = loss_fn(y_pred, y.float())
        self.log(
            "val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True
        )

        # Calculate metrics

        # Calculate Accuracy
        y_pred_class = torch.round(y_pred)
        acc = (y_pred_class == y).sum().item() / len(y_pred)
        self.log(
            "val_acc", acc, on_step=False, on_epoch=True, prog_bar=True, logger=True
        )
        # Calculate F1
        metric_f1 = BinaryF1Score().to(y.device)
        f1 = metric_f1(y_pred_class, y)
        self.log("val_f1", f1, on_step=False, on_epoch=True, prog_bar=True, logger=True)

        # Calculate Precision
        metric_precision = BinaryPrecision().to(y.device)
        precision = metric_precision(y_pred_class, y)
        self.log(
            "val_precision",
            precision,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )
        # Calculate Recall
        metric_f1 = BinaryRecall().to(y.device)
        recall = metric_f1(y_pred_class, y)
        self.log(
            "val_recall",
            recall,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        if isinstance(batch, list):
            # Assuming the first element in the list is the input tensor
            input_tensor = batch[0]
            return self(input_tensor)
        else:
            # If batch is already a tensor, proceed as usual
            print("Input Shape:", batch.shape)
            return self(batch)

    def configure_optimizers(self):
        if self.hyperparameters["optimizer"] == "Adam":
            optimizer = Adam(
                self.parameters(), lr=self.hyperparameters["lr"], weight_decay=2e-5
            )
        elif self.hyperparameters["optimizer"] == "SGD":
            optimizer = SGD(
                self.parameters(), lr=self.hyperparameters["lr"], weight_decay=2e-5
            )
        elif self.hyperparameters["optimizer"] == "RMSprop":
            optimizer = RMSprop(
                self.parameters(), lr=self.hyperparameters["lr"], weight_decay=2e-5
            )

        if self.hyperparameters["scheduler"] == "ReduceLROnPlateau":
            scheduler = {
                "scheduler": ReduceLROnPlateau(
                    optimizer, mode="min", factor=0.1, patience=3
                ),
                "monitor": "val_loss",
            }
        elif self.hyperparameters["scheduler"] == "CosineAnnealingLR":
            scheduler = CosineAnnealingLR(optimizer, T_max=50, eta_min=0)

        return [optimizer], [scheduler]


def objective(trial):
    hyperparameters = {
        "optimizer": trial.suggest_categorical("optimizer", ["Adam", "SGD", "RMSprop"]),
        "scheduler": trial.suggest_categorical(
            "scheduler", ["ReduceLROnPlateau", "CosineAnnealingLR"]
        ),
        "lr": trial.suggest_loguniform("lr", 1e-5, 1e-1),
        "dropout": trial.suggest_float("dropout", 0.01, 1),
    }

    model = Resnet_model(hyperparameters)
    logger = CSVLogger("logs", name=str(trial.number))
    trainer = pl.Trainer(
        logger=logger,
        max_epochs=50,
        callbacks=[EarlyStopping(monitor="val_loss", patience=1)],
    )

    trainer.fit(model, train_dataloader, val_dataloader)

    return trainer.callback_metrics["val_loss"].item()


def print_callback(study, trial):
    print(
        f"Trial {trial.number} finished with value: {trial.value} and parameters: {trial.params}"
    )

In [None]:
(
    preprocess_train,
    preprocess_val,
    preprocess_test,
) = prepare_data.get_timm_transforms(train_NN.Resnet_model())

(
    train_dataloader,
    val_dataloader,
    test_dataloader,
) = prepare_data.get_dataloaders(
    preprocess_train=preprocess_train,
    preprocess_val=preprocess_val,
    preprocess_test=preprocess_test,
)

study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=100, callbacks=[print_callback])

In [None]:
plot_optimization_history(study)

In [None]:
plot_parallel_coordinate(study)

In [None]:
plot_contour(study)

In [None]:
plot_param_importances(study)