# Creating segmenation masks 
### Creating masks using the model that crops the image around the head of the fish.

In [None]:
import torch, torchvision
import os
import random
import metrics
import time

import constants as cst
import numpy as np
import torch.nn as nn
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torchvision.utils import draw_bounding_boxes
import torchvision.transforms.functional as F
from torchvision.ops import masks_to_boxes
from torchvision.utils import draw_segmentation_masks

from unet import UNET
import utils
from PIL import Image

In [None]:
def predict_img(model, image, device, transform, out_threshold=0.5):
    with torch.no_grad():
        x = image
        logits = model(x.to(device))
        logits = transform(logits)
        y_pred = nn.Softmax(dim=1)(logits)
        proba = y_pred.detach().cpu().squeeze(0).numpy()[1, :, :]
        return proba > out_threshold


TERMS = cst.COMBINED_TERM
    
# Images that belong to every testing set
names = ["120220_4xzoom_v_fish 16.jpg",
        "120220_4xzoom_v2_fish 14.jpg",
        "26022020 inx +- 4xzoom fish08 v3.jpg",
        "26022020 inx +- 4xzoom fish11 v.jpg",
        "120220_4xzoom_v_fish 01.jpg",
        "120220_4xzoom_v_fish 09.jpg",
        "120220_4xzoom_v1_fish 19.jpg",
        "120220_4xzoom_v2_fish 15.jpg",
        "120220_4xzoom_v2_fish 17.jpg",
        "120220_4xzoom_v1_fish 23.jpg",
        "120220_4xzoom_v2_fish 21.jpg",
        "120220_4xzoom_v1_fish 11.jpg"]


img_names = ["img1",
             "img2",
             "img3",
             "img4",
             "img5",
             "img6",
             "img7",
             "img8",
             "img9",
             "img10",
             "img11",
             "img12"]

random.seed(cst.SEED)
torch.manual_seed(cst.SEED)
np.random.seed(cst.SEED)

# One model per annotation
model_names = ["br1_Focal_Fold_2_Epoch_95_MaxEpochs_250_Adam_LR_0.0001.pth",
              "br2_CE_Fold_2_Epoch_225_MaxEpochs_250_Adam_LR_0.0001.pth",
              "cb_Tversky_Fold_3_Epoch_97_MaxEpochs_250_Adam_LR_0.0001.pth",
              "ch_Tversky_Fold_2_Epoch_190_MaxEpochs_250_Adam_LR_0.0001.pth",
              "cl_Focal_Fold_2_Epoch_47_MaxEpochs_250_Adam_LR_0.0001.pth",
              "d_CE_Fold_3_Epoch_194_MaxEpochs_250_Adam_LR_0.0001.pth",
              "en_Tversky_Fold_4_Epoch_105_MaxEpochs_250_Adam_LR_0.0001.pth",
              "hm_Tversky_Fold_2_Epoch_176_MaxEpochs_250_Adam_LR_0.0001.pth",
              "m_CE_Fold_4_Epoch_247_MaxEpochs_250_Adam_LR_0.0001.pth",
              "n_Focal_Fold_3_Epoch_96_MaxEpochs_250_Adam_LR_0.0001.pth",
              "oc_Focal_Fold_0_Epoch_47_MaxEpochs_250_Adam_LR_0.0001.pth",
              "op_Focal_Fold_3_Epoch_57_MaxEpochs_250_Adam_LR_0.0001.pth",
              "p_CE_Fold_3_Epoch_129_MaxEpochs_250_Adam_LR_0.0001.pth"]


model_path = os.path.join(cst.DIR, "final_cropped")
save_path = os.path.join(cst.DIR, "pred_cropped")

# Threshold for each model/annotation
thresholds = [0.06,
              0.02,
              0.02,
              0.02,
              0.08,
              0.02,
              0.02,
              0.0,
              0.02,
              0.1,
              0.08,
              0.08,
              0.08]

# Color for each annotation
all_colors = ["red",
          "coral",
          "sandybrown",
          "gold",
          "greenyellow",
          "seagreen",
          "cyan",
          "steelblue",
          "blue",
          "mediumslateblue",
          "darkorchid",
          "magenta",
          "lightpink",
          "dimgrey"]

loss_name = "best"
fold = 0

SIZE = (384, 512)


DEVICE_NAME = "cpu"
if torch.cuda.is_available():
    DEVICE_NAME = 'cuda:0'
DEVICE = torch.device(DEVICE_NAME)

transform_tensor = transforms.ToTensor()
transform = transforms.Compose([transforms.Resize(SIZE),
                                transforms.Pad((0, 64, 0, 64))])
untransform = transforms.Compose([transforms.CenterCrop(SIZE),
                                 transforms.Resize((1932, 2576))])

# Model used to segment the fish out of the image
fish_model_name = "fish_CE_Fold_4_Epoch_95_MaxEpochs_600_Adam_LR_0.0001.pth"
fish_model_path = os.path.join(cst.MODEL, fish_model_name)

fish_model = utils.load_model(fish_model_path)
fish_model.to(DEVICE)

for i in range(len(names)):
    name_img = names[i]
    save_name = img_names[i]
    
    img_path = os.path.join(cst.DIR, "images")
    
    img = transform_tensor(Image.open(os.path.join(img_path, name_img)))
    img = img[:3,:,:]
    
    fish_image = transform(img)
    fish_mask = predict_img(fish_model, fish_image.unsqueeze(dim=0), DEVICE, untransform)
    fish_mask_image = Image.fromarray(fish_mask)
    fish_mask_tensor = transform_tensor(fish_mask_image)
    obj_ids = torch.unique(fish_mask_tensor)
    obj_ids = obj_ids[1:]

    fish_masks = fish_mask_tensor == obj_ids[:, None, None]
    boxes = masks_to_boxes(fish_masks)

    h_length = boxes[0, 2]+1 - boxes[0, 0]
    v_length = boxes[0, 3]+1 - boxes[0, 1]
    h1 = int(boxes[0, 0])
    h2 = int(boxes[0, 2])+1
    v1 = int(boxes[0, 1])
    v2 = int(boxes[0, 3])+1
    if h_length%10!=0:
        mod = 10 - (h_length%4)
        h_length += mod

    h_length = (3*h_length)/5
    h2 = int(h1 + h_length)

    if v_length%2==1:
        v1 = v1-1
        v_length += 1
        
    if h_length>v_length:
        padding = int((h_length-v_length)/2)
        post_tr = transforms.Compose([transforms.Pad((0, padding, 0, padding)),
                                      transforms.Resize((512,512))])
        untr = transforms.Compose([transforms.Resize((int(h_length), int(h_length))),
                                  transforms.CenterCrop((int(v_length), int(h_length)))])
    elif h_length<v_length:
        padding = int((v_length-h_length)/2)
        post_tr = transforms.Compose([transforms.Pad((padding, 0, padding, 0)),
                                      transforms.Resize((512,512))])
        untr = transforms.Compose([transforms.Resize((int(v_length), int(v_length))),
                                   transforms.CenterCrop((int(v_length), int(h_length)))])
    else:
        post_tr = transforms.Compose([transforms.Resize((512,512))])
        untr = transforms.Compose([transforms.Resize((int(h_length), int(h_length)))])
        
    cropped = img[:, v1:v2, h1:h2]    
    img_copy = cropped
    
    img = cropped.unsqueeze(dim=0)
    
    
    all_groundtruth = []
    all_predictions = []
    
    for t in range(len(TERMS)-1):
        print("Image:" , name_img)
        print("Model:", model_names[t])
        print("Threshold:", thresholds[t])
        model_name = model_names[t]
        model = utils.load_model(os.path.join(model_path, model_names[t]))
        model.to(DEVICE)
        
        grdtruth = transform_tensor(Image.open(os.path.join(cst.DIR, TERMS[t],name_img)))
        grdtruth = grdtruth[:, v1:v2, h1:h2]  
        plt.imshow(grdtruth.squeeze())
        plt.axis('off')
        plt.savefig(os.path.join(save_path, save_name + "_gt_" + TERMS[t] + ".jpg"))
        plt.show()
        grdtruth = grdtruth.unsqueeze(dim=0)
        all_groundtruth.append(grdtruth)
        
        prediction = predict_img(model, post_tr(img), DEVICE, untr, out_threshold=thresholds[t])
        pred = torch.from_numpy(prediction)
        all_predictions.append(pred)
        plt.imshow(pred)
        plt.axis('off')
        plt.savefig(os.path.join(save_path, save_name + "_pred_" + loss_name + "_"+ TERMS[t] + ".jpg"))
        plt.show()

        
    img_copy = img_copy * 255
    img_copy = img_copy.type(torch.uint8)
    img_copy2 = img_copy
    
    im = img_copy
    im_gt = im
    drawn_masks = []
    for p in range(len(all_predictions)):
        if p == 7:
            gt = all_groundtruth[p].squeeze(dim=0)
            gt_bool = gt > 0.9 
            im_gt = draw_segmentation_masks(im_gt, gt_bool[0], alpha=1, colors=all_colors[p])
            continue
            
        predict = all_predictions[p]
        predict = predict.unsqueeze(dim=0)
        
        gt = all_groundtruth[p].squeeze(dim=0)
        gt_bool = gt > 0.9        
        
        im = draw_segmentation_masks(im, predict[0], alpha=1, colors=all_colors[p])
        im_gt = draw_segmentation_masks(im_gt, gt_bool[0], alpha=1, colors=all_colors[p])
        
    plt.imshow(im.permute(1, 2, 0))
    plt.axis('off')
    plt.savefig(os.path.join(save_path, "0_" + save_name + "_pred_on_img_" + loss_name + ".jpg"))
    plt.show()
    
    plt.imshow(im_gt.permute(1, 2, 0))
    plt.axis('off')
    plt.savefig(os.path.join(save_path, "0_" + save_name + "_gt_on_img_" + loss_name + ".jpg"))
    plt.show()