# MiniGPT-4
MiniGPT-4 13B parameter version Zero Shot Classfication with 3 different prompts

In [None]:
%cd MiniGPT-4

### Imports

In [None]:
import argparse
import os
import random

import numpy as np
import torch
import torch.backends.cudnn as cudnn
import gradio as gr

from minigpt4.common.config import Config
from minigpt4.common.dist_utils import get_rank
from minigpt4.common.registry import registry
from minigpt4.conversation.conversation import Chat, CONV_VISION

# imports modules for registration
from minigpt4.datasets.builders import *
from minigpt4.models import *
from minigpt4.processors import *
from minigpt4.runners import *
from minigpt4.tasks import *
import os

import argparse as argparse

### Helper Methods

In [None]:
def parse_args():
    parser = argparse.ArgumentParser(description="Demo")
    parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
    parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
    parser.add_argument("--num-beams", type=int, default=2, help="specify the gpu to load the model.")
    parser.add_argument("--temperature", type=int, default=0.9, help="specify the gpu to load the model.")
    parser.add_argument("--english", type=bool, default=True, help="chinese or english")
    parser.add_argument("--prompt-en", type=str, default="can you describe the current picture?", help="Can you describe the current picture?")
    parser.add_argument("--prompt-zh", type=str, default="你能描述一下当前的图片？", help="Can you describe the current picture?")
    parser.add_argument(
        "--options",
        nargs="+",
        help="override some settings in the used config, the key-value pair "
        "in xxx=yyy format will be merged into config file (deprecate), "
        "change to --cfg-options instead.",
    )
    args = parser.parse_args()
    return args


def setup_seeds(config):
    seed = config.run_cfg.seed + get_rank()

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    cudnn.benchmark = False
    cudnn.deterministic = True


### fix this method since it's not completely accurate 
def encode_diagnosis(diagnosis):
    # could add: if contains glaucomatous and normal, then only look at first sentence
    
    if 'glaucomatous' in diagnosis.lower():
        return 1
    if 'normal' in diagnosis.lower():
        return 0
    else:
        return 2

def fetch_ground_truth(img_path):
    split_string = img_path.split("/")

    # Find the index of "glaucoma" in the split string
    try:
        split_string.index("glaucoma")
        return 1
    except:
        return 0

def get_all_files(directory):
    all_files = []
    
    # Iterate over all the directories and files within the given directory
    for root, directories, files in os.walk(directory):
        for file in files:
            file_path = os.path.join(root, file)
            all_files.append(file_path)
    
    return all_files

def ask_model(chat_state, image, prompt, img_list):
    chat_state.upload_img(image, chat_state, img_list)
    chat_state.ask(prompt, chat_state)
    llm_message = chat.answer(
        conv=chat_state,
        img_list=img_list,
        num_beams=args.num_beams,
        temperature=args.temperature,
        max_new_tokens=300,
        max_length=2000
    )[0]
    return llm_message

### Initialize the Model

In [None]:
print('Initializing Chat')
#args = parse_args()
#args = preset_args
args = argparse.Namespace(cfg_path='eval_configs/minigpt4_eval.yaml', gpu_id=0, num_beams=2, temperature=0.9, english=True, prompt_en='can you describe the current picture?', prompt_zh='你能描述一下当前的图片？', options=None)
cfg = Config(args)

model_config = cfg.model_cfg
model_config.device_8bit = args.gpu_id
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))

vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
print('Initialization Finished')

print('Intializing Test')

### Set the directory of the test images 

In [None]:
import random

directory = 'RIM-ONE_DL_images/partitioned_randomly/test_set'

files = get_all_files(directory)
random.shuffle(files)

#print(files)



## Run Test

In [None]:
prompt1 = "Please diagnose the fundus image as either glaucomatous or normal."
prompt2 = """You are ophthoLLM, an ophthalmologist AI assistant that provides diagnoses on fundus 
    images in order to assist doctors. You understand that it is important to recommend consulting 
    a medical professional if there is any uncertainty, and before taking any action. You give a binary, 
    one-word diagnosis on images. You either state that the image is Glaucomatous if there are signs of 
    glaucoma, or Normal if the image appears healthy. Following these instructions, and making sure to only give 
    your answer as either “Glaucomatous” or “Normal,” please diagnose the image. Make sure to include either 'Normal' or 'Glaucoma' in your answer."""

prompt3 = """You are OphthoLLM, an ophthalmology expert AI that diagnosis Glaucoma. You have been provided a fundus image. 

Here are some guidelines for diagnosing Glaucoma using a fundus image. 

Optic Disc Size: The size of the optic disc should be evaluated, as variations can be a normal characteristic. However, an unusually small or large optic disc may indicate specific conditions or risk factors.

Cup-to-Disc Ratio: The cup-to-disc ratio measures the size of the cup (the depression in the center of the optic nerve head) relative to the size of the entire optic disc. An increased cup-to-disc ratio may suggest glaucomatous damage.

Cup Shape: The shape of the cup should be observed, as an asymmetric or vertically elongated cup can be an indication of glaucoma.

Optic Disc Rim: The appearance of the neuroretinal rim, which surrounds the cup, is assessed. In glaucoma, this rim tends to thin and become pale or grayish.

Rim Notching: Notching or notches in the neuroretinal rim, particularly in the inferior and superior regions, can be a characteristic sign of glaucoma.

Disc Hemorrhages: Presence of hemorrhages, small bleeding spots, at or around the optic nerve head may indicate glaucomatous damage.

Nerve Fiber Layer Defects: The doctor will look for thinning or gaps in the retinal nerve fiber layer (RNFL) around the optic nerve head, which is a common early sign of glaucoma.

Vascular Changes: Changes in the blood vessels, such as vascular narrowing, crossing defects, or bayoneting, may suggest glaucoma or other optic nerve disorders.

Optic Disc Color: The color of the optic nerve head is assessed, and any abnormal discoloration, such as pallor or hyperemia, may raise suspicion of optic nerve damage.
Peripapillary Atrophy: Doctors will look for areas of atrophy (thinning) of the retinal pigment epithelium around the optic disc, which can be associated with glaucoma.

Optic Nerve Head Excavation: The depth of the optic nerve head excavation or the cup depth is evaluated, as increased cupping can indicate glaucomatous damage.

Presence of Drusen: In elderly patients, the presence of drusen, small yellowish deposits in the optic disc, should be noted, as they may mimic glaucomatous changes.

Please answer my questions.
"""

prompts = [prompt1, prompt2, prompt3]

In [None]:
import pandas as pd

for num, prompt in enumerate(prompts):
    data = {'img_path': [],
            'diagnosis' : [],
            'ground_truth': [],
            'llm_message': [],
            }
    for image in files:
        data['img_path'].append(image)
        ground_truth = fetch_ground_truth(image)
        chat_state = CONV_VISION.copy()
        llm_message = ask_model(chat_state, prompt, image, [])
        data['llm_message'].append(llm_message)
        diagnosis = encode_diagnosis(llm_message)
        data['diagnosis'].append(diagnosis)
        data['ground_truth'].append(ground_truth)
        print(f""" Image: {image} | Diagnosis: {diagnosis} | Label: {ground_truth} | LLM: {llm_message}""")
    
   
    data_output_dir = "zero_shot_data_prompt"+str(num+1)+".csv"
    zero_shot_data = pd.DataFrame(data)
    zero_shot_data.to_csv(data_output_dir)
    zero_shot_data.head()

In [None]:
# Metrics

import sklearn.metrics as metrics
import matplotlib.pyplot as plt

all_metrics = []

for num in range(3):
    data_dir = "zero_shot_data_prompt"+str(num+1)+".csv"
    zero_shot_data = pd.read_csv(data_dir)
    
    y_pred = zero_shot_data['diagnosis']
    y_true = zero_shot_data['ground_truth']

    metrics.ConfusionMatrixDisplay.from_predictions(y_true, y_pred)
    plt.show()
    metrics.PrecisionRecallDisplay.from_predictions(y_true, y_pred)
    plt.show()

    tn, fp, fn, tp = metrics.confusion_matrix(y_true, y_pred).ravel()

    specificity = tn / (tn+fp)

    prec = tp/(tp+fp)

    rec = tp/(tp+fn)

    npv = tn/(fn + tn)

    accuracy = metrics.accuracy_score(y_true, y_pred)

    precision, recall, _ = metrics.precision_recall_curve(y_true, y_pred)
    prc = metrics.PrecisionRecallDisplay(precision=precision, recall=recall)
    prc.plot()
    plt.show()

    metrics.RocCurveDisplay.from_predictions(y_true, y_pred)
    plt.show()


    roc = metrics.roc_curve(y_true, y_pred)
    auroc = metrics.roc_auc_score(y_true, y_pred)
    auprc = metrics.average_precision_score(y_true, y_pred)

    f1_score = metrics.f1_score(y_true, y_pred)

    zero_shot_metrics = {
                        'name' : 'prompt '+str(num+1),
                        'accuracy': accuracy,
                        'precision': prec,
                        'recall': rec,
                        'auroc': auroc,
                        'auprc': auprc,
                        'f1_score': f1_score,
                        'tn:': tn,
                        'fp': fp,
                        'fn': fn,
                        'tp': tp,
                        'negative predictive value': npv,
                        'specificity': specificity,
                        }

    all_metrics.append(zero_shot_metrics)

    zero_shot_metrics_df = pd.DataFrame(zero_shot_metrics, index=[0])
    metrics_output_dir = "zero_shot_metrics_prompt"+str(num+1)+".csv"
    metrics_output_dir.to_csv(metrics_output_dir)