# **Medical LLM Fine-Tuning Notebook**

## **1. Setup & Installation**
```python
!pip install unsloth
!pip install --force-reinstall --no-cache-dir --no-deps git+https://github.com/unslothai/unsloth.git
!pip install trl==0.14.0 peft==0.14.0 xformers==0.0.28.post3
!pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu124
!pip install --upgrade datasets huggingface_hub evaluate wandb
```
**Why?**  
- **Unsloth**: Optimizes LoRA fine-tuning for 2x faster training  
- **TRL/Peft**: Enables parameter-efficient fine-tuning (LoRA)  
- **Torch 2.5.1**: Required for Unsloth compatibility  
- **WandB**: Tracks training metrics  

---

## **2. Initialize Model**
```python
from unsloth import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
    max_seq_length=2048,
    load_in_4bit=True,  # Quantization for memory efficiency
    token=hf_token
)
```
**Key Choices:**  
- **4-bit Loading**: Reduces VRAM usage by ~75%  
- **2048 seq_len**: Accommodates long medical reasoning chains  
- **DeepSeek-R1**: Strong base for clinical reasoning  

---

## **3. Pre-Fine-Tuning Test**
```python
def generate_response(question):
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    outputs = model.generate(**inputs, max_new_tokens=200)
    return tokenizer.decode(outputs[0])

test_case = "A 61-year-old woman with urinary incontinence..."
pre_tuning_response = generate_response(test_case)
```
**Purpose:**  
Establishes baseline performance before fine-tuning  

---

## **4. Dataset Preparation**
```python
dataset = load_dataset("FreedomIntelligence/medical-o1-reasoning-SFT", split="train[:500]")
dataset = dataset.train_test_split(test_size=0.1)

def format_prompt(example):
    return f"""
### Clinical Case: {example['Question']}
### Analysis: <think>{example['Complex_CoT']}</think>
### Diagnosis: {example['Response']}
"""
```
**Why This Format?**  
- **Structured prompts** improve model's clinical reasoning  
- **Chain-of-Thought (CoT)**: Forces step-by-step analysis  

---

## **5. LoRA Configuration**
```python
model_lora = FastLanguageModel.get_peft_model(
    model,
    r=16,  # LoRA rank
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_alpha=16,
    lora_dropout=0  # Disabled for Unsloth optimization
)
```
**LoRA Parameters Explained:**  
- **r=16**: Balance between adaptability and overfitting  
- **Target modules**: Where LoRA adapters are injected  
- **alpha=16**: Scales LoRA weights (r/alpha ~1:1 ratio recommended)  

---

## **6. Training Setup**
```python
training_args = TrainingArguments(
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,  # Effective batch size = 8
    learning_rate=2e-4,  # Optimal for LoRA
    eval_strategy="steps",
    eval_steps=50  # Validate every 50 steps
)

trainer = SFTTrainer(
    model=model_lora,
    train_dataset=dataset["train"],
    formatting_func=format_prompt,
    args=training_args
)
```
**Training Strategy:**  
- Small batch size → Fits in Colab's free GPU  
- Gradient accumulation → Simulates larger batches  
- Frequent eval → Catch overfitting early  

---

## **7. Post-Training Evaluation**
```python
post_tuning_response = generate_response(test_case)

def compare_responses(question, pre, post):
    print(f"PRE: {pre.split('Diagnosis:')[-1]}")
    print(f"POST: {post.split('Diagnosis:')[-1]}")
```
**Expected Improvements:**  
1. More accurate diagnoses  
2. Better explanation of pathophysiology  
3. Reduced hallucinations  

---

## **8. Gradio Deployment**
```python
import gradio as gr

demo = gr.Interface(
    fn=generate_response,
    inputs=gr.Textbox(lines=5),
    outputs="text",
    examples=[test_case]
)
demo.launch()
```
**Deployment Notes:**  
- Runs locally in Colab  
- Add `share=True` for temporary public link  

---

## **Key Takeaways**
1. **LoRA Efficiency**: Achieves good results with only 500 examples  
2. **Unsloth Benefits**: 2x faster than standard Peft  
3. **Medical Specialization**: Model learns clinical reasoning patterns  
4. **Quantization**: Enables fine-tuning on consumer GPUs  


In [None]:
#install required dependencies
!pip install unsloth # install unsloth
!pip install --force-reinstall --no-cache-dir --no-deps git+https://github.com/unslothai/unsloth.git
!pip install trl==0.14.0 peft==0.14.0 xformers==0.0.28.post3
!pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu124

In [None]:
!pip install --upgrade datasets huggingface_hub
!pip install -qU evaluate rouge_score


In [None]:
from unsloth import FastLanguageModel
import torch
from trl import SFTTrainer
from unsloth import is_bfloat16_supported
from huggingface_hub import login
from transformers import TrainingArguments
from datasets import load_dataset
import wandb

In [None]:
from google.colab import userdata
hf_token = userdata.get('HF_TOKEN')
login(hf_token)

In [None]:
from unsloth import FastLanguageModel
from transformers import TextStreamer

# Configuration
model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
max_seq_length = 2048  # For long medical reasoning chains

# Load 4-bit quantized model
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_name,
    max_seq_length=max_seq_length,
    dtype=None,
    load_in_4bit=True,
    token=hf_token  # Replace with your HF token
)

# Set pad token if missing
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

In [None]:
def generate_response(question):
    prompt = f"""
### Clinical Case:
{question}

### Step-by-Step Analysis:
<think>
"""
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to("cuda")
    streamer = TextStreamer(tokenizer)

    outputs = model.generate(
        **inputs,
        max_new_tokens=1200,
        do_sample=True,
        temperature=0.7,
        streamer=streamer,
        pad_token_id=tokenizer.eos_token_id
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

In [None]:
### Test Case 1: Urinary Incontinence
# *A classic clinical reasoning challenge*

# %%
question = """A 61-year-old woman presents with involuntary urine leakage when coughing or sneezing,
but no nighttime symptoms. Gynecological exam shows a hypermobile urethra.
What is the most likely diagnosis and what would cystometry show?"""

print("=== Pre-Fine-Tuning Response ===")
pre_tuning_response = generate_response(question)


In [None]:
from datasets import load_dataset

# Load dataset
dataset = load_dataset(
    "FreedomIntelligence/medical-o1-reasoning-SFT", "en",
    split="train[:500]",  # First 500 examples
    trust_remote_code=True
)

# Split into train/validation
dataset = dataset.train_test_split(test_size=0.1, seed=42)

# Formatting function for SINGLE example
def format_prompt(example):
    return f"""
### Clinical Case:
{example['Question']}

### Step-by-Step Analysis:
<think>
{example['Complex_CoT']}
</think>

### Final Assessment:
{example['Response']}
"""

# Apply formatting to ENTIRE dataset
def preprocess_function(batch):
    return {"text": [format_prompt({
        "Question": q,
        "Complex_CoT": cot,
        "Response": r
    }) for q, cot, r in zip(
        batch["Question"],
        batch["Complex_CoT"],
        batch["Response"]
    )]}

# Map the formatting across all splits
dataset = dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=dataset["train"].column_names  # Remove original columns
)

# Preview
print(dataset["train"][0]["text"])  # Now shows fully formatted example

In [None]:
# Instead of re-applying LoRA, create a new LoRA model based on the original model
original_model = model  # Store the original model with LoRA

# Apply LoRA to the original, non-LoRA model
model_lora = FastLanguageModel.get_peft_model(
    model=original_model,  # Use the original model instead of original_model.base_model
    r=16,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_alpha=16,
    lora_dropout=0,
    bias="none",
    use_gradient_checkpointing="unsloth",
    random_state=3047,
    use_rslora=False,
    loftq_config=None
)

# Training arguments
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./medical-lora",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    warmup_steps=10,
    num_train_epochs=1,
    learning_rate=2e-4,
    fp16=not torch.cuda.is_bf16_supported(),
    bf16=torch.cuda.is_bf16_supported(),
    logging_steps=10,
    eval_strategy="steps",  # Changed from evaluation_strategy
    eval_steps=50,
    save_strategy="steps",
    save_steps=100,
    report_to="wandb"
)

# Initialize trainer
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    args=training_args
)

In [None]:
import wandb
wandb.init(project="medical-llm-finetuning")

# Start training
trainer.train()

# Save model
model.save_pretrained("medical-deepseek-lora")
tokenizer.save_pretrained("medical-deepseek-lora")

wandb.finish()

In [None]:
# Reload model for inference
model, tokenizer = FastLanguageModel.from_pretrained(
    "medical-deepseek-lora",
    max_seq_length=max_seq_length,
    dtype=None,
    load_in_4bit=True,
)


In [None]:
### Test Case 1 Revisited: Urinary Incontinence
# Let's compare responses to the same question

print("=== Post-Fine-Tuning Response ===")
post_tuning_response = generate_response(question)


In [None]:
question2 = """A 42-year-old IV drug user presents with fever, shortness of breath,
and a new murmur. Blood cultures grow gram-positive cocci in clusters.
Echocardiography shows vegetations on the tricuspid valve.
What is the most likely organism and why is this valve involved?"""

print("=== Post-Fine-Tuning Response (New Case) ===")
post_tuning_response2 = generate_response(question2)


In [None]:
from IPython.display import display, Markdown

def compare_responses(question, pre, post):
    display(Markdown(f"""
### **Question**:
{question}

#### **Pre-Fine-Tuning**:
```text
{pre.split('### Final Assessment:')[-1].strip()}
```

#### **Post-Fine-Tuning**:
```text
{post.split('### Final Assessment:')[-1].strip()}
```
"""))

compare_responses(question, pre_tuning_response, post_tuning_response)