# OncoScope: Fine-tune Gemma 3n for Cancer Genomics Analysis
Modified from Unsloth's official Gemma 3n notebook for cancer mutation analysis

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/yourusername/oncoscope/blob/main/oncoscope_gemma_3n_colab.ipynb)

## 🚀 Quick Start
This notebook can be run with **Runtime > Run all** for automatic execution, or step-by-step for more control.
- **Auto-configuration**: Detects your GPU and optimizes settings automatically
- **Checkpoint persistence**: Saves to Google Drive to prevent work loss
- **Production-ready**: Includes both automated and manual configuration options

## 1. Mount Google Drive (Save Progress!)
**IMPORTANT**: Mount Drive first to prevent losing checkpoints if Colab disconnects

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Create checkpoint directory in Drive
import os
checkpoint_dir = "/content/drive/MyDrive/oncoscope_checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
print(f"✅ Checkpoints will be saved to: {checkpoint_dir}")

## 2. Install Dependencies
Using Unsloth's official installation method for Gemma 3n support

In [None]:
%%capture
# Install Unsloth with Gemma 3n support - official recommended method
!pip install "unsloth[gemma3n] @ git+https://github.com/unslothai/unsloth.git"
!pip install --upgrade "datasets>=2.16.0" "trl>=0.8.3"

## 3. Upload Your Cancer Training Data
Upload your `cancer_training_data.json` file when prompted

In [None]:
from google.colab import files
import json

print("Please upload your cancer_training_data.json file...")
uploaded = files.upload()

# Load the training data (robust file handling)
file_name = list(uploaded.keys())[0]
with open(file_name, 'r') as f:
    cancer_training_data = json.load(f)

print(f"✅ Loaded {len(cancer_training_data)} cancer genomics training examples")

## 4. Auto-Detect GPU and Configure Settings

In [None]:
import torch

# Auto-detect GPU and configure settings
gpu_name = torch.cuda.get_device_name(0)
vram_gb = torch.cuda.get_device_properties(0).total_memory / 1e9

print(f"Detected GPU: {gpu_name}")
print(f"Available VRAM: {vram_gb:.1f} GB")

# Note: Unsloth handles Conv2D autocast issues for float16 GPUs internally
# No manual intervention needed for T4 GPUs

# Auto-configure based on GPU
if "A100" in gpu_name:
    batch_size = 4
    grad_accum = 1
    lora_r = 32
    lora_alpha = 64
    max_memory = "40GB"
    max_seq_length = 2048
    print("✅ A100 detected - Using optimal settings!")
elif "L4" in gpu_name:
    batch_size = 2
    grad_accum = 2
    lora_r = 16
    lora_alpha = 32
    max_memory = "22GB"
    max_seq_length = 1536
    print("✅ L4 detected - Using balanced settings")
else:  # T4 or smaller
    batch_size = 1
    grad_accum = 4
    lora_r = 8
    lora_alpha = 8
    max_memory = "14GB"
    max_seq_length = 1024
    print("✅ T4/smaller GPU - Using memory-efficient settings")
    print("📝 Note: Unsloth automatically handles float16 Conv2D issues")

print(f"\nConfiguration:")
print(f"- Batch size: {batch_size}")
print(f"- Gradient accumulation: {grad_accum}")
print(f"- Effective batch size: {batch_size * grad_accum}")
print(f"- LoRA rank: {lora_r}")
print(f"- Max sequence length: {max_seq_length}")

## 5. Load Gemma 3n Model

In [None]:
from unsloth import FastModel
import torch

# Use E2B (2B) model with auto-detected settings
model, tokenizer = FastModel.from_pretrained(
    model_name = "unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit",
    dtype = None,  # Auto detection
    max_seq_length = max_seq_length,  # From auto-config
    load_in_4bit = True,
    full_finetuning = False,  # Use LoRA
    device_map = "auto",
    max_memory = {0: max_memory, "cpu": "20GB"},  # From auto-config
)

print(f"Model loaded! GPU: {torch.cuda.get_device_name(0)}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 6. Configure LoRA Adapters
### Option A: Auto-Configured LoRA (Recommended)

In [None]:
# Configure for text-only cancer genomics with auto-detected settings
model = FastModel.get_peft_model(
    model,
    finetune_vision_layers     = False,  # Disable vision - saves VRAM!
    finetune_language_layers   = True,   # Text only for cancer analysis
    finetune_attention_modules = True,   # Good for specialized domain
    finetune_mlp_modules       = True,   # Keep for performance
    
    r = lora_r,                # From auto-config
    lora_alpha = lora_alpha,   # From auto-config
    lora_dropout = 0,
    bias = "none",
    random_state = 3407,
)

print(f"LoRA adapters configured with rank={lora_r} for cancer genomics!")

### Option B: Manual Override (Advanced Users)
Uncomment the code below to manually configure LoRA parameters

In [None]:
# MANUAL OVERRIDE: Use this if you want to override auto-configuration
# Uncomment and modify the values below as needed

# model = FastModel.get_peft_model(
#     model,
#     finetune_vision_layers     = False,
#     finetune_language_layers   = True,
#     finetune_attention_modules = True,
#     finetune_mlp_modules       = True,
#     
#     r = 8,           # LoRA rank - manual override
#     lora_alpha = 8,  # Match r value
#     lora_dropout = 0,
#     bias = "none",
#     random_state = 3407,
# )
# 
# print("LoRA adapters manually configured for cancer genomics!")

## 7. Prepare Cancer Genomics Dataset

In [None]:
from unsloth.chat_templates import get_chat_template, standardize_data_formats
from datasets import Dataset

# Setup Gemma 3 chat template
tokenizer = get_chat_template(tokenizer, chat_template = "gemma-3")

# Convert cancer data to conversation format (CORRECT FORMAT)
conversations_data = [
    {
        'conversations': [
            {'role': 'user', 'content': ex['input']},
            {'role': 'assistant', 'content': ex['output']},
        ]
    } for ex in cancer_training_data
]

# Create dataset
dataset = Dataset.from_list(conversations_data)
dataset = standardize_data_formats(dataset)

# Apply chat template
def formatting_prompts_func(examples):
    convos = examples["conversations"]
    texts = [
        tokenizer.apply_chat_template(
            convo, 
            tokenize=False, 
            add_generation_prompt=False
        ).removeprefix('<bos>')
        for convo in convos
    ]
    return {"text": texts}

dataset = dataset.map(formatting_prompts_func, batched=True)

# Split dataset
dataset = dataset.train_test_split(test_size=0.1, seed=3407)
print(f"Training samples: {len(dataset['train'])}")
print(f"Evaluation samples: {len(dataset['test'])}")

## 8. Setup Training Configuration

In [None]:
from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth.chat_templates import train_on_responses_only

# Check for existing checkpoints
import glob
existing_checkpoints = glob.glob(f"{checkpoint_dir}/checkpoint-*")
resume_from_checkpoint = True if existing_checkpoints else False

if resume_from_checkpoint:
    print(f"✅ Found {len(existing_checkpoints)} checkpoints - will resume from latest!")
else:
    print("📝 No checkpoints found - starting fresh training")

# Use modern TrainingArguments with automatic bf16/fp16 detection
training_args = TrainingArguments(
    output_dir = checkpoint_dir,
    per_device_train_batch_size = batch_size,
    gradient_accumulation_steps = grad_accum,
    warmup_steps = 5,
    num_train_epochs = 3,
    learning_rate = 2e-4,
    fp16 = not torch.cuda.is_bf16_supported(),
    bf16 = torch.cuda.is_bf16_supported(),
    logging_steps = 10,
    save_steps = 100,
    eval_strategy = "steps",
    eval_steps = 50,
    optim = "adamw_8bit",
    weight_decay = 0.01,
    lr_scheduler_type = "linear",
    seed = 3407,
    report_to = "tensorboard",
    load_best_model_at_end = True,
    metric_for_best_model = "eval_loss",
    greater_is_better = False,
    save_total_limit = 3,
)

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset["train"],
    eval_dataset = dataset["test"],
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    dataset_num_proc = 2,
    packing = False,
    args = training_args,
)

# Train only on assistant responses
trainer = train_on_responses_only(
    trainer,
    instruction_part = "<start_of_turn>user\n",
    response_part = "<start_of_turn>model\n",
)

print(f"Trainer configured with effective batch size = {batch_size * grad_accum}")

## 9. Start Training
### Option A: Full Training with Monitoring (Recommended)

In [None]:
# Monitor GPU memory before training
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

# Estimate training time
if "A100" in gpu_stats.name:
    estimated_time = "1-2 hours"
elif "L4" in gpu_stats.name:
    estimated_time = "3-4 hours"
else:
    estimated_time = "10-12 hours"
print(f"\n⏰ Estimated training time: {estimated_time}")
print("💾 Checkpoints auto-save to Google Drive every 100 steps")
print("🔄 Training will auto-resume if disconnected\n")

# IMPORTANT: Gemma 3n starts with high losses (6-7) - this is normal!
print("⚠️ NOTE: Initial loss will be high (6-7) for Gemma 3n - this is expected!")
print("The loss will drop rapidly in the first few steps.\n")

# Train the model
trainer_stats = trainer.train(resume_from_checkpoint=resume_from_checkpoint)

# Show training stats
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory * 100, 3)
print(f"\nTraining completed in {trainer_stats.metrics['train_runtime']/60:.2f} minutes")
print(f"Peak memory = {used_memory} GB ({used_percentage}% of max)")
print(f"Peak memory for training = {used_memory_for_lora} GB")

### Option B: Quick Training Start (Minimal Output)
Uncomment the code below for faster execution without detailed monitoring

In [None]:
# Quick start training (minimal output)
# Uncomment below for quick training without detailed monitoring

# print("Starting training...")
# trainer_stats = trainer.train(resume_from_checkpoint=resume_from_checkpoint)
# print(f"✅ Training completed in {trainer_stats.metrics['train_runtime']/60:.2f} minutes")

## 10. Test Model & Preview Results

In [None]:
# Test cancer genomics queries with Gemma 3n official settings
test_queries = [
    "Analyze the BRCA1 c.68_69delAG mutation for cancer risk assessment.",
    "What are the therapeutic implications of KRAS G12C mutation in lung cancer?",
    "Provide clinical recommendations for a patient with TP53 R175H mutation."
]

from transformers import TextStreamer
streamer = TextStreamer(tokenizer, skip_prompt=True)

for query in test_queries:
    print(f"\n{'='*60}")
    print(f"Query: {query}")
    print(f"{'='*60}")
    
    messages = [{"role": "user", "content": query}]
    inputs = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt",
        tokenize=True,
        return_dict=True,
    ).to("cuda")
    
    # Use Gemma 3n official inference settings with CRITICAL do_sample=True
    _ = model.generate(
        **inputs,
        max_new_tokens=256,
        temperature=1.0,          # Gemma team recommended
        top_p=0.95,              # Gemma team recommended
        top_k=64,                # Gemma team recommended
        min_p=0.0,               # Official setting
        repetition_penalty=1.0,   # Official setting
        do_sample=True,          # CRITICAL: Must be True for sampling to work!
        streamer=streamer,
    )
    
    # Cleanup
    del inputs
    torch.cuda.empty_cache()

## 11. Save the Model
### Option A: Save to Google Drive (Persistent - Recommended)

In [None]:
# Save LoRA adapters to Drive
save_path = f"{checkpoint_dir}/oncoscope-gemma-3n-final"
trainer.save_model(save_path)
tokenizer.save_pretrained(save_path)

print(f"✅ Model saved to Google Drive: {save_path}")

# Option 1: Save merged model (16-bit) - recommended for A100
if "A100" in torch.cuda.get_device_name(0):
    print("Saving merged model (A100 has enough memory)...")
    model.save_pretrained_merged(f"{save_path}-merged", tokenizer, save_method="merged_16bit")

# Option 2: Save as GGUF for Ollama
# print("Saving GGUF format for local deployment...")
# model.save_pretrained_gguf(
#     f"{save_path}-gguf",
#     tokenizer,
#     quantization_type = "Q8_0"
# )

print("\n📦 Files saved to Google Drive - safe from disconnections!")
print("💡 You can access them at /content/drive/MyDrive/oncoscope_checkpoints/")

### Option B: Save Locally & Download

In [None]:
# Save LoRA adapters locally for immediate download
local_save_path = "oncoscope-gemma-3n-lora"
model.save_pretrained(local_save_path)
tokenizer.save_pretrained(local_save_path)

# Zip the model for easy download
!zip -r {local_save_path}.zip {local_save_path}/

# Download the zip file
from google.colab import files
files.download(f'{local_save_path}.zip')

print(f"✅ Model downloaded as {local_save_path}.zip")

## 🎉 Training Complete!

### Next Steps:
1. **Deploy with Ollama**: Convert to GGUF format (uncomment code in Section 11)
2. **Share on HuggingFace**: Push your model to the Hub
3. **Production Deployment**: Use the saved model in your cancer genomics application

### Key Features of The Model:
- Specialized for cancer mutation analysis (BRCA1/2, TP53, KRAS, etc.)
- Trained on 6,000+ expert-curated examples
- Optimized for clinical decision support
- Ready for real-world deployment

### Resources:
- [OncoScope Documentation](https://github.com/yourusername/oncoscope)
- [Unsloth Documentation](https://docs.unsloth.ai/)
- [Gemma 3n Guide](https://docs.unsloth.ai/basics/gemma-3n-how-to-run-and-fine-tune)

Thank you for using OncoScope! Together, we're making precision oncology accessible to everyone. 🧬🏥