In [1]:
# MOUNT DRIVE
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
# IMPORT LIBRARIES
import cv2
import numpy as np
from google.colab.patches import cv2_imshow
import torch
from tqdm.auto import tqdm
import argparse
import glob
from PIL import Image
import shutil
import os
import pandas as pd
from scipy import ndimage
import pandas as pd
from scipy.stats import pearsonr

In [3]:
# SET VARIABLES
model_dir = '/content/drive/MyDrive/PraNet/' # change this current model directory
dataset_dir = '/content/drive/MyDrive/PraNet/'  # change this data directory
mtype = 'S' # ('S' for Pranet, CaraNet, SSFormer-S , UACANet-S )('L' for SSFormer-L, UACANet-L)
rotation_angles = [90,180,270] # considered angles
threshold = 0.90 # THRESHOLD OF TRUSTWORTHINESS
datasets = ['CVC-300'] #,'CVC-ClinicDB','CVC-ColonDB','ETIS-LaribPolypDB','Kvasir'] # name of test data for polyps

In [5]:
# MAKE FOLDERS
if mtype == "L":
    os.makedirs('Output_L/', exist_ok=True)  # main folder for saving output
    os.makedirs('Output_L/CSV_files/', exist_ok=True) # folder for csv files inside main folder
    for i in rotation_angles:
      for j in datasets:
        os.makedirs('Output_L/'+str(j)+'/', exist_ok=True)  # Each test dataset folder inside main folder
        os.makedirs('Output_L/rotation'+str(i)+'/', exist_ok=True) # folders for 3 rotations  inside main folder
        os.makedirs('Output_L/rotation'+str(i)+'/rotated_pred/', exist_ok=True) # folder to save rotated pred inside each rotaion folder
        os.makedirs('Output_L/rotation'+str(i)+'/rotated_pred/'+ str(j), exist_ok=True)  # Each test dataset folder inside rotated_pred folder
        os.makedirs('Output_L/rotation'+str(i)+'/invert_rotated_pred/', exist_ok=True) # folder to save invert rotated pred inside each rotaion folder
        os.makedirs('Output_L/rotation'+str(i)+'/invert_rotated_pred/'+ str(j), exist_ok=True)  # Each test dataset folder inside invert rotated pred folder
else:
  os.makedirs('Output/', exist_ok=True)
  os.makedirs('Output/CSV_files/', exist_ok=True)
  for i in rotation_angles:
    for j in ['CVC-300','CVC-ClinicDB','CVC-ColonDB','ETIS-LaribPolypDB','Kvasir']:
      os.makedirs('Output/'+str(j)+'/', exist_ok=True)
      os.makedirs('Output/rotation'+str(i)+'/', exist_ok=True)
      os.makedirs('Output/rotation'+str(i)+'/rotated_pred/', exist_ok=True)
      os.makedirs('Output/rotation'+str(i)+'/rotated_pred/'+ str(j), exist_ok=True)
      os.makedirs('Output/rotation'+str(i)+'/invert_rotated_pred/', exist_ok=True)
      os.makedirs('Output/rotation'+str(i)+'/invert_rotated_pred/'+ str(j), exist_ok=True)


In [4]:
# ALL REQUIRED FUNCTIONS
def rotation90(i):
  r_i = cv2.rotate(i,cv2.ROTATE_90_COUNTERCLOCKWISE)
  return r_i

def rotation180(i):
  r_i = cv2.flip(i, -1)
  return r_i

def rotation270(i):
  r_i = cv2.rotate(i,cv2.ROTATE_90_CLOCKWISE)
  return r_i

def invert_rotation90(i):
  r_i = cv2.rotate(i,cv2.ROTATE_90_CLOCKWISE)
  return r_i

def invert_rotation180(i):
  r_i = cv2.flip(i, -1)
  return r_i

def invert_rotation270(i):
  r_i = cv2.rotate(i,cv2.ROTATE_90_COUNTERCLOCKWISE)
  return r_i

def rotation_dataset(img_path,mask_path,save_rotated_img_path,save_rotated_mask_path,rotation_angle):
  sorted_os_list = os.listdir(img_path)
  sorted_os_list.sort()
  for i in tqdm(sorted_os_list):
    img = cv2.imread(img_path + i)
    mask = cv2.imread(mask_path + i)
    if rotation_angle == 90:
      r_img = rotation90(img)
      r_mask = rotation90(mask)
    elif rotation_angle == 180:
      r_img = rotation180(img)
      r_mask = rotation180(mask)
    elif rotation_angle == 270:
      r_img = rotation270(img)
      r_mask = rotation270(mask)
    else:
      raise Exception("wrong rotation angle!!")
    cv2.imwrite(save_rotated_img_path + i, r_img)
    cv2.imwrite(save_rotated_mask_path + i, r_mask)

def invert_rotation_prediction(saved_pred_path, saved_invert_pred_path, rotation_angle):
  for i in tqdm(os.listdir(saved_pred_path)):
    pred = cv2.imread(saved_pred_path + i)
    if rotation_angle == 90:
     r_pred = invert_rotation90(pred)
    elif rotation_angle == 180:
     r_pred = invert_rotation180(pred)
    elif rotation_angle == 270:
     r_pred = invert_rotation270(pred)
    cv2.imwrite(saved_invert_pred_path + i,r_pred)

def dice_metric(pred, target):
  intersection = 2.0 * (target * pred).sum()
  union = target.sum() + pred.sum()
  if target.sum() == 0 and pred.sum() == 0:
      return 1.0
  return intersection / union

def load(path):
    img = Image.open(open(path, 'rb'))
    mask = img.convert('L')
    mask = np.asarray(mask, np.float32)
    mask /= (mask.max() + 1e-8)
    return mask

def intersection(m):
  m[m < 1] = 0
  final_mask = m
  return final_mask

def get_YA_YB(mtype, i, data_name, cso=False):
  if mtype == 'L':
      path0 =  "Output_L/"+ data_name +"/"
      path90 = "Output_L/rotation90/invert_rotated_pred/"+ data_name +"/"
      path180 =  "Output_L/rotation180/invert_rotated_pred/"+ data_name +"/"
      path270 =  "Output_L/rotation270/invert_rotated_pred/"+ data_name +"/"
  else:
    path0 =  "Output/"+ data_name +"/"
    path90 = "Output/rotation90/invert_rotated_pred/"+ data_name +"/"
    path180 =  "Output/rotation180/invert_rotated_pred/"+ data_name +"/"
    path270 =  "Output/rotation270/invert_rotated_pred/"+ data_name +"/"

  path_list = os.listdir(path0)
  path_list.sort()
  mask0 = load(path0 + i)
  mask90 = load(path90 + i)
  mask180 = load(path180 + i)
  mask270 = load(path270 + i)
  dice0_90 = dice_metric(mask0,mask90)
  dice0_180 = dice_metric(mask0,mask180)
  dice0_270 = dice_metric(mask0,mask270)
  dice90_180 = dice_metric(mask90,mask180)
  dice90_270 = dice_metric(mask90,mask270)
  dice180_270 = dice_metric(mask180,mask270)
  avg0 = (dice0_90 + dice0_180 + dice0_270)/3
  avg90 = (dice0_90 + dice90_180 + dice90_270)/3
  avg180 = (dice90_180 + dice0_180 + dice180_270)/3
  avg270 = (dice0_270 + dice180_270 + dice90_270)/3
  avg = [avg0,avg90,avg180, avg270]
  v1 = max(avg)
  avg.remove(v1)
  v2 = max(avg)
  if (v1 == avg0 and v2 == avg90) or (v1 == avg90 and v2 == avg0):
    mask1 = mask0
    mask2 = mask90
  elif (v1 == avg0 and v2 == avg180) or (v1 == avg180 and v2 == avg0):
    mask1 = mask0
    mask2 = mask180
  elif (v1 == avg0 and v2 == avg270) or (v1 == avg270 and v2 == avg0):
    mask1 = mask0
    mask2 = mask270
  elif (v1 == avg90 and v2 == avg180) or (v1 == avg180 and v2 == avg90):
    mask1 =  mask90
    mask2 =  mask180
  elif (v1 == avg90 and v2 == avg270) or (v1 == avg270 and v2 == avg90):
    mask1 =  mask90
    mask2 =  mask270
  elif (v1 == avg180 and v2 == avg270) or (v1 == avg270 and v2 == avg180):
    mask1 =  mask180
    mask2 = mask270
  return mask1, mask2

In [5]:
# # ROTATE INPUT DATA
for data_name in datasets:
  for rotation_angle in rotation_angles:
    img_path = "TestDataset/"+ data_name +"/images/"   # input image path
    save_rotated_img_path = "TestDataset/RotatedDataset"+ str(rotation_angle) +"/"+ data_name +"/images/"  # rotated input image path
    mask_path = "TestDataset/"+ data_name +"/masks/"  # input mask path
    save_rotated_mask_path = "TestDataset/RotatedDataset"+ str(rotation_angle) +"/"+ data_name +"/masks/"  # rotated input mask path
    rotation_dataset(img_path,mask_path,save_rotated_img_path,save_rotated_mask_path,rotation_angle)

In [6]:
####.......Run this cell after getting all the MIS predictions of input and corresponding variants on existing models.....#####
# INVERT ROTATION ON DIFFERENT PREDICTIONS
os.chdir(model_dir)
for data_name in datasets:
  for rotation_angle in rotation_angles:
    if mtype == 'L':
      saved_pred_path = "Output_L/rotation" + str(rotation_angle) +"/rotated_pred/"+ data_name +"/"  # saved rotated prediction path
      saved_invert_pred_path = "Output_L/rotation" + str(rotation_angle) +"/invert_rotated_pred/"+ data_name +"/"  # invert rotated prediction path
    else:
      saved_pred_path = "Output/rotation" + str(rotation_angle) +"/rotated_pred/"+ data_name +"/"  # saved rotated prediction path
      saved_invert_pred_path = "Output/rotation" + str(rotation_angle) +"/invert_rotated_pred/"+ data_name +"/"  # invert rotated prediction path
    invert_rotation_prediction(saved_pred_path, saved_invert_pred_path, rotation_angle)

In [8]:
####.......Run this cell after getting all the MIS predictions of input and corresponding variants on existing models.....#####
def int_method(mtype, i, data_name):
  YA, YB = get_YA_YB(mtype, i, data_name) # YA and YB masks for test data (INT)
  count_white_pixels_mask1 = cv2.countNonZero(YA) # count of foreground pixels in YA
  count_white_pixels_mask2 = cv2.countNonZero(YB) # count of foreground pixels in YB
  total_pixels = YA.shape[0] * YB.shape[1] # total pixels in YA or YB
  max_white_pixel = max(count_white_pixels_mask1, count_white_pixels_mask2) # max foreground pixels between YA and YB
  if max_white_pixel > 0.2 * total_pixels: # case if ----''---- is 0.2% of total pixels
    final_mask = [YA if max_white_pixel == count_white_pixels_mask1 else YB] # F having max foreground pixels (either YA or YB)
    F = final_mask[0]
    S = F
  else:
    m = (YA + YB) / 2.0
    S = intersection(m)  # intersection between YA and YB (S = YA intersect YB)
  return S


def it_method(gt_path,rotation_angles,data_name,mtype,threshold):
  path_list = os.listdir(gt_path)
  path_list.sort()
  z = {}
  dice_cso = {}
  for i in tqdm(path_list):
    a = []
    for rotation_angle in rotation_angles:
      if mtype == 'L':
        original_pred_path = "Output_L/" + data_name + "/"   # prediction without rotation
        saved_invert_pred_path = "Output_L/rotation" + str(rotation_angle) +"/invert_rotated_pred/"+ data_name +"/"  # invert rotated prediction path
      else:
        original_pred_path = "Output/" + data_name + "/"   # prediction without rotation
        saved_invert_pred_path = "Output/rotation" + str(rotation_angle) +"/invert_rotated_pred/"+ data_name +"/"  # invert rotated prediction path
      m1 = load(original_pred_path + i) # load actual pred Y
      m2 = load(saved_invert_pred_path + i) # load variant pred Y^i
      consistency = dice_metric(m1,m2) # check dice between actual and variant prediction (consistancy) (For IT method)
      a.append(consistency)
    conf = sum(a) / len(a)
    if conf >= threshold:
      print("prediction of sample "+ str(i)+ " is Trustworthy (T) with confidence " + str(conf) + ".")
      S = m1
    else:
      print("prediction of sample "+ str(i)+ " is Non-Trustworthy (NT) with confidence " + str(conf) + ".")
      S = int_method(mtype, i, data_name)
    final_trustmis_performance = dice_metric(S, load(gt_path + i))
    z.update({i:final_trustmis_performance})
    ### FOR CSO ###
    YA, YB = get_YA_YB(mtype, i, data_name)
    d_cso = dice_metric(YA, YB) # dice used in CSO method
  dice_cso.update({i:d_cso})
  return z, dice_cso


In [10]:
####.......Run this cell after getting all the MIS predictions of input and corresponding variants on existing models.....#####
# FUNCTION CALLING
for data_name in datasets:
  os.chdir(dataset_dir) # set dir to dataset dir
  gt_path = "/content/drive/MyDrive/PraNet/TestDataset/" + data_name + "/masks/" # ground truth mask
  os.chdir(model_dir)     # change dir to current model dir
  z, dice_cso = it_method(gt_path,rotation_angles,data_name,mtype,threshold)
  print("The improved INT performance (c^INT) for dataset "+ str(data_name) + " is : " + str(sum(z.values())/len(z)))