In [1]:
import os
os.environ['HF_HOME'] = '/vol/biomedic3/bglocker/ugproj2324/nns20/cxr-agent/.hf_cache' ## THIS HAS TO BE BEFORE YOU IMPORT TRANSFORMERS

import transformers
import torch

from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional
from PIL import Image
from my_secrets import LLAMA3_INSTRUCT_ACCESS_TOKEN, GEMINI_ACCESS_TOKEN
from agent_utils import select_best_gpu

from pathology_detector import PathologyDetector, CheXagentVisionTransformerPathologyDetector
from pathology_sets import Pathologies

from phrase_grounder import PhraseGrounder, BioVilTPhraseGrounder


class GenerationEngine(ABC):


    @abstractmethod
    def generate_model_output(self, system_prompt: str , image_context_prompt: str, user_prompt:Optional[str] = None, image_path = None):
        pass

    
    def detect_and_localise_pathologies(image_path,pathology_detector,phrase_grounder,pathology_detection_threshold = 0.3, ignore_pathologies = None, do_not_localise = None):

        pathology_confidences = pathology_detector.detect_pathologies(image_path, threshold = pathology_detection_threshold)    

        pathology_confidences = {pathology: confidence for pathology, confidence in pathology_confidences.items() if pathology not in ignore_pathologies}
        pathologies_to_localise = [pathology for pathology in pathology_confidences.keys() if pathology not in do_not_localise]

        localised_pathologies = phrase_grounder.get_pathology_lateral_position(image_path, pathologies_to_localise)

        return pathology_confidences, localised_pathologies


    def generate_prompts(detected_pathologies, localised_pathologies, examples = False):
        
        if len(detected_pathologies) == 0:
            system_prompt = """ You are a helpful assistant, specialising in radiology and interpreting Chest X-rays. However there is insufficient data to make any comments on pathologies. Just mention it is possible there are no findings and this should be double checked by a radiologist."""
            image_context_prompt = f"""No pathologies were detected in the chest X-ray. The user will now interact with you."""
            return system_prompt, image_context_prompt

        if len(localised_pathologies) == 0:
            localised_pathologies = "No lateral positions could be confidently determined for any pathologies detected."

        system_prompt = """You are a helpful assistant, specialising in radiology and interpreting Chest X-rays. Please answer CONCISELY and professionally as a radiologist would. You should not be confident unless the data is confident. Use language that reflects the confidence of the data."""
        image_context_prompt = f"""
        You are being provided with data derived from analyzing a chest X-ray, which includes findings on potential pathologies alongside their confidence levels and, separately, possible lateral locations of these pathologies with their own confidence scores.
        This information comes from two specialized diagnostic tools.
        It's important to recognize how to interpret these datasets together when responding to queries:

        Pathology Detection with Confidence Scores:
        {detected_pathologies}

        Phrase Grounding Locations with Confidence Scores:
        {localised_pathologies}

        This separate dataset provides potential lateral locations for some of the detected pathologies, each with its own confidence score, indicating the model's certainty about each pathology's location.
        For instance, left Pleural Effusion is listed with a confidence score of 0.53, suggesting the location of Pleural Effusion to be on the left side with moderate confidence.

        When you interact with end users, remember:

        - A pathology and its lateral location (e.g., Pleural Effusion and left Pleural Effusion) are part of the same finding. The location attribute is an additional detail about where the pathology is likely found, not an indicator of a separate pathology. 
        - Synthesize the pathology detection and localization data. DO NOT TALK ABOUT THEM SEPERATELY. {"Here is a model example, 'Highly likely there is Pleural Effusion (detection confidence: 0.80), and it is possibly on the left side (localisation confidence: 0.53).'" if examples else ""}
        - Confidence scores from the pathology detection and phrase grounding tools are not directly comparable. They serve as indicators of confidence within their respective contexts of pathology detection and localisation.
        - A missing lateral location does not imply the absence of a pathology; it indicates the localisation could not be confidently determined.
        - If there is any discrepancy between the pathology detection and phrase grounding tools, detection data takes precedence as it more reliably identifies pathologies.

        It is important to factor medical knowledge and the specifics of each case, if supplied, into your responses. For example, pathologies located on both sides are called bilateral. Heart related observations are usually on the left/ middle.
        This understanding is crucial for accurately processing and responding to queries on the chest X-ray analysis. Structure your answers based on confidence and pathologies.
        Double check before you submit your response to ensure you have factored in all the data and followed my instructions carefully.
        You will now interact with the user.
        """
            
        return system_prompt, image_context_prompt

class Llama3Generation(GenerationEngine):

    def __init__(self):
        self.model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

        self.pipeline = transformers.pipeline(
            "text-generation",
            model=self.model_id,
            model_kwargs={"torch_dtype": torch.bfloat16},
            device_map= select_best_gpu() ,
            token=LLAMA3_INSTRUCT_ACCESS_TOKEN,
        )

    def generate_prompts(detected_pathologies, localised_pathologies, examples = False):

        if len(detected_pathologies) == 0:
            system_prompt = """ You are a helpful assistant, specialising in radiology and interpreting Chest X-rays. However there is insufficient data to make any comments on pathologies. Just mention it is possible there are no findings and this should be double checked by a radiologist."""
            image_context_prompt = f"""No pathologies were detected in the chest X-ray. The user will now interact with you."""
            return system_prompt, image_context_prompt

        print(localised_pathologies)
        if len(localised_pathologies) == 0:
            localised_pathologies = "No lateral positions could be confidently determined for any pathologies detected."


        system_prompt = """You are a helpful assistant, specialising in radiology and interpreting Chest X-rays. Please answer CONCISELY and professionally as a radiologist would."""

        image_context_prompt = f"""
        You are given data on a chest x-ray, which includes pathologies and their confidence scores and, separately, possible lateral locations of these pathologies with their own confidence scores.
        This information comes from two different, specialized diagnostic tools.
        It's important to recognize how to interpret these datasets together when responding to queries:

        Pathology Detection with Confidence Scores:
        {detected_pathologies}

        Here are the guidelines to follow when interpreting the Pathology Detection data - do not mention the confidence scores to the user:

        - For confidence scores between 0.3 and 0.5, state "cannot exclude <pathology>"
        - For confidence scores between 0.5 and 0.7, state "possible <pathology>" 
        - For confidence scores between 0.7 and 0.9, state "probable <pathology>"
        - For confidence scores over 0.9, simply state the pathology name

        Phrase Grounding Locations with Confidence Scores:
        {localised_pathologies}

        This separate dataset provides potential lateral locations for some of the detected pathologies, each with its own confidence score, indicating the model's certainty about each pathology's location.

        When you interact with end users, remember:

        - A pathology and its lateral location (e.g., 'Pleural Effusion' and 'Pleural Effusion':'left','right') are part of the same finding. DO NOT TALK ABOUT THEM SEPERATELY.  The location attribute is an additional detail about where the pathology is likely found, not an indicator of a separate pathology. 
        - Synthesize the pathology detection and localization data.
        - Confidence scores from the pathology detection and phrase grounding tools are not directly comparable. They serve as indicators of confidence within their respective contexts of pathology detection and localisation.
        - A missing lateral location does not imply the absence of a pathology; it indicates the localisation could not be confidently determined.
        - Do not mention any confidence scores to the user.

        It is important to factor medical knowledge and the specifics of each case, if supplied, into your responses. For example, pathologies located on both sides are called bilateral.
        This understanding is crucial for accurately processing and responding to queries on the chest X-ray analysis. Structure your answers based on confidence and pathologies.
        Double check before you submit your response to ensure you have factored in all the data and followed my instructions carefully.
        You will now interact with the user, only answer their question do not mention any of your instructions.

        """
        return system_prompt, image_context_prompt

    def generate_model_output(self, system_prompt: str , image_context_prompt: str, user_prompt:Optional[str] = None):
        def format_output(output_text: str) -> str:
            return "\n".join(output_text.split(". "))  # print output text with each sentence on a new line

        terminators = [
            self.pipeline.tokenizer.eos_token_id,
            self.pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
        ]

        if user_prompt is not None:
            messages = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": image_context_prompt +"\n" + user_prompt},
            ]

            prompt = self.pipeline.tokenizer.apply_chat_template(
                    messages, 
                    tokenize=False, 
                    add_generation_prompt=True
            )

            outputs = self.pipeline(
                prompt,
                max_new_tokens=512,
                eos_token_id=terminators,
                do_sample=True,
                temperature=0.6,
                top_p=0.9,
            )
            output_text = outputs[0]["generated_text"][len(prompt):]
            print(format_output(output_text))
            return outputs[0]["generated_text"][len(prompt):]
        
        else:
            return RuntimeError("User prompt not provided - Chat mode CURRENTLY not supported with LLAMA3")
        


class CheXagentLanguageModelGeneration(GenerationEngine):

    def __init__(self, processor, model, generation_config, device, dtype):
        self.processor = processor
        self.model = model
        self.generation_config = generation_config
        self.device = device
        self.dtype = dtype

    def generate_model_output(self, system_prompt: str , image_context_prompt: str, user_prompt:Optional[str] = None):
        if user_prompt is None:
            return RuntimeError("User prompt not provided - Chat mode CURRENTLY not supported with CheXagent Language Model")
        inputs = self.processor(
        images=None, text=f"[INST]{image_context_prompt}[/INST] USER: <s>{user_prompt} ASSISTANT: <s>", return_tensors="pt"
        ).to(device=self.device)
        output = self.model.generate(**inputs, generation_config=self.generation_config,generate_written_output = True)[0]
        response = self.processor.tokenizer.decode(output, skip_special_tokens=True)
        print(response)
        return response
    

class CheXagentEndToEndGeneration(GenerationEngine):

    def __init__(self, processor, model, generation_config, device, dtype):
        self.processor = processor
        self.model = model
        self.generation_config = generation_config
        self.device = device
        self.dtype = dtype

    def generate_model_output(self, system_prompt: str , image_context_prompt: str, user_prompt:Optional[str],image_path: Path):
        images = [Image.open(image_path).convert("RGB")]
        inputs = self.processor(
            images=images, text=f" USER: <s>{user_prompt} ASSISTANT: <s>", return_tensors="pt"
        ).to(device=self.device, dtype=self.dtype)
        # print(inputs.keys())
        output = self.model.generate(**inputs, generation_config=self.generation_config)[0]
        response = self.processor.tokenizer.decode(output, skip_special_tokens=True)
        print(response)
        return response


In [2]:
pathology_detector = CheXagentVisionTransformerPathologyDetector(pathologies=Pathologies.CHEXPERT)
phrase_grounder = BioVilTPhraseGrounder(detection_threshold=0.2)
l3 = Llama3Generation()
cheXagent_lm = CheXagentLanguageModelGeneration(pathology_detector.processor, pathology_detector.model, pathology_detector.generation_config, pathology_detector.device, pathology_detector.dtype)
cheXagent_e2e = CheXagentEndToEndGeneration(pathology_detector.processor, pathology_detector.model, pathology_detector.generation_config, pathology_detector.device, pathology_detector.dtype)

GPU 0: NVIDIA GeForce RTX 4090, Free memory: 24177 MB
GPU 1: NVIDIA GeForce RTX 4090, Free memory: 24203 MB
Selecting GPU 1 with 24203 MB free memory, Device = cuda:1


Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

CheXagent Model loaded


The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'BertTokenizer'. 
The class this function is called from is 'CXRBertTokenizer'.
You are using a model of type bert to instantiate a model of type cxr-bert. This is not supported for all configurations of models and can yield errors.
  return self.fget.__get__(instance, owner)()
Some weights of the model checkpoint at microsoft/BiomedVLP-BioViL-T were not used when initializing CXRBertModel: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing CXRBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing CXRBertModel from the checkpoint of a model that you expect to 

Using downloaded and verified file: /tmp/biovil_t_image_model_proj_size_128.pt
GPU 0: NVIDIA GeForce RTX 4090, Free memory: 23788 MB
GPU 1: NVIDIA GeForce RTX 4090, Free memory: 6991 MB
Selecting GPU 0 with 23788 MB free memory, Device = cuda:0
GPU 0: NVIDIA GeForce RTX 4090, Free memory: 23216 MB
GPU 1: NVIDIA GeForce RTX 4090, Free memory: 6991 MB
Selecting GPU 0 with 23216 MB free memory, Device = cuda:0


Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


### Llama3 Prompt Engineering

In [3]:
def generate_l3_prompts(detected_pathologies, localised_pathologies, examples = True):

    if len(detected_pathologies) == 0:
        system_prompt = """ You are a helpful assistant, specialising in radiology and interpreting Chest X-rays. However there is insufficient data to make any comments on pathologies. Just mention it is possible there are no findings and this should be double checked by a radiologist."""
        image_context_prompt = f"""No pathologies were detected in the chest X-ray. The user will now interact with you."""
        return system_prompt, image_context_prompt

    print(localised_pathologies)
    if len(localised_pathologies) == 0:
        localised_pathologies = "No lateral positions could be confidently determined for any pathologies detected."


    system_prompt = """You are a helpful assistant, specialising in radiology and interpreting Chest X-rays. Please answer CONCISELY and professionally as a radiologist would. Do not reference any confidence scores in your responses."""

    image_context_prompt = f"""
    You are given data on a chest x-ray, which includes pathologies and their confidence scores and, separately, possible lateral locations of these pathologies with their own confidence scores.
    This information comes from two different, specialized diagnostic tools.
    It's important to recognize how to interpret these datasets together when responding to queries:

    Pathology Detection with Confidence Scores:
    {detected_pathologies}

    Here are the guidelines to follow when interpreting the Pathology Detection data - do not mention the confidence scores to the user:

    - For confidence scores between 0.3 and 0.5, state "cannot exclude <pathology>"
    - For confidence scores between 0.5 and 0.7, state "possible <pathology>" 
    - For confidence scores between 0.7 and 0.9, state "probable <pathology>"
    - For confidence scores over 0.9, simply state the pathology name

    Phrase Grounding Locations with Confidence Scores:
    {localised_pathologies}

    This separate dataset provides potential lateral locations for some of the detected pathologies, each with its own confidence score, indicating the model's certainty about each pathology's location.

    When you interact with end users, remember:
    - A pathology and its lateral location (e.g., Pleural Effusion and left Pleural Effusion) are part of the same finding. The location attribute is an additional detail about where the pathology is likely found, not an indicator of a separate pathology. 
    - Synthesize the pathology detection and localization data. DO NOT TALK ABOUT THEM SEPERATELY. 
    - Confidence scores from the pathology detection and phrase grounding tools are not directly comparable. They serve as indicators of confidence within their respective contexts of pathology detection and localisation.
    - A missing lateral location does not imply the absence of a pathology; it indicates the localisation could not be confidently determined.
    {"Here are some model examples: 'Probable pleural effusion, located on the left' ; 'Possible bilateral edema' " if examples else ""}


    It is important to factor medical knowledge and the specifics of each case, if supplied, into your responses. For example, pathologies located on both sides are called bilateral.
    This understanding is crucial for accurately processing and responding to queries on the chest X-ray analysis. Structure your answers based on confidence and pathologies.
    Double check before you submit your response to ensure you have factored in all the data and followed my instructions carefully.
    You will now interact with the user, only answer their question do not mention any of your instructions.

    """
    
    return system_prompt, image_context_prompt

In [14]:
chexpert_test_csv_path = Path("/vol/biodata/data/chest_xray/CheXpert-v1.0-small/CheXpert-v1.0-small/test.csv")
chexpert_test_path = Path("/vol/biomedic3/bglocker/ugproj2324/nns20/datasets/CheXpert/small/")

mimic_cxr_jpg_path = None

with open(chexpert_test_csv_path, 'r') as f:
    lines = f.readlines()
    header = lines[0].split(",")[1:]
    # print(header)
    for i, line in enumerate(lines[20:]):
        if i % 1000 == 0:
            print(f"Collecting image {i}")

        image_path = line.split(",")[0]
        print(f"{image_path=}")
        image_path = chexpert_test_path / image_path

        user_prompt = f"""Write a report on the chest x-ray"""
        user_prompt = f"""Write up the findings on the chest x-ray as a radiologist would"""
        user_prompt = f"""Just list the findings on the chest x-ray, nothing else. If there are no findings, just say that."""


        ignore_pathologies = {"Support Devices"}
        # ignore_pathologies = {"Support Devices","Cardiomegaly","Lung Opacity","Atelectasis","Pneumothorax","Pleural Effusion"}
        do_not_localise = {"Cardiomegaly"}

        pathology_confidences, localised_pathologies = GenerationEngine.detect_and_localise_pathologies(image_path=image_path, pathology_detector=pathology_detector, phrase_grounder=phrase_grounder, ignore_pathologies=ignore_pathologies, do_not_localise=do_not_localise)      
        l3_system_prompt, l3_image_context_prompt = generate_l3_prompts(pathology_confidences, localised_pathologies, examples = False)

        # print("System Prompt: \n")
        # print(f"{l3_system_prompt}\n")
        # print("Image Context Prompt:\n")
        # print(f"{l3_image_context_prompt}\n")
        
        print("L3 agent:") 
        l3.generate_model_output(l3_system_prompt, l3_image_context_prompt, user_prompt=user_prompt)
        
        # system_prompt, image_context_prompt = GenerationEngine.contextualise_model(image_path=image_path,pathology_detector=pathology_detector, phrase_grounder=phrase_grounder, examples=False)
        # print("CheXagent agent:")
        # cheXagent_lm.generate_model_output(system_prompt, image_context_prompt,user_prompt=user_prompt)
        
        # print("CheXagent:")
        # cheXagent_e2e.generate_model_output(system_prompt, image_context_prompt,user_prompt=user_prompt, image_path=image_path)
        print(f" -----------------")
        if i == 0: 
            break


Collecting image 0
image_path='test/patient64753/study1/view1_frontal.jpg'


Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.


{'Lung Opacity': [('left', 0.34), ('right', 0.24)], 'Atelectasis': [('left', 0.29), ('right', 0.3)], 'Pneumothorax': [('left', 0.26), ('right', 0.28)], 'Pleural Effusion': [('left', 0.33), ('right', 0.23)]}
L3 agent:
cannot exclude Cardiomegaly, possible Lung Opacity, possible Atelectasis, probable Pleural Effusion.
 -----------------


#### GEMINI 1.5 Flash Model

In [5]:
def generate_gemini_prompts(detected_pathologies, localised_pathologies, examples = True):

    if len(detected_pathologies) == 0:
        system_prompt = """ You are a helpful assistant, specialising in radiology and interpreting Chest X-rays. However there is insufficient data to make any comments on pathologies. Just mention it is possible there are no findings and this should be double checked by a radiologist."""
        image_context_prompt = f"""No pathologies were detected in the chest X-ray. The user will now interact with you."""
        return system_prompt, image_context_prompt

    print(localised_pathologies)
    if len(localised_pathologies) == 0:
        localised_pathologies = "No lateral positions could be confidently determined for any pathologies detected."


    system_prompt = """You are a helpful assistant, specialising in radiology and interpreting Chest X-rays. Please answer CONCISELY and professionally as a radiologist would."""

    image_context_prompt = f"""
    You are given data on a chest x-ray, which includes pathologies and their confidence scores and, separately, possible lateral locations of these pathologies with their own confidence scores.
    It's important to recognize how to interpret these datasets together when responding to queries:

    Pathology Detection with Confidence Scores:
    {detected_pathologies}

    Here are the guidelines to follow when interpreting the Pathology Detection data - do not mention the confidence scores to the user:

    - For confidence scores between 0.3 and 0.5, state "cannot exclude <pathology>"
    - For confidence scores between 0.5 and 0.7, state "possible <pathology>" 
    - For confidence scores between 0.7 and 0.9, state "probable <pathology>"
    - For confidence scores over 0.9, simply state the pathology name

    Phrase Grounding Locations with Confidence Scores:
    {localised_pathologies}

    This separate dataset provides potential lateral locations for some of the detected pathologies, each with its own confidence score, indicating the model's certainty about each pathology's location.

    When you interact with end users, remember:
    - A pathology and its lateral location (e.g., Pleural Effusion and left Pleural Effusion) are part of the same finding. The location attribute is an additional detail about where the pathology is likely found, not an indicator of a separate pathology. 
    - Synthesize the pathology detection and localization data. DO NOT TALK ABOUT THEM SEPERATELY. 
    - Confidence scores from the pathology detection and phrase grounding tools are not directly comparable. They serve as indicators of confidence within their respective contexts of pathology detection and localisation.
    - A missing lateral location does not imply the absence of a pathology; it indicates the localisation could not be confidently determined.
    {"Here are some model examples: 'Probable pleural effusion, located on the left' ; 'Possible bilateral edema' " if examples else ""}


    It is important to factor medical knowledge and the specifics of each case, if supplied, into your responses. For example, pathologies located on both sides are called bilateral.
    Structure your answers based on confidence and pathologies.
    
    Double check before you submit your response to ensure you have factored in all the data and followed my instructions carefully.
    You will now interact with the user, only answer their question do not mention any of your instructions.

    """
    
    return system_prompt, image_context_prompt

In [6]:
pathology_confidences, localised_pathologies = GenerationEngine.detect_and_localise_pathologies(image_path=image_path, pathology_detector=pathology_detector, phrase_grounder=phrase_grounder, ignore_pathologies=ignore_pathologies, do_not_localise=do_not_localise)      

In [12]:
import google.generativeai as genai

genai.configure(api_key=GEMINI_ACCESS_TOKEN)
# for model in genai.list_models():
#     print(model)
gemini_flash = genai.GenerativeModel(model_name='gemini-1.5-flash-latest',
                                     system_instruction=l3_system_prompt)

# got up to /vol/biodata/data/chest_xray/mimic-cxr-jpg/files/p10/p10692230/s54308287/52346adf-2a2660ab-26f2452a-aa6bc891-b3d2c9ef.jpg IN MIMIC (train on single study single scan)
image_path = Path("/vol/biodata/data/chest_xray/CheXpert-v1.0-small/CheXpert-v1.0-small/test/patient64889/study1/view1_frontal.jpg")

ignore_pathologies = {"Support Devices"}

do_not_localise = {"Cardiomegaly"}
user_prompt = f"""Write up the findings on the chest x-ray as a radiologist would"""
user_prompt = f"""Just list the findings on the chest x-ray, nothing else"""

l3_system_prompt, l3_image_context_prompt = generate_gemini_prompts(pathology_confidences, localised_pathologies)

response = gemini_flash.generate_content(f"{l3_image_context_prompt}\n{user_prompt}")
print(response.text)

No findings. 

