
# Fine-tuning LLaMA 3.2 (3B) on Medical Chain-of-Thought Dataset

## Task Overview

In this notebook, we fine-tune the LLaMA 3.2 (3B) model on a medical Chain-of-Thought (CoT) dataset using parameter-efficient fine-tuning (PEFT) with Unsloth. The goal is to improve the model's ability to generate step-by-step medical reasoning with structured responses.

## 1. Environment Setup

### Install Required Libraries

First, let's install required dependencies:

In [None]:
# Install unsloth and unsloth-zoo from GitHub
!pip install --upgrade --force-reinstall --no-cache-dir --no-deps git+https://github.com/unslothai/unsloth.git git+https://github.com/unslothai/unsloth-zoo.git
# Install required dependencies
!pip install bitsandbytes accelerate xformers==0.0.29.post3 peft trl==0.15.2 triton cut_cross_entropy unsloth_zoo
!pip install sentencepiece protobuf "datasets>=3.4.1" huggingface_hub hf_transfer
!pip install transformers==4.51.3
!pip install unsloth
!pip install rouge_score evaluate

###   Import Dependencies

In [None]:
import os
import torch
import random
import numpy as np
from datasets import load_dataset
import nltk
from unsloth import FastLanguageModel
import wandb
from rouge_score import rouge_scorer
from transformers import TrainingArguments
from trl import SFTTrainer

###   Set Random Seeds for Reproducibility

In [None]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

###  Download nltk data

In [None]:
nltk.download('punkt', quiet=True)

###   Check GPU availability

In [None]:
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("No GPU detected. This notebook requires GPU acceleration.")


###  Configure Weights & Biases Authentication

In [None]:
try:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("HF_TOKEN")
secret_value_1 = user_secrets.get_secret("WANDB_API_KEY")

    
    os.environ["WANDB_API_KEY"] = secret_value_1
    os.environ["WANDB_PROJECT"] = "llama-medical-cot"
    os.environ["WANDB_WATCH"] = "gradients"
    
    wandb.login(key=wandb_api_key)
    print("Logged in to Weights & Biases successfully!")
    wandb_enabled = True
except Exception as e:
    print(f"Warning: Failed to log in to Weights & Biases. Error: {e}")
    print("Training will continue but metrics won't be logged to W&B.")
    wandb_enabled = False

# 2. Dataset Preparation

###  Load Medical Chain-of-Thought Dataset


In [None]:
print("Loading dataset...")

dataset = load_dataset("FreedomIntelligence/medical-o1-reasoning-SFT", "en")

### Display Dataset Information

In [None]:
print(dataset)
print(f"Number of training examples: {len(dataset['train'])}")

### Show Sample Data

In [None]:
print("\nSample from the dataset:")
sample = dataset["train"][0]
print(f"Question: {sample['Question']}")
print(f"Response: {sample['Response']}")
print(f"Complex_CoT: {sample['Complex_CoT'][:200]}..." if len(sample['Complex_CoT']) > 200 else sample['Complex_CoT'])

### Shuffle Dataset

In [None]:
train_data = dataset["train"].shuffle(seed=42)

### Split Dataset (100 validation, rest training)

In [None]:
val_data = train_data.select(range(100))  # Exactly 100 rows as required
train_data = train_data.select(range(100, len(train_data)))

### Format Data for Chain-of-Thought Training

In [None]:
def format_medical_cot(example):
    """Format data into prompt and completion pairs with CoT structure"""
    prompt = f"""Below is a medical question. Think step by step to solve it.
Question: {example['Question']}
"""
    
    completion = f"""<think>
{example['Complex_CoT']}
</think>
<response>
{example['Response']}
</response>"""
    
    return {"text": prompt + completion}

# Format the datasets
print("Formatting datasets...")
train_formatted = train_data.map(format_medical_cot, remove_columns=train_data.column_names)
val_formatted = val_data.map(format_medical_cot, remove_columns=val_data.column_names)

print(f"Training samples formatted: {len(train_formatted)}")
print(f"Validation samples formatted: {len(val_formatted)}")

# Show sample
print("\nSample formatted data:")
print(train_formatted[0]["text"][:300] + "...")

## 3. Model Setup with Unsloth

###  Import Unsloth FastLanguageModel

In [None]:
from unsloth import FastLanguageModel, UnslothTrainer

### Load LLaMA 3.2 3B Model with 4-bit Quantization

In [None]:
model, tokenizer = FastLanguageModel.from_pretrained(
            model_name="unsloth/llama-3.2-3b-Instruct",
            max_seq_length=2048,  # Reduced from 4096 to save memory
            dtype=None,  # None for auto detection. float16, bfloat16, or float32
            load_in_4bit=True,  # Use 4-bit quantization to save memory
        )

###  Configure LoRA (Parameter-Efficient Fine-tuning)

In [None]:
model = FastLanguageModel.get_peft_model(
            model,
            r=16,             # Rank of the LoRA adapters
            target_modules=["q_proj", "k_proj", "v_proj", "o_proj", 
                          "gate_proj", "up_proj", "down_proj"], 
            lora_alpha=16,    # Alpha parameter for LoRA scaling
            lora_dropout=0, # Dropout probability for LoRA layers
            bias="none",      # Add bias to LoRA adapters
        )

### Display Trainable Parameters

In [None]:
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"Trainable parameters: {trainable_params:,} ({100 * trainable_params / total_params:.4f}% of total)")

In [None]:
print("Data already prepared for Unsloth training...")
print(f"Training samples: {len(train_formatted)}")
print(f"Validation samples: {len(val_formatted)}")

# 4. TRAINING CONFIGURATION

### Configure Training Arguments

In [None]:
from transformers import TrainingArguments
import datetime

# Get a unique run name for wandb
run_name = f"llama3.2-3b-medical-cot-{datetime.datetime.now().strftime('%Y%m%d-%H%M')}"

# Set up optimized training arguments for faster training
training_args = TrainingArguments(
    output_dir=f"./results/{run_name}",
    num_train_epochs=1,  # Changed to 1 epoch as requested
    per_device_train_batch_size=2,  # Increased from 1 to 2 (safer than 4)
    per_device_eval_batch_size=2,   # Increased from 1 to 2
    gradient_accumulation_steps=8,  # Reduced from 16 to 8 (effective batch size still 32)
    learning_rate=5e-4,  # Higher learning rate for faster convergence
    weight_decay=0.01,
    logging_steps=25,    # Reduced logging frequency to save time
    eval_steps=500,      # Much less frequent evaluation
    save_steps=1000,     # Much less frequent saving
    eval_strategy="steps",
    save_strategy="steps",
    report_to="wandb" if wandb_enabled else "none",
    run_name=run_name,
    save_total_limit=1,  # Reduced from 2 to save disk space
    fp16=True,
    warmup_steps=25,     # Increased warmup for better convergence in 1 epoch
    gradient_checkpointing=True,   # Enable gradient checkpointing to save memory
)

print("Training arguments configured successfully")

###  Initialize Trainer and Start Fine-tuning

In [None]:
from trl import SFTTrainer

# Configure tokenizer
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Initialize wandb if enabled
if wandb_enabled:
    if wandb.run is not None:
        wandb.finish()
    wandb.init(project="llama-medical-cot", name=run_name)

# Create trainer with optimized settings
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_formatted,
    eval_dataset=val_formatted,
    tokenizer=tokenizer,
    dataset_text_field="text",
    max_seq_length=2048,  # Keep original value
    packing=False,        # Keep original value
)

print("Starting fine-tuning...")
train_result = trainer.train()
print("Training complete!")

# 5. EVALUATION AND SAVING

### Evaluate Model and Save Locally

In [None]:
# Import required libraries
import os
from kaggle_secrets import UserSecretsClient
from huggingface_hub import login

# Get secrets from Kaggle
user_secrets = UserSecretsClient()
hf_token = user_secrets.get_secret("HF_TOKEN")

# Login to Hugging Face
login(token=hf_token)

# Your existing model training code here...
# (Include your full training pipeline)

# Evaluate model
print("Evaluating model...")
eval_results = trainer.evaluate()
print("Evaluation results:")
for key, value in eval_results.items():
    if isinstance(value, (int, float)):
        print(f"  {key}: {value:.4f}")

# Save model locally in Kaggle environment
print("Saving model...")
local_path = "/kaggle/working/llama-3b-medical-cot"
model.save_pretrained(local_path)
tokenizer.save_pretrained(local_path)
print(f"Model saved locally to {local_path}")

# Upload to Hugging Face Hub
print("Uploading to Hugging Face Hub...")
repo_name = "AzzamShahid/llama-3b-medical-cot"  # Your HuggingFace username

try:
    model.push_to_hub(repo_name, token=hf_token)
    tokenizer.push_to_hub(repo_name, token=hf_token)
    print(f"Model successfully uploaded to: https://huggingface.co/{repo_name}")
except Exception as e:
    print(f"Error uploading to Hub: {e}")
    print("Model is saved locally and can be downloaded from Kaggle output")

# Optional: Save evaluation results
import json
with open("/kaggle/working/eval_results.json", "w") as f:
    json.dump(eval_results, f, indent=2)
print("Evaluation results saved to eval_results.json")

### Define Inference Function

In [None]:
def generate_medical_answer(question, model=None, tokenizer=None):
    """Generate a medical answer with CoT reasoning using the model"""
    
    try:
        # Format the prompt to match training format
        prompt = f"""Below is a medical question. Think step by step to solve it.

Question: {question}

Answer: Let me think through this step by step.

"""
        
        # Prepare model for inference
        FastLanguageModel.for_inference(model)
        
        # Tokenize input
        inputs = tokenizer(
            [prompt], 
            return_tensors="pt",
            truncation=True,
            max_length=512  # Prevent overly long inputs
        )
        
        # Move to appropriate device (cuda if available)
        device = next(model.parameters()).device
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # Generate response
        with torch.no_grad():  # Save memory during inference
            outputs = model.generate(
                **inputs,
                max_new_tokens=512,
                temperature=0.7,
                do_sample=True,
                top_p=0.95,
                top_k=40,
                repetition_penalty=1.1,
                use_cache=True,
                pad_token_id=tokenizer.eos_token_id,  # Handle padding
                eos_token_id=tokenizer.eos_token_id
            )
        
        # Decode only the new tokens (exclude the input prompt)
        input_length = inputs['input_ids'].shape[1]
        new_tokens = outputs[0][input_length:]
        response = tokenizer.decode(new_tokens, skip_special_tokens=True)
        
        return response.strip()
    
    except Exception as e:
        return f"Error generating response: {str(e)}"

# Test the fine-tuned model
def test_model(model, tokenizer):
    """Test the model with sample medical questions"""
    
    test_questions = [
        "What are potential causes of acute chest pain, and how would you differentiate between them?",
        "A 65-year-old patient presents with shortness of breath and ankle swelling. What should be considered in the differential diagnosis?",
        "What are the key differences between Type 1 and Type 2 diabetes mellitus?"
    ]
    
    print("Testing fine-tuned medical model:")
    print("=" * 50)
    
    for i, question in enumerate(test_questions, 1):
        print(f"\nTest {i}:")
        print(f"Question: {question}")
        print("-" * 30)
        
        response = generate_medical_answer(question, model, tokenizer)
        print(f"Model Response:\n{response}")
        print("=" * 50)

# Run test if training was successful
if 'model' in locals() and 'tokenizer' in locals():
    try:
        test_model(model, tokenizer)
    except Exception as e:
        print(f"Error during testing: {e}")
else:
    print("Model or tokenizer not available for testing")

# 6. SUMMARY

###  Display Training Summary

In [None]:
print("\n==== Training Summary ====")

# Check dataset
try:
    if 'train_formatted' in locals() and 'val_formatted' in locals():
        if len(train_formatted) > 0 and len(val_formatted) > 0:
            print("✅ Dataset loaded and formatted successfully")
            print(f"   Training samples: {len(train_formatted)}")
            print(f"   Validation samples: {len(val_formatted)}")
        else:
            print("❌ Dataset formatted but empty")
    else:
        print("❌ Dataset formatting failed")
except Exception as e:
    print(f"❌ Dataset loading failed: {str(e)}")

# Check model
try:
    if 'model' in locals() and model is not None:
        print("✅ Model loaded successfully")
        print(f"   Model type: {type(model).__name__}")
    else:
        print("❌ Model loading failed")
except Exception as e:
    print(f"❌ Model loading failed: {str(e)}")

# Check tokenizer
try:
    if 'tokenizer' in locals() and tokenizer is not None:
        print("✅ Tokenizer loaded successfully")
    else:
        print("❌ Tokenizer loading failed")
except Exception as e:
    print(f"❌ Tokenizer loading failed: {str(e)}")

# Check training
try:
    if 'train_result' in locals() and train_result is not None:
        print("✅ Training completed successfully")
        if hasattr(train_result, 'metrics'):
            final_loss = train_result.metrics.get('train_loss', 'N/A')
            if final_loss != 'N/A':
                print(f"   Final training loss: {final_loss:.4f}")
            else:
                print("   Training loss: Not available")
        else:
            print("   Training metrics: Not available")
    else:
        print("❌ Training not completed")
except Exception as e:
    print(f"❌ Training check failed: {str(e)}")

# Check evaluation
try:
    if 'eval_results' in locals() and eval_results is not None:
        print("✅ Evaluation completed successfully")
        # Display key evaluation metrics
        if isinstance(eval_results, dict):
            eval_loss = eval_results.get('eval_loss', 'N/A')
            if eval_loss != 'N/A':
                print(f"   Evaluation loss: {eval_loss:.4f}")
            
            # Check for other common metrics
            for metric in ['eval_accuracy', 'eval_f1', 'eval_bleu']:
                if metric in eval_results:
                    print(f"   {metric.replace('eval_', '').title()}: {eval_results[metric]:.4f}")
        else:
            print("   Evaluation results format unexpected")
    else:
        print("❌ Evaluation not performed")
except Exception as e:
    print(f"❌ Evaluation check failed: {str(e)}")

# Check model saving
try:
    import os
    local_path = "/kaggle/working/llama-3b-medical-cot"
    if os.path.exists(local_path):
        print(f"✅ Model saved locally to {local_path}")
        # Check if both model and tokenizer files exist
        model_files = os.listdir(local_path)
        if any('pytorch_model' in f or 'model.safetensors' in f for f in model_files):
            print("   ✅ Model files found")
        if any('tokenizer' in f for f in model_files):
            print("   ✅ Tokenizer files found")
    else:
        print("❌ Local model save directory not found")
except Exception as e:
    print(f"❌ Model save check failed: {str(e)}")

# Check HuggingFace token
try:
    if 'hf_token' in locals() and hf_token is not None:
        print("✅ HuggingFace token loaded")
        print("✅ Ready for HuggingFace upload!")
    else:
        print("⚠️  HuggingFace token not found - upload may fail")
except Exception as e:
    print(f"❌ HuggingFace token check failed: {str(e)}")

print("\n" + "="*50)
print("Training pipeline status check complete!")

# Overall success check
all_components = [
    'train_formatted' in locals() and len(train_formatted) > 0,
    'model' in locals() and model is not None,
    'tokenizer' in locals() and tokenizer is not None,
    'train_result' in locals() and train_result is not None,
    'eval_results' in locals() and eval_results is not None
]

if all(all_components):
    print("🎉 All components successful - Model is ready for deployment!")
else:
    failed_components = []
    component_names = ['Dataset', 'Model', 'Tokenizer', 'Training', 'Evaluation']
    for i, success in enumerate(all_components):
        if not success:
            failed_components.append(component_names[i])
    print(f"⚠️  Some components failed: {', '.join(failed_components)}")
    print("   Please check the errors above before proceeding.")