<a href="https://colab.research.google.com/github/Michael-David-Lam/Medical-Dialogue-Summary/blob/Bart-base-model/Medical_Dialogue_Generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install Dependencies

In [2]:
!pip install kaggle
!pip install -U transformers
!pip install -U datasets
!pip install -U accelerate
!pip install -U evaluate
!pip install -U rouge_score
!pip install -U peft
!pip install sentencepiece
!mkdir ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json




The syntax of the command is incorrect.
'cp' is not recognized as an internal or external command,
operable program or batch file.
'chmod' is not recognized as an internal or external command,
operable program or batch file.


# Import Dataset GitHub Repo

In [1]:
import kagglehub
import pandas as pd
import re
import numpy as np
!git clone https://github.com/abachaa/MTS-Dialog.git
# Load data
training_data =pd.read_csv('/content/MTS-Dialog/Main-Dataset/MTS-Dialog-TrainingSet.csv')
validation_data = pd.read_csv('/content/MTS-Dialog/Main-Dataset/MTS-Dialog-ValidationSet.csv')
Test_data = pd.read_csv('/content/MTS-Dialog/Main-Dataset/MTS-Dialog-TestSet-1-MEDIQA-Chat-2023.csv')
# Rename columns
training_data = training_data.rename(columns={'context': 'input_text', 'target': 'target_text'})

from datasets import Dataset
train_dataset = Dataset.from_pandas(training_data)
val_dataset = Dataset.from_pandas(validation_data)
test_dataset = Dataset.from_pandas(Test_data)


fatal: destination path 'MTS-Dialog' already exists and is not an empty directory.


# Define Model and Preprocess Data

In [2]:
from transformers import BartTokenizer, BartModel

model_name = "facebook/bart-base"
tokenizer = BartTokenizer.from_pretrained(model_name)

def preprocess_data(df):
    # Define section mapping and ordering
    SECTION_ORDER = [
        ('chief_complaint', 'cc'),
        ('history_of_present_illness', 'genhx'),
        ('past_medical_history', 'pastmedicalhx'),
        ('past_surgeries', 'pastsurgical'),
        ('medications', 'medications'),
        ('allergies', 'allergy'),
        ('social_history', 'fam/sochx'),
        ('educational_courses', 'edcourse'),
        ('review_of_systems', 'ros'),
        ('physical_exam', 'exam'),
        ('assessment', 'assessment'),
        ('exam','exam'),
        ('procedures','procedures'),
        ('labs','labs'),
        ('plan', 'plan'),
        ('disposition', 'disposition')
    ]

    df['dialogue_id'] = df['ID'].astype(str)
    grouped = df.groupby('dialogue_id')

    structured_data = []

    for dialogue_id, group in grouped:
        # Combine all dialogue turns (more robust than iloc[0])
        full_dialogue = ' '.join(group['dialogue'].tolist())

        # Extract all sections
        sections = {}
        for _, row in group.iterrows():
            section_key = row['section_header'].lower().strip()
            sections[section_key] = row['section_text'].strip()

        # Build target text in XML-style format
        target_parts = []
        for standard_name, source_name in SECTION_ORDER:
            if source_name in sections and sections[source_name]:
                target_parts.append(f"<{standard_name}>{sections[source_name]}</{standard_name}>")

        # Add unmapped sections at the end
        for section_key, text in sections.items():
            if section_key not in [x[1] for x in SECTION_ORDER] and text:
                target_parts.append(f"<{section_key}>{text}</{section_key}>")

        target_text = ' '.join(target_parts)

        structured_data.append({
            'input_text': f"Generate clinical note: {full_dialogue}",
            'target_text': target_text,
            'dialogue_id': dialogue_id
        })

    return pd.DataFrame(structured_data)

# Apply preprocessing
training_structured = preprocess_data(training_data)
validation_structured = preprocess_data(validation_data)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [3]:
from torch.utils.data import Dataset, DataLoader

from torch.utils.data import Dataset

class DoctorPatientDataset(Dataset):
    def __init__(self, data, tokenizer, max_input_length=512, max_target_length=256):
        self.data = data
        self.tokenizer = tokenizer
        self.max_input_length = max_input_length
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Access data using .iloc to ensure integer-based indexing
        item = self.data.iloc[idx]  # Use .iloc for integer-based indexing
        input_text = item['input_text']
        target_text = item['target_text']

        # Tokenize inputs
        input_encodings = self.tokenizer(
            input_text,
            max_length=self.max_input_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # Tokenize targets
        target_encodings = self.tokenizer(
            target_text,
            max_length=self.max_target_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # Replace padding token id with -100 for loss calculation
        labels = target_encodings['input_ids']
        labels[labels == self.tokenizer.pad_token_id] = -100

        return {
            'input_ids': input_encodings['input_ids'].flatten(),
            'attention_mask': input_encodings['attention_mask'].flatten(),
            'labels': labels.flatten()
        }

## Create Train/Val Tokenized Datasets

In [4]:

# Then create datasets
train_tokenized = DoctorPatientDataset(training_structured, tokenizer)
val_tokenized = DoctorPatientDataset(validation_structured, tokenizer)

## Init Model and Lora Config

In [None]:
from transformers import BartForConditionalGeneration
from peft import LoraConfig, get_peft_model, TaskType
# Example data preparation

# Initialize model with LoRA
model = BartForConditionalGeneration.from_pretrained(model_name)
lora_config = LoraConfig(
    r=4,
    lora_alpha=32,
    # target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM
)
# Wrap model with LoRA
model = get_peft_model(model, lora_config)


In [9]:
from transformers import GenerationConfig

generation_config = GenerationConfig(
    temperature=0.9,
    top_k=50,
    top_p=0.95,
    do_sample=True,
    repetition_penalty=2.0,
    no_repeat_ngram_size=4,
    num_beams=1,
    max_length=128
)


# Define Training Args and Metrics

In [10]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
import evaluate
from transformers import GenerationConfig
from transformers.trainer_utils import IntervalStrategy, SaveStrategy
rouge = evaluate.load("rouge")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred

    # Replace -100 with the pad token id for decoding labels
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)

    # Ensure token IDs are within valid range
    vocab_size = len(tokenizer)
    predictions = np.where(
        (predictions >= 0) & (predictions < vocab_size),
        predictions,
        tokenizer.unk_token_id  # Replace out-of-range IDs with unknown token
    )

    # Decode predictions and labels
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Compute ROUGE scores
    result = rouge.compute(
        predictions=decoded_preds,
        references=decoded_labels,
        use_stemmer=True
    )

    return {k: round(v, 4) for k, v in result.items()}

training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    # Replace 'evaluation_strategy' with 'eval_strategy'
    eval_strategy=IntervalStrategy.EPOCH,
    save_strategy=IntervalStrategy.EPOCH,
    learning_rate=3e-4,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=20,
    predict_with_generate=True,
    fp16=True,
    generation_max_length=128,
    report_to="none",
    load_best_model_at_end=True,
    logging_strategy ="epoch",
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_tokenized,
    eval_dataset=val_tokenized,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

Downloading builder script:   0%|          | 0.00/6.27k [00:00<?, ?B/s]

  trainer = Seq2SeqTrainer(
No label_names provided for model class `PeftModelForSeq2SeqLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


## Train Model

In [11]:
trainer.train()
model.save_pretrained("./clinical_note_model")
tokenizer.save_pretrained("./clinical_note_model")

Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,2.9025,2.172034,0.258,0.1,0.1928,0.1929
2,2.349,2.078383,0.3165,0.121,0.2408,0.2416
3,2.2563,2.03674,0.348,0.1483,0.2767,0.2778
4,2.1991,2.001751,0.3397,0.1459,0.2756,0.2769
5,2.1342,1.973599,0.3655,0.1596,0.2941,0.2938
6,2.1119,1.947342,0.3533,0.1541,0.2842,0.2846
7,2.0521,1.959473,0.3929,0.1822,0.3276,0.3274
8,2.0371,1.941308,0.377,0.1718,0.3066,0.3055
9,2.0199,1.919962,0.3972,0.1824,0.3267,0.3261
10,1.996,1.913205,0.4074,0.1914,0.3425,0.3411


('./clinical_note_model/tokenizer_config.json',
 './clinical_note_model/special_tokens_map.json',
 './clinical_note_model/vocab.json',
 './clinical_note_model/merges.txt',
 './clinical_note_model/added_tokens.json')

# Generate Summary

In [13]:
from transformers import GenerationConfig

# Define your generation config once
generation_config = GenerationConfig(
    temperature=0.7,
    top_k=60,
    top_p=0.95,
    do_sample=True,
    repetition_penalty=2.4,
    no_repeat_ngram_size=2,
    num_beams=1,
    max_length=128  # You can adjust this
)

# Function to generate notes from dialogue
def generate_note(dialogue):
    inputs = tokenizer(
        dialogue,
        max_length=128,
        padding='max_length',
        truncation=True,
        return_tensors="pt"
    ).to(model.device)

    outputs = model.generate(
        input_ids=inputs['input_ids'],
        attention_mask=inputs['attention_mask'],
        generation_config=generation_config  # ✅ This is where it goes
    )

    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Example usage

# List of section headers necessary
options = ["CC", "GENHX", "PASTMEDICALHX", "DIAGNOSIS", "PLAN"]

for example in test_dataset:
    if example['section_header'] in options:
        note = generate_note(example['dialogue'])  # Access the 'dialogue' column
        print(example['section_header'])
        print(note)  # Or store the note for later use
        # print(example['dialogue'])


GENHX
<history_of_present_illness>The patient is a 50-year-old African American who was last seen on 07/09/08.  The patient did not have her chart with her at that time, but the nurse brought it to me today.Â She has had no history of MRSI results since then. No previous hospitalizations.</history__of_(present)...
GENHX
<history_of_present_illness>The patient is a 17-year-old female who was sedated with Ativan.  She appeared short of breath upon arrival and immediately had Xray come in to scan her lungs. The patient's right diaphragm showed what we believe to be free air under her right ventricular fibrillation, which has been shown to have free oxygen under the right tracheostomy.</History_OF_CONSTITUTION>
GENHX
<history_of_present_illness>The patient is a 75-year-old African American female who presents to the hospital for evaluation.  The patient weighs approximately 1,100 pounds and does not have any major medical conditions.</History_OF_Present_Illnesses>
GENHX
<history_of_present