## Utils

### ✅ Checkpoint & Resume Feature Added

This notebook now includes automatic checkpoint/resume functionality to handle API quota limits:

**Key Features:**
1. **Auto-save progress**: Each processing step saves progress after each record
2. **Smart resume**: When you re-run a cell, it automatically resumes from where it stopped
3. **Rate limiting**: Automatically adds 7-second delay between API calls to prevent 429 errors
4. **Smart retry**: Detects 429 quota errors, extracts retry delay from error message, and waits before retrying
5. **Progress tracking**: Shows how many records have been processed

**Rate Limiting:**
- Free tier limit: 10 requests per minute
- Automatic delay: 7 seconds between each API call (configurable via `API_DELAY_SECONDS`)
- If rate limit is hit, the code will automatically wait for the suggested retry time before continuing

**How to use:**
- Just run the cells normally
- The code will automatically handle rate limiting with delays between requests
- If you hit quota limit, the progress is saved automatically
- Re-run the same cell later (after quota resets) - it will resume from where it stopped
- No need to start from beginning!

**Files saved as checkpoints:**
- `note_simp.pkl` - Simplified notes
- `history_of_present_illness.pkl` - History processing
- `past_medical_history.pkl` - Past medical history
- `allergies.pkl` - Allergies extraction
- `med_on_adm.pkl` - Medications on admission

In [None]:
import pandas as pd
import dill
from google import genai
import time
from tqdm import tqdm, trange
import os
import re

# Sử dụng model Gemini mới (có thể đổi sang gemini-1.5-flash-latest nếu muốn rẻ hơn)
GEMINI_MODEL = "gemini-2.5-flash"

# Client sẽ được khởi tạo ở cell cấu hình API
client: genai.Client | None = None

# Rate limiting configuration
# Free tier: 10 requests per minute = 6 seconds between requests minimum
# Adding extra buffer: 7 seconds to be safe
API_DELAY_SECONDS = 7.0  # Delay between API calls in seconds
_last_api_call_time = 0  # Track last API call time


def call_api(content: str) -> str:
    """Gọi Gemini API bằng Google GenAI SDK mới.

    :param content: prompt, str
    :return: model response, str
    """
    global client, _last_api_call_time

    if client is None:
        raise RuntimeError("Gemini client is not initialized. Run the configuration cell first.")

    # Rate limiting: ensure minimum delay between API calls
    current_time = time.time()
    time_since_last_call = current_time - _last_api_call_time
    if time_since_last_call < API_DELAY_SECONDS:
        sleep_time = API_DELAY_SECONDS - time_since_last_call
        time.sleep(sleep_time)

    _last_api_call_time = time.time()

    response = client.models.generate_content(
        model=GEMINI_MODEL,
        contents=content,
        # Có thể thêm config nếu muốn cố định temperature:
        # config=genai.types.GenerateContentConfig(temperature=0),
    )
    return response.text


def extract_retry_delay(error_msg: str) -> float:
    """Extract retry delay from error message if available."""
    # Try to extract retry delay from error message
    # Format: "Please retry in X.XXXXXXs"
    match = re.search(r'retry in ([\d.]+)s', error_msg, re.IGNORECASE)
    if match:
        try:
            return float(match.group(1)) + 1.0  # Add 1 second buffer
        except:
            pass
    return None


def get_msg(content: str) -> str:
    """Gọi Gemini với retry đơn giản. Raise exception nếu vượt quota."""
    max_retries = 3
    for attempt in range(max_retries):
        try:
            return call_api(content)
        except Exception as e:
            error_msg = str(e)
            # Nếu là lỗi quota, extract retry delay và wait
            if "429" in error_msg or "RESOURCE_EXHAUSTED" in error_msg or "quota" in error_msg.lower():
                retry_delay = extract_retry_delay(error_msg)
                if retry_delay:
                    print(f"\n⚠️ Rate limit hit. Waiting {retry_delay:.1f} seconds before retry...")
                    time.sleep(retry_delay)
                    # Continue to retry instead of raising immediately
                    if attempt < max_retries - 1:
                        continue
                else:
                    print(f"\n⚠️ Quota exceeded! Checkpoint saved. Please try again later.")
                    raise
            print(f"Attempt {attempt + 1} failed: {e}")
            if attempt < max_retries - 1:
                time.sleep(2 ** attempt)
            else:
                raise


def simplify_note(note):
    prompt = 'Please summarize specific sections from a patient\'s discharge summary: 1. HISTORY OF PRESENT ILLNESS, 2. PAST MEDICAL HISTORY, 3. ALLERGIES, 4. MEDICATIONS ON ADMISSION 5.DISCHARGE MEDICATIONS. Ignore other details while in hospital and focus only on these sections.\n'\
'output template:\n'\
'HISTORY OF PRESENT ILLNESS:\n'\
'(Language summary as short as possible)\n'\
'PAST MEDICAL HISTORY:\n'\
'(Language summary as short as possible)\n'\
'ALLERGIES:\n'\
'(A series of allergies names, separated by commas, does not require any other information)\n'\
'MEDICATIONS ON ADMISSION:\n'\
'(A series of drug names, separated by commas, remove dosage information. Maybe None.)\n'\
'DISCHARGE MEDICATIONS:\n'\
'(A series of drug names, separated by commas, remove dosage information. Maybe None.)\n'\
'Note:' + note + '\n' + 'Summarize result in five aspects in a concise paragraph without any other words:\n'

    msg = get_msg(prompt)

    return msg


def split_string(s, splitted_num):
    split_indices = [i * len(s) // splitted_num for i in range(1, splitted_num)]

    result = []
    start = 0
    for index in split_indices:
        end_0 = min(s.find('.', index), s.find('\n', index))
        end_new = s.find('\n\n', index)
        if abs(end_new - end_0) < 200:
            end = end_new
        else:
            end = end_0
        if end == -1:
            end = len(s)
        result.append(s[start:end + 1])
        start = end + 1

    result.append(s[start:])

    return result


def devide_list(origin_text_list):
    """Chia nhỏ văn bản để phù hợp context window của Gemini.

    Dùng số ký tự để xấp xỉ số token (1 token ≈ 4 ký tự).
    """
    MAX_CHARS = 15000  # ≈ 3.5k–4k tokens

    while True:
        new_text_list = []
        for text in origin_text_list:
            if len(text) <= MAX_CHARS:
                new_text_list.append(text)
            else:
                splitted_num = len(text) // MAX_CHARS + 1
                splitted_result = split_string(text, splitted_num)
                new_text_list += splitted_result
        if new_text_list == origin_text_list:
            break
        origin_text_list = new_text_list
    return new_text_list


def check_note(note):
    idx1 = note.upper().find('HISTORY OF PRESENT ILLNESS')
    idx2 = note.upper().find('PAST MEDICAL HISTORY')
    idx3 = note.upper().find('ALLERGIES')
    idx4 = note.upper().find('MEDICATIONS ON ADMISSION')
    idx5 = note.upper().find('DISCHARGE MEDICATIONS')
    if idx1 == -1 or idx2 == -1 or idx3 == -1 or idx4 == -1 or idx5 == -1:
        return False
    elif idx1 > idx2 or idx2 > idx3 or idx3 > idx4 or idx4 > idx5:
        return False
    else:
        return True

def generate_note(row):
    # for index, row in result_data.iterrows():
    hadm_id = row['HADM_ID']
    note_text = row.TEXT

    origin_text_list = devide_list([note_text])
    if len(origin_text_list) == 1:
        for i in range(10):
            note = simplify_note(origin_text_list[0])
            if check_note(note):
                break
            else:
                note = simplify_note(origin_text_list[0])

        return hadm_id, [note]
    else:
        processed_text = []
        for text_idx, text in enumerate(origin_text_list):
            for i in range(10):
                note = simplify_note(text)
                if check_note(note):
                    break
                else:
                    note = simplify_note(text)
            processed_text.append(note)

        return hadm_id, processed_text

In [None]:
# ===== PREVIEW MODE CONFIGURATION =====
# Set to a number to limit rows for quick preview, or None to process all rows
PREVIEW_LIMIT = 250  # Change to None for full processing
# ======================================


In [None]:
# Cấu hình Gemini API bằng Google GenAI SDK mới
from google import genai
import os

# Lấy API key từ biến môi trường (khuyến nghị) hoặc điền trực tiếp
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "YOUR_GEMINI_API_KEY_HERE")

# Khởi tạo client dùng chung cho các hàm trong utils
client = genai.Client(api_key=GEMINI_API_KEY)


data4LLM_path = 'data_process/output/mimic-iii/data4LLM.csv'
noteevents_path = 'data_process/input/mimic-iii/NOTEEVENTS.csv'

filter_noteevents_path = 'data_process/output/mimic-iii/note/noteevents_filtered.pkl'
simplified_note_path = 'data_process/output/mimic-iii/note/note_simp.pkl'
note_p1_path = 'data_process/output/mimic-iii/note/history_of_present_illness.pkl'
note_p2_path = 'data_process/output/mimic-iii/note/past_medical_history.pkl'
note_p3_path = 'data_process/output/mimic-iii/note/allergies.pkl'
note_p4_path = 'data_process/output/mimic-iii/note/med_on_adm.pkl'
note_content_path = 'data_process/output/mimic-iii/note/note_content.pkl'

data4LLM_with_note_path = 'data_process/output/mimic-iii/data4LLM_with_note.csv'

## generate noteevents_filtered.pkl

In [None]:
import pandas as pd
import dill

noteevents = pd.read_csv(noteevents_path)
noteevents = noteevents.sort_values(by=['SUBJECT_ID', 'HADM_ID', 'CHARTDATE', 'CHARTTIME'])

data4LLM = pd.read_csv(data4LLM_path)

# show the statistics of the data and noteevents
print('data4LLM shape:', data4LLM.shape)

# filter out the HADM_ID that are not in data4LLM
print('noteevents shape before filtering:', noteevents.shape)
noteevents = noteevents[noteevents['HADM_ID'].isin(data4LLM['HADM_ID'])]
noteevents = noteevents[(noteevents['CATEGORY'] == 'Discharge summary') & (noteevents['DESCRIPTION'] == 'Report')]
noteevents = noteevents.sort_values(by=['SUBJECT_ID', 'HADM_ID', 'CHARTDATE', 'CHARTTIME'])
noteevents = noteevents.reset_index(drop=True)
print('noteevents shape after filtering:', noteevents.shape)

In [None]:
# check whether all hadm_id in data4LLM are in noteevents
def check_hadm_id_in_noteevents(noteevents, data4LLM):
    hadm_id_in_data4LLM = set(data4LLM['HADM_ID'].tolist())
    hadm_id_in_noteevents = set(noteevents['HADM_ID'].tolist())
    num_hadm_id_not_in_noteevents = len(hadm_id_in_data4LLM - hadm_id_in_noteevents)
    print('num_hadm_id_not_in_noteevents:', num_hadm_id_not_in_noteevents)
    print('hadm_id in data4LLM but not in noteevents:', hadm_id_in_data4LLM - hadm_id_in_noteevents)

# check whether all hadm_id in noteevents are in data4LLM
check_hadm_id_in_noteevents(noteevents, data4LLM)

In [None]:
# check whether all hadm_id appears only once in noteevents
def check_hadm_id_appear_only_once(noteevents):
    nunique_hadm_id = noteevents['HADM_ID'].nunique()
    hadm_id_count = noteevents['HADM_ID'].value_counts()
    hadm_id_appear_more_than_once = hadm_id_count[hadm_id_count > 1].index.tolist()
    print('nunique_hadm_id:', nunique_hadm_id, '\t', len(hadm_id_appear_more_than_once), 'hadm_ids appear more than once', hadm_id_appear_more_than_once)
    return hadm_id_appear_more_than_once

print('check whether all hadm_id appears only once in noteevents...')
hadm_id_appear_more_than_once = check_hadm_id_appear_only_once(noteevents)

In [None]:
# for those hadm_id that appear more than once, we keep the first appearance
print('noteevents shape before filtering:', noteevents.shape)
noteevents_filtered = noteevents.drop_duplicates(subset=['HADM_ID'], keep='first')
print('noteevents shape after filtering:', noteevents_filtered.shape)
_ = check_hadm_id_appear_only_once(noteevents_filtered)
# 把hadm_id转换成int
noteevents_filtered['HADM_ID'] = noteevents_filtered['HADM_ID'].astype(int)


dill.dump(noteevents_filtered, open(filter_noteevents_path, 'wb'))

## split and simplify the note

In [None]:
# note_data = dill.load(open(filter_noteevents_path, 'rb'))

# # Apply PREVIEW_LIMIT if set
# if PREVIEW_LIMIT is not None:
#     note_data = note_data.head(PREVIEW_LIMIT)
#     print(f"⚠️ PREVIEW MODE: Processing only first {PREVIEW_LIMIT} rows")

# # Load checkpoint if exists
# if os.path.exists(simplified_note_path):
#     result_pd = dill.load(open(simplified_note_path, 'rb'))
#     processed_hadm_ids = set(result_pd['HADM_ID'].unique())
#     print(f"Resuming from checkpoint: {len(processed_hadm_ids)} HADM_IDs already processed")
# else:
#     result_pd = pd.DataFrame(columns=['HADM_ID', 'NOTE'])
#     processed_hadm_ids = set()

# try:
#     for idx, row in tqdm(note_data.iterrows(), total=len(note_data)):
#         hadm_id = row['HADM_ID']

#         # Skip if already processed
#         if hadm_id in processed_hadm_ids:
#             continue

#         hadm_id, note_list = generate_note(row)

#         for note in note_list:
#             result_pd.loc[len(result_pd)] = [hadm_id, note]

#         # Save checkpoint after each HADM_ID
#         dill.dump(result_pd, open(simplified_note_path, 'wb'))
#         processed_hadm_ids.add(hadm_id)

# except Exception as e:
#     print(f"\n❌ Error occurred: {e}")
#     print(f"✅ Progress saved. Processed {len(processed_hadm_ids)} HADM_IDs so far.")
#     dill.dump(result_pd, open(simplified_note_path, 'wb'))
#     raise

# print(f"✅ Complete! Total {len(processed_hadm_ids)} HADM_IDs processed.")

## Manually modify notes that still do not meet the format requirements

In [None]:
# process_note = dill.load(open(simplified_note_path, 'rb'))

# for i, row in process_note.iterrows():
#     if not check_note(row['NOTE']):
#         print(row['NOTE'])
#         print('*********************************************')
#         process_note.at[i, 'NOTE'] = input('Please input the correct note: ')

# dill.dump(process_note, open(simplified_note_path, 'wb'))

## 1 history of present illness

In [None]:
# final_note = dill.load(open(simplified_note_path, 'rb'))

# # Apply PREVIEW_LIMIT if set
# if PREVIEW_LIMIT is not None:
#     # Limit to unique HADM_IDs up to PREVIEW_LIMIT
#     unique_hadm_ids = final_note['HADM_ID'].unique()[:PREVIEW_LIMIT]
#     final_note = final_note[final_note['HADM_ID'].isin(unique_hadm_ids)]
#     print(f"⚠️ PREVIEW MODE: Processing only first {len(unique_hadm_ids)} unique HADM_IDs")

# idx = 0

# result_pd_1 = pd.DataFrame(columns=['HADM_ID', 'CON_NOTE'])

# while idx < len(final_note):
#     hadm_id = final_note.loc[idx].HADM_ID
#     visit_note = ' '
#     note_part = final_note.loc[idx].NOTE
#     start = note_part.upper().find('HISTORY OF PRESENT ILLNESS')
#     end = note_part.upper().find('PAST MEDICAL HISTORY')
#     visit_note += note_part[start + len('HISTORY OF PRESENT ILLNESS:'):end].strip()
#     while idx + 1 < len(final_note) and final_note.loc[idx + 1].HADM_ID == hadm_id:
#         idx += 1
#         visit_note += '  +  '
#         note_part = final_note.loc[idx].NOTE
#         start = note_part.upper().find('HISTORY OF PRESENT ILLNESS')
#         end = note_part.upper().find('PAST MEDICAL HISTORY')
#         visit_note += note_part[start + len('HISTORY OF PRESENT ILLNESS:'):end].strip()

#     idx += 1

#     result_pd_1.loc[len(result_pd_1)] = [hadm_id, visit_note]

# # Load checkpoint if exists
# if os.path.exists(note_p1_path):
#     history_of_present_illness_pd = dill.load(open(note_p1_path, 'rb'))
#     processed_hadm_ids = set(history_of_present_illness_pd['HADM_ID'].tolist())
#     start_idx = len(history_of_present_illness_pd)
#     print(f"Resuming from checkpoint: {start_idx} records already processed")
# else:
#     history_of_present_illness_pd = pd.DataFrame(columns=['HADM_ID', 'HISTORY OF PRESENT ILLNESS'])
#     processed_hadm_ids = set()
#     start_idx = 0

# try:
#     # LIMIT TO PREVIEW_LIMIT FOR QUICK PREVIEW
#     end_idx = min(len(result_pd_1), PREVIEW_LIMIT) if PREVIEW_LIMIT is not None else len(result_pd_1)
#     for idx in trange(start_idx, end_idx):
#         hadm_id = result_pd_1.loc[idx].HADM_ID

#         # Skip if already processed
#         if hadm_id in processed_hadm_ids:
#             continue

#         note = result_pd_1.loc[idx].CON_NOTE
#         prompt = '''
# I'll provide you with an input containing the history of present illness for a patient. Your task is to:
# 1.Retain the descriptions of the patient's history of present illness before admission and on admission, while removing all descriptions after admission and at discharge.
# 2.Consolidate the text to produce a concise output.

# input: ''' + note + '''

# You only need to answer the refined results, no other explanation is needed!

# output:
# '''
#         result = get_msg(prompt)

#         history_of_present_illness_pd.loc[len(history_of_present_illness_pd)] = [hadm_id, result]
#         processed_hadm_ids.add(hadm_id)

#         # Save checkpoint after each record
#         dill.dump(history_of_present_illness_pd, open(note_p1_path, 'wb'))

# except Exception as e:
#     print(f"\n❌ Error occurred: {e}")
#     print(f"✅ Progress saved at index {len(history_of_present_illness_pd)}/{len(result_pd_1)}")
#     dill.dump(history_of_present_illness_pd, open(note_p1_path, 'wb'))
#     raise

# print(f"✅ Complete! Total {len(history_of_present_illness_pd)} records processed.")

## 2 past medical history

In [None]:
# idx = 0

# final_note = dill.load(open(simplified_note_path, 'rb'))

# # Apply PREVIEW_LIMIT if set
# if PREVIEW_LIMIT is not None:
#     # Limit to unique HADM_IDs up to PREVIEW_LIMIT
#     unique_hadm_ids = final_note['HADM_ID'].unique()[:PREVIEW_LIMIT]
#     final_note = final_note[final_note['HADM_ID'].isin(unique_hadm_ids)]
#     print(f"⚠️ PREVIEW MODE: Processing only first {len(unique_hadm_ids)} unique HADM_IDs")

# result_pd_2 = pd.DataFrame(columns=['HADM_ID', 'CON_NOTE'])

# while idx < len(final_note):
#     hadm_id = final_note.loc[idx].HADM_ID
#     visit_note = ' '
#     note_part = final_note.loc[idx].NOTE
#     start = note_part.upper().find('PAST MEDICAL HISTORY')
#     end = note_part.upper().find('ALLERGIES')
#     visit_note += note_part[start + len('PAST MEDICAL HISTORY:'):end].strip()
#     while idx + 1 < len(final_note) and final_note.loc[idx + 1].HADM_ID == hadm_id:
#         idx += 1
#         visit_note += '  +  '
#         note_part = final_note.loc[idx].NOTE
#         start = note_part.upper().find('PAST MEDICAL HISTORY')
#         end = note_part.upper().find('ALLERGIES')
#         visit_note += note_part[start + len('PAST MEDICAL HISTORY:'):end].strip()

#     idx += 1

#     result_pd_2.loc[len(result_pd_2)] = [hadm_id, visit_note]

# # Load checkpoint if exists
# if os.path.exists(note_p2_path):
#     past_medical_history_pd = dill.load(open(note_p2_path, 'rb'))
#     processed_hadm_ids = set(past_medical_history_pd['HADM_ID'].tolist())
#     start_idx = len(past_medical_history_pd)
#     print(f"Resuming from checkpoint: {start_idx} records already processed")
# else:
#     past_medical_history_pd = pd.DataFrame(columns=['HADM_ID', 'PAST MEDICAL HISTORY'])
#     processed_hadm_ids = set()
#     start_idx = 0

# try:
#     # LIMIT TO PREVIEW_LIMIT FOR QUICK PREVIEW
#     end_idx = min(len(result_pd_2), PREVIEW_LIMIT) if PREVIEW_LIMIT is not None else len(result_pd_2)
#     for idx in trange(start_idx, end_idx):
#         hadm_id = result_pd_2.loc[idx].HADM_ID

#         # Skip if already processed
#         if hadm_id in processed_hadm_ids:
#             continue

#         note = result_pd_2.loc[idx].CON_NOTE
#         prompt = '''
# I'll provide you with input containing a patient's past medical history. I need you to consolidate the text and output a concise summary.

# input: ''' + note + '''

# You only need to answer the refined results, no other explanation is needed!

# output:
# '''
#         result = get_msg(prompt)

#         past_medical_history_pd.loc[len(past_medical_history_pd)] = [hadm_id, result]
#         processed_hadm_ids.add(hadm_id)

#         # Save checkpoint after each record
#         dill.dump(past_medical_history_pd, open(note_p2_path, 'wb'))

# except Exception as e:
#     print(f"\n❌ Error occurred: {e}")
#     print(f"✅ Progress saved at index {len(past_medical_history_pd)}/{len(result_pd_2)}")
#     dill.dump(past_medical_history_pd, open(note_p2_path, 'wb'))
#     raise

# print(f"✅ Complete! Total {len(past_medical_history_pd)} records processed.")

## 3 allergies

In [None]:
# idx = 0

# final_note = dill.load(open(simplified_note_path, 'rb'))

# # Apply PREVIEW_LIMIT if set
# if PREVIEW_LIMIT is not None:
#     # Limit to unique HADM_IDs up to PREVIEW_LIMIT
#     unique_hadm_ids = final_note['HADM_ID'].unique()[:PREVIEW_LIMIT]
#     final_note = final_note[final_note['HADM_ID'].isin(unique_hadm_ids)]
#     print(f"⚠️ PREVIEW MODE: Processing only first {len(unique_hadm_ids)} unique HADM_IDs")

# result_pd_3 = pd.DataFrame(columns=['HADM_ID', 'CON_NOTE'])

# while idx < len(final_note):
#     hadm_id = final_note.loc[idx].HADM_ID
#     visit_note = ' '
#     note_part = final_note.loc[idx].NOTE
#     start = note_part.upper().find('ALLERGIES')
#     end = note_part.upper().find('MEDICATIONS ON ADMISSION')
#     visit_note += note_part[start + len('ALLERGIES:'):end].strip()
#     while idx + 1 < len(final_note) and final_note.loc[idx + 1].HADM_ID == hadm_id:
#         idx += 1
#         visit_note += '  +  '
#         note_part = final_note.loc[idx].NOTE
#         start = note_part.upper().find('ALLERGIES')
#         end = note_part.upper().find('MEDICATIONS ON ADMISSION')
#         visit_note += note_part[start + len('ALLERGIES:'):end].strip()
#     result_pd_3.loc[len(result_pd_3)] = [hadm_id, visit_note]

#     idx += 1

# # Load checkpoint if exists
# if os.path.exists(note_p3_path):
#     allergies_pd = dill.load(open(note_p3_path, 'rb'))
#     processed_hadm_ids = set(allergies_pd['HADM_ID'].tolist())
#     start_idx = len(allergies_pd)
#     print(f"Resuming from checkpoint: {start_idx} records already processed")
# else:
#     allergies_pd = pd.DataFrame(columns=['HADM_ID', 'ALLERGIES'])
#     processed_hadm_ids = set()
#     start_idx = 0

# try:
#     # LIMIT TO PREVIEW_LIMIT FOR QUICK PREVIEW
#     end_idx = min(len(result_pd_3), PREVIEW_LIMIT) if PREVIEW_LIMIT is not None else len(result_pd_3)
#     for idx in trange(start_idx, end_idx):
#         hadm_id = result_pd_3.loc[idx].HADM_ID

#         # Skip if already processed
#         if hadm_id in processed_hadm_ids:
#             continue

#         note = result_pd_3.loc[idx].CON_NOTE
#         prompt = '''
# I'm going to give you an input, which is a bunch of text and some plus signs. I need you to extract all the drug names for me from each input, and output the corresponding list.

# Here are some of the input and output sample:

# input1:No Known Allergies to Drugs.  +  None mentioned.

# output1:[]

# input2:None mentioned.  +  The patient is allergic to cefazolin and penicillins.

# output2:[cefazolin, penicillins]

# Now you need to provide the corresponding output of input3, without any other words:

# input3:''' + note + '''

# You only need to output a list!

# output3:
# '''
#         result = get_msg(prompt)

#         allergies_pd.loc[len(allergies_pd)] = [hadm_id, result]
#         processed_hadm_ids.add(hadm_id)

#         # Save checkpoint after each record
#         dill.dump(allergies_pd, open(note_p3_path, 'wb'))

# except Exception as e:
#     print(f"\n❌ Error occurred: {e}")
#     print(f"✅ Progress saved at index {len(allergies_pd)}/{len(result_pd_3)}")
#     dill.dump(allergies_pd, open(note_p3_path, 'wb'))
#     raise

# print(f"✅ Complete! Total {len(allergies_pd)} records processed.")

##  4 med_on_adm

In [None]:
idx = 0

final_note = dill.load(open(simplified_note_path, 'rb'))

# Apply PREVIEW_LIMIT if set
if PREVIEW_LIMIT is not None:
    # Limit to unique HADM_IDs up to PREVIEW_LIMIT
    unique_hadm_ids = final_note['HADM_ID'].unique()[:PREVIEW_LIMIT]
    final_note = final_note[final_note['HADM_ID'].isin(unique_hadm_ids)]
    print(f"⚠️ PREVIEW MODE: Processing only first {len(unique_hadm_ids)} unique HADM_IDs")

result_pd_4 = pd.DataFrame(columns=['HADM_ID', 'CON_NOTE'])


while idx < len(final_note):
    hadm_id = final_note.loc[idx].HADM_ID
    visit_note = ' '
    note_part = final_note.loc[idx].NOTE
    start = note_part.upper().find('MEDICATIONS ON ADMISSION')
    end = note_part.upper().find('DISCHARGE MEDICATIONS')
    visit_note += note_part[start + len('MEDICATIONS ON ADMISSION:'):end].strip()
    while idx + 1 < len(final_note) and final_note.loc[idx + 1].HADM_ID == hadm_id:
        idx += 1
        visit_note += '  +  '
        note_part = final_note.loc[idx].NOTE
        start = note_part.upper().find('MEDICATIONS ON ADMISSION')
        end = note_part.upper().find('DISCHARGE MEDICATIONS')
        visit_note += note_part[start + len('MEDICATIONS ON ADMISSION:'):end].strip()
    result_pd_4.loc[len(result_pd_4)] = [hadm_id, visit_note]

    idx += 1

# Load checkpoint if exists
if os.path.exists(note_p4_path):
    med_on_adm_pd = dill.load(open(note_p4_path, 'rb'))
    processed_hadm_ids = set(med_on_adm_pd['HADM_ID'].tolist())
    start_idx = len(med_on_adm_pd)
    print(f"Resuming from checkpoint: {start_idx} records already processed")
else:
    med_on_adm_pd = pd.DataFrame(columns=['HADM_ID', 'MEDICATIONS ON ADMISSION'])
    processed_hadm_ids = set()
    start_idx = 0

try:
    # LIMIT TO PREVIEW_LIMIT FOR QUICK PREVIEW
    end_idx = min(len(result_pd_4), PREVIEW_LIMIT) if PREVIEW_LIMIT is not None else len(result_pd_4)
    for idx in trange(start_idx, end_idx):
        hadm_id = result_pd_4.loc[idx].HADM_ID

        # Skip if already processed
        if hadm_id in processed_hadm_ids:
            continue

        note = result_pd_4.loc[idx].CON_NOTE
        prompt = '''
I'm going to give you an input, which is a bunch of text and some plus signs. I need you to extract all the drug names for me from each input, and output the corresponding list.

Here are some of the input and output sample:

input1:None.  +   Nifedipine XL, Calcitriol, Lisinopril, Aspirin, Lasix, Glyburide, Clonidine, Zoloft, Simvastatin, Tums, Procrit, Lupron, Niferex.

output1:[Nifedipine XL, Calcitriol, Lisinopril, Aspirin, Lasix, Glyburide, Clonidine, Zoloft, Simvastatin, Tums, Procrit, Lupron, Niferex]

input2: The patient was taking Aspirin, Atovaquone, Levofloxacin  +  The patient was on multiple medications including Emtriva, Lisinoprol, Metoprolol, Stavudine.

output2:[Aspirin, Atovaquone, Levofloxacin, Emtriva, Lisinoprol, Metoprolol, Stavudine]

Now you need to provide the corresponding output of input3, without any other words:

input3:''' + note + '''

You only need to output a list!

output3:
'''
        result = get_msg(prompt)

        med_on_adm_pd.loc[len(med_on_adm_pd)] = [hadm_id, result]
        processed_hadm_ids.add(hadm_id)

        # Save checkpoint after each record
        dill.dump(med_on_adm_pd, open(note_p4_path, 'wb'))

except Exception as e:
    print(f"\n❌ Error occurred: {e}")
    print(f"✅ Progress saved at index {len(med_on_adm_pd)}/{len(result_pd_4)}")
    dill.dump(med_on_adm_pd, open(note_p4_path, 'wb'))
    raise

print(f"✅ Complete! Total {len(med_on_adm_pd)} records processed.")

## combine

In [None]:
history_of_present_illness_pd = dill.load(open(note_p1_path, 'rb'))
past_medical_history_pd = dill.load(open(note_p2_path, 'rb'))
allergies_pd = dill.load(open(note_p3_path, 'rb'))
med_on_adm_pd = dill.load(open(note_p4_path, 'rb'))

result_pd_5 = pd.DataFrame(columns=['HADM_ID', 'NOTE_CONTENT'])

for idx in range(len(history_of_present_illness_pd)):
    hadm_id = history_of_present_illness_pd.loc[idx].HADM_ID
    history_note = history_of_present_illness_pd.iloc[idx, 1]

    past_medical_history_note = past_medical_history_pd[past_medical_history_pd.HADM_ID == hadm_id].iloc[0, 1]
    allergies_note = allergies_pd[allergies_pd.HADM_ID == hadm_id].iloc[0, 1]
    med_on_adm_note = med_on_adm_pd[med_on_adm_pd.HADM_ID == hadm_id].iloc[0, 1]

    note_content = 'History of present illness: ' + history_note + ',\nPast medical history: ' + past_medical_history_note + ',\nAllergies: ' + allergies_note + ',\nMedications on admission: ' + med_on_adm_note

    result_pd_5.loc[len(result_pd_5)] = [hadm_id, note_content]

dill.dump(result_pd_5, open(note_content_path, 'wb'))

## generate data4LLM_with_note.csv

In [None]:
original_data = pd.read_csv(data4LLM_path)
note_content = dill.load(open(note_content_path, 'rb'))

data4LLM_with_note = []

# 逐行读取original_data
# 读取每一行下的HADM_ID，看一下是否在note_content中有对应行
# 如果有，则将original_data这一行的内容与note_content这一行的NOTE_CONTENT拼接在一起，加入到data4LLM_with_note中
for index, row in original_data.iterrows():
    hadm_id = row['HADM_ID']
    if hadm_id in note_content['HADM_ID'].values:
        note = note_content[note_content['HADM_ID'] == hadm_id].iloc[0, 1]
        data4LLM_with_note.append(row.tolist() + [note])

pd.DataFrame(data4LLM_with_note, columns=original_data.columns.tolist() + ['NOTE']).to_csv(data4LLM_with_note_path, index=False)


