In [1]:
pip_source = "hi-ml-multimodal"
from typing import List
from typing import Tuple

import tempfile
from pathlib import Path

import torch
from IPython.display import display
from IPython.display import Markdown
import random 

import os
os.environ["TRANSFORMERS_CACHE"] = "/vol/biomedic3/bglocker/ugproj2324/nns20/.hi-ml-cache"

from health_multimodal.common.visualization import plot_phrase_grounding_similarity_map
from health_multimodal.text import get_bert_inference
from health_multimodal.text.utils import BertEncoderType
from health_multimodal.image import get_image_inference
from health_multimodal.image.utils import ImageModelType
from health_multimodal.vlp import ImageTextInferenceEngine

from agent_utils import select_best_gpu



In [2]:
text_inference = get_bert_inference(BertEncoderType.BIOVIL_T_BERT)
image_inference = get_image_inference(ImageModelType.BIOVIL_T)
image_text_inference = ImageTextInferenceEngine(
    image_inference_engine=image_inference,
    text_inference_engine=text_inference,
)
device = select_best_gpu()
image_text_inference.to(device)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'BertTokenizer'. 
The class this function is called from is 'CXRBertTokenizer'.
You are using a model of type bert to instantiate a model of type cxr-bert. This is not supported for all configurations of models and can yield errors.
  return self.fget.__get__(instance, owner)()
Some weights of the model checkpoint at microsoft/BiomedVLP-BioViL-T were not used when initializing CXRBertModel: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing CXRBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing CXRBertModel from the checkpoint of a model that you expect to 

Downloading https://cdn-lfs.huggingface.co/repos/63/2f/632fbb459426d5c3e8e64aa8be737ccf0c8ba541980f23a79ecf1ab6e87df8b4/b2399d73dc2a68b9f3a1950e864ae0ecd24093fb07aa459d7e65807ebdc0fb77?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27biovil_t_image_model_proj_size_128.pt%3B+filename%3D%22biovil_t_image_model_proj_size_128.pt%22%3B&Expires=1715695775&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxNTY5NTc3NX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy82My8yZi82MzJmYmI0NTk0MjZkNWMzZThlNjRhYThiZTczN2NjZjBjOGJhNTQxOTgwZjIzYTc5ZWNmMWFiNmU4N2RmOGI0L2IyMzk5ZDczZGMyYTY4YjlmM2ExOTUwZTg2NGFlMGVjZDI0MDkzZmIwN2FhNDU5ZDdlNjU4MDdlYmRjMGZiNzc%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=a9gxkLcB9AO6UyUQ8xDNoeb1MR2fwG%7EMLzlJux1yxGodr5vUteMpYOk91ZztTgoq0f87zl4WWJaywq4u-jk-AIlz7VYOSzJjvYze1RGPHNFa8ckXMi4f2xd%7EaIHLnfK5vIVu%7E74FLve7Cdw1rxSJrHllZ3%7ERZmrIGJrk7XZN7P8g2mabzM3UF3zsYGxVyh4OZM1sHO%7EKptJdnscVL

100%|██████████| 109745561/109745561 [00:00<00:00, 113390221.99it/s]


GPU 0: NVIDIA GeForce RTX 2080 Ti, Free memory: 11002 MB
GPU 1: NVIDIA GeForce RTX 2080 Ti, Free memory: 10977 MB
Selecting GPU 0 with 11002 MB free memory, Device = cuda:0


In [8]:
threshold = 0.2
top_n = 25

def get_top_values(similarity_map,threshold, top_n=top_n):
    top_values = []
    for i in range(similarity_map.shape[0]):
        for j in range(similarity_map.shape[1]):
            if similarity_map[i, j] > threshold:
                top_values.append((i, j, similarity_map[i, j]))

    top_values = sorted(top_values, key = lambda x: x[2], reverse = True)
    return top_values[:top_n]

def calculate_mean(similarity_map_top_values):
    if len(similarity_map_top_values) == 0:
        return 0
    return sum([x[2] for x in similarity_map_top_values]) / len(similarity_map_top_values)
    

In [28]:
vindr_pathology_left_or_right_path = Path("/vol/biomedic3/bglocker/ugproj2324/nns20/datasets/VinDr-CXR/image_text_reasoning_datasets/test_pathology_left_or_right")
vindr_png_path = Path('/vol/biodata/data/chest_xray/VinDr-CXR/1.0.0_png_512/raw/test')

detection_thresholds = [0.25,0.35,0.45,0.55]

with open(vindr_pathology_left_or_right_path) as f:

    lines = f.readlines()

    for detection_threshold in detection_thresholds:

        print(f"{detection_threshold=}")
        exact_match = 0
        left_correct = 0
        right_correct = 0

        total = 0
        total_left = 0
        total_right = 0

        for index, line in enumerate(lines):
            image_id, text_prompt, ground_truth_side = line.strip().split(",")
            image_path = vindr_png_path / f"{image_id}.png"

            left_similarity_map = image_text_inference.get_similarity_map_from_raw_data(
                image_path=image_path,
                query_text=f"left {text_prompt}",
                interpolation="bilinear",
            )

            right_similarity_map = image_text_inference.get_similarity_map_from_raw_data(
                image_path=image_path,
                query_text=f"right {text_prompt}",
                interpolation="bilinear",
            )

            left_mean_activation = calculate_mean(get_top_values(left_similarity_map, detection_threshold))
            right_mean_activation = calculate_mean(get_top_values(right_similarity_map, detection_threshold))
            
            locations = []
            if left_mean_activation >= detection_threshold:
                locations.append("left")
            
            if right_mean_activation >= detection_threshold:
                locations.append("right")

            predictions = " and ".join(locations)

            # print(f"{left_mean_activation=}")
            # print(f"{right_mean_activation=}")
            # print(f"Ground Truth Side: {ground_truth_side}")
            # print(f"Predictions: {predictions}")

            if left_mean_activation + right_mean_activation > 0:
                total += 1
        
            if "left" in ground_truth_side:#  and not "right" in ground_truth_side:
                total_left += 1
            elif "right" in ground_truth_side:# and not "left" in ground_truth_side:
                total_right += 1

            if ground_truth_side == predictions:
                exact_match += 1

            if "left" in ground_truth_side and "left" in locations:
                left_correct += 1

            if "right" in ground_truth_side and "right" in locations:
                right_correct += 1

            if index == 100:
                break

        
        print(f"{total=}")
        print(f"{total_left=}")
        print(f"{total_right=}")

        print(f"Exact Match Accuracy: {exact_match/total}")
        print(f"Left Accuracy: {left_correct/total}")
        print(f"Right Accuracy: {right_correct/total}")

        print("\n")
      

detection_threshold=0.25
total=88
total_left=69
total_right=32
Exact Match Accuracy: 0.3181818181818182
Left Accuracy: 0.6136363636363636
Right Accuracy: 0.4431818181818182


detection_threshold=0.35
total=64
total_left=69
total_right=32
Exact Match Accuracy: 0.40625
Left Accuracy: 0.515625
Right Accuracy: 0.453125


detection_threshold=0.45
total=36
total_left=69
total_right=32
Exact Match Accuracy: 0.6111111111111112
Left Accuracy: 0.3888888888888889
Right Accuracy: 0.6388888888888888


detection_threshold=0.55
total=21
total_left=69
total_right=32
Exact Match Accuracy: 0.7142857142857143
Left Accuracy: 0.2857142857142857
Right Accuracy: 0.7142857142857143


