# Import packages

In [None]:
import logging
import os
import sys
import tempfile
from glob import glob

import torch
from PIL import Image
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt

import monai
from monai.data import create_test_image_2d, list_data_collate, decollate_batch
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.networks.nets import UNet
from monai.transforms import Activations, AddChanneld, AsDiscrete, Compose, LoadImaged, SaveImage, ScaleIntensityd, EnsureTyped, EnsureType, AsChannelFirstd, Resized

# Check MONAI configurations

In [None]:
monai.config.print_config()
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

# Process VGH data

In [None]:
# Set the Data folder
data_path = "/Workspace/data/VGH_Seg_IMG_Label/"

## -obtain testing data list

In [None]:
# Load testing files
tempdir = data_path + "Test/img/"
test_images = sorted(glob(os.path.join(tempdir, "*.jpg")))

tempdir = data_path + "Test/msk_img/"
test_segs = sorted(glob(os.path.join(tempdir, "*.png")))

test_files = [{"img": img, "seg": seg} for img, seg in zip(test_images[:], test_segs[:])]


# Define Transform for image and Segmentation

In [None]:
# define transforms for image and segmentation
test_transforms = Compose(
    [
        LoadImaged(keys=["img", "seg"]),
        
        AddChanneld(keys=["seg"]),        
        AsChannelFirstd(keys=["img"]),

        ScaleIntensityd(keys=["img", "seg"]),
        #Resized(keys=["img", "seg"], spatial_size=[800, 800]),
        EnsureTyped(keys=["img", "seg"]),
    ]
)
test_ds = monai.data.Dataset(data=test_files, transform=test_transforms)

# Create Data Loader, Save Output, Model Architecture

In [None]:
# sliding window inference need to input 1 image in every iteration
test_loader = DataLoader(test_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
saver = SaveImage(output_dir="./output", output_ext=".png", output_postfix="seg",scale=255,separate_folder=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(
    spatial_dims=2,
    in_channels=3,
    out_channels=1,
    #channels=(16, 32, 64, 128, 256),
    channels=(32, 64, 128, 256, 512),
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device)

# Load previous model

In [None]:
model.load_state_dict(torch.load("best_metric_model_segmentation2d_dict.pth"))

In [None]:
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 16))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

# Run evaluation on testing data

In [None]:
model.eval()

with torch.no_grad():
    for test_data in test_loader:
        test_images, test_labels = test_data["img"].to(device), test_data["seg"].to(device)
        # define sliding window size and batch size for windows inference
        roi_size = (800, 800)
        sw_batch_size = 4
        test_outputs = sliding_window_inference(test_images, roi_size, sw_batch_size, model)

        visualize( 
            image=test_images[0].cpu().permute(1,2,0), 
            ground_truth_mask=test_labels[0].cpu().permute(1,2,0), 
            predicted_mask=test_outputs[0].squeeze().cpu().numpy().round()
        )   
       
        
        test_outputs = [post_trans(i) for i in decollate_batch(test_outputs)]
        test_labels = [post_trans(i) for i in decollate_batch(test_labels)]
        
        #test_labels = decollate_batch(test_labels)
        # compute metric for current iteration
        dice_metric(y_pred=test_outputs, y=test_labels)
        for test_output in test_outputs:            
            saver(test_output*255)
    # aggregate the final mean dice result    
    print("evaluation metric:", dice_metric.aggregate().item())
    # reset the status
    dice_metric.reset()