### Preprocessing for Mimic4 Dataset

**Note that the AUG dataset is incomplete, so we skip it here**

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from medrlcot.config.env import MedRL_CoT
from medrlcot import data_manager
from medrlcot.medrlcot_logger import setup_logger
from dotenv import load_dotenv
from datasets import Features, Value
import torch
from IPython.display import display
import numpy as np
import pandas as pd
import datasets as hf_datasets
import logging
import os
import json

# pd.set_option('display.max_colwidth', None)  
# pd.set_option('display.width', 0)    
# pd.set_option('display.max_rows', None)  

In [3]:
model_cfg_path = os.path.join(os.getcwd(), "medrlcot/config/.env")
medrlcot_config = MedRL_CoT(model_cfg_path)

setup_logger()
logger = logging.getLogger("MedRL-CoT Preprocess")

2025-06-04 01:58:52,214 || INFO || Logger - Setup for MedRL-CoT's log done. This is the beginning of the log.


Generated new log file logs/medrlcot144.log


In [4]:
# processed_dirs = [os.path.join(os.getcwd(), medrlcot_config.data_dir, ds, 'processed') for ds in medrlcot_config.datasets]
processed_dirs = {ds: os.path.join(os.getcwd(), medrlcot_config.data_dir, ds, 'processed') for ds in medrlcot_config.datasets}
processed_dirs.keys()

dict_keys(['aug_med_notes', 'mimic4'])

In [5]:
classes = np.array(['symptoms_labs', 'thought_process', 'diagnosis'])

In [6]:
def load_labeled(arrow_dir):
    arrows = [os.path.join(arrow_dir, f) for f in os.listdir(arrow_dir) if f.endswith(".arrow")]
    processed_dataset = hf_datasets.concatenate_datasets([hf_datasets.Dataset.from_file(arrow) for arrow in arrows]).to_dict()

    return processed_dataset

In [7]:
processed_datasets = {key: pd.DataFrame(load_labeled(processed_dir)) for key, processed_dir in processed_dirs.items()}
processed_datasets.keys()

dict_keys(['aug_med_notes', 'mimic4'])

In [8]:
def mimic_preprocess(dataset):
    cleaned_ds = dataset.copy()

    # For loggin purposes
    N = cleaned_ds.shape[0]
    logger.info(f"Found {N} rows")
    num_renames = cleaned_ds[cleaned_ds['class'].isin(['symptoms_lbs', 'symptoms_lads'])].shape[0]
    logger.info(f"Fixed class naming for {num_renames} rows ({(num_renames / N)*100} %)")
    
    # Fix naming of some classes
    cleaned_ds['class'] = cleaned_ds['class'].replace({'symptoms_lbs': 'symptoms_labs', 'symptoms_lads': 'symptoms_labs'})
    
    pos_invalids = cleaned_ds[~cleaned_ds['class'].isin(classes)]
    swapped_values = pos_invalids[pos_invalids['sentence'].str.lower().isin(classes)]    # Rows with swapped values
    
    # Clearly invalids, temp drop to remove from our ceaned list
    invalid_classes = pos_invalids[pos_invalids['class'].str.lower().isin(['', '0', '__', 'None'])]  # collect empty sentence and classes (Note that doing it here will catch the invalid swapped sentences as well)
    invalid_sentences = pos_invalids[pos_invalids['sentence'].str.lower().isin(['', '__', 'None'])]
    invalids = pos_invalids.loc[invalid_classes.index.union(invalid_sentences.index)]
    nonstd_classes = pos_invalids.drop(index=swapped_values.index.union(invalids.index))   # Get list of non-standard classes

    # Get list of classes that can be classified as "other" with enough occurence (non-outliery)
    # value_cnts = nonstd_classes['class'].value_counts()
    # other_classes = value_cnts[value_cnts >= 5].index.tolist()
    value_cnts = nonstd_classes['class'].value_counts()
    other_classes = value_cnts[value_cnts >= 5].index.tolist()
    other_class_indices = nonstd_classes[nonstd_classes['class'].isin(other_classes)].index
    # print(value_cnts[value_cnts >= 5])
    # display(cleaned_ds[cleaned_ds['class'] == 'past_surgical_history'])
    # display(cleaned_ds[cleaned_ds['class'] == 'followup_instructions'])
    # display(cleaned_ds[cleaned_ds['class'] == 'demographic_data'])
    # display(cleaned_ds[cleaned_ds['class'] == 'past_surgical_history'])
    
    # Clean the dataset
    swapped_indices = swapped_values.index
    cleaned_ds.loc[swapped_indices, ['sentence', 'class']] = cleaned_ds.loc[swapped_indices, ['class', 'sentence']].values # swap the values in indices where it's swapped
    cleaned_ds.loc[other_class_indices, 'class'] = 'other'
    # cleaned_ds['class'] = cleaned_ds['class'].apply(lambda x: 'other' if x in other_classes else x)  # relabel non-standards to 'other'
    drop_indices = cleaned_ds[~cleaned_ds['class'].isin(np.append(classes, 'other'))].index
    cleaned_ds = cleaned_ds.drop(index=drop_indices) # drop all others that aren't in our list of classes + 'other'  (basically all invalids)

    # Summary
    num_reclass = cleaned_ds[cleaned_ds['class'] == 'other'].shape[0]
    logger.info(f"Re-classified {len(other_classes)} classes as 'other', or {num_reclass} rows ({(num_reclass / N)*100} %)")
    logger.info(f'Swapped class and sentence values of {swapped_indices.shape[0]} rows ({(swapped_indices.shape[0] / N)*100} %)')
    logger.info(f'Dropped {drop_indices.shape[0]} invalid rows ({(drop_indices.shape[0] / N)*100} %)')

    # change case_id into int
    cleaned_ds['case_id'] = cleaned_ds['case_id'].astype(int)

    return cleaned_ds

mimic_preprocess(processed_datasets['mimic4'])['class'].value_counts()

2025-06-04 01:58:52,634 || INFO || MedRL-CoT Preprocess - Found 29654 rows
2025-06-04 01:58:52,637 || INFO || MedRL-CoT Preprocess - Fixed class naming for 282 rows (0.9509678289606799 %)
2025-06-04 01:58:52,654 || INFO || MedRL-CoT Preprocess - Re-classified 18 classes as 'other', or 1073 rows (3.618398866931948 %)
2025-06-04 01:58:52,655 || INFO || MedRL-CoT Preprocess - Swapped class and sentence values of 571 rows (1.925541242328185 %)
2025-06-04 01:58:52,656 || INFO || MedRL-CoT Preprocess - Dropped 169 invalid rows (0.5699062521076415 %)


class
symptoms_labs      16680
diagnosis           9112
thought_process     2620
other               1073
Name: count, dtype: int64

In [9]:
def aug_preprocess(dataset):
    cleaned_ds = dataset.copy()

    # For loggin purposes
    N = cleaned_ds.shape[0]
    logger.info(f"Found {N} rows")
    num_renames = cleaned_ds[cleaned_ds['class'].isin(['symptoms_lbs', 'symptoms_lads'])].shape[0]
    logger.info(f"Fixed class naming for {num_renames} rows ({(num_renames / N)*100} %)")
    
    # Fix naming of some classes
    cleaned_ds['class'] = cleaned_ds['class'].replace({'symptoms_lbs': 'symptoms_labs', 'symptoms_lads': 'symptoms_labs'})
    
    pos_invalids = cleaned_ds[~cleaned_ds['class'].isin(classes)]
    swapped_values = pos_invalids[pos_invalids['sentence'].isin(classes)]    # Rows with swapped values
    invalid_classes = pos_invalids[pos_invalids['class'].str.lower().isin(['', '0', '__', 'None', '[]', 'False', 'The', 'No'])]  # collect empty sentence and classes (Note that doing it here will catch the invalid swapped sentences as well)
    ignore_classes = pos_invalids[pos_invalids['sentence'].str.contains('not a sentence')]
    invalid_sentences = pos_invalids[pos_invalids['sentence'].str.lower().isin(['', '__', 'None', '[]', 'False', '()'])]

    invalids = pos_invalids.loc[invalid_classes.index.union(invalid_sentences.index.union(ignore_classes.index))]
    # invalids = pos_invalids.loc[invalid_classes.index.union(invalid_sentences.index)]
    # invalids = pos_invalids[pos_invalids ['class'].isin(['', '0', '[]', 'False'])]   # Clearly invalids, temp drop to remove from our ceaned list
    nonstd_classes = pos_invalids.drop(index=swapped_values.index.union(invalids.index))   # Get list of non-standard class rows
    # print(ignore_classes.index[0] in list(nonstd_classes.index))
    # print(ignore_classes.index[0] in list(invalids.index))
    # print(nonstd_classes[nonstd_classes['sentence'].str.contains('not a sentence')])

    # Get list of classes that can be classified as "other" with enough occurence (non-outliery)
    # value_cnts = nonstd_classes['class'].value_counts()
    # other_classes = value_cnts[value_cnts >= 5].index.tolist()
    # other_class_rows = nonstd_classes[nonstd_classes['class'].isin(other_classes)]
    # other_class_rows = other_class_rows[~other_class_rows['sentence'].str.contains('thought_process')]
    # other_class_rows = other_class_rows[~other_class_rows['sentence'].str.contains('symptoms_labs')]
    # other_class_indices = other_class_rows[~other_class_rows['sentence'].str.contains('diagnosis')].index  # Note we get rid of this because bad classification when it should've been "diagnosis", also removes "diagnosis: " sentences
    # print(other_class_indices)
    # print(value_cnts[value_cnts >= 5])
    # display(cleaned_ds[cleaned_ds['class'] == 'symptoms_lbs'])
    # display(cleaned_ds[cleaned_ds['class'] == 'A'])
    # display(cleaned_ds[cleaned_ds['class'] == 'classification'])
    # display(cleaned_ds[cleaned_ds['class'] == 'No'])
    # display(cleaned_ds[cleaned_ds['class'] == 'The'])

    swapped_indices = swapped_values.index
    cleaned_ds.loc[swapped_indices, ['sentence', 'class']] = cleaned_ds.loc[swapped_indices, ['class', 'sentence']].values # swap the values in indices where it's swapped
    # cleaned_ds.loc[other_class_indices, 'class'] = 'other'
    # cleaned_ds['class'] = cleaned_ds['class'].apply(lambda x: 'other' if x in other_classes else x)  # relabel non-standards to 'other'
    drop_indices = cleaned_ds[~cleaned_ds['class'].isin(classes)].index
    # print(ignore_classes.index in list(drop_indices))
    # print(ignore_classes.index)
    # print(list(drop_indices))
    
    cleaned_ds = cleaned_ds.drop(index=drop_indices) # drop all others that aren't in our list of classes + 'other' (basically all invalids)

    # Summary
    num_reclass = cleaned_ds[cleaned_ds['class'] == 'other'].shape[0]
    # logger.info(f"Re-classified {len(other_classes)} classes as 'other', or {num_reclass} rows ({(num_reclass / N)*100} %)")
    logger.info(f'Swapped class and sentence values of {swapped_indices.shape[0]} rows ({(swapped_indices.shape[0] / N)*100} %)')
    logger.info(f'Dropped {drop_indices.shape[0]} invalid rows ({(drop_indices.shape[0] / N)*100} %)')

    # change case_id into int
    cleaned_ds['case_id'] = cleaned_ds['case_id'].astype(int)
    # print(cleaned_ds[cleaned_ds['sentence'].str.contains('not a sentence')])
    
    return cleaned_ds

aug_preprocess(processed_datasets['aug_med_notes'])

2025-06-04 01:58:52,708 || INFO || MedRL-CoT Preprocess - Found 19968 rows
2025-06-04 01:58:52,712 || INFO || MedRL-CoT Preprocess - Fixed class naming for 21 rows (0.10516826923076923 %)
2025-06-04 01:58:52,724 || INFO || MedRL-CoT Preprocess - Swapped class and sentence values of 131 rows (0.6560496794871795 %)
2025-06-04 01:58:52,725 || INFO || MedRL-CoT Preprocess - Dropped 240 invalid rows (1.201923076923077 %)


Unnamed: 0,sentence,class,case_id
0,"A sixteen year-old girl, presented to our Outp...",symptoms_labs,-1
1,She was not able to maintain an erect posture ...,symptoms_labs,-1
2,She would keep her head turned to the right an...,symptoms_labs,-1
3,There was a sideways bending of the back in th...,symptoms_labs,-1
4,To counter the abnormal positioning of the bac...,symptoms_labs,-1
...,...,...,...
19963,"No significant muscle atrophy was present, and...",symptoms_labs,912
19964,The physical examination showed that both of h...,symptoms_labs,915
19965,Both knees were very soft and could touch the ...,symptoms_labs,915
19966,"Upon palpation, the continuity of the quadrice...",symptoms_labs,915


In [10]:
# Testing
proc_funcs = {'mimic4': mimic_preprocess, 'aug_med_notes': aug_preprocess}

In [11]:
preprocessed_datasets = dict()
for key, item in processed_datasets.items():
    logger.info("=" * 50)
    logger.info(f"Cleaning up {key} dataset")
    # processed_datasets[key] = mimic_preprocess(processed_datasets[key])
    # mimic_preprocess(processed_datasets[key])
    preprocessed_datasets[key] = proc_funcs[key](processed_datasets[key])
    logger.info("=" * 50)

2025-06-04 01:58:52,815 || INFO || MedRL-CoT Preprocess - Cleaning up aug_med_notes dataset
2025-06-04 01:58:52,817 || INFO || MedRL-CoT Preprocess - Found 19968 rows
2025-06-04 01:58:52,819 || INFO || MedRL-CoT Preprocess - Fixed class naming for 21 rows (0.10516826923076923 %)
2025-06-04 01:58:52,831 || INFO || MedRL-CoT Preprocess - Swapped class and sentence values of 131 rows (0.6560496794871795 %)
2025-06-04 01:58:52,832 || INFO || MedRL-CoT Preprocess - Dropped 240 invalid rows (1.201923076923077 %)
2025-06-04 01:58:52,836 || INFO || MedRL-CoT Preprocess - Cleaning up mimic4 dataset
2025-06-04 01:58:52,838 || INFO || MedRL-CoT Preprocess - Found 29654 rows
2025-06-04 01:58:52,841 || INFO || MedRL-CoT Preprocess - Fixed class naming for 282 rows (0.9509678289606799 %)
2025-06-04 01:58:52,858 || INFO || MedRL-CoT Preprocess - Re-classified 18 classes as 'other', or 1073 rows (3.618398866931948 %)
2025-06-04 01:58:52,860 || INFO || MedRL-CoT Preprocess - Swapped class and sentence 

In [12]:
preprocessed_datasets.keys()

dict_keys(['aug_med_notes', 'mimic4'])

In [13]:
preprocessed_datasets['mimic4']['case_id'].unique()

array([ -1,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,  14,
        15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
        28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,
        41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,
        54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,
        67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,  78,  79,
        80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,  91,  92,
        93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103, 104, 105,
       106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118,
       119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131,
       132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144,
       145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157,
       158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170,
       171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 18

In [14]:
# def join_sentence_class(group):
#     return ' '.join(f"{row['sentence']} <{row['class']}>" for _, row in group.iterrows())
    
# # preprocessed_datasets['mimic4'].groupby('case_id').apply(join_sentence_class).reset_index(name='full_class_case').sort_values('case_id')

In [15]:
# cases_datasets = dict()
# for key, dataset in preprocessed_datasets.items():
#     cases_datasets[key] = dataset.groupby('case_id').apply(join_sentence_class).reset_index(name='full_class_case').sort_values('case_id')

In [16]:
# cases_datasets['aug_med_notes']

In [17]:
# cases_datasets['aug_med_notes']

In [18]:
# preprocessed_datasets['aug_med_notes'][preprocessed_datasets['aug_med_notes']['class'] == 'other']

In [19]:
def xy_split_processing_sft(group, x_func=None, y_func=None):
    X = []
    Y = []
    
    for _, row in group.iterrows():
        if row['class'] == 'symptoms_labs' or row['class'] == 'other':
            X.append(row)
        else:
            Y.append(row)

    # X_case = ' '.join(f"{row['sentence']} <{row['class']}>" for _, row in group.iterrows()) # Input with all
    X_case = x_func(X) if x_func else ' '.join([str(row['sentence']) for row in X])
    Y_case = ' '.join([f"{row['sentence']} <{row['class']}> " for row in Y])    # Output with only thought_process and diagnosis
    
    X_prompt = f"""Below is a clinical case. Your task is to provide a step-by-step clinical reasoning followed by the diagnosis.

    {X_case}
    """
    
    return pd.Series({'X': X_prompt, 'Y': Y_case})

import medrlcot.preprocessing as mp
preprocessed_datasets = mp.preprocess_datasets()

# Combine cases into one for cases as example for SFT
cases_datasets = dict()
for key, dataset in preprocessed_datasets.items():
    cases_datasets[key] = dataset.groupby('case_id').apply(mp.xy_split_processing_sft).reset_index().sort_values('case_id')
    cases_datasets[key]['wc_x'] = cases_datasets[key]['X'].apply(lambda x: len(x.split())) 
    cases_datasets[key]['wc_y'] = cases_datasets[key]['Y'].apply(lambda x: len(x.split())) 
    cases_datasets[key] = cases_datasets[key][(cases_datasets[key]['wc_x'] <= 400) & (cases_datasets[key]['wc_y'] <= 200)].reset_index(drop=True)

import os
from medrlcot import data_manager
import medrlcot.config.env as mce
model_cfg_path = os.path.join(os.getcwd(), "medrlcot/config/.env")
medrlcot_config = mce.MedRL_CoT(model_cfg_path)
raw_datasets = data_manager.load_datasets(medrlcot_config.datasets, data_dir=medrlcot_config.data_dir)  # Load raw dataset (original cases) for RM

2025-06-04 01:58:53,237 || INFO || Logger - Setup for MedRL-CoT's log done. This is the beginning of the log.


Generated new log file logs/medrlcot145.log


2025-06-04 01:58:53,444 || INFO || MedRL-CoT Preprocess - Cleaning up aug_med_notes dataset
2025-06-04 01:58:53,446 || INFO || MedRL-CoT Preprocess - Found 19968 rows
2025-06-04 01:58:53,448 || INFO || MedRL-CoT Preprocess - Fixed class naming for 21 rows (0.10516826923076923 %)
2025-06-04 01:58:53,471 || INFO || MedRL-CoT Preprocess - Re-classified 5 classes as 'other', or 76 rows (0.38060897435897434 %)
2025-06-04 01:58:53,472 || INFO || MedRL-CoT Preprocess - Swapped class and sentence values of 131 rows (0.6560496794871795 %)
2025-06-04 01:58:53,473 || INFO || MedRL-CoT Preprocess - Dropped 224 invalid rows (1.1217948717948718 %)
2025-06-04 01:58:53,476 || INFO || MedRL-CoT Preprocess - Cleaning up mimic4 dataset
2025-06-04 01:58:53,478 || INFO || MedRL-CoT Preprocess - Found 29654 rows
2025-06-04 01:58:53,480 || INFO || MedRL-CoT Preprocess - Fixed class naming for 282 rows (0.9509678289606799 %)
2025-06-04 01:58:53,504 || INFO || MedRL-CoT Preprocess - Swapped class and sentence 

In [20]:
cases_datasets['aug_med_notes']

Unnamed: 0,case_id,X,Y,wc_x,wc_y
0,100,Provide a step-by-step clinical reasoning foll...,"Thought Process: A 34 year old Persian woman, ...",159,168
1,101,Provide a step-by-step clinical reasoning foll...,Thought Process: She was treated with thrombol...,180,154
2,102,Provide a step-by-step clinical reasoning foll...,Thought Process: \n\n Diagnosis: \n,12,3
3,103,Provide a step-by-step clinical reasoning foll...,"Thought Process: In the past medical history, ...",374,155
4,104,Provide a step-by-step clinical reasoning foll...,Thought Process: The initial workup was unreve...,248,109
...,...,...,...,...,...
512,915,Provide a step-by-step clinical reasoning foll...,Thought Process: \n\n Diagnosis: \n,87,3
513,93,Provide a step-by-step clinical reasoning foll...,Thought Process: The patient initially refused...,329,48
514,94,Provide a step-by-step clinical reasoning foll...,Thought Process: The working diagnosis was a c...,147,49
515,97,Provide a step-by-step clinical reasoning foll...,Thought Process: \n\n Diagnosis: \n,287,3


In [21]:
preprocessed_datasets['aug_med_notes']['class'].value_counts()

class
symptoms_labs      9977
thought_process    5318
diagnosis          4373
other                76
Name: count, dtype: int64

In [22]:
# filtered_datasets = dict()
# for key, dataset in cases_datasets.items():
#     dataset['wc_x'] = dataset['X'].apply(lambda x: len(x.split())) 
#     dataset['wc_y'] = dataset['Y'].apply(lambda x: len(x.split())) 
#     filtered_datasets[key] = dataset[(dataset['wc_x'] <= 800) & (dataset['wc_y'] <= 900)].reset_index(drop=True)

# print(cases_datasets['aug_med_notes'].shape, filtered_datasets['aug_med_notes'].shape)
# print(cases_datasets['mimic4'].shape, filtered_datasets['mimic4'].shape)

In [23]:
max_words = (None, 0)
next_max = (None, -1)
num_exceed = 0
for key, dataset in cases_datasets.items():
    for row in dataset.iterrows():
        cnt = ((key, row), len(row[1]['X'].split()))
        if cnt[1] > 750:
            num_exceed += 1
        if cnt[1] > max_words[1]:
            next_max = max_words
            max_words = cnt
        else:
            next_max = max(cnt, next_max, key=lambda x: x[1])
print(max_words, next_max)
num_exceed

(('aug_med_notes', (175, case_id                                                  383
X          Provide a step-by-step clinical reasoning foll...
Y                 Thought Process: \n\n    Diagnosis: \n    
wc_x                                                     399
wc_y                                                       3
Name: 175, dtype: object)), 399) (('mimic4', (31, case_id                                                  196
X          Provide a step-by-step clinical reasoning foll...
Y          Thought Process: ___ 05:25AM   estGFR-Using th...
wc_x                                                     399
wc_y                                                     116
Name: 31, dtype: object)), 399)


0

In [24]:
# Create customd dataset obj
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small", device_map="auto")
cases_data = data_manager.MedRL_CoT_Dataset(cases_datasets, seed=0, tokenizer=tokenizer)

2025-06-04 01:58:56,467 || INFO || DataManager - Using train-val split: [0.75, 0.25]
2025-06-04 01:58:56,468 || INFO || DataManager - Split-Shuffle with seed 0
2025-06-04 01:58:56,469 || INFO || DataManager - Splitting aug_med_notes dataset.
2025-06-04 01:58:56,471 || INFO || DataManager - Split into 387 train rows and 130 val rows
2025-06-04 01:58:56,472 || INFO || DataManager - Splitting mimic4 dataset.
2025-06-04 01:58:56,474 || INFO || DataManager - Split into 80 train rows and 27 val rows
2025-06-04 01:58:56,475 || INFO || DataManager - Creating a single joint dataset of dict_keys(['aug_med_notes', 'mimic4']) dataset splits
2025-06-04 01:58:56,476 || INFO || DataManager - Joined 467 train rows
2025-06-04 01:58:56,477 || INFO || DataManager - Shuffling train's rows
2025-06-04 01:58:56,479 || INFO || DataManager - Joined 157 val rows
2025-06-04 01:58:56,480 || INFO || DataManager - Shuffling val's rows
2025-06-04 01:58:56,481 || INFO || DataManager - Returning shuffled train-val spl

Index([], dtype='int64')
Index([], dtype='int64')


In [25]:
# cases_data['train'] = cases_data['train'].drop(cases_data['train'][cases_data['train']['Y'].str.strip() == ''].index)
# print(cases_data['train'][cases_data['train']['X'].str.strip() == ''].index)

In [26]:
cases_data['train']

Unnamed: 0,X,Y
0,Provide a step-by-step clinical reasoning foll...,Thought Process: Informed written consent was ...
1,Provide a step-by-step clinical reasoning foll...,Thought Process: Clinically a differential dia...
2,Provide a step-by-step clinical reasoning foll...,Thought Process: continue until INR >2.\n\n ...
3,Provide a step-by-step clinical reasoning foll...,Thought Process: Pt admitting to walking aroun...
4,Provide a step-by-step clinical reasoning foll...,Thought Process: \n\n Diagnosis: \n
...,...,...
462,Provide a step-by-step clinical reasoning foll...,Thought Process: He gave a history of having w...
463,Provide a step-by-step clinical reasoning foll...,Thought Process: Because of metastatic adenoca...
464,Provide a step-by-step clinical reasoning foll...,Thought Process: \n\n Diagnosis: \n
465,Provide a step-by-step clinical reasoning foll...,Thought Process: No other investigations were ...


In [27]:
cases_data['train'][cases_data['train']['Y'].str.strip() == '']

Unnamed: 0,X,Y


In [28]:
cases_data['train']

Unnamed: 0,X,Y
0,Provide a step-by-step clinical reasoning foll...,Thought Process: Informed written consent was ...
1,Provide a step-by-step clinical reasoning foll...,Thought Process: Clinically a differential dia...
2,Provide a step-by-step clinical reasoning foll...,Thought Process: continue until INR >2.\n\n ...
3,Provide a step-by-step clinical reasoning foll...,Thought Process: Pt admitting to walking aroun...
4,Provide a step-by-step clinical reasoning foll...,Thought Process: \n\n Diagnosis: \n
...,...,...
462,Provide a step-by-step clinical reasoning foll...,Thought Process: He gave a history of having w...
463,Provide a step-by-step clinical reasoning foll...,Thought Process: Because of metastatic adenoca...
464,Provide a step-by-step clinical reasoning foll...,Thought Process: \n\n Diagnosis: \n
465,Provide a step-by-step clinical reasoning foll...,Thought Process: No other investigations were ...


In [29]:
cases_data['train']['Y'].iloc[0]

'Thought Process: Informed written consent was taken from the patient\'s parent. The patient\'s parents were made aware of nature of disease and instructed to take possible precautions." -> We have kept him under observation because any surgical intervention of ossified muscle might lead to further deterioration of condition as it has been experienced by patient with two previous surgeries." ->\n\n    Diagnosis: \n    '

In [30]:
cases_data.get_dataloader('train')

<torch.utils.data.dataloader.DataLoader at 0x7fd9db3ca290>

In [31]:
cases_data.get_dataloader('val')

<torch.utils.data.dataloader.DataLoader at 0x7fd9db3cbdc0>

In [32]:
from transformers import AutoModelForSeq2SeqLM
model  = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")

In [33]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq

training_args = Seq2SeqTrainingArguments(
    output_dir="./checkpoints",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    predict_with_generate=True,
    # fp16=True,
    learning_rate=5e-5,
    num_train_epochs=3,
    save_strategy="epoch",
    logging_dir="./logs"
)

In [34]:
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    padding="longest",
    return_tensors="pt"
)

In [35]:
def compute_metrics(eval_preds): # https://www.datacamp.com/tutorial/flan-t5-tutorial
   preds, labels = eval_preds

   # decode preds and labels
   labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
   decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
   decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

   # rougeLSum expects newline after each sentence
   decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
   decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]

   result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
  
   return result

In [36]:
train_tok_dataset = cases_data.get_torch_dataset('train')
val_tok_dataset = cases_data.get_torch_dataset('val')

In [37]:
t = val_tok_dataset[80]
t['labels']

tensor([ 4229,    17, 10272,    10, 17656,    47, 21119, 10058,    12,  1130,
           95,    38,  6640,    38,     8,    16,    26,  2091,    53,     3,
         1462,   449,   138,     3,     7,  4669,   398,    36,  3641,    11,
           42,  2509,    26,     5, 10747,     7,    15,   638, 11706,    41,
         7171,   302,   342, 19049,    61,   164,    43,   118, 14327,    12,
         1792,   442,    18, 31058,  6900, 11537,   257,    42,  6900, 11537,
          257,  1341,    12,   169,    13,     3,    29,  4667,  9798,  1406,
        11208,     5,  2678, 14175,    15,     3,    99,  6044, 22429,    42,
        28582,  1344,     7,     5,   638, 11706,    19,     3,     9, 22429,
           18, 12369,    35,    49,     6,  4486,     3,     9,    50,   226,
         1528,     5,    37,  1868,    31,     7,   892,    13,  6676, 13177,
           11,  6676, 19437, 11658,    47,  1702,     5,  5267,  6715,     7,
          159,    10,     3,     1])

In [38]:
tokenizer.decode(t['input_ids'])

"Provide a step-by-step clinical reasoning followed by the diagnosis: History of Present Illness: This is a ___ year old female who presents with left flank pain x4 days. Patient has short term memory loss and is a poor historian, history obtained with assistance of patient's mother. Pain in the left flank started approximately 4 days ago associated with dysuria and nausea. Unsure if she has had fever, chills but currently denies f/c/cp/sob. Never had a stone before. Last PO last night. Social History: Family History: Physical Exam: GEN: NAD, resting comfortably, AAO HEENT: NCAT, EOMI, anicteric sclera PULM: nonlabored breathing, normal chest rise ABD: soft, NT, ND, no rebound/guarding EXT: WWP Pertinent Results: 04:50PM BLOOD WBC-9.1 RBC-3.96 Hgb-12.1 Hct-37.0 MCV-93 MCH-30.6 MCHC-32.7 RDW-13.4 RDWSD-45.9 Plt ___ 06:05AM BLOOD Glucose-102* UreaN-19 Creat-1.7* Na-142 K-3.8 Cl-105 HCO3-20* AnGap-17 Medications on Admission: The Preadmission Medication list may be inaccurate and requires

In [39]:
label_ids = t['labels']
# Remove masked (-100) values
cleaned_labels = [token_id for token_id in label_ids if token_id != -100]
decoded = tokenizer.decode(cleaned_labels, skip_special_tokens=True)
print(decoded)

Thought Process: Patient was explicitly advised to follow up as directed as the indwelling ureteral stent must be removed and or exchanged. False Colace (docusate sodium) may have been prescribed to avoid post-surgical constipation or constipation related to use of narcotic pain medications. Discontinue if loose stool or diarrhea develops. Colace is a stool-softener, NOT a laxative. The patient's history of hypertension and hyperlipidemia was considered. Diagnosis: 


In [40]:
inputs = val_tok_dataset[80]
input_ids = inputs["input_ids"].unsqueeze(0).to(model.device)
attention_mask = inputs["attention_mask"].unsqueeze(0).to(model.device)

with torch.no_grad():
    outputs = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=200)

decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(decoded_output)

1. Acetaminophen 325-650 mg PO Q6H:PRN pain, fever 2. Atorvastatin 325-650 mg PO Q6H:PRN pain, fever 3. Atorvastatin 4 mg PO Q6H:PRN pain, fever 4. Gabapentin 5 mg PO BID 5. Multivitamins 1 TAB PO DAILY 6. Omeprazole 20 mg PO QPM 7. Topiramate (Topamax) 100 mg PO BID 8. Warfarin 2.5 mg PO 3X/WEEK (_____) 9. Warfarin 5 mg PO 4X/WEEK (________________________________________________________


In [41]:
# sample = val_tok_dataset[0]
# print(tokenizer.decode(sample["input_ids"]))|
# print(tokenizer.decode(sample['labels']))

In [42]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_tok_dataset,
    eval_dataset=val_tok_dataset,
    compute_metrics=compute_metrics,
)

In [43]:
import torch
sample = train_tok_dataset[0]
batch = {k: v.unsqueeze(0).to(model.device) for k, v in sample.items()}
with torch.no_grad():
    output = model(**batch)
print("Loss:", output.loss.item())


== Input Text ==
Provide a step-by-step clinical reasoning followed by the diagnosis: A 5-year-old male child came to the department of oral and maxillofacial surgery with a complaint of difficulty in opening the mouth for the past 2 years. The birth history was normal. No other member of the family was similarly affected. His mouth opening was normal until he met with a trauma with wooden thorn on the left cheek region. Thorn was removed by a surgeon immediately, but the patient had progressively reduced mouth opening since then. Computed tomography scan showed a radiodense mass extending in front of the anterior border of the ramus of the mandible, suggestive of ossified masseter muscle []. Magnetic resonance imaging revealed evidence of a large elongated T1–T2 intermediate signal intensity lesion with mild surrounding edema in the substance of left masseter muscle, abutting the ramus of left mandible likely to represent heterotopic bone formation (at the expected location of previou

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Loss: 3.3646023273468018


In [44]:
model.gradient_checkpointing_enable()

In [45]:
trainer.train()

  batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64)
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Step,Training Loss
500,2.9701
1000,2.6926


== Input Text ==
Provide a step-by-step clinical reasoning followed by the diagnosis: A 5-year-old male child came to the department of oral and maxillofacial surgery with a complaint of difficulty in opening the mouth for the past 2 years. The birth history was normal. No other member of the family was similarly affected. His mouth opening was normal until he met with a trauma with wooden thorn on the left cheek region. Thorn was removed by a surgeon immediately, but the patient had progressively reduced mouth opening since then. Computed tomography scan showed a radiodense mass extending in front of the anterior border of the ramus of the mandible, suggestive of ossified masseter muscle []. Magnetic resonance imaging revealed evidence of a large elongated T1–T2 intermediate signal intensity lesion with mild surrounding edema in the substance of left masseter muscle, abutting the ramus of left mandible likely to represent heterotopic bone formation (at the expected location of previou

TrainOutput(global_step=1401, training_loss=2.7813816547053443, metrics={'train_runtime': 276.8587, 'train_samples_per_second': 5.06, 'train_steps_per_second': 5.06, 'total_flos': 173532779882496.0, 'train_loss': 2.7813816547053443, 'epoch': 3.0})

In [46]:
# example_case = cases_data['val'].iloc[-1]

# val_input = example_case["X"]
# val_target = example_case["Y"]
# val_input

In [47]:
val_tok_dataset[80]

{'input_ids': tensor([ 7740,     3,     9,  1147,    18,   969,    18,  7910,  3739, 20893,
          2348,    57,     8,  8209,    10,  5528,    13, 18795,    27,   195,
           655,    10,   100,    19,     3,     9,     3,   834,   834,   834,
           215,   625,  3955,   113,  6621,    28,   646, 24397,  1406,     3,
           226,   591,   477,     5, 17656,    65,   710,  1657,  2594,  1453,
            11,    19,     3,     9,  2714, 18637,     6,   892,  5105,    28,
          2927,    13,  1868,    31,     7,  2039,     5, 19043,    16,     8,
           646, 24397,   708,  3241,   314,   477,   977,  1968,    28, 16633,
           459,     9,    11, 25808,     5,   597,  4334,     3,    99,   255,
            65,   141, 17055,     6, 10191,     7,    68,  1083,   177,   725,
             3,    89,    87,    75,    87,    75,   102,    87,     7,    32,
           115,     5,  8400,   141,     3,     9,  3372,   274,     5,  2506,
          9915,   336,   706,     5,  2

In [48]:
# tokenizer.decode(inputs['input_ids'])

In [49]:
tokenizer.decode(train_tok_dataset[0]['input_ids'])

== Input Text ==
Provide a step-by-step clinical reasoning followed by the diagnosis: A 5-year-old male child came to the department of oral and maxillofacial surgery with a complaint of difficulty in opening the mouth for the past 2 years. The birth history was normal. No other member of the family was similarly affected. His mouth opening was normal until he met with a trauma with wooden thorn on the left cheek region. Thorn was removed by a surgeon immediately, but the patient had progressively reduced mouth opening since then. Computed tomography scan showed a radiodense mass extending in front of the anterior border of the ramus of the mandible, suggestive of ossified masseter muscle []. Magnetic resonance imaging revealed evidence of a large elongated T1–T2 intermediate signal intensity lesion with mild surrounding edema in the substance of left masseter muscle, abutting the ramus of left mandible likely to represent heterotopic bone formation (at the expected location of previou

"Provide a step-by-step clinical reasoning followed by the diagnosis: A 5-year-old male child came to the department of oral and maxillofacial surgery with a complaint of difficulty in opening the mouth for the past 2 years. The birth history was normal. No other member of the family was similarly affected. His mouth opening was normal until he met with a trauma with wooden thorn on the left cheek region. Thorn was removed by a surgeon immediately, but the patient had progressively reduced mouth opening since then. Computed tomography scan showed a radiodense mass extending in front of the anterior border of the ramus of the mandible, suggestive of ossified masseter muscle []. Magnetic resonance imaging revealed evidence of a large elongated T1–T2 intermediate signal intensity lesion with mild surrounding edema in the substance of left masseter muscle, abutting the ramus of left mandible likely to represent heterotopic bone formation (at the expected location of previous surgery for th

In [50]:
inputs = train_tok_dataset[10]
input_ids = inputs["input_ids"].unsqueeze(0).to(model.device)
attention_mask = inputs["attention_mask"].unsqueeze(0).to(model.device)

with torch.no_grad():
    outputs = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=256)

decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(decoded_output)



Though),


In [51]:
# example_case = """
# John Doe is a 40-year-old male who was involved in a motor vehicle accident. 
# He was a restrained passenger who was rear-ended by another car going approximately 45 mph. 
# He has no medical history and takes no medications at home. 
# He has no allergies. His primary physician is Dr. Johnson. He has been awake, alert, and oriented with a Glasgow Coma Score of 15 since arrival. 
# He is moving all extremities and has good strength bilaterally. 
# His chief complaint is neck pain, rated a 5 out of 10, and he remains in cervical-spine precautions until the trauma team clears him. 
# He is to be kept NPO until cleared by the trauma team as well. CT scans have been completed of the head, cervical spine, chest, abdomen, and pelvis. 
# We are pending reports from radiology. Last vital signs, 15 minutes ago, were temperature 36.6, pulse 80, respiratory rate 14, and blood pressure 120/80. Lung sounds are clear.
# The abdomen is soft and non-tender. The patient has two largebore IVs. The Right AC 18 gauge with 1 liter Lactated Ringers at TKO. 
# The left AC has been saline locked. The patient received fentanyl 50 mcg slow IVP, and stated relief, with pain now 1 out of 10. 
# The family is at the bedside, and the patient remains in good spirits.
# """