### Preprocessing for Mimic4 Dataset

**Note that the AUG dataset is incomplete, so we skip it here**

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from medrlcot.config.env import MedRL_CoT
from medrlcot import data_manager
from medrlcot.medrlcot_logger import setup_logger
from dotenv import load_dotenv
from datasets import Features, Value
import torch
from IPython.display import display
import numpy as np
import pandas as pd
import datasets as hf_datasets
import logging
import os
import json

# pd.set_option('display.max_colwidth', None)  
# pd.set_option('display.width', 0)    
# pd.set_option('display.max_rows', None)  

In [3]:
model_cfg_path = os.path.join(os.getcwd(), "medrlcot/config/.env")
medrlcot_config = MedRL_CoT(model_cfg_path)

setup_logger()
logger = logging.getLogger("MedRL-CoT Preprocess")

2025-06-04 07:20:23,500 || INFO || Logger - Setup for MedRL-CoT's log done. This is the beginning of the log.


Generated new log file logs/medrlcot039.log


In [4]:
# processed_dirs = [os.path.join(os.getcwd(), medrlcot_config.data_dir, ds, 'processed') for ds in medrlcot_config.datasets]
processed_dirs = {ds: os.path.join(os.getcwd(), medrlcot_config.data_dir, ds, 'processed') for ds in medrlcot_config.datasets}
processed_dirs.keys()

dict_keys(['aug_med_notes', 'mimic4'])

In [5]:
classes = np.array(['symptoms_labs', 'thought_process', 'diagnosis'])

In [6]:
def load_labeled(arrow_dir):
    arrows = [os.path.join(arrow_dir, f) for f in os.listdir(arrow_dir) if f.endswith(".arrow")]
    processed_dataset = hf_datasets.concatenate_datasets([hf_datasets.Dataset.from_file(arrow) for arrow in arrows]).to_dict()

    return processed_dataset

In [7]:
processed_datasets = {key: pd.DataFrame(load_labeled(processed_dir)) for key, processed_dir in processed_dirs.items()}
processed_datasets.keys()

dict_keys(['aug_med_notes', 'mimic4'])

In [8]:
def mimic_preprocess(dataset):
    cleaned_ds = dataset.copy()

    # For loggin purposes
    N = cleaned_ds.shape[0]
    logger.info(f"Found {N} rows")
    num_renames = cleaned_ds[cleaned_ds['class'].isin(['symptoms_lbs', 'symptoms_lads'])].shape[0]
    logger.info(f"Fixed class naming for {num_renames} rows ({(num_renames / N)*100} %)")
    
    # Fix naming of some classes
    cleaned_ds['class'] = cleaned_ds['class'].replace({'symptoms_lbs': 'symptoms_labs', 'symptoms_lads': 'symptoms_labs'})
    
    pos_invalids = cleaned_ds[~cleaned_ds['class'].isin(classes)]
    swapped_values = pos_invalids[pos_invalids['sentence'].str.lower().isin(classes)]    # Rows with swapped values
    
    # Clearly invalids, temp drop to remove from our ceaned list
    invalid_classes = pos_invalids[pos_invalids['class'].str.lower().isin(['', '0', '__', 'None'])]  # collect empty sentence and classes (Note that doing it here will catch the invalid swapped sentences as well)
    invalid_sentences = pos_invalids[pos_invalids['sentence'].str.lower().isin(['', '__', 'None'])]
    invalids = pos_invalids.loc[invalid_classes.index.union(invalid_sentences.index)]
    nonstd_classes = pos_invalids.drop(index=swapped_values.index.union(invalids.index))   # Get list of non-standard classes

    # Get list of classes that can be classified as "other" with enough occurence (non-outliery)
    # value_cnts = nonstd_classes['class'].value_counts()
    # other_classes = value_cnts[value_cnts >= 5].index.tolist()
    value_cnts = nonstd_classes['class'].value_counts()
    other_classes = value_cnts[value_cnts >= 5].index.tolist()
    other_class_indices = nonstd_classes[nonstd_classes['class'].isin(other_classes)].index
    # print(value_cnts[value_cnts >= 5])
    # display(cleaned_ds[cleaned_ds['class'] == 'past_surgical_history'])
    # display(cleaned_ds[cleaned_ds['class'] == 'followup_instructions'])
    # display(cleaned_ds[cleaned_ds['class'] == 'demographic_data'])
    # display(cleaned_ds[cleaned_ds['class'] == 'past_surgical_history'])
    
    # Clean the dataset
    swapped_indices = swapped_values.index
    cleaned_ds.loc[swapped_indices, ['sentence', 'class']] = cleaned_ds.loc[swapped_indices, ['class', 'sentence']].values # swap the values in indices where it's swapped
    cleaned_ds.loc[other_class_indices, 'class'] = 'other'
    # cleaned_ds['class'] = cleaned_ds['class'].apply(lambda x: 'other' if x in other_classes else x)  # relabel non-standards to 'other'
    drop_indices = cleaned_ds[~cleaned_ds['class'].isin(np.append(classes, 'other'))].index
    cleaned_ds = cleaned_ds.drop(index=drop_indices) # drop all others that aren't in our list of classes + 'other'  (basically all invalids)

    # Summary
    num_reclass = cleaned_ds[cleaned_ds['class'] == 'other'].shape[0]
    logger.info(f"Re-classified {len(other_classes)} classes as 'other', or {num_reclass} rows ({(num_reclass / N)*100} %)")
    logger.info(f'Swapped class and sentence values of {swapped_indices.shape[0]} rows ({(swapped_indices.shape[0] / N)*100} %)')
    logger.info(f'Dropped {drop_indices.shape[0]} invalid rows ({(drop_indices.shape[0] / N)*100} %)')

    # change case_id into int
    cleaned_ds['case_id'] = cleaned_ds['case_id'].astype(int)

    return cleaned_ds

mimic_preprocess(processed_datasets['mimic4'])['class'].value_counts()

2025-06-04 07:20:23,829 || INFO || MedRL-CoT Preprocess - Found 29654 rows
2025-06-04 07:20:23,832 || INFO || MedRL-CoT Preprocess - Fixed class naming for 282 rows (0.9509678289606799 %)
2025-06-04 07:20:23,854 || INFO || MedRL-CoT Preprocess - Re-classified 18 classes as 'other', or 1073 rows (3.618398866931948 %)
2025-06-04 07:20:23,855 || INFO || MedRL-CoT Preprocess - Swapped class and sentence values of 571 rows (1.925541242328185 %)
2025-06-04 07:20:23,856 || INFO || MedRL-CoT Preprocess - Dropped 169 invalid rows (0.5699062521076415 %)


class
symptoms_labs      16680
diagnosis           9112
thought_process     2620
other               1073
Name: count, dtype: int64

In [9]:
def aug_preprocess(dataset):
    cleaned_ds = dataset.copy()

    # For loggin purposes
    N = cleaned_ds.shape[0]
    logger.info(f"Found {N} rows")
    num_renames = cleaned_ds[cleaned_ds['class'].isin(['symptoms_lbs', 'symptoms_lads'])].shape[0]
    logger.info(f"Fixed class naming for {num_renames} rows ({(num_renames / N)*100} %)")
    
    # Fix naming of some classes
    cleaned_ds['class'] = cleaned_ds['class'].replace({'symptoms_lbs': 'symptoms_labs', 'symptoms_lads': 'symptoms_labs'})
    
    pos_invalids = cleaned_ds[~cleaned_ds['class'].isin(classes)]
    swapped_values = pos_invalids[pos_invalids['sentence'].isin(classes)]    # Rows with swapped values
    invalid_classes = pos_invalids[pos_invalids['class'].str.lower().isin(['', '0', '__', 'None', '[]', 'False', 'The', 'No'])]  # collect empty sentence and classes (Note that doing it here will catch the invalid swapped sentences as well)
    ignore_classes = pos_invalids[pos_invalids['sentence'].str.contains('not a sentence')]
    invalid_sentences = pos_invalids[pos_invalids['sentence'].str.lower().isin(['', '__', 'None', '[]', 'False', '()'])]

    invalids = pos_invalids.loc[invalid_classes.index.union(invalid_sentences.index.union(ignore_classes.index))]
    # invalids = pos_invalids.loc[invalid_classes.index.union(invalid_sentences.index)]
    # invalids = pos_invalids[pos_invalids ['class'].isin(['', '0', '[]', 'False'])]   # Clearly invalids, temp drop to remove from our ceaned list
    nonstd_classes = pos_invalids.drop(index=swapped_values.index.union(invalids.index))   # Get list of non-standard class rows
    # print(ignore_classes.index[0] in list(nonstd_classes.index))
    # print(ignore_classes.index[0] in list(invalids.index))
    # print(nonstd_classes[nonstd_classes['sentence'].str.contains('not a sentence')])

    # Get list of classes that can be classified as "other" with enough occurence (non-outliery)
    # value_cnts = nonstd_classes['class'].value_counts()
    # other_classes = value_cnts[value_cnts >= 5].index.tolist()
    # other_class_rows = nonstd_classes[nonstd_classes['class'].isin(other_classes)]
    # other_class_rows = other_class_rows[~other_class_rows['sentence'].str.contains('thought_process')]
    # other_class_rows = other_class_rows[~other_class_rows['sentence'].str.contains('symptoms_labs')]
    # other_class_indices = other_class_rows[~other_class_rows['sentence'].str.contains('diagnosis')].index  # Note we get rid of this because bad classification when it should've been "diagnosis", also removes "diagnosis: " sentences
    # print(other_class_indices)
    # print(value_cnts[value_cnts >= 5])
    # display(cleaned_ds[cleaned_ds['class'] == 'symptoms_lbs'])
    # display(cleaned_ds[cleaned_ds['class'] == 'A'])
    # display(cleaned_ds[cleaned_ds['class'] == 'classification'])
    # display(cleaned_ds[cleaned_ds['class'] == 'No'])
    # display(cleaned_ds[cleaned_ds['class'] == 'The'])

    swapped_indices = swapped_values.index
    cleaned_ds.loc[swapped_indices, ['sentence', 'class']] = cleaned_ds.loc[swapped_indices, ['class', 'sentence']].values # swap the values in indices where it's swapped
    # cleaned_ds.loc[other_class_indices, 'class'] = 'other'
    # cleaned_ds['class'] = cleaned_ds['class'].apply(lambda x: 'other' if x in other_classes else x)  # relabel non-standards to 'other'
    drop_indices = cleaned_ds[~cleaned_ds['class'].isin(classes)].index
    # print(ignore_classes.index in list(drop_indices))
    # print(ignore_classes.index)
    # print(list(drop_indices))
    
    cleaned_ds = cleaned_ds.drop(index=drop_indices) # drop all others that aren't in our list of classes + 'other' (basically all invalids)

    # Summary
    num_reclass = cleaned_ds[cleaned_ds['class'] == 'other'].shape[0]
    # logger.info(f"Re-classified {len(other_classes)} classes as 'other', or {num_reclass} rows ({(num_reclass / N)*100} %)")
    logger.info(f'Swapped class and sentence values of {swapped_indices.shape[0]} rows ({(swapped_indices.shape[0] / N)*100} %)')
    logger.info(f'Dropped {drop_indices.shape[0]} invalid rows ({(drop_indices.shape[0] / N)*100} %)')

    # change case_id into int
    cleaned_ds['case_id'] = cleaned_ds['case_id'].astype(int)
    # print(cleaned_ds[cleaned_ds['sentence'].str.contains('not a sentence')])
    
    return cleaned_ds

aug_preprocess(processed_datasets['aug_med_notes'])

2025-06-04 07:20:23,898 || INFO || MedRL-CoT Preprocess - Found 19968 rows
2025-06-04 07:20:23,901 || INFO || MedRL-CoT Preprocess - Fixed class naming for 21 rows (0.10516826923076923 %)
2025-06-04 07:20:23,917 || INFO || MedRL-CoT Preprocess - Swapped class and sentence values of 131 rows (0.6560496794871795 %)
2025-06-04 07:20:23,917 || INFO || MedRL-CoT Preprocess - Dropped 240 invalid rows (1.201923076923077 %)


Unnamed: 0,sentence,class,case_id
0,"A sixteen year-old girl, presented to our Outp...",symptoms_labs,-1
1,She was not able to maintain an erect posture ...,symptoms_labs,-1
2,She would keep her head turned to the right an...,symptoms_labs,-1
3,There was a sideways bending of the back in th...,symptoms_labs,-1
4,To counter the abnormal positioning of the bac...,symptoms_labs,-1
...,...,...,...
19963,"No significant muscle atrophy was present, and...",symptoms_labs,912
19964,The physical examination showed that both of h...,symptoms_labs,915
19965,Both knees were very soft and could touch the ...,symptoms_labs,915
19966,"Upon palpation, the continuity of the quadrice...",symptoms_labs,915


In [10]:
# Testing
proc_funcs = {'mimic4': mimic_preprocess, 'aug_med_notes': aug_preprocess}

In [11]:
preprocessed_datasets = dict()
for key, item in processed_datasets.items():
    logger.info("=" * 50)
    logger.info(f"Cleaning up {key} dataset")
    # processed_datasets[key] = mimic_preprocess(processed_datasets[key])
    # mimic_preprocess(processed_datasets[key])
    preprocessed_datasets[key] = proc_funcs[key](processed_datasets[key])
    logger.info("=" * 50)

2025-06-04 07:20:23,983 || INFO || MedRL-CoT Preprocess - Cleaning up aug_med_notes dataset
2025-06-04 07:20:23,985 || INFO || MedRL-CoT Preprocess - Found 19968 rows
2025-06-04 07:20:23,986 || INFO || MedRL-CoT Preprocess - Fixed class naming for 21 rows (0.10516826923076923 %)
2025-06-04 07:20:24,002 || INFO || MedRL-CoT Preprocess - Swapped class and sentence values of 131 rows (0.6560496794871795 %)
2025-06-04 07:20:24,002 || INFO || MedRL-CoT Preprocess - Dropped 240 invalid rows (1.201923076923077 %)
2025-06-04 07:20:24,006 || INFO || MedRL-CoT Preprocess - Cleaning up mimic4 dataset
2025-06-04 07:20:24,008 || INFO || MedRL-CoT Preprocess - Found 29654 rows
2025-06-04 07:20:24,010 || INFO || MedRL-CoT Preprocess - Fixed class naming for 282 rows (0.9509678289606799 %)
2025-06-04 07:20:24,031 || INFO || MedRL-CoT Preprocess - Re-classified 18 classes as 'other', or 1073 rows (3.618398866931948 %)
2025-06-04 07:20:24,031 || INFO || MedRL-CoT Preprocess - Swapped class and sentence 

In [12]:
preprocessed_datasets.keys()

dict_keys(['aug_med_notes', 'mimic4'])

In [13]:
preprocessed_datasets['mimic4']['case_id'].unique()

array([ -1,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,  14,
        15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
        28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,
        41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,
        54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,
        67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,  78,  79,
        80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,  91,  92,
        93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103, 104, 105,
       106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118,
       119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131,
       132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144,
       145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157,
       158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170,
       171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 18

In [14]:
# def join_sentence_class(group):
#     return ' '.join(f"{row['sentence']} <{row['class']}>" for _, row in group.iterrows())
    
# # preprocessed_datasets['mimic4'].groupby('case_id').apply(join_sentence_class).reset_index(name='full_class_case').sort_values('case_id')

In [15]:
# cases_datasets = dict()
# for key, dataset in preprocessed_datasets.items():
#     cases_datasets[key] = dataset.groupby('case_id').apply(join_sentence_class).reset_index(name='full_class_case').sort_values('case_id')

In [16]:
# cases_datasets['aug_med_notes']

In [17]:
# cases_datasets['aug_med_notes']

In [18]:
# preprocessed_datasets['aug_med_notes'][preprocessed_datasets['aug_med_notes']['class'] == 'other']

In [19]:
def xy_split_processing_sft(group, x_func=None, y_func=None):
    X = []
    Y = []
    
    for _, row in group.iterrows():
        if row['class'] == 'symptoms_labs' or row['class'] == 'other':
            X.append(row)
        else:
            Y.append(row)

    # X_case = ' '.join(f"{row['sentence']} <{row['class']}>" for _, row in group.iterrows()) # Input with all
    X_case = x_func(X) if x_func else ' '.join([str(row['sentence']) for row in X])
    Y_case = ' '.join([f"{row['sentence']} <{row['class']}> " for row in Y])    # Output with only thought_process and diagnosis
    
    X_prompt = f"""Below is a clinical case. Your task is to provide a step-by-step clinical reasoning followed by the diagnosis.

    {X_case}
    """
    
    return pd.Series({'X': X_prompt, 'Y': Y_case})

import medrlcot.preprocessing as mp
preprocessed_datasets = mp.preprocess_datasets()

# Combine cases into one for cases as example for SFT
cases_datasets = dict()
for key, dataset in preprocessed_datasets.items():
    cases_datasets[key] = dataset.groupby('case_id').apply(mp.xy_split_processing_sft).reset_index().sort_values('case_id')
    cases_datasets[key]['wc_x'] = cases_datasets[key]['X'].apply(lambda x: len(x.split())) 
    cases_datasets[key]['wc_y'] = cases_datasets[key]['Y'].apply(lambda x: len(x.split())) 
    cases_datasets[key] = cases_datasets[key][(cases_datasets[key]['wc_x'] <= 800) & (cases_datasets[key]['wc_y'] <= 900)].reset_index(drop=True)

import os
from medrlcot import data_manager
import medrlcot.config.env as mce
model_cfg_path = os.path.join(os.getcwd(), "medrlcot/config/.env")
medrlcot_config = mce.MedRL_CoT(model_cfg_path)
raw_datasets = data_manager.load_datasets(medrlcot_config.datasets, data_dir=medrlcot_config.data_dir)  # Load raw dataset (original cases) for RM

2025-06-04 07:20:24,262 || INFO || Logger - Setup for MedRL-CoT's log done. This is the beginning of the log.
2025-06-04 07:20:24,436 || INFO || MedRL-CoT Preprocess - Cleaning up aug_med_notes dataset
2025-06-04 07:20:24,437 || INFO || MedRL-CoT Preprocess - Found 19968 rows
2025-06-04 07:20:24,440 || INFO || MedRL-CoT Preprocess - Fixed class naming for 21 rows (0.10516826923076923 %)


Generated new log file logs/medrlcot040.log


2025-06-04 07:20:24,472 || INFO || MedRL-CoT Preprocess - Re-classified 5 classes as 'other', or 76 rows (0.38060897435897434 %)
2025-06-04 07:20:24,473 || INFO || MedRL-CoT Preprocess - Swapped class and sentence values of 131 rows (0.6560496794871795 %)
2025-06-04 07:20:24,473 || INFO || MedRL-CoT Preprocess - Dropped 224 invalid rows (1.1217948717948718 %)
2025-06-04 07:20:24,476 || INFO || MedRL-CoT Preprocess - Cleaning up mimic4 dataset
2025-06-04 07:20:24,478 || INFO || MedRL-CoT Preprocess - Found 29654 rows
2025-06-04 07:20:24,481 || INFO || MedRL-CoT Preprocess - Fixed class naming for 282 rows (0.9509678289606799 %)
2025-06-04 07:20:24,512 || INFO || MedRL-CoT Preprocess - Swapped class and sentence values of 571 rows (1.925541242328185 %)
2025-06-04 07:20:24,513 || INFO || MedRL-CoT Preprocess - Dropped 1461 invalid rows (4.9268226883388415 %)
  cases_datasets[key] = dataset.groupby('case_id').apply(mp.xy_split_processing_sft).reset_index().sort_values('case_id')
  cases_da

In [20]:
cases_datasets['aug_med_notes']

Unnamed: 0,case_id,X,Y,wc_x,wc_y
0,100,Provide a step-by-step clinical reasoning foll...,"A 34 year old Persian woman, gravida 1, para 0...",159,304
1,101,Provide a step-by-step clinical reasoning foll...,She was treated with thrombolysis and endovasc...,180,255
2,102,Provide a step-by-step clinical reasoning foll...,classification [diagnosis],12,2
3,103,Provide a step-by-step clinical reasoning foll...,Staged bilateral total hip arthroplasties were...,374,314
4,104,Provide a step-by-step clinical reasoning foll...,The initial workup was unrevealing for an etio...,248,149
...,...,...,...,...,...
732,93,Provide a step-by-step clinical reasoning foll...,A 45 mm × 37 mm pseudoaneurysm in lateral side...,329,185
733,94,Provide a step-by-step clinical reasoning foll...,The working diagnosis was a collection seconda...,147,415
734,96,Provide a step-by-step clinical reasoning foll...,She had undergone endovascular trapping of the...,160,587
735,97,Provide a step-by-step clinical reasoning foll...,The patient underwent general anesthesia for t...,287,183


In [21]:
preprocessed_datasets['aug_med_notes']['class'].value_counts()

class
symptoms_labs      9977
thought_process    5318
diagnosis          4373
other                76
Name: count, dtype: int64

In [22]:
# filtered_datasets = dict()
# for key, dataset in cases_datasets.items():
#     dataset['wc_x'] = dataset['X'].apply(lambda x: len(x.split())) 
#     dataset['wc_y'] = dataset['Y'].apply(lambda x: len(x.split())) 
#     filtered_datasets[key] = dataset[(dataset['wc_x'] <= 800) & (dataset['wc_y'] <= 900)].reset_index(drop=True)

# print(cases_datasets['aug_med_notes'].shape, filtered_datasets['aug_med_notes'].shape)
# print(cases_datasets['mimic4'].shape, filtered_datasets['mimic4'].shape)

In [23]:
max_words = (None, 0)
next_max = (None, -1)
num_exceed = 0
for key, dataset in cases_datasets.items():
    for row in dataset.iterrows():
        cnt = ((key, row), len(row[1]['X'].split()))
        if cnt[1] > 750:
            num_exceed += 1
        if cnt[1] > max_words[1]:
            next_max = max_words
            max_words = cnt
        else:
            next_max = max(cnt, next_max, key=lambda x: x[1])
print(max_words, next_max)
num_exceed

(('mimic4', (47, case_id                                                  152
X          Provide a step-by-step clinical reasoning foll...
Y          She received 1L NS, 5 mg of morphine, 25 mg of...
wc_x                                                     795
wc_y                                                     362
Name: 47, dtype: object)), 795) (('mimic4', (191, case_id                                                   56
X          Provide a step-by-step clinical reasoning foll...
Y          Thought process: logical, linear, perservated ...
wc_x                                                     793
wc_y                                                     527
Name: 191, dtype: object)), 793)


16

In [24]:
# Create customd dataset obj
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base", device_map="auto")
cases_data = data_manager.MedRL_CoT_Dataset(cases_datasets, seed=0, tokenizer=tokenizer)

2025-06-04 07:20:27,711 || INFO || DataManager - Using train-val split: [0.75, 0.25]
2025-06-04 07:20:27,712 || INFO || DataManager - Split-Shuffle with seed 0
2025-06-04 07:20:27,713 || INFO || DataManager - Splitting aug_med_notes dataset.
2025-06-04 07:20:27,715 || INFO || DataManager - Split into 552 train rows and 185 val rows
2025-06-04 07:20:27,716 || INFO || DataManager - Splitting mimic4 dataset.
2025-06-04 07:20:27,718 || INFO || DataManager - Split into 174 train rows and 59 val rows
2025-06-04 07:20:27,719 || INFO || DataManager - Creating a single joint dataset of dict_keys(['aug_med_notes', 'mimic4']) dataset splits
2025-06-04 07:20:27,721 || INFO || DataManager - Joined 726 train rows
2025-06-04 07:20:27,722 || INFO || DataManager - Shuffling train's rows
2025-06-04 07:20:27,723 || INFO || DataManager - Joined 244 val rows
2025-06-04 07:20:27,724 || INFO || DataManager - Shuffling val's rows
2025-06-04 07:20:27,726 || INFO || DataManager - Returning shuffled train-val sp

Index([], dtype='int64')
Index([], dtype='int64')


In [25]:
# cases_data['train'] = cases_data['train'].drop(cases_data['train'][cases_data['train']['Y'].str.strip() == ''].index)
# print(cases_data['train'][cases_data['train']['X'].str.strip() == ''].index)

In [26]:
cases_data['train']

Unnamed: 0,X,Y
0,Provide a step-by-step clinical reasoning foll...,Pain of the first mentioned episode was subsid...
1,Provide a step-by-step clinical reasoning foll...,The swelling was seen extending from mesial as...
2,Provide a step-by-step clinical reasoning foll...,The patient was diagnosed as TN at another hos...
3,Provide a step-by-step clinical reasoning foll...,After being referred to several gastroenterolo...
4,Provide a step-by-step clinical reasoning foll...,While mast cells represent the effector cell i...
...,...,...
721,Provide a step-by-step clinical reasoning foll...,Discharge Disposition: [diagnosis] Discharge ...
722,Provide a step-by-step clinical reasoning foll...,Sinus plain radiographs demonstrated ‘sinusiti...
723,Provide a step-by-step clinical reasoning foll...,"After admission, the patient maintained with a..."
724,Provide a step-by-step clinical reasoning foll...,An 80-year-old woman with a history of collaps...


In [27]:
cases_data['train'][cases_data['train']['Y'].str.strip() == '']

Unnamed: 0,X,Y


In [28]:
cases_data['train']

Unnamed: 0,X,Y
0,Provide a step-by-step clinical reasoning foll...,Pain of the first mentioned episode was subsid...
1,Provide a step-by-step clinical reasoning foll...,The swelling was seen extending from mesial as...
2,Provide a step-by-step clinical reasoning foll...,The patient was diagnosed as TN at another hos...
3,Provide a step-by-step clinical reasoning foll...,After being referred to several gastroenterolo...
4,Provide a step-by-step clinical reasoning foll...,While mast cells represent the effector cell i...
...,...,...
721,Provide a step-by-step clinical reasoning foll...,Discharge Disposition: [diagnosis] Discharge ...
722,Provide a step-by-step clinical reasoning foll...,Sinus plain radiographs demonstrated ‘sinusiti...
723,Provide a step-by-step clinical reasoning foll...,"After admission, the patient maintained with a..."
724,Provide a step-by-step clinical reasoning foll...,An 80-year-old woman with a history of collaps...


In [29]:
cases_data['train']['Y'].iloc[0]

'Pain of the first mentioned episode was subsided spontaneously and the patient underwent endoscopic retrograde cholangiopancreatography (ERCP) 1 month later [diagnosis]  but there was no visible CBD stone on ERCP at that time. [diagnosis]  Differential diagnosis made by CT scan was a polypoid mass arising from biliary ducts or thick sludge fulfilling biliary ducts. [diagnosis]  classification [diagnosis]  Hence, tissue sampling of mentioned infiltrative mass was performed in May 2019 revealing cholangiocarcinoma developing in intraductal papillary neoplasm of bile duct in microscopic pathological study and mucin secreting neoplasm (adenocarcinoma) with GI origin on immunohistochemistry study. [diagnosis]  Another contrast enhanced abdominal CT scan was ordered in May, 2019 before surgery [] which demonstrates heterogeneous mass in size of 70 mm × 42 mm at left liver lobe accompanied with perilesional staining and focal peripheral biliary duct ectasia. [thought_process]  Branching and 

In [30]:
cases_data.get_dataloader('train')

<torch.utils.data.dataloader.DataLoader at 0x7697466c1090>

In [31]:
cases_data.get_dataloader('val')

<torch.utils.data.dataloader.DataLoader at 0x7697466c2620>

In [32]:
from transformers import AutoModelForSeq2SeqLM
model  = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")

In [33]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq

training_args = Seq2SeqTrainingArguments(
    output_dir="./checkpoints",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    predict_with_generate=True,
    # fp16=True,
    learning_rate=5e-5,
    num_train_epochs=3,
    save_strategy="epoch",
    logging_dir="./logs"
)

In [34]:
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    padding="max_length",
    return_tensors="pt"
)

In [35]:
def compute_metrics(eval_preds): # https://www.datacamp.com/tutorial/flan-t5-tutorial
   preds, labels = eval_preds

   # decode preds and labels
   labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
   decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
   decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

   # rougeLSum expects newline after each sentence
   decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
   decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]

   result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
  
   return result

In [36]:
train_tok_dataset = cases_data.get_torch_dataset('train')
val_tok_dataset = cases_data.get_torch_dataset('val')

In [37]:
t = val_tok_dataset[100]
t['labels']

tensor([14298,   769, 18780,    23,   152,   592,    47, 13090,   227,     3,
            9,  6743,    47,   990,     5,   784, 11841,    17,   834, 15056,
          908,    37,   255,     9,   189,    47,   646,    16,   286,    11,
            8,  3979,    47,   703,   127,  1054,     5,   784, 11841,    17,
          834, 15056,   908,    37,  1868,    47, 24731, 16928,  1427,  5711,
           11,  1026,    12, 29216,    26,    12,    51,  5984,    41,  6227,
           61, 19083,    21,  5002,    13, 27970,  2871,     5,   784, 11841,
           17,   834, 15056,   908,    37,  8668,  5924,  5111,  1223,    54,
           29,  7830,    13,     8,     3,     9,   127,  1225, 26806,    28,
            8,   255,     9,   189,     5,   784, 25930,  4844,   159,   908,
           37,  1868,    47,  1461, 10250,    57, 16352,    12,    69,  6568,
           21,  1146,   593,    13,     3, 15100,  3730,   124,     5,   784,
        11841,    17,   834, 15056,   908,    37,  1868,    47, 

In [38]:
tokenizer.decode(t['input_ids'])

'Provide a step-by-step clinical reasoning followed by the diagnosis: An 88-year-old man was transferred from a referring hospital for descending thoracic aortic injury after attempted pacemaker placement; 3 days prior he was admitted with a transient ischemic attack. He had new onset atrial fibrillation and sinus bradycardia that prompted pacemaker placement. After placement of a five French sheath, arterial blood return was noted. The patient arrived intubated and sedated with the sheath in place in the left chest covered with a dressing and a left chest tube in place with 100cc of sanguinous output. Thought Process: Diagnosis: </s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pa

In [39]:
inputs = val_tok_dataset[100]
input_ids = inputs["input_ids"].unsqueeze(0).to(model.device)
attention_mask = inputs["attention_mask"].unsqueeze(0).to(model.device)

with torch.no_grad():
    outputs = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=200)

decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(decoded_output)

Aortic aortic injury


In [40]:
sample = train_tok_dataset[0]
print(tokenizer.decode(sample["input_ids"]))
print(sample['labels'])

Provide a step-by-step clinical reasoning followed by the diagnosis: In December 2006, about 6 months earlier than the first episode, checkup ultrasound had been done for the patient that showed only mild focal fatty change in lateral segments of left liver lobe without similar findings of biliary ducts. An abdominal computed tomography (CT) scan was performed on June 2011, (images not provided) which demonstrates a hypodense mass-like lesion in the origin of the left hepatic duct measuring 38 mm <unk> 18 mm <unk> 34 mm with extension to peripheral branches causing dilation of left main hepatic duct and left intrahepatic biliary ducts. At this time (September 2018) MRCP was ordered for more evaluation. MRCP demonstrates a T2 high signal intensity lesion in the left liver lobe accompanied with intrahepatic biliary duct ectasia and parenchymal shrinkage []. EUS was done in March, 2019 which demonstrates a 50 mm <unk> 30 mm hypoechoic mass in the left liver lobe containing dilated intrahe

In [41]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_tok_dataset,
    eval_dataset=val_tok_dataset,
    compute_metrics=compute_metrics,
)

In [42]:
import torch
sample = train_tok_dataset[0]
batch = {k: v.unsqueeze(0).to(model.device) for k, v in sample.items()}
with torch.no_grad():
    output = model(**batch)
print("Loss:", output.loss.item())


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Loss: 3.027536630630493


In [43]:
model.gradient_checkpointing_enable()

In [None]:
trainer.train()

  batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64)
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Step,Training Loss


In [None]:
# example_case = cases_data['val'].iloc[-1]

# val_input = example_case["X"]
# val_target = example_case["Y"]
# val_input

In [None]:
val_tok_dataset[100]

In [None]:
# tokenizer.decode(inputs['input_ids'])

In [None]:
tokenizer.decode(val_tok_dataset[100]['input_ids'])

In [None]:
inputs = val_tok_dataset[99]
input_ids = inputs["input_ids"].unsqueeze(0).to(model.device)
attention_mask = inputs["attention_mask"].unsqueeze(0).to(model.device)

with torch.no_grad():
    outputs = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=200)

decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(decoded_output)

In [None]:
# example_case = """
# John Doe is a 40-year-old male who was involved in a motor vehicle accident. 
# He was a restrained passenger who was rear-ended by another car going approximately 45 mph. 
# He has no medical history and takes no medications at home. 
# He has no allergies. His primary physician is Dr. Johnson. He has been awake, alert, and oriented with a Glasgow Coma Score of 15 since arrival. 
# He is moving all extremities and has good strength bilaterally. 
# His chief complaint is neck pain, rated a 5 out of 10, and he remains in cervical-spine precautions until the trauma team clears him. 
# He is to be kept NPO until cleared by the trauma team as well. CT scans have been completed of the head, cervical spine, chest, abdomen, and pelvis. 
# We are pending reports from radiology. Last vital signs, 15 minutes ago, were temperature 36.6, pulse 80, respiratory rate 14, and blood pressure 120/80. Lung sounds are clear.
# The abdomen is soft and non-tender. The patient has two largebore IVs. The Right AC 18 gauge with 1 liter Lactated Ringers at TKO. 
# The left AC has been saline locked. The patient received fentanyl 50 mcg slow IVP, and stated relief, with pain now 1 out of 10. 
# The family is at the bedside, and the patient remains in good spirits.
# """