# LOAD LIBRARIES

In [None]:
import pandas as pd

In [None]:
import os
import glob
from pathlib import Path
import PIL
from PIL import Image

In [None]:
import matplotlib.colors as mcolors
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches

In [None]:
import argparse
import numpy as np
import torch

# import datasets
import utils.misc as misc
from utils.box_utils import xywh2xyxy, bbox_iou, calculate_iou_mask, calculate_dice_mask
from utils.visual_bbox import visualBBox
from models import build_model
import datasets.transforms as T
import PIL.Image as Image
import data_loader
from transformers import AutoTokenizer

import pandas as pd
import os
import json


# get_args_parser()

In [None]:
!nvidia-smi


In [None]:
device= "cuda"

In [None]:
gold_500 = pd.read_csv('gold_500.csv')

# LOAD MODEL

In [None]:

parser = argparse.ArgumentParser()
parser.add_argument('--lr_bert', default=0., type=float)
parser.add_argument('--lr_visu_cnn', default=0., type=float)
parser.add_argument('--lr_visu_tra', default=1e-5, type=float)
parser.add_argument('--batch_size', default=32, type=int)
parser.add_argument('--weight_decay', default=1e-4, type=float)
parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--lr_power', default=0.9, type=float, help='lr poly power')
parser.add_argument('--clip_max_norm', default=0., type=float,
                    help='gradient clipping max norm')
parser.add_argument('--eval', dest='eval', default=False, action='store_true', help='if evaluation only')
parser.add_argument('--optimizer', default='rmsprop', type=str)
parser.add_argument('--lr_scheduler', default='poly', type=str)
parser.add_argument('--lr_drop', default=80, type=int)
# Model parameters
parser.add_argument('--model_name', type=str, default='TransVG_ca',
                    help="Name of model to be exploited.")


# Transformers in two branches
parser.add_argument('--bert_enc_num', default=12, type=int)
parser.add_argument('--detr_enc_num', default=6, type=int)

# DETR parameters
# * Backbone
parser.add_argument('--backbone', default='resnet50', type=str,
                    help="Name of the convolutional backbone to use")
parser.add_argument('--dilation', action='store_true',
                    help="If true, we replace stride with dilation in the last convolutional block (DC5)")
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), help="Type of positional embedding to use on top of the image features")
# * Transformer
parser.add_argument('--enc_layers', default=6, type=int,
                    help="Number of encoding layers in the transformer")
parser.add_argument('--dec_layers', default=0, type=int,
                    help="Number of decoding layers in the transformer")
parser.add_argument('--dim_feedforward', default=2048, type=int,
                    help="Intermediate size of the feedforward layers in the transformer blocks")
parser.add_argument('--hidden_dim', default=256, type=int,
                    help="Size of the embeddings (dimension of the transformer)")
parser.add_argument('--dropout', default=0.1, type=float,
                    help="Dropout applied in the transformer")
parser.add_argument('--nheads', default=8, type=int,
                    help="Number of attention heads inside the transformer's attentions")
parser.add_argument('--num_queries', default=100, type=int,
                    help="Number of query slots")
parser.add_argument('--pre_norm', action='store_true')

parser.add_argument('--imsize', default=640, type=int, help='image size')
parser.add_argument('--emb_size', default=512, type=int,
                    help='fusion module embedding dimensions')
# Vision-Language Transformer
parser.add_argument('--use_vl_type_embed', action='store_true',
                    help="If true, use vl_type embedding")
parser.add_argument('--vl_dropout', default=0.1, type=float,
                    help="Dropout applied in the vision-language transformer")
parser.add_argument('--vl_nheads', default=8, type=int,
                    help="Number of attention heads inside the vision-language transformer's attentions")
parser.add_argument('--vl_hidden_dim', default=256, type=int,
                    help='Size of the embeddings (dimension of the vision-language transformer)')
parser.add_argument('--vl_dim_feedforward', default=2048, type=int,
                    help="Intermediate size of the feedforward layers in the vision-language transformer blocks")
parser.add_argument('--vl_enc_layers', default=6, type=int,
                    help='Number of encoders in the vision-language transformer')


parser.add_argument('--dataset', default='MS_CXR', type=str,
                    help='referit/flickr/unc/unc+/gref')
parser.add_argument('--max_query_len', default=20, type=int,
                    help='maximum time steps (lang length) per batch')

# dataset parameters
parser.add_argument('--output_dir', default='answers',
                    help='path where to save, empty for no saving')
parser.add_argument('--device', default='cuda',
                    help='device to use for training / testing')

parser.add_argument('--detr_model', default='./saved_models/detr-r50.pth', type=str, help='detr model')
parser.add_argument('--bert_model', default='bert-base-uncased', type=str, help='bert model')

parser.add_argument('--eval_model', default='released_checkpoint/MedMPG_MS_CXR.pth', type=str)

parser.add_argument('--body_part', default='cardiac silhouette', type=str)




In [None]:
args, unknown = parser.parse_known_args()


In [None]:
model = build_model(args)


In [None]:
model.to('cuda')
checkpoint = torch.load(args.eval_model, map_location='cpu')
model.load_state_dict(checkpoint['model'])

In [None]:

def make_transforms(imsize):
    """
    image transformations
    """
    return T.Compose([
            T.ToTensor(),
            T.NormalizeAndPad(size=imsize),
        ])


# NECESSARY DICT AND FUNCTIONS

In [None]:
# the folder for all the directories
!mkdir "<YOUR_FOLDER_NAME>" 

In [None]:
body_parts = gold_500.bbox_name.unique()

# the structure of thisdictionary is part--> "{image_name --> coordinates of the bounding box}"

#this one is for the original body box
bboxes_orig = {} 
for part in body_parts:
    bboxes_orig[part] = {}

#this one is for the maira predicted body box
bboxes_medRPG = {} 
for part in body_parts:
    bboxes_medRPG[part] = {}

# the structure of thisdictionary is part--> "{image_name --> iou of the predicted bbox relating to original bbox}"
bboxes_iou_medrpg = {} 
for part in body_parts:
    bboxes_iou_medrpg[part] = {}


In [None]:
def calculate_iou_mask(predictions, targets, epsilon=1e-7):
    """
    calculate iou of two masks (predicted and original)
    """
    predictions = predictions.byte()
    targets = targets.byte()
    intersection = (predictions & targets).sum((0, 1))
    union = (predictions | targets).sum((0, 1))
    iou = (intersection + epsilon) / (union + epsilon)
    return iou


# GETTING THE MODEL'S PREDICTIONS

In [None]:
from IPython.display import clear_output
from tqdm import tqdm

for part in body_parts:
   
    
    clear_output(wait=True)
    parent_dir = "<YOUR_FOLDER_NAME>"
    path = os.path.join(parent_dir, part) 
    os.mkdir(path) 
    s = 0
    iou = 0
    iou1 = 0
    
    
    for image in tqdm(gold_500.image_id_jpg.unique()):
        path1 = "<path to resized mimic>/"
        path2 = "<path to originalsized mimic>"
        img_path = path1 + image
        size_path = path2 + image
       
        if (os.path.exists(img_path) and os.path.exists(size_path)):
            
            if len(gold_500.original_x1[((gold_500.image_id_jpg==image) & (gold_500.bbox_name ==part))].values) > 0:
                W, H = Image.open(size_path).size
                img_try = Image.open(img_path).convert("RGB").resize((640, 640 ))
                x = gold_500.original_x1[((gold_500.image_id_jpg==image) & (gold_500.bbox_name ==part))].values[0]
                y = gold_500.original_y1[((gold_500.image_id_jpg==image) & (gold_500.bbox_name ==part))].values[0]
                w = gold_500.original_width[((gold_500.image_id_jpg==image) & (gold_500.bbox_name ==part))].values[0]
                h = gold_500.original_height[((gold_500.image_id_jpg==image) & (gold_500.bbox_name ==part))].values[0]
                bbox = [round(x / W *500) , round(y / H *500), round(w / W *500), round(h / H *500)]


                
                examples = data_loader.read_examples(part, 1)
                tokenizer = AutoTokenizer.from_pretrained(args.bert_model, do_lower_case=True)
                features = data_loader.convert_examples_to_features(
                    examples=examples, seq_length=args.max_query_len, tokenizer=tokenizer, usemarker=None)
                word_id = torch.tensor(features[0].input_ids)  #
                word_mask = torch.tensor(features[0].input_mask)

                input_dict = dict()
                input_dict['img'] = img_try
                fake_bbox = torch.tensor(np.array([0,0,0,0], dtype=int)).float() 
                input_dict['box'] = fake_bbox 
                input_dict['text'] = part
                transform = make_transforms(imsize=640)
                input_dict = transform(input_dict)
                img = input_dict['img']  
                img_mask = input_dict['mask'] 

                img_data = misc.NestedTensor(img.unsqueeze(0), img_mask.unsqueeze(0))
                text_data = misc.NestedTensor(word_id.unsqueeze(0), word_mask.unsqueeze(0))
                img_data = img_data.to(device)
                text_data = text_data.to(device)
                
                model.eval()
                with torch.no_grad():
                    outputs = model(img_data, text_data)
                    pred_box = outputs['pred_box']
                    
                    pred_box = pred_box.detach().cpu().numpy()[0] *500
                    pred_box = [round(pred_box[0]), round(pred_box[1]), 
                                round(pred_box[2]), round(pred_box[3])]
                    
                    # predicted coordinates are the coordinattes of center, width and height. 
                    #We need to transform them into lower left angle, width, height
                    x_p, y_p, w_p, h_p = pred_box
                    x_p = round(x_p-0.5*w_p)
                    y_p = round(y_p-0.5*h_p)

                    bboxes_medRPG[part][image] = (x_p, y_p, w_p, h_p)

                
                #original bbox
                x_o,y_o, w_o, h_o = bbox
                bboxes_orig[part][image] = (x_o, y_o, w_o, h_o)   
                
                
                fig, ax = plt.subplots()
                #changing the coordinate grid
                ax.imshow(np.flipud(Image.open(size_path).resize((500, 500))), origin='lower')
                

                #reflect the bboxes
                rect2 = patches.Rectangle((x_p,   500 - y_p), w_p, -h_p, linewidth=3, edgecolor='red', facecolor='none')
                ax.add_patch(rect2)

                
                rect3 = patches.Rectangle((x_o,  500 - y_o), w_o, -h_o, linewidth=3, edgecolor='yellow', facecolor='none')
                ax.add_patch(rect3)
             
                # plt.show()
                path_image = os.path.join(path, image) 
                fig.savefig(path_image)
                
                
                W, H = 500, 500
                mask_orig = torch.tensor(np.zeros((W,H)))
                mask_orig[ x_o: x_o+w_o, y_o:y_o+h_o] = 1
                
                
                mask_pred = torch.tensor(np.zeros((W,H)))
                mask_pred[ x_p: x_p+w_p, y_p:y_p+h_p] = 1
                
               
                bboxes_iou_medrpg[part][image] = calculate_iou_mask(mask_pred, mask_orig).item()


    with open('<NAME>.json', 'w') as json_file:
        json.dump(bboxes_orig , json_file, allow_nan=True) #ground truth bboxes

    with open('<NAME>.json', 'w') as json_file:
        json.dump(bboxes_medRPG, json_file, allow_nan=True) #MedRPG predicted bboxed

    with open('IOU_MedRPG_v2.json', 'w') as json_file:
        json.dump(bboxes_iou_medrpg, json_file, allow_nan=True) #iou computed for every picture
