In [None]:
%matplotlib inline

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pickle
import random

from monai.data import Dataset, list_data_collate # , decollate_batch
# from monai.handlers.utils import from_engine
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.networks.nets import UNet

from monai.transforms import (
    LoadImage,
    Activations,
    AddChannel,
    AsDiscrete,
    Compose,
    LoadImage,
    RandRotate,
    RandSpatialCrop,
    ScaleIntensity,
    AsChannelFirst,
    AsChannelLast, 
    RandFlip,
    ToTensor,
    Resize
    # EnsureType,
)
from monai.visualize import plot_2d_or_3d_image
from monai.data import ArrayDataset, create_test_image_2d # , decollate_batch
from torchvision.transforms import Lambda

from monai.utils import set_determinism
from monai.utils.misc import first
import torch

In [None]:
from monai.metrics import DiceMetric

In [None]:
# define all parameters:
nr_train_samples = 500
nr_val_samples = 100

In [None]:
image_dir = "/kvh4/optic_disc/data/REFUGE-2/REFUGE2-Training/resized_data/non_glaucoma/images"
images = os.listdir(image_dir)
image_paths = [os.path.join(image_dir, i) for i in images]

gt_dir = "/kvh4/optic_disc/data/REFUGE-2/REFUGE2-Training/resized_data/non_glaucoma/ground_truth"
gt_paths = [os.path.join(gt_dir, i[:-4]+".bmp") for i in images]

In [None]:
training_images = image_paths[:nr_train_samples]
training_gt = gt_paths[:nr_train_samples]

validation_images = image_paths[nr_train_samples:(nr_train_samples+nr_val_samples)]
validation_gt = gt_paths[nr_train_samples:(nr_train_samples+nr_val_samples)]

### Test the dataloader

In [None]:
test_images = image_paths[:10]
test_gt = gt_paths[:10]

In [None]:
# define training transformations (separately for input and gt)
test_imtransforms = Compose(
        [ LoadImage(image_only=True),
         AsChannelFirst(),
         RandFlip(spatial_axis=1, prob=.5),
         RandFlip(spatial_axis=0, prob=.5),
         RandRotate(range_x=15, prob=0.3, keep_size=True),
         ScaleIntensity(),
         ToTensor()
        ]
    )

test_gttransforms = Compose(
        [ LoadImage(image_only=True),
         AsChannelFirst(),
         RandFlip(spatial_axis=1, prob=.5),
         RandFlip(spatial_axis=0, prob=.5),
         RandRotate(range_x=15, prob=0.3, keep_size=True),
         ToTensor(),
         # Lambda(lambda x: torch.cat([x==255, x==0,x==128], 0))
         # Lambda(lambda x: x[0,:,:]),
         # AddChannel(),
         # Lambda(lambda x: torch.cat([x==255, x==0,x==128], 0))
        ]
    )

In [None]:
test_ds = ArrayDataset(test_images, test_imtransforms, test_gt, test_gttransforms)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=6) #, shuffle=True)

In [None]:
check_data = first(test_loader)

In [None]:
np.unique(check_data[1][0].numpy())

In [None]:
plt.subplots(6,3, figsize=(12,20))
for i in range(6):
    plt.subplot(6,3,i*3+1)
    image = np.zeros((check_data[0][i][0].shape[0], check_data[0][i][0].shape[1],3))
    for c in range(3):
        image[:,:,c] = check_data[0][i][c]
    plt.imshow(image)
    plt.subplot(6,3,i*3+2)
    plt.imshow(check_data[1][i][1])
    plt.subplot(6,3,i*3+3)
    plt.imshow(check_data[1][i][2])

## Training and validation datasets

In [None]:
# define training transformations (separately for input and gt)
training_imtransforms = Compose(
        [ LoadImage(image_only=True),
         AsChannelFirst(),
         RandFlip(spatial_axis=1, prob=.5),
         RandFlip(spatial_axis=0, prob=.5),
         RandRotate(range_x=15, prob=0.3, keep_size=True),
         ScaleIntensity(),
         ToTensor()
        ]
    )

training_gttransforms = Compose(
        [ LoadImage(image_only=True),
         AsChannelFirst(),
         RandFlip(spatial_axis=1, prob=.5),
         RandFlip(spatial_axis=0, prob=.5),
         RandRotate(range_x=15, prob=0.3, keep_size=True),
         ToTensor(),
         # Lambda(lambda x: torch.cat([x==255, x==0,x==128], 0))
         # Lambda(lambda x: x[0,:,:]),
         # AddChannel(),
         # Lambda(lambda x: torch.cat([x==255, x==0,x==128], 0))
        ]
    )


validation_imtransforms = Compose(
        [ LoadImage(image_only=True),
         AsChannelFirst(),
         ScaleIntensity(),
         ToTensor()
        ]
    )

validation_gttransforms = Compose(
        [ LoadImage(image_only=True),
         AsChannelFirst(),
         ToTensor(),
         # Lambda(lambda x: torch.cat([x==255, x==0,x==128], 0))
         # Lambda(lambda x: x[0,:,:]),
         # AddChannel(),
         # Lambda(lambda x: torch.cat([x==255, x==0,x==128], 0))
        ]
    )

In [None]:
training_ds = ArrayDataset(training_images, training_imtransforms, training_gt, training_gttransforms)
training_loader = torch.utils.data.DataLoader(training_ds, batch_size=6, shuffle=True)

validation_ds = ArrayDataset(validation_images, validation_imtransforms, validation_gt, validation_gttransforms)
validation_loader = torch.utils.data.DataLoader(validation_ds, batch_size=1, shuffle=False)

In [None]:
len(training_images)

In [None]:
len(training_gt)

## Set up training

In [None]:
device = torch.device("cuda:3")

In [None]:
epoch_num = 100

model_dir = "/kvh4/optic_disc/models/01_UNet"
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

In [None]:
# model
model = UNet(
        dimensions=2,
        in_channels=3,
        out_channels=3,
        channels=(32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

In [None]:
dice_metric = DiceMetric(include_background=True, reduction="mean") #, get_not_nans=False)

# transforms for the output
post_trans_1 = Compose([ AddChannel(), Activations(softmax=True)])
post_trans_2 = Compose([ Activations(softmax=True), AsDiscrete(threshold_values=True)])

loss_function = DiceLoss(softmax = True)
optimizer = torch.optim.Adam(model.parameters(), 1e-4)

In [None]:
val_interval = 1
# best_metric = -1
# best_metric_epoch = -1
epoch_loss_values = list()
metric_values = list()
for epoch in range(epoch_num):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{10}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in training_loader:
        step += 1
        inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_len = len(training_ds) // training_loader.batch_size
        # print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
        # writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    savepath = os.path.join(model_dir, "epoch_"+str(epoch+1)+".pth")
    print("savepath: ", savepath)
    torch.save(model.state_dict(), savepath)
    print("saved model")
    
    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            val_images = None
            val_labels = None
            val_outputs = None
            metric_sum = 0.0
            metric_count = 0
            for val_data in validation_loader:
                val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
                roi_size = (96, 96)
                sw_batch_size = 4
                val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
                # val_outputs = [post_trans_1(i) for i in val_outputs]
                val_outputs = post_trans_1(val_outputs[0])
                # compute metric for current iteration
                value = dice_metric(y_pred=val_outputs, y=val_labels)
                value = dice_metric(y_pred=val_outputs, y=val_labels)
                metric_count += len(value)
                metric_sum += value.item() * len(value)
            metric = metric_sum / metric_count
            metric_values.append(metric)
            # reset the status for next validation round
            # dice_metric.reset()
        
            
            print(
                "current epoch: {} current mean dice: {:.4f}".format(
                    epoch + 1, metric
                )
            )
            
np.save(os.path.join(model_dir, "epoch_loss.npy"), epoch_loss_values)
np.save(os.path.join(model_dir, "val_metrics.npy"), metric_values)


            # plot the last model output as GIF image in TensorBoard with the corresponding image and label
            # plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image")
            # plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
            # plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output")


In [None]:
val_outputs[0].shape

In [None]:
len(val_labels)

In [None]:
post_trans_1 = Compose([ AddChannel(), Activations(softmax=True)])



In [None]:
val_outputs = post_trans_1(val_outputs[0])
# compute metric for current iteration
val_outputs = val_outputs[None,:]
val_outputs.shape

In [None]:
dice_metric(y_pred=val_outputs, y=val_labels)

In [None]:
val_outputs.shape

In [None]:
val_labels.shape

In [None]:
value = dice_metric(y_pred=val_outputs, y=val_labels)

In [None]:
value.item()