<a href="https://colab.research.google.com/github/Michael-David-Lam/Medical-Dialogue-Summary/blob/dev3michael/Medical_Dialogue_Generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install Dependencies

In [1]:
!pip install kaggle
!pip install -U transformers
!pip install -U datasets
!pip install -U accelerate
!pip install -U evaluate
!pip install -U rouge_score
!pip install -U peft
!pip install sentencepiece
!mkdir ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json


Collecting datasets
  Downloading datasets-3.5.1-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.1-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.4/491.4 kB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2025.3.0-py3-none-any.whl (

# Import Dataset GitHub Repo

In [2]:
import kagglehub
import pandas as pd
import re
import numpy as np
!git clone https://github.com/abachaa/MTS-Dialog.git
# Load data
training_data =pd.read_csv('/content/MTS-Dialog/Main-Dataset/MTS-Dialog-TrainingSet.csv')
validation_data = pd.read_csv('/content/MTS-Dialog/Main-Dataset/MTS-Dialog-ValidationSet.csv')
Test_data = pd.read_csv('/content/MTS-Dialog/Main-Dataset/MTS-Dialog-TestSet-1-MEDIQA-Chat-2023.csv')
# Rename columns
training_data = training_data.rename(columns={'context': 'input_text', 'target': 'target_text'})

from datasets import Dataset
train_dataset = Dataset.from_pandas(training_data)
val_dataset = Dataset.from_pandas(validation_data)
test_dataset = Dataset.from_pandas(Test_data)


fatal: destination path 'MTS-Dialog' already exists and is not an empty directory.


# Define Model and Preprocess Data

In [11]:
from transformers import BartTokenizer, BartModel

model_name = "facebook/bart-base"
tokenizer = BartTokenizer.from_pretrained(model_name)

def preprocess_data(df):
    # Define section mapping and ordering
    SECTION_ORDER = [
        ('chief_complaint', 'cc'),
        ('history_of_present_illness', 'genhx'),
        ('past_medical_history', 'pastmedicalhx'),
        ('past_surgeries', 'pastsurgical'),
        ('medications', 'medications'),
        ('allergies', 'allergy'),
        ('social_history', 'fam/sochx'),
        ('educational_courses', 'edcourse'),
        ('review_of_systems', 'ros'),
        ('physical_exam', 'exam'),
        ('assessment', 'assessment'),
        ('exam','exam'),
        ('procedures','procedures'),
        ('labs','labs'),
        ('plan', 'plan'),
        ('disposition', 'disposition')
    ]

    df['dialogue_id'] = df['ID'].astype(str)
    grouped = df.groupby('dialogue_id')

    structured_data = []

    for dialogue_id, group in grouped:
        # Combine all dialogue turns (more robust than iloc[0])
        full_dialogue = ' '.join(group['dialogue'].tolist())

        # Combine section headers and texts into lists
        section_texts = group['section_text'].tolist()
        section_header = group['section_header'].tolist()

        # Set standard name value based on current section_header
        section_header_name =''
        for standard_name, header_name in SECTION_ORDER:
            if section_header[0].lower() == header_name:
                section_header_name = standard_name


        # Combine section texts to include section header
        combined_sections = []
        for text, header in zip(group['section_text'], group['section_header']):
            combined_sections.append(f"{header.strip()}: {text.strip()}")

        full_note = "\n".join(combined_sections)

        structured_data.append({
            'input_text': f"Generate a {section_header_name} clinical note: {full_dialogue}",
            'target_text': full_note,
            'section_header': section_header,
            'dialogue_id': dialogue_id
        })
    return pd.DataFrame(structured_data)

# Apply preprocessing
training_structured = preprocess_data(training_data)
print(training_structured.head())
validation_structured = preprocess_data(validation_data)

                                          input_text  \
0  Generate a history_of_present_illness clinical...   
1  Generate a history_of_present_illness clinical...   
2  Generate a history_of_present_illness clinical...   
3  Generate a past_medical_history clinical note:...   
4  Generate a past_medical_history clinical note:...   

                                         target_text   section_header  \
0  GENHX: The patient is a 76-year-old white fema...          [GENHX]   
1  GENHX: The patient is a 25-year-old right-hand...          [GENHX]   
2  GENHX: This 19-year-old Caucasian female prese...          [GENHX]   
3  PASTMEDICALHX: Significant for moderate to sev...  [PASTMEDICALHX]   
4                     PASTMEDICALHX: Nonsignificant.  [PASTMEDICALHX]   

  dialogue_id  
0           0  
1           1  
2          10  
3         100  
4        1000  


In [12]:
from torch.utils.data import Dataset, DataLoader

from torch.utils.data import Dataset

class DoctorPatientDataset(Dataset):
    def __init__(self, data, tokenizer, max_input_length=512, max_target_length=256):
        self.data = data
        self.tokenizer = tokenizer
        self.max_input_length = max_input_length
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Access data using .iloc to ensure integer-based indexing
        item = self.data.iloc[idx]  # Use .iloc for integer-based indexing
        input_text = item['input_text']
        target_text = item['target_text']

        # Tokenize inputs
        input_encodings = self.tokenizer(
            input_text,
            max_length=self.max_input_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # Tokenize targets
        target_encodings = self.tokenizer(
            target_text,
            max_length=self.max_target_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # Replace padding token id with -100 for loss calculation
        labels = target_encodings['input_ids']
        labels[labels == self.tokenizer.pad_token_id] = -100

        return {
            'input_ids': input_encodings['input_ids'].flatten(),
            'attention_mask': input_encodings['attention_mask'].flatten(),
            'labels': labels.flatten()
        }

## Create Train/Val Tokenized Datasets

In [24]:

# Then create datasets
train_tokenized = DoctorPatientDataset(training_structured, tokenizer)
val_tokenized = DoctorPatientDataset(validation_structured, tokenizer)

## Init Model and Lora Config

In [23]:
from transformers import BartForConditionalGeneration
from peft import LoraConfig, get_peft_model, TaskType
from peft import LoraConfig, get_peft_model, TaskType
# Example data preparation

# Initialize model with LoRA
model = BartForConditionalGeneration.from_pretrained(model_name)
lora_config = LoraConfig(
    r=4,
    lora_alpha=32,
    # target_modules=["q", "v"],
    lora_dropout=0.10,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM
)
# Wrap model with LoRA
model = get_peft_model(model, lora_config)


In [22]:
from transformers import GenerationConfig

generation_config = GenerationConfig(
    temperature=0.9,
    top_k=50,
    top_p=0.95,
    do_sample=True,
    repetition_penalty=2.0,
    no_repeat_ngram_size=4,
    num_beams=1,
    max_length=128
)


# Define Training Args and Metrics

In [21]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
import evaluate
from transformers import GenerationConfig
from transformers.trainer_utils import IntervalStrategy, SaveStrategy
rouge = evaluate.load("rouge")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred

    # Replace -100 with the pad token id for decoding labels
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)

    # Ensure token IDs are within valid range
    vocab_size = len(tokenizer)
    predictions = np.where(
        (predictions >= 0) & (predictions < vocab_size),
        predictions,
        tokenizer.unk_token_id  # Replace out-of-range IDs with unknown token
    )

    # Decode predictions and labels
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Compute ROUGE scores
    result = rouge.compute(
        predictions=decoded_preds,
        references=decoded_labels,
        use_stemmer=True
    )

    return {k: round(v, 4) for k, v in result.items()}

training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    # Replace 'evaluation_strategy' with 'eval_strategy'
    eval_strategy=IntervalStrategy.EPOCH,
    save_strategy=IntervalStrategy.EPOCH,
    learning_rate=5e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=15,
    predict_with_generate=True,
    fp16=True,
    generation_max_length=128,
    report_to="none",
    load_best_model_at_end=True,
    logging_strategy ="epoch",
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_tokenized,
    eval_dataset=val_tokenized,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

  trainer = Seq2SeqTrainer(
No label_names provided for model class `PeftModelForSeq2SeqLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


## Train Model

In [25]:
trainer.train()
model.save_pretrained("./clinical_note_model")
tokenizer.save_pretrained("./clinical_note_model")

Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,2.0909,2.1171,0.4335,0.2072,0.3814,0.3802
2,2.0744,2.10812,0.4273,0.2044,0.376,0.3756
3,2.0734,2.10803,0.431,0.21,0.38,0.379
4,2.0758,2.119206,0.4149,0.2019,0.3671,0.366
5,2.0504,2.105944,0.4411,0.2224,0.3867,0.3853
6,2.0677,2.107839,0.4243,0.2041,0.3704,0.3698
7,2.0313,2.104823,0.4253,0.2029,0.373,0.3718
8,2.0428,2.102727,0.4394,0.2206,0.3864,0.3855
9,2.0598,2.094616,0.432,0.2146,0.3818,0.3807
10,2.0582,2.099574,0.4317,0.2079,0.3815,0.381


('./clinical_note_model/tokenizer_config.json',
 './clinical_note_model/special_tokens_map.json',
 './clinical_note_model/vocab.json',
 './clinical_note_model/merges.txt',
 './clinical_note_model/added_tokens.json')

# Generate Summary

In [20]:
from transformers import GenerationConfig

# Define your generation config once
generation_config = GenerationConfig(
    temperature=0.7,
    top_k=60,
    top_p=0.95,
    do_sample=True,
    repetition_penalty=2.4,
    no_repeat_ngram_size=2,
    num_beams=1,
    max_length=128  # You can adjust this
)

# Function to generate notes from dialogue
def generate_note(dialogue, section_header):
    # Inject the target header as a prompt prefix
    prompt = f"Generate detailed {section_header} clinical note: {dialogue}"

    inputs = tokenizer(
        prompt,
        max_length=128,
        padding='max_length',
        truncation=True,
        return_tensors="pt"
    ).to(model.device)

    outputs = model.generate(
        input_ids=inputs['input_ids'],
        attention_mask=inputs['attention_mask'],
        generation_config=generation_config  # ✅ This is where it goes
    )

    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Example usage

# List of section headers necessary
options = ["CC", "GENHX", "PASTMEDICALHX", "DIAGNOSIS", "PLAN"]

for example in test_dataset:
    # if example['section_header'] in options:
      note = generate_note(example['dialogue'], example['section_header'])  # Access the 'dialogue' column
      print(example['section_header'])
      print(note)
      print(example["section_text"])



GENHX
GENHX: The patient is a 50-year-old African American female who was last seen on July 28, 2008.  She has been in the hospital since that time.
The patient is a 55-year-old African-American male that was last seen in clinic on 07/29/2008 with diagnosis of new onset seizures and an MRI scan, which demonstrated right contrast-enhancing temporal mass.  Given the characteristics of this mass and his new onset seizures, it is significantly concerning for a high-grade glioma.
FAM/SOCHX
FAM/SOCHX: Stroke.
Positive for stroke and sleep apnea.
ROS
ROS: No back pain.
MSK: Negative myalgia, negative joint pain, negative stiffness, negative weakness, negative back pain.
FAM/SOCHX
FAM/SOCHX: Noncontributory.
Noncontributory.
FAM/SOCHX
FAM/SOCHX: His father died of thoracic aortic aneurysm.  His mother died at age 70.
Father died of a thoracic aortic aneurysm, age 71. Mother died of stroke, age 81.
FAM/SOCHX
FAM/SOCHX: Negative.
Reviewed and remained unchanged.
GENHX
GENHX: The patient is a 7-y

KeyboardInterrupt: 