### Preprocessing for Mimic4 Dataset

**Note that the AUG dataset is incomplete, so we skip it here**

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from medrlcot.config.env import MedRL_CoT
from medrlcot import data_manager
from medrlcot.medrlcot_logger import setup_logger
from dotenv import load_dotenv
from datasets import Features, Value
from IPython.display import display
import numpy as np
import pandas as pd
import datasets as hf_datasets
import logging
import os
import json

# pd.set_option('display.max_colwidth', None)  
# pd.set_option('display.width', 0)    
# pd.set_option('display.max_rows', None)  

In [3]:
model_cfg_path = os.path.join(os.getcwd(), "medrlcot/config/.env")
medrlcot_config = MedRL_CoT(model_cfg_path)

setup_logger()
logger = logging.getLogger("MedRL-CoT Preprocess")

2025-06-03 18:54:08,528 || INFO || Logger - Setup for MedRL-CoT's log done. This is the beginning of the log.


Generated new log file logs/medrlcot121.log


In [4]:
# processed_dirs = [os.path.join(os.getcwd(), medrlcot_config.data_dir, ds, 'processed') for ds in medrlcot_config.datasets]
processed_dirs = {ds: os.path.join(os.getcwd(), medrlcot_config.data_dir, ds, 'processed') for ds in medrlcot_config.datasets}
processed_dirs.keys()

dict_keys(['aug_med_notes', 'mimic4'])

In [5]:
classes = np.array(['symptoms_labs', 'thought_process', 'diagnosis'])

In [6]:
def load_labeled(arrow_dir):
    arrows = [os.path.join(arrow_dir, f) for f in os.listdir(arrow_dir) if f.endswith(".arrow")]
    processed_dataset = hf_datasets.concatenate_datasets([hf_datasets.Dataset.from_file(arrow) for arrow in arrows]).to_dict()

    return processed_dataset

In [7]:
processed_datasets = {key: pd.DataFrame(load_labeled(processed_dir)) for key, processed_dir in processed_dirs.items()}
processed_datasets.keys()

dict_keys(['aug_med_notes', 'mimic4'])

In [8]:
def mimic_preprocess(dataset):
    cleaned_ds = dataset.copy()

    # For loggin purposes
    N = cleaned_ds.shape[0]
    logger.info(f"Found {N} rows")
    num_renames = cleaned_ds[cleaned_ds['class'].isin(['symptoms_lbs', 'symptoms_lads'])].shape[0]
    logger.info(f"Fixed class naming for {num_renames} rows ({(num_renames / N)*100} %)")
    
    # Fix naming of some classes
    cleaned_ds['class'] = cleaned_ds['class'].replace({'symptoms_lbs': 'symptoms_labs', 'symptoms_lads': 'symptoms_labs'})
    
    pos_invalids = cleaned_ds[~cleaned_ds['class'].isin(classes)]
    swapped_values = pos_invalids[pos_invalids['sentence'].str.lower().isin(classes)]    # Rows with swapped values
    
    # Clearly invalids, temp drop to remove from our ceaned list
    invalid_classes = pos_invalids[pos_invalids['class'].str.lower().isin(['', '0', '__', 'None'])]  # collect empty sentence and classes (Note that doing it here will catch the invalid swapped sentences as well)
    invalid_sentences = pos_invalids[pos_invalids['sentence'].str.lower().isin(['', '__', 'None'])]
    invalids = pos_invalids.loc[invalid_classes.index.union(invalid_sentences.index)]
    nonstd_classes = pos_invalids.drop(index=swapped_values.index.union(invalids.index))   # Get list of non-standard classes

    # Get list of classes that can be classified as "other" with enough occurence (non-outliery)
    # value_cnts = nonstd_classes['class'].value_counts()
    # other_classes = value_cnts[value_cnts >= 5].index.tolist()
    value_cnts = nonstd_classes['class'].value_counts()
    other_classes = value_cnts[value_cnts >= 5].index.tolist()
    other_class_indices = nonstd_classes[nonstd_classes['class'].isin(other_classes)].index
    # print(value_cnts[value_cnts >= 5])
    # display(cleaned_ds[cleaned_ds['class'] == 'past_surgical_history'])
    # display(cleaned_ds[cleaned_ds['class'] == 'followup_instructions'])
    # display(cleaned_ds[cleaned_ds['class'] == 'demographic_data'])
    # display(cleaned_ds[cleaned_ds['class'] == 'past_surgical_history'])
    
    # Clean the dataset
    swapped_indices = swapped_values.index
    cleaned_ds.loc[swapped_indices, ['sentence', 'class']] = cleaned_ds.loc[swapped_indices, ['class', 'sentence']].values # swap the values in indices where it's swapped
    cleaned_ds.loc[other_class_indices, 'class'] = 'other'
    # cleaned_ds['class'] = cleaned_ds['class'].apply(lambda x: 'other' if x in other_classes else x)  # relabel non-standards to 'other'
    drop_indices = cleaned_ds[~cleaned_ds['class'].isin(np.append(classes, 'other'))].index
    cleaned_ds = cleaned_ds.drop(index=drop_indices) # drop all others that aren't in our list of classes + 'other'  (basically all invalids)

    # Summary
    num_reclass = cleaned_ds[cleaned_ds['class'] == 'other'].shape[0]
    logger.info(f"Re-classified {len(other_classes)} classes as 'other', or {num_reclass} rows ({(num_reclass / N)*100} %)")
    logger.info(f'Swapped class and sentence values of {swapped_indices.shape[0]} rows ({(swapped_indices.shape[0] / N)*100} %)')
    logger.info(f'Dropped {drop_indices.shape[0]} invalid rows ({(drop_indices.shape[0] / N)*100} %)')

    # change case_id into int
    cleaned_ds['case_id'] = cleaned_ds['case_id'].astype(int)

    return cleaned_ds

mimic_preprocess(processed_datasets['mimic4'])['class'].value_counts()

2025-06-03 18:54:08,943 || INFO || MedRL-CoT Preprocess - Found 29654 rows
2025-06-03 18:54:08,947 || INFO || MedRL-CoT Preprocess - Fixed class naming for 282 rows (0.9509678289606799 %)
2025-06-03 18:54:08,965 || INFO || MedRL-CoT Preprocess - Re-classified 18 classes as 'other', or 1073 rows (3.618398866931948 %)
2025-06-03 18:54:08,966 || INFO || MedRL-CoT Preprocess - Swapped class and sentence values of 571 rows (1.925541242328185 %)
2025-06-03 18:54:08,967 || INFO || MedRL-CoT Preprocess - Dropped 169 invalid rows (0.5699062521076415 %)


class
symptoms_labs      16680
diagnosis           9112
thought_process     2620
other               1073
Name: count, dtype: int64

In [9]:
def aug_preprocess(dataset):
    cleaned_ds = dataset.copy()

    # For loggin purposes
    N = cleaned_ds.shape[0]
    logger.info(f"Found {N} rows")
    num_renames = cleaned_ds[cleaned_ds['class'].isin(['symptoms_lbs', 'symptoms_lads'])].shape[0]
    logger.info(f"Fixed class naming for {num_renames} rows ({(num_renames / N)*100} %)")
    
    # Fix naming of some classes
    cleaned_ds['class'] = cleaned_ds['class'].replace({'symptoms_lbs': 'symptoms_labs', 'symptoms_lads': 'symptoms_labs'})
    
    pos_invalids = cleaned_ds[~cleaned_ds['class'].isin(classes)]
    swapped_values = pos_invalids[pos_invalids['sentence'].isin(classes)]    # Rows with swapped values
    invalid_classes = pos_invalids[pos_invalids['class'].str.lower().isin(['', '0', '__', 'None', '[]', 'False', 'The', 'No'])]  # collect empty sentence and classes (Note that doing it here will catch the invalid swapped sentences as well)
    ignore_classes = pos_invalids[pos_invalids['sentence'].str.contains('not a sentence')]
    invalid_sentences = pos_invalids[pos_invalids['sentence'].str.lower().isin(['', '__', 'None', '[]', 'False', '()'])]

    invalids = pos_invalids.loc[invalid_classes.index.union(invalid_sentences.index.union(ignore_classes.index))]
    # invalids = pos_invalids.loc[invalid_classes.index.union(invalid_sentences.index)]
    # invalids = pos_invalids[pos_invalids ['class'].isin(['', '0', '[]', 'False'])]   # Clearly invalids, temp drop to remove from our ceaned list
    nonstd_classes = pos_invalids.drop(index=swapped_values.index.union(invalids.index))   # Get list of non-standard class rows
    # print(ignore_classes.index[0] in list(nonstd_classes.index))
    # print(ignore_classes.index[0] in list(invalids.index))
    # print(nonstd_classes[nonstd_classes['sentence'].str.contains('not a sentence')])

    # Get list of classes that can be classified as "other" with enough occurence (non-outliery)
    # value_cnts = nonstd_classes['class'].value_counts()
    # other_classes = value_cnts[value_cnts >= 5].index.tolist()
    # other_class_rows = nonstd_classes[nonstd_classes['class'].isin(other_classes)]
    # other_class_rows = other_class_rows[~other_class_rows['sentence'].str.contains('thought_process')]
    # other_class_rows = other_class_rows[~other_class_rows['sentence'].str.contains('symptoms_labs')]
    # other_class_indices = other_class_rows[~other_class_rows['sentence'].str.contains('diagnosis')].index  # Note we get rid of this because bad classification when it should've been "diagnosis", also removes "diagnosis: " sentences
    # print(other_class_indices)
    # print(value_cnts[value_cnts >= 5])
    # display(cleaned_ds[cleaned_ds['class'] == 'symptoms_lbs'])
    # display(cleaned_ds[cleaned_ds['class'] == 'A'])
    # display(cleaned_ds[cleaned_ds['class'] == 'classification'])
    # display(cleaned_ds[cleaned_ds['class'] == 'No'])
    # display(cleaned_ds[cleaned_ds['class'] == 'The'])

    swapped_indices = swapped_values.index
    cleaned_ds.loc[swapped_indices, ['sentence', 'class']] = cleaned_ds.loc[swapped_indices, ['class', 'sentence']].values # swap the values in indices where it's swapped
    # cleaned_ds.loc[other_class_indices, 'class'] = 'other'
    # cleaned_ds['class'] = cleaned_ds['class'].apply(lambda x: 'other' if x in other_classes else x)  # relabel non-standards to 'other'
    drop_indices = cleaned_ds[~cleaned_ds['class'].isin(classes)].index
    # print(ignore_classes.index in list(drop_indices))
    # print(ignore_classes.index)
    # print(list(drop_indices))
    
    cleaned_ds = cleaned_ds.drop(index=drop_indices) # drop all others that aren't in our list of classes + 'other' (basically all invalids)

    # Summary
    num_reclass = cleaned_ds[cleaned_ds['class'] == 'other'].shape[0]
    # logger.info(f"Re-classified {len(other_classes)} classes as 'other', or {num_reclass} rows ({(num_reclass / N)*100} %)")
    logger.info(f'Swapped class and sentence values of {swapped_indices.shape[0]} rows ({(swapped_indices.shape[0] / N)*100} %)')
    logger.info(f'Dropped {drop_indices.shape[0]} invalid rows ({(drop_indices.shape[0] / N)*100} %)')

    # change case_id into int
    cleaned_ds['case_id'] = cleaned_ds['case_id'].astype(int)
    # print(cleaned_ds[cleaned_ds['sentence'].str.contains('not a sentence')])
    
    return cleaned_ds

aug_preprocess(processed_datasets['aug_med_notes'])

2025-06-03 18:54:09,022 || INFO || MedRL-CoT Preprocess - Found 19968 rows
2025-06-03 18:54:09,026 || INFO || MedRL-CoT Preprocess - Fixed class naming for 21 rows (0.10516826923076923 %)
2025-06-03 18:54:09,040 || INFO || MedRL-CoT Preprocess - Swapped class and sentence values of 131 rows (0.6560496794871795 %)
2025-06-03 18:54:09,041 || INFO || MedRL-CoT Preprocess - Dropped 240 invalid rows (1.201923076923077 %)


Unnamed: 0,sentence,class,case_id
0,"A sixteen year-old girl, presented to our Outp...",symptoms_labs,-1
1,She was not able to maintain an erect posture ...,symptoms_labs,-1
2,She would keep her head turned to the right an...,symptoms_labs,-1
3,There was a sideways bending of the back in th...,symptoms_labs,-1
4,To counter the abnormal positioning of the bac...,symptoms_labs,-1
...,...,...,...
19963,"No significant muscle atrophy was present, and...",symptoms_labs,912
19964,The physical examination showed that both of h...,symptoms_labs,915
19965,Both knees were very soft and could touch the ...,symptoms_labs,915
19966,"Upon palpation, the continuity of the quadrice...",symptoms_labs,915


In [10]:
# Testing
proc_funcs = {'mimic4': mimic_preprocess, 'aug_med_notes': aug_preprocess}

In [11]:
preprocessed_datasets = dict()
for key, item in processed_datasets.items():
    logger.info("=" * 50)
    logger.info(f"Cleaning up {key} dataset")
    # processed_datasets[key] = mimic_preprocess(processed_datasets[key])
    # mimic_preprocess(processed_datasets[key])
    preprocessed_datasets[key] = proc_funcs[key](processed_datasets[key])
    logger.info("=" * 50)

2025-06-03 18:54:09,131 || INFO || MedRL-CoT Preprocess - Cleaning up aug_med_notes dataset
2025-06-03 18:54:09,132 || INFO || MedRL-CoT Preprocess - Found 19968 rows
2025-06-03 18:54:09,135 || INFO || MedRL-CoT Preprocess - Fixed class naming for 21 rows (0.10516826923076923 %)
2025-06-03 18:54:09,148 || INFO || MedRL-CoT Preprocess - Swapped class and sentence values of 131 rows (0.6560496794871795 %)
2025-06-03 18:54:09,149 || INFO || MedRL-CoT Preprocess - Dropped 240 invalid rows (1.201923076923077 %)
2025-06-03 18:54:09,153 || INFO || MedRL-CoT Preprocess - Cleaning up mimic4 dataset
2025-06-03 18:54:09,155 || INFO || MedRL-CoT Preprocess - Found 29654 rows
2025-06-03 18:54:09,158 || INFO || MedRL-CoT Preprocess - Fixed class naming for 282 rows (0.9509678289606799 %)
2025-06-03 18:54:09,177 || INFO || MedRL-CoT Preprocess - Re-classified 18 classes as 'other', or 1073 rows (3.618398866931948 %)
2025-06-03 18:54:09,179 || INFO || MedRL-CoT Preprocess - Swapped class and sentence 

In [12]:
preprocessed_datasets.keys()

dict_keys(['aug_med_notes', 'mimic4'])

In [13]:
preprocessed_datasets['mimic4']['case_id'].unique()

array([ -1,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,  14,
        15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
        28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,
        41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,
        54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,
        67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,  78,  79,
        80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,  91,  92,
        93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103, 104, 105,
       106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118,
       119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131,
       132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144,
       145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157,
       158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170,
       171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 18

In [14]:
# def join_sentence_class(group):
#     return ' '.join(f"{row['sentence']} <{row['class']}>" for _, row in group.iterrows())
    
# # preprocessed_datasets['mimic4'].groupby('case_id').apply(join_sentence_class).reset_index(name='full_class_case').sort_values('case_id')

In [15]:
# cases_datasets = dict()
# for key, dataset in preprocessed_datasets.items():
#     cases_datasets[key] = dataset.groupby('case_id').apply(join_sentence_class).reset_index(name='full_class_case').sort_values('case_id')

In [16]:
# cases_datasets['aug_med_notes']

In [17]:
# cases_datasets['aug_med_notes']

In [18]:
# preprocessed_datasets['aug_med_notes'][preprocessed_datasets['aug_med_notes']['class'] == 'other']

In [19]:
def xy_split_processing_sft(group, x_func=None, y_func=None):
    X = []
    Y = []
    for _, row in group.iterrows():
        if row['class'] == 'symptoms_labs' or row['class'] == 'other':
            X.append(row)
        else:
            Y.append(row)

    # X_case = ' '.join(f"{row['sentence']} <{row['class']}>" for _, row in group.iterrows()) # Input with all
    X_case = x_func(X) if x_func else ' '.join([str(row['sentence']) for row in X])
    Y_case = ' '.join([f"{row['sentence']} <{row['class']}> " for row in Y])    # Output with only thought_process and diagnosis
    
    return pd.Series({'X': X_case, 'Y': Y_case})

import medrlcot.preprocessing as mp
preprocessed_datasets = mp.preprocess_datasets()

# Combine cases into one for cases as example for SFT
cases_datasets = dict()
for key, dataset in preprocessed_datasets.items():
    cases_datasets[key] = dataset.groupby('case_id').apply(mp.xy_split_processing_sft).reset_index().sort_values('case_id')
    cases_datasets[key]['wc_x'] = cases_datasets[key]['X'].apply(lambda x: len(x.split())) 
    cases_datasets[key]['wc_y'] = cases_datasets[key]['Y'].apply(lambda x: len(x.split())) 
    cases_datasets[key] = cases_datasets[key][(cases_datasets[key]['wc_x'] <= 800) & (cases_datasets[key]['wc_y'] <= 900)].reset_index(drop=True)

import os
from medrlcot import data_manager
import medrlcot.config.env as mce
model_cfg_path = os.path.join(os.getcwd(), "medrlcot/config/.env")
medrlcot_config = mce.MedRL_CoT(model_cfg_path)
raw_datasets = data_manager.load_datasets(medrlcot_config.datasets, data_dir=medrlcot_config.data_dir)  # Load raw dataset (original cases) for RM

2025-06-03 18:54:09,530 || INFO || Logger - Setup for MedRL-CoT's log done. This is the beginning of the log.


Generated new log file logs/medrlcot122.log


2025-06-03 18:54:09,734 || INFO || MedRL-CoT Preprocess - Cleaning up aug_med_notes dataset
2025-06-03 18:54:09,736 || INFO || MedRL-CoT Preprocess - Found 19968 rows
2025-06-03 18:54:09,739 || INFO || MedRL-CoT Preprocess - Fixed class naming for 21 rows (0.10516826923076923 %)
2025-06-03 18:54:09,767 || INFO || MedRL-CoT Preprocess - Re-classified 5 classes as 'other', or 76 rows (0.38060897435897434 %)
2025-06-03 18:54:09,768 || INFO || MedRL-CoT Preprocess - Swapped class and sentence values of 131 rows (0.6560496794871795 %)
2025-06-03 18:54:09,769 || INFO || MedRL-CoT Preprocess - Dropped 224 invalid rows (1.1217948717948718 %)
2025-06-03 18:54:09,772 || INFO || MedRL-CoT Preprocess - Cleaning up mimic4 dataset
2025-06-03 18:54:09,773 || INFO || MedRL-CoT Preprocess - Found 29654 rows
2025-06-03 18:54:09,776 || INFO || MedRL-CoT Preprocess - Fixed class naming for 282 rows (0.9509678289606799 %)
2025-06-03 18:54:09,798 || INFO || MedRL-CoT Preprocess - Swapped class and sentence 

In [20]:
cases_datasets['aug_med_notes']

Unnamed: 0,case_id,X,Y,wc_x,wc_y
0,100,Her past medical history included mild Multipl...,"A 34 year old Persian woman, gravida 1, para 0...",147,304
1,101,A 60-year-old female who was previously health...,She was treated with thrombolysis and endovasc...,168,255
2,102,,classification <diagnosis>,0,2
3,103,The chief complains included bilateral hips an...,Staged bilateral total hip arthroplasties were...,362,314
4,104,The patient is a male in his 60s with a past m...,The initial workup was unrevealing for an etio...,236,149
...,...,...,...,...,...
732,93,The patient was a 60-year-old man who referred...,A 45 mm × 37 mm pseudoaneurysm in lateral side...,317,185
733,94,A 38 year old Vietnamese man was admitted with...,The working diagnosis was a collection seconda...,135,415
734,96,A 30-year-old woman was admitted to our instit...,She had undergone endovascular trapping of the...,148,587
735,97,A 51-year-old hypertensive Pakistani male pati...,The patient underwent general anesthesia for t...,275,183


In [21]:
preprocessed_datasets['aug_med_notes']['class'].value_counts()

class
symptoms_labs      9977
thought_process    5318
diagnosis          4373
other                76
Name: count, dtype: int64

In [22]:
# filtered_datasets = dict()
# for key, dataset in cases_datasets.items():
#     dataset['wc_x'] = dataset['X'].apply(lambda x: len(x.split())) 
#     dataset['wc_y'] = dataset['Y'].apply(lambda x: len(x.split())) 
#     filtered_datasets[key] = dataset[(dataset['wc_x'] <= 800) & (dataset['wc_y'] <= 900)].reset_index(drop=True)

# print(cases_datasets['aug_med_notes'].shape, filtered_datasets['aug_med_notes'].shape)
# print(cases_datasets['mimic4'].shape, filtered_datasets['mimic4'].shape)

In [23]:
max_words = (None, 0)
next_max = (None, -1)
num_exceed = 0
for key, dataset in cases_datasets.items():
    for row in dataset.iterrows():
        cnt = ((key, row), len(row[1]['X'].split()))
        if cnt[1] > 750:
            num_exceed += 1
        if cnt[1] > max_words[1]:
            next_max = max_words
            max_words = cnt
        else:
            next_max = max(cnt, next_max, key=lambda x: x[1])
print(max_words, next_max)
num_exceed

(('mimic4', (32, case_id                                                  135
X          History of Present Illness: Mr. ___ is a ___ y...
Y          1. <diagnosis>  2. <diagnosis>  3. <diagnosis>...
wc_x                                                     791
wc_y                                                     486
Name: 32, dtype: object)), 791) (('mimic4', (48, case_id                                                  152
X          Today, patient has a new cough with productive...
Y          She received 1L NS, 5 mg of morphine, 25 mg of...
wc_x                                                     783
wc_y                                                     362
Name: 48, dtype: object)), 783)


13

In [24]:
# Create customd dataset obj
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl")
cases_data = data_manager.MedRL_CoT_Dataset(cases_datasets, seed=0, tokenizer=tokenizer)

2025-06-03 18:54:12,775 || INFO || DataManager - Using train-val split: [0.75, 0.25]
2025-06-03 18:54:12,777 || INFO || DataManager - Split-Shuffle with seed 0
2025-06-03 18:54:12,778 || INFO || DataManager - Splitting aug_med_notes dataset.
2025-06-03 18:54:12,782 || INFO || DataManager - Split into 552 train rows and 185 val rows
2025-06-03 18:54:12,783 || INFO || DataManager - Splitting mimic4 dataset.
2025-06-03 18:54:12,785 || INFO || DataManager - Split into 175 train rows and 59 val rows
2025-06-03 18:54:12,786 || INFO || DataManager - Creating a single joint dataset of dict_keys(['aug_med_notes', 'mimic4']) dataset splits
2025-06-03 18:54:12,789 || INFO || DataManager - Joined 727 train rows
2025-06-03 18:54:12,790 || INFO || DataManager - Shuffling train's rows
2025-06-03 18:54:12,792 || INFO || DataManager - Joined 244 val rows
2025-06-03 18:54:12,792 || INFO || DataManager - Shuffling val's rows
2025-06-03 18:54:12,794 || INFO || DataManager - Returning shuffled train-val sp

In [25]:
cases_data['train']

Unnamed: 0,X,Y
0,"In December 2006, about 6 months earlier than ...",Pain of the first mentioned episode was subsid...
1,The patient was suffering from occasional blee...,An 80-year-old male was referred to us with sq...
2,The culture result showed polymicrobial growth.,The patient was hospitalized for 2 months to r...
3,The 55 years old male patient received a locki...,"The patient consented to implant removal, with..."
4,Prior to treatment with omalizumab and after c...,While mast cells represent the effector cell i...
...,...,...
722,History of Present Illness: Ms. ___ is a ___ w...,"Of note, patient was recently admitted <thoug..."
723,This 36-year-old man presented with a three-mo...,Sinus plain radiographs demonstrated ‘sinusiti...
724,A 60-year-old male who had a history of liver ...,"After admission, the patient maintained with a..."
725,"She also had diabetes mellitus, hypertension, ...",An 80-year-old woman with a history of collaps...


In [26]:
cases_data.get_dataloader('train')

<torch.utils.data.dataloader.DataLoader at 0x7717d35d7c40>

In [27]:
cases_data.get_dataloader('val')

<torch.utils.data.dataloader.DataLoader at 0x7717d35d5a80>

In [28]:
from transformers import AutoModelForSeq2SeqLM
model  = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-xl")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [29]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq

training_args = Seq2SeqTrainingArguments(
    output_dir="./checkpoints",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    predict_with_generate=True,
    fp16=True,
    learning_rate=5e-5,
    num_train_epochs=3,
    save_strategy="epoch",
    logging_dir="./logs"
)

In [30]:
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    padding="max_length",
    return_tensors="pt"
)

In [31]:
def compute_metrics(eval_preds): # https://www.datacamp.com/tutorial/flan-t5-tutorial
   preds, labels = eval_preds

   # decode preds and labels
   labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
   decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
   decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

   # rougeLSum expects newline after each sentence
   decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
   decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]

   result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
  
   return result

In [32]:
train_tok_dataset = cases_data.get_torch_dataset('train')
val_tok_dataset = cases_data.get_torch_dataset('val')

In [33]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_tok_dataset,
    eval_dataset=val_tok_dataset,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

  batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64)
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Step,Training Loss
