In [None]:
%cd MiniGPT-4


# Imports 

In [None]:
#@title Import
import argparse
import os
import random

import numpy as np
import torch
import torch.backends.cudnn as cudnn
import gradio as gr

from minigpt4.common.config import Config
from minigpt4.common.dist_utils import get_rank
from minigpt4.common.registry import registry
from minigpt4.conversation.conversation import Chat, CONV_VISION

# imports modules for registration
from minigpt4.datasets.builders import *
from minigpt4.models import *
from minigpt4.processors import *
from minigpt4.runners import *
from minigpt4.tasks import *
import os

import argparse as argparse


In [None]:
#@title Methods
def parse_args():
    parser = argparse.ArgumentParser(description="Demo")
    parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
    parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
    parser.add_argument("--num-beams", type=int, default=2, help="specify the gpu to load the model.")
    parser.add_argument("--temperature", type=int, default=0.9, help="specify the gpu to load the model.")
    parser.add_argument("--english", type=bool, default=True, help="chinese or english")
    parser.add_argument("--prompt-en", type=str, default="can you describe the current picture?", help="Can you describe the current picture?")
    parser.add_argument("--prompt-zh", type=str, default="你能描述一下当前的图片？", help="Can you describe the current picture?")
    parser.add_argument(
        "--options",
        nargs="+",
        help="override some settings in the used config, the key-value pair "
        "in xxx=yyy format will be merged into config file (deprecate), "
        "change to --cfg-options instead.",
    )
    args = parser.parse_args()
    return args


def setup_seeds(config):
    seed = config.run_cfg.seed + get_rank()

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    cudnn.benchmark = False
    cudnn.deterministic = True
    
def encode_diagnosis(diagnosis):
    if 'normal' in diagnosis:
        return 0
    if 'glaucoma' in diagnosis:
        return 1
    else:
        return 2

def fetch_ground_truth(img_path):
    split_string = img_path.split("/")

    # Find the index of "glaucoma" in the split string
    return 1 if split_string.index("glaucoma") != -1 else 0

def get_all_files(directory):
    all_files = []
    
    # Iterate over all the directories and files within the given directory
    for root, directories, files in os.walk(directory):
        for file in files:
            file_path = os.path.join(root, file)
            all_files.append(file_path)
    
    return all_files

In [None]:
print('Initializing Chat')
#args = parse_args()
#args = preset_args
args = argparse.Namespace(cfg_path='eval_configs/minigpt4_eval.yaml', gpu_id=0, num_beams=2, temperature=0.9, english=True, prompt_en='can you describe the current picture?', prompt_zh='你能描述一下当前的图片？', options=None)
cfg = Config(args)

model_config = cfg.model_cfg
model_config.device_8bit = args.gpu_id
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))

vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
print('Initialization Finished')

print('Intializing Test')

In [None]:
directory = 'RIM-ONE_DL_images/partitioned_randomly/test_set'

files = get_all_files(directory)

#print(files)

data = {'img_path': [],
        'diagnosis' : [],
        'ground_truth': [],
        'llm_message': [],
        }

In [None]:
for image in files:
    
    img_path = image
    data['img_path'].append(image)
    ground_truth = fetch_ground_truth(img_path)
    img_list = []
    chat_state = CONV_VISION.copy()
    chat.upload_img(image, chat_state, img_list)
    chat.ask("You are ophthoLLM, an ophthalmologist AI assistant that provides diagnoses on fundus \
    images in order to assist doctors. You understand that it is important to recommend consulting \
    a medical professional if there is any uncertainty, and before taking any action. You give a binary, \
    one-word diagnosis on images. You either state that the image is Glaucomatous if there are signs of \
    glaucoma, or Normal if the image appears healthy. Following these instructions, and making sure to only give \
    your answer as either “Glaucomatous” or “Normal,” please diagnose the image. Make sure to include either 'Normal' or 'Glaucoma' in your answer", chat_state)
    llm_message = chat.answer(
        conv=chat_state,
        img_list=img_list,
        num_beams=args.num_beams,
        temperature=args.temperature,
        max_new_tokens=300,
        max_length=2000
    )[0]
    
    data['llm_message'].append(llm_message)
    
    diagnosis = encode_diagnosis(llm_message)
    
    data['diagnosis'].append(diagnosis)
    data['ground_truth'].append(ground_truth)
    
    print(f""" Image: {image} | Diagnosis: {diagnosis} | Label: {ground_truth}""")

### Convert data to dataframe 

In [None]:
import pandas as pd

zero_shot_data = pd.DataFrame(data)
zero_shot_data.head()

### Evaluate Metrics

In [None]:
import sklearn.metrics as metrics
