In [None]:
from dataset import GradingDataset
from grading_model import GradingModel
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torcheval.metrics import MulticlassAccuracy, MulticlassAUPRC, MulticlassAUROC, MulticlassF1Score
from torch.utils.tensorboard import SummaryWriter

In [None]:
BATCH_SIZE = 8

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
train_images_dir = "/home/wilk/diabetic_retinopathy/datasets/grading/train_set/images"
train_labels_csv = "/home/wilk/diabetic_retinopathy/datasets/grading/train_set/ground_truths/a. IDRiD_Disease Grading_Training Labels.csv"

test_images_dir = "/home/wilk/diabetic_retinopathy/datasets/grading/test_set/images"
test_labels_csv = "/home/wilk/diabetic_retinopathy/datasets/grading/test_set/ground_truths/b. IDRiD_Disease Grading_Testing Labels.csv"

In [None]:
train_dataset = GradingDataset(train_images_dir, train_labels_csv)

train_dataset, validation_dataset = random_split(train_dataset, [0.8, 0.2])
test_dataset = GradingDataset(test_images_dir, test_labels_csv)

In [None]:
train_dataloader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True)
validation_dataloader = DataLoader(validation_dataset, BATCH_SIZE, shuffle=True)
test_dataset = DataLoader(test_dataset, BATCH_SIZE, shuffle=False)

In [None]:
grading_model = GradingModel()
grading_model.to(device)

In [None]:
optimizer = torch.optim.Adam(grading_model.parameters(), lr=1e-4)

In [None]:
criterion = torch.nn.CrossEntropyLoss()

In [None]:
writer = SummaryWriter("runs/grading_experiment_1")

In [None]:
def validate(grading_model, validation_dataloader, criterion):
        validation_loss = 0

        predicted_values = []
        targets = []

        grading_model.eval()
        for input_batch, target_batch in train_dataloader:
            input_batch = input_batch.to(device)
            target_batch = target_batch.to(device)

            logits, f_high = grading_model(input_batch)

            loss = criterion(logits, target_batch)

            normalized_output = torch.softmax(logits, dim=-1)

            predicted_values += normalized_output.squeeze().cpu().detach().tolist()
            targets += target_batch.squeeze().cpu().detach().tolist()

            validation_loss += loss.item()

        mean_validation_loss = validation_loss / len(validation_dataloader)

        predicted_values = torch.tensor(predicted_values)
        targets = torch.tensor(targets)

        f1_metric = MulticlassF1Score(num_classes=5)
        f1_metric.update(predicted_values, targets)
        f1_score = f1_metric.compute()

        accuracy_metric = MulticlassAccuracy(num_classes=5)
        accuracy_metric.update(predicted_values, targets)
        accuracy_score = f1_metric.compute()
        
        auprc_metric = MulticlassAUPRC(num_classes=5)
        auprc_metric.update(predicted_values, targets)
        auprc_score = f1_metric.compute()

        auroc_metric = MulticlassAUROC(num_classes=5)
        auroc_metric.update(predicted_values, targets)
        auroc_score = f1_metric.compute()

        return mean_validation_loss, accuracy_score, f1_score, auprc_score, auroc_score

In [None]:
def train(grading_model, train_dataloader, validation_dataloader, optimizer, criterion, n_epochs):
    for epoch in range(n_epochs):
        training_epoch_loss = 0

        grading_model.train()
        for input_batch, target_batch in train_dataloader:
            optimizer.zero_grad()

            input_batch = input_batch.to(device)
            target_batch = target_batch.to(device)

            logits, f_high = grading_model(input_batch)

            loss = criterion(logits, target_batch)
            loss.backward()
            optimizer.step()

            training_epoch_loss += loss.item()

        mean_training_loss = training_epoch_loss / len(train_dataloader)

        mean_training_loss, train_accuracy_score, train_f1_score, train_auprc_score, train_auroc_score = validate(grading_model, train_dataloader, criterion)
        mean_validation_loss, validation_accuracy_score, validation_f1_score, validation_auprc_score, validation_auroc_score = validate(grading_model, validation_dataloader, criterion)

        writer.add_scalar("train/Loss", mean_training_loss, epoch)
        writer.add_scalar("train/Accuracy", train_accuracy_score, epoch)
        writer.add_scalar("train/F1 Score", train_f1_score, epoch)
        writer.add_scalar("train/AUPRC", train_auprc_score, epoch)
        writer.add_scalar("train/AUROC", train_auroc_score, epoch)

        writer.add_scalar("validation/Loss", mean_validation_loss, epoch)
        writer.add_scalar("validation/Accuracy", validation_accuracy_score, epoch)
        writer.add_scalar("validation/F1 Score", validation_f1_score, epoch)
        writer.add_scalar("validation/AUPRC", validation_auprc_score, epoch)
        writer.add_scalar("validation/AUROC", validation_auroc_score, epoch)

        print(f"Epoch: {epoch}, Mean training loss: {mean_training_loss}, Mean validation loss: {mean_validation_loss}")

In [None]:
train(grading_model, train_dataloader, validation_dataloader, optimizer, criterion, 100)