**Please run this notebook using Kaggle, there you will get free GPU T4 × 2.**

See my project in Github: [Fine-Tuning LLM for Medical Chat Summarization](https://github.com/aslamsikder/Fine-Tuning-LLM-for-Medical-Chat-Summarization)  
See my Kaggle Notebook: [Fine-Tuning LLM for Medical Chat Summarization](https://www.kaggle.com/code/aslamsikder/lora-fine-tuning-gemma-2b-medical-summarization?scriptVersionId=270967244)  

✍️ Author Information
Developed by **Aslam Sikder**, October 2025  
Email: [aslamsikder.edu@gmail.com](mailto:aslamsikder.edu@gmail.com)  
LinkedIn: [Aslam Sikder - Linkedin](https://www.linkedin.com/in/aslamsikder)  
Google Scholar: [Aslam Sikder - Google Scholar](https://scholar.google.com/citations?hl=en&user=Ip1qQi8AAAAJ)

**Cell 1: Project Setup and Installation**

This first step is crucial for setting up our environment. We will install all the necessary Python libraries. The most important library here is unsloth, which is specifically designed to make fine-tuning large language models (LLMs) significantly faster and more memory-efficient. This library is the key to making this project feasible on a free, resource-constrained GPU (like a T4 in Google Colab / T4*2 in Kaggle) which typically has VRAM limitations that would prevent standard fine-tuning. We also install trl for its specialized training tools, peft for parameter-efficient fine-tuning, and other standard libraries from the Hugging Face ecosystem.

In [None]:
# Cell 1
# Install required libraries
# We use unsloth for memory-efficient and faster fine-tuning of LLMs.
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git" -q
!pip install "trl<0.9.0" peft accelerate bitsandbytes -q
!pip install datasets evaluate bert_score rouge_score -q

In [None]:
# required library
import torch
from unsloth import FastLanguageModel
from datasets import load_dataset
from transformers import TrainingArguments
from trl import SFTTrainer
from huggingface_hub import login
from kaggle_secrets import UserSecretsClient
from flask import Flask, request, jsonify
from tqdm import tqdm
import evaluate
import shutil
import os
import re 

**Cell 2: Hugging Face Authentication**

To download the model or dataset we need to authenticate with Hugging Face. This step uses the login function to securely connect to your Hugging Face account. You will need to generate an access token with "write" permissions from your Hugging Face account settings and save it as a secret in your environment (in Colab/Kaggle, this is done via the "Secrets" tab). This ensures your credentials are not exposed in the code.

In [None]:
# Cell 2
# Hugging Face login & Git LFS setup (for Kaggle)

# Authenticate with Hugging Face to access the Llama 3 model
# Ensure you have saved your Hugging Face token as 'HF_TOKEN' in Colab secrets
try:
    secret_label = "HF_TOKEN"  # This should match the name of your secret in Kaggle
    hf_token = UserSecretsClient().get_secret(secret_label)
    
    # Log in securely to Hugging Face
    login(token=hf_token)
    print("Successfully logged into Hugging Face.")
except Exception as e:
    print(f"Could not log in. Please ensure 'HF_TOKEN' is set in Colab secrets. Error: {e}")


**Cell 3: Load Model and Tokenizer**

Here, we load the Gemma 2b model and its corresponding tokenizer. We use unsloth's **FastLanguageModel** class, which is a highly optimized wrapper around the standard Hugging Face model class.

Key parameters used:
*   model_name = "unsloth/gemma-2b-it-bnb-4bit": We use a version of Llama 3 that has been pre-quantized to 4-bit precision and optimized for Unsloth. This dramatically reduces the memory required to load the model.
*   load_in_4bit = True: This argument explicitly tells the model to load the weights in 4-bit precision, which is the core of our memory-saving strategy.
*   max_seq_length = 2048: We set a maximum sequence length to balance context understanding with memory constraints.

This cell makes the impossible possible: loading a billion-parameter model into a GPU with less than 8GB of VRAM.

In [None]:
# Cell 3

# Define model loading parameters
max_seq_length = 2048  # Maximum sequence length
dtype = None           # Unsloth will automatically choose the best dtype (bf16/fp16)
load_in_4bit = True    # Load in 4-bit quantized precision for memory efficiency

print("Loading Gemma 2B model and tokenizer with Unsloth...")

# Load the pre-quantized Gemma 2B instruction-tuned model
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/gemma-2b-it-bnb-4bit",  # Instruction-tuned, 4-bit version
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=load_in_4bit,
    trust_remote_code=True,  # Necessary for custom model implementations
)

# Verification
print("Model and tokenizer loaded successfully.")
print(f"Model Type: {model.config.model_type}")
print(f"Max Seq Length: {max_seq_length}")


**Cell 4: Load and Explore the Dataset**

We load our medical dialogue dataset directly from the Hugging Face Hub. The dataset is conveniently split into train, validation, and test sets. We will use the datasets library for this. After loading, we'll inspect the data to understand its structure, confirming the presence of the dialogue and soap columns which will serve as our input and target, respectively.

In [None]:
# Cell 4


# Load the medical dialogue dataset from Hugging Face
dataset_name = "omi-health/medical-dialogue-to-soap-summary"

print(f"Loading dataset: {dataset_name} ...")

# Load train, validation, and test splits
train_dataset = load_dataset(dataset_name, split="train")
validation_dataset = load_dataset(dataset_name, split="validation")
test_dataset = load_dataset(dataset_name, split="test")

# Display dataset information
print("\nDataset loaded successfully!")
print(f"Train samples: {len(train_dataset)}")
print(f"Validation samples: {len(validation_dataset)}")
print(f"Test samples: {len(test_dataset)}")

# Display column names
print("\nAvailable columns:", train_dataset.column_names)

# Show one formatted example (to verify fields before preprocessing)
print("\nExample sample from training set:")
print(train_dataset[0])


**Cell 5: Data Preprocessing and Prompt Formatting**

Instruction-tuned models like Gemma 2B perform best when the input data is formatted as a clear instruction. We use the official Gemma 2B chat template to structure our data in a conversational format that distinguishes between the user’s request and the model’s response. Each example is converted into a prompt where the user instructs the model to summarize a given medical dialogue into a SOAP note, and the model provides the corresponding summary. This ensures that the fine-tuned model learns to follow natural instructions effectively. We define a function format_chat_template to apply this transformation and then use the .map() method to process all splits of our dataset efficiently for training.

In [None]:
# Cell 5

def format_chat_template(row):
    # Gemma 2B expects simple role-based text prompts
    messages = [
        {
            "role": "user",
            "content": f"Summarize the following medical dialogue into a SOAP note:\n\n{row['dialogue']}"
        },
        {
            "role": "model",
            "content": f"{row['soap']}"
        }
    ]
    
    # Define the official Gemma chat template
    tokenizer.chat_template = (
        "{% for message in messages %}"
        "{% if message['role'] == 'user' %}"
        "{{ '<bos><start_of_turn>user\n' + message['content'] + '<end_of_turn>\n' }}"
        "{% elif message['role'] == 'model' %}"
        "{{ '<start_of_turn>model\n' + message['content'] + '<end_of_turn>' }}"
        "{% endif %}"
        "{% endfor %}"
    )
    
    # Apply template without tokenizing
    return {"text": tokenizer.apply_chat_template(messages, tokenize=False)}

# Apply to all splits
train_dataset_formatted = train_dataset.map(format_chat_template)
validation_dataset_formatted = validation_dataset.map(format_chat_template)
test_dataset_formatted = test_dataset.map(format_chat_template)

# Display a formatted example
print("\nFormatted prompt example:")
print(train_dataset_formatted["text"][0])


**Cell 6: Configure LoRA (Parameter-Efficient Fine-Tuning)**

In this cell, we prepare the Gemma 2B model for LoRA-based fine-tuning. LoRA (Low-Rank Adaptation) allows us to update only a small subset of parameters instead of the full model, making fine-tuning memory- and compute-efficient, especially for large models like Gemma 2B.

* r = 16: The rank of the low-rank matrices injected into the model. Higher rank allows more capacity for adaptation but increases memory usage.

* lora_alpha = 32: Scaling factor that balances the contribution of LoRA updates to the original weights.

* target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]: Specifies which layers of the model will receive LoRA adapters. These include attention projection layers and MLP layers, which are critical for capturing task-specific behavior.

* lora_dropout = 0.05: Applies a small dropout to LoRA updates for stability and regularization. Can be set to 0 if desired.

* bias = "none": Keeps the original biases frozen and only trains LoRA parameters.

* use_gradient_checkpointing = "unsloth": Enables memory-efficient backpropagation by storing fewer intermediate activations. This is crucial for training large models with limited GPU memory.

* random_state = 3407: Ensures reproducibility of LoRA initialization.

* use_rslora = False: Optional feature for rank-stabilized LoRA; only set True if using RSLora.

* loftq_config = None: Reserved for LoFTQ quantization; not needed in our current setup.

Overall, this configuration allows us to adapt Gemma 2B to our task efficiently, keeping memory consumption low while still enabling the model to learn the SOAP note summarization task effectively.

Output: Confirms that LoRA adapters are successfully applied to the model.

In [None]:
# Cell 6

# Configure the model for LoRA (Parameter-Efficient Fine-Tuning)
model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    lora_alpha=32,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",   # Attention projections
        "gate_proj", "up_proj", "down_proj"       # MLP projections
    ],
    lora_dropout=0.05,  # Small dropout for stability (can set to 0 if needed)
    bias="none",
    use_gradient_checkpointing="unsloth",  # Memory-efficient training
    random_state=3407,
    use_rslora=False,   # Set True only if using rank-stabilized LoRA (RSLora)
    loftq_config=None,  # Not needed unless using LoFTQ quantization
)

print("LoRA adapters successfully configured on the Gemma 2B model.")

**Cell 7: Define Training Arguments**

We define the hyperparameters for our fine-tuning process using the TrainingArguments class from the Hugging Face Transformers library. These settings are optimized for Gemma 2B fine-tuning with LoRA and 4-bit quantization using Unsloth. Each parameter is carefully chosen to balance training stability, speed, and memory efficiency.

* output_dir = "./outputs": Directory where model checkpoints and logs will be saved.

* num_train_epochs = 1: A single epoch is often sufficient for instruction fine-tuning, as the goal is to align the model with task-specific behavior rather than train from scratch.

* per_device_train_batch_size = 2: Small batch size ensures that the 4-bit quantized model fits comfortably within GPU memory limits.

* gradient_accumulation_steps = 4: Accumulates gradients over 4 mini-batches before performing an optimization step. This effectively simulates a larger batch size of 2 × 4 = 8, enhancing training stability without extra memory usage.

* warmup_steps = 5: Gradually increases the learning rate during the initial steps to prevent training instability and allow smoother convergence.

* learning_rate = 2e-4: A standard and reliable learning rate for QLoRA-based fine-tuning of instruction-tuned models.

* fp16 = not torch.cuda.is_bf16_supported(): Enables half-precision (16-bit) floating-point computation on GPUs that don’t support bf16, improving speed and reducing memory use.

* bf16 = torch.cuda.is_bf16_supported(): Enables bf16 precision when supported (e.g., A100, H100 GPUs) for faster and more stable mixed-precision training.

* optim = "adamw_8bit": Uses a memory-efficient 8-bit version of the AdamW optimizer, significantly reducing GPU memory consumption.

* weight_decay = 0.01: Adds a small regularization term to prevent overfitting during fine-tuning.

* lr_scheduler_type = "linear": Uses a linear learning rate schedule, which gradually decays the learning rate as training progresses.

* logging_steps = 10: Logs training progress every 10 steps for easier monitoring.

* report_to = "none": Disables third-party logging integrations like Weights & Biases for a cleaner local training setup.

Overall, these configurations are ideal for parameter-efficient fine-tuning (PEFT) on limited GPU resources while maintaining strong model performance and stable convergence.

In [None]:
# Cell 7



# Define the training arguments for Gemma 2B LoRA fine-tuning
training_args = TrainingArguments(
    output_dir="./outputs",              # Directory to save model checkpoints and logs
    num_train_epochs=1,                  # Number of epochs (increase for full training)
    per_device_train_batch_size=2,       # Small batch size for memory efficiency
    gradient_accumulation_steps=4,       # Accumulate gradients to simulate larger batches
    warmup_steps=5,                      # Steps for LR warm-up
    learning_rate=2e-4,                  # Good default for LoRA + 4-bit
    fp16=not torch.cuda.is_bf16_supported(),  # Use fp16 if bf16 not available
    bf16=torch.cuda.is_bf16_supported(),      # Prefer bf16 on newer GPUs (A100, H100)
    logging_steps=10,                    # Log every 10 steps
    save_strategy="epoch",               # Save checkpoints each epoch
    save_total_limit=2,                  # Keep only last 2 checkpoints
    optim="adamw_8bit",                  # Memory-efficient optimizer from bitsandbytes
    weight_decay=0.01,                   # Regularization
    lr_scheduler_type="linear",          # Linear LR schedule
    seed=3407,                           # Reproducibility
    report_to="none",                    # Disable external logging (e.g., W&B)
)

print("Training arguments configured successfully.")


**Cell 8: Initialize SFTTrainer and Start Fine-Tuning**

In this cell, we initialize the SFTTrainer from the trl library to perform LoRA-based fine-tuning of Gemma 2B on our SOAP summarization task. This trainer handles the complete training loop, including gradient accumulation, mixed-precision training, and evaluation on the validation dataset.

* model = model: The Gemma 2B model with LoRA adapters applied in Cell 6, ready for parameter-efficient fine-tuning.

* tokenizer = tokenizer: Ensures that all input prompts are tokenized using Gemma’s chat template for proper instruction-following behavior.

* train_dataset / eval_dataset: The datasets preprocessed with format_chat_template in Cell 5, where each example is formatted as a clear instruction-response pair.

* dataset_text_field = "text": Specifies the column in the dataset containing the formatted prompts and responses.

* max_seq_length = max_seq_length: Ensures all sequences are truncated or padded to a consistent length for stable training.

* dataset_num_proc = 2: Uses 2 parallel processes to speed up dataset preprocessing.

* packing = False: Disables sequence packing; useful for dialogue-style sequences where prompts and responses should remain separate.

* args = training_args: Passes the hyperparameters defined in Cell 7, which are optimized for memory-efficient, stable fine-tuning with LoRA.

* trainer.train(): Starts the fine-tuning process, handling all forward and backward passes, gradient accumulation, optimizer steps, and mixed-precision computations automatically.

This setup ensures that Gemma 2B learns the instruction-following task efficiently, while keeping GPU memory usage low and maintaining training stability.

In [None]:
# Cell 8



print("🚀 Initializing the SFTTrainer...")

# Initialize the trainer for LoRA fine-tuning
trainer = SFTTrainer(
    model=model,                             # Gemma 2B model with LoRA adapters
    tokenizer=tokenizer,                     # Corresponding tokenizer
    train_dataset=train_dataset_formatted,   # Preprocessed training dataset
    eval_dataset=validation_dataset_formatted, # Preprocessed validation dataset
    dataset_text_field="text",               # Field containing the formatted prompt text
    max_seq_length=max_seq_length,           # Maximum sequence length
    dataset_num_proc=2,                      # Number of processes for dataset preprocessing
    packing=False,                           # Disable packing for dialogue tasks; can enable for very short sequences
    args=training_args,                      # Training arguments defined in Cell 7
)

# Start the fine-tuning process
print("Starting model training...")
trainer_stats = trainer.train()
print("Training complete.")

**Cell 9: Save Fine-Tuned LoRA Adapters and Tokenizer**

After completing the LoRA fine-tuning, it is crucial to save the trained adapters and tokenizer for future use, such as inference, further fine-tuning, or deployment. This ensures that the model can be reloaded exactly as it was trained, maintaining the instruction-following behavior.

* output_directory: Specifies the folder where the model adapters and tokenizer will be saved. Using a descriptive name helps identify the model and task.

* os.makedirs(output_directory, exist_ok=True): Creates the directory if it doesn’t already exist, preventing errors during saving.

* model.save_pretrained(output_directory): Saves the LoRA adapters (parameter-efficient fine-tuning weights) to the directory. If LoRA adapters are not applied, it saves the full model weights.

* tokenizer.save_pretrained(output_directory): Saves the tokenizer configuration and vocabulary, ensuring consistent tokenization when the model is later loaded for inference.

* print statements: Confirm that the adapters and tokenizer have been saved successfully, providing clear feedback for the user.

This step finalizes the training workflow, producing a ready-to-use fine-tuned Gemma 2B model that can generate SOAP notes from medical dialogues.

In [None]:
# Cell 9



# Define the output directory
output_directory = "lora-fine-tuned-gemma-2b-medical-dialogue-to-soap-summary"

# Create the directory if it doesn't exist
os.makedirs(output_directory, exist_ok=True)

# Save only the LoRA adapters (parameter-efficient fine-tuning)
if hasattr(model, "peft_config"):
    print("Saving LoRA adapters...")
    model.save_pretrained(output_directory)
else:
    print("No LoRA adapters found. Saving full model instead.")
    model.save_pretrained(output_directory)

# Save the tokenizer
print("Saving tokenizer...")
tokenizer.save_pretrained(output_directory)

print(f"Model adapters and tokenizer successfully saved to '{output_directory}'.")

In [None]:
# zip the finetuned model



# Define the path to the model directory and the zip output path
model_dir = '/kaggle/working/lora-fine-tuned-gemma-2b-medical-dialogue-to-soap-summary'
zip_output = '/kaggle/working/lora_fine_tuned_model.zip'

# Create a zip archive of the model directory
shutil.make_archive(zip_output.replace('.zip', ''), 'zip', model_dir)


**Cell 10: Evaluate Baseline Gemma 2B Model**

In this cell, we evaluate the original, unfine-tuned Gemma 2B model on a sample from the test dataset. This provides a baseline performance to compare against the fine-tuned LoRA model, helping us understand how much improvement the instruction fine-tuning has achieved.

* base_model, base_tokenizer: Loads the original Gemma 2B model and tokenizer in 4-bit precision for memory-efficient evaluation.

* base_model.to("cuda") & base_model.eval(): Moves the model to GPU and sets it to evaluation mode, disabling dropout and other training behaviors.

* num_samples = 1: Specifies how many test samples to generate summaries for. Can be increased for more comprehensive evaluation.

* messages: Formats the user input as a clear instruction prompt (“Summarize the following medical dialogue into a SOAP note”) to align with the model’s instruction-following capability.

* prompt = base_tokenizer.apply_chat_template(...): Converts the messages into the proper chat template expected by Gemma 2B, ensuring the model receives input in the same structured format as during training.

* inputs & attention_mask: Tokenizes the prompt and generates an attention mask to tell the model which tokens to pay attention to.

* base_model.generate(...): Produces the summary for the given dialogue using the baseline model.

* baseline_summary = base_tokenizer.batch_decode(...): Decodes the generated token IDs into readable text, skipping special tokens.

* Print statements: Display the original dialogue, reference SOAP summary, and the generated summary from the baseline model for easy side-by-side comparison.

This evaluation step is essential to measure the improvement after fine-tuning, giving a clear quantitative and qualitative comparison between the pretrained baseline and the LoRA-adapted model.

In [None]:
# Cell 10



# Load the original base Gemma 2B model and tokenizer for baseline comparison
print("Loading the original base Gemma 2B model...")
base_model, base_tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/gemma-2b-it-bnb-4bit",
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=load_in_4bit,
    trust_remote_code=True,  # Ensure custom model code is loaded
)
base_model.to("cuda")
base_model.eval()  # Set to evaluation mode

# Number of test samples to evaluate
num_samples = 1  # Adjust as needed

for i in range(num_samples):
    # Select a sample from the test dataset
    sample = test_dataset[i]
    dialogue = sample["dialogue"]
    reference_summary = sample["soap"]

    # Format the messages using the chat template
    messages = [
        {
            "role": "user",
            "content": f"Summarize the following medical dialogue into a SOAP note:\n\n{dialogue}"
        }
    ]

    # Format prompt using the tokenizer's chat template
    prompt = base_tokenizer.apply_chat_template(
        messages, 
        tokenize=False, 
        add_generation_prompt=True  # Ensures the model knows to generate a response
    )

    # Tokenize the prompt and move inputs to GPU
    inputs = base_tokenizer([prompt], return_tensors="pt").to("cuda")

    # Create attention mask
    attention_mask = (inputs["input_ids"] != base_tokenizer.pad_token_id).to(torch.long)

    # Generate output from the baseline (unfine-tuned) model
    print(f"\nGenerating summary for sample {i + 1} (BASE model)...")
    outputs = base_model.generate(
        input_ids=inputs["input_ids"],
        attention_mask=attention_mask,
        max_new_tokens=512,
        eos_token_id=base_tokenizer.eos_token_id,
        pad_token_id=base_tokenizer.pad_token_id
    )

    # Decode the generated summary
    baseline_summary = base_tokenizer.batch_decode(outputs, skip_special_tokens=True)
    generated_summary = baseline_summary[0]

    # Print results for comparison
    print("\n--- BASELINE MODEL (NO FINE-TUNING) ---")
    print("\nDIALOGUE:")
    print(dialogue)
    print("\nREFERENCE SOAP SUMMARY:")
    print(reference_summary)
    print("\nGENERATED SUMMARY (BASELINE):")
    print(generated_summary)
    print("\n--- END OF BASELINE ---")


**Cell 11: Evaluate Fine-Tuned Gemma 2B Model**

In this cell, we evaluate the fine-tuned Gemma 2B model on a sample from the test dataset. This allows us to see how well the model generates SOAP notes after LoRA-based instruction fine-tuning, and to compare it with the baseline model.

* fine_tuned_model, fine_tuned_tokenizer: Loads the Gemma 2B model fine-tuned on medical dialogues and the corresponding tokenizer. Using load_in_4bit ensures memory-efficient loading.

* fine_tuned_model.to("cuda") & fine_tuned_model.eval(): Moves the model to GPU and sets evaluation mode to disable dropout, ensuring deterministic outputs.

* num_samples = 1: Specifies how many test samples to evaluate; can be increased for more extensive testing.

* messages: Formats the user input as a clear instruction prompt, telling the model to summarize the dialogue into a SOAP note.

* prompt = fine_tuned_tokenizer.apply_chat_template(...): Converts the messages into the proper chat template expected by Gemma 2B, maintaining consistency with training.

* inputs & attention_mask: Tokenizes the prompt and creates an attention mask, indicating which tokens the model should attend to.

* torch.no_grad(): Disables gradient computation to save memory and speed up inference since we are only generating outputs.

* fine_tuned_model.generate(...): Generates the SOAP note for the given dialogue, using max_new_tokens to control the output length and properly handling EOS and padding tokens.

* fine_tuned_tokenizer.batch_decode(...): Decodes the generated tokens into readable text. The assistant’s response is extracted with .split("assistant\n")[-1].strip() to ensure only the generated SOAP note is evaluated.

* Print statements: Display the original dialogue, reference SOAP summary, and the generated summary from the fine-tuned model, making it easy to compare qualitative improvements over the baseline.

This step demonstrates the effectiveness of LoRA fine-tuning and provides concrete examples of how the model performs on real medical dialogues, forming the basis for both qualitative and quantitative evaluation.

In [None]:
# Cell 11

# Load the fine-tuned Gemma 2B model and tokenizer
print("Loading the fine-tuned model for evaluation...")
fine_tuned_model, fine_tuned_tokenizer = FastLanguageModel.from_pretrained(
    model_name="aslamsikder/lora-fine-tuned-gemma-2b-medical-dialogue-to-soap-summary",
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=load_in_4bit,
    trust_remote_code=True,
)

fine_tuned_model.eval()  # Set to evaluation mode

num_samples = 1  # Adjust as needed

for i in range(num_samples):
    sample = test_dataset[i]
    dialogue = sample["dialogue"]
    reference_summary = sample["soap"]

    # Prepare chat messages
    messages = [
        {
            "role": "user",
            "content": f"Summarize the following medical dialogue into a SOAP note perfectly:\n\n{dialogue}"
        }
    ]

    prompt = fine_tuned_tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    inputs = fine_tuned_tokenizer([prompt], return_tensors="pt").to("cuda")
    attention_mask = (inputs["input_ids"] != fine_tuned_tokenizer.pad_token_id).to(torch.long)

    print(f"\nGenerating summary for sample {i + 1} (FINE-TUNED model)...")
    with torch.no_grad():
        outputs = fine_tuned_model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=attention_mask,
            max_new_tokens=512,
            eos_token_id=fine_tuned_tokenizer.eos_token_id,
            pad_token_id=fine_tuned_tokenizer.pad_token_id
        )

    # Decode and clean output
    fine_tuned_summary = fine_tuned_tokenizer.batch_decode(outputs, skip_special_tokens=True)
    raw_output = fine_tuned_summary[0]

    # Remove assistant/user prefixes and unwanted artifacts
    cleaned_output = re.sub(r'^(user|model|assistant)\n', '', raw_output, flags=re.IGNORECASE)
    cleaned_output = re.sub(r'strtoA:.*', '', cleaned_output) 
    cleaned_output = cleaned_output.strip()

    print("\n--- FINE-TUNED MODEL (CLEANED) ---")
    print("\nDIALOGUE:")
    print(dialogue)
    print("\nREFERENCE SOAP SUMMARY:")
    print(reference_summary)
    print("\nGENERATED SUMMARY (FINE-TUNED):")
    print(cleaned_output)
    print("\n--- END OF FINE-TUNED ---")


**Cell 12: Quantitative Evaluation on the Test Set**

In this cell, we evaluate the performance of the fine-tuned Gemma 2B model on the test dataset using standard text summarization metrics: ROUGE and BERTScore. This provides an objective measure of how well the model generates SOAP notes from medical dialogues.

* predictions / references: Lists to store the generated summaries and their corresponding ground-truth SOAP notes.

* messages: Formats the user input as a clear instruction (“Summarize the following medical dialogue into a SOAP note”) using the chat template, ensuring consistency with training.

* tokenizer.apply_chat_template(...): Converts the messages into the proper prompt format expected by Gemma 2B.

* inputs & attention_mask: Tokenizes the prompt and generates the attention mask, telling the model which tokens to attend to during generation.

* torch.no_grad(): Disables gradient calculation to save memory and speed up evaluation since no backpropagation is required.

* model.generate(...): Generates the summary for each dialogue. max_new_tokens limits the summary length, while eos_token_id and pad_token_id ensure proper sequence termination and padding handling.

* tokenizer.batch_decode(...): Decodes token IDs into readable text, skipping special tokens. The assistant’s response is extracted using .split("assistant\n")[-1].strip().

* ROUGE: Evaluates overlap between generated summaries and reference SOAP notes, capturing precision, recall, and F1 for n-grams.

* BERTScore: Measures semantic similarity between predictions and references using contextual embeddings; the average F1 score provides a single overall metric.

* Print results: Displays ROUGE and BERTScore in a readable percentage format for easy comparison.

This evaluation step provides a quantitative benchmark to assess how much the fine-tuning improved the model over the baseline and helps validate the quality of generated SOAP notes.

In [None]:
# Cell 12

# Load evaluation metrics
rouge = evaluate.load('rouge')
bertscore = evaluate.load("bertscore")

# Initialize lists to store predictions and references
predictions = []
references = []

print("\nRunning quantitative evaluation on the test set...")

# Loop through the entire test dataset
for sample in tqdm(test_dataset):
    dialogue = sample["dialogue"]
    reference = sample["soap"]

    # Format prompt for model using chat template
    messages = [
        {
            "role": "user",
            "content": f"Summarize the following medical dialogue into a SOAP note:\n\n{dialogue}"
        }
    ]

    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True  # Ensures the model generates an assistant response
    )

    # Tokenize and move to GPU
    inputs = tokenizer([prompt], return_tensors="pt").to("cuda")

    # Generate summary from the fine-tuned model
    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=(inputs["input_ids"] != tokenizer.pad_token_id).to(torch.long),
            max_new_tokens=512,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id
        )

    # Decode generated text
    generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    # Extract the assistant response
    prediction = generated_text[0].split("assistant\n")[-1].strip()

    # Append to lists
    predictions.append(prediction)
    references.append(reference)

# Compute evaluation metrics
rouge_results = rouge.compute(predictions=predictions, references=references)
bertscore_results = bertscore.compute(predictions=predictions, references=references, lang="en")

# Print results
print("\n--- QUANTITATIVE EVALUATION RESULTS ---")
print("\nROUGE Scores:")
for key, value in rouge_results.items():
    print(f"{key}: {value*100:.2f}")

print("\nBERTScore:")
# Use mean F1 score for summary evaluation
avg_bert_f1 = sum(bertscore_results['f1']) / len(bertscore_results['f1'])
print(f"Average F1 Score: {avg_bert_f1*100:.2f}")
print("\n--- END OF EVALUATION ---")


**Cell 13: Deploy Fine-Tuned Gemma 2B Model via Flask API**

This cell sets up a Flask web API to allow remote inference with your fine-tuned Gemma 2B model. It provides a simple endpoint to submit medical dialogues and receive generated SOAP summaries, making the model accessible for web applications or integration with other tools.

* Flask app initialization: app = Flask(__name__) creates the Flask application instance.

* Model loading (outside request handler):

- Loads the fine-tuned Gemma 2B model and tokenizer once to avoid repeated loading on each request.

- load_in_4bit=True ensures memory-efficient loading.

- model.to("cuda") moves the model to GPU for faster inference.

- model.eval() disables dropout and other training-specific layers for deterministic outputs.

- FastLanguageModel.for_inference(model) optimizes the model for inference.

* /summarize endpoint (POST):

- Accepts JSON input containing a "dialogue" field.

- Returns JSON output with the generated SOAP summary.

* Prompt formatting:

- The user input is converted into the Llama 3 chat template (messages → prompt) to match the fine-tuning format.

- add_generation_prompt=True ensures the model knows it should generate an assistant response.

* Tokenization and attention mask:

- The prompt is tokenized and moved to GPU.

- The attention mask identifies which tokens are real versus padding, ensuring proper handling during generation.

* Text generation:

- torch.no_grad() disables gradient computation for memory efficiency.

- model.generate(...) produces the summary with controlled length (max_new_tokens) and correct EOS/pad token handling.

* Extract assistant response: generated_text[0].split("assistant\n")[-1].strip() ensures only the generated SOAP note is returned.

* Error handling: Returns descriptive JSON errors if the model isn’t loaded, the input is invalid, or generation fails.

* Running the app: app.run(host='0.0.0.0', port=7860) starts the server accessible on all network interfaces; 7860 is compatible with Hugging Face Spaces.

This setup provides a production-ready, lightweight API for generating SOAP notes from medical dialogues using the fine-tuned Gemma 2B model, making it easy to integrate into web apps or other services.

In [None]:
# Cell 13



# Initialize Flask app
app = Flask(__name__)

# --- Model Loading (Load once, outside request handler) ---
try:
    max_seq_length = 2048
    dtype = None
    load_in_4bit = True

    # Load the fine-tuned Gemma 2B model and tokenizer
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name="aslamsikder/fine-tuned-gemma-2b-medical-dialogue-to-soap-summary",
        max_seq_length=max_seq_length,
        dtype=dtype,
        load_in_4bit=load_in_4bit,
        trust_remote_code=True  # Ensures custom Hugging Face code is loaded
    )
    model.to("cuda")
    model.eval()  # Set to evaluation mode
    FastLanguageModel.for_inference(model)
    print("Model loaded successfully for inference.")
except Exception as e:
    print(f"Error loading model: {e}")
    model, tokenizer = None, None
# --- End Model Loading ---

@app.route('/summarize', methods=['POST'])
def summarize():
    if not model or not tokenizer:
        return jsonify({"error": "Model is not loaded"}), 500

    # Parse JSON input
    data = request.get_json()
    if not data or 'dialogue' not in data:
        return jsonify({"error": "Invalid input. 'dialogue' key is required."}), 400

    dialogue = data['dialogue']

    # Format prompt using Llama 3 chat template
    messages = [{"role": "user", "content": dialogue}]
    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    # Tokenize and move inputs to GPU
    inputs = tokenizer([prompt], return_tensors="pt").to("cuda")

    # Generate summary
    try:
        with torch.no_grad():
            outputs = model.generate(
                input_ids=inputs["input_ids"],
                attention_mask=(inputs["input_ids"] != tokenizer.pad_token_id).to(torch.long),
                max_new_tokens=512,
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.pad_token_id
            )

        generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        # Extract assistant response
        summary = generated_text[0].split("assistant\n")[-1].strip()

        return jsonify({"summary": summary})

    except Exception as e:
        return jsonify({"error": f"Error during generation: {str(e)}"}), 500

if __name__ == '__main__':
    # For production, use a WSGI server like Gunicorn; port 7860 is default for HF Spaces
    app.run(host='0.0.0.0', port=7860)
