# Import Packages

In [None]:
import logging
import os
import sys
import tempfile
from glob import glob
import time

import torch
from PIL import Image
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import monai
from monai.data import create_test_image_2d, list_data_collate, decollate_batch
from monai.inferers import sliding_window_inference, SimpleInferer
from monai.metrics import DiceMetric
from monai.transforms import (
    Activations,
    AddChanneld,
    AsDiscrete,
    Compose,
    LoadImaged,
    RandCropByPosNegLabeld,
    RandRotate90d,
    ScaleIntensityd,
    EnsureTyped,
    EnsureType,
    AsChannelFirstd,
    AsChannelLast,
    Resized,
    RandScaleCropd,
    RandRotated,
    SaveImage,
)
from monai.visualize import plot_2d_or_3d_image

# Check MONAI configurations

In [None]:
monai.config.print_config()
logging.basicConfig(stream=sys.stdout, level=logging.INFO)


# Process VGH Data

In [None]:
# Set the Data folder
data_path = "C:\\Users\\alzoo\\Desktop\\monai\\SEG_Train_Datasets\\SEG_Train_Datasets\\"

## -obtain train data list

In [None]:
# Load train files
tempdir = data_path + "Train_Images\\"
train_images = sorted(glob(os.path.join(tempdir, "*.jpg")))

tempdir = data_path + "msk_img\\"
train_segs = sorted(glob(os.path.join(tempdir, "*.png")))
print(f" {len(train_images)} train_images and {len(train_segs)} train_segs")
train_files = [{"img": img, "seg": seg} for img, seg in zip(train_images[:], train_segs[:])]


## -obtain validation data list

In [None]:
# Load validation files
tempdir = data_path + "valid_img\\"
valid_images = sorted(glob(os.path.join(tempdir, "*.jpg")))

tempdir = data_path + "valid_msk_img\\"
valid_segs = sorted(glob(os.path.join(tempdir, "*.png")))
print(f" {len(valid_images)} valid_images and {len(valid_segs)} valid_segs")

val_files = [{"img": img, "seg": seg} for img, seg in zip(valid_images[:], valid_segs[:])]


# Define Trasform for image and segmentation

In [None]:
# define transforms for image and segmentation
train_transforms = Compose(
    [
        LoadImaged(keys=["img", "seg"]),
        AddChanneld(keys=["seg"]),        
        AsChannelFirstd(keys=["img"]),
        ScaleIntensityd(keys=["img", "seg"]),
        RandScaleCropd(keys=["img", "seg"],roi_scale=0.5),
        RandRotated(keys=["img", "seg"],range_x=3.14),
        Resized(keys=["img", "seg"], spatial_size=[800, 800]),
        #RandCropByPosNegLabeld(
        #    keys=["img", "seg"], label_key="seg", spatial_size=[96, 96], pos=1, neg=1, num_samples=4
        #),
        RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 1]),
        EnsureTyped(keys=["img", "seg"]),
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["img", "seg"]),
        AddChanneld(keys=["seg"]),        
        AsChannelFirstd(keys=["img"]),
        ScaleIntensityd(keys=["img", "seg"]),
        Resized(keys=["img", "seg"], spatial_size=[800, 800]),
        EnsureTyped(keys=["img", "seg"]),
    ]
)

# Check and visualize the transform results

In [None]:
# define dataset, data loader
check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)

In [None]:
# use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
check_loader = DataLoader(check_ds, batch_size=8, num_workers=12, collate_fn=list_data_collate)
check_data = monai.utils.misc.first(check_loader)
print(check_data["img"].shape, check_data["seg"].shape)


import matplotlib.pyplot as plt

plt.figure("visualize",(16,64))
for i in range(8):
    plt.subplot(8,2,2*i+1)    
    plt.imshow(check_data["img"][i].permute(1,2,0))
    plt.subplot(8,2,2*i+2)
    plt.imshow(check_data["seg"][i].permute(1,2,0))

# Create DataLoader for train and validation data

In [None]:
# create a training data loader
train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
# use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
train_loader = DataLoader(
    train_ds,
    batch_size=8,
    shuffle=True,
    num_workers=8,
    collate_fn=list_data_collate,
    pin_memory=torch.cuda.is_available(),
)

# create a validation data loader
val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=4, num_workers=4, collate_fn=list_data_collate)



# Define metric and post-processing

In [None]:
dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])

# Built Model

In [None]:
# create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = monai.networks.nets.UNet(
    spatial_dims=2,
    in_channels=3,
    out_channels=1,
    #channels=(16, 32, 64, 128, 256),
    channels=(32, 64, 128, 256, 512),
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device)


loss_function = monai.losses.DiceLoss(sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-3)

In [None]:
torch.cuda.is_available()
print(torch.__version__)

# Do you want to load previous model?

In [None]:
#model.load_state_dict(torch.load("Dice_55_best_metric_model_segmentation2d_dict.pth"))

# Create Visualize Function

In [None]:
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 16))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

# Define training parameters and Start training

In [None]:
#### start a typical PyTorch training
total_epochs = 20
val_interval = 1
best_metric = 100
best_metric_epoch = -1
epoch_loss_values = list()
metric_values = list()
writer = SummaryWriter()
for epoch in range(total_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{total_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = batch_data["img"].to(device), batch_data["seg"].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_len = len(train_ds) // train_loader.batch_size
        # print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
        writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    local_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
    print(f"{local_time} epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            val_images = None
            val_labels = None
            val_outputs = None
            show_val = True
            for val_data in val_loader:
                val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device)
                roi_size = (800, 800)
                sw_batch_size = 4
                val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) 
                
                if show_val:
                    visualize( 
                        image=val_images[0].cpu().permute(1,2,0), 
                        ground_truth_mask=val_labels[0].cpu().permute(1,2,0), 
                        predicted_mask=val_outputs[0].cpu().permute(1,2,0)
                    )        
                show_val = False
                
                val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]                
                val_labels = [post_trans(i) for i in decollate_batch(val_labels)]
                # compute metric for current iteration
                dice_metric(y_pred=val_outputs, y=val_labels)
            # aggregate the final mean dice result
            metric = dice_metric.aggregate().item()
            # reset the status for next validation round
            dice_metric.reset()
            metric_values.append(metric)
            if metric < best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), "best_metric_model_segmentation2d_dict.pth")
                print("saved new best metric model")
            print(
                "current epoch: {} current val mean dice loss: {:.4f} best val mean dice loss: {:.4f} at epoch {}".format(
                    epoch + 1, metric, best_metric, best_metric_epoch
                )
            )
            writer.add_scalar("val_mean_dice loss", metric, epoch + 1)
            # plot the last model output as GIF image in TensorBoard with the corresponding image and label
            plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image")
            plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
            plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output")
            

print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
writer.close()


torch.save(model.state_dict(), "Final_model_40_epoches_segmentation2d_dict.pth")

## -obtain testing data list

In [None]:
# Load testing files
tempdir = data_path + "Test/img/"
test_images = sorted(glob(os.path.join(tempdir, "*.jpg")))

tempdir = data_path + "Test/msk_img/"
test_segs = sorted(glob(os.path.join(tempdir, "*.png")))

print(f" {len(test_images)} test_images and {len(test_segs)} test_segs")

test_files = [{"img": img, "seg": seg} for img, seg in zip(test_images[:], test_segs[:])]


# Define Transform for image and Segmentation

In [None]:
# define transforms for image and segmentation
test_transforms = Compose(
    [
        LoadImaged(keys=["img", "seg"]),
        
        AddChanneld(keys=["seg"]),        
        AsChannelFirstd(keys=["img"]),

        ScaleIntensityd(keys=["img", "seg"]),
        #Resized(keys=["img", "seg"], spatial_size=[800, 800]),
        EnsureTyped(keys=["img", "seg"]),
    ]
)
test_ds = monai.data.Dataset(data=test_files, transform=test_transforms)

# Save IM(images) GT(ground-truths) PD(predictions) in the /output/ folder

In [None]:
test_loader = DataLoader(test_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)
saverPD = SaveImage(output_dir="./output", output_ext=".png", output_postfix="PD",scale=255,separate_folder=False)
saverGT = SaveImage(output_dir="./output", output_ext=".png", output_postfix="GT",scale=255,separate_folder=False)
saverIM = SaveImage(output_dir="./output", output_ext=".png", output_postfix="IM",scale=255,separate_folder=False)

# Load another model?

In [None]:
#model.load_state_dict(torch.load("best_metric_model_segmentation2d_dict.pth"))

# Inference on Test data

In [None]:
with torch.no_grad():
    for test_data in test_loader:
        test_images, test_labels = test_data["img"].to(device), test_data["seg"].to(device)
        # define sliding window size and batch size for windows inference
        roi_size = (800, 800)
        sw_batch_size = 4
        test_outputs = sliding_window_inference(test_images, roi_size, sw_batch_size, model)

        visualize( 
            image=test_images[0].cpu().permute(1,2,0), 
            ground_truth_mask=test_labels[0].cpu().permute(1,2,0), 
            predicted_mask=test_outputs[0].squeeze().cpu().numpy().round()
        )           
        saverGT(test_labels[0].cpu())
        saverIM(test_images[0].cpu())        
        saverPD(test_outputs[0].cpu())
        
        test_outputs = [post_trans(i) for i in decollate_batch(test_outputs)]
        test_labels = [post_trans(i) for i in decollate_batch(test_labels)]
        
        
                
        # compute metric for current iteration
        dice_metric(y_pred=test_outputs, y=test_labels)
        #for test_output in test_outputs:            
        #    saver(test_output*255)
    # aggregate the final mean dice result    
    print("evaluation metric:", dice_metric.aggregate().item())
    # reset the status
    dice_metric.reset()