In [1]:
%cd /home/is/dwipraseetyo-a/NAS_HAI/Project/Qwen2.5-Omni

import pandas as pd
import numpy as np
import soundfile as sf
from pathlib import Path
import os, librosa, random, pickle, pydicom, requests, torch, re
from pydicom.datadict import keyword_for_tag
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
from IPython.display import Markdown, display
from openai import OpenAI
from PIL import Image
from tqdm import tqdm

import commons, const_variable
random.seed(42)

/home/ldap-users-2/dwipraseetyo-a/Project/Qwen2.5-Omni


In [2]:
DATA_PATH = "/home/is/dwipraseetyo-a/NAS_HAI/Datasets/cidrz"

In [5]:
system_prompt = '''
You are an advanced medical assistant AI specialized in analyzing and diagnosing clinical conditions, capable of perceiving auditory and visual inputs. 
You can interpret and reason over various medical inputs, including auditory inputs, visual inputs, and patient symptoms, individually or in combination, depending on what is provided. 
Your task is to analyze the given input, explain your reasoning, and give a possible diagnosis. 
Always respond in the following format:

## ⚠️ Points to Review and Disclaimer
<If no auditory or visual input is provided>

## 🧠 Overview
<Diagnosis sentence>

## 📋 Observations
**Chest X-ray:**
<Your explanation based on the relevant visual input>"

**Symptoms:**
<Your explanation based on the input symptoms>"

**Audio:**
<Your explanation based on the input audio>"
'''

In [12]:
for split in ['train', 'dev']:
    df = pd.read_csv(f"{DATA_PATH}/metadata_cut_processed.csv.{split}")
    df_llm_symptoms = ( pd.read_csv(f"datas/reasoning/symptoms/gpt4_symptoms.csv.{split}").groupby('barcode', group_keys=False).apply(lambda x: x.sample(1), include_groups=True).reset_index(drop=True) ) 
    df_llm_images = ( pd.read_csv(f"datas/reasoning/xray/medgemma_xray_formatted.csv.{split}").groupby('path_file_image', group_keys=False).apply(lambda x: x.sample(1), include_groups=True).reset_index(drop=True) )
    df = pd.merge(df, df_llm_symptoms, on='barcode', how='left')
    df = pd.merge(df, df_llm_images, on='path_file_image', how='left')
    df = df.rename(columns={'coughdur': 'cough_duration', 'ngtsweats': 'night_sweets', 'weightloss': 'weight_loss', 'body_wt': 'body_weight'})

    instruct_array = []
    for now_row in tqdm(df.itertuples(), desc=f"Processing {split}", total=len(df)):
        row_dict = now_row._asdict()
        now_audiopath = "/home/is/dwipraseetyo-a/NAS_HAI/Datasets/cidrz/" +  now_row.path_file_audio
        now_imgaepath = "/home/is/dwipraseetyo-a/NAS_HAI/Datasets/cidrz/" +  now_row.path_file_image

        for i in range(6):
            array_df = [None, None]
            question, modalities = commons.unique_modalities_generator(const_variable.prompt_templates)
            question += ". "
            answer = commons.generate_tb_response(modalities, now_row.llm_analyze_symptoms, now_row.llm_analyze_image, positive=(now_row.ground_truth_tb == 1))
            
            if now_row.path_file_audio == 'Unknown':
                modalities = [m for m in modalities if m != "audio"]

            array_df = [None, None]
            if "symptoms" in modalities:
                row_dict = now_row._asdict()
                selected_feats = random.sample(const_variable.columns_soundfeat, k=random.randint(3, len(const_variable.columns_soundfeat)))
                symptom_descriptions = ", ".join(
                    f"{feat.replace('_', ' ')} is {row_dict[feat]}"
                    for feat in selected_feats
                    if row_dict.get(feat) != "Unknown"
                )
                if symptom_descriptions:
                    question += f" Also, the patient presents with: {symptom_descriptions}."

            if "audio" in modalities:
                array_df[0] = now_audiopath

            if "xray" in modalities:
                array_df[1] = now_imgaepath
                xray_descriptions = ", ".join(
                    f"{feat.replace('_', ' ')} is {row_dict[feat]}"
                    for feat in const_variable.columns_imagefeat
                    if row_dict.get(feat) != "Unknown"
                )
                if xray_descriptions:
                    question += f" Additional chest x-ray metadata include: {xray_descriptions}."

            question = question.strip()
            question = question.rstrip(",.")
            if not question.endswith("."):
                question += "."

            temp_instruct = {"messages": [
                {"role": "system",
                    "content": [
                        {"type": "text", "text": system_prompt}
                    ],
                },
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": question},
                    ],
                },
                {"role": "assistant", "content": [
                    {"type": "text", "text": answer}]},
            ]}
            if array_df[0] != None:
                temp_instruct["messages"][1]['content'].append({"type": "audio", "audio": array_df[0]})
            if array_df[1] != None:
                temp_instruct["messages"][1]['content'].append({"type": "image", "image":array_df[1]})
            instruct_array.append(temp_instruct)
    
    with open(f"datas/instruct.pkl.{split}", "wb") as f:
        pickle.dump(instruct_array, f)

  df_llm_symptoms = ( pd.read_csv(f"datas/reasoning/symptoms/gpt4_symptoms.csv.{split}").groupby('barcode', group_keys=False).apply(lambda x: x.sample(1), include_groups=True).reset_index(drop=True) )
  df_llm_images = ( pd.read_csv(f"datas/reasoning/xray/medgemma_xray_formatted.csv.{split}").groupby('path_file_image', group_keys=False).apply(lambda x: x.sample(1), include_groups=True).reset_index(drop=True) )
Processing train: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2151/2151 [00:00<00:00, 5277.76it/s]
  df_llm_symptoms = ( pd.read_csv(f"datas/reasoning/symptoms/gpt4_symptoms.csv.{split}").groupby('barcode', group_keys=False).apply(lambda x: x.sample(1), include_groups=True).reset_index(drop=True) )
  df_llm_images = ( pd.read_csv(f"datas/reasoning/xray/medgemma_xray_formatted.csv.{split}").groupby('path_file_image', group_keys=False).apply(lam

In [7]:
df = pd.read_csv(f"{DATA_PATH}/metadata_cut_processed.csv.test")
df_llm_symptoms = ( pd.read_csv(f"datas/reasoning/symptoms/o4-mini_symptoms.csv.test").groupby('barcode', group_keys=False).apply(lambda x: x.sample(1), include_groups=True).reset_index(drop=True) ) 
df_llm_images = ( pd.read_csv(f"datas/reasoning/xray/medgemma_xray_formatted.csv.test").groupby('path_file_image', group_keys=False).apply(lambda x: x.sample(1), include_groups=True).reset_index(drop=True) )
df = pd.merge(df, df_llm_symptoms, on='barcode', how='left')
df = pd.merge(df, df_llm_images, on='path_file_image', how='left')
df = df.rename(columns={'coughdur': 'cough_duration', 'ngtsweats': 'night_sweets', 'weightloss': 'weight_loss', 'body_wt': 'body_weight'})

instruct_array_positive = []
instruct_array_negative = []

for now_row in tqdm(df.itertuples(), desc=f"Processing Datasets", total=len(df)):
    row_dict = now_row._asdict()
    now_audiopath = "/home/is/dwipraseetyo-a/NAS_HAI/Datasets/cidrz/" +  now_row.path_file_audio
    now_imgaepath = "/home/is/dwipraseetyo-a/NAS_HAI/Datasets/cidrz/" +  now_row.path_file_image

    modalities = ["audio", "xray", "symptoms"]
    key = tuple(modalities)
    if key in const_variable.prompt_templates:
        question = random.choice(const_variable.prompt_templates[key]) + ". "
    answer = commons.generate_tb_response(modalities, now_row.llm_analyze_symptoms, now_row.llm_analyze_image, positive=(now_row.ground_truth_tb == 1))

    if now_row.path_file_audio == 'Unknown':
        modalities = [m for m in modalities if m != "audio"]

    array_df = [None, None]
    if "symptoms" in modalities:
        row_dict = now_row._asdict()
        selected_feats = random.sample(const_variable.columns_soundfeat, k=random.randint(3, len(const_variable.columns_soundfeat)))
        symptom_descriptions = ", ".join(
            f"{feat.replace('_', ' ')} is {row_dict[feat]}"
            for feat in selected_feats
            if row_dict.get(feat) != "Unknown"
        )
        if symptom_descriptions:
            question += f" Also, the patient presents with: {symptom_descriptions}."

    if "audio" in modalities:
        array_df[0] = now_audiopath

    if "xray" in modalities:
        array_df[1] = now_imgaepath
        xray_descriptions = ", ".join(
            f"{feat.replace('_', ' ')} is {row_dict[feat]}"
            for feat in const_variable.columns_imagefeat
            if row_dict.get(feat) != "Unknown"
        )
        if xray_descriptions:
            question += f" Additional chest x-ray metadata include: {xray_descriptions}."

    question = question.strip()
    question = question.rstrip(",.")
    if not question.endswith("."):
        question += "."

    temp_instruct = {"messages": [
        {"role": "system",
            "content": [
                {"type": "text", "text": system_prompt}
            ],
        },
        {
            "role": "user",
            "content": [
                {"type": "text", "text": question},
            ],
        },
        {"role": "assistant", "content": [
            {"type": "text", "text": answer}]},
    ]}
    if array_df[0] != None:
        temp_instruct["messages"][1]['content'].append({"type": "audio", "audio": array_df[0]})
    if array_df[1] != None:
        temp_instruct["messages"][1]['content'].append({"type": "image", "image":array_df[1]})

    if now_row.ground_truth_tb == 1:
        instruct_array_positive.append(temp_instruct)
    elif now_row.ground_truth_tb == 0:
        instruct_array_negative.append(temp_instruct)
        
with open(f"datas/positive_instruct.pkl.test", "wb") as f:
    pickle.dump(instruct_array_positive, f)

with open(f"datas/negative_instruct.pkl.test", "wb") as f:
    pickle.dump(instruct_array_negative, f)

  df_llm_symptoms = ( pd.read_csv(f"datas/reasoning/symptoms/o4-mini_symptoms.csv.test").groupby('barcode', group_keys=False).apply(lambda x: x.sample(1), include_groups=True).reset_index(drop=True) )
  df_llm_images = ( pd.read_csv(f"datas/reasoning/xray/medgemma_xray_formatted.csv.test").groupby('path_file_image', group_keys=False).apply(lambda x: x.sample(1), include_groups=True).reset_index(drop=True) )
Processing Datasets: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 270/270 [00:00<00:00, 22871.55it/s]


In [None]:
def clear_memory():
    # Delete variables if they exist in the current global scope
    if "inputs" in globals():
        del globals()["inputs"]
    if "model" in globals():
        del globals()["model"]
    if "processor" in globals():
        del globals()["processor"]
    if "trainer" in globals():
        del globals()["trainer"]
    if "peft_model" in globals():
        del globals()["peft_model"]
    if "bnb_config" in globals():
        del globals()["bnb_config"]
    time.sleep(2)

    # Garbage collection and clearing CUDA memory
    gc.collect()
    time.sleep(2)
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    time.sleep(2)
    gc.collect()
    time.sleep(2)

    print(f"GPU allocated memory: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    print(f"GPU reserved memory: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")


clear_memory()