# Predict notebook pipeline

### Set up imports

In [None]:
# Imports

# Python imports
import os
import cv2
import sys
import glob
import json
import random
import pathlib
import numpy as np
import seaborn as sns
from PIL import Image
from barbar import Bar
from natsort import natsorted
import time

from models import UNet, init_net
from dataloader import EndoMaskDataset
from tensorboardX import SummaryWriter

# For experimentation purpose
import torch
import torchvision
import albumentations as alb
from torchsummary import summary
from torchvision.utils import save_image
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, random_split
import albumentations.augmentations.transforms as alb_tr

# Project imports
import utils as ut

os.environ['CUDA_VISIBLE_DEVICES'] = str(1)
DEVICE = torch.device('cuda:0')

# Setup interact widget
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

# Load the TensorBoard notebook extension
%load_ext tensorboard

# Auto-reload magic function setup
%load_ext autoreload
%autoreload 2

# Matplotlib magic function setup
%matplotlib inline
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (20,10)

### Load model and the trained experiment

In [None]:
# Set weights dir and tb dir
load_epoch = 200
exp_dir = ""  # path to exp dir
model_path = os.path.join(exp_dir, "model_weights", "weights_{}".format(str(load_epoch)))

exp_name = pathlib.Path(exp_dir).name    
with open(os.path.join(config_file), 'r') as configfile:
    exp_opts = json.load(configfile)
    print("Loaded experiment configs!")
    
model = Unet().to(device)
checkpoint = torch.load()  # path to the checkpoint file that ends with ' .pt'
model.load_state_dict(checkpoint["model_state_dict"])

model.eval();

### Set up test data

In [None]:
# You can predict on the val file if you are not using the val loss for early stopping
# Or you can specify a separate test file
split_file_path = "../aicm_sim_dataset/fold_1/{}_files.txt"
predict_filenames = ut.read_lines_from_text_file(split_file_path.format("val"))


HEIGHT = 448  # This is just a default, change this as per needed
WIDTH = 448
DATAROOT = "/mnt/sds-stud/guest/data_preprocessed/data_coco_final_v3"  # An example, you can change this later

predict_dataset = EndoMaskDataset(data_root_folder=DATAROOT,
                                  filenames=predict_filenames,
                                  height=HEIGHT,
                                  width=WIDTH,
                                  image_aug=None,
                                  image_mask_aug=None)

predict_dataloader = DataLoader(predict_dataset,
                                batch_size=1,
                                shuffle=False,
                                drop_last=False)

#----------------------------------------------------------------#
# If you want to visualise a random sample or a specified sample from the prediction data, 
# you can use this code, else comment it out
index = np.random.choice(dataset.__len__(), 1)
image, mask = dataset.__getitem__(index[0])
image, mask = map(process_utils.convert_to_numpy_image, (image, mask))

plt.rcParams["figure.figsize"] = (40, 20)

plt.subplot(1, 3, 1) 
plt.imshow(image)

plt.subplot(1, 3, 2) 
plt.imshow(np.squeeze(mask))  

plt.subplot(1, 3, 3) 
plt.imshow(image+mask)

### Prediction routine

In [None]:
images, gt_masks, pred_masks = [], [], []
epoch_metric = 0
metrics = []

for i, batch in enumerate(Bar(predict_dataloader), 0):
    image, gt_mask, _ = batch
    image_input, gt_mask = image.to(device), gt_mask.to(device)
    pred_mask = model(image_input)   # gaussian  is the output from model
    
    metric = mse(pred=pred_mask, target=gt_mask)
    
    # You can directly use the save_image function of pytorch that saves the tensors as png images
    # Just have to specify all the paths here
    if save_png:
        save_image(image_input, os.path.join(save_image_path, "{:06d}.png".format(i + 1)))
        save_image(gt_mask, os.path.join(save_gt_mask_path, "{:06d}.png".format(i + 1)))
        save_image(pred_mask, os.path.join(save_pred_mask_path, "{:06d}.png".format(i + 1)))
    
    images.append(image.detach().cpu().clone().numpy().transpose((0, 2, 3, 1)))
    gt_masks.append(gt_mask.detach().cpu().clone().numpy().transpose((0, 2, 3, 1)))
    pred_masks.append(pred_mask.detach().cpu().clone().numpy().transpose((0, 2, 3, 1)))
    metrics.append(metric.item())

    
# Here you get the images, gt and pred as np arrays and then 
# you can do whatever you want with it
images = np.concatenate(images)
gt_masks = np.concatenate(gt_masks)
pred_masks = np.concatenate(pred_masks)

print("Evaluation completed. Metric score: {} %".format(np.mean(metrics)*100))

### Visualise outputs

In [None]:
plt.rcParams["figure.figsize"] = (40, 20)

plt.subplot(1, 3, 1) 
plt.imshow(images[index[0]])

plt.subplot(1, 3, 2) 
plt.imshow(np.squeeze(pred_masks[index[0]]))  

plt.subplot(1, 3, 3) 
plt.imshow(images[index[0]]+pred_masks[index[0]])

#plt.savefig("...pred.png")  # use if needed

In [None]:
# Compute any metrics using the np arrays...
# Visualise the metrics...
# Save metrics to disk...

* End of program