In [31]:
import os
import pandas as pd
import numpy as np
import json
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
import re

### Load the dataset with the predictions

In [32]:
df = pd.read_csv('../../data/mimic-acute-mi_modelling_with_preds.csv')
#df[df['STEMI']==1].head()
display(df.head())
print(df.shape)

Unnamed: 0.1,Unnamed: 0,hadm_id,ecg_time,st_elevation,st_depression,t_wave,Acute_MI,study_id,simple_note,full_note,...,comments,troponin,STEMI,NSTEMI,Q2,pred_st_elevation,pred_st_depression,pred_t_wave,pred_Q2,pred_troponin
0,0,26913865.0,2189-06-27 10:24:00,0,0,0,1,49144190,Chief Complaint:\ndyspnea\n \nPast Medical His...,Allergies: \nNo Known Allergies / Adverse Drug...,...,___,1,0,1,0,1.723308,-4.258851,-3.849245,-3.775553,0.902153
1,1,21727510.0,2187-11-24 09:06:00,0,0,0,0,42151877,Chief Complaint:\nShortness of breath\n \nPast...,Allergies: \nNo Known Allergies / Adverse Drug...,...,CTROPNT > 0.10 NG/ML SUGGESTS ACUTE MI.,1,0,0,0,-0.933257,-1.280811,-3.238944,-3.052563,-0.30116
2,2,20125756.0,2186-01-19 22:17:00,0,0,0,0,42113752,Chief Complaint:\nchest pain\n \nPast Medical ...,Allergies: \nNo Known Allergies / Adverse Drug...,...,CTROPNT > 0.10 NG/ML SUGGESTS ACUTE MI.,1,0,0,0,-4.724946,-5.166714,-4.467826,-5.640689,-4.209776
3,3,24299957.0,2167-07-18 09:37:00,0,0,0,0,41216642,Chief Complaint:\nCritical renal artery stenos...,Allergies: \nShellfish\n \nChief Complaint:\nC...,...,,0,0,0,0,-3.929287,-3.027426,-6.42435,-2.726993,-4.34295
4,4,29035697.0,2121-03-11 17:43:00,0,0,0,1,41803997,Chief Complaint:\nChest pain\n \nPast Medical ...,Allergies: \nPatient recorded as having No Kno...,...,___,1,0,1,0,-2.019528,-4.272657,-5.044726,-4.304021,-0.200219


(30000, 21)


In [33]:
columns_to_transform = ['pred_st_elevation', 'pred_st_depression', 'pred_t_wave', 'pred_Q2', 'pred_troponin']

# Define the sigmoid function
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

# Apply the sigmoid transformation to the specified columns
df[columns_to_transform] = df[columns_to_transform].apply(sigmoid)

# Display the dataframe to check the transformation
print(df[columns_to_transform].head())

   pred_st_elevation  pred_st_depression  pred_t_wave   pred_Q2  pred_troponin
0           0.848554            0.013941     0.020852  0.022411       0.711392
1           0.282264            0.217412     0.037726  0.045107       0.425274
2           0.008793            0.005671     0.011342  0.003538       0.014632
3           0.019279            0.046202     0.001619  0.061399       0.012831
4           0.117168            0.013753     0.006402  0.013334       0.450112


In [34]:
## Check AUCs are ok
print(df.columns)
print(roc_auc_score(df['st_elevation'], df['pred_st_elevation']))
print(roc_auc_score(df['st_depression'], df['pred_st_depression']))
print(roc_auc_score(df['t_wave'], df['pred_t_wave']))
# print(roc_auc_score(df['Acute_MI'], df['pred_st_elevation']))
print(roc_auc_score(df['troponin'], df['pred_troponin']))
print(roc_auc_score(df['Q2'], df['pred_Q2']))

Index(['Unnamed: 0', 'hadm_id', 'ecg_time', 'st_elevation', 'st_depression',
       't_wave', 'Acute_MI', 'study_id', 'simple_note', 'full_note',
       'valuenum', 'comments', 'troponin', 'STEMI', 'NSTEMI', 'Q2',
       'pred_st_elevation', 'pred_st_depression', 'pred_t_wave', 'pred_Q2',
       'pred_troponin'],
      dtype='object')
0.8452305038536526
0.8773483603032629
0.7523178343800515
0.7327958141945257
0.8146190568348426


In [35]:
print(df.shape)
print(len(df['hadm_id'].unique()))

(30000, 21)
21520


### Improve the medical note representation

In [36]:
print(df['simple_note'][5])

Chief Complaint:
Fevers
 
Past Medical History:
Medications on Admission:
The Preadmission Medication list is accurate and complete.
1. Aspirin 81 mg PO DAILY 
2. Atorvastatin 20 mg PO QPM 
3. Mycophenolate Mofetil 500 mg PO BID 
4. Omeprazole 20 mg PO BID 
5. PredniSONE 5 mg PO DAILY 
6. Tacrolimus 2 mg PO Q12H 
7. Nystatin Oral Suspension 5 mL PO QID 
8. Sulfameth/Trimethoprim SS 1 TAB PO DAILY 

 



In [37]:
import re

# Your original string
text = "Hello,\nthis is a string with strange symbols! @#$%^&*()"

# Remove all non-alphanumeric characters and replace them with a space
cleaned_text = re.sub(r'[^a-zA-Z0-9\s]', '', text)

# Replace newline characters with a space
cleaned_text = cleaned_text.replace('\n', ' ')

# Remove extra spaces
cleaned_text = re.sub(r'\s+', ' ', cleaned_text).strip()

print(cleaned_text)

Hello this is a string with strange symbols


In [38]:
SECTIONS = ['Chief Complaint'] #, 'Past Medical History'
def clean_note(note):
    cleaned = ''
    for heading in SECTIONS:
        content = note.split(heading+':')
        if len(content) == 1: continue
        content = re.split(r'\n[^\n:]+:', content[1])[0] # find next section heading and cut
        cleaned += heading+':' + content + '\n'
    # Remove all non-alphanumeric characters and replace them with a space
    cleaned_text = re.sub(r'[^a-zA-Z0-9\s]', '', cleaned)

    # Replace newline characters with a space
    cleaned_text = cleaned_text.replace('\n', ' ')

    # Remove extra spaces
    cleaned_text = re.sub(r'\s+', ' ', cleaned_text).strip()
    return cleaned_text

df['chief_complaint'] = df['full_note'].map(clean_note)

In [39]:
print(df['chief_complaint'][2])

Chief Complaint chest pain


In [40]:
print(df['full_note'][1])

Allergies: 
No Known Allergies / Adverse Drug Reactions
 
Chief Complaint:
Shortness of breath
 
Major Surgical or Invasive Procedure:
Right heart Cath ___
TEE ___

 
History of Present Illness:
___ with non-ischemic cardiomyopathy, CHF with EF ___, 
biventricular pacing, recent admission for VT storm, atrial 
fibrillation with CHADS-VASC of 4, CKD stage 3, hypertension, 
now presenting with ___ days of dyspnea. He also reports 
increased heart rate, palpitations, paroxysmal nocturnal 
dyspnea, increased lower extremity edema. He states he has a 
chronic nonproductive cough but has no fevers / chills. He notes 
he has been taking all of his medications and has not missed his 
diuretics. He has not noticed any significant weight changes but 
does not weigh himself regularly. He has baseline orthopnea. He 
has been eating sausage, pizza, chips with resulting increased 
thirst and fluid intake as an outpatient. He denies chest pain, 
syncope but has had light headedness since admission in

### Check on some conversations

In [41]:
with open('../../data/train_conversations_with_preds.json') as f:
    conversations = json.load(f)

In [42]:
# Function to filter and print the conversation by specific ID
def print_conversation_by_id(conversations, target_id):
    for item in conversations:
        if item.get('id') == target_id:
            print(f"Conversation with ID {target_id}:")
            print(json.dumps(item, indent=4))
            print("\n" + "-"*40 + "\n")
            return  # Exit after finding the conversation
    print(f"No conversation found with ID {target_id}")

# Example usage:
target_id = 43919386  # Replace this with the desired ID
print_conversation_by_id(conversations, target_id)
print(len(conversations))

Conversation with ID 43919386:
{
    "id": 43919386,
    "image": "43919386.jpeg",
    "conversations": [
        {
            "from": "human",
            "value": "Chief Complaint:\nChest Pain\n \nPast Medical History:\nHTN  \nHyperlipidemia  \nR knee meniscal repair  \nmelanoma removed from right thigh ___ \n\n \nMedications on Admission:\nZocor 20 mg daily  \nCozaar 50 mg daily\n\n \n\n pred_st_elevation: 0.3637114267568077  pred_st_depression: 0.0051662996433517894 pred_t_wave 0.002297407935460237 pred_Q2 0.005083725698716161 pred_troponin  0.24447508535836818  Please review the ECG of this patient. Check for any signs of ST elevation."
        },
        {
            "from": "gpt",
            "value": "The ECG shows evidence of ST elevation."
        },
        {
            "from": "human",
            "value": " pred_st_elevation: 0.3637114267568077  pred_st_depression: 0.0051662996433517894 pred_t_wave 0.002297407935460237 pred_Q2 0.005083725698716161 pred_troponin  0.24447

In [43]:
with open('../../data/train_conversations_with_preds.json') as f:
    conversations = json.load(f)

N = 5  # Number of elements to display

for i, item in enumerate(conversations[-5:], start=1):
    print(f"Element {i}:")
    print(json.dumps(item, indent=4))
    print("\n" + "-"*40 + "\n")

Element 1:
{
    "id": 40285773,
    "image": "40285773.jpeg",
    "conversations": [
        {
            "from": "human",
            "value": "Chief Complaint:\ntransferred for possible CVVH\n \nPast Medical History:\n-History of hypertension, more recently with hypotension (BP as \nlow as 78/46 in clinic during the past year)\n-HFrEF (___) s/p ICD; akinetic in LAD distribution, so \nthought to be ischemic cardiomyopathy. Also with diastolic \ndysfunction and moderate AS.\n-CAD s/p CABG in ___\n-Type II DM \n-neuropathy and b/l Charcot foot\n-CKD (baseline Cr 2.0)\n-OSA\n-Nocturnal desaturations, started on home O2\n-Thrombocytopenia\n-Asthma/COPD (FEV1/FVC 79%, FEV1 58%)\n-Anemia \n-PVD\n-Chronic pain on opiates\n-Chronic constipation\n-Anxiety\n-Depression \n-Tremors/myoclonic jerks\n-Squamous cell skin cancer of R hand\n-C diff\n\nSURGICAL HISTORY\nappendectomy\nCataracts\nUteroscopy for post-menopausal bleeding\nBone resection for ?osteomyelitis of sternum\nClosed R trimalleola

In [44]:
with open('../../data/train_conversations_with_preds.json') as f:
    conversations = json.load(f)

N = 5  # Number of elements to display

for i, item in enumerate(conversations[-5:], start=1):
    print(f"Element {i}:")
    print(json.dumps(item, indent=4))
    print("\n" + "-"*40 + "\n")

Element 1:
{
    "id": 40285773,
    "image": "40285773.jpeg",
    "conversations": [
        {
            "from": "human",
            "value": "Chief Complaint:\ntransferred for possible CVVH\n \nPast Medical History:\n-History of hypertension, more recently with hypotension (BP as \nlow as 78/46 in clinic during the past year)\n-HFrEF (___) s/p ICD; akinetic in LAD distribution, so \nthought to be ischemic cardiomyopathy. Also with diastolic \ndysfunction and moderate AS.\n-CAD s/p CABG in ___\n-Type II DM \n-neuropathy and b/l Charcot foot\n-CKD (baseline Cr 2.0)\n-OSA\n-Nocturnal desaturations, started on home O2\n-Thrombocytopenia\n-Asthma/COPD (FEV1/FVC 79%, FEV1 58%)\n-Anemia \n-PVD\n-Chronic pain on opiates\n-Chronic constipation\n-Anxiety\n-Depression \n-Tremors/myoclonic jerks\n-Squamous cell skin cancer of R hand\n-C diff\n\nSURGICAL HISTORY\nappendectomy\nCataracts\nUteroscopy for post-menopausal bleeding\nBone resection for ?osteomyelitis of sternum\nClosed R trimalleola

### Split dataset into test and training dataframe

In [45]:
# Load the JSON file
with open('../../data/test_conversations.json', 'r') as file:
    data = json.load(file)

# Extract all 'id' values into a list
test_ids = [float(item['id']) for item in data]

# Display the list of ids
print(len(test_ids))

# Load the JSON file
with open('../../data/train_conversations.json', 'r') as file:
    data = json.load(file)

# Extract all 'id' values into a list
train_ids = [float(item['id']) for item in data]

# Display the list of ids
print(len(train_ids))

6000
24000


In [46]:
train_df = df[df['study_id'].isin(train_ids)]
print(train_df.shape)
test_df = df[df['study_id'].isin(test_ids)]
print(test_df.shape)

(24000, 22)
(6000, 22)


### Regenerate conversation files

In [47]:
with open('../../data/conv_templates_with_preds.json') as f: #diag_q
    conv_templates = json.load(f)

In [48]:
import random

def generate_conversations(df, conv_templates):
    conversations = []
    
    for idx, row in df.iterrows():
        idx_1 = row['study_id']
        conv = {
            "id": row['study_id'],
            "image": f"{idx_1}.jpeg",
            "conversations": []
        }

        if row['STEMI']:
            conv['diagnosis'] = 'STEMI'
        elif row['NSTEMI']:
            conv['diagnosis'] = 'NSTEMI'
        else:
            conv['diagnosis'] = 'NO ACS'

        conv['conversations'].append({
            "from": "human",
            "value": random.choice(conv_templates['st_elev_q']).replace('{INSERT PATIENT NOTE}', row['chief_complaint']).replace("{PREDICTIONS HERE}", f"pred_st_elevation: {row['pred_st_elevation']}  pred_st_depression: {row['pred_st_depression']} pred_t_wave {row['pred_t_wave']} pred_Q2 {row['pred_Q2']} pred_troponin  {row['pred_troponin']}")
        })

        if row['st_elevation']:
            conv['conversations'].append({
                "from": "gpt",
                "value": random.choice(conv_templates['st_elev_a_true'])
            })
            conv['conversations'].append({
                "from": "human",
                "value": random.choice(conv_templates['st_depres_twi_q']).replace('Your prior assessment was: {INSERT PRIOR ANSWER}', '').replace("{PREDICTIONS HERE}", f"pred_st_elevation: {row['pred_st_elevation']}  pred_st_depression: {row['pred_st_depression']} pred_t_wave {row['pred_t_wave']} pred_Q2 {row['pred_Q2']} pred_troponin  {row['pred_troponin']}")
            })
            if row['st_depression'] or row['t_wave']:
                conv['conversations'].append({
                    "from": "gpt",
                    "value": random.choice(conv_templates['st_depres_twi_a_true'])
                })
                conv['conversations'].append({
                    "from": "human",
                    "value": random.choice(conv_templates['troponin_q']).replace('Your prior assessment was: {INSERT PRIOR ANSWER}', '').replace("{PREDICTIONS HERE}", f"pred_st_elevation: {row['pred_st_elevation']}  pred_st_depression: {row['pred_st_depression']} pred_t_wave {row['pred_t_wave']} pred_Q2 {row['pred_Q2']} pred_troponin  {row['pred_troponin']}")
                })
            else:
                conv['conversations'].append({
                    "from": "gpt",
                    "value": random.choice(conv_templates['st_depres_twi_a_false'])
                })
                conv['conversations'].append({
                    "from": "human",
                    "value": random.choice(conv_templates['troponin_q']).replace('Your prior assessment was: {INSERT PRIOR ANSWER}', '').replace("{PREDICTIONS HERE}", f"pred_st_elevation: {row['pred_st_elevation']}  pred_st_depression: {row['pred_st_depression']} pred_t_wave {row['pred_t_wave']} pred_Q2 {row['pred_Q2']} pred_troponin  {row['pred_troponin']}")
                })
            
            if row['troponin']:
                conv['conversations'].append({
                    "from": "gpt",
                    "value": random.choice(conv_templates['troponin_a_true'])
                })
                conv['conversations'].append({
                    "from": "human",
                    "value": random.choice(conv_templates['diag_q_trop']).replace('Your prior assessment was: {INSERT PRIOR ANSWER}', '').replace('{INSERT_RESULT}', f"{row['valuenum']},{row['comments']}").replace("{PREDICTIONS HERE}", f"pred_st_elevation: {row['pred_st_elevation']}  pred_st_depression: {row['pred_st_depression']} pred_t_wave {row['pred_t_wave']} pred_Q2 {row['pred_Q2']} pred_troponin  {row['pred_troponin']}")
                })
            else:
                conv['conversations'].append({
                    "from": "gpt",
                    "value": random.choice(conv_templates['troponin_a_false'])
                })
                conv['conversations'].append({
                    "from": "human",
                    "value": random.choice(conv_templates['diag_q']).replace('Your prior assessment was: {INSERT PRIOR ANSWER}', '').replace("{PREDICTIONS HERE}", f"pred_st_elevation: {row['pred_st_elevation']}  pred_st_depression: {row['pred_st_depression']} pred_t_wave {row['pred_t_wave']} pred_Q2 {row['pred_Q2']} pred_troponin  {row['pred_troponin']}")
                })

            if row['STEMI']:
                conv['conversations'].append({
                    "from": "gpt",
                    "value": random.choice(conv_templates['diag_a_stemi'])
                })
            else:
                conv['conversations'].append({
                    "from": "gpt",
                    "value": random.choice(conv_templates['diag_a_normal'])
                })
        else:
            conv['conversations'].append({
                "from": "gpt",
                "value": random.choice(conv_templates['st_elev_a_false'])
            })
            conv['conversations'].append({
                "from": "human",
                "value": random.choice(conv_templates['st_depres_twi_q']).replace('Your prior assessment was: {INSERT PRIOR ANSWER}', '').replace("{PREDICTIONS HERE}", f"pred_st_elevation: {row['pred_st_elevation']}  pred_st_depression: {row['pred_st_depression']} pred_t_wave {row['pred_t_wave']} pred_Q2 {row['pred_Q2']} pred_troponin  {row['pred_troponin']}")
            })
            if row['st_depression'] or row['t_wave']:
                conv['conversations'].append({
                    "from": "gpt",
                    "value": random.choice(conv_templates['st_depres_twi_a_true'])
                })
                conv['conversations'].append({
                    "from": "human",
                    "value": random.choice(conv_templates['troponin_q']).replace('Your prior assessment was: {INSERT PRIOR ANSWER}', '').replace("{PREDICTIONS HERE}", f"pred_st_elevation: {row['pred_st_elevation']}  pred_st_depression: {row['pred_st_depression']} pred_t_wave {row['pred_t_wave']} pred_Q2 {row['pred_Q2']} pred_troponin  {row['pred_troponin']}")
                })
            else:
                conv['conversations'].append({
                    "from": "gpt",
                    "value": random.choice(conv_templates['st_depres_twi_a_false'])
                })
                conv['conversations'].append({
                    "from": "human",
                    "value": random.choice(conv_templates['troponin_q']).replace('Your prior assessment was: {INSERT PRIOR ANSWER}', '').replace("{PREDICTIONS HERE}", f"pred_st_elevation: {row['pred_st_elevation']}  pred_st_depression: {row['pred_st_depression']} pred_t_wave {row['pred_t_wave']} pred_Q2 {row['pred_Q2']} pred_troponin  {row['pred_troponin']}")
                })

            if row['troponin']:
                conv['conversations'].append({
                    "from": "gpt",
                    "value": random.choice(conv_templates['troponin_a_true'])
                })
                conv['conversations'].append({
                    "from": "human",
                    "value": random.choice(conv_templates['diag_q_trop']).replace('Your prior assessment was: {INSERT PRIOR ANSWER}', '').replace('{INSERT_RESULT}', f"{row['valuenum']},{row['comments']}").replace("{PREDICTIONS HERE}", f"pred_st_elevation: {row['pred_st_elevation']}  pred_st_depression: {row['pred_st_depression']} pred_t_wave {row['pred_t_wave']} pred_Q2 {row['pred_Q2']} pred_troponin  {row['pred_troponin']}")
                })
            else:
                conv['conversations'].append({
                    "from": "gpt",
                    "value": random.choice(conv_templates['troponin_a_false'])
                })
                conv['conversations'].append({
                    "from": "human",
                    "value": random.choice(conv_templates['diag_q']).replace('Your prior assessment was: {INSERT PRIOR ANSWER}', '').replace("{PREDICTIONS HERE}", f"pred_st_elevation: {row['pred_st_elevation']}  pred_st_depression: {row['pred_st_depression']} pred_t_wave {row['pred_t_wave']} pred_Q2 {row['pred_Q2']} pred_troponin  {row['pred_troponin']}")
                })
                
            if row['NSTEMI']:
                conv['conversations'].append({
                    "from": "gpt",
                    "value": random.choice(conv_templates['diag_a_nstemi'])
                })
            else:
                conv['conversations'].append({
                    "from": "gpt",
                    "value": random.choice(conv_templates['diag_a_normal'])
                })
        
        conversations.append(conv)

    return conversations


In [49]:
# train_df = train_df[~train_df['simple_note'].isna()]
# train_df = train_df[(df['troponin'] & ~df['valuenum'].isnull() & ~df['comments'].isnull()) | (df['troponin'] == 0)]

In [50]:
train_conversations = generate_conversations(train_df, conv_templates)
test_conversations = generate_conversations(test_df, conv_templates)

In [51]:
with open('../../data/train_conversations_with_preds_simple.json', 'w') as f:
    json.dump(train_conversations, f)

with open('../../data/test_conversations_with_preds_simple.json', 'w') as f:
    json.dump(test_conversations, f)


### Generate corresponding Single Turn QA Datasets 

In [52]:

def transform_to_qa(test_conversations):
    qa_list = []  # List to hold all QA pairs

    # Iterate over each conversation item with a progress bar
    for item in tqdm(test_conversations, desc='Processing conversations'):
        convs = item.get('conversations', [])
        conv_id = item.get('id')
        image = item.get('image')
        diagnosis = item.get('diagnosis')

        # Initialize index
        i = 0
        while i < len(convs):
            # Check if the current turn is from 'human'
            if convs[i].get('from') == 'human':
                human_msg = convs[i].get('value', '').strip().replace('<image>', '')

                # Check if the next turn exists and is from 'gpt'
                if i + 1 < len(convs) and convs[i + 1].get('from') == 'gpt':
                    gpt_msg = convs[i + 1].get('value', '').strip()

                    # Append the QA pair to the list, maintaining 'from' and 'value'
                    qa_list.append({
                        'id': conv_id,
                        'image': image,
                        'conversations': [
                            {
                                'from': 'human',
                                'value': human_msg
                            },
                            {
                                'from': 'gpt',
                                'value': gpt_msg
                            }
                        ],
                        #'diagnosis': diagnosis
                    })

                    # Move to the next pair
                    i += 2
                else:
                    # If there's no corresponding 'gpt' response, skip to next
                    i += 1
            else:
                # If the current turn is not from 'human', skip to next
                i += 1
    return qa_list

qa_list = transform_to_qa(test_conversations)

print(f"Total QA pairs extracted: {len(qa_list)}")

Processing conversations: 100%|██████████| 6000/6000 [00:00<00:00, 95112.17it/s]

Total QA pairs extracted: 24000





In [53]:
with open('../../data/test_conversations_single_qa_with_preds_simple.json', 'w') as f:
    json.dump(qa_list, f)

In [54]:
qa_list = transform_to_qa(train_conversations)
print(f"Total QA pairs extracted: {len(qa_list)}")

with open('../../data/train_conversations_single_qa_with_preds_simple.json', 'w') as f:
    json.dump(qa_list, f)

Processing conversations:   0%|          | 0/24000 [00:00<?, ?it/s]

Processing conversations: 100%|██████████| 24000/24000 [00:00<00:00, 49608.24it/s]


Total QA pairs extracted: 96000


In [55]:
def transform_to_q_and_a(test_conversations):
    q_list = []  # List to hold all QA pairs
    a_list = []

    # Iterate over each conversation item with a progress bar
    for item in tqdm(test_conversations, desc='Processing conversations'):
        convs = item.get('conversations', [])
        conv_id = item.get('id')
        image = item.get('image')
        diagnosis = item.get('diagnosis')

        # Initialize index
        i = 0
        while i < len(convs):
            # Check if the current turn is from 'human'
            if convs[i].get('from') == 'human':
                human_msg = convs[i].get('value', '').strip().replace('<image>', '')

                # Check if the next turn exists and is from 'gpt'
                if i + 1 < len(convs) and convs[i + 1].get('from') == 'gpt':
                    gpt_msg = convs[i + 1].get('value', '').strip()

                    # Append the QA pair to the list, maintaining 'from' and 'value'
                    q_list.append({
                        'question_id': conv_id,
                        'image': image,
                        'text':  human_msg,
                        'category': diagnosis,
                        'number': i//2                    
                        })

                    a_list.append({
                        'question_id': conv_id,
                        'image': image,
                        'text':  gpt_msg,
                        'category': diagnosis,
                        'number': i//2                     
                        })

                    # Move to the next pair
                    i += 2
                else:
                    # If there's no corresponding 'gpt' response, skip to next
                    i += 1
            else:
                # If the current turn is not from 'human', skip to next
                i += 1
    return q_list, a_list

q_list, a_list = transform_to_q_and_a(test_conversations)

print(f"Total QA pairs extracted: {len(q_list)}")

Processing conversations: 100%|██████████| 6000/6000 [00:00<00:00, 119749.44it/s]

Total QA pairs extracted: 24000





In [56]:
# Save q_list to a JSON file
with open('../../data/test_simple_LLaVA_acs_q.json', 'w') as f:
    for item in q_list:
        json.dump(item, f)
        f.write('\n')

# Save a_list to a JSON file
with open('../../data/test_simple_LLaVA_acs_a.json', 'w') as f:
    for item in a_list:
        json.dump(item, f)
        f.write('\n')

In [57]:
q_list, a_list = transform_to_q_and_a(train_conversations)

print(f"Total QA pairs extracted: {len(q_list)}")

Processing conversations:   0%|          | 0/24000 [00:00<?, ?it/s]

Processing conversations: 100%|██████████| 24000/24000 [00:00<00:00, 118647.53it/s]

Total QA pairs extracted: 96000





In [58]:
# Save q_list to a JSON file
with open('../../data/train_simple_LLaVA_acs_q.json', 'w') as f:
    for item in q_list:
        json.dump(item, f)
        f.write('\n')

# Save a_list to a JSON file
with open('../../data/train_simple_LLaVA_acs_a.json', 'w') as f:
    for item in a_list:
        json.dump(item, f)
        f.write('\n')

In [59]:
i=2
print(i//2)

1
