## Notebook to filter data for Dr-LLaVA Experiment

In [1]:
import os
import pandas as pd
import numpy as np
import json
from tqdm import tqdm
from sklearn.model_selection import train_test_split

In [2]:
all_ecg_df = pd.read_csv('../../data/mimic-acute-mi.csv')

In [3]:
all_ecg_df = all_ecg_df[~all_ecg_df['text'].isna()]
all_ecg_df = all_ecg_df[(all_ecg_df['troponin'] & ~all_ecg_df['valuenum'].isnull() & ~all_ecg_df['comments'].isnull()) | (all_ecg_df['troponin'] == 0)]
display(all_ecg_df.head())
print(all_ecg_df.shape)

Unnamed: 0,hadm_id,ecg_time,st_elevation,st_depression,t_wave,Acute_MI,study_id,text,valuenum,comments,troponin,STEMI,NSTEMI
0,29196424.0,2087-05-08 00:05:00,0,0,0,0,44673611,Allergies: \nlisinopril / vancomycin\n \nChief...,,,0,0,0
1,21091437.0,2087-05-08 00:05:00,0,0,0,0,43819938,Allergies: \namitriptyline / Sulfa (Sulfonamid...,0.03,cTropnT > 0.10 ng/mL suggests Acute MI.,1,0,0
3,21236438.0,2088-04-14 02:58:00,0,0,0,0,47239325,Allergies: \nMULTIPLE - SEE LIST **** / amoxic...,0.02,cTropnT > 0.10 ng/mL suggests Acute MI.,1,0,0
4,23821411.0,2089-12-11 00:06:00,0,0,0,0,41010651,Allergies: \nchlorhexidine\n \nChief Complaint...,0.07,cTropnT > 0.10 ng/mL suggests Acute MI.,1,0,0
5,28274927.0,2090-08-21 00:12:00,0,0,0,0,49791716,Allergies: \nNo Known Allergies / Adverse Drug...,,,0,0,0


(272109, 13)


In [4]:
print(all_ecg_df.STEMI.sum())
print(all_ecg_df.NSTEMI.sum())
print()

print(all_ecg_df.st_elevation.sum())
print(all_ecg_df.st_depression.sum())
print(all_ecg_df.t_wave.sum())

8482
30144

31925
1572
493


In [5]:
# Function to sample rows for a specific condition
def sample_condition(df, condition, count, sampled_indices):
    condition_df = df[condition & ~df.index.isin(sampled_indices)]
    if len(condition_df) < count:
        raise ValueError(f"Not enough rows to sample for condition: {condition}")
    sampled = condition_df.sample(n=count, random_state=RANDOM_STATE)
    sampled_dfs.append(sampled)
    return sampled.index

# Set a random seed for reproducibility
RANDOM_STATE = 42

# Define the minimum required counts for each ECG category
required_counts = {
    'STEMI': 900,
    'NSTEMI': 3200,
    'ST_elevation': 3500,
    'ST_depression': 500,
    'T_wave_inversion': 300
}

# Create a copy of the original dataframe to work with
df = all_ecg_df.copy()

# Initialize an empty list to store the sampled dataframes
sampled_dfs = []

# Initialize a set to keep track of sampled indices to avoid duplication where necessary
sampled_indices = set()

# Sample STEMI ECGs
sampled_indices.update(sample_condition(
    df,
    df['STEMI'] == 1,
    required_counts['STEMI'],
    sampled_indices
))

# Sample NSTEMI ECGs
sampled_indices.update(sample_condition(
    df,
    df['NSTEMI'] == 1,
    required_counts['NSTEMI'],
    sampled_indices
))

# Sample ST-elevation ECGs
sampled_indices.update(sample_condition(
    df,
    df['st_elevation'] == 1,
    required_counts['ST_elevation'],
    sampled_indices
))

# Sample ST-depression ECGs
sampled_indices.update(sample_condition(
    df,
    df['st_depression'] == 1,
    required_counts['ST_depression'],
    sampled_indices
))

# Sample T-wave inversion ECGs
sampled_indices.update(sample_condition(
    df,
    df['t_wave'] == 1,
    required_counts['T_wave_inversion'],
    sampled_indices
))

# Concatenate all sampled dataframes
downsampled_df = pd.concat(sampled_dfs)

# Calculate the remaining number of rows to reach 30,000
remaining_rows = 30000 - len(downsampled_df)

# Check if there are enough remaining rows to sample
if remaining_rows > 0:
    # Exclude already sampled indices
    remaining_df = df[~df.index.isin(sampled_indices)]
    
    if len(remaining_df) < remaining_rows:
        raise ValueError("Not enough remaining rows to reach 30,000 after sampling required categories.")
    
    # Sample the remaining rows
    remaining_sampled = remaining_df.sample(n=remaining_rows, random_state=RANDOM_STATE)
    downsampled_df = pd.concat([downsampled_df, remaining_sampled])

# Shuffle the final dataframe
downsampled_df = downsampled_df.sample(frac=1, random_state=RANDOM_STATE).reset_index(drop=True)

# Optional: Verify the counts
print("Downsampled DataFrame Shape:", downsampled_df.shape)
print("STEMI Count:", downsampled_df['STEMI'].sum())
print("NSTEMI Count:", downsampled_df['NSTEMI'].sum())
print("ST Elevation Count:", (downsampled_df['st_elevation'] == 1).sum())
print("ST Depression Count:", (downsampled_df['st_depression'] == 1).sum())
print("T-wave Inversion Count:", (downsampled_df['t_wave'] == 1).sum())


Downsampled DataFrame Shape: (30000, 13)
STEMI Count: 2421
NSTEMI Count: 5547
ST Elevation Count: 6969
ST Depression Count: 707
T-wave Inversion Count: 354


In [6]:
study_ids = downsampled_df['study_id'].tolist()
print(study_ids[:10])
# Path to the image directory
image_dir = '../../data/image_folder'

# Initialize a set to store the integer filenames
image_ids = set()

# Traverse through each file in the image directory
for filename in os.listdir(image_dir):
    # Check if the file is a JPEG image
    if filename.lower().endswith('.jpeg'):
        # Remove the '.jpeg' extension
        name_without_ext = filename[:-5]
        try:
            # Convert the filename to an integer and add to the set
            file_id = int(name_without_ext)
            image_ids.add(file_id)
        except ValueError:
            # If the filename is not an integer, skip it
            print(f"Skipping file with non-integer name: {filename}")

# Convert study_ids to a set for efficient lookup
study_ids_set = set(study_ids)

# Find the intersection of study_ids and image_ids
matching_ids = study_ids_set.intersection(image_ids)

# Calculate the number of matching IDs
num_matching = len(matching_ids)
total_study_ids = len(study_ids)

# Print the results
print(f"Total number of study_ids: {total_study_ids}")
print(f"Number of study_ids with corresponding image files: {num_matching}")
print(f"Percentage matched: { (num_matching / total_study_ids) * 100:.2f}%")

[49931631, 47699806, 45604321, 49910525, 40829195, 48006520, 48573578, 48742280, 49282801, 49289007]
Total number of study_ids: 30000
Number of study_ids with corresponding image files: 29965
Percentage matched: 99.88%


In [7]:
downsampled_df = downsampled_df[downsampled_df['study_id'].isin(matching_ids)] 
print(downsampled_df.shape)

(29965, 13)


In [8]:
downsampled_df.to_csv('../../data/mimic-acute-mi_modelling.csv')

In [9]:
with open('../../data/conversations_new.json') as f:
    conversations = json.load(f)

In [10]:
study_id_list = downsampled_df['study_id'].tolist()
filtered_conversations = [item for item in tqdm(conversations) if int(item['id']) in downsampled_df['study_id'].tolist()]
print(f"Number of conversations after filtering: {len(filtered_conversations)}")


  0%|          | 0/272109 [00:00<?, ?it/s]

100%|██████████| 272109/272109 [03:42<00:00, 1223.88it/s]

Number of conversations after filtering: 29965





In [11]:
with open('../../data/conversations_modelling.json', 'w') as f:
    json.dump(filtered_conversations, f)

### Create Train and Test Set

In [12]:
# Set parameters
RANDOM_STATE = 42
TRAIN_RATIO = 0.8

# Paths to JSON files
train_json_path = '../../data/train_conversations.json'
test_json_path = '../../data/test_conversations.json'

# Paths to save split DataFrames
train_df_path = '../../data/train_downsampled_df.csv'
test_df_path = '../../data/test_downsampled_df.csv'


# 1. Split the study_ids
study_ids = downsampled_df['study_id'].astype(int).unique()
train_ids, test_ids = train_test_split(
    study_ids,
    test_size=1 - TRAIN_RATIO,
    random_state=RANDOM_STATE,
    shuffle=True
)

print(f"Training study_ids: {len(train_ids)}")
print(f"Testing study_ids: {len(test_ids)}")

# 2. Create training and test DataFrames
train_df = downsampled_df[downsampled_df['study_id'].astype(int).isin(train_ids)].reset_index(drop=True)
test_df = downsampled_df[downsampled_df['study_id'].astype(int).isin(test_ids)].reset_index(drop=True)

print(f"Training DataFrame shape: {train_df.shape}")
print(f"Test DataFrame shape: {test_df.shape}")

# 3. Split conversations
train_conversations = [conv for conv in filtered_conversations if int(conv['id']) in train_ids]
test_conversations = [conv for conv in filtered_conversations if int(conv['id']) in test_ids]

print(f"Training conversations count: {len(train_conversations)}")
print(f"Test conversations count: {len(test_conversations)}")

# 4. Save the split DataFrames
train_df.to_csv(train_df_path, index=False)
test_df.to_csv(test_df_path, index=False)
print("Training and test DataFrames saved successfully!")

# 5. Save the conversations
with open(train_json_path, 'w') as f:
    json.dump(train_conversations, f)

with open(test_json_path, 'w') as f:
    json.dump(test_conversations, f)

Training study_ids: 23972
Testing study_ids: 5993
Training DataFrame shape: (23972, 13)
Test DataFrame shape: (5993, 13)


Training conversations count: 23972
Test conversations count: 5993
Training and test DataFrames saved successfully!


In [13]:
study_id_set = downsampled_df['study_id'].tolist()
with open('../../data/modelling_ids.json', 'w') as json_file:
    json.dump(list(study_id_set), json_file)

### Represent Test conversations as single QA data

In [14]:

def transform_to_qa(test_conversations):
    qa_list = []  # List to hold all QA pairs

    # Iterate over each conversation item with a progress bar
    for item in tqdm(test_conversations, desc='Processing conversations'):
        convs = item.get('conversations', [])
        conv_id = item.get('id')
        image = item.get('image')
        diagnosis = item.get('diagnosis')

        # Initialize index
        i = 0
        while i < len(convs):
            # Check if the current turn is from 'human'
            if convs[i].get('from') == 'human':
                human_msg = convs[i].get('value', '').strip()

                # Check if the next turn exists and is from 'gpt'
                if i + 1 < len(convs) and convs[i + 1].get('from') == 'gpt':
                    gpt_msg = convs[i + 1].get('value', '').strip()

                    # Append the QA pair to the list, maintaining 'from' and 'value'
                    qa_list.append({
                        'id': conv_id,
                        'image': image,
                        'conversations': [
                            {
                                'from': 'human',
                                'value': human_msg
                            },
                            {
                                'from': 'gpt',
                                'value': gpt_msg
                            }
                        ],
                        #'diagnosis': diagnosis
                    })

                    # Move to the next pair
                    i += 2
                else:
                    # If there's no corresponding 'gpt' response, skip to next
                    i += 1
            else:
                # If the current turn is not from 'human', skip to next
                i += 1
    return qa_list

qa_list = transform_to_qa(test_conversations)

print(f"Total QA pairs extracted: {len(qa_list)}")

Processing conversations: 100%|██████████| 5993/5993 [00:00<00:00, 88070.49it/s]

Total QA pairs extracted: 23972





In [15]:
qa_list[:5]

[{'id': 40569673,
  'image': '40569673.jpeg',
  'conversations': [{'from': 'human',
   {'from': 'gpt',
    'value': 'No, there are no signs of ST elevation on the ECG.'}]},
 {'id': 40569673,
  'image': '40569673.jpeg',
  'conversations': [{'from': 'human',
    'value': "Examine the patient's ECG and ascertain if there are signs of ST depression or T wave inversion."},
   {'from': 'gpt',
    'value': 'The ECG does not exhibit ST depression or T-wave inversion.'}]},
 {'id': 40569673,
  'image': '40569673.jpeg',
  'conversations': [{'from': 'human',
    'value': '. Is it advisable to proceed with a troponin test?'},
   {'from': 'gpt', 'value': 'Yes, we should arrange for a troponin test.'}]},
 {'id': 40569673,
  'image': '40569673.jpeg',
  'conversations': [{'from': 'human',
    'value': '. Troponin test result: 1.09,cTropnT > 0.10 ng/mL suggests Acute MI.. Might this patient be having a heart attack?'},
   {'from': 'gpt',
    'value': 'The patient is currently experiencing a heart attack

In [16]:
with open('../../data/test_conversations_single_qa.json', 'w') as f:
    json.dump(qa_list, f)

In [17]:
qa_list = transform_to_qa(train_conversations)
print(f"Total QA pairs extracted: {len(qa_list)}")

with open('../../data/train_conversations_single_qa.json', 'w') as f:
    json.dump(qa_list, f)

Processing conversations: 100%|██████████| 23972/23972 [00:00<00:00, 26812.80it/s]


Total QA pairs extracted: 95888
