In [1]:
!pip install datasets bitsandbytes trl==0.12.1 transformers peft huggingface-hub accelerate safetensors pandas matplotlib numpy==1.26.4

Collecting bitsandbytes
  Downloading bitsandbytes-0.48.1-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Collecting trl==0.12.1
  Downloading trl-0.12.1-py3-none-any.whl.metadata (10 kB)
Collecting numpy==1.26.4
  Downloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
Downloading trl-0.12.1-py3-none-any.whl (310 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m310.9/310.9 kB[0m [31m12.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.0/18.0 MB[0m [31m125.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading bitsandbytes-0.48.1-py3-none-manylinux_2_24_x86_64.whl (60.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.1/60.1 MB[0m [31m39.

# Install necessary libraries

In [1]:
import os
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    #AutoPeftModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from trl import SFTTrainer, SFTConfig
# from trl.trainer.utils import DataCollatorForCompletionOnlyLM
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model, AutoPeftModelForCausalLM, PeftConfig # Added to peft
from huggingface_hub import notebook_login
from trl import SFTTrainer, SFTConfig, setup_chat_format, DataCollatorForCompletionOnlyLM


# Check for bf16 support and set compute dtype


In [2]:
support = torch.cuda.is_bf16_supported(including_emulation=False)
calculate_dtype = torch.bfloat16 if support else torch.float32

In [3]:
print(calculate_dtype)

torch.bfloat16


#bnb config for loading 4 bit model with nf4 quant type
* loading model with quantization config
* device map to cuda
* 4bit true

In [4]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit = True,
    bnb_4bit_quant_type= "nf4",
    bnb_4bit_compute_dtype= calculate_dtype, #calculate_dtype can be bf16 or float32- use bf16 if supported
    bnb_4bit_use_double_quant= True
    )
repo = "google/gemma-2-2b-it"
model = AutoModelForCausalLM.from_pretrained(repo, quantization_config= bnb_config, device_map= "cuda:0")

config.json:   0%|          | 0.00/838 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/241M [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/187 [00:00<?, ?B/s]

#Check model memory footprint

In [5]:
print(model.get_memory_footprint()/1024/1024)

2090.7119140625


#model config

In [6]:
model

Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 2304, padding_idx=0)
    (layers): ModuleList(
      (0-25): 26 x Gemma2DecoderLayer(
        (self_attn): Gemma2Attention(
          (q_proj): Linear4bit(in_features=2304, out_features=2048, bias=False)
          (k_proj): Linear4bit(in_features=2304, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=2304, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=2048, out_features=2304, bias=False)
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear4bit(in_features=2304, out_features=9216, bias=False)
          (up_proj): Linear4bit(in_features=2304, out_features=9216, bias=False)
          (down_proj): Linear4bit(in_features=9216, out_features=2304, bias=False)
          (act_fn): GELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (post_attention_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (pre_feedfor

#Prepare model for kbit training
##Use Lora Config


1.   rank [4,8,16,32] - choose one
2.   lora_alpha is a scalling factor which should be 2x the rank of matrix.
3.   dropout range from 0.03 to 0.10 which helps prevent overfit
4.   module - choose module as per requirement


In [7]:
model = prepare_model_for_kbit_training(model)

config = LoraConfig(
    r = 8, #. rank of LoRA - [4-16]
    bias = "none", # ["all", "lora_only"] - for train bias term
    lora_alpha = 16, # scalling factor
    lora_dropout = 0.10, # prevent overfit- used for regularisation
    target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
    task_type = "CAUSAL_LM"

)

model = get_peft_model(model, config)
model

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Gemma2ForCausalLM(
      (model): Gemma2Model(
        (embed_tokens): Embedding(256000, 2304, padding_idx=0)
        (layers): ModuleList(
          (0-25): 26 x Gemma2DecoderLayer(
            (self_attn): Gemma2Attention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=2304, out_features=2048, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=2304, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=2048, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
           

#once again check memory footprint

In [8]:
print(model.get_memory_footprint()/1024/1024)

3255.78271484375


#Print base model to compare

In [9]:
print(model.get_base_model)

<bound method PeftModel.get_base_model of PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Gemma2ForCausalLM(
      (model): Gemma2Model(
        (embed_tokens): Embedding(256000, 2304, padding_idx=0)
        (layers): ModuleList(
          (0-25): 26 x Gemma2DecoderLayer(
            (self_attn): Gemma2Attention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=2304, out_features=2048, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=2304, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=2048, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector)

In [10]:
print(model.get_memory_footprint()/1e6)

3413.935616


#Check for trainable Parameters and its percentage for a mathematical view.

In [11]:
trainable_params, total_params = model.get_nb_trainable_parameters()
percentage = (trainable_params / total_params) * 100

print(f"Trainable Parameters: {trainable_params:,}")
print(f"Total Parameters: {total_params:,}")
print(f"Percentage Trainable: {percentage:.2f}%")

Trainable Parameters: 10,383,360
Total Parameters: 2,624,725,248
Percentage Trainable: 0.40%


#ETL Process for Dataset Prep stage, Tokenizer load and define chat template if needed.

In [6]:
from datasets import load_dataset
from transformers import AutoTokenizer
import torch

# ============================================
# SETUP: Gemma-2-2B Model & Tokenizer
# ============================================
model_name = "google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

# Add padding token if it doesn't exist
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# ============================================
# LOAD MEDICAL DATASET
# ============================================
print("Loading ChatDoctor-HealthCareMagic dataset...")

# Load the medical Q&A dataset (using 10k samples for reasonable training time)
raw_dataset = load_dataset(
    "lavita/ChatDoctor-HealthCareMagic-100k",
    split="train[:10000]"  # Using first 10k samples
)

print(f"Dataset loaded: {len(raw_dataset)} samples")
print(f"Sample entry: {raw_dataset[0]}")
print(f"Dataset columns: {raw_dataset.column_names}")

# ============================================
# FORMAT PROMPTS FOR GEMMA-2
# ============================================

def format_prompt(example):
    """
    Format medical Q&A into Gemma-2's chat format with safety disclaimer

    Gemma-2 uses a specific chat template format:
    <start_of_turn>user
    {user message}<end_of_turn>
    <start_of_turn>model
    {assistant response}<end_of_turn>
    """

    # Extract patient question and doctor response
    # The dataset has 'input' (patient question) and 'output' (doctor answer)
    patient_question = example["input"]
    doctor_response = example["output"]

    # Add safety disclaimer to medical responses
    safe_response = (
        f"{doctor_response}\n\n"
        "⚕️ Disclaimer: This information is for educational purposes only. "
        "Please consult a qualified healthcare professional for medical advice."
    )

    # Format in Gemma-2 chat template
    prompt = (
        f"<start_of_turn>user\n"
        f"{patient_question}<end_of_turn>\n"
        f"<start_of_turn>model\n"
        f"{safe_response}<end_of_turn>"
    )

    return {"text": prompt}

# ============================================
# TOKENIZATION FUNCTION
# ============================================

def tokenize_function(examples):
    """Tokenize the formatted text with proper padding and labels"""

    # Tokenize the text
    tokenized = tokenizer(
        examples["text"],
        truncation=True,
        padding="max_length",
        max_length=1024,  # Increased for medical responses (often longer)
        return_tensors=None
    )

    # Create labels for causal LM training
    # Labels = input_ids, but -100 for padded tokens (ignored in loss)
    labels = []
    for input_ids, attention_mask in zip(tokenized["input_ids"], tokenized["attention_mask"]):
        # Convert to list if needed
        label = input_ids.copy() if isinstance(input_ids, list) else list(input_ids)

        # Set padded positions to -100 (ignored in loss calculation)
        for i, mask in enumerate(attention_mask):
            if mask == 0:  # This is a padded token
                label[i] = -100

        labels.append(label)

    tokenized["labels"] = labels
    return tokenized

# ============================================
# PROCESS DATASET
# ============================================

print("\nFormatting medical Q&A prompts...")
formatted_dataset = raw_dataset.map(
    format_prompt,
    desc="Formatting prompts with Gemma-2 chat template"
)

print("\nTokenizing dataset...")
tokenized_dataset = formatted_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=formatted_dataset.column_names,  # Remove original columns
    desc="Tokenizing medical conversations"
)

# Final dataset ready for training
final_dataset = tokenized_dataset

# ============================================
# SUMMARY
# ============================================

print("\n" + "="*60)
print("✅ MEDICAL DATASET PREPARATION COMPLETE!")
print("="*60)
print(f"📊 Total samples: {len(final_dataset)}")
print(f"📏 Max sequence length: {len(final_dataset[0]['input_ids'])}")
print(f"🔑 Dataset keys: {final_dataset.column_names}")
print(f"💾 Approximate size: {len(final_dataset) * 1024 * 4 / 1024 / 1024:.2f} MB")
print("\n🚀 Ready to use 'final_dataset' in your SFTTrainer!")
print("="*60)

# Optional: Print a sample to verify formatting
print("\n📋 Sample formatted prompt (first 500 chars):")
print(formatted_dataset[0]["text"][:500] + "...")

tokenizer_config.json:   0%|          | 0.00/47.0k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

Loading ChatDoctor-HealthCareMagic dataset...


README.md:   0%|          | 0.00/542 [00:00<?, ?B/s]

data/train-00000-of-00001-5e7cb295b9cff0(…):   0%|          | 0.00/70.5M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/112165 [00:00<?, ? examples/s]

Dataset loaded: 10000 samples
Sample entry: {'instruction': "If you are a doctor, please answer the medical questions based on the patient's description.", 'input': 'I woke up this morning feeling the whole room is spinning when i was sitting down. I went to the bathroom walking unsteadily, as i tried to focus i feel nauseous. I try to vomit but it wont come out.. After taking panadol and sleep for few hours, i still feel the same.. By the way, if i lay down or sit down, my head do not spin, only when i want to move around then i feel the whole world is spinning.. And it is normal stomach discomfort at the same time? Earlier after i relieved myself, the spinning lessen so i am not sure whether its connected or coincidences.. Thank you doc!', 'output': 'Hi, Thank you for posting your query. The most likely cause for your symptoms is benign paroxysmal positional vertigo (BPPV), a type of peripheral vertigo. In this condition, the most common symptom is dizziness or giddiness, which is ma

Formatting prompts with Gemma-2 chat template:   0%|          | 0/10000 [00:00<?, ? examples/s]


Tokenizing dataset...


Tokenizing medical conversations:   0%|          | 0/10000 [00:00<?, ? examples/s]


✅ MEDICAL DATASET PREPARATION COMPLETE!
📊 Total samples: 10000
📏 Max sequence length: 1024
🔑 Dataset keys: ['input_ids', 'attention_mask', 'labels']
💾 Approximate size: 39.06 MB

🚀 Ready to use 'final_dataset' in your SFTTrainer!

📋 Sample formatted prompt (first 500 chars):
<start_of_turn>user
I woke up this morning feeling the whole room is spinning when i was sitting down. I went to the bathroom walking unsteadily, as i tried to focus i feel nauseous. I try to vomit but it wont come out.. After taking panadol and sleep for few hours, i still feel the same.. By the way, if i lay down or sit down, my head do not spin, only when i want to move around then i feel the whole world is spinning.. And it is normal stomach discomfort at the same time? Earlier after i relie...


In [8]:

# My optimized parameters for LoRA training on medical data
min_effective_batch_size = 4  # Reduced for 2B model with longer sequences (1024 tokens)
lr = 2e-5  # Slightly lower LR for Gemma-2 stability with medical data
max_seq_length = 1024  # Increased for detailed medical responses
collator_fn = None  # I'm not using a custom collator since I pre-pad in tokenization
packing = False  # I disabled packing since I'm using fixed-length sequences
steps = 20  # My logging and saving frequency (adjusted for larger dataset)
num_train_epochs = 3  # Standard for medical fine-tuning with LoRA
warmup_ratio = 0.1  # Warmup for stable training start

# My SFT configuration for Gemma-2-2B medical assistant
sft_config = SFTConfig(
    # I'm saving my medical model to a dedicated directory
    output_dir = '/content/drive/MyDrive/gemma-2-2b-medical/Gemma-2-2B-MedicalQA-finetuned',

    # My data processing settings
    packing = packing,
    max_seq_length = max_seq_length,

    # I disabled gradient checkpointing (not needed with 4-bit quantization)
    gradient_checkpointing = False,

    # My training batch and precision settings
    # Note: Smaller batch size due to longer sequences (1024 vs 512)
    per_device_train_batch_size = min_effective_batch_size,
    auto_find_batch_size = True,  # Let trainer optimize for available VRAM
    bf16 = True,  # Using bf16 for better numerical stability (important for medical accuracy)

    # My training schedule optimized for medical domain
    num_train_epochs = num_train_epochs,
    learning_rate = lr,
    lr_scheduler_type = "cosine",  # Smooth learning rate decay
    warmup_ratio = warmup_ratio,
    weight_decay = 0.01,  # Regularization to prevent overfitting on medical terminology
    max_grad_norm = 1.0,  # Gradient clipping for stability

    # My logging and monitoring setup
    report_to = 'wandb',  # Tracking my medical AI experiment
    run_name = "Gemma-2-2B-MedicalQA-LoRA-r8-alpha16",  # Descriptive run name with LoRA config

    # My logging directory
    logging_dir = '/content/drive/MyDrive/gemma-2-2b-medical/Gemma-2-2B-MedicalQA-finetuned/logs',

    # My checkpoint and logging strategy
    logging_strategy = 'steps',
    save_strategy = 'steps',
    logging_steps = steps,  # I log every 20 steps
    save_steps = steps,     # I save checkpoint every 20 steps (more frequent due to larger dataset)
    save_total_limit = 2,   # Keep last 2 checkpoints to save Google Drive space
)

# I create my trainer with the medical dataset and configuration
trainer = SFTTrainer(
    model = model,                    # My Gemma-2-2B model with LoRA adapters (rank=8, alpha=16, dropout=0.1)
    train_dataset = final_dataset.select(range(1500)),    # My 10k medical Q&A dataset
    processing_class = tokenizer,     # Gemma-2 tokenizer
    data_collator = collator_fn,      # Using default collator
    args = sft_config,               # My medical training configuration
)

# I start the medical fine-tuning process
print("🏥 Starting Gemma-2-2B Medical Fine-tuning...")
print(f"📊 Training on {len(final_dataset.select(range(1500)))} medical Q&A samples")
print(f"🔧 LoRA Config: rank=8, alpha=16, dropout=0.1")
print(f"💾 4-bit quantization: nf4")
print(f"⏱️ Estimated time: ~2-3 hours on T4 GPU")
print("-" * 60)

trainer.train()

print("\n" + "="*60)
print("✅ Medical fine-tuning completed!")
print(f"💾 Model saved to: {sft_config.output_dir}")
print("🏥 Your Gemma-2-2B Medical Assistant is ready!")
print("="*60)

NameError: name 'SFTConfig' is not defined

In [None]:
# Step 1: I'm saving my trained medical model locally first
print("💾 Saving my trained Gemma-2-2B medical model...")
trainer.save_model('/content/gemma-2-medical-saved')

# Step 2: I load and merge the LoRA adapter with the base model
from peft import AutoPeftModelForCausalLM

print("🔧 Loading my PEFT model and merging adapter...")
# I load the saved PEFT model (use the same path as Step 1)
peft_model = AutoPeftModelForCausalLM.from_pretrained('/content/gemma-2-medical-saved')

# I merge and unload the adapter to get a single model
print("⚙️ Merging LoRA weights into base model...")
merged_model = peft_model.merge_and_unload()

# Step 3: I save the merged model with tokenizer
print("💾 Saving my merged medical model...")
merged_model.save_pretrained('/content/gemma-2-medical-merged')
tokenizer.save_pretrained('/content/gemma-2-medical-merged')

# Step 4: I upload my model to Hugging Face Hub
from huggingface_hub import HfApi

print("☁️ Uploading my medical assistant to Hugging Face Hub...")
api = HfApi()
api.upload_folder(
    folder_path='/content/gemma-2-medical-merged',
    repo_id="sweatSmile/Gemma-2-2B-MedicalQA-Assistant",  # My medical AI repo
    repo_type="model",
    commit_message="Upload Gemma-2-2B fine-tuned on 1.5k medical Q&A (ChatDoctor-HealthCareMagic) with LoRA (r=8, alpha=16)"
)

print("\n" + "="*60)
print("✅ Model upload completed! 🎉")
print("🏥 Your Gemma-2-2B Medical Assistant is now live!")
print("="*60)
print("🔗 Model URL: https://huggingface.co/sweatSmile/Gemma-2-2B-MedicalQA-Assistant")
print("\n📋 Model Details:")
print("   - Base: google/gemma-2-2b-it")
print("   - Dataset: ChatDoctor-HealthCareMagic (1.5k samples)")
print("   - LoRA: rank=8, alpha=16, dropout=0.1")
print("   - Quantization: 4-bit (nf4)")
print("   - Domain: Medical Q&A with safety disclaimers")
print("="*60)