In [1]:
import os
os.environ['HF_HOME'] = '/vol/biomedic3/bglocker/ugproj2324/nns20/cxr-agent/.hf_cache' ## THIS HAS TO BE BEFORE YOU IMPORT TRANSFORMERS

import transformers
import torch

from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional
from my_secrets import LLAMA3_INSTRUCT_ACCESS_TOKEN
from agent_utils import select_best_gpu

from pathology_detector import PathologyDetector, CheXagentVisionTransformerPathologyDetector
from pathology_sets import Pathologies

from phrase_grounder import PhraseGrounder, BioVilTPhraseGrounder

class GenerationEngine(ABC):
    @abstractmethod
    def generate_report(self, image_path: Path, prompt: Optional[str], output_dir: Optional[str]) -> str:
        pass


class Llama3Generation(GenerationEngine):

    def __init__(self):
        self.model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

        self.pipeline = transformers.pipeline(
            "text-generation",
            model=self.model_id,
            model_kwargs={"torch_dtype": torch.bfloat16},
            device_map= select_best_gpu() ,
            token=LLAMA3_INSTRUCT_ACCESS_TOKEN,
        )

    def generate_report(self, image_path: Path, pathology_detector: PathologyDetector, phrase_grounder: Optional[PhraseGrounder] = None,  prompt: Optional[str] = "Write up a findings section based on these observations") -> str:
        return "TODO"

In [2]:
l3 = Llama3Generation()
pathology_detector = CheXagentVisionTransformerPathologyDetector(pathologies=Pathologies.CHEXPERT)
phrase_grounder = BioVilTPhraseGrounder(detection_threshold=0.5)

GPU 0: NVIDIA RTX 6000 Ada Generation, Free memory: 46315 MB
GPU 1: NVIDIA RTX 6000 Ada Generation, Free memory: 48643 MB
Selecting GPU 1 with 48643 MB free memory, Device = cuda:1


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


GPU 0: NVIDIA RTX 6000 Ada Generation, Free memory: 46315 MB
GPU 1: NVIDIA RTX 6000 Ada Generation, Free memory: 32769 MB
Selecting GPU 0 with 46315 MB free memory, Device = cuda:0


Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

CheXagent Model loaded


The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'BertTokenizer'. 
The class this function is called from is 'CXRBertTokenizer'.
You are using a model of type bert to instantiate a model of type cxr-bert. This is not supported for all configurations of models and can yield errors.
  return self.fget.__get__(instance, owner)()
Some weights of the model checkpoint at microsoft/BiomedVLP-BioViL-T were not used when initializing CXRBertModel: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing CXRBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing CXRBertModel from the checkpoint of a model that you expect to 

Using downloaded and verified file: /tmp/biovil_t_image_model_proj_size_128.pt
GPU 0: NVIDIA RTX 6000 Ada Generation, Free memory: 29063 MB
GPU 1: NVIDIA RTX 6000 Ada Generation, Free memory: 32769 MB
Selecting GPU 1 with 32769 MB free memory, Device = cuda:1


In [84]:
def contextualise_model(generation_engine: GenerationEngine, image_path: Path, pathology_detector: PathologyDetector, phrase_grounder: Optional[PhraseGrounder] = None) -> str:
    
    pathology_detection_threshold = 0.5

    if pathology_detector is not None:
        pathology_confidences = pathology_detector.detect_pathologies(image_path, threshold = pathology_detection_threshold)
        print(pathology_confidences)
    else:
        return RuntimeError("Pathology detector not provided")
    
    #### PROMPT PIPELINE ###
    
    system_prompt = """You are a helpful assistant, specialising in radiology and interpreting Chest X-rays. Please answer CONCISELY and professionally as a radiologist would. You should not be confident unless the data is confident. Use language that reflects the confidence of the data."""

    # TODO: have list of prompts (pathology, phrase_grounder, etc) and build final user prompt by joining list with \n

    if phrase_grounder is not None:
        # get list of pathologies detected
        pathologies = [pathology for pathology, confidence in pathology_confidences.items() if confidence > pathology_detection_threshold]
        grounded_pathologies_confidences = phrase_grounder.get_pathology_lateral_position(image_path, pathologies)
        print(grounded_pathologies_confidences)
        if len(grounded_pathologies_confidences) == 0:
            grounded_pathologies_confidences = "No lateral positions could be confidently determined for any pathologies detected."

        image_context_prompt = f"""
        You are being provided with data derived from analyzing a chest X-ray, which includes findings on potential pathologies alongside their confidence levels and, separately, possible lateral locations of these pathologies with their own confidence scores.
        This information comes from two specialized diagnostic tools.
        It's important to recognize how to interpret these datasets together when responding to queries:

        Pathology Detection with Confidence Scores:
        {pathology_confidences}

        Phrase Grounding Locations with Confidence Scores:
        {grounded_pathologies_confidences}

        This separate dataset provides potential lateral locations for some of the detected pathologies, each with its own confidence score, indicating the model's certainty about each pathology's location.
        For instance, left Pleural Effusion is listed with a confidence score of 0.53, suggesting the location of Pleural Effusion to be on the left side with moderate confidence.
        
        When you interact with end users, remember:

        - A pathology and its lateral location (e.g., Pleural Effusion and left Pleural Effusion) are part of the same finding. The location attribute is an additional detail about where the pathology is likely found, not an indicator of a separate pathology. 
        - Synthesize the pathology detection and localization data. DO NOT TALK ABOUT THEM SEPERATELY. Here is a model example, "Highly likely there is Pleural Effusion (detection confidence: 0.80), and it is possibly on the left side (localisation confidence: 0.53)."
        - Confidence scores from the pathology detection and phrase grounding tools are not directly comparable. They serve as indicators of confidence within their respective contexts of pathology detection and localisation.
        - A missing lateral location does not imply the absence of a pathology; it indicates the localisation could not be confidently determined.
        - If there is any discrepancy between the pathology detection and phrase grounding tools, detection data takes precedence as it more reliably identifies pathologies.

        It is important to factor medical knowledge and the specifics of each case, if supplied, into your responses. For example, pathologies located on both sides are called bilateral. Heart related observations are usually on the left/ middle.
        
        This understanding is crucial for accurately processing and responding to queries on the chest X-ray analysis. Structure your answers based on confidence and pathologies.
        The end-user will now interact with you.
        """

    return system_prompt, image_context_prompt

def generate_model_output(generation_engine: GenerationEngine, system_prompt: str , image_context_prompt: str, user_prompt:Optional[str] = None):

    def format_output(output_text: str) -> str:
        return "\n".join(output_text.split(". "))  # print output text with each sentence on a new line
    
    if user_prompt is not None:
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": image_context_prompt +"\n" + user_prompt},
        ]

        prompt = generation_engine.pipeline.tokenizer.apply_chat_template(
                messages, 
                tokenize=False, 
                add_generation_prompt=True
        )

        terminators = [
            generation_engine.pipeline.tokenizer.eos_token_id,
            generation_engine.pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
        ]

        outputs = generation_engine.pipeline(
            prompt,
            max_new_tokens=512,
            eos_token_id=terminators,
            do_sample=True,
            temperature=0.6,
            top_p=0.9,
        )
        output_text = outputs[0]["generated_text"][len(prompt):]
        print(format_output(output_text))
        return outputs[0]["generated_text"][len(prompt):]
    
    else:
        # setup chat loop
        chat_history = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": image_context_prompt},
        ]

        user_prompt = ""
        while user_prompt != "exit":
            # user_query = input("User: ")
            # messages.append({"role": "user", "content": user_query})
            prompt = generation_engine.pipeline.tokenizer.apply_chat_template(
                chat_history, 
                tokenize=False, 
                add_generation_prompt=True
            )

            terminators = [
                generation_engine.pipeline.tokenizer.eos_token_id,
                generation_engine.pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
            ]

            outputs = generation_engine.pipeline(
                prompt,
                max_new_tokens=512,
                eos_token_id=terminators,
                do_sample=True,
                temperature=0.6,
                top_p=0.9,
            )
            output_text = outputs[0]["generated_text"][len(prompt):]
            chat_history.append({"role": "assistant", "content": output_text})
            print(format_output(output_text))

            user_prompt = input("User: ")
            print(f"\n{user_prompt}")
            chat_history.append({"role": "user", "content": user_prompt})

        

In [None]:
chexpert_test_csv_path = Path("/vol/biodata/data/chest_xray/CheXpert-v1.0-small/CheXpert-v1.0-small/test.csv")
chexpert_test_path = Path("/vol/biomedic3/bglocker/ugproj2324/nns20/datasets/CheXpert/small/")

with open(chexpert_test_csv_path, 'r') as f:
    lines = f.readlines()
    header = lines[0].split(",")[1:]
    # print(header)
    for i, line in enumerate(lines[10:]):
        if i % 1000 == 0:
            print(f"Collecting image {i}")

        image_path = line.split(",")[0]
        print(f"{image_path=}")
        image_path = chexpert_test_path / image_path

        system_prompt, image_context_prompt = contextualise_model(l3,image_path,pathology_detector=pathology_detector, phrase_grounder=phrase_grounder)
        user_prompt = f"""Write up a findings section based on these observations"""
        generate_model_output(l3, system_prompt, image_context_prompt,user_prompt=None)
        
        print(" \n -------------------------------- \n")
        if i == 0: 
            break
