### Import Libraries

In [None]:
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
import imgaug.augmenters as iaa

import sklearn.metrics as metrics
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

from dataset import LungDataset
from model import UNet

### Data Preperation

In [None]:
# Define data augmentations
seq = iaa.Sequential([
    iaa.Affine(scale=(0.85, 1.15),
    rotate=(-45, 45)),
    iaa.ElasticTransformation()
])

In [None]:
# Define paths
train_path = Path('../Preprocessed/train/')
val_path = Path('../Preprocessed/val/')

# Create training dataset
train_dataset = LungDataset(train_path, seq)

# Create val dataset
val_dataset = LungDataset(val_path, None)


#### Oversampling

In [None]:
target_list = []
for _, label in train_dataset:
    # Check if mask contains a tumorous pixel:
    if np.any(label):
        target_list.append(1)
    else:
        target_list.append(0)

unique = np.unique(target_list, return_counts=True)
ratio = unique[1][0] / unique[1][1]
ratio

In [None]:
weight_list = []

for target in target_list:
    if target == 0:
        weight_list.append(1)
    else:
        weight_list.append(ratio)

In [None]:
sampler = torch.utils.data.sampler.WeightedRandomSampler(weight_list, len(weight_list))

In [None]:
# Define variables
batch_size = 8
num_workers = 4

# Create data loaders
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    sampler=sampler
)

val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    shuffle=False
)

### Create Segmentation Model

In [None]:
# Define the TumorSegmentation class as a LightningModule
class TumorSegmentation(pl.LightningModule):
    def __init__(self):
        super().__init__()

        # Create an instance of the UNet model
        self.model = UNet()

        # Define the Adam optimizer for model parameters with a learning rate of 1e-4
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)

        # Define the loss function for binary classification with logits
        self.loss_fn = torch.nn.BCEWithLogitsLoss()

    # Define the forward pass of the model
    def forward(self, data):
        pred = self.model(data)
        return pred

    # Define the training step for the LightningModule
    def training_step(self, batch, batch_idx):
        # Unpack the batch into CT scan images and their corresponding masks
        ct_scan, mask = batch
        mask = mask.float()

        # Perform the forward pass to get the model predictions
        pred = self(ct_scan)

        # Calculate the binary cross-entropy loss between predictions and masks
        loss = self.loss_fn(pred, mask)

        # Log the loss and images periodically during training
        self.log("Train Loss", loss)
        if batch_idx % 50 == 0:
            self.log_images(ct_scan.cpu(), pred.cpu(), mask.cpu(), "Train")

        return loss

    # Define the validation step for the LightningModule
    def validation_step(self, batch, batch_idx):
        # Unpack the batch into CT scan images and their corresponding masks
        ct_scan, mask = batch
        mask = mask.float()

        # Perform the forward pass to get the model predictions
        pred = self(ct_scan)

        # Calculate the binary cross-entropy loss between predictions and masks
        loss = self.loss_fn(pred, mask)

        # Log the loss and images periodically during validation
        self.log("Val Loss", loss)
        if batch_idx % 50 == 0:
            self.log_images(ct_scan.cpu(), pred.cpu(), mask.cpu(), "Val")

        return loss

    # Define a method to log CT scan images, actual masks, and predicted masks
    def log_images(self, ct_scan, pred, mask, name):
        results = []

        # Threshold the predicted masks to obtain binary masks
        pred = pred > 0.5

        # Create a figure with two subplots to display actual and predicted masks
        fig, axis = plt.subplots(1, 2)
        axis[0].imshow(ct_scan[0][0], cmap="bone")
        mask_ = np.ma.masked_where(mask[0][0] == 0, mask[0][0])
        axis[0].imshow(mask_, alpha=0.6)
        axis[0].set_title("Actual")

        axis[1].imshow(ct_scan[0][0], cmap="bone")
        mask_ = np.ma.masked_where(mask[0][0] == 0, mask[0][0])
        axis[1].imshow(mask_, alpha=0.6)
        axis[1].set_title("Predicted")

        # Add the figure to the experiment's logger for visualization
        self.logger.experiment.add_figure(f"{name} Actual vs Prediction", fig, self.global_step)

    # Define a method to log the confusion matrix
    def log_confusion_matrix(self, mask, pred):
        # Calculate the confusion matrix
        pred = (pred > 0.5).float().cpu()
        mask = mask.float().cpu()
        confusion_matrix = metrics.confusion_matrix(mask.flatten(), pred.flatten())

        # Plot the confusion matrix using a heatmap
        plt.figure()
        sns.heatmap(confusion_matrix, annot=True, fmt='d', cmap='Blues')
        plt.title('Confusion Matrix')
        plt.xlabel('Predicted')
        plt.ylabel('Actual')
        plt.show()

    # Define a method to log the ROC curve
    def log_roc_curve(self, mask, pred):
        # Calculate the ROC curve and AUC
        fpr, tpr, thresholds = metrics.roc_curve(mask, pred)
        roc_auc = metrics.auc(fpr, tpr)

        # Create a plot for the ROC curve
        sns.set_style("darkgrid")
        roc_df = pd.DataFrame({'fpr': fpr, 'tpr': tpr})
        sns.lineplot(data=roc_df, x='fpr', y='tpr', color='blue', label=f'ROC AUC = {roc_auc:.2f}')
        sns.lineplot(x=[0, 1], y=[0, 1], color='gray', linestyle='--')
        plt.title('Receiver Operating Characteristic (ROC) Curve')
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.legend(loc="lower right")

    # Define a method to configure the optimizer
    def configure_optimizers(self):
        return [self.optimizer]

In [None]:
# Create checkpoint
checkpoint_callback = ModelCheckpoint(
    monitor='Val Loss',
    dirpath='./models',
    filename='best_model',
    save_top_k=1,
    mode='min'
)

In [None]:
# Intialize the model
model = TumorSegmentation()

In [None]:
# Create PyTorch Lightning Trainer
trainer = pl.Trainer(
    devices=1,
    accelerator='gpu',
    logger=TensorBoardLogger(save_dir='../logs'),
    log_every_n_steps=1,
    callbacks=checkpoint_callback,
    max_epochs=1
)

In [None]:
# Train model
trainer.fit(model, train_loader, val_loader)