# Creating segmenation masks 
### Creating masks using the multi-class model

In [None]:
import torch, torchvision
import os
import random
import metrics
import time

import constants as cst
import numpy as np
import torch.nn as nn
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torchvision.utils import draw_bounding_boxes
import torchvision.transforms.functional as F
from torchvision.ops import masks_to_boxes
from torchvision.utils import draw_segmentation_masks

from unet import UNET
import utils
from PIL import Image

In [None]:
def predict_img(model, image, device, transform, index, out_threshold=0.5):
    with torch.no_grad():
        x = image
        logits = model(x.to(device))
        logits = transform(logits)
        y_pred = nn.Softmax(dim=1)(logits)
        y_cur = y_pred[:,index,:,:]
        proba = y_cur.detach().cpu().squeeze(0).numpy()
        return proba > out_threshold

TERMS = cst.COMBINED_TERM
    
# Images that belong to every testing set
names = ["120220_4xzoom_v_fish 16.jpg",
        "120220_4xzoom_v2_fish 14.jpg",
        "26022020 inx +- 4xzoom fish08 v3.jpg",
        "26022020 inx +- 4xzoom fish11 v.jpg",
        "120220_4xzoom_v_fish 01.jpg",
        "120220_4xzoom_v_fish 09.jpg",
        "120220_4xzoom_v1_fish 19.jpg",
        "120220_4xzoom_v2_fish 15.jpg",
        "120220_4xzoom_v2_fish 17.jpg",
        "120220_4xzoom_v1_fish 23.jpg",
        "120220_4xzoom_v2_fish 21.jpg",
        "120220_4xzoom_v1_fish 11.jpg"]

img_names = ["img1",
             "img2",
             "img3",
             "img4",
             "img5",
             "img6",
             "img7",
             "img8",
             "img9",
             "img10",
             "img11",
             "img12"]

random.seed(cst.SEED)
torch.manual_seed(cst.SEED)
np.random.seed(cst.SEED)

# Model 
model_names = ["alles_different_masks_Focal_Fold_2_Epoch_43_MaxEpochs_250_Adam_LR_0.0001.pth"]


model_path = os.path.join(cst.DIR, "final_multi")
save_path = os.path.join(cst.DIR, "pred_multi_focal")

# Threshold for each annotation
thresholds = [0.96,
              0.88,
              0.42,
              0.8,
              0.44,
              0.92,
              0.86,
              0.88,
              0.78,
              0.48,
              0.44,
              0.9,
              0.94,
              0.96]

# One color for each annotation
all_colors = ["red",
          "coral",
          "sandybrown",
          "gold",
          "greenyellow",
          "seagreen",
          "cyan",
          "steelblue",
          "blue",
          "mediumslateblue",
          "darkorchid",
          "magenta",
          "lightpink",
          "dimgrey"]

loss_name = "Focal"
fold = 0

SIZE = (384, 512)


DEVICE_NAME = "cpu"
if torch.cuda.is_available():
    DEVICE_NAME = 'cuda:0'
DEVICE = torch.device(DEVICE_NAME)

transform_tensor = transforms.ToTensor()
transform = transforms.Compose([transforms.Resize(SIZE),
                                transforms.Pad((0, 64, 0, 64))])
untransform = transforms.Compose([transforms.CenterCrop(SIZE),
                                 transforms.Resize((1932, 2576))])

for i in range(len(names)):
    name_img = names[i]
    save_name = img_names[i]
    
    img_path = os.path.join(cst.DIR, "images")
    
    img = transform_tensor(Image.open(os.path.join(img_path, name_img)))
    img = img[:3,:,:]
    img_copy = img
    img = img.unsqueeze(dim=0)
    
    all_groundtruth = []
    all_predictions = []
    
    for t in range(len(TERMS)):
        print("Image:" , name_img)
        print("Model:", model_names[0])
        print("Threshold:", thresholds[t])
        model_name = model_names[0]
        model = utils.load_model_all(os.path.join(model_path, model_names[0]))
        model.to(DEVICE)
        
        grdtruth = transform_tensor(Image.open(os.path.join(cst.DIR, TERMS[t],name_img)))
        plt.imshow(grdtruth.squeeze())
        plt.axis('off')
        plt.savefig(os.path.join(save_path, save_name + "_gt_" + TERMS[t] + ".jpg"))
        plt.show()
        grdtruth = grdtruth.unsqueeze(dim=0)
        all_groundtruth.append(grdtruth)
        
        prediction = predict_img(model, transform(img), DEVICE, untransform, t, out_threshold=thresholds[t])
        pred = torch.from_numpy(prediction)
        all_predictions.append(pred)
        plt.imshow(pred)
        plt.axis('off')
        plt.savefig(os.path.join(save_path, save_name + "_pred_" + loss_name + "_"+ TERMS[t] + ".jpg"))
        plt.show()
        
    img_copy = img_copy * 255
    img_copy = img_copy.type(torch.uint8)
    img_copy2 = img_copy
    
    im = img_copy
    im_gt = im
    drawn_masks = []
    for p in range(len(all_predictions)):
            
        predict = all_predictions[p]
        predict = predict.unsqueeze(dim=0)
        
        gt = all_groundtruth[p].squeeze(dim=0)
        gt_bool = gt > 0.9        
        
        im = draw_segmentation_masks(im, predict[0], alpha=1, colors=all_colors[p])
        im_gt = draw_segmentation_masks(im_gt, gt_bool[0], alpha=1, colors=all_colors[p])
        
    plt.imshow(im.permute(1, 2, 0))
    plt.axis('off')
    plt.savefig(os.path.join(save_path, "0_" + save_name + "_pred_on_img_" + loss_name + ".jpg"))
    plt.show()
    
    plt.imshow(im_gt.permute(1, 2, 0))
    plt.axis('off')
    plt.savefig(os.path.join(save_path, "0_" + save_name + "_gt_on_img_" + loss_name + ".jpg"))
    plt.show()
        