In [1]:
%cd MiniGPT-4

/home/jupyter/opthollm/MiniGPT-4


### Import Necessary Packages
Import minigpt4 and necessary helper libraries

In [11]:
#@title Import
import argparse
import os
import random

import numpy as np
import torch
import torch.backends.cudnn as cudnn
import gradio as gr

from minigpt4.common.config import Config
from minigpt4.common.dist_utils import get_rank
from minigpt4.common.registry import registry
from minigpt4.conversation.multi_img_conversation import Chat, CONV_VISION

# imports modules for registration
from minigpt4.datasets.builders import *
from minigpt4.models import *
from minigpt4.processors import *
from minigpt4.runners import *
from minigpt4.tasks import *
import os

import argparse as argparse

### Helper Methods
Define helper methods including encode diagnosis 

In [12]:
#@title Methods
def parse_args():
    parser = argparse.ArgumentParser(description="Demo")
    parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
    parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
    parser.add_argument("--num-beams", type=int, default=2, help="specify the gpu to load the model.")
    parser.add_argument("--temperature", type=int, default=0.9, help="specify the gpu to load the model.")
    parser.add_argument("--english", type=bool, default=True, help="chinese or english")
    parser.add_argument("--prompt-en", type=str, default="can you describe the current picture?", help="Can you describe the current picture?")
    parser.add_argument("--prompt-zh", type=str, default="你能描述一下当前的图片？", help="Can you describe the current picture?")
    parser.add_argument(
        "--options",
        nargs="+",
        help="override some settings in the used config, the key-value pair "
        "in xxx=yyy format will be merged into config file (deprecate), "
        "change to --cfg-options instead.",
    )
    args = parser.parse_args()
    return args


def setup_seeds(config):
    seed = config.run_cfg.seed + get_rank()

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    cudnn.benchmark = False
    cudnn.deterministic = True


### fix this method since it's not completely accurate 

# determines if the LLM thinks the image is glaucomatous or not based on whether or not the text contains glaucomatous or not 
def encode_diagnosis(diagnosis):
    # could add: if contains glaucomatous and normal, then only look at first sentence
    
    if 'glaucomatous' in diagnosis.lower():
        return 1
    if 'normal' in diagnosis.lower():
        return 0
    else:
        return 2

# finds the true label of an image based on where it's stored in file path 
def fetch_ground_truth(img_path):
    split_string = img_path.split("/")

    # Find the index of "glaucoma" in the split string
    try:
        split_string.index("glaucoma")
        return 1
    except:
        return 0

# helper method that gets all files from a directory 
def get_all_files(directory):
    all_files = []
    
    # Iterate over all the directories and files within the given directory
    for root, directories, files in os.walk(directory):
        for file in files:
            file_path = os.path.join(root, file)
            all_files.append(file_path)
    
    return all_files

def get_random_file(directory):
    all_files = get_all_files(directory)
    return random.choice(all_files)

### Initialize Model

In [13]:
print('Initializing Chat')
#args = parse_args()
#args = preset_args
args = argparse.Namespace(cfg_path='eval_configs/minigpt4_eval.yaml', gpu_id=0, num_beams=2, temperature=0.9, english=True, prompt_en='can you describe the current picture?', prompt_zh='你能描述一下当前的图片？', options=None)
cfg = Config(args)

model_config = cfg.model_cfg
model_config.device_8bit = args.gpu_id
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))

vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
print('Initialization Finished')

print('Intializing Test')

Initializing Chat
Loading VIT
Loading VIT Done
Loading Q-Former
Loading Q-Former Done
Loading LLAMA


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading LLAMA Done
Load 4 training prompts
Prompt Example 
###Human: <Img><ImageHere></Img> Could you describe the contents of this image for me? ###Assistant: 
Load BLIP2-LLM Checkpoint: pretrained_minigpt4.pth
Initialization Finished
Intializing Test


### Test Few Shot Learning  

In [14]:
# store output

few_shot_data = {'img_path:': [],
                 'prediction:': [],
                 'ground_truth:': [],
                 'llm_message': []
                 }

# define examples for few shot learning

from chain_of_thought_imgs import img_descriptions
examples = img_descriptions.chain_of_thought_imgs

In [27]:
directory = 'RIM-ONE_DL_images/partitioned_randomly/training_set'
# pick random training image to test on

img_list = []
chat_state = CONV_VISION.copy()

image = get_random_file(directory)
few_shot_data['img_path:'].append(image)
few_shot_data['ground_truth:'].append(fetch_ground_truth(image))



# ask the prompt that has multiple examples (few shot inference)

chat.embed_imgs([row[0] for row in examples], img_list)
chat.embed_imgs([image], img_list)

prompt = f"""

<Img><ImageHere></Img>
Please diagnose the image as glaucomatous or normal:

Diagnosis: {examples[0][1]}

<Img><ImageHere></Img>
Please diagnose the image as glaucomatous or normal:

Diagnosis: {examples[1][1]}

<Img><ImageHere></Img>
Please diagnose the image as glaucomatous or normal:

Diagnosis: {examples[2][1]}

<Img><ImageHere></Img>
Please diagnose the image as glaucomatous or normal:

Diagnosis:
"""

chat.ask(prompt, conv=chat_state, img_list=img_list)

# have the model answer and display 
llm_message = llm_message = chat.answer(
        conv=chat_state,
        img_list=img_list,
        num_beams=args.num_beams,
        temperature=args.temperature,
        max_new_tokens=300,
        max_length=2000
    )[0]

few_shot_data['llm_message'].append(llm_message)
few_shot_data['prediction:'].append(encode_diagnosis(llm_message))

print(f"Img: {image} - Prediction: {few_shot_data['prediction:'][-1]} - Ground Truth: {few_shot_data['ground_truth:'][-1]} - LLM Message: {few_shot_data['llm_message'][-1]}")
  

In [28]:
prompt.split('<ImageHere>')

['\n\n<Img>',
 "</Img>\nPlease diagnose the image as glaucomatous or normal:\n\nDiagnosis: Color fundus photography of both eyes in a 71-year old woman with severe open-angle glaucoma. Right eye (left image) demonstrates marked cupping of the optic nerve. Retinal vessels can be seen 'bayoneting' superonasal (arrow). Left eye (right image) demonstrates greater cupping than the right eye with near complete loss of neuroretinal rim and surrounding peripapillary atrophy. Retinal vessels are seen 'bayoneting' inferotemporal with brief loss of visualization upon entering the cup (arrow). Optic nerve heads are magnified in the lower right corner of each image.\n\n<Img>",
 '</Img>\nPlease diagnose the image as glaucomatous or normal:\n\nDiagnosis: The optic nerve shows moderate cupping but there is a prominent inferior notch. The normal sheen of the nerve fiber layer is absent in a distribution radiating temporally from this notch due to cellular loss. The notch and nerve fiber layer defect ca