Code to implement the multi-modal LLM (GPT-4o) with reasoning model (GPT-o1) for CXR analysis
This notebook is for method demonstration. If you want to replicate the experiment, please contact the author.

In [None]:
# uncomment the following lines to install the required packages

# ! pip install --upgrade pip
# ! pip install -q openai
# ! pip install -q langchain
# ! pip install -q transformers
# ! pip install torch
# ! pip install -q trainer
# ! pip install -q open_clip_torch

In [1]:
import langchain_core

In [1]:
from openai import OpenAI
import torch
from torch import nn
from transformers import AutoModel, AutoImageProcessor, AutoModelForImageClassification
from transformers import MobileViTFeatureExtractor, MobileViTForImageClassification
from transformers import Trainer
from PIL import Image
import torch.nn.functional as F
import pandas as pd
import requests
import argparse
import base64
import json
import os
import json

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

2025-01-10 15:13:52.538979: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-01-10 15:13:52.554616: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-01-10 15:13:52.559576: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-01-10 15:13:52.571460: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [None]:
from openai import OpenAI


# Set your OpenAI API key
api_key = "***"
client = OpenAI(api_key=api_key)

In [None]:
# the encoder for text and encoder for gpt image processing

def encoder(text, model="text-embedding-3-small"):
    text = text.replace("\n", " ")
    return client.embeddings.create(input = [text], model=model).data[0].embedding


def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')
    
encoded_dimention = 1536

In [None]:
# select the images to be processed

# the file list file is not provided, but it should contain the paths to the images and their corresponding crop paths
# the file should have a column 'fold' to indicate the fold number, 'crop_path' for cropped image paths, and 'whole_path' for whole image paths
# to access the image files, please conatct the author of the code or the dataset provider
data = pd.read_csv('file_list.csv')
val_data = data[data['fold'] == 1]
val_cropimages = val_data['crop_path'].tolist()
val_images = val_data['whole_path'].tolist()

cropimg_paths = val_cropimages[:5]
img_paths = val_images[:5]

print(f"crop image: {cropimg_paths}")
print(f"whole image: {img_paths}")

crop image: ['/root/GAN_models/covid_cv/crop_fold1/train/positive/1.2.826.0.1.3680043.10.474.419639.298706138782242084784855945802.png', '/root/GAN_models/covid_cv/crop_fold1/train/positive/chest_XR_205.png', '/root/GAN_models/covid_cv/crop_fold1/train/positive/1.2.826.0.1.3680043.10.474.419639.198791789717449124503990981583.png', '/root/GAN_models/covid_cv/crop_fold1/train/positive/chest_XR_1220.png', '/root/GAN_models/covid_cv/crop_fold1/train/positive/1.2.826.0.1.3680043.10.474.419639.692619205424417339465290726479.png']
whole image: ['/root/GAN_models/covid_cv/fold1/train/positive/1.2.826.0.1.3680043.10.474.419639.298706138782242084784855945802.png', '/root/GAN_models/covid_cv/fold1/train/positive/chest_XR_205.png', '/root/GAN_models/covid_cv/fold1/train/positive/1.2.826.0.1.3680043.10.474.419639.198791789717449124503990981583.png', '/root/GAN_models/covid_cv/fold1/train/positive/chest_XR_1220.png', '/root/GAN_models/covid_cv/fold1/train/positive/1.2.826.0.1.3680043.10.474.419639.6

In [6]:
# define helper function for GPT-4o model

def get_gpt_image_info(img_path, prompts, key, model):
    
    base64_image = encode_image(img_path)
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {key}"
    }
    
    
    role_instruct = """
    You are a radiologist responsible for reviewing the given chest X-ray and provide discription for it. \
    Your answer should cover the following contents. \
    Does the image present COVID-19 pneumonia ? \
    what is the mRALE(modified Radiographic Assessment of Lung Edema) score presented on the image ?\
    Describe the clinical findings ? \
    rate the pneumonia severity in one of the four levels: low, mild, moderate, and severe. \
    """
    
    # prompt = "Please retrieve the information from the input image."
    messages = [
        {"role": "system", 'content': role_instruct},
        {"role": "user", 'content': [
            {"type": "text", "text": prompt},
            {"type": "image_url", "image_url":{"url": f"data:image/jpeg;base64,{base64_image}"}}],
        "max_tokens": 200},
    ]
    
    payload = {
        "model": model,
        "messages": messages,
        "temperature": 0.3,
    }
    response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
    
    if response.status_code == 200:
        response_json = response.json()
        # Extract the reply text
        reply = response_json['choices'][0]['message']['content']
        return reply
    else:
        print(f"Request failed with status code {response.status_code}")
        return 0

In [7]:
# extract key info from the GPT-4o response

def extract_label(prompt, client, model='gpt-4o-mini'):
    role_instruct = """
    You are a medical AI expert. Your task is to extract the key information from the response of a fine-tuned gpt-4o model.\
    Your final answer includes: \
    label: positive or negative \
    level: pneumonia severity from low, mild, moderate, severe \
    score: the mRALE score for lung edema from 0 to 24 \
    return the final response in a dict format.
    """
    messages = [
        {"role": "system", 'content': role_instruct},
        {"role": "user", 'content': f"response from the fine-tuned model: {prompt}"},
        ]
    
    response = client.chat.completions.create(
        model=model,
        messages=messages,
        temperature=0.3, # this is the degree of randomness of the model's output, 0 - 1.0
    )
    json_string = response.choices[0].message.content
    # Convert the string to a dictionary
    response_dict = json.loads(json_string)
    return response_dict
    

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# load the dinov2 regressor

model_name = "facebook/dinov2-base"
# the checkpoint of the trained regressor is saved to the regressor_model.zip, please unzip it (multiple files) to the checkpoint_dir
checkpoint_dir = 'path/to/your/checkpoint'  # replace with your checkpoint path
dinov2_feature_extractor = AutoImageProcessor.from_pretrained(model_name)
regressor = AutoModelForImageClassification.from_pretrained(checkpoint_dir, num_labels=1)
regressor.to(device)
regressor.eval()
print("model loaded to GPU")

model loaded to GPU


In [None]:
# load the mobile ViT classifier
model_name = "apple/mobilevit-small"
# the checkpoint of the trained classifier is saved to the classifier_model.zip, please unzip it (multiple files) to the checkpoint_dir
checkpoint_dir = 'path/to/your/classifier_checkpoint'  # replace with your checkpoint path
classifier = MobileViTForImageClassification.from_pretrained(checkpoint_dir)
classifier.eval()
classifier = classifier.to(device)
mobilevit_processor = MobileViTFeatureExtractor.from_pretrained(model_name)
print("model loaded to GPU")

model loaded to GPU




In [10]:
# helper function to convert the prediction by the classifier and regressor to text
# LLM can only process text; using numbers as input causes errors

def parse_classify_text(classify_result):
    predict_label, prob_dict = classify_result
    message = f"The image is predicted as COVID-19 penumonia {predict_label}. The probability of being negative is {prob_dict['negative']:.3f}, and the probability of being positive is {prob_dict['positive']:.3f}."
    return message

def parse_regression(pred_score):
    message = f"The image is predicted to have an mRALE score of {pred_score}."
    return message

In [11]:
# function to use classifier and regressor

import torch.nn.functional as F

def classify_image(image_path, classifier):
    # Preprocess the image
    class_dict = {0:'negative', 1:'positive'}
    image = Image.open(image_path).convert('RGB')
    image = image.resize((224, 224))
    inputs = mobilevit_processor(images=image, return_tensors="pt")
    
    # Make predictions
    with torch.no_grad():
        inputs = inputs.to(device)
        outputs = classifier(**inputs)
    logits = outputs.logits
    probabilities = F.softmax(logits, dim=1)
    predicted_class = torch.argmax(logits, dim=1).item()
    if not predicted_class in [0, 1]:
        predicted_class = 0
    class_probabilities = {class_dict[i]: probabilities[0, i].item() for i in range(len(class_dict))}
    return class_dict[predicted_class], class_probabilities

def predict_score(image_path, regressor):
    # Preprocess the image
    image = Image.open(image_path).convert('RGB')
    image = image.resize((224, 224))
    inputs = dinov2_feature_extractor(images=image, return_tensors="pt")
    
    with torch.no_grad():
        inputs = inputs.to(device)
        outputs = regressor(**inputs)
    score = outputs.logits.item()
    if score < 0:
        score = 0
    return round(score, 2)

In [12]:
# load the Biomed CLIP model
# load the Biomed CLIP - PubMedBERT model
# helper functions to prompt the Biomed CLIP for prediction

from open_clip import create_model_from_pretrained, get_tokenizer # works on open-clip-torch>=2.23.0, timm>=0.9.8

biomed_model, biomed_preprocessor = create_model_from_pretrained('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
biomed_tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')


def pubmed_predict(model, preprocessor, tokenizer, image_path, label_list):
    model.to(device)
    model.eval()
    file_name = os.path.basename(image_path)
    context_length = 256
    template = 'this is a photo of '
    image = preprocessor(Image.open(image_path)).to(device)
    texts = tokenizer([template + l for l in label_list], context_length=256).to(device)
    with torch.no_grad():
        image_features, text_features, logit_scale = model(image.unsqueeze(0), texts)
        logits = (logit_scale * image_features @ text_features.t()).detach().softmax(dim=-1)
        sorted_indices = torch.argsort(logits, dim=-1, descending=True)
        logits = logits.cpu().numpy()
        sorted_indices = sorted_indices.cpu().numpy()
        # remove the batch dim
        return file_name, sorted_indices[0], logits[0]
    

def parse_pubmed_result(file_name, sorted_indices, logits, label_list, threshold=0.1):
    info_dict = {}
    info_dict['filename'] = file_name
    info_dict['text'] = []
    info_dict['prob'] = []
    for idx in range(len(label_list)):
        if logits[sorted_indices[idx]] >= threshold:
            info_dict['text'].append(label_list[sorted_indices[idx]])
            info_dict['prob'].append(logits[sorted_indices[idx]])
        else:
            break
    return info_dict

# helper function for BioMed CLIP

def parse_pubmed_result(file_name, sorted_indices, logits, label_list, threshold=0.1):
    info_dict = {}
    info_dict['filename'] = file_name
    info_dict['text'] = []
    info_dict['prob'] = []
    for idx in range(len(label_list)):
        if logits[sorted_indices[idx]] >= threshold:
            info_dict['text'].append(label_list[sorted_indices[idx]])
            info_dict['prob'].append(logits[sorted_indices[idx]])
        else:
            continue
    
    filename = info_dict['filename']
    text_list = info_dict['text']
    score_list = info_dict['prob']
    output_text = "image information retrieved by Microsoft BiomedCLIP-PubMedBERT model: \n"
    for info_text, score in zip(text_list, score_list):
        pred = f"{info_text}: probability: {score:.3f} \n"
        output_text += pred
    return info_dict, output_text

In [None]:
# load the fine-tuned GPT-4o text-to-image model
# use different model IDs for different folds of cross-validation.
# The model is fine-tuned on the COVID-19 pneumonia dataset, and it can be used to describe the image based on the system instruction.
# The fine-tuned model is not available in the public domain, so you need to use your own model ID.

prompt = "describe the image based on the system instruction."
model_id = "****"

In [47]:
# classify and predict mRALE score
# need to convert numbers to text for downstream LLM processing

classify_prediction = classify_image(img_paths[1], classifier)
mobilevit_prompt = parse_classify_text(classify_prediction)

print(mobilevit_prompt)

pred_score = predict_score(cropimg_paths[1], regressor)
dino_prompt = parse_regression(pred_score)

print(dino_prompt)

The image is predicted as COVID-19 penumonia positive. The probability of being negative is 0.000, and the probability of being positive is 0.914.
The image is predicted to have an mRALE score of 11.75.


In [48]:
# call fine-tuned GPT-4o for image-to-text

gpt_prompt = get_gpt_image_info(img_paths[1], prompt, api_key, model_id)
print(gpt_prompt)

The image presents COVID-19 pneumonia with a mRALE score of 2. There is mild lung involvement with limited ground-glass opacities, affecting a small area of the lung. The pneumonia severity is mild.


In [49]:
# format the key info from the GPT-4o response

gpt_format_answer = extract_label(gpt_prompt, client)
print(gpt_format_answer)

{'label': 'positive', 'level': 'mild', 'score': 2}


In [9]:
# set the input key words list to be detected by BiomedCLIP - PubmedBERT

labels = [
    'chest X-ray',
    'chest CT',
    'chest MRI',
    'Centred Rotated Inspiration Adequate Expiratory image',
    'Centred Rotated Inspiration Hyperinflated',
    'TRACHEA Normal in position',
    'TRACHEA Shifted towards right',
    'TRACHEA Shifted towards left.',
    'MEDIASTINUM Normal outlines',
    'MEDIASTINUM Widened mediastinal shadow',
    'MEDIASTINUM Widened with lobulated outline',
    'AORTA Tortuous aorta',
    'AORTA Mural calcifications in aorta',
    'AORTA Tortuous aorta with mural calcifications',
    'HEART',
    'HEART Normal in transverse diameter',
    'HEART Enlarged in transverse diameter',
    'HEART Reduced in transverse diameter',
    'HEART almost tubular in shape',
    'HEART Cardiac diameter may not be represented accurately',
    'HILA LUNG FIELDS Normal pulmonary vasculature',
    'HILA LUNG FIELDS Oligaemic',
    'HILA LUNG FIELDS Plethoric',
    'HILA Bronchovascular markings prominent',
    'HILA Bronchovascular markings more prominent in upper lung fields — upper lobe diversion',
    'HILA Bronchovascular No focal or diffuse parenchymal lesion',
    'DIAPHRAGM Normal position & contour',
    'DIAPHRAGM Elevated on right side',
    'DIAPHRAGM Elevated on left side',
    'DIAPHRAGM Hump on right side',
    'DIAPHRAGM Hump on left side',
    'DIAPHRAGM Lowered on RL side',
    'DIAPHRAGM Flattened on RL side',
    'Costophrenic (CP) ANGLES Clear',
    'Costophrenic (CP) ANGLES Obscured on right side',
    'Costophrenic (CP) ANGLES Obscured on left side',
    'Costophrenic (CP) ANGLES Obscured on both sides',
    'BONES No fracture',
    'BONES lytic',
    'BONES sclerotic lesion',
    'SOFT TISSUE No calcification',
    'SOFT TISSUE obvious swelling',
    'IMPRESSION No detectable finding Inflammatory changes',
    'IMPRESSION Inflammatory & fibrotic changes',
    'IMPRESSION Inflammatory & bronchiectatic changes',
    'IMPRESSION Chronic inflammatory & fibrotic changes',
    'Bilateral pulmonary emphysematous changes [does not include mild disease]',
    'COVID-19',
    'covid',
    'Severe acute respiratory syndrome coronavirus 2',
    'SARS‑CoV‑2',
    'CXR CATEGORY Classic covid',
    'CXR CATEGORY Probable covid',
    'CXR CATEGORY Indeterminate for covid',
    'CXR CATEGORY Non-covid pathology',
    'CXR GRADING Mild',
    'CXR GRADING Moderate',
    'CXR GRADING Severe',
    'COMPARISON Stable',
    'COMPARISON Marginal improvement',
    'COMPARISON Progression',
    'COMPARISON Significant improvement',
]

In [10]:
len(labels)

62

In [52]:
torch.cuda.empty_cache()


file_name, sorted_indices, logits = pubmed_predict(biomed_model, biomed_preprocessor, biomed_tokenizer, img_paths[1], labels)
biomed_dict, biomed_prompt = parse_pubmed_result(file_name, sorted_indices, logits, labels)

In [53]:
print(biomed_dict)
print(biomed_prompt)

{'filename': 'chest_XR_205.png', 'text': ['CXR GRADING Mild', 'CXR CATEGORY Non-covid pathology'], 'prob': [0.44146204, 0.19367959]}
image information retrieved by Microsoft BiomedCLIP-PubMedBERT model: 
CXR GRADING Mild: probability: 0.441 
CXR CATEGORY Non-covid pathology: probability: 0.194 



In [35]:
# GPT o1-mini
# use this function to integrate all outputs into the GPT-o1 model for final reasoning

def reasoning_predict(gpt_4o_prompt, gpt_format_answer, dino_prompt, vit_prompt, biomed_prompt, client, model='o1-mini'):
    role_instruct = """
    You are a medical AI expert. Your task is to extract the key information from the predictions of multiple AI models including: \
    A fine-tuned Mobile ViT model to predict whether the chest x-ray image has COVID-19 pneumonia with prediction Accuracy of 0.86, F1 Score of 0.92, Sensitivity of 0.95 for positive, and Specificity of 0.54. \
    A fine-tuned Dinov2 model to predict the mRALE (Modified Radiographic Assessment of Lung Edema) score to evaluate the severity of the pneumonia as presented in the image. \
    If the mRALE is 0, the image should be classified as low for penumonia; if the mRALE is between 1 to 10, the image should be classified as mild penumonia; \
    if the mRALE is between 11 to 18, the image should be classified as moderate pnuemonia; if the mRALE is above 19, the image should be classified as severe penumonia; \
    The fine-tuned Dinov2 model has the prediction Mean Squared Error of 73.74, Mean Absolute Error of 6.85, Root Mean Squared Error of 8.59, R-Squared of -0.14, and Explained Variance Score of -0.13 for mRALE prediction. \
    A fine-tuned GPT-4o for image-to-text generation to predict chest x-ray images on the COVID-19 pneumonia and pneumonia severity. \
    For COVID-19 pneumonia prediction, the fine-tuned GPT-4o has Accuracy of 0.86, F1 Score of 0.92, Sensitivity of 0.97, and Specificity of 0.48. \ 
    For mRALE score prediction, the fine-tuned GPT-4o model has the prediction Mean Squared Error of 31.82, Mean Absolute Error of 3.69, Root Mean Squared Error of 5.64, R-Squared of 0.51, and Explained Variance Score of 0.54 for mRALE prediction. \
    A pretrained BiomedCLIP-PubMedBERT model to retrieve the diagnostic entities from the image. Note that a probability score greater than 0.1 is considered significant. \
    Your final decision should not be based on the prediction by any single model exclusively. If the predictions by different models have big difference, consider assign proper weights to each prediction to finalize your answer. \ 
    Your final answer must include the following content: \
    whether the image is COVID-19 pneumonia positive or COVID-19 pneumonia negative \
    The level of pneumonia severity from low, mild, moderate, or severe \
    pneumonia score: the mRALE score for lung edema from 0 to 24 \
    return the final response in the dict format with the key "covid_19" for the COVID-19 pneumonia, with the key "severity" for the pneumonia sererity level, \
    with the key "mRALE" for the numeric mRALE score, and explanation of the reasoning procedure respectively.
    """
    messages = [
        {"role": "user", 'content': role_instruct},
        {"role": "user", 'content': f"predicted information from the fine-tuned gpt-4o model: {gpt_4o_prompt}"},
        {"role": "user", 'content': f"predictions by the fine-tuned gpt-4o model: covid pneumonia: {gpt_format_answer['label']}, severity: {gpt_format_answer['level']}, mRALE score: {gpt_format_answer['score']}"},
        {"role": "user", 'content': f"predicted information from the fine-tuned Dinov2 model for mRALE score with mean absolut error (MAE) of 5 points: {dino_prompt}"},
        {"role": "user", 'content': f"predicted information from the fine-tuned Mobile ViT model for COVID-19 pneumonia mean accuracy of 0.86: {vit_prompt}"},
        {"role": "user", 'content': f"predicted information from the pretrained BiomedCLIP-PubMedBERT model to retrieve the diagnostic entities: {biomed_prompt}"},
        ]
    
    response = client.chat.completions.create(
        model=model,
        messages=messages
    )
    return response.choices[0].message.content

In [23]:
pred_text = reasoning_predict(gpt_prompt, gpt_format_answer, dino_prompt, mobilevit_prompt, biomed_prompt, client)
# reason_pred = reason_text_format(pred_text)
print(pred_text)
# print(reason_pred)

```json
{
  "covid_19": "positive",
  "severity": "mild",
  "mRALE": 6,
  "explanation": "After evaluating predictions from multiple models, the image is classified as COVID-19 pneumonia positive based on consistent outputs from both the Mobile ViT and GPT-4o models, which showed high probabilities for a positive diagnosis. For the mRALE score, the GPT-4o model predicted a score of 1 with a lower mean absolute error (MAE) of 3.69 compared to the Dinov2 model's prediction of 11.75 and MAE of 5. Given the lower error margin, greater weight was assigned to the GPT-4o prediction. Combining the scores with appropriate weighting results in an adjusted mRALE score of approximately 6, categorizing the pneumonia severity as mild."
}
```


# pick the fold index

In [None]:
# load the file list for cross-validation
# the file list file is not provided, but it should contain the paths to the images and their corresponding crop paths
# the file should have a column 'fold' to indicate the fold number, 'crop_path' for cropped image paths, and 'whole_path' for whole image paths
# to access the image files, please contact the author of the code or the dataset provider
file_data = pd.read_csv('file_list.csv')
val_data = file_data[file_data['fold'] == 5] # change the fold number to select the validation set
val_cropimages = val_data['crop_path'].tolist()
val_images = val_data['whole_path'].tolist()
print(len(val_images))
# cropimg_paths = val_cropimages[:2]
# img_paths = val_images[:2]

171


In [55]:
biomed_labels = [
    'chest X-ray',
    'chest CT',
    'chest MRI',
    'Centred Rotated Inspiration Adequate Expiratory image',
    'Centred Rotated Inspiration Hyperinflated',
    'TRACHEA Normal in position',
    'TRACHEA Shifted towards right',
    'TRACHEA Shifted towards left.',
    'MEDIASTINUM Normal outlines',
    'MEDIASTINUM Widened mediastinal shadow',
    'MEDIASTINUM Widened with lobulated outline',
    'AORTA Tortuous aorta',
    'AORTA Mural calcifications in aorta',
    'AORTA Tortuous aorta with mural calcifications',
    'HEART',
    'HEART Normal in transverse diameter',
    'HEART Enlarged in transverse diameter',
    'HEART Reduced in transverse diameter',
    'HEART almost tubular in shape'
    'HEART Cardiac diameter may not be represented accurately',
    'HILA LUNG FIELDS Normal pulmonary vasculature',
    'HILA LUNG FIELDS Oligaemic',
    'HILA LUNG FIELDS Plethoric',
    'HILA Bronchovascular markings prominent',
    'HILA Bronchovascular markings more prominent in upper lung fields — upper lobe diversion',
    'HILA Bronchovascular No focal or diffuse parenchymal lesion',
    'DIAPHRAGM Normal position & contour',
    'DIAPHRAGM Elevated on right side',
    'DIAPHRAGM Elevated on left side',
    'DIAPHRAGM Hump on right side',
    'DIAPHRAGM Hump on left side',
    'DIAPHRAGM Lowered on RL side',
    'DIAPHRAGM Flattened on RL side',
    'Costophrenic (CP) ANGLES Clear',
    'Costophrenic (CP) ANGLES Obscured on right side',
    'Costophrenic (CP) ANGLES Obscured on left side',
    'Costophrenic (CP) ANGLES Obscured on both sides',
    'BONES No fracture',
    'BONES lytic',
    'BONES sclerotic lesion',
    'SOFT TISSUE No calcification',
    'SOFT TISSUE obvious swelling',
    'IMPRESSION No detectable finding Inflammatory changes',
    'IMPRESSION Inflammatory & fibrotic changes',
    'IMPRESSION Inflammatory & bronchiectatic changes',
    'IMPRESSION Chronic inflammatory & fibrotic changes',
    'Bilateral pulmonary emphysematous changes [does not include mild disease]',
    'COVID-19',
    'covid',
    'Severe acute respiratory syndrome coronavirus 2',
    'SARS‑CoV‑2',
    'CXR CATEGORY Classic covid',
    'CXR CATEGORY Probable covid',
    'CXR CATEGORY Indeterminate for covid',
    'CXR CATEGORY Non-covid pathology',
    'CXR GRADING Mild',
    'CXR GRADING Moderate',
    'CXR GRADING Severe',
    'COMPARISON Stable',
    'COMPARISON Marginal improvement',
    'COMPARISON Progression',
    'COMPARISON Significant improvement',
]

In [56]:
data = pd.read_csv('/root/GAN_models/covid_cv/reason_val.csv')
data['pred_class'] = data['pred_class'].astype(str)
data['gpt_class'] = data['gpt_class'].astype(str)
data['gpt_text'] = data['gpt_text'].astype(str)
data['gpt_class'] = data['gpt_class'].astype(str)
data['gpt_level'] = data['gpt_level'].astype(str)
data['biomed_text'] = data['biomed_text'].astype(str)
data['reason_class'] = data['reason_class'].astype(str)
data['reason_level'] = data['reason_level'].astype(str)
data['explain'] = data['explain'].astype(str)
data['reason_text'] = data['reason_text'].astype(str)


for image_path, cropimage_path in zip(val_images, val_cropimages):
    image_name = os.path.basename(image_path)
    row_index = data.index[data['filename'] == image_name].tolist()
    if len(row_index) > 1:
        print("more than one image is identified.")
    idx = row_index[0]

    if not row_index:
        print(f"Image {image_name} cannot be found in the dataset.")
        continue
    
    # classify and predict mRALE score
    classify_prediction = classify_image(image_path, classifier)
    mobilevit_prompt = parse_classify_text(classify_prediction)
    pred_score = predict_score(cropimage_path, regressor)
    dino_prompt = parse_regression(pred_score)
    gpt_prompt = "describe the image based on the system instruction."
    gpt_prompt = get_gpt_image_info(image_path, gpt_prompt, api_key, model_id)
    format_answer = extract_label(gpt_prompt, client)
    torch.cuda.empty_cache()
    file_name, sorted_indices, logits = pubmed_predict(biomed_model, biomed_preprocessor, biomed_tokenizer, image_path, biomed_labels)
    biomed_dict, biomed_prompt = parse_pubmed_result(file_name, sorted_indices, logits, biomed_labels)
    pred_text = reasoning_predict(gpt_prompt, format_answer, dino_prompt, mobilevit_prompt, biomed_prompt, client)
    # json_text = pred_text.replace("```json", "").replace("```", "")
    # json_dict = json.loads(json_text)


    data.loc[idx, 'pred_class'] = classify_prediction[0]
    data.loc[idx, 'pred_score'] = pred_score
    data.loc[idx, 'gpt_text'] = gpt_prompt
    data.loc[idx, 'gpt_class'] = format_answer['label']
    data.loc[idx, 'gpt_score'] = format_answer['score']
    data.loc[idx, 'gpt_level'] = format_answer['level']
    data.loc[idx, 'biomed_text'] = biomed_prompt
    data.loc[idx, 'reason_text'] = pred_text
    # data.loc[idx, 'reason_class'] = json_dict['covid_19']
    # data.loc[idx, 'reason_score'] = json_dict['mRALE']
    # data.loc[idx, 'reason_level'] = json_dict['severity']
    # data.loc[idx, 'explain'] = json_dict['explanation']
    
data.to_csv('/root/GAN_models/covid_cv/reason_val.csv', index=False)
print("predictions saved.")

predictions saved.


In [None]:
from openai import OpenAI

# Set your OpenAI API key
# Replace '***' with your actual OpenAI API key
api_key = "***"
client = OpenAI(api_key=api_key)

In [None]:
# Load the CSV file and ensure all columns are of string type
# The CSV file should contain the predictions from the previous step
# The file should have columns: 'pred_class', 'gpt_class', 'gpt_text', 'gpt_level', 'biomed_text', 'reason_class', 'reason_level', 'explain', 'reason_text', 'reason_score'
# please contact the author of the code or the dataset provider to access the file
data = pd.read_csv('reason_val.csv')
data['pred_class'] = data['pred_class'].astype(str)
data['gpt_class'] = data['gpt_class'].astype(str)
data['gpt_text'] = data['gpt_text'].astype(str)
data['gpt_class'] = data['gpt_class'].astype(str)
data['gpt_level'] = data['gpt_level'].astype(str)
data['biomed_text'] = data['biomed_text'].astype(str)
data['reason_class'] = data['reason_class'].astype(str)
data['reason_level'] = data['reason_level'].astype(str)
data['explain'] = data['explain'].astype(str)
data['reason_text'] = data['reason_text'].astype(str)
data['reason_class'] = data['reason_class'].astype(str)
data['reason_score'] = data['reason_score'].astype(str)
data['reason_level'] = data['reason_level'].astype(str)
data['explain'] = data['explain'].astype(str)

In [17]:
# extract key info from the GPT-4o response

import json

def extract_reason(prompt, client, model='gpt-4o-mini'):
    role_instruct = """
    You are a medical AI expert. Your task is to extract the key information from the response from the gpt-o1 model.\
    Your final answer includes: \
    covid_19: positive or negative \
    severity: pneumonia severity from low, mild, moderate, severe \
    mRALE: the mRALE score for lung edema from 0 to 24 \
    explanation: the explanation text for how the decision is made. \
    Output a JSON object with the extracted information.".
    """
    messages = [
        {"role": "system", 'content': role_instruct},
        {"role": "user", 'content': f"response from the fine-tuned model: {prompt}"},
        ]
    
    response = client.chat.completions.create(
        model=model,
        messages=messages,
        temperature=0.3, # this is the degree of randomness of the model's output, 0 - 1.0
    )
    json_text = response.choices[0].message.content.strip()
    # Convert the string to a dictionary
    # data_dict = json.loads(json_text)
    return json_text



def format_reason(prompt):
    json_text = prompt.replace("```json", "").replace("```", "")
    data_dict = json.loads(json_text)
    return data_dict['covid_19'], data_dict['severity'], data_dict['mRALE'], data_dict['explanation']

In [18]:
explain_list = data['reason_text'].tolist()
prompt = explain_list[1]
response = extract_reason(prompt, client)
classify, level, score, explain = format_reason(response)
print(classify, level, score, explain)

positive mild 1 Multiple models indicate a positive diagnosis for COVID-19 pneumonia with high confidence. The Mobile ViT and GPT-4o models both predict COVID-19 pneumonia as positive, supported by high sensitivity and strong probability scores for positivity. Regarding pneumonia severity, the GPT-4o model predicts an mRALE score of 1, which falls into the 'mild' category based on the provided classification criteria (mRALE 1-10). Although the Dinov2 model predicts a higher mRALE score of 12.04, its performance metrics (Mean Squared Error of 73.74, R-Squared of -0.14) suggest lower reliability. Additionally, the BiomedCLIP-PubMedBERT model's CXR grading further supports a mild severity classification with the highest probability assigned to 'mild'. Therefore, considering the overall consensus and reliability of each model's predictions, the final assessment is a positive COVID-19 pneumonia diagnosis with mild severity and an mRALE score of 1.


In [22]:
data['reason_class'] = data['reason_class'].astype(str)
data['reason_level'] = data['reason_level'].astype(str)
data['explain'] = data['explain'].astype(str)
data['reason_score'] = data['reason_score'].astype(float)

for index, row in data.iterrows():
    input_text = row['reason_text']
    extract_json = extract_reason(input_text, client)
    classify, level, score, explain = format_reason(extract_json)
    data.loc[index, 'reason_class'] = classify
    data.loc[index, 'reason_score'] = score
    data.loc[index, 'reason_level'] = level
    data.loc[index, 'explain'] = explain
    
data.to_csv('/root/GAN_models/covid_cv/reason_val.csv', index=False)
print("Finished processing")

Finished processing


In [25]:
data.head()

Unnamed: 0.1,Unnamed: 0,label,filename,extent_right,density_right,extent_left,density_left,extent_right_numerical,density_right_numerical,extent_left_numerical,...,gpt_score,gpt_level,gpt_text,reason_class,reason_score,reason_level,reason_text,explain,whole_path,crop_path
0,0,positive,1.2.826.0.1.3680043.10.474.419639.298706138782...,,,,,0.0,0.0,0.0,...,1.0,low,This image presents COVID-19 pneumonia. The mR...,positive,5.0,mild,"```json\n{\n ""covid_19"": ""positive"",\n ""seve...",Multiple AI models were analyzed to determine ...,/root/GAN_models/covid_cv/fold1/train/positive...,/root/GAN_models/covid_cv/crop_fold1/train/pos...
1,1,positive,chest_XR_835.png,,,,,0.0,0.0,0.0,...,1.0,low,This image presents COVID-19 pneumonia with a ...,positive,1.0,mild,"```json\n{\n ""covid_19"": ""positive"",\n ""seve...",Multiple models indicate a positive diagnosis ...,/root/GAN_models/covid_cv/fold1/train/positive...,/root/GAN_models/covid_cv/crop_fold1/train/pos...
2,2,positive,1.2.826.0.1.3680043.10.474.419639.270940636091...,,,,,0.0,0.0,0.0,...,0.0,low,The image presents COVID-19 pneumonia with a m...,positive,0.0,low,"```json\n{\n ""covid_19"": ""positive"",\n ""seve...",The final assessment integrates predictions fr...,/root/GAN_models/covid_cv/fold1/train/positive...,/root/GAN_models/covid_cv/crop_fold1/train/pos...
3,3,positive,1.2.826.0.1.3680043.10.474.419639.246480861434...,,,,,0.0,0.0,0.0,...,1.0,mild,The image presents COVID-19 pneumonia with a m...,positive,1.0,mild,"```json\n{\n ""covid_19"": ""positive"",\n ""seve...",Both the fine-tuned GPT-4o and Mobile ViT mode...,/root/GAN_models/covid_cv/fold1/train/positive...,/root/GAN_models/covid_cv/crop_fold1/train/pos...
4,4,positive,chest_XR_611.png,,,,,0.0,0.0,0.0,...,0.0,low,This is a case of COVID-19 pneumonia with no l...,positive,4.0,mild,"```json\n{\n ""covid_19"": ""positive"",\n ""seve...",Multiple models indicate a positive diagnosis ...,/root/GAN_models/covid_cv/fold1/val/positive/c...,/root/GAN_models/covid_cv/crop_fold1/val/posit...


In [59]:
import json

json_text = extract_json.replace("```json", "").replace("```", "")
data_dict = json.loads(json_text)

In [60]:
data_dict

{'covid_19': 'positive',
 'severity': 'moderate',
 'mRALE': 11.8,
 'explanation': 'The final assessment classifies the image as COVID-19 pneumonia positive with moderate severity based on a weighted average mRALE score of 11.8, which falls within the moderate severity range.'}

In [36]:
def get_png_files(img_dir):
    png_files = []
    for subdir, _, files in os.walk(img_dir):
        for file in files:
            if file.endswith('.png'):
                # Get the absolute path and add it to the list
                png_files.append(os.path.abspath(os.path.join(subdir, file)))
    return png_files

In [40]:
val_dir = '/root/GAN_models/covid_cv/fold1/val'
val_imgs = get_png_files(val_dir)

In [41]:
print(len(val_imgs))

177


In [None]:
# use this part to reconstruct the outputs into structured data for performance evaluation

from PIL import Image
import pandas as pd

# load the validation set file
# the file should have a column 'filename' for the image file names, and 'cvt-1' for the predictions by the BiomedCLIP model
# please contact the author of the code or the dataset provider to access the file
data = pd.read_csv('val_chatgpt.csv')

# data['segformer-2'] = pd.NA
# data['cvt-1'] = data['cvt-1'].astype(str)
model_id = '****'  # replace with your fine-tuned model ID
pred_labels = []
for image_path in val_imgs:
    image_name = os.path.basename(image_path)
    row_index = data.index[data['filename'] == image_name].tolist()
    if not row_index:
        print(f'image {image_name} cannot be found')
    prompt1 = "given the image, detect if the image has covid-19 positive pneumonia or covid-19 negative with a confidence score from 0 to 1."
    prompt2 = "rate the pneumonia at one of the four level: low, mild, moderate, severe with a confidence score from 0 to 1."
    prompt3 = "predict the mRALE (from 0 to 24, zero means no pneumonia) score of the image"
    prompts = [prompt1, prompt2, prompt3]
    answer = ""
    for prompt in prompts:
        response = get_image_info(img_path, prompt, api_key, model_id)
        answer += response

    format_answer = extract_label(answer, client)
    print(format_answer)
    
    data.loc[row_index, 'gpt_covid'] = format_answer['label']
    data.loc[row_index, 'covid_conf'] = format_answer['label_score']
    data.loc[row_index, 'gpt_mrale'] = format_answer['score']
    data.loc[row_index, 'gpt_level'] = format_answer['level']
    data.loc[row_index, 'level_conf'] = format_answer['level_score']
    
data.to_csv('val_chatgpt.csv', index=False)
print("predictions saved.")

{'label': 'positive', 'label_score': 0.95, 'level': 'severe', 'level_score': 0.9, 'score': 18}


  data.loc[row_index, 'gpt_covid'] = format_answer['label']
  data.loc[row_index, 'gpt_level'] = format_answer['level']


{'label': 'positive', 'label_score': 0.95, 'level': 'severe', 'level_score': 0.9, 'score': 18}
{'label': 'positive', 'label_score': 0.95, 'level': 'severe', 'level_score': 0.9, 'score': 18}
{'label': 'positive', 'label_score': 0.95, 'level': 'severe', 'level_score': 0.9, 'score': 18}
{'label': 'positive', 'label_score': 0.95, 'level': 'severe', 'level_score': 0.9, 'score': 18}
{'label': 'positive', 'label_score': 0.95, 'level': 'severe', 'level_score': 0.9, 'score': 18}
{'label': 'positive', 'label_score': 0.95, 'level': 'severe', 'level_score': 0.95, 'score': 18}
{'label': 'positive', 'label_score': 0.95, 'level': 'severe', 'level_score': 0.9, 'score': 18}
{'label': 'positive', 'label_score': 0.95, 'level': 'severe', 'level_score': 0.95, 'score': 18}
{'label': 'positive', 'label_score': 0.95, 'level': 'severe', 'level_score': 0.9, 'score': 18}
{'label': 'positive', 'label_score': 1.0, 'level': 'severe', 'level_score': 0.95, 'score': 18}
{'label': 'positive', 'label_score': 0.95, 'leve

TypeError: can only concatenate str (not "int") to str