### Preprocessing for Mimic4 Dataset

**Note that the AUG dataset is incomplete, so we skip it here**

In [70]:
from medrlcot.config.env import MedRL_CoT
from medrlcot import data_manager
from medrlcot.medrlcot_logger import setup_logger
from dotenv import load_dotenv
from datasets import Features, Value
from IPython.display import display
import numpy as np
import pandas as pd
import datasets as hf_datasets
import logging
import os
import json

# pd.set_option('display.max_colwidth', None)  
# pd.set_option('display.width', 0)    
# pd.set_option('display.max_rows', None)  

In [71]:
model_cfg_path = os.path.join(os.getcwd(), "medrlcot/config/.env")
medrlcot_config = MedRL_CoT(model_cfg_path)

setup_logger()
logger = logging.getLogger("MedRL-CoT Preprocess")

2025-06-02 14:09:03,636 || INFO || Logger - Setup for MedRL-CoT's log done. This is the beginning of the log.


Generated new log file logs/medrlcot079.log


In [72]:
# processed_dirs = [os.path.join(os.getcwd(), medrlcot_config.data_dir, ds, 'processed') for ds in medrlcot_config.datasets]
processed_dirs = {ds: os.path.join(os.getcwd(), medrlcot_config.data_dir, ds, 'processed') for ds in medrlcot_config.datasets}
processed_dirs.keys()

dict_keys(['aug_med_notes', 'mimic4'])

In [73]:
classes = np.array(['symptoms_labs', 'thought_process', 'diagnosis'])

In [74]:
def load_labeled(arrow_dir):
    arrows = [os.path.join(arrow_dir, f) for f in os.listdir(arrow_dir) if f.endswith(".arrow")]
    processed_dataset = hf_datasets.concatenate_datasets([hf_datasets.Dataset.from_file(arrow) for arrow in arrows]).to_dict()

    return processed_dataset

In [75]:
processed_datasets = {key: pd.DataFrame(load_labeled(processed_dir)) for key, processed_dir in processed_dirs.items()}
processed_datasets.keys()

dict_keys(['aug_med_notes', 'mimic4'])

In [338]:
def mimic_preprocess(dataset):
    cleaned_ds = dataset.copy()

    # For loggin purposes
    N = cleaned_ds.shape[0]
    logger.info(f"Found {N} rows")
    num_renames = cleaned_ds[cleaned_ds['class'].isin(['symptoms_lbs', 'symptoms_lads'])].shape[0]
    logger.info(f"Fixed class naming for {num_renames} rows ({(num_renames / N)*100} %)")
    
    # Fix naming of some classes
    cleaned_ds['class'] = cleaned_ds['class'].replace({'symptoms_lbs': 'symptoms_labs', 'symptoms_lads': 'symptoms_labs'})
    
    pos_invalids = cleaned_ds[~cleaned_ds['class'].isin(classes)]
    swapped_values = pos_invalids[pos_invalids['sentence'].str.lower().isin(classes)]    # Rows with swapped values
    
    # Clearly invalids, temp drop to remove from our ceaned list
    invalid_classes = pos_invalids[pos_invalids['class'].str.lower().isin(['', '0', '__', 'None'])]  # collect empty sentence and classes (Note that doing it here will catch the invalid swapped sentences as well)
    invalid_sentences = pos_invalids[pos_invalids['sentence'].str.lower().isin(['', '__', 'None'])]
    invalids = pos_invalids.loc[invalid_classes.index.union(invalid_sentences.index)]
    nonstd_classes = pos_invalids.drop(index=swapped_values.index.union(invalids.index))   # Get list of non-standard classes

    # Get list of classes that can be classified as "other" with enough occurence (non-outliery)
    value_cnts = nonstd_classes['class'].value_counts()
    other_classes = value_cnts[value_cnts >= 5].index.tolist()
    # print(value_cnts[value_cnts >= 5])
    # display(cleaned_ds[cleaned_ds['class'] == 'past_surgical_history'])
    # display(cleaned_ds[cleaned_ds['class'] == 'followup_instructions'])
    # display(cleaned_ds[cleaned_ds['class'] == 'demographic_data'])
    # display(cleaned_ds[cleaned_ds['class'] == 'past_surgical_history'])
    
    # Clean the dataset
    swapped_indices = swapped_values.index
    cleaned_ds.loc[swapped_indices, ['sentence', 'class']] = cleaned_ds.loc[swapped_indices, ['class', 'sentence']].values # swap the values in indices where it's swapped
    cleaned_ds['class'] = cleaned_ds['class'].apply(lambda x: 'other' if x in other_classes else x)  # relabel non-standards to 'other'
    drop_indices = cleaned_ds[~cleaned_ds['class'].isin(np.append(classes, 'other'))].index
    cleaned_ds = cleaned_ds.drop(index=drop_indices) # drop all others that aren't in our list of classes + 'other'  (basically all invalids)

    # Summary
    num_reclass = cleaned_ds[cleaned_ds['class'] == 'other'].shape[0]
    logger.info(f"Re-classified {len(other_classes)} classes as 'other', or {num_reclass} rows ({(num_reclass / N)*100} %)")
    logger.info(f'Swapped class and sentence values of {swapped_indices.shape[0]} rows ({(swapped_indices.shape[0] / N)*100} %)')
    logger.info(f'Dropped {drop_indices.shape[0]} invalid rows ({(drop_indices.shape[0] / N)*100} %)')

    return cleaned_ds

mimic_preprocess(processed_datasets['mimic4'])

2025-06-02 17:42:33,392 || INFO || MedRL-CoT Preprocess - Found 29654 rows
2025-06-02 17:42:33,394 || INFO || MedRL-CoT Preprocess - Fixed class naming for 282 rows (0.9509678289606799 %)
2025-06-02 17:42:33,420 || INFO || MedRL-CoT Preprocess - Re-classified 18 classes as 'other', or 1076 rows (3.628515545963445 %)
2025-06-02 17:42:33,421 || INFO || MedRL-CoT Preprocess - Swapped class and sentence values of 571 rows (1.925541242328185 %)
2025-06-02 17:42:33,423 || INFO || MedRL-CoT Preprocess - Dropped 166 invalid rows (0.5597895730761449 %)


Unnamed: 0,sentence,class,case_id
0,History of Present Illness:,symptoms_labs,-1
1,Pt reports self-discontinuing lasix and spirno...,symptoms_labs,-1
2,She does not follow Na-restricted diets.,symptoms_labs,-1
3,"In the past week, she notes that she has been ...",symptoms_labs,-1
4,"She denies ___ edema, or SOB, or orthopnea.",symptoms_labs,-1
...,...,...,...
29649,Please call cardiac surgery office with any qu...,diagnosis,283
29650,Patient presented with a chief complaint of ch...,symptoms_labs,283
29651,The ECG showed no signs of ischemia.,symptoms_labs,283
29652,The patient's history of hypertension and hype...,thought_process,283


In [337]:
def aug_preprocess(dataset):
    cleaned_ds = dataset.copy()

    # For loggin purposes
    N = cleaned_ds.shape[0]
    logger.info(f"Found {N} rows")
    num_renames = cleaned_ds[cleaned_ds['class'].isin(['symptoms_lbs', 'symptoms_lads'])].shape[0]
    logger.info(f"Fixed class naming for {num_renames} rows ({(num_renames / N)*100} %)")
    
    # Fix naming of some classes
    cleaned_ds['class'] = cleaned_ds['class'].replace({'symptoms_lbs': 'symptoms_labs', 'symptoms_lads': 'symptoms_labs'})
    
    pos_invalids = cleaned_ds[~cleaned_ds['class'].isin(classes)]
    swapped_values = pos_invalids[pos_invalids['sentence'].isin(classes)]    # Rows with swapped values
    invalid_classes = pos_invalids[pos_invalids['class'].str.lower().isin(['', '0', '__', 'None', '[]', 'False', 'The', 'No'])]  # collect empty sentence and classes (Note that doing it here will catch the invalid swapped sentences as well)
    # ignore_classes = pos_invalids[pos_invalids['class'].str.contains('not a sentence')]
    invalid_sentences = pos_invalids[pos_invalids['sentence'].str.lower().isin(['', '__', 'None', '[]', 'False', '()'])]
    
    # invalids = pos_invalids.loc[invalid_classes.index.union(invalid_sentences.index.union(ignore_classes.index))]
    invalids = pos_invalids.loc[invalid_classes.index.union(invalid_sentences.index)]
    # invalids = pos_invalids[pos_invalids ['class'].isin(['', '0', '[]', 'False'])]   # Clearly invalids, temp drop to remove from our ceaned list
    nonstd_classes = pos_invalids.drop(index=swapped_values.index.union(invalids.index))   # Get list of non-standard classes

    # Get list of classes that can be classified as "other" with enough occurence (non-outliery)
    value_cnts = nonstd_classes['class'].value_counts()
    other_classes = value_cnts[value_cnts >= 5].index.tolist()
    # print(value_cnts[value_cnts >= 5])
    # display(cleaned_ds[cleaned_ds['class'] == 'symptoms_lbs'])
    # display(cleaned_ds[cleaned_ds['class'] == 'A'])
    # display(cleaned_ds[cleaned_ds['class'] == 'classification'])
    # display(cleaned_ds[cleaned_ds['class'] == 'No'])
    # display(cleaned_ds[cleaned_ds['class'] == 'The'])

    swapped_indices = swapped_values.index
    cleaned_ds.loc[swapped_indices, ['sentence', 'class']] = cleaned_ds.loc[swapped_indices, ['class', 'sentence']].values # swap the values in indices where it's swapped
    cleaned_ds['class'] = cleaned_ds['class'].apply(lambda x: 'other' if x in other_classes else x)  # relabel non-standards to 'other'
    drop_indices = cleaned_ds[~cleaned_ds['class'].isin(np.append(classes, 'other'))].index
    cleaned_ds = cleaned_ds.drop(index=drop_indices) # drop all others that aren't in our list of classes + 'other' (basically all invalids)

    # Summary
    num_reclass = cleaned_ds[cleaned_ds['class'] == 'other'].shape[0]
    logger.info(f"Re-classified {len(other_classes)} classes as 'other', or {num_reclass} rows ({(num_reclass / N)*100} %)")
    logger.info(f'Swapped class and sentence values of {swapped_indices.shape[0]} rows ({(swapped_indices.shape[0] / N)*100} %)')
    logger.info(f'Dropped {drop_indices.shape[0]} invalid rows ({(drop_indices.shape[0] / N)*100} %)')
    
    return cleaned_ds

aug_preprocess(processed_datasets['aug_med_notes'])

2025-06-02 17:42:21,243 || INFO || MedRL-CoT Preprocess - Found 19968 rows
2025-06-02 17:42:21,246 || INFO || MedRL-CoT Preprocess - Fixed class naming for 21 rows (0.10516826923076923 %)
2025-06-02 17:42:21,272 || INFO || MedRL-CoT Preprocess - Re-classified 5 classes as 'other', or 77 rows (0.3856169871794872 %)
2025-06-02 17:42:21,274 || INFO || MedRL-CoT Preprocess - Swapped class and sentence values of 131 rows (0.6560496794871795 %)
2025-06-02 17:42:21,275 || INFO || MedRL-CoT Preprocess - Dropped 163 invalid rows (0.8163060897435898 %)


Unnamed: 0,sentence,class,case_id
0,"A sixteen year-old girl, presented to our Outp...",symptoms_labs,-1
1,She was not able to maintain an erect posture ...,symptoms_labs,-1
2,She would keep her head turned to the right an...,symptoms_labs,-1
3,There was a sideways bending of the back in th...,symptoms_labs,-1
4,To counter the abnormal positioning of the bac...,symptoms_labs,-1
...,...,...,...
19963,"No significant muscle atrophy was present, and...",symptoms_labs,912
19964,The physical examination showed that both of h...,symptoms_labs,915
19965,Both knees were very soft and could touch the ...,symptoms_labs,915
19966,"Upon palpation, the continuity of the quadrice...",symptoms_labs,915


In [339]:
# Testing
proc_funcs = {'mimic4': mimic_preprocess, 'aug_med_notes': aug_preprocess}

In [344]:
preprocessed_datasets = dict()
for key, item in processed_datasets.items():
    logger.info("=" * 50)
    logger.info(f"Cleaning up {key} dataset")
    # processed_datasets[key] = mimic_preprocess(processed_datasets[key])
    # mimic_preprocess(processed_datasets[key])
    preprocessed_datasets[key] = proc_funcs[key](processed_datasets[key])
    logger.info("=" * 50)

2025-06-02 17:47:33,791 || INFO || MedRL-CoT Preprocess - Cleaning up aug_med_notes dataset
2025-06-02 17:47:33,802 || INFO || MedRL-CoT Preprocess - Found 19968 rows
2025-06-02 17:47:33,806 || INFO || MedRL-CoT Preprocess - Fixed class naming for 21 rows (0.10516826923076923 %)
2025-06-02 17:47:33,840 || INFO || MedRL-CoT Preprocess - Re-classified 5 classes as 'other', or 77 rows (0.3856169871794872 %)
2025-06-02 17:47:33,842 || INFO || MedRL-CoT Preprocess - Swapped class and sentence values of 131 rows (0.6560496794871795 %)
2025-06-02 17:47:33,842 || INFO || MedRL-CoT Preprocess - Dropped 163 invalid rows (0.8163060897435898 %)
2025-06-02 17:47:33,847 || INFO || MedRL-CoT Preprocess - Cleaning up mimic4 dataset
2025-06-02 17:47:33,850 || INFO || MedRL-CoT Preprocess - Found 29654 rows
2025-06-02 17:47:33,853 || INFO || MedRL-CoT Preprocess - Fixed class naming for 282 rows (0.9509678289606799 %)
2025-06-02 17:47:33,886 || INFO || MedRL-CoT Preprocess - Re-classified 18 classes as 

In [345]:
preprocessed_datasets.keys()

dict_keys(['aug_med_notes', 'mimic4'])