# Generate Discharge Summary Sections with the Trained Checkpoints

This notebook loads model trained for ["Discharge Me!" Shared Task](https://stanford-aimi.github.io/discharge-me/) by e-Health CSIRO and shows how to make inferences on samples from the validation set. The inference setup is the same as used for the final shared task submission.

Checkpoints can also be found at the [repo](https://github.com/JHLiu7/bionlp24-shared-task-discharge-me).

In [1]:
import datasets
import evaluate
import numpy as np
import pandas as pd
from tqdm import tqdm
from datasets import load_dataset

import transformers
import logging
import re
import sys
import torch

from transformers import (
    AutoModelForSeq2SeqLM,
    AutoModelForCausalLM,
    AutoTokenizer, 
    pipeline
)

from peft import LoraConfig, get_peft_model, PeftModelForCausalLM

## 1. Prepare data
We use the validation set for illustration of how we prepared and used the data. Our models use the prior content of the target sections (`discharge instruction` (`DI`) and `brief hospital course` (`BHC`)) as input context for the generation task. To prepare the data, we 

1. Follow the official repo that uses an updated method to extract `BHC` (`DI` is the same as the released file as of v1.3);
2. We extract the contents before `DI` and `BHC` respectively and use them as the corresponding source context for each target section.

In [2]:
fold = 'valid'
df_val_text = pd.read_csv(f'./physionet.org/files/discharge-me/1.3/{fold}/discharge.csv.gz').sort_values(by=['hadm_id'])
df_val_target = pd.read_csv(f'./physionet.org/files/discharge-me/1.3/{fold}/discharge_target.csv.gz').sort_values(by=['hadm_id'])
df_val_rad = pd.read_csv(f'./physionet.org/files/discharge-me/1.3/{fold}/radiology.csv.gz').sort_values(by=['hadm_id'])


print(df_val_text.hadm_id.nunique())
print(df_val_target.hadm_id.nunique())
print(df_val_rad.hadm_id.nunique())

14719
14719
14719


In [3]:
# New method to extract bhc
# from https://colab.research.google.com/drive/1yW-29KcDYoswMrqwjEMO6l6i2Ll3p5BX?usp=sharing
# see https://www.codabench.org/forums/1927/275/
from collections import OrderedDict
input_sections = OrderedDict([
    ('Brief Hospital Course', 'Brief Hospital Course'),
    ('Medications on Admission', '[A-Za-z_]+ on Admission'),
    ('Discharge Medications', '[A-Za-z_]+ Medications'),
    ('Discharge Disposition', '[A-Za-z_]+ Disposition'),
    ('Discharge Diagnosis', '[A-Za-z_]+ Diagnosis'),
    ('Discharge Condition', '[A-Za-z_]+ Condition')
])


def parse_brief_hospital_course(row):
    discharge_summary = row['text']
    section_name = 'Brief Hospital Course'
    section = input_sections.get(section_name)
    for next_section in list(input_sections.values())[1:]:
        search = re.findall(section + ".+\n" + next_section, discharge_summary, re.DOTALL)
        if len(search) > 0:
            break
    rex = r'(%s?):\s*\n{0,2}(.+?)\s*(\n\s*){1,10}(%s):\n' % (section, next_section)

    section_ext = re.findall(rex, discharge_summary, re.DOTALL)
    if len(section_ext) > 0:
        return section_ext[-1][1]
    else:
        return None

In [4]:
df_val_text['brief_hospital_course'] = df_val_text.apply(parse_brief_hospital_course, axis=1)


In [5]:
# slice notes and use prior ctx as input

def _query_hadm(hadm, df, col):
    row = df[df['hadm_id'] == hadm]
    assert len(row) == 1
    return row[col].iloc[0]

def _get_notes(hadm, df_text, df_target, df_rad):
    dnote = _query_hadm(hadm, df_text, 'text')
    bhc = _query_hadm(hadm, df_text, 'brief_hospital_course') # new extraction
    di = _query_hadm(hadm, df_target, 'discharge_instructions')
    return dnote, bhc, di

def process_hadm(hadm, df_text, df_target, df_rad):

    dnote = _query_hadm(hadm, df_text, 'text')
    di = _query_hadm(hadm, df_target, 'discharge_instructions')
    
    bhc = _query_hadm(hadm, df_text, 'brief_hospital_course') # new extraction

    # prepare dnote, bhc, di
    dnote = re.sub(r'Brief Hospital Course:\s*', 'Brief Hospital Course: ', dnote)
    dnote = re.sub(r'Discharge Instructions:\s*', 'Discharge Instructions: ', dnote)

    if "Brief Hospital Course:" not in bhc:
        bhc = "Brief Hospital Course: " + bhc
    else:
        bhc = re.sub(r'Brief Hospital Course:\s*', 'Brief Hospital Course: ', bhc)
    if "Discharge Instructions:" not in di:
        di = "Discharge Instructions: "+di
    else:
        di = re.sub(r'Discharge Instructions:\s*', 'Discharge Instructions: ', di)

    assert dnote.find(bhc) > 0, hadm
    assert dnote.find(di) > 0, hadm

    # slice dnote
    ctx_bhc, rgt_bhc = dnote.split(bhc)
    ctx_di_long, _ = dnote.split(di) # use all prior content as input for di
    ctx_di_short, _ = dnote.replace(ctx_bhc, '').split(di) #[0] # use content between bhc (included) and di as input

    # get rad reports
    rad = df_rad[df_rad['hadm_id'] == hadm].sort_values(by='charttime', ascending=False) # keep latest first
    ctx_rad = '\n'.join(rad['text'].tolist())

    # collect
    return {
        'source-bhc-dnote': ctx_bhc,
        'source-bhc-dnote_rad': '\n'.join([ctx_bhc, ctx_rad]),
        
        'source-di-dnote': ctx_di_short,
        'source-di-dnote_rad': '\n'.join([ctx_di_short, ctx_rad]),
        
        'source-rad': ctx_rad,
        
        'target-bhc': bhc,
        'target-di': di,

        'hadm_id': hadm
    }

def process_fold(df_text, df_target, df_rad):

    data = []

    assert len(df_text) == len(df_target)

    hadms = df_target.hadm_id.tolist()

    for hadm in tqdm(hadms):
        data.append(
            process_hadm(
                hadm=hadm, 
                df_target=df_target,
                df_text=df_text,
                df_rad=df_rad,
            )
        )
    
    return pd.DataFrame(data)

In [6]:
new_df_val = process_fold(df_val_text, df_val_target, df_val_rad)

100%|██████████| 14719/14719 [00:16<00:00, 895.29it/s]


In [7]:
val_dataset = datasets.Dataset.from_pandas(new_df_val)
val_dataset

Dataset({
    features: ['source-bhc-dnote', 'source-bhc-dnote_rad', 'source-di-dnote', 'source-di-dnote_rad', 'source-rad', 'target-bhc', 'target-di', 'hadm_id'],
    num_rows: 14719
})

## 2. Generate `Brief Hospital Course`

Here we load the trained models for generating `BHC`. Similar to our paper, we set the maximum length of the generated section to `1280` tokens. But for PRIMERA, this will automatically be set to 1024 due to the model limit.

In [8]:
SECTION = 'bhc'
MAX_LEN = 1280

text_column = f'source-{SECTION}-dnote'
summary_column = f'target-{SECTION}'

all_columns = val_dataset.column_names
columns_to_remove = [col for col in all_columns if col not in [text_column, summary_column]]

bhc_val_dataset = val_dataset \
        .select(range(100)) \
        .remove_columns(columns_to_remove) 

### 2.1. Using decoder-only model
We load and use Llama-3 and our fine-tuned LoRA module for `BHC`. 

In [9]:
model_id = 'meta-llama/Meta-Llama-3-8B'
cache_dir = f'/scratch3/liu217/Llama/{model_id}' # you may want to change this accordingly

tokenizer = AutoTokenizer.from_pretrained(
    model_id, cache_dir=cache_dir, padding_side='left'
)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype = torch.bfloat16,
    device_map='auto',
    # attn_implementation="flash_attention_2", # you can choose to use flash attn 
    cache_dir=cache_dir
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [10]:
# load and merge lora
adapter_path = 'jhliu/DischargeGen-Llama3-lora-BHC/'

model = PeftModelForCausalLM.from_pretrained(model, adapter_path, adapter_name="main")
model = model.merge_and_unload()

pipe = pipeline('text-generation', tokenizer=tokenizer, model=model)

We format the input using the prompt template as below, and use the `transformers` pipeline for inference.

In [11]:
TEMPLATE_BHC = \
'''Summarize the below clinical text into a section of brief hospital course.
    
### Input:
{input_text}

### Summary:
'''

TEMPLATE_DI = \
'''Summarize the below clinical text into a section of discharge instruction.
    
### Input:
{input_text}

### Summary:
'''

def inference_dec(input_text):
    TEMPLATE = TEMPLATE_DI if SECTION == 'di' else TEMPLATE_BHC
    prompt = TEMPLATE.format(input_text=input_text)
    
    gen = pipe(prompt, max_new_tokens=MAX_LEN, do_sample=False)
    assert len(gen) == 1
    output = gen[0]['generated_text'][len(prompt):]
    return output

In [12]:
print(inference_dec(bhc_val_dataset[text_column][20]))



Brief Hospital Course: The patient presented to the emergency department and was  evaluated by the orthopedic surgery team. The patient was found  to have a left femur fracture and left pilon fracture and was  admitted to the orthopedic surgery service. The patient was  taken to the operating room on ___ for left femur retrograde  nail and left pilon external fixation, which the patient  tolerated well. For full details of the procedure please see the  separately dictated operative report. The patient was taken from  the OR to the PACU in stable condition and after satisfactory  recovery from anesthesia was transferred to the floor. The  patient was initially given IV fluids and IV pain medications,  and progressed to a regular diet and oral medications by POD#1.  The patient was given ___ antibiotics and  anticoagulation per routine. The patient's home medications were  continued throughout this hospitalization. The patient worked  with ___ who determined that discharge to rehab was a

### 2.2. Using encoder-decoder model
We load and use the fully fine-tuned PRIMERA for `DI`. 

In [13]:
model_id = 'jhliu/DischargeGen-PRIMERA-BHC'

tokenizer_s2s = AutoTokenizer.from_pretrained(model_id, cache_dir='cache')
model_s2s = AutoModelForSeq2SeqLM.from_pretrained(model_id, cache_dir='cache').eval().cuda()

In [14]:
def inference_s2s(input_text):
    model_inputs = tokenizer_s2s(input_text, max_length=4096, return_tensors='pt',
                             padding=False, truncation=True)

    input_ids = model_inputs['input_ids'].to('cuda')
    attention_mask = model_inputs['attention_mask'].to('cuda')

    out = model_s2s.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_new_tokens=MAX_LEN,
    )

    output = tokenizer_s2s.batch_decode(out, skip_special_tokens=True, clean_up_tokenization_spaces=True)
    assert len(output) == 1
    return output[0].replace('\n', ' ')

In [15]:
print(inference_s2s(bhc_val_dataset[text_column][20]))

Input ids are automatically padded from 445 to 512 to be a multiple of `config.attention_window`: 512


Brief Hospital Course: The patient presented to the emergency department and was evaluated by the orthopedic  surgery team. The patient was found to have a left femur fracture and was admitted to the  orthopedic surgery service. The orthopedics team was consulted for operative  repair. The medicine team was also consulted for management of his bipolar disorder and  anxiety. The patients home medications were continued throughout this hospitalization. The  patient was taken to the operating room on ___ for L femur retrograde nail, L pilon  ex-fix, which the patient tolerated well. For full details of the procedure please see  the separately dictated operative report. The left femoral nail was placed retrograde  nail, L pin on ex-fix. The lateral pin site was dry and the patient was placed in a  splint. The ___ hospital course was otherwise unremarkable.  At the time of discharge the patient's pain was well controlled with oral medications,  incisions were clean/dry/intact, and the LLE w

## 3. Generate `Discharge Instruction`

Here we load the trained models for generating `DI`. Similar to our paper, we set the maximum length of the generated section to `512` tokens.

In [16]:
SECTION = 'di'
MAX_LEN = 512

text_column = f'source-{SECTION}-dnote'
summary_column = f'target-{SECTION}'

all_columns = val_dataset.column_names
columns_to_remove = [col for col in all_columns if col not in [text_column, summary_column]]

di_val_dataset = val_dataset \
        .select(range(100)) \
        .remove_columns(columns_to_remove) 


### 3.1. Using decoder-only model (Llama-3)

In [17]:
model_id = 'meta-llama/Meta-Llama-3-8B'
cache_dir = f'/scratch3/liu217/Llama/{model_id}' # you may want to change this accordingly

tokenizer = AutoTokenizer.from_pretrained(
    model_id, cache_dir=cache_dir, padding_side='left'
)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype = torch.bfloat16,
    device_map='auto',
    # attn_implementation="flash_attention_2", # you can choose to use flash attn 
    cache_dir=cache_dir
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [18]:
# load and merge lora
adapter_path = 'jhliu/DischargeGen-Llama3-lora-DI/'

model = PeftModelForCausalLM.from_pretrained(model, adapter_path, adapter_name="main")
model = model.merge_and_unload()

pipe = pipeline('text-generation', tokenizer=tokenizer, model=model)

In [19]:
print(inference_dec(di_val_dataset[text_column][5]))

Discharge Instructions:
Dear Ms. ___,   You were hospitalized due to symptoms of visual deficit and  unsteady gait resulting from an ACUTE ISCHEMIC STROKE, a  condition where a blood vessel providing oxygen and nutrients to  the brain is blocked by a clot. The brain is the part of your  body that controls and directs all the other parts of your body,  so damage to the brain from being deprived of its blood supply  can result in a variety of symptoms.   Stroke can have many different causes, so we assessed you for  medical conditions that might raise your risk of having stroke.  In order to prevent future strokes, we plan to modify those risk  factors. Your risk factors are:   - atrial fibrillation - high blood pressure - high cholesterol  We are changing your medications as follows:   - continue your Coumadin as previously prescribed - take aspirin 81mg daily for one week, then stop  Please take your other medications as prescribed.   Please followup with Neurology and your primary car

### 3.2. Using encoder-decoder model (PRIMERA)

In [20]:
model_id = 'jhliu/DischargeGen-PRIMERA-DI'

tokenizer_s2s = AutoTokenizer.from_pretrained(model_id, cache_dir='cache')
model_s2s = AutoModelForSeq2SeqLM.from_pretrained(model_id, cache_dir='cache').eval().cuda()

In [21]:
print(inference_s2s(di_val_dataset[text_column][5]))


Input ids are automatically padded from 1264 to 1536 to be a multiple of `config.attention_window`: 512


Discharge Instructions: Dear Ms. ___,   You were hospitalized due to symptoms of vision  changes resulting from an ACUTE ISCHEMIC STROKE, a  condition where a blood vessel providing oxygen and  nutrients to the brain is blocked by a clot. The brain is  the part of your body that controls and directs all the  other parts of your life, so damage to thebrain from being  deprived of its blood supply can result in a variety  of symptoms.     Stroke can have many different causes, so we assessed  you for medical conditions that might raise your risk   of having stroke. In order to prevent future strokes,  we plan to modify those risk factors. Your risk factors  are:   Atrial fibrillation     We are changing your medications as follows:    - Please continue taking your Coumadin as prescribed.  Please have your INR checked on ___. If it is   below  therapeutic (___), please instruct your doctor to  stop your aspirin 81mg (for bridging). If not therapeutic,    Please take your other medications

## 4. Compare the two models

In [22]:
i=10

print("#"*50+ " Llama-3 " +"#"*50)
print(inference_dec(di_val_dataset[text_column][i]))

print()

print("#"*50+ " Primera " +"#"*50)
print(inference_s2s(di_val_dataset[text_column][i]))


################################################## Llama-3 ##################################################


Input ids are automatically padded from 1609 to 2048 to be a multiple of `config.attention_window`: 512


Discharge Instructions:
Dear Ms. ___,  
You were hospitalized due to symptoms of slurred speech and right  sided weakness resulting from an ACUTE ISCHEMIC STROKE, a  condition where a blood vessel providing oxygen and nutrients to  the brain is blocked by a clot. The brain is the part of your  body that controls and directs all the other parts of your body,  so damage to the brain from being deprived of its blood supply  can result in a variety of symptoms.  
Stroke can have many different causes, so we assessed you for  medical conditions that might raise your risk of having stroke.  In order to prevent future strokes, we plan to modify those risk  factors. Your risk factors are:   - high blood pressure - high cholesterol - heart disease   We are changing your medications as follows:   - aspirin 81 mg daily - atorvastatin 80 mg daily  Please take your other medications as prescribed.   Please follow up with Neurology and your primary care physician  as listed below.   If you experienc