In [1]:
%load_ext autoreload
%autoreload 2
# %reload_ext autoreload

### Download required modules
Make sure you're using Python3.10

Use this to make sure all modules required is installed:
```
pip install datasets dotenv torch SentencePiece accelerate huggingface_hub bitsandbytes transformers
```

In [2]:
# Install required modules
# !pip install datasets dotenv torch SentencePiece accelerate huggingface_hub
# !pip install -U bitsandbytes
# !pip install --force-reinstall transformers==4.35.2
# !pip install --force-reinstall --upgrade bitsandbytes

In [3]:
from medrlcot.config.env import MedRL_CoT
from medrlcot import data_manager
from medrlcot.medrlcot_logger import setup_logger
from dotenv import load_dotenv
from datasets import Features, Value
import datasets as hf_datasets
import logging
import os
import json

In [4]:
# env = load_dotenv()
# model_cfg_path = os.path.join(os.getcwd(), os.getenv('model_config'))
model_cfg_path = os.path.join(os.getcwd(), "medrlcot/config/.env")
medrlcot_config = MedRL_CoT(model_cfg_path)

setup_logger()
logger = logging.getLogger("MedRL-CoT Setup")

# Download datasets
datasets = data_manager.load_datasets(medrlcot_config.datasets, data_dir=medrlcot_config.data_dir)

datasets
# print(datasets['aug_med_notes']['full_note'][0])
# data_manager.load_datasets(medrlcot_config.datasets, data_dir=medrlcot_config.data_dir, load=False)

2025-05-31 22:34:26,898 || INFO || Logger - Setup for MedRL-CoT's log done. This is the beginning of the log.
2025-05-31 22:34:26,899 || INFO || DataManager - Loading datasets: ['aug_med_notes', 'mimic4']
2025-05-31 22:34:26,900 || INFO || DataManager - AGBonnet/augmented-clinical-notes dataset already exists in disk. If the dataset is giving errors or you'd like a fresh install, delete the /home/ubuntu/medrlcot/data/aug_med_notes directory.
2025-05-31 22:34:26,901 || INFO || DataManager - Loading saved hugginface AGBonnet/augmented-clinical-notes dataset.
2025-05-31 22:34:26,909 || INFO || DataManager - Successfully loaded AGBonnet/augmented-clinical-notes as key aug_med_notes


Generated new log file logs/medrlcot004.log


Generating train split: 0 examples [00:00, ? examples/s]

Saving the dataset (0/8 shards):   0%|          | 0/331793 [00:00<?, ? examples/s]

2025-05-31 22:35:59,633 || INFO || DataManager - Done loading discharge.csv.gz and saved hugging_face dataset to /home/ubuntu/medrlcot/data/mimic4/hf.
2025-05-31 22:35:59,635 || INFO || DataManager - Successfully loaded 2 datasets: ['aug_med_notes', 'mimic4']


{'aug_med_notes': Dataset({
     features: ['idx', 'note', 'full_note', 'conversation', 'summary'],
     num_rows: 30000
 }),
 'mimic4': DatasetDict({
     train: Dataset({
         features: ['note_id', 'subject_id', 'hadm_id', 'note_type', 'note_seq', 'charttime', 'storetime', 'text'],
         num_rows: 331793
     })
 })}

In [5]:
ckpt_file = os.path.join(os.getcwd(), medrlcot_config.data_dir, 'aug_med_notes', 'checkpoint.json')
processed_dir = os.path.join(os.getcwd(), medrlcot_config.data_dir, 'aug_med_notes', 'processed')

In [6]:
# Get checkpoitn index
checkpoint_idx = 0
if os.path.exists(ckpt_file):
    with open(ckpt_file, 'r') as f:
        checkpoint_idx = json.load(f)['ckpt-idx']

checkpoint_idx

881

In [7]:
# Load model directly
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import bitsandbytes

# from transformers import AutoTokenizer, AutoModelForCausalLM

# tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.3-70B-Instruct")
# model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.3-70B-Instruct")

tokenizer = AutoTokenizer.from_pretrained("NousResearch/Nous-Hermes-2-Mistral-7B-DPO", use_fast=False)
model = AutoModelForCausalLM.from_pretrained("NousResearch/Nous-Hermes-2-Mistral-7B-DPO", torch_dtype=torch.float16, 
                                             device_map="auto", load_in_8bit=False, load_in_4bit=True)
# model = AutoModelForCausalLM.from_pretrained("NousResearch/Nous-Hermes-2-Mistral-7B-DPO")

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
2025-05-31 22:37:08,735 || INFO || accelerate.utils.modeling - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [8]:
import json
import re

def json_to_dict(full_response:str):
    match = re.search(r'assistant\s*(\{.*|\d+:\s*".*?)$', full_response, re.DOTALL)
    json_data = None
    if match:
        json_str = match.group(1)
        try:
            json_data = json.loads(json_str)
            logger.info("Successful JSON parsed!")
        except json.JSONDecodeError as e:
            # print("Invalid JSON:", e)
            logger.warning("Invalid JSON, trying conversion")
            try:
                pattern = re.compile(r'\d+:\s*"(.+?)"\s*[:\-]*>*\s*(\w+)', re.DOTALL)   # Parse to match and obtain json part
                json_data = {}
                # print(json_str)
                for match in pattern.findall(json_str):
                    # print(match)
                    sentence, label = match
                    cleaned_sentence = sentence.strip()
                    json_data[cleaned_sentence] = label.strip()
                assert len(output) > 0
                logger.info("Successful conversion!")
            except:
                logger.error("Provided data schema not parseable/known, skipping this chunk")
                json_data = None
    else:
        logger.error("Nothing found, bad example!")

    return json_data

In [9]:
# Load checkpoints
dataset_example_idx = checkpoint_idx
if checkpoint_idx > 0:
    arrow_dir = os.path.join(processed_dir)
    arrows = [os.path.join(arrow_dir, f) for f in os.listdir(arrow_dir) if f.endswith(".arrow")]
    processed_dataset = hf_datasets.concatenate_datasets([hf_datasets.Dataset.from_file(arrow) for arrow in arrows]).to_dict()
    assert dataset_example_idx == int(processed_dataset['case_id'][-1]) + 1, "Checkpoint index case does not match expected next case index in dataset!"
else:
    processed_dataset = {'sentence': [], 'class': []}
print(dataset_example_idx)
print(len(processed_dataset['sentence']))
processed_dataset

881
19642


{'sentence': ['A sixteen year-old girl, presented to our Outpatient department with the complaints of discomfort in the neck and lower back as well as restriction of body movements.',
  'She was not able to maintain an erect posture and would tend to fall on either side while standing up from a sitting position.',
  'She would keep her head turned to the right and upwards due to the sustained contraction of the neck muscles.',
  'There was a sideways bending of the back in the lumbar region.',
  'To counter the abnormal positioning of the back and neck, she would keep her limbs in a specific position to allow her body weight to be supported.',
  'Due to the restrictions with the body movements at the neck and in the lumbar region, she would require assistance in standing and walking.',
  'She would require her parents to help her with daily chores, including all activities of self-care.',
  'She had been experiencing these difficulties for the past four months since when she was introd

In [None]:
from tqdm import tqdm
from datasets import Dataset
# def preprocess_func(model, dataset):
    # processed_dataset = {'symptoms_labs': [], 'thought_process': [], 'diagnosis': []}
def process_example(example_note):
    sentences = example_note.split('.')
    chunks = example_note.split('\\n')
    
    processed_example = {'sentence': [], 'class': []}
    
    initit_prompt = """<|im_start|>system
You are a medical assistant helping organize a patient's clinical notes  that will presented by paragraph/chunk.
    
Your tasks is to read through the chunk of a clinical note and classify different sentences as one of the following keys "symptoms_labs", "thought_process", and "diagnosis":
symptoms_labs: 
- All references to symptoms they are currently or was experiencing, test results, or lab findings
thought_process: 
- The reasoning, considerations, or requested labs made by the doctor
diagnosis: 
- The doctor's final or working diagnosis, including the explanation or background of a diagnosis.

Respond ONLY with a valid JSON object **EXACTLY** with the sentence as key and the classification as the value: "symptoms_labs", "thought_process", and "diagnosis". 

**Do not rephrase or paraphrase sentences. Use the original sentences exactly as provided.**
<|im_end|>
"""

    
    prompts = [f"""{initit_prompt}
    <|im_start|>user
    {chunk}
    <|im_end|>
    <|im_start|>assistant
    """ for chunk in chunks
    ]

    responses = []
    for prompt in tqdm(prompts, "Classifying prompts", leave=False):
        # print(prompt)
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        output = model.generate(**inputs, max_new_tokens=1024, do_sample=False, pad_token_id=tokenizer.eos_token_id)
        # output = model.generate(**inputs, max_new_tokens=1024, do_sample=False, temperature=0.0)

        response = tokenizer.decode(output[0], skip_special_tokens=True)
        parsed_dict = json_to_dict(response)
        if parsed_dict is not None:
            for sent, cat in parsed_dict.items():
                processed_example['class'].append(str(cat))
                processed_example['sentence'].append(str(sent))
                # processed_dataset[cat.strip()].append(sent)
        # break

    return processed_example

# processed_dataset = []

for i in tqdm(range(0, len(datasets['aug_med_notes']['full_note'])), "Preprocessing dataset"):
    logger.info(f"Preprocessing example {i}")
    if i >= dataset_example_idx:
        processed_example = process_example(datasets['aug_med_notes']['full_note'][i])
        processed_dataset['class'].extend(processed_example['class'])
        processed_dataset['sentence'].extend(processed_example['sentence'])
        processed_dataset['case_id'].extend([str(i) for _ in processed_example['class']])

        logger.info(f"Dumping checkpoint")
        # processed_dir = os.path.join(os.getcwd(), medrlcot_config.data_dir, 'aug_med_notes', 'processed')

        features = Features({
            "sentence": Value("string"),
            "class": Value("string"),
            "case_id": Value("string")
        })
        hf_dataset = Dataset.from_dict(processed_dataset, features=features)
        hf_dataset.save_to_disk(processed_dir)

        dataset_example_idx += 1
        ckpt_preprocess = {
            'ckpt-idx': dataset_example_idx
        }

        with open(ckpt_file, 'w') as f:
            json.dump(ckpt_preprocess, f)
    else:
        logger.info(f"Example {i} already complete, skipping")
    # processed_dataset.append(process_example(example))
    # if i > 0:
    #     break

    


# return processed_dataset

Preprocessing dataset:   0%|                                                                                                                                                        | 0/30000 [00:00<?, ?it/s]2025-05-31 22:41:30,313 || INFO || MedRL-CoT Setup - Preprocessing example 0
2025-05-31 22:41:30,314 || INFO || MedRL-CoT Setup - Example 0 already complete, skipping
2025-05-31 22:41:30,315 || INFO || MedRL-CoT Setup - Preprocessing example 1
2025-05-31 22:41:30,315 || INFO || MedRL-CoT Setup - Example 1 already complete, skipping
2025-05-31 22:41:30,316 || INFO || MedRL-CoT Setup - Preprocessing example 2
2025-05-31 22:41:30,316 || INFO || MedRL-CoT Setup - Example 2 already complete, skipping
2025-05-31 22:41:30,317 || INFO || MedRL-CoT Setup - Preprocessing example 3
2025-05-31 22:41:30,318 || INFO || MedRL-CoT Setup - Example 3 already complete, skipping
2025-05-31 22:41:30,318 || INFO || MedRL-CoT Setup - Preprocessing example 4
2025-05-31 22:41:30,319 || INFO || MedRL-CoT Setu

Saving the dataset (0/1 shards):   0%|          | 0/19671 [00:00<?, ? examples/s]

Preprocessing dataset:   3%|████                                                                                                                                        | 882/30000 [03:03<6:42:49,  1.20it/s]2025-05-31 22:44:33,880 || INFO || MedRL-CoT Setup - Preprocessing example 882

Classifying prompts:   0%|                                                                                                                                                              | 0/4 [00:00<?, ?it/s][A2025-05-31 22:45:12,327 || INFO || MedRL-CoT Setup - Successful JSON parsed!

Classifying prompts:  25%|█████████████████████████████████████▌                                                                                                                | 1/4 [00:38<01:54, 38.25s/it][A2025-05-31 22:46:39,588 || ERROR || MedRL-CoT Setup - Nothing found, bad example!

2025-05-31 22:46:58,962 || ERROR || MedRL-CoT Setup - Provided data schema not parseable/known, skipping this chunk

Classifying prompt

Saving the dataset (0/1 shards):   0%|          | 0/19684 [00:00<?, ? examples/s]

Preprocessing dataset:   3%|████                                                                                                                                       | 883/30000 [05:57<15:43:00,  1.94s/it]2025-05-31 22:47:27,318 || INFO || MedRL-CoT Setup - Preprocessing example 883

2025-05-31 22:47:47,206 || ERROR || MedRL-CoT Setup - Provided data schema not parseable/known, skipping this chunk

Classifying prompts:  50%|███████████████████████████████████████████████████████████████████████████                                                                           | 1/2 [00:19<00:19, 19.78s/it][A2025-05-31 22:48:54,702 || INFO || MedRL-CoT Setup - Successful JSON parsed!

Classifying prompts: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [01:27<00:00, 47.85s/it][A
                                                                                                    

Saving the dataset (0/1 shards):   0%|          | 0/19704 [00:00<?, ? examples/s]

Preprocessing dataset:   3%|████                                                                                                                                       | 884/30000 [07:24<22:02:11,  2.72s/it]2025-05-31 22:48:54,863 || INFO || MedRL-CoT Setup - Preprocessing example 884

2025-05-31 22:49:09,123 || ERROR || MedRL-CoT Setup - Provided data schema not parseable/known, skipping this chunk

2025-05-31 22:49:47,972 || ERROR || MedRL-CoT Setup - Provided data schema not parseable/known, skipping this chunk

2025-05-31 22:50:13,712 || ERROR || MedRL-CoT Setup - Provided data schema not parseable/known, skipping this chunk

Classifying prompts: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [01:18<00:00, 27.34s/it][A
                                                                                                                                                         

Saving the dataset (0/1 shards):   0%|          | 0/19704 [00:00<?, ? examples/s]

Preprocessing dataset:   3%|████                                                                                                                                       | 885/30000 [08:43<29:58:40,  3.71s/it]2025-05-31 22:50:13,871 || INFO || MedRL-CoT Setup - Preprocessing example 885

Classifying prompts:   0%|                                                                                                                                                              | 0/5 [00:00<?, ?it/s][A2025-05-31 22:50:46,467 || INFO || MedRL-CoT Setup - Successful JSON parsed!

2025-05-31 22:51:03,414 || ERROR || MedRL-CoT Setup - Provided data schema not parseable/known, skipping this chunk

2025-05-31 22:51:18,412 || ERROR || MedRL-CoT Setup - Provided data schema not parseable/known, skipping this chunk

Classifying prompts:  60%|██████████████████████████████████████████████████████████████████████████████████████████                                                            | 3/5 [01:04<00:3

Saving the dataset (0/1 shards):   0%|          | 0/19724 [00:00<?, ? examples/s]

Preprocessing dataset:   3%|████                                                                                                                                       | 886/30000 [10:29<44:56:13,  5.56s/it]2025-05-31 22:52:00,014 || INFO || MedRL-CoT Setup - Preprocessing example 886

Classifying prompts:   0%|                                                                                                                                                              | 0/4 [00:00<?, ?it/s][A2025-05-31 22:52:30,737 || INFO || MedRL-CoT Setup - Successful JSON parsed!

2025-05-31 22:52:46,550 || ERROR || MedRL-CoT Setup - Provided data schema not parseable/known, skipping this chunk

2025-05-31 22:53:06,350 || ERROR || MedRL-CoT Setup - Provided data schema not parseable/known, skipping this chunk

2025-05-31 22:53:32,833 || ERROR || MedRL-CoT Setup - Provided data schema not parseable/known, skipping this chunk

Classifying prompts: 100%|██████████████████████████████████████████████████

Saving the dataset (0/1 shards):   0%|          | 0/19735 [00:00<?, ? examples/s]

Preprocessing dataset:   3%|████                                                                                                                                       | 887/30000 [12:02<62:42:55,  7.76s/it]2025-05-31 22:53:32,993 || INFO || MedRL-CoT Setup - Preprocessing example 887

2025-05-31 22:53:46,942 || ERROR || MedRL-CoT Setup - Provided data schema not parseable/known, skipping this chunk

2025-05-31 22:54:05,552 || ERROR || MedRL-CoT Setup - Provided data schema not parseable/known, skipping this chunk

2025-05-31 22:54:28,985 || ERROR || MedRL-CoT Setup - Provided data schema not parseable/known, skipping this chunk

Classifying prompts:  75%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                     | 3/4 [00:55<00:19, 19.75s/it][A

In [12]:
processed_dataset['sentence'][-40:-10]

['The patient was electively taken to the operating theatre for a middle ear exploration and mastoidectomy with the intention of confirming the nature of the tumour and performing a subtotal removal.',
 'Intraoperative frozen section raised the possibility of a neuroendocrine tumour.',
 'The patient was successfully discharged on postoperative day 1.',
 'At 2 months follow-up postoperative hearing was preserved with normal facial nerve function and no evidence of recurrence.',
 'Examination of histological sections stained with haematoxylin and eosin revealed small fragments of mucosa with underlying pieces of vital bone.',
 'The mucosa was surfaced by nonkeratinising stratified squamous epithelium which showed no epithelial dysplasia with no evidence of surface origin of the tumour.',
 'The underlying lamina propria was extensively infiltrated by an unencapsulated tumour.',
 'The tumour cells were arranged in small irregular nests and trabeculae with surrounding fibrosis.',
 'Occasion

In [None]:
processed_dataset['class'][-40:-10]

In [274]:
# processed_dataset = preprocess_func(model, datasets['aug_med_notes']['full_note'])

The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:32000 for open-end generation.


KeyboardInterrupt: 

In [273]:
# print(len(processed_dataset['class']), len(processed_dataset['sentence']))

53 53


In [330]:
Dataset.from_file(os.path.join(processed_dir, 'data-00000-of-00001.arrow'))

Dataset({
    features: ['sentence', 'class', 'case_id'],
    num_rows: 569
})