In [None]:
from dataset import GradingDataset
from grading_model import GradingModel
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split

In [None]:
BATCH_SIZE = 8

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
train_images_dir = "/home/wilk/diabetic_retinopathy/datasets/grading/train_set/images"
train_labels_csv = "/home/wilk/diabetic_retinopathy/datasets/grading/train_set/ground_truths/a. IDRiD_Disease Grading_Training Labels.csv"

test_images_dir = "/home/wilk/diabetic_retinopathy/datasets/grading/test_set/images"
test_labels_csv = "/home/wilk/diabetic_retinopathy/datasets/grading/test_set/ground_truths/b. IDRiD_Disease Grading_Testing Labels.csv"

In [None]:
train_dataset = GradingDataset(train_images_dir, train_labels_csv)

train_dataset, validation_dataset = random_split(train_dataset, [0.8, 0.2])
test_dataset = GradingDataset(test_images_dir, test_labels_csv)

In [None]:
train_dataloader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True)
validation_dataloader = DataLoader(validation_dataset, BATCH_SIZE, shuffle=True)
test_dataset = DataLoader(test_dataset, BATCH_SIZE, shuffle=False)

In [None]:
grading_model = GradingModel()
grading_model.to(device)

In [None]:
optimizer = torch.optim.Adam(grading_model.parameters(), lr=1e-4)

In [None]:
criterion = torch.nn.CrossEntropyLoss()

In [None]:
def validate(grading_model, validation_dataloader, criterion):
        validation_loss = 0

        grading_model.eval()
        for input_batch, target_batch in train_dataloader:
            input_batch = input_batch.to(device)
            target_batch = target_batch.to(device)

            logits, f_high = grading_model(input_batch)

            loss = criterion(logits, target_batch)

            validation_loss += loss.item()

        mean_validation_loss = validation_loss / len(validation_dataloader)

        return mean_validation_loss

In [None]:
def train(grading_model, train_dataloader, validation_dataloader, optimizer, criterion, n_epochs):
    for epoch in range(n_epochs):
        training_epoch_loss = 0

        grading_model.train()
        for input_batch, target_batch in train_dataloader:
            optimizer.zero_grad()

            input_batch = input_batch.to(device)
            target_batch = target_batch.to(device)

            logits, f_high = grading_model(input_batch)

            loss = criterion(logits, target_batch)
            loss.backward()
            optimizer.step()

            training_epoch_loss += loss.item()

        mean_training_loss = training_epoch_loss / len(train_dataloader)

        mean_validation_loss = validate(grading_model, validation_dataloader, criterion)

        print(f"Epoch: {epoch}, Mean training loss: {mean_training_loss}, Mean validation loss: {mean_validation_loss}")

In [None]:
train(grading_model, train_dataloader, validation_dataloader, optimizer, criterion, 100)