# Step 1: Set Up the Environment

In [1]:
# Confirm Colab Pro RAM and GPU
!free -h
!nvidia-smi

# Install required libraries
!pip install transformers datasets peft bitsandbytes torch accelerate rouge_score bert_score

# Mount Google Drive for saving models
from google.colab import drive
drive.mount('/content/drive')



               total        used        free      shared  buff/cache   available
Mem:            83Gi       965Mi        78Gi       1.0Mi       4.5Gi        81Gi
Swap:             0B          0B          0B
Sat Mar 22 19:35:39 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   31C    P0             57W /  400W |       0MiB /  40960MiB |      0%      Default |
|                        

In [None]:
# Log in to Hugging Face with a read token (since LLaMA 3 8B is gated)
from huggingface_hub import login
login(token="")  # Replace with your read token

# Step 2: Import Libraries

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, BitsAndBytesConfig
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from rouge_score import rouge_scorer
from bert_score import score
import re
import os
import pandas as pd
from transformers import pipeline
from rouge_score import rouge_scorer
from bert_score import score as bert_score
import numpy as np



# Step 3: Load and Preprocess the Dataset

In [None]:
# Load the full dataset
dataset = load_dataset("omi-health/medical-dialogue-to-soap-summary")
print("Original Training Split Size:", len(dataset["train"]))
print("Original Validation Split Size:", len(dataset["validation"]))
print("Original Test Split Size:", len(dataset["test"]))

# Preprocessing function
def preprocess_dialogue(example):
    dialogue = example["dialogue"]
    soap = example["soap"]

    # Clean text
    dialogue = re.sub(r'[^A-Za-z0-9\s.,:?-]', '', dialogue).lower()
    soap = re.sub(r'[^A-Za-z0-9\s.,:?-]', '', soap).lower()

    # Add role tags
    dialogue = dialogue.replace("Doctor:", "[Doctor]:")
    dialogue = dialogue.replace("Patient:", "[Patient]:")

    return {"dialogue": dialogue, "soap": soap}

# Apply preprocessing
processed_dataset = dataset.map(preprocess_dialogue)

Original Training Split Size: 9250
Original Validation Split Size: 500
Original Test Split Size: 250


# Step 4: Tokenize the Dataset

In [35]:
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")

# Set padding token (LLaMA 3 might not have a pad token by default)
tokenizer.pad_token = tokenizer.eos_token

# Tokenization function with labels
def tokenize_function(example):
    tokenized = tokenizer(
        example["text"],
        truncation=True,
        padding="max_length",
        max_length=512,
        return_tensors="pt"
    )
    tokenized["labels"] = tokenized["input_ids"].clone()
    return tokenized

# Apply tokenization
tokenized_dataset = processed_dataset.map(tokenize_function, batched=True)

# Remove unnecessary columns and set format
tokenized_dataset = tokenized_dataset.remove_columns(["dialogue", "soap", "text"])
tokenized_dataset.set_format("torch")

# Split into train and eval datasets
train_dataset = tokenized_dataset["train"]
eval_dataset = tokenized_dataset["validation"]

# Debug: Verify 'labels' key
print("Sample from train_dataset:", train_dataset[0].keys())
print("Sample 'labels' shape:", train_dataset[0]["labels"].shape)
print("Sample from eval_dataset:", eval_dataset[0].keys())
print("Sample 'labels' shape (eval):", eval_dataset[0]["labels"].shape)

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Sample from train_dataset: dict_keys(['prompt', 'messages', 'messages_nosystem', 'input_ids', 'attention_mask', 'labels'])
Sample 'labels' shape: torch.Size([512])
Sample from eval_dataset: dict_keys(['prompt', 'messages', 'messages_nosystem', 'input_ids', 'attention_mask', 'labels'])
Sample 'labels' shape (eval): torch.Size([512])


# Step 5: Load the Model with QLoRA

In [36]:
# Quantization config for QLoRA
quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype="float16"
)

# Load the model
model_name = "meta-llama/Meta-Llama-3-8B"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=quant_config,
    device_map="auto",
    torch_dtype=torch.float16
)

# LoRA config
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

# Apply LoRA
model = get_peft_model(model, lora_config)

# Print model to verify
print("Model Loaded with PEFT:", model)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Model Loaded with PEFT: PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(128256, 4096)
        (layers): ModuleList(
          (0-31): 32 x LlamaDecoderLayer(
            (self_attn): LlamaAttention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=16, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
   

# Step 6: Fine-Tune the Model

In [None]:
# Clear GPU memory
torch.cuda.empty_cache()

In [None]:


# Set PyTorch memory management
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Training arguments
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=5,  # Increased from 3 to 5
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    report_to="none"
)

# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
)

# Train the model
trainer.train()

# Save the model (optional, for later use)
model.save_pretrained("/content/drive/MyDrive/fine_tuned_model")
tokenizer.save_pretrained("/content/drive/MyDrive/fine_tuned_model")

  trainer = Trainer(
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


Epoch,Training Loss,Validation Loss
1,1.3675,1.344921
2,1.3222,1.328012
3,1.2905,1.321401
4,1.2795,1.321004
5,1.2421,1.322907


('/content/drive/MyDrive/fine_tuned_model/tokenizer_config.json',
 '/content/drive/MyDrive/fine_tuned_model/special_tokens_map.json',
 '/content/drive/MyDrive/fine_tuned_model/tokenizer.json')

# Step 7: Clear Memory After Training

In [41]:
# Clear memory after training
del trainer, model
torch.cuda.empty_cache()

# Step 8: Load Model for Evaluation

In [None]:
# Load the quantization config

quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype="float16"
)

# Load the fine-tuned model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    "/content/drive/MyDrive/fine_tuned_model",
    quantization_config=quant_config,
    device_map="auto",
    torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
tokenizer.pad_token = tokenizer.eos_token

# Load the test split
dataset = load_dataset("omi-health/medical-dialogue-to-soap-summary", split="test")

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

# Step 9: Generate Predictions for Evaluation

In [55]:
# Generate predictions
generated_soaps = []
reference_soaps = []

num_samples = len(dataset)
for i in range(num_samples):
    # Extract the dialogue for this sample
    dialogue_input = dataset[i]["dialogue"]

    # Format the input prompt
    prompt = f"""
You are a medical assistant AI designed to generate SOAP notes (Subjective, Objective, Assessment, and Plan) from medical dialogues.

A SOAP note has four sections:
Subjective (S): Patient's reported symptoms, history, and concerns.
Objective (O): Measurable or observed data (e.g., vitals, exam findings, labs).
Assessment (A): Diagnosis or clinical impression.
Plan (P): Treatment plan (e.g., tests, medications, follow-up).

Now, carefully read the following medical dialogue and create a detailed SOAP note in this format:

Dialogue:
{dialogue_input}

SOAP Note:
"""

    # Use the prompt as the input text
    input_text = prompt

    # Tokenize the input
    inputs = tokenizer(
        input_text,
        return_tensors="pt",
        truncation=True,
        max_length=1024
    ).to("cuda")
    generate_inputs = {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"]
    }

    # Generate with adjusted parameters
    outputs = model.generate(
        **generate_inputs,
        max_new_tokens=800,
        pad_token_id=tokenizer.eos_token_id,
        temperature=0.1,  # Reduced for less hallucination
        top_p=0.95,  # Adjusted for more diversity
        repetition_penalty=1.2,
        do_sample=True
    )
    generated_soap = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Extract the generated SOAP note content
    if "SOAP Note:" in generated_soap:
        generated_soap_content = generated_soap.split("SOAP Note:")[-1].strip()
    else:
        generated_soap_content = "No SOAP note generated."

    # Print only the sample identifier and the raw SOAP note content
    print(f"Sample {i}\n{generated_soap_content}\n")

    # Store the generated and reference SOAP notes
    generated_soaps.append(generated_soap_content)
    reference_soaps.append(dataset[i]["soap"])

    # Clear memory
    del inputs, outputs
    torch.cuda.empty_cache()

Sample 0
Subjective: The patient reports painless blurry vision in the right eye for one week, along with intermittent fever, headaches, body aches, and a non-pruritic maculopapular rash on both lower limbs present for six months. There is no associated neck stiffness, nausea, vomiting, Raynauds phenomenon, oral ulcers, chest pain, dyspnea, abdominal pain, or photosensitivity. Past medical history includes occasional episodes of left knee and testicle swelling but no known exposure to toxins or unhealthy lifestyle practices such as smoking, alcohol consumption, or illegal drugs. Currently employed as a flooring installer.
Objective: Vital signs within normal limits. Physical examination revealed bilateral papilledema and optic nerve erythema more pronounced in the right eye compared to the left, accompanied by a right inferior nasal quadrant visual field deficit and a relative afferent pupillary defect. Muscle tone and deep tendon reflexes were unremarkable; sensory testing indicated i

# Step 10: Evaluation

In [68]:
# Install dependencies for ROUGE, BERTScore, and BLEURT
!pip install rouge-score bert-score



In [None]:

# Initialize lists to store scores
rouge1_scores = []
rouge2_scores = []
rougel_scores = []
bertscore_f1_scores = []

# Initialize ROUGE scorer
rouge_scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

# Compute metrics for each sample
for i in range(len(generated_soaps)):
    gen_soap = generated_soaps[i]
    ref_soap = reference_soaps[i]

    # Handle empty or invalid SOAP notes
    if not gen_soap or gen_soap == "No SOAP note generated." or not ref_soap:
        print(f"Skipping Sample {i}: Invalid or empty SOAP note.")
        rouge1_scores.append(0.0)
        rouge2_scores.append(0.0)
        rougel_scores.append(0.0)
        bertscore_f1_scores.append(0.0)
        continue

    # ROUGE Scores
    rouge_scores = rouge_scorer.score(ref_soap, gen_soap)
    rouge1_scores.append(rouge_scores['rouge1'].fmeasure)
    rouge2_scores.append(rouge_scores['rouge2'].fmeasure)
    rougel_scores.append(rouge_scores['rougeL'].fmeasure)

    # BERTScore with a different model to avoid warnings
    P, R, F1 = bert_score([gen_soap], [ref_soap], lang="en", model_type="bert-base-uncased", verbose=False)
    bertscore_f1_scores.append(F1.item())

# Compute averages
avg_rouge1 = np.mean(rouge1_scores)
avg_rouge2 = np.mean(rouge2_scores)
avg_rougel = np.mean(rougel_scores)
avg_bertscore_f1 = np.mean(bertscore_f1_scores)

# Print results
print("Automated Metrics Across All Samples:")
print(f"Average ROUGE-1 F1: {avg_rouge1:.4f}")
print(f"Average ROUGE-2 F1: {avg_rouge2:.4f}")
print(f"Average ROUGE-L F1: {avg_rougel:.4f}")
print(f"Average BERTScore F1: {avg_bertscore_f1:.4f}")

Automated Metrics Across All Samples:
Average ROUGE-1 F1: 0.4703
Average ROUGE-2 F1: 0.1897
Average ROUGE-L F1: 0.3008
Average BERTScore F1: 0.6934


# SOAP 2 Summarized and Explainable Report

In [77]:
# Clear memory after training
torch.cuda.empty_cache()

In [None]:

# Load the Hugging Face model
model_name = "Remiscus/MediGen"  # Replace with your preferred model if needed
generator = pipeline("text2text-generation", model=model_name)

# Function to create a prompt for summarization
def create_prompt(soap_note):
    return f"""
You are a medical assistant. Your task is to take a SOAP (Subjective, Objective, Assessment, Plan) report and generate a summarized, easy-to-understand explanation. The output should be clear and accessible to anyone without a medical background, avoiding jargon unless fully explained.

SOAP Report:
{soap_note}

Summarized and Explainable Report:
"""

# Function to generate the summary
def generate_summary(prompt):
    response = generator(prompt, max_length=500, num_return_sequences=1, temperature=0.01)
    return response[0]['generated_text']

# Read the CSV file
csv_file_path = '/content/output.csv'  # Replace with your actual CSV file path
df = pd.read_csv(csv_file_path)
df = df.head(5)
# Ensure the column name matches your CSV
if 'SOAP' not in df.columns:
    raise ValueError("The CSV file must contain a 'SOAP' column")

# Process each row
for index, row in df.iterrows():
    soap_note = row['SOAP']
    prompt = create_prompt(soap_note)
    simplified = generate_summary(prompt)

    print(f"\nRow {index + 1} - Original SOAP Note:\n", soap_note)
    print(f"\nRow {index + 1} - Summarized and Explainable Report:\n", simplified)
    print("=" * 100)


`low_cpu_mem_usage` was None, now default to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Device set to use cuda:0
The model 'LlamaForCausalLM' is not supported for text2text-generation. Supported models are ['BartForConditionalGeneration', 'BigBirdPegasusForConditionalGeneration', 'BlenderbotForConditionalGeneration', 'BlenderbotSmallForConditionalGeneration', 'EncoderDecoderModel', 'FSMTForConditionalGeneration', 'GPTSanJapaneseForConditionalGeneration', 'LEDForConditionalGeneration', 'LongT5ForConditionalGeneration', 'M2M100ForConditionalGeneration', 'MarianMTModel', 'MBartForConditionalGeneration', 'MT5ForConditionalGeneration', 'MvpForConditionalGeneration', 'NllbMoeForConditionalGeneration', 'PegasusForConditionalGeneration', 'PegasusXForConditionalGeneration', 'PLBartForConditionalGeneration', 'ProphetNetForConditionalGeneration', 'Qwen2AudioForConditionalGeneration', 'SeamlessM4TForTextToText', 'SeamlessM4Tv2ForTextToText', 'SwitchTransformersForConditionalGeneration', 'T5ForConditionalGeneration', 'UMT5ForConditionalGeneration', 'XLMProphetNetForConditionalGenerati


Row 1 - Original SOAP Note:
 Subjective: The patient reports painless blurry vision in the right eye for one week, along with intermittent fever, headaches, body aches, and a non-pruritic maculopapular rash on both lower limbs present for six months. There is no associated neck stiffness, nausea, vomiting, Raynauds phenomenon, oral ulcers, chest pain, dyspnea, abdominal pain, or photosensitivity. Past medical history includes occasional episodes of left knee and testicle swelling but no known exposure to toxins or unhealthy lifestyle practices such as smoking, alcohol consumption, or illegal drugs. Currently employed as a flooring installer.
Objective: Vital signs within normal limits. Physical examination revealed bilateral papilledema and optic nerve erythema more pronounced in the right eye compared to the left, accompanied by a right inferior nasal quadrant visual field deficit and a relative afferent pupillary defect. Muscle tone and deep tendon reflexes were unremarkable; sensor

In [None]:

# Load a summarization model (BART is well-suited for this task)
model_name = "facebook/bart-large-cnn"
generator = pipeline("summarization", model=model_name)

# Function to create a prompt for summarization
def create_prompt(soap_note):
    return f"""
Summarize the following SOAP (Subjective, Objective, Assessment, Plan) report into a concise, easy-to-understand explanation for someone without a medical background. Avoid technical terms unless they’re explained simply. Focus on what the patient feels, what the doctor found, what might be wrong, and what will happen next.

SOAP Report:
{soap_note}

Summarized and Explainable Report:
"""

# Function to generate the summary
def generate_summary(soap_note):
    try:
        # Use the SOAP note directly as input for summarization, with a custom prompt
        response = generator(soap_note, max_length=150, min_length=50, do_sample=False)
        summary = response[0]['summary_text']
        # Post-process to ensure it’s clear and fits the "explainable" goal
        return summary.replace("SOAP", "").strip()
    except Exception as e:
        return f"Error generating summary: {str(e)}"

# Read the CSV file
csv_file_path = '/content/output.csv'  # Replace with your actual CSV file path
try:
    df = pd.read_csv(csv_file_path)
    df = df.head(5)  # Limit to first 5 rows for testing
except FileNotFoundError:
    print(f"Error: File '{csv_file_path}' not found.")
    exit()

# Ensure the column name matches your CSV
if 'SOAP' not in df.columns:
    raise ValueError("The CSV file must contain a 'SOAP' column")

# List to store results
results = []

# Process each row
for index, row in df.iterrows():
    soap_note = row['SOAP']
    simplified = generate_summary(soap_note)

    # Store results
    results.append({"Row": index + 1, "Original SOAP Note": soap_note, "Simplified Report": simplified})

    # Print with improved formatting
    print(f"\n{'*' * 50} Row {index + 1} {'*' * 50}")
    print(f"Original Note:\n{soap_note}\n")
    print(f"Simplified Explanation:\n{simplified}")
    print(f"{'*' * 110}")

# Save results to a new CSV file
output_df = pd.DataFrame(results)
output_df.to_csv('simplified_soap_reports.csv', index=False)
print("\nResults saved to 'simplified_soap_reports.csv'")


config.json:   0%|          | 0.00/1.58k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Device set to use cuda:0



************************************************** Row 1 **************************************************
Original Note:
Subjective: The patient reports painless blurry vision in the right eye for one week, along with intermittent fever, headaches, body aches, and a non-pruritic maculopapular rash on both lower limbs present for six months. There is no associated neck stiffness, nausea, vomiting, Raynauds phenomenon, oral ulcers, chest pain, dyspnea, abdominal pain, or photosensitivity. Past medical history includes occasional episodes of left knee and testicle swelling but no known exposure to toxins or unhealthy lifestyle practices such as smoking, alcohol consumption, or illegal drugs. Currently employed as a flooring installer.
Objective: Vital signs within normal limits. Physical examination revealed bilateral papilledema and optic nerve erythema more pronounced in the right eye compared to the left, accompanied by a right inferior nasal quadrant visual field deficit and a rela