In [1]:
%load_ext autoreload
%autoreload 2
# %reload_ext autoreload

### Download required modules
Make sure you're using Python3.10

Use this to make sure all modules required is installed:
```
pip install datasets dotenv torch SentencePiece accelerate huggingface_hub bitsandbytes transformers
```

In [2]:
# Install required modules
# !pip install datasets dotenv torch SentencePiece accelerate huggingface_hub
# !pip install -U bitsandbytes
# !pip install --force-reinstall transformers==4.35.2
# !pip install --force-reinstall --upgrade bitsandbytes

In [3]:
from medrlcot.config.env import MedRL_CoT
from medrlcot import data_manager
from medrlcot.medrlcot_logger import setup_logger
from dotenv import load_dotenv
from datasets import Features, Value
import datasets as hf_datasets
import logging
import os
import json

# MIMIC-IV Processing/Labeling

In [4]:
# env = load_dotenv()
# model_cfg_path = os.path.join(os.getcwd(), os.getenv('model_config'))
model_cfg_path = os.path.join(os.getcwd(), "medrlcot/config/.env")
medrlcot_config = MedRL_CoT(model_cfg_path)

setup_logger()
logger = logging.getLogger("MedRL-CoT Processing")

# Download datasets
datasets = data_manager.load_datasets(medrlcot_config.datasets, data_dir=medrlcot_config.data_dir)

datasets
# print(datasets['aug_med_notes']['full_note'][0])
# data_manager.load_datasets(medrlcot_config.datasets, data_dir=medrlcot_config.data_dir, load=False)

2025-05-31 23:47:29,822 || INFO || Logger - Setup for MedRL-CoT's log done. This is the beginning of the log.
2025-05-31 23:47:29,823 || INFO || DataManager - Loading datasets: ['aug_med_notes', 'mimic4']
2025-05-31 23:47:29,824 || INFO || DataManager - AGBonnet/augmented-clinical-notes dataset already exists in disk. If the dataset is giving errors or you'd like a fresh install, delete the /home/shared/medrlcot/data/aug_med_notes/train directory.
2025-05-31 23:47:29,825 || INFO || DataManager - Loading saved hugginface AGBonnet/augmented-clinical-notes dataset.
2025-05-31 23:47:29,832 || INFO || DataManager - Successfully loaded AGBonnet/augmented-clinical-notes as key aug_med_notes
2025-05-31 23:47:29,833 || INFO || DataManager - discharge.csv.gz dataset already exists in disk as huggin_face dataset. If the dataset is giving errors or you'd like a fresh install, delete the /home/shared/medrlcot/data/mimic4/hf directory.
2025-05-31 23:47:29,834 || INFO || DataManager - Loading saved h

Generated new log file logs/medrlcot010.log


{'aug_med_notes': Dataset({
     features: ['idx', 'note', 'full_note', 'conversation', 'summary'],
     num_rows: 30000
 }),
 'mimic4': Dataset({
     features: ['note_id', 'subject_id', 'hadm_id', 'note_type', 'note_seq', 'charttime', 'storetime', 'text'],
     num_rows: 331793
 })}

In [5]:
ckpt_file = os.path.join(os.getcwd(), medrlcot_config.data_dir, 'mimic4', 'checkpoint.json')
processed_dir = os.path.join(os.getcwd(), medrlcot_config.data_dir, 'mimic4', 'processed')

In [6]:
# Get checkpoitn index
checkpoint_idx = 0
if os.path.exists(ckpt_file):
    with open(ckpt_file, 'r') as f:
        checkpoint_idx = json.load(f)['ckpt-idx']

checkpoint_idx

284

In [7]:
# Load model directly
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import bitsandbytes

# from transformers import AutoTokenizer, AutoModelForCausalLM

# tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.3-70B-Instruct")
# model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.3-70B-Instruct")

tokenizer = AutoTokenizer.from_pretrained("NousResearch/Nous-Hermes-2-Mistral-7B-DPO", use_fast=False)
model = AutoModelForCausalLM.from_pretrained("NousResearch/Nous-Hermes-2-Mistral-7B-DPO", torch_dtype=torch.float16, 
                                             device_map="auto", load_in_8bit=False, bnb_4bit_compute_dtype=torch.float16, load_in_4bit=True)
# model = AutoModelForCausalLM.from_pretrained("NousResearch/Nous-Hermes-2-Mistral-7B-DPO")

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
2025-05-31 23:47:34,035 || INFO || accelerate.utils.modeling - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [8]:
import json
import re

def json_to_dict(full_response:str):
    match = re.search(r'assistant\s*(\{.*|\d+:\s*".*?)$', full_response, re.DOTALL)
    json_data = None
    if match:
        json_str = match.group(1)
        try:
            json_data = json.loads(json_str)
            logger.info("Successful JSON parsed!")
        except json.JSONDecodeError as e:
            # print("Invalid JSON:", e)
            logger.warning("Invalid JSON, trying conversion")
            try:
                pattern = re.compile(r'\d+:\s*"(.+?)"\s*[:\-]*>*\s*(\w+)', re.DOTALL)   # Parse to match and obtain json part
                json_data = {}
                # print(json_str)
                for match in pattern.findall(json_str):
                    # print(match)
                    sentence, label = match
                    cleaned_sentence = sentence.strip()
                    json_data[cleaned_sentence] = label.strip()
                assert len(json_data) > 0
                logger.info("Successful conversion!")
            except:
                logger.error("Provided data schema not parseable/known, skipping this chunk")
                json_data = None
    else:
        logger.error("Nothing found, bad example!")

    return json_data

In [9]:
# Load checkpoints
dataset_example_idx = checkpoint_idx
if checkpoint_idx > 0:
    arrow_dir = os.path.join(processed_dir)
    arrows = [os.path.join(arrow_dir, f) for f in os.listdir(arrow_dir) if f.endswith(".arrow")]
    processed_dataset = hf_datasets.concatenate_datasets([hf_datasets.Dataset.from_file(arrow) for arrow in arrows]).to_dict()
    assert dataset_example_idx == int(processed_dataset['case_id'][-1]) + 1, "Checkpoint index case does not match expected next case index in dataset!"
else:
    processed_dataset = {'sentence': [], 'class': []}
print(dataset_example_idx)
print(len(processed_dataset['sentence']))
processed_dataset

284
29654


{'sentence': ['History of Present Illness:',
  "Pt reports self-discontinuing lasix and spirnolactone ___ weeks ago, because she feels like 'they don't do anything' and that she 'doesn't want to put more chemicals in her.'",
  'She does not follow Na-restricted diets.',
  'In the past week, she notes that she has been having worsening abd distension and discomfort.',
  'She denies ___ edema, or SOB, or orthopnea.',
  'She denies f/c/n/v, d/c, dysuria.',
  'She had food poisoning a week ago from eating stale cake (n/v 20 min after food ingestion), which resolved the same day.',
  'She denies other recent illness or sick contacts.',
  'She notes that she has been noticing gum bleeding while brushing her teeth in recent weeks.',
  'she denies easy bruising, melena, BRBPR, hemetesis, hemoptysis, or hematuria.',
  'Because of her abd pain, she went to OSH ED and was transferred to ___ for further care.',
  'Per ED report, pt has brief period of confusion - she did not recall the ultrasound 

In [None]:
from tqdm import tqdm
from datasets import Dataset
# def preprocess_func(model, dataset):
    # processed_dataset = {'symptoms_labs': [], 'thought_process': [], 'diagnosis': []}
def process_example(example_note):
    sentences = example_note.split('.')
    chunks = example_note.split('\n\n')
    print(len(chunks))
    
    processed_example = {'sentence': [], 'class': []}
    
    initit_prompt = """<|im_start|>system
You are a medical assistant helping organize a patient's clinical notes  that will presented by paragraph/chunk.
    
Your tasks is to read through the chunk of a clinical note and classify different sentences as one of the following keys "symptoms_labs", "thought_process", and "diagnosis":
symptoms_labs: 
- All references to symptoms they are currently or was experiencing, test results, or lab findings
thought_process: 
- The reasoning, considerations, or requested labs made by the doctor
diagnosis: 
- The doctor's final or working diagnosis, including the explanation or background of a diagnosis.

Respond ONLY with a valid JSON object **EXACTLY** with the sentence as key and the classification as the value: "symptoms_labs", "thought_process", and "diagnosis". 

**Do not rephrase or paraphrase sentences. Use the original sentences exactly as provided.**
<|im_end|>
"""

    
    prompts = [f"""{initit_prompt}
    <|im_start|>user
    {chunk}
    <|im_end|>
    <|im_start|>assistant
    """ for chunk in chunks
    ]

    responses = []
    for prompt in tqdm(prompts, "Classifying prompts", leave=False):
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        output = model.generate(**inputs, max_new_tokens=1024, do_sample=False, pad_token_id=tokenizer.eos_token_id)
        # output = model.generate(**inputs, max_new_tokens=1024, do_sample=False, temperature=0.0)

        response = tokenizer.decode(output[0], skip_special_tokens=True)
        parsed_dict = json_to_dict(response)
        if parsed_dict is not None:
            for sent, cat in parsed_dict.items():
                processed_example['class'].append(str(cat))
                processed_example['sentence'].append(str(sent))
                # processed_dataset[cat.strip()].append(sent)
        # break

    return processed_example

# processed_dataset = []

for i in tqdm(range(0, len(datasets['mimic4'])), "Preprocessing dataset"):
# for i in tqdm(range(0, len(datasets['aug_med_notes']['full_note'])), "Preprocessing dataset"):
    logger.info(f"Preprocessing example {i}")
    if i >= dataset_example_idx:
        processed_example = process_example(datasets['mimic4'][i]['text'])
        processed_dataset['class'].extend(processed_example['class'])
        processed_dataset['sentence'].extend(processed_example['sentence'])
        processed_dataset['case_id'].extend([str(i) for _ in processed_example['class']])
            
        logger.info(f"Dumping checkpoint")
        # processed_dir = os.path.join(os.getcwd(), medrlcot_config.data_dir, 'mimic4', 'processed')
        features = Features({
            "sentence": Value("string"),
            "class": Value("string"),
            "case_id": Value("string")
        })
        hf_dataset = Dataset.from_dict(processed_dataset, features=features)
        hf_dataset.save_to_disk(processed_dir)
        # hf_dataset = Dataset.from_dict(processed_dataset)
        # hf_dataset.save_to_disk(processed_dir)
        
        dataset_example_idx += 1
        ckpt_preprocess = {
            'ckpt-idx': dataset_example_idx
        }

        with open(ckpt_file, 'w') as f:
            json.dump(ckpt_preprocess, f)
    else:
        logger.info(f"Example {i} already complete, skipping")
    # processed_dataset.append(process_example(example))
    # if i > 0:
    #     break



# return processed_dataset

Preprocessing dataset:   0%|                         | 0/331793 [00:00<?, ?it/s]2025-05-31 22:33:00,850 || INFO || MedRL-CoT Setup - Preprocessing example 0
2025-05-31 22:33:00,851 || INFO || MedRL-CoT Setup - Example 0 already complete, skipping
2025-05-31 22:33:00,852 || INFO || MedRL-CoT Setup - Preprocessing example 1
2025-05-31 22:33:00,853 || INFO || MedRL-CoT Setup - Example 1 already complete, skipping
2025-05-31 22:33:00,853 || INFO || MedRL-CoT Setup - Preprocessing example 2
2025-05-31 22:33:00,854 || INFO || MedRL-CoT Setup - Example 2 already complete, skipping
2025-05-31 22:33:00,855 || INFO || MedRL-CoT Setup - Preprocessing example 3
2025-05-31 22:33:00,855 || INFO || MedRL-CoT Setup - Example 3 already complete, skipping
2025-05-31 22:33:00,860 || INFO || MedRL-CoT Setup - Preprocessing example 4
2025-05-31 22:33:00,861 || INFO || MedRL-CoT Setup - Example 4 already complete, skipping
2025-05-31 22:33:00,862 || INFO || MedRL-CoT Setup - Preprocessing example 5
2025-05-

5



Classifying prompts:   0%|                                | 0/5 [00:00<?, ?it/s][A2025-05-31 22:33:18,018 || INFO || MedRL-CoT Setup - Successful JSON parsed!

Preprocessing dataset:   0%|              | 244/331793 [00:19<09:14, 597.96it/s]2025-05-31 22:33:46,138 || INFO || MedRL-CoT Setup - Successful JSON parsed!

Classifying prompts:  40%|█████████▌              | 2/5 [00:44<01:10, 23.42s/it][A2025-05-31 22:34:07,837 || INFO || MedRL-CoT Setup - Successful JSON parsed!

Classifying prompts:  60%|██████████████▍         | 3/5 [01:06<00:45, 22.63s/it][A2025-05-31 22:34:34,327 || INFO || MedRL-CoT Setup - Successful JSON parsed!

Classifying prompts:  80%|███████████████████▏    | 4/5 [01:33<00:24, 24.16s/it][A2025-05-31 22:34:53,937 || INFO || MedRL-CoT Setup - Successful JSON parsed!

Classifying prompts: 100%|████████████████████████| 5/5 [01:52<00:00, 22.52s/it][A
                                                                                [A2025-05-31 22:34:53,943 || INF

Saving the dataset (0/1 shards):   0%|          | 0/29048 [00:00<?, ? examples/s]

Preprocessing dataset:   0%|            | 276/331793 [01:53<74:14:44,  1.24it/s]2025-05-31 22:34:54,170 || INFO || MedRL-CoT Setup - Preprocessing example 276


27



2025-05-31 22:35:36,195 || ERROR || MedRL-CoT Setup - Provided data schema not parseable/known, skipping this chunk

Classifying prompts:   4%|▊                      | 1/27 [00:42<18:12, 42.02s/it][A2025-05-31 22:35:56,246 || INFO || MedRL-CoT Setup - Successful JSON parsed!

Classifying prompts:   7%|█▋                     | 2/27 [01:02<12:07, 29.10s/it][A2025-05-31 22:36:12,032 || INFO || MedRL-CoT Setup - Successful JSON parsed!

Classifying prompts:  11%|██▌                    | 3/27 [01:17<09:12, 23.02s/it][A2025-05-31 22:36:13,637 || INFO || MedRL-CoT Setup - Successful JSON parsed!

Classifying prompts:  15%|███▍                   | 4/27 [01:19<05:34, 14.57s/it][A2025-05-31 22:36:15,820 || INFO || MedRL-CoT Setup - Successful JSON parsed!

Classifying prompts:  19%|████▎                  | 5/27 [01:21<03:42, 10.10s/it][A2025-05-31 22:36:30,681 || INFO || MedRL-CoT Setup - Successful JSON parsed!

Classifying prompts:  22%|█████                  | 6/27 [01:36<04:06, 11.72s/

Saving the dataset (0/1 shards):   0%|          | 0/29171 [00:00<?, ? examples/s]

Preprocessing dataset:   0%|           | 277/331793 [07:16<373:10:36,  4.05s/it]2025-05-31 22:40:16,950 || INFO || MedRL-CoT Setup - Preprocessing example 277


11



Classifying prompts:   0%|                               | 0/11 [00:00<?, ?it/s][A2025-05-31 22:40:18,392 || INFO || MedRL-CoT Setup - Successful JSON parsed!

Classifying prompts:   9%|██                     | 1/11 [00:01<00:14,  1.44s/it][A2025-05-31 22:40:26,717 || INFO || MedRL-CoT Setup - Successful JSON parsed!

Classifying prompts:  18%|████▏                  | 2/11 [00:09<00:49,  5.49s/it][A2025-05-31 22:40:30,927 || INFO || MedRL-CoT Setup - Successful JSON parsed!

Classifying prompts:  27%|██████▎                | 3/11 [00:13<00:39,  4.91s/it][A2025-05-31 22:40:39,447 || INFO || MedRL-CoT Setup - Successful JSON parsed!

Classifying prompts:  36%|████████▎              | 4/11 [00:22<00:44,  6.33s/it][A2025-05-31 22:40:46,391 || INFO || MedRL-CoT Setup - Successful JSON parsed!

Classifying prompts:  45%|██████████▍            | 5/11 [00:29<00:39,  6.55s/it][A2025-05-31 22:40:48,760 || INFO || MedRL-CoT Setup - Successful JSON parsed!

Classifying prompts:  55%|███████

Saving the dataset (0/1 shards):   0%|          | 0/29226 [00:00<?, ? examples/s]

Preprocessing dataset:   0%|           | 278/331793 [09:21<531:24:07,  5.77s/it]2025-05-31 22:42:22,017 || INFO || MedRL-CoT Setup - Preprocessing example 278


22



Classifying prompts:   0%|                               | 0/22 [00:00<?, ?it/s][A2025-05-31 22:42:31,942 || INFO || MedRL-CoT Setup - Successful JSON parsed!

2025-05-31 22:43:23,817 || ERROR || MedRL-CoT Setup - Provided data schema not parseable/known, skipping this chunk

Classifying prompts:   9%|██                     | 2/22 [01:01<11:31, 34.60s/it][A2025-05-31 22:43:25,388 || INFO || MedRL-CoT Setup - Successful JSON parsed!

Classifying prompts:  14%|███▏                   | 3/22 [01:03<06:10, 19.52s/it][A2025-05-31 22:43:27,803 || INFO || MedRL-CoT Setup - Successful JSON parsed!

Classifying prompts:  18%|████▏                  | 4/22 [01:05<03:49, 12.77s/it][A2025-05-31 22:43:49,024 || INFO || MedRL-CoT Setup - Successful JSON parsed!

Classifying prompts:  23%|█████▏                 | 5/22 [01:27<04:28, 15.82s/it][A2025-05-31 22:44:33,579 || ERROR || MedRL-CoT Setup - Nothing found, bad example!

Classifying prompts:  27%|██████▎                | 6/22 [02:11<06:49, 25

Saving the dataset (0/1 shards):   0%|          | 0/29323 [00:00<?, ? examples/s]

Preprocessing dataset:   0%|          | 279/331793 [13:51<1014:54:35, 11.02s/it]2025-05-31 22:46:51,898 || INFO || MedRL-CoT Setup - Preprocessing example 279


22



Classifying prompts:   0%|                               | 0/22 [00:00<?, ?it/s][A2025-05-31 22:47:01,792 || INFO || MedRL-CoT Setup - Successful JSON parsed!

Classifying prompts:   5%|█                      | 1/22 [00:09<03:27,  9.89s/it][A2025-05-31 22:47:29,973 || INFO || MedRL-CoT Setup - Successful JSON parsed!

Classifying prompts:   9%|██                     | 2/22 [00:38<06:52, 20.65s/it][A2025-05-31 22:47:32,408 || INFO || MedRL-CoT Setup - Successful JSON parsed!

Classifying prompts:  14%|███▏                   | 3/22 [00:40<03:54, 12.33s/it][A2025-05-31 22:48:24,156 || INFO || MedRL-CoT Setup - Successful JSON parsed!

Classifying prompts:  18%|████▏                  | 4/22 [01:32<08:22, 27.89s/it][A2025-05-31 22:48:32,422 || INFO || MedRL-CoT Setup - Successful JSON parsed!

Classifying prompts:  23%|█████▏                 | 5/22 [01:40<05:53, 20.82s/it][A2025-05-31 22:48:34,572 || INFO || MedRL-CoT Setup - Successful JSON parsed!

Classifying prompts:  27%|██████▎

Saving the dataset (0/1 shards):   0%|          | 0/29430 [00:00<?, ? examples/s]

Preprocessing dataset:   0%|          | 280/331793 [17:17<1511:00:39, 16.41s/it]2025-05-31 22:50:18,000 || INFO || MedRL-CoT Setup - Preprocessing example 280


12



2025-05-31 22:50:35,048 || ERROR || MedRL-CoT Setup - Provided data schema not parseable/known, skipping this chunk

Classifying prompts:   8%|█▉                     | 1/12 [00:17<03:07, 17.04s/it][A2025-05-31 22:50:37,769 || INFO || MedRL-CoT Setup - Successful JSON parsed!

2025-05-31 22:51:46,919 || ERROR || MedRL-CoT Setup - Provided data schema not parseable/known, skipping this chunk

Classifying prompts:  25%|█████▊                 | 3/12 [01:28<05:26, 36.26s/it][A2025-05-31 22:51:49,090 || INFO || MedRL-CoT Setup - Successful JSON parsed!

Classifying prompts:  33%|███████▋               | 4/12 [01:31<03:02, 22.80s/it][A2025-05-31 22:51:56,072 || INFO || MedRL-CoT Setup - Successful JSON parsed!

Classifying prompts:  42%|█████████▌             | 5/12 [01:38<01:59, 17.10s/it][A2025-05-31 22:51:58,537 || INFO || MedRL-CoT Setup - Successful JSON parsed!

Classifying prompts:  50%|███████████▌           | 6/12 [01:40<01:12, 12.12s/it][A2025-05-31 22:52:15,912 || INFO || Med

Saving the dataset (0/1 shards):   0%|          | 0/29483 [00:00<?, ? examples/s]

Preprocessing dataset:   0%|          | 281/331793 [21:01<2237:03:05, 24.29s/it]2025-05-31 22:54:02,143 || INFO || MedRL-CoT Setup - Preprocessing example 281


15



2025-05-31 22:54:14,567 || ERROR || MedRL-CoT Setup - Provided data schema not parseable/known, skipping this chunk

Classifying prompts:   7%|█▌                     | 1/15 [00:12<02:53, 12.42s/it][A2025-05-31 22:54:17,130 || INFO || MedRL-CoT Setup - Successful JSON parsed!

Classifying prompts:  13%|███                    | 2/15 [00:14<01:26,  6.62s/it][A2025-05-31 22:54:19,898 || INFO || MedRL-CoT Setup - Successful JSON parsed!

Classifying prompts:  20%|████▌                  | 3/15 [00:17<00:58,  4.86s/it][A2025-05-31 22:54:32,721 || INFO || MedRL-CoT Setup - Successful JSON parsed!

Classifying prompts:  27%|██████▏                | 4/15 [00:30<01:28,  8.01s/it][A

In [10]:
processed_dataset['sentence'][-40:-10]

['-There are complex (>4mm) and simple atheroma in the descending thoracic aorta.',
 'There are complex & simple atheroma in the aortic arch.',
 '-There are three aortic valve leaflets.',
 'The aortic valve leaflets are severely thickened/deformed.',
 'There is critical aortic valve stenosis (valve area <0.8cm2).',
 'Mild to moderate aortic regurgitation is seen.',
 '-The mitral valve leaflets are moderately thickened, with severe mitral annular calcification.',
 'Moderate to severe (3+) mitral regurgitation is seen.',
 '-There is no pericardial effusion.',
 'The patient is AV paced on low dose phenylephrine infusion.',
 'There is a well seated bioprosthetic valve in the aortic position.',
 'There is appropriate leaflet excursion.',
 'There is no AI.',
 'Gradients across aortic valve are appropriate.',
 'The mitral valve regurgitation is unchanged vs slightly worsened from prebypass exam.',
 'Biventricular function is maintained.',
 'The aorta remains intact.',
 'Dr. ___ was notified o

In [11]:
processed_dataset['class'][-40:-10]

['diagnosis',
 'diagnosis',
 'diagnosis',
 'diagnosis',
 'diagnosis',
 'diagnosis',
 'diagnosis',
 'diagnosis',
 'diagnosis',
 'symptoms_labs',
 'symptoms_labs',
 'symptoms_labs',
 'symptoms_labs',
 'symptoms_labs',
 'symptoms_labs',
 'symptoms_labs',
 'symptoms_labs',
 'thought_process',
 'thought_process',
 'thought_process',
 'diagnosis',
 'diagnosis',
 'diagnosis',
 'diagnosis',
 'diagnosis',
 'diagnosis',
 'diagnosis',
 'symptoms_labs',
 'symptoms_labs',
 'symptoms_labs']

In [17]:
hf_dataset = hf_datasets.Dataset.from_dict(processed_dataset)

AttributeError: 'Dataset' object has no attribute 'items'

In [54]:
processed_dir = os.path.join(os.getcwd(), medrlcot_config.data_dir, 'mimic4', 'processed')
hf_dataset = Dataset.from_dict(processed_dataset)
hf_dataset.save_to_disk(processed_dir)

Saving the dataset (0/1 shards):   0%|          | 0/666 [00:00<?, ? examples/s]

# AUG Medical Notes Processing/Labeling

In [None]:
# Get checkpoitn index
checkpoint_idx = 0
if os.path.exists(ckpt_file):
    with open(ckpt_file, 'r') as f:
        checkpoint_idx = json.load(f)['ckpt-idx']

checkpoint_idx

In [None]:
ckpt_file = os.path.join(os.getcwd(), medrlcot_config.data_dir, 'aug_med_notes', 'checkpoint.json')
processed_dir = os.path.join(os.getcwd(), medrlcot_config.data_dir, 'aug_med_notes', 'processed')

In [None]:
# Load checkpoints
dataset_example_idx = checkpoint_idx
if checkpoint_idx > 0:
    arrow_dir = os.path.join(processed_dir)
    arrows = [os.path.join(arrow_dir, f) for f in os.listdir(arrow_dir) if f.endswith(".arrow")]
    processed_dataset = hf_datasets.concatenate_datasets([hf_datasets.Dataset.from_file(arrow) for arrow in arrows]).to_dict()
    assert dataset_example_idx == int(processed_dataset['case_id'][-1]) + 1, "Checkpoint index case does not match expected next case index in dataset!"
else:
    processed_dataset = {'sentence': [], 'class': []}
print(dataset_example_idx)
print(len(processed_dataset['sentence']))
processed_dataset

In [None]:
for i in tqdm(range(0, len(datasets['aug_med_notes']['full_note'])), "Preprocessing dataset"):
    logger.info(f"Preprocessing example {i}")
    if i >= dataset_example_idx:
        processed_example = process_example(datasets['aug_med_notes']['full_note'][i])
        processed_dataset['class'].extend(processed_example['class'])
        processed_dataset['sentence'].extend(processed_example['sentence'])
        processed_dataset['case_id'].extend([str(i) for _ in processed_example['class']])

        logger.info(f"Dumping checkpoint")
        # processed_dir = os.path.join(os.getcwd(), medrlcot_config.data_dir, 'aug_med_notes', 'processed')

        features = Features({
            "sentence": Value("string"),
            "class": Value("string"),
            "case_id": Value("string")
        })
        hf_dataset = Dataset.from_dict(processed_dataset, features=features)
        hf_dataset.save_to_disk(processed_dir)

        dataset_example_idx += 1
        ckpt_preprocess = {
            'ckpt-idx': dataset_example_idx
        }

        with open(ckpt_file, 'w') as f:
            json.dump(ckpt_preprocess, f)
    else:
        logger.info(f"Example {i} already complete, skipping")