# Imports

In [None]:
from dataset import GradingDataset
from grading_model import GradingModel
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torcheval.metrics import MulticlassAccuracy, MulticlassAUPRC, MulticlassAUROC, MulticlassF1Score
from torch.utils.tensorboard import SummaryWriter
import cv2

### Modify path so that unet can be impoerted

In [None]:
import os
import sys

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))

# Add the project root to sys.path
if project_root not in sys.path:
    sys.path.insert(0, project_root)

In [None]:
for path in sys.path:
    print(path)

In [None]:
from segmentation.unet import UNet

# Main code

In [None]:
BATCH_SIZE = 1

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
train_images_dir = "/home/wilk/diabetic_retinopathy/datasets/grading/train_set/images"
train_labels_csv = "/home/wilk/diabetic_retinopathy/datasets/grading/train_set/labels/a. IDRiD_Disease Grading_Training Labels.csv"

test_images_dir = "/home/wilk/diabetic_retinopathy/datasets/grading/test_set/images"
test_labels_csv = "/home/wilk/diabetic_retinopathy/datasets/grading/test_set/labels/b. IDRiD_Disease Grading_Testing Labels.csv"

In [None]:
train_dataset = GradingDataset(train_images_dir, train_labels_csv)

train_dataset, validation_dataset = random_split(train_dataset, [0.8, 0.2])
test_dataset = GradingDataset(test_images_dir, test_labels_csv)

In [None]:
train_dataloader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True)
validation_dataloader = DataLoader(validation_dataset, BATCH_SIZE, shuffle=False)
test_dataset = DataLoader(test_dataset, BATCH_SIZE, shuffle=False)

In [None]:
grading_model = GradingModel()
grading_model.to(device)

In [None]:
segmentation_model = UNet(3, 5)
segmentation_model.to(device)
segmentation_model.load_state_dict(torch.load("/home/wilk/diabetic_retinopathy/segmentation/segmentation_generator.pth", weights_only=True, map_location=device))

In [None]:
optimizer = torch.optim.Adam(grading_model.parameters(), lr=1e-5)

In [None]:
criterion = torch.nn.CrossEntropyLoss()

In [None]:
writer = SummaryWriter("runs/grading_experiment_1")

# Pretraining classifier

In [None]:
def validate(grading_model, validation_dataloader, criterion, epoch=None):
        grading_model.eval()
        validation_loss = 0

        predicted_values = []
        targets = []
        for batch_index, (input_batch, target_batch) in enumerate(validation_dataloader):
            input_batch = input_batch.to(device)
            target_batch = target_batch.to(device)

            logits, _ = grading_model(input_batch)

            loss = criterion(logits, target_batch)

            normalized_output = torch.softmax(logits, dim=-1)

            predicted_values.extend(normalized_output.cpu().detach().tolist())
            
            if len(target_batch) > 1:
                targets += target_batch.squeeze().cpu().detach().tolist()
            else:
                targets.append(target_batch[0].cpu().detach().item())

            validation_loss += loss.item()

        mean_validation_loss = validation_loss / len(validation_dataloader)

        predicted_values = torch.tensor(predicted_values)
        targets = torch.tensor(targets)

        f1_metric = MulticlassF1Score(num_classes=5)
        f1_metric.update(predicted_values, targets)
        f1_score = f1_metric.compute()

        accuracy_metric = MulticlassAccuracy(num_classes=5)
        accuracy_metric.update(predicted_values, targets)
        accuracy_score = accuracy_metric.compute()
        
        auprc_metric = MulticlassAUPRC(num_classes=5)
        auprc_metric.update(predicted_values, targets)
        auprc_score = auprc_metric.compute()

        auroc_metric = MulticlassAUROC(num_classes=5)
        auroc_metric.update(predicted_values, targets)
        auroc_score = auroc_metric.compute()

        return mean_validation_loss, accuracy_score, f1_score, auprc_score, auroc_score

In [None]:
def train(grading_model, train_dataloader, validation_dataloader, optimizer, criterion, n_epochs):
    for epoch in range(n_epochs):
        grading_model.train()
        segmentation_model.eval()
        for input_batch, target_batch in train_dataloader:
            optimizer.zero_grad()

            input_batch = input_batch.to(device)
            target_batch = target_batch.to(device)

            logits, _= grading_model(input_batch)

            loss = criterion(logits, target_batch)
            loss.backward()
            optimizer.step()

        mean_training_loss, train_accuracy_score, train_f1_score, train_auprc_score, train_auroc_score = validate(grading_model, train_dataloader, criterion)
        mean_validation_loss, validation_accuracy_score, validation_f1_score, validation_auprc_score, validation_auroc_score = validate(grading_model, validation_dataloader, criterion, epoch)

        writer.add_scalar("train/Loss", mean_training_loss, epoch)
        writer.add_scalar("train/Accuracy", train_accuracy_score, epoch)
        writer.add_scalar("train/F1 Score", train_f1_score, epoch)
        writer.add_scalar("train/AUPRC", train_auprc_score, epoch)
        writer.add_scalar("train/AUROC", train_auroc_score, epoch)

        writer.add_scalar("validation/Loss", mean_validation_loss, epoch)
        writer.add_scalar("validation/Accuracy", validation_accuracy_score, epoch)
        writer.add_scalar("validation/F1 Score", validation_f1_score, epoch)
        writer.add_scalar("validation/AUPRC", validation_auprc_score, epoch)
        writer.add_scalar("validation/AUROC", validation_auroc_score, epoch)

        print(f"Epoch: {epoch}, Mean training loss: {mean_training_loss}, Mean validation loss: {mean_validation_loss}")

In [None]:
train(grading_model, train_dataloader, validation_dataloader, optimizer, criterion, 60)

# Training with mask generator

In [None]:
def validate(grading_model, validation_dataloader, criterion, epoch=None):
        validation_loss = 0

        predicted_values = []
        targets = []

        grading_model.eval()
        for batch_index, (input_batch, target_batch) in enumerate(validation_dataloader):
            input_batch = input_batch.to(device)
            target_batch = target_batch.to(device)

            masks = segmentation_model(input_batch)
            logits, attention_maps = grading_model(input_batch, masks.detach())

            if epoch is not None and epoch % 5 == 0 and batch_index == 0:
                reference_image = input_batch[0].cpu().detach().numpy()
                reference_image = reference_image.transpose(1, 2, 0)

                # TODO: Add min max scaling
                reference_image = (reference_image - reference_image.min()) / (reference_image.max() - reference_image.min())
                reference_image = (reference_image * 255).astype('uint8')
                cv2.imwrite("output_masks/reference_image.png", reference_image)
                
                for i in range(5):
                    attention_map = attention_maps[0][i].cpu().detach().numpy()
                    # mask = mask.transpose(1, 2, 0)
                    attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())
                    attention_map = (attention_map * 255).astype('uint8')
                    cv2.imwrite(f"output_masks/mask_{i}/mask_{epoch}.png", attention_map)

            if len(logits.shape) > 1:
                loss = criterion(logits, target_batch)
                normalized_output = torch.softmax(logits, dim=-1)
            else:
                logits = logits.unsqueeze(0)
                loss = criterion(logits, target_batch)
                normalized_output = torch.softmax(logits, dim=-1)

            predicted_values.extend(normalized_output.cpu().detach().tolist())
            
            
            if len(target_batch) > 1:
                targets += target_batch.squeeze().cpu().detach().tolist()
            else:
                targets.append(target_batch[0].cpu().detach().item())

            validation_loss += loss.item()

        mean_validation_loss = validation_loss / len(validation_dataloader)

        predicted_values = torch.tensor(predicted_values)
        targets = torch.tensor(targets)

        f1_metric = MulticlassF1Score(num_classes=5)
        f1_metric.update(predicted_values, targets)
        f1_score = f1_metric.compute()

        accuracy_metric = MulticlassAccuracy(num_classes=5)
        accuracy_metric.update(predicted_values, targets)
        accuracy_score = accuracy_metric.compute()
        
        auprc_metric = MulticlassAUPRC(num_classes=5)
        auprc_metric.update(predicted_values, targets)
        auprc_score = auprc_metric.compute()

        auroc_metric = MulticlassAUROC(num_classes=5)
        auroc_metric.update(predicted_values, targets)
        auroc_score = auroc_metric.compute()

        return mean_validation_loss, accuracy_score, f1_score, auprc_score, auroc_score

In [None]:
def train(grading_model, train_dataloader, validation_dataloader, optimizer, criterion, n_epochs):
    for epoch in range(n_epochs):
        grading_model.train()
        segmentation_model.eval()
        for input_batch, target_batch in train_dataloader:
            optimizer.zero_grad()

            input_batch = input_batch.to(device)
            target_batch = target_batch.to(device)


            masks = segmentation_model(input_batch)
            logits, attention_maps = grading_model(input_batch, masks.detach())

            if len(logits.shape) > 1:
                loss = criterion(logits, target_batch)
            else:
                loss = criterion(logits.unsqueeze(0), target_batch)
            loss.backward()
            optimizer.step()

        del input_batch
        del target_batch
        del masks
        del logits
        del attention_maps
        torch.cuda.empty_cache()
        
        mean_training_loss, train_accuracy_score, train_f1_score, train_auprc_score, train_auroc_score = validate(grading_model, train_dataloader, criterion)
        mean_validation_loss, validation_accuracy_score, validation_f1_score, validation_auprc_score, validation_auroc_score = validate(grading_model, validation_dataloader, criterion, epoch)

        writer.add_scalar("train/Loss", mean_training_loss, epoch)
        writer.add_scalar("train/Accuracy", train_accuracy_score, epoch)
        writer.add_scalar("train/F1 Score", train_f1_score, epoch)
        writer.add_scalar("train/AUPRC", train_auprc_score, epoch)
        writer.add_scalar("train/AUROC", train_auroc_score, epoch)

        writer.add_scalar("validation/Loss", mean_validation_loss, epoch)
        writer.add_scalar("validation/Accuracy", validation_accuracy_score, epoch)
        writer.add_scalar("validation/F1 Score", validation_f1_score, epoch)
        writer.add_scalar("validation/AUPRC", validation_auprc_score, epoch)
        writer.add_scalar("validation/AUROC", validation_auroc_score, epoch)

        print(f"Epoch: {epoch}, Mean training loss: {mean_training_loss}, Mean validation loss: {mean_validation_loss}")

In [None]:
train(grading_model, train_dataloader, validation_dataloader, optimizer, criterion, 100)

In [None]:
torch.save(grading_model.state_dict(), "grading_model_2.pth")