<a href="https://colab.research.google.com/github/TienNguyen93/clinical-generation/blob/main/clinical_generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Clinical Note Generation**

In [None]:
%pip install datasets evaluate rouge_score

## **Import libraries**

In [None]:
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig, TrainingArguments, Trainer
import torch
import time
import evaluate
import pandas as pd
import numpy as np

## **Load dataset**

In [None]:
ds = load_dataset("316usman/research_clinical_visit_note_summarization_corpus_mts")

In [None]:
ds

### **Find out the longest sequence and shortest sequence in train, val, tes set**

In [None]:
# TODO

### **Prepare dataset**

 Convert the dialog-summary (prompt-response) pairs into explicit instructions

In [None]:
"""
Preprocessing function needs to:

* Prefix the input with a prompt so T5 knows this is a summarization task. Some models capable of multiple NLP tasks require prompting for specific tasks.
* Use the keyword text_target argument when tokenizing labels.
* Truncate sequences to be no longer than the maximum length set by the max_length parameter.
"""

# tokenize function
def t5_tokenize_function(example):
    start_prompt = 'Summarize the following conversation.\n\n'
    end_prompt = '\n\nSummary: '

    prompt = [start_prompt + dialogue + end_prompt for dialogue in example["prompt"]]

    example['input_ids'] = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt").input_ids
    example['labels'] = tokenizer(example["completion"], padding="max_length", truncation=True, return_tensors="pt").input_ids

    return example

## **Load models**

### **T5 model**

In [None]:
# load T5 model
t5_name ='google/flan-t5-base'
t5_model = AutoModelForSeq2SeqLM.from_pretrained(t5_name)

# T5 tokenizer
# parameter use_fast switches on fast tokenizer
t5_tokenizer = AutoTokenizer.from_pretrained(t5_name, use_fast=True)

In [None]:
# apply tokenization
t5_tokenized_ds = ds.map(t5_tokenize_function, batched=True)
t5_tokenized_ds = t5_tokenized_ds.remove_columns(['prompt', 'completion'])

In [None]:
# t5_tokenized_ds = t5_tokenized_ds.filter(lambda example, index: index % 100 == 0, with_indices=True)

# check shape
print(f"Shapes of the datasets:")
print(f"Training: {t5_tokenized_ds['train'].shape}")
print(f"Validation: {t5_tokenized_ds['validation'].shape}")
print(f"Test: {t5_tokenized_ds['test'].shape}")

In [None]:
t5_tokenized_ds

In [None]:
"""
 create a batch of examples using DataCollatorForSeq2Seq.
 It’s more efficient to dynamically pad the sentences to the longest length in a batch during collation,
 instead of padding the whole dataset to the maximum length.
"""

# from transformers import DataCollatorForSeq2Seq

# data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)

#### Fine-tune T5

In [None]:
output_dir="./results"

training_args = TrainingArguments(
    output_dir=output_dir,
    learning_rate=1e-5,
    num_train_epochs=3,
    weight_decay=0.01,
    per_device_train_batch_size=8,
    auto_find_batch_size=True,
    logging_steps=10,
    # max_steps=1,
    eval_strategy='epoch',
    report_to="none",
)

trainer = Trainer(
    model=t5_model,
    tokenizer=t5_tokenizer,
    args=training_args,
    train_dataset=t5_tokenized_ds['train'],
    eval_dataset=t5_tokenized_ds['validation']
)

trainer.train()

In [None]:
t5_instruct_model = AutoModelForSeq2SeqLM.from_pretrained("/content/results/checkpoint-903")

#### Evaluate the T5 Qualitatively

In [None]:
index = 100
dialogue = ds['test'][index]['prompt']
human_baseline_summary = ds['test'][index]['completion']

prompt = f"""
Summarize the following conversation.

{dialogue}

Summary:
"""

input_ids = t5_tokenizer(prompt, return_tensors="pt").input_ids

t5_res = t5_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200, num_beams=1))
t5_text_res = t5_tokenizer.decode(t5_res[0], skip_special_tokens=True)

t5_instruct_res = t5_instruct_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200, num_beams=1))
t5_instruct_text_res = t5_tokenizer.decode(t5_instruct_res[0], skip_special_tokens=True)

dash_line = '-'.join('' for x in range(100))
print(dash_line)
print(f'BASELINE HUMAN SUMMARY:\n{human_baseline_summary}')
print(dash_line)
print(f'ORIGINAL MODEL:\n{t5_text_res}')
print(dash_line)
print(f'INSTRUCT MODEL:\n{t5_instruct_text_res}')

#### Evaluate the T5 Quantitatively

ROUGE Metric

In [None]:
rouge = evaluate.load('rouge')

In [None]:
dialogues = ds['test'][0:3]['prompt']
human_baseline_summaries = ds['test'][0:3]['completion']

original_model_summaries = []
instruct_model_summaries = []

for _, dialogue in enumerate(dialogues):
    prompt = f"""
Summarize the following conversation.

{dialogue}

Summary: """
    input_ids = t5_tokenizer(prompt, return_tensors="pt").input_ids

    original_model_outputs = t5_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200))
    original_model_text_output = t5_tokenizer.decode(original_model_outputs[0], skip_special_tokens=True)
    original_model_summaries.append(original_model_text_output)

    instruct_model_outputs = t5_instruct_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200))
    instruct_model_text_output = t5_tokenizer.decode(instruct_model_outputs[0], skip_special_tokens=True)
    instruct_model_summaries.append(instruct_model_text_output)

zipped_summaries = list(zip(human_baseline_summaries, original_model_summaries, instruct_model_summaries))

df = pd.DataFrame(zipped_summaries, columns = ['human_baseline_summaries', 'original_model_summaries', 'instruct_model_summaries'])
df

In [None]:
original_model_results = rouge.compute(
    predictions=original_model_summaries,
    references=human_baseline_summaries[0:len(original_model_summaries)],
    use_aggregator=True,
    use_stemmer=True,
)

instruct_model_results = rouge.compute(
    predictions=instruct_model_summaries,
    references=human_baseline_summaries[0:len(instruct_model_summaries)],
    use_aggregator=True,
    use_stemmer=True,
)

print('ORIGINAL MODEL:')
print(original_model_results)
print('INSTRUCT MODEL:')
print(instruct_model_results)

BERTScore, and

BLEURT

### **BART**

In [None]:
# load BART model
bart_name = 'facebook/bart-large-cnn'
bart_model = AutoModelForSeq2SeqLM.from_pretrained(bart_name)

# BART tokenizer
bart_tokenizer = AutoTokenizer.from_pretrained(bart_name, use_fast=True)

In [None]:
def bart_tokenize_function(example):
    start_prompt = 'Summarize the following conversation.\n\n'
    end_prompt = '\n\nSummary: '

    prompt = [start_prompt + dialogue + end_prompt for dialogue in example["prompt"]]

    model_inputs = bart_tokenizer(prompt, padding="max_length", truncation=True, max_length=512)
    labels = bart_tokenizer(example["completion"], padding="max_length", truncation=True, max_length=128)

    # example['input_ids'] = bart_tokenizer(prompt, padding="max_length", truncation=True,  max_length=512)
    # example['labels'] = bart_tokenizer(example["completion"], padding="max_length", truncation=True, max_length=128)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

    # return example

In [None]:
bart_tokenizer.pad_token = bart_tokenizer.eos_token

# apply tokenization
bart_tokenized_ds = ds.map(bart_tokenize_function, batched=True)
bart_tokenized_ds = bart_tokenized_ds.remove_columns(['prompt', 'completion'])

In [None]:
training_args_bart = TrainingArguments(
    output_dir='./bart-clinical',
    learning_rate=1e-5,
    num_train_epochs=3,
    weight_decay=0.01,
    per_device_train_batch_size=8,
    auto_find_batch_size=True,
    logging_steps=10,
    # max_steps=1,
    eval_strategy='epoch',
    report_to="none",
)

trainer_bart = Trainer(
    model=bart_model,
    tokenizer=bart_tokenizer,
    args=training_args_bart,
    train_dataset=bart_tokenized_ds['train'],
    eval_dataset=bart_tokenized_ds['validation']
)

trainer_bart.train()

## **Evaluation**

In [None]:
# ROUGE, BERTScore, and BLEURT.

# **Examples**

## **View an instance of dialogue**

In [None]:
example_indices = [40, 200]

dash_line = '-'.join('' for x in range(100))

for i, index in enumerate(example_indices):
    print(dash_line)
    print('Example ', i + 1)
    print(dash_line)
    print('INPUT DIALOGUE:')
    print(ds['test'][index]['prompt'])
    print(dash_line)
    print('BASELINE HUMAN SUMMARY:')
    print(ds['test'][index]['completion'])
    print(dash_line)
    print()

In [None]:
# test tokenizer
sentence = "What time is it, Tom?"

sentence_encoded = t5_tokenizer(sentence, return_tensors='pt')

sentence_decoded = t5_tokenizer.decode(
        sentence_encoded["input_ids"][0],
        skip_special_tokens=True
    )

print('ENCODED SENTENCE:')
print(sentence_encoded["input_ids"][0])
print('\nDECODED SENTENCE:')
print(sentence_decoded)

## **Summarize Dialogue without Prompt Engineering**

In [None]:
for model_name, (tokenizer, model) in models.items():
  print("Model:", model_name)

  for i, index in enumerate(example_indices):
      dialogue = ds['test'][index]['prompt']
      summary = ds['test'][index]['completion']

      # tokenization
      inputs = tokenizer(dialogue, return_tensors='pt')
      output = tokenizer.decode(
          model.generate(
              inputs["input_ids"],
              max_new_tokens=50,
          )[0],
          skip_special_tokens=True
      )

      print(dash_line)
      print('Example ', i + 1)
      print(dash_line)
      print(f'INPUT PROMPT:\n{dialogue}')
      print(dash_line)
      print(f'BASELINE HUMAN SUMMARY:\n{summary}')
      print(dash_line)
      print(f'MODEL GENERATION - WITHOUT PROMPT ENGINEERING:\n{output}\n')

## **Summarize Dialogue with an Instruction Prompt**

### Zero Shot Inference with an Instruction Prompt

In [None]:
for model_name, (tokenizer, model) in models.items():
  print("Model:", model_name)

  for i, index in enumerate(example_indices):
      dialogue = ds['test'][index]['prompt']
      summary = ds['test'][index]['completion']

      prompt = f"""
Summarize the following conversation.

{dialogue}

Summary:
    """

      # tokenization
      inputs = tokenizer(prompt, return_tensors='pt')
      output = tokenizer.decode(
          model.generate(
              inputs["input_ids"],
              max_new_tokens=50,
          )[0],
          skip_special_tokens=True
      )

      print(dash_line)
      print('Example ', i + 1)
      print(dash_line)
      print(f'INPUT PROMPT:\n{prompt}')
      print(dash_line)
      print(f'BASELINE HUMAN SUMMARY:\n{summary}')
      print(dash_line)
      print(f'MODEL GENERATION - ZERO SHOT:\n{output}\n')

### Zero Shot Inference with the Prompt Template

In [None]:
for model_name, (tokenizer, model) in models.items():
  print("Model:", model_name)

  for i, index in enumerate(example_indices):
      dialogue = ds['test'][index]['prompt']
      summary = ds['test'][index]['completion']

      prompt = f"""
Dialogue:

{dialogue}

What was going on?
"""

      # tokenization
      inputs = tokenizer(prompt, return_tensors='pt')
      output = tokenizer.decode(
          model.generate(
              inputs["input_ids"],
              max_new_tokens=50,
          )[0],
          skip_special_tokens=True
      )

      print(dash_line)
      print('Example ', i + 1)
      print(dash_line)
      print(f'INPUT PROMPT:\n{prompt}')
      print(dash_line)
      print(f'BASELINE HUMAN SUMMARY:\n{summary}')
      print(dash_line)
      print(f'MODEL GENERATION - ZERO SHOT (another template):\n{output}\n')

## **Summarize Dialogue with One Shot and Few Shot Inference**

### One Shot Inference

In [None]:
def make_prompt(example_indices_full, example_index_to_summarize):
    prompt = ''
    for index in example_indices_full:
        dialogue = ds['test'][index]['prompt']
        summary = ds['test'][index]['completion']

        # The stop sequence '{summary}\n\n\n' is important for FLAN-T5. Other models may have their own preferred stop sequence.
        prompt += f"""
Dialogue:

{dialogue}

What was going on?
{summary}


"""

    dialogue = ds['test'][example_index_to_summarize]['prompt']

    prompt += f"""
Dialogue:

{dialogue}

What was going on?
"""

    return prompt

In [None]:
example_indices_full = [40]
example_index_to_summarize = 200

one_shot_prompt = make_prompt(example_indices_full, example_index_to_summarize)

print(one_shot_prompt)

In [None]:
summary = ds['test'][example_index_to_summarize]['completion']

inputs = t5_tokenizer(one_shot_prompt, return_tensors='pt')
output = t5_tokenizer.decode(
    t5_model.generate(
        inputs["input_ids"],
        max_new_tokens=50,
    )[0],
    skip_special_tokens=True
)

print(dash_line)
print(f'BASELINE HUMAN SUMMARY:\n{summary}\n')
print(dash_line)
print(f'MODEL GENERATION - ONE SHOT:\n{output}')

### Few Shot Inference

In [None]:
example_indices_full = [40, 80, 120]
example_index_to_summarize = 200

few_shot_prompt = make_prompt(example_indices_full, example_index_to_summarize)

print(few_shot_prompt)

In [None]:
summary = ds['test'][example_index_to_summarize]['completion']

inputs = t5_tokenizer(few_shot_prompt, return_tensors='pt')
output = t5_tokenizer.decode(
    t5_model.generate(
        inputs["input_ids"],
        max_new_tokens=50,
    )[0],
    skip_special_tokens=True
)

print(dash_line)
print(f'BASELINE HUMAN SUMMARY:\n{summary}\n')
print(dash_line)
print(f'MODEL GENERATION - FEW SHOT:\n{output}')