<a href="https://colab.research.google.com/github/Michael-David-Lam/Medical-Dialogue-Summary/blob/dev3michael/Medical_Dialogue_Generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install Dependencies

In [24]:
!pip install kaggle
!pip install -U transformers
!pip install -U datasets
!pip install -U accelerate
!pip install -U evaluate
!pip install -U rouge_score
!pip install -U peft
!pip install sentencepiece
!mkdir ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json


mkdir: cannot create directory ‘/root/.kaggle’: File exists
cp: cannot stat 'kaggle.json': No such file or directory
chmod: cannot access '/root/.kaggle/kaggle.json': No such file or directory


# Import Dataset GitHub Repo

In [36]:
import kagglehub
import pandas as pd
import re
import numpy as np
!git clone https://github.com/abachaa/MTS-Dialog.git
# Load data
training_data =pd.read_csv('/content/MTS-Dialog/Main-Dataset/MTS-Dialog-TrainingSet.csv')
validation_data = pd.read_csv('/content/MTS-Dialog/Main-Dataset/MTS-Dialog-ValidationSet.csv')
Test_data = pd.read_csv('/content/MTS-Dialog/Main-Dataset/MTS-Dialog-TestSet-1-MEDIQA-Chat-2023.csv')
# Rename columns
training_data = training_data.rename(columns={'context': 'input_text', 'target': 'target_text'})

from datasets import Dataset
train_dataset = Dataset.from_pandas(training_data)
val_dataset = Dataset.from_pandas(validation_data)
test_dataset = Dataset.from_pandas(Test_data)


fatal: destination path 'MTS-Dialog' already exists and is not an empty directory.


# Define Model and Preprocess Data

In [37]:
from transformers import BartTokenizer, BartModel

model_name = "facebook/bart-base"
tokenizer = BartTokenizer.from_pretrained(model_name)

def preprocess_data(df):
    # Define section mapping and ordering
    SECTION_ORDER = [
        ('chief_complaint', 'cc'),
        ('history_of_present_illness', 'genhx'),
        ('past_medical_history', 'pastmedicalhx'),
        ('past_surgeries', 'pastsurgical'),
        ('medications', 'medications'),
        ('allergies', 'allergy'),
        ('social_history', 'fam/sochx'),
        ('educational_courses', 'edcourse'),
        ('review_of_systems', 'ros'),
        ('physical_exam', 'exam'),
        ('assessment', 'assessment'),
        ('exam','exam'),
        ('procedures','procedures'),
        ('labs','labs'),
        ('plan', 'plan'),
        ('disposition', 'disposition')
    ]

    df['dialogue_id'] = df['ID'].astype(str)
    grouped = df.groupby('dialogue_id')

    structured_data = []

    for dialogue_id, group in grouped:
        # Combine all dialogue turns (more robust than iloc[0])
        full_dialogue = ' '.join(group['dialogue'].tolist())

        # Combine section headers and texts into lists
        section_texts = group['section_text'].tolist()
        section_header = group['section_header'].tolist()

        # Set standard name value based on current section_header
        section_header_name =''
        for standard_name, header_name in SECTION_ORDER:
            if section_header[0].lower() == header_name:
                section_header_name = standard_name


        # Combine section text to include section header's proper name
        combined_sections = []
        for text, header in zip(group['section_text'], group['section_header']):
            full_name = section_header_name.replace('_', ' ')   # Remove '_' from section header names
            combined_sections.append(f"{full_name.strip()}: {text.strip()}")

        full_note = "\n".join(combined_sections)

        structured_data.append({
            'input_text': f"Summarize the following doctor-patient dialogue into a detailed {full_name} clinical note: {full_dialogue}",
            'target_text': full_note,
            'section_header': section_header[0],
            'dialogue_id': dialogue_id
        })
    return pd.DataFrame(structured_data)

# Apply preprocessing
training_structured = preprocess_data(training_data)
print(training_structured.head())
validation_structured = preprocess_data(validation_data)

                                          input_text  \
0  Summarize the following doctor-patient dialogu...   
1  Summarize the following doctor-patient dialogu...   
2  Summarize the following doctor-patient dialogu...   
3  Summarize the following doctor-patient dialogu...   
4  Summarize the following doctor-patient dialogu...   

                                         target_text section_header  \
0  history of present illness: The patient is a 7...          GENHX   
1  history of present illness: The patient is a 2...          GENHX   
2  history of present illness: This 19-year-old C...          GENHX   
3  past medical history: Significant for moderate...  PASTMEDICALHX   
4              past medical history: Nonsignificant.  PASTMEDICALHX   

  dialogue_id  
0           0  
1           1  
2          10  
3         100  
4        1000  


In [38]:
from torch.utils.data import Dataset, DataLoader

from torch.utils.data import Dataset

class DoctorPatientDataset(Dataset):
    def __init__(self, data, tokenizer, max_input_length=512, max_target_length=256):
        self.data = data
        self.tokenizer = tokenizer
        self.max_input_length = max_input_length
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Access data using .iloc to ensure integer-based indexing
        item = self.data.iloc[idx]  # Use .iloc for integer-based indexing
        input_text = item['input_text']
        target_text = item['target_text']

        # Tokenize inputs
        input_encodings = self.tokenizer(
            input_text,
            max_length=self.max_input_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # Tokenize targets
        target_encodings = self.tokenizer(
            target_text,
            max_length=self.max_target_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # Replace padding token id with -100 for loss calculation
        labels = target_encodings['input_ids']
        labels[labels == self.tokenizer.pad_token_id] = -100

        return {
            'input_ids': input_encodings['input_ids'].flatten(),
            'attention_mask': input_encodings['attention_mask'].flatten(),
            'labels': labels.flatten()
        }

## Create Train/Val Tokenized Datasets

In [39]:

# Then create datasets
train_tokenized = DoctorPatientDataset(training_structured, tokenizer)
val_tokenized = DoctorPatientDataset(validation_structured, tokenizer)

## Init Model and Lora Config

In [40]:
from transformers import BartForConditionalGeneration
from peft import LoraConfig, get_peft_model, TaskType
from peft import LoraConfig, get_peft_model, TaskType
# Example data preparation

# Initialize model with LoRA
model = BartForConditionalGeneration.from_pretrained(model_name)
#best with r=8 and a=64
lora_config = LoraConfig(
    r=8,
    lora_alpha=64,
    lora_dropout=0.01,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM
)
# Wrap model with LoRA
model = get_peft_model(model, lora_config)


# Define Training Args and Metrics

In [41]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
import evaluate
from transformers import GenerationConfig
from transformers.trainer_utils import IntervalStrategy, SaveStrategy
rouge = evaluate.load("rouge")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred

    # Replace -100 with the pad token id for decoding labels
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)

    # Ensure token IDs are within valid range
    vocab_size = len(tokenizer)
    predictions = np.where(
        (predictions >= 0) & (predictions < vocab_size),
        predictions,
        tokenizer.unk_token_id  # Replace out-of-range IDs with unknown token
    )

    # Decode predictions and labels
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Compute ROUGE scores
    result = rouge.compute(
        predictions=decoded_preds,
        references=decoded_labels,
        use_stemmer=True
    )

    return {k: round(v, 4) for k, v in result.items()}

training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    eval_strategy=IntervalStrategy.EPOCH,
    save_strategy=IntervalStrategy.EPOCH,
    learning_rate=2e-4,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=20,
    predict_with_generate=True,
    fp16=True,
    generation_max_length=128,
    report_to="none",
    load_best_model_at_end=True,
    logging_strategy ="epoch",
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_tokenized,
    eval_dataset=val_tokenized,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

  trainer = Seq2SeqTrainer(
No label_names provided for model class `PeftModelForSeq2SeqLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


## Train Model

In [42]:
trainer.train()
model.save_pretrained("./clinical_note_model")
tokenizer.save_pretrained("./clinical_note_model")

Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,2.9977,2.389215,0.4,0.2019,0.3591,0.3598
2,2.5846,2.31234,0.4087,0.2121,0.3702,0.3697
3,2.4985,2.279534,0.4319,0.2331,0.3875,0.3891
4,2.441,2.237248,0.4409,0.2357,0.3915,0.3913
5,2.3628,2.211381,0.4552,0.2511,0.4013,0.4004
6,2.3236,2.180698,0.4578,0.2432,0.399,0.3988
7,2.2685,2.166882,0.4483,0.2423,0.3944,0.394
8,2.2507,2.165949,0.4536,0.2584,0.4012,0.4009
9,2.2275,2.140306,0.4702,0.2699,0.4189,0.4183
10,2.2109,2.140665,0.4523,0.2565,0.4059,0.4048


('./clinical_note_model/tokenizer_config.json',
 './clinical_note_model/special_tokens_map.json',
 './clinical_note_model/vocab.json',
 './clinical_note_model/merges.txt',
 './clinical_note_model/added_tokens.json')

# Generate Summary

In [46]:
from transformers import GenerationConfig

generation_config = GenerationConfig(
    temperature=0.9,
    top_k=50,
    top_p=0.95,
    do_sample=False,
    repetition_penalty=2.0,
    no_repeat_ngram_size=4,
    num_beams=4,
    max_length=256
)

# Function to generate notes from dialogue
def generate_note(dialogue, section_header):
    # Inject the target header as a prompt prefix
    prompt = f"Summarize the following doctor-patient dialogue into a detailed {section_header} clinical note: {dialogue}"

    inputs = tokenizer(
        prompt,
        max_length=128,
        padding='max_length',
        truncation=True,
        return_tensors="pt"
    ).to(model.device)

    outputs = model.generate(
        input_ids=inputs['input_ids'],
        attention_mask=inputs['attention_mask'],
        generation_config=generation_config
    )

    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Example usage
SECTION_NAME_MAP = {
    'CC': 'Chief Complaint',
    'GENHX': 'History of Present Illness',
    'PASTMEDICALHX': 'Past Medical History',
    'PASTSURGICAL': 'Past Surgeries',
    'MEDICATIONS': 'Medications',
    'ALLERGY': 'Allergies',
    'FAM/SOCHX': 'Social History',
    'EDCOURSE': 'Educational Courses',
    'ROS': 'Review of Systems',
    'EXAM': 'Physical Exam',
    'ASSESSMENT': 'Assessment',
    'PROCEDURES': 'Procedures',
    'LABS': 'Labs',
    'PLAN': 'Plan',
    'DISPOSITION': 'Disposition'
}

# List of section headers necessary
options = ["CC", "GENHX", "PASTMEDICALHX", "DIAGNOSIS", "PLAN"]

for example in test_dataset:
    # if example['section_header'] in options:
      note = generate_note(example['dialogue'], SECTION_NAME_MAP.get(example['section_header'])) # Pass dialogue column and section header name
      print(example['section_header'])
      print(note)
      print(example["section_text"])



GENHX
history of Present Illness: The patient is a 50-year-old African American woman who presents with history of present illness.  She has been in the hospital since 07/09/08.
The patient is a 55-year-old African-American male that was last seen in clinic on 07/29/2008 with diagnosis of new onset seizures and an MRI scan, which demonstrated right contrast-enhancing temporal mass.  Given the characteristics of this mass and his new onset seizures, it is significantly concerning for a high-grade glioma.
FAM/SOCHX
social History: Noncontributory.
Positive for stroke and sleep apnea.
ROS
review of Systems: No joint pain, stiffness, weakness, or back pain.
MSK: Negative myalgia, negative joint pain, negative stiffness, negative weakness, negative back pain.
FAM/SOCHX
social History: Noncontributory.
Noncontributory.
FAM/SOCHX
social History: Father died of thoracic aortic aneurysm. Mother died of stroke.
Father died of a thoracic aortic aneurysm, age 71. Mother died of stroke, age 81.
FAM