<a href="https://colab.research.google.com/github/Michael-David-Lam/Medical-Dialogue-Summary/blob/dev3michael/Medical_Dialogue_Generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install Dependencies

In [54]:
!pip install kaggle
!pip install -U transformers
!pip install -U datasets
!pip install -U accelerate
!pip install -U evaluate
!pip install -U rouge_score
!pip install -U peft
!pip install sentencepiece
!mkdir ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json


mkdir: cannot create directory ‘/root/.kaggle’: File exists
cp: cannot stat 'kaggle.json': No such file or directory
chmod: cannot access '/root/.kaggle/kaggle.json': No such file or directory


# Import Dataset GitHub Repo

In [55]:
import kagglehub
import pandas as pd
import re
import numpy as np
!git clone https://github.com/abachaa/MTS-Dialog.git
# Load data
training_data =pd.read_csv('/content/MTS-Dialog/Main-Dataset/MTS-Dialog-TrainingSet.csv')
validation_data = pd.read_csv('/content/MTS-Dialog/Main-Dataset/MTS-Dialog-ValidationSet.csv')
Test_data = pd.read_csv('/content/MTS-Dialog/Main-Dataset/MTS-Dialog-TestSet-1-MEDIQA-Chat-2023.csv')
# Rename columns
training_data = training_data.rename(columns={'context': 'input_text', 'target': 'target_text'})

from datasets import Dataset
train_dataset = Dataset.from_pandas(training_data)
val_dataset = Dataset.from_pandas(validation_data)
test_dataset = Dataset.from_pandas(Test_data)


fatal: destination path 'MTS-Dialog' already exists and is not an empty directory.


# Define Model and Preprocess Data

In [65]:
from transformers import BartTokenizer, BartModel

model_name = "facebook/bart-base"
tokenizer = BartTokenizer.from_pretrained(model_name)

def preprocess_data(df):
    # Define section mapping and ordering
    SECTION_ORDER = [
        ('chief_complaint', 'cc'),
        ('history_of_present_illness', 'genhx'),
        ('past_medical_history', 'pastmedicalhx'),
        ('past_surgeries', 'pastsurgical'),
        ('medications', 'medications'),
        ('allergies', 'allergy'),
        ('social_history', 'fam/sochx'),
        ('educational_courses', 'edcourse'),
        ('review_of_systems', 'ros'),
        ('physical_exam', 'exam'),
        ('assessment', 'assessment'),
        ('exam','exam'),
        ('procedures','procedures'),
        ('labs','labs'),
        ('plan', 'plan'),
        ('disposition', 'disposition')
    ]

    df['dialogue_id'] = df['ID'].astype(str)
    grouped = df.groupby('dialogue_id')

    structured_data = []

    for dialogue_id, group in grouped:
        # Combine all dialogue turns (more robust than iloc[0])
        full_dialogue = ' '.join(group['dialogue'].tolist())

        # Combine section headers and texts into lists
        section_header = group['section_header'].tolist()
        section_texts = group['section_text'].tolist()


        structured_data.append({
            'input_text': f"Generate detailed {section_header[0]} clinical note: {full_dialogue}",
            'target_text': section_texts,
            'dialogue_id': dialogue_id
        })
    return pd.DataFrame(structured_data)

# Apply preprocessing
training_structured = preprocess_data(training_data)
print(training_structured.head())
validation_structured = preprocess_data(validation_data)

                                          input_text  \
0  Generate detailed GENHX clinical note: Doctor:...   
1  Generate detailed GENHX clinical note: Doctor:...   
2  Generate detailed GENHX clinical note: Doctor:...   
3  Generate detailed PASTMEDICALHX clinical note:...   
4  Generate detailed PASTMEDICALHX clinical note:...   

                                         target_text dialogue_id  
0  [The patient is a 76-year-old white female who...           0  
1  [The patient is a 25-year-old right-handed Cau...           1  
2  [This 19-year-old Caucasian female presents to...          10  
3  [Significant for moderate to severe aortic ste...         100  
4                                  [Nonsignificant.]        1000  


In [66]:
from torch.utils.data import Dataset, DataLoader

from torch.utils.data import Dataset

class DoctorPatientDataset(Dataset):
    def __init__(self, data, tokenizer, max_input_length=512, max_target_length=256):
        self.data = data
        self.tokenizer = tokenizer
        self.max_input_length = max_input_length
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Access data using .iloc to ensure integer-based indexing
        item = self.data.iloc[idx]  # Use .iloc for integer-based indexing
        input_text = item['input_text']
        target_text = item['target_text']

        # Tokenize inputs
        input_encodings = self.tokenizer(
            input_text,
            max_length=self.max_input_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # Tokenize targets
        target_encodings = self.tokenizer(
            target_text,
            max_length=self.max_target_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # Replace padding token id with -100 for loss calculation
        labels = target_encodings['input_ids']
        labels[labels == self.tokenizer.pad_token_id] = -100

        return {
            'input_ids': input_encodings['input_ids'].flatten(),
            'attention_mask': input_encodings['attention_mask'].flatten(),
            'labels': labels.flatten()
        }

## Create Train/Val Tokenized Datasets

In [67]:

# Then create datasets
train_tokenized = DoctorPatientDataset(training_structured, tokenizer)
val_tokenized = DoctorPatientDataset(validation_structured, tokenizer)

## Init Model and Lora Config

In [68]:
from transformers import BartForConditionalGeneration
from peft import LoraConfig, get_peft_model, TaskType
from peft import LoraConfig, get_peft_model, TaskType
# Example data preparation

# Initialize model with LoRA
model = BartForConditionalGeneration.from_pretrained(model_name)
lora_config = LoraConfig(
    r=4,
    lora_alpha=32,
    # target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM
)
# Wrap model with LoRA
model = get_peft_model(model, lora_config)


In [69]:
from transformers import GenerationConfig

generation_config = GenerationConfig(
    temperature=0.9,
    top_k=50,
    top_p=0.95,
    do_sample=True,
    repetition_penalty=2.0,
    no_repeat_ngram_size=4,
    num_beams=1,
    max_length=128
)


# Define Training Args and Metrics

In [70]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
import evaluate
from transformers import GenerationConfig
from transformers.trainer_utils import IntervalStrategy, SaveStrategy
rouge = evaluate.load("rouge")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred

    # Replace -100 with the pad token id for decoding labels
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)

    # Ensure token IDs are within valid range
    vocab_size = len(tokenizer)
    predictions = np.where(
        (predictions >= 0) & (predictions < vocab_size),
        predictions,
        tokenizer.unk_token_id  # Replace out-of-range IDs with unknown token
    )

    # Decode predictions and labels
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Compute ROUGE scores
    result = rouge.compute(
        predictions=decoded_preds,
        references=decoded_labels,
        use_stemmer=True
    )

    return {k: round(v, 4) for k, v in result.items()}

training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    # Replace 'evaluation_strategy' with 'eval_strategy'
    eval_strategy=IntervalStrategy.EPOCH,
    save_strategy=IntervalStrategy.EPOCH,
    learning_rate=3e-4,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=20,
    predict_with_generate=True,
    fp16=True,
    generation_max_length=128,
    report_to="none",
    load_best_model_at_end=True,
    logging_strategy ="epoch",
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_tokenized,
    eval_dataset=val_tokenized,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

  trainer = Seq2SeqTrainer(
No label_names provided for model class `PeftModelForSeq2SeqLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


## Train Model

In [71]:
trainer.train()
model.save_pretrained("./clinical_note_model")
tokenizer.save_pretrained("./clinical_note_model")

Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,3.1499,2.56367,0.2318,0.0888,0.1984,0.1987
2,2.791,2.486053,0.2725,0.1038,0.224,0.2235
3,2.706,2.449391,0.2874,0.1059,0.2451,0.2457
4,2.6591,2.424839,0.2955,0.1016,0.2495,0.2496
5,2.5824,2.399325,0.3025,0.1115,0.2489,0.2488
6,2.5567,2.369612,0.331,0.1123,0.2686,0.2683
7,2.5003,2.366112,0.2906,0.106,0.2415,0.2396
8,2.4623,2.357732,0.3159,0.1142,0.261,0.2606
9,2.4558,2.361243,0.3012,0.1151,0.2569,0.2557
10,2.44,2.351985,0.2889,0.1119,0.2347,0.2347


('./clinical_note_model/tokenizer_config.json',
 './clinical_note_model/special_tokens_map.json',
 './clinical_note_model/vocab.json',
 './clinical_note_model/merges.txt',
 './clinical_note_model/added_tokens.json')

# Generate Summary

In [72]:
from transformers import GenerationConfig

# Define your generation config once
generation_config = GenerationConfig(
    temperature=0.7,
    top_k=60,
    top_p=0.95,
    do_sample=True,
    repetition_penalty=2.4,
    no_repeat_ngram_size=2,
    num_beams=1,
    max_length=128  # You can adjust this
)

# Function to generate notes from dialogue
# Function to generate notes from dialogue
def generate_note(dialogue, section_header):
    # Inject the target header as a prompt prefix
    prompt = f"Section: {section_header}, Generate note: {dialogue}"

    inputs = tokenizer(
        prompt,
        max_length=128,
        padding='max_length',
        truncation=True,
        return_tensors="pt"
    ).to(model.device)

    outputs = model.generate(
        input_ids=inputs['input_ids'],
        attention_mask=inputs['attention_mask'],
        generation_config=generation_config  # ✅ This is where it goes
    )

    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Example usage

# List of section headers necessary
options = ["CC", "GENHX", "PASTMEDICALHX", "DIAGNOSIS", "PLAN"]

for example in test_dataset:
    if example['section_header'] in options:
        note = generate_note(example['dialogue'], example['section_header'])  # Access the 'dialogue' column
        print(example['section_header'])
        print(note)  # Or store the note for later use
        print(example["section_text"])



GENHX
The patient is a 50-year-old African American woman who comes in today for a birthday visit.  She has not had her chart with her since July 08, 2008.
The patient is a 55-year-old African-American male that was last seen in clinic on 07/29/2008 with diagnosis of new onset seizures and an MRI scan, which demonstrated right contrast-enhancing temporal mass.  Given the characteristics of this mass and his new onset seizures, it is significantly concerning for a high-grade glioma.
GENHX
The patient is a 17-year-old female who was sedated with Ativan and appeared short of breath upon arrival.  We immediately had Xray done to scan her lungs, which showed what we believe to be free air under her right di
The patient is a 77-year-old female who is unable to give any information.  She has been sedated with Ativan and came into the emergency room obtunded and unable to give any history.  On a chest x-ray for what appeared to be shortness of breath she was found to have what was thought to b