In [16]:
import torch
import cv2
from dataset import DRSegmentationDataset
from unet import UNet
from sklearn.model_selection import KFold
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import matplotlib.pyplot as plt

In [None]:
BATCH_SIZE = 4

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

In [None]:
test_dataset = DRSegmentationDataset("/home/wilk/diabetic_retinopathy/datasets/processed_segmentation_dataset/test_set")
train_dataset = DRSegmentationDataset("/home/wilk/diabetic_retinopathy/datasets/processed_segmentation_dataset/train_set")

In [None]:
writer = SummaryWriter(log_dir="runs/Adam_original")

# Create training loop

In [None]:
def reset_weights(m):
  '''
    Try resetting model weights to avoid
    weight leakage.
  '''
  for layer in m.children():
   if hasattr(layer, 'reset_parameters'):
    # print(f'Reset trainable parameters of layer = {layer}')
    layer.reset_parameters()

In [None]:
def train(train_dataset, epochs, k_folds=5):
    kfold = KFold(n_splits=k_folds, shuffle=True) 
    fold_results = {fold: 0 for fold in range(k_folds)}

    for fold, (train_ids, val_ids) in enumerate(kfold.split(train_dataset)):
        train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
        val_subsampler = torch.utils.data.SubsetRandomSampler(val_ids)

        train_loader = torch.utils.data.DataLoader(
                      train_dataset, 
                      batch_size=BATCH_SIZE, sampler=train_subsampler)
        val_loader = torch.utils.data.DataLoader(
                      train_dataset,
                      batch_size=BATCH_SIZE, sampler=val_subsampler)
        
        model = UNet(in_channels=3, out_channels=5)
        model.to(device)

        model.apply(reset_weights)

        optimizer = torch.optim.Adam(model.parameters(), lr=0.0002, betas=[0.5, 0.5])

        loss = torch.nn.BCELoss()
        
        for epoch in range(epochs):
            training_epoch_loss = 0
            validation_epoch_loss = 0

            model.train()
            for train_batch_id, train_batch in enumerate(train_loader):
                optimizer.zero_grad()
                
                input_tensor = train_batch[0].to(device)
                target_tensor = train_batch[1].to(device)

                train_output = model(input_tensor)

                loss_value = loss(train_output, target_tensor)
                loss_value.backward()
                optimizer.step()

                training_epoch_loss += loss_value.item()

            model.eval()
            with torch.no_grad():
                for val_batch_id, val_batch in enumerate(val_loader):                
                    input_tensor = val_batch[0].to(device)
                    target_tensor = val_batch[1].to(device)

                    val_output = model(input_tensor)

                    loss_value = loss(val_output, target_tensor)
                    validation_epoch_loss += loss_value.item() 

            training_epoch_loss /= len(train_loader)
            validation_epoch_loss /= len(val_loader)

            writer.add_scalar(f'Fold{fold}/Loss/Train', training_epoch_loss, epoch)
            writer.add_scalar(f'Fold{fold}/Loss/Val', validation_epoch_loss, epoch)


            print(f"Fold: {fold}, Epoch: {epoch}, Mean training loss: {training_epoch_loss}, Mean validation loss: {validation_epoch_loss}")
        fold_results[fold] = validation_epoch_loss
    
    final_fold_values = [value for k, value in fold_results.items()]
    average_validation_result = np.mean(final_fold_values)
    print("Average validation result:", average_validation_result)

    writer.add_scalar("Average validation result:", average_validation_result)

In [None]:
train(train_dataset, 200, k_folds=5)