### Preprocessing for Mimic4 Dataset

**Note that the AUG dataset is incomplete, so we skip it here**

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from medrlcot.config.env import MedRL_CoT
from medrlcot import data_manager
from medrlcot.medrlcot_logger import setup_logger
from dotenv import load_dotenv
from datasets import Features, Value
from IPython.display import display
import numpy as np
import pandas as pd
import datasets as hf_datasets
import logging
import os
import json

# pd.set_option('display.max_colwidth', None)  
# pd.set_option('display.width', 0)    
# pd.set_option('display.max_rows', None)  

In [3]:
model_cfg_path = os.path.join(os.getcwd(), "medrlcot/config/.env")
medrlcot_config = MedRL_CoT(model_cfg_path)

setup_logger()
logger = logging.getLogger("MedRL-CoT Preprocess")

2025-06-03 02:44:17,528 || INFO || Logger - Setup for MedRL-CoT's log done. This is the beginning of the log.


Generated new log file logs/medrlcot092.log


In [4]:
# processed_dirs = [os.path.join(os.getcwd(), medrlcot_config.data_dir, ds, 'processed') for ds in medrlcot_config.datasets]
processed_dirs = {ds: os.path.join(os.getcwd(), medrlcot_config.data_dir, ds, 'processed') for ds in medrlcot_config.datasets}
processed_dirs.keys()

dict_keys(['aug_med_notes', 'mimic4'])

In [5]:
classes = np.array(['symptoms_labs', 'thought_process', 'diagnosis'])

In [6]:
def load_labeled(arrow_dir):
    arrows = [os.path.join(arrow_dir, f) for f in os.listdir(arrow_dir) if f.endswith(".arrow")]
    processed_dataset = hf_datasets.concatenate_datasets([hf_datasets.Dataset.from_file(arrow) for arrow in arrows]).to_dict()

    return processed_dataset

In [7]:
processed_datasets = {key: pd.DataFrame(load_labeled(processed_dir)) for key, processed_dir in processed_dirs.items()}
processed_datasets.keys()

dict_keys(['aug_med_notes', 'mimic4'])

In [93]:
def mimic_preprocess(dataset):
    cleaned_ds = dataset.copy()

    # For loggin purposes
    N = cleaned_ds.shape[0]
    logger.info(f"Found {N} rows")
    num_renames = cleaned_ds[cleaned_ds['class'].isin(['symptoms_lbs', 'symptoms_lads'])].shape[0]
    logger.info(f"Fixed class naming for {num_renames} rows ({(num_renames / N)*100} %)")
    
    # Fix naming of some classes
    cleaned_ds['class'] = cleaned_ds['class'].replace({'symptoms_lbs': 'symptoms_labs', 'symptoms_lads': 'symptoms_labs'})
    
    pos_invalids = cleaned_ds[~cleaned_ds['class'].isin(classes)]
    swapped_values = pos_invalids[pos_invalids['sentence'].str.lower().isin(classes)]    # Rows with swapped values
    
    # Clearly invalids, temp drop to remove from our ceaned list
    invalid_classes = pos_invalids[pos_invalids['class'].str.lower().isin(['', '0', '__', 'None'])]  # collect empty sentence and classes (Note that doing it here will catch the invalid swapped sentences as well)
    invalid_sentences = pos_invalids[pos_invalids['sentence'].str.lower().isin(['', '__', 'None'])]
    invalids = pos_invalids.loc[invalid_classes.index.union(invalid_sentences.index)]
    nonstd_classes = pos_invalids.drop(index=swapped_values.index.union(invalids.index))   # Get list of non-standard classes

    # Get list of classes that can be classified as "other" with enough occurence (non-outliery)
    # value_cnts = nonstd_classes['class'].value_counts()
    # other_classes = value_cnts[value_cnts >= 5].index.tolist()
    value_cnts = nonstd_classes['class'].value_counts()
    other_classes = value_cnts[value_cnts >= 5].index.tolist()
    other_class_indices = nonstd_classes[nonstd_classes['class'].isin(other_classes)].index
    # print(value_cnts[value_cnts >= 5])
    # display(cleaned_ds[cleaned_ds['class'] == 'past_surgical_history'])
    # display(cleaned_ds[cleaned_ds['class'] == 'followup_instructions'])
    # display(cleaned_ds[cleaned_ds['class'] == 'demographic_data'])
    # display(cleaned_ds[cleaned_ds['class'] == 'past_surgical_history'])
    
    # Clean the dataset
    swapped_indices = swapped_values.index
    cleaned_ds.loc[swapped_indices, ['sentence', 'class']] = cleaned_ds.loc[swapped_indices, ['class', 'sentence']].values # swap the values in indices where it's swapped
    cleaned_ds.loc[other_class_indices, 'class'] = 'other'
    # cleaned_ds['class'] = cleaned_ds['class'].apply(lambda x: 'other' if x in other_classes else x)  # relabel non-standards to 'other'
    drop_indices = cleaned_ds[~cleaned_ds['class'].isin(np.append(classes, 'other'))].index
    cleaned_ds = cleaned_ds.drop(index=drop_indices) # drop all others that aren't in our list of classes + 'other'  (basically all invalids)

    # Summary
    num_reclass = cleaned_ds[cleaned_ds['class'] == 'other'].shape[0]
    logger.info(f"Re-classified {len(other_classes)} classes as 'other', or {num_reclass} rows ({(num_reclass / N)*100} %)")
    logger.info(f'Swapped class and sentence values of {swapped_indices.shape[0]} rows ({(swapped_indices.shape[0] / N)*100} %)')
    logger.info(f'Dropped {drop_indices.shape[0]} invalid rows ({(drop_indices.shape[0] / N)*100} %)')

    # change case_id into int
    cleaned_ds['case_id'] = cleaned_ds['case_id'].astype(int)

    return cleaned_ds

mimic_preprocess(processed_datasets['mimic4'])

2025-06-03 03:22:10,897 || INFO || Logger - Setup for MedRL-CoT's log done. This is the beginning of the log.
2025-06-03 03:22:10,902 || INFO || MedRL-CoT Preprocess - Found 29654 rows
2025-06-03 03:22:10,905 || INFO || MedRL-CoT Preprocess - Fixed class naming for 282 rows (0.9509678289606799 %)
2025-06-03 03:22:10,924 || INFO || MedRL-CoT Preprocess - Re-classified 18 classes as 'other', or 1073 rows (3.618398866931948 %)
2025-06-03 03:22:10,926 || INFO || MedRL-CoT Preprocess - Swapped class and sentence values of 571 rows (1.925541242328185 %)
2025-06-03 03:22:10,927 || INFO || MedRL-CoT Preprocess - Dropped 169 invalid rows (0.5699062521076415 %)


Generated new log file logs/medrlcot095.log


Unnamed: 0,sentence,class,case_id
0,History of Present Illness:,symptoms_labs,-1
1,Pt reports self-discontinuing lasix and spirno...,symptoms_labs,-1
2,She does not follow Na-restricted diets.,symptoms_labs,-1
3,"In the past week, she notes that she has been ...",symptoms_labs,-1
4,"She denies ___ edema, or SOB, or orthopnea.",symptoms_labs,-1
...,...,...,...
29649,Please call cardiac surgery office with any qu...,diagnosis,283
29650,Patient presented with a chief complaint of ch...,symptoms_labs,283
29651,The ECG showed no signs of ischemia.,symptoms_labs,283
29652,The patient's history of hypertension and hype...,thought_process,283


In [94]:
def aug_preprocess(dataset):
    cleaned_ds = dataset.copy()

    # For loggin purposes
    N = cleaned_ds.shape[0]
    logger.info(f"Found {N} rows")
    num_renames = cleaned_ds[cleaned_ds['class'].isin(['symptoms_lbs', 'symptoms_lads'])].shape[0]
    logger.info(f"Fixed class naming for {num_renames} rows ({(num_renames / N)*100} %)")
    
    # Fix naming of some classes
    cleaned_ds['class'] = cleaned_ds['class'].replace({'symptoms_lbs': 'symptoms_labs', 'symptoms_lads': 'symptoms_labs'})
    
    pos_invalids = cleaned_ds[~cleaned_ds['class'].isin(classes)]
    swapped_values = pos_invalids[pos_invalids['sentence'].isin(classes)]    # Rows with swapped values
    invalid_classes = pos_invalids[pos_invalids['class'].str.lower().isin(['', '0', '__', 'None', '[]', 'False', 'The', 'No'])]  # collect empty sentence and classes (Note that doing it here will catch the invalid swapped sentences as well)
    ignore_classes = pos_invalids[pos_invalids['sentence'].str.contains('not a sentence')]
    invalid_sentences = pos_invalids[pos_invalids['sentence'].str.lower().isin(['', '__', 'None', '[]', 'False', '()'])]

    invalids = pos_invalids.loc[invalid_classes.index.union(invalid_sentences.index.union(ignore_classes.index))]
    # invalids = pos_invalids.loc[invalid_classes.index.union(invalid_sentences.index)]
    # invalids = pos_invalids[pos_invalids ['class'].isin(['', '0', '[]', 'False'])]   # Clearly invalids, temp drop to remove from our ceaned list
    nonstd_classes = pos_invalids.drop(index=swapped_values.index.union(invalids.index))   # Get list of non-standard class rows
    # print(ignore_classes.index[0] in list(nonstd_classes.index))
    # print(ignore_classes.index[0] in list(invalids.index))
    # print(nonstd_classes[nonstd_classes['sentence'].str.contains('not a sentence')])

    # Get list of classes that can be classified as "other" with enough occurence (non-outliery)
    # value_cnts = nonstd_classes['class'].value_counts()
    # other_classes = value_cnts[value_cnts >= 5].index.tolist()
    # other_class_rows = nonstd_classes[nonstd_classes['class'].isin(other_classes)]
    # other_class_rows = other_class_rows[~other_class_rows['sentence'].str.contains('thought_process')]
    # other_class_rows = other_class_rows[~other_class_rows['sentence'].str.contains('symptoms_labs')]
    # other_class_indices = other_class_rows[~other_class_rows['sentence'].str.contains('diagnosis')].index  # Note we get rid of this because bad classification when it should've been "diagnosis", also removes "diagnosis: " sentences
    # print(other_class_indices)
    # print(value_cnts[value_cnts >= 5])
    # display(cleaned_ds[cleaned_ds['class'] == 'symptoms_lbs'])
    # display(cleaned_ds[cleaned_ds['class'] == 'A'])
    # display(cleaned_ds[cleaned_ds['class'] == 'classification'])
    # display(cleaned_ds[cleaned_ds['class'] == 'No'])
    # display(cleaned_ds[cleaned_ds['class'] == 'The'])

    swapped_indices = swapped_values.index
    cleaned_ds.loc[swapped_indices, ['sentence', 'class']] = cleaned_ds.loc[swapped_indices, ['class', 'sentence']].values # swap the values in indices where it's swapped
    # cleaned_ds.loc[other_class_indices, 'class'] = 'other'
    # cleaned_ds['class'] = cleaned_ds['class'].apply(lambda x: 'other' if x in other_classes else x)  # relabel non-standards to 'other'
    drop_indices = cleaned_ds[~cleaned_ds['class'].isin(classes)].index
    # print(ignore_classes.index in list(drop_indices))
    # print(ignore_classes.index)
    # print(list(drop_indices))
    
    cleaned_ds = cleaned_ds.drop(index=drop_indices) # drop all others that aren't in our list of classes + 'other' (basically all invalids)

    # Summary
    num_reclass = cleaned_ds[cleaned_ds['class'] == 'other'].shape[0]
    # logger.info(f"Re-classified {len(other_classes)} classes as 'other', or {num_reclass} rows ({(num_reclass / N)*100} %)")
    logger.info(f'Swapped class and sentence values of {swapped_indices.shape[0]} rows ({(swapped_indices.shape[0] / N)*100} %)')
    logger.info(f'Dropped {drop_indices.shape[0]} invalid rows ({(drop_indices.shape[0] / N)*100} %)')

    # change case_id into int
    cleaned_ds['case_id'] = cleaned_ds['case_id'].astype(int)
    # print(cleaned_ds[cleaned_ds['sentence'].str.contains('not a sentence')])
    
    return cleaned_ds

aug_preprocess(processed_datasets['aug_med_notes'])

2025-06-03 03:22:11,032 || INFO || MedRL-CoT Preprocess - Found 19968 rows
2025-06-03 03:22:11,035 || INFO || MedRL-CoT Preprocess - Fixed class naming for 21 rows (0.10516826923076923 %)
2025-06-03 03:22:11,048 || INFO || MedRL-CoT Preprocess - Swapped class and sentence values of 131 rows (0.6560496794871795 %)
2025-06-03 03:22:11,049 || INFO || MedRL-CoT Preprocess - Dropped 240 invalid rows (1.201923076923077 %)


Unnamed: 0,sentence,class,case_id
0,"A sixteen year-old girl, presented to our Outp...",symptoms_labs,-1
1,She was not able to maintain an erect posture ...,symptoms_labs,-1
2,She would keep her head turned to the right an...,symptoms_labs,-1
3,There was a sideways bending of the back in th...,symptoms_labs,-1
4,To counter the abnormal positioning of the bac...,symptoms_labs,-1
...,...,...,...
19963,"No significant muscle atrophy was present, and...",symptoms_labs,912
19964,The physical examination showed that both of h...,symptoms_labs,915
19965,Both knees were very soft and could touch the ...,symptoms_labs,915
19966,"Upon palpation, the continuity of the quadrice...",symptoms_labs,915


In [95]:
# Testing
proc_funcs = {'mimic4': mimic_preprocess, 'aug_med_notes': aug_preprocess}

In [96]:
preprocessed_datasets = dict()
for key, item in processed_datasets.items():
    logger.info("=" * 50)
    logger.info(f"Cleaning up {key} dataset")
    # processed_datasets[key] = mimic_preprocess(processed_datasets[key])
    # mimic_preprocess(processed_datasets[key])
    preprocessed_datasets[key] = proc_funcs[key](processed_datasets[key])
    logger.info("=" * 50)

2025-06-03 03:22:11,448 || INFO || MedRL-CoT Preprocess - Cleaning up aug_med_notes dataset
2025-06-03 03:22:11,450 || INFO || MedRL-CoT Preprocess - Found 19968 rows
2025-06-03 03:22:11,452 || INFO || MedRL-CoT Preprocess - Fixed class naming for 21 rows (0.10516826923076923 %)
2025-06-03 03:22:11,466 || INFO || MedRL-CoT Preprocess - Swapped class and sentence values of 131 rows (0.6560496794871795 %)
2025-06-03 03:22:11,468 || INFO || MedRL-CoT Preprocess - Dropped 240 invalid rows (1.201923076923077 %)
2025-06-03 03:22:11,474 || INFO || MedRL-CoT Preprocess - Cleaning up mimic4 dataset
2025-06-03 03:22:11,476 || INFO || MedRL-CoT Preprocess - Found 29654 rows
2025-06-03 03:22:11,479 || INFO || MedRL-CoT Preprocess - Fixed class naming for 282 rows (0.9509678289606799 %)
2025-06-03 03:22:11,503 || INFO || MedRL-CoT Preprocess - Re-classified 18 classes as 'other', or 1073 rows (3.618398866931948 %)
2025-06-03 03:22:11,504 || INFO || MedRL-CoT Preprocess - Swapped class and sentence 

In [97]:
preprocessed_datasets.keys()

dict_keys(['aug_med_notes', 'mimic4'])

In [98]:
preprocessed_datasets['mimic4']['case_id'].unique()

array([ -1,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,  14,
        15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
        28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,
        41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,
        54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,
        67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,  78,  79,
        80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,  91,  92,
        93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103, 104, 105,
       106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118,
       119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131,
       132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144,
       145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157,
       158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170,
       171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 18

In [99]:
# def join_sentence_class(group):
#     return ' '.join(f"{row['sentence']} <{row['class']}>" for _, row in group.iterrows())
    
# # preprocessed_datasets['mimic4'].groupby('case_id').apply(join_sentence_class).reset_index(name='full_class_case').sort_values('case_id')

In [100]:
# cases_datasets = dict()
# for key, dataset in preprocessed_datasets.items():
#     cases_datasets[key] = dataset.groupby('case_id').apply(join_sentence_class).reset_index(name='full_class_case').sort_values('case_id')

In [101]:
# cases_datasets['aug_med_notes']

In [102]:
# cases_datasets['aug_med_notes']

In [103]:
preprocessed_datasets['aug_med_notes'][preprocessed_datasets['aug_med_notes']['class'] == 'other']

Unnamed: 0,sentence,class,case_id


In [104]:
def xy_split_processing(group):
    X = []
    Y = []
    for _, row in group.iterrows():
        if row['class'] == 'symptoms_labs' or row['class'] == 'other':
            X.append(row)
        else:
            Y.append(row)

    X_case = ' '.join([str(row['sentence']) for row in X])
    Y_case = ' '.join([f"{row['sentence']} <{row['class']}> " for row in Y])
    
    return pd.Series({'X': X_case, 'Y': Y_case})

import medrlcot.preprocessing as mp
preprocessed_datasets = mp.preprocess_datasets()

# Combine cases into one for cases as example for SFT
cases_datasets = dict()
for key, dataset in preprocessed_datasets.items():
    cases_datasets[key] = dataset.groupby('case_id').apply(xy_split_processing).reset_index().drop(columns=['case_id'])

import os
from medrlcot import data_manager
import medrlcot.config.env as mce
model_cfg_path = os.path.join(os.getcwd(), "medrlcot/config/.env")
medrlcot_config = mce.MedRL_CoT(model_cfg_path)
raw_datasets = data_manager.load_datasets(medrlcot_config.datasets, data_dir=medrlcot_config.data_dir)  # Load raw dataset (original cases) for RM

# Create customd dataset obj
# cases_data = data_manager.Dataset(cases_datasets)

2025-06-03 03:22:14,687 || INFO || MedRL-CoT Preprocess - Cleaning up aug_med_notes dataset
2025-06-03 03:22:14,689 || INFO || MedRL-CoT Preprocess - Found 19968 rows
2025-06-03 03:22:14,691 || INFO || MedRL-CoT Preprocess - Fixed class naming for 21 rows (0.10516826923076923 %)
2025-06-03 03:22:14,705 || INFO || MedRL-CoT Preprocess - Re-classified 5 classes as 'other', or 76 rows (0.38060897435897434 %)
2025-06-03 03:22:14,707 || INFO || MedRL-CoT Preprocess - Swapped class and sentence values of 131 rows (0.6560496794871795 %)
2025-06-03 03:22:14,707 || INFO || MedRL-CoT Preprocess - Dropped 164 invalid rows (0.8213141025641026 %)
2025-06-03 03:22:14,711 || INFO || MedRL-CoT Preprocess - Cleaning up mimic4 dataset
2025-06-03 03:22:14,713 || INFO || MedRL-CoT Preprocess - Found 29654 rows
2025-06-03 03:22:14,716 || INFO || MedRL-CoT Preprocess - Fixed class naming for 282 rows (0.9509678289606799 %)
2025-06-03 03:22:14,735 || INFO || MedRL-CoT Preprocess - Swapped class and sentence 

In [125]:
test = data_manager.Dataset(cases_datasets)
print([len(t) for t in test.data['train'][0]])
pd.concat(test.data['train'][0], ignore_index=True)

2025-06-03 03:58:34,075 || INFO || DataManager - Using train-val split: [0.75, 0.25]
2025-06-03 03:58:34,076 || INFO || DataManager - Splitting aug_med_notes dataset.
2025-06-03 03:58:34,079 || INFO || DataManager - Splitting mimic4 dataset.
2025-06-03 03:58:34,082 || INFO || DataManager - Creating a single joint dataset of dict_keys(['aug_med_notes', 'mimic4']) dataset splits


[561, 211]


0      The goal of no headache more than twice a week...
1      Local anesthesia containing 1.8 mL lidocaine a...
2      A 65-year-old male, without any comorbidities ...
3      On examination there was firm, bony expansion ...
4      Five years later, a CT scan indicated that the...
                             ...                        
767    PMH: Chronic Diastolic Heart Failure, Aortic s...
768    She reports that she swallowed an underwire of...
769    Chief Complaint: chest pain, SOB After dischar...
770    Chief Complaint: Major Surgical or Invasive Pr...
771    History of Present Illness: Patient is a ___ y...
Name: X, Length: 772, dtype: object

In [26]:
cases_datasets['aug_med_notes']

Unnamed: 0,X,Y
0,"A sixteen year-old girl, presented to our Outp...",The introduction and subsequent withdrawal of ...
1,Her past medical history included mild Multipl...,"A 34 year old Persian woman, gravida 1, para 0..."
2,A 60-year-old female who was previously health...,She was treated with thrombolysis and endovasc...
3,,classification <diagnosis>
4,The chief complains included bilateral hips an...,Staged bilateral total hip arthroplasties were...
...,...,...
743,The patient was a 60-year-old man who referred...,A 45 mm × 37 mm pseudoaneurysm in lateral side...
744,A 38 year old Vietnamese man was admitted with...,The working diagnosis was a collection seconda...
745,A 30-year-old woman was admitted to our instit...,She had undergone endovascular trapping of the...
746,A 51-year-old hypertensive Pakistani male pati...,The patient underwent general anesthesia for t...


In [110]:
import torch
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
tokenizer(list(cases_datasets['aug_med_notes']['X']), padding=True, truncation=True, return_tensors='pt')

{'input_ids': tensor([[ 101, 1037, 7032,  ..., 1037, 6578,  102],
        [ 101, 2014, 2627,  ...,    0,    0,    0],
        [ 101, 1037, 3438,  ...,    0,    0,    0],
        ...,
        [ 101, 1037, 2382,  ...,    0,    0,    0],
        [ 101, 1037, 4868,  ...,    0,    0,    0],
        [ 101, 2019, 2324,  ...,    0,    0,    0]]), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])}

In [108]:
cases_datasets['aug_med_notes']['X']

0      A sixteen year-old girl, presented to our Outp...
1      Her past medical history included mild Multipl...
2      A 60-year-old female who was previously health...
3                                                       
4      The chief complains included bilateral hips an...
                             ...                        
743    The patient was a 60-year-old man who referred...
744    A 38 year old Vietnamese man was admitted with...
745    A 30-year-old woman was admitted to our instit...
746    A 51-year-old hypertensive Pakistani male pati...
747    An 18-year-old male patient (height 140 cm, we...
Name: X, Length: 748, dtype: object