# Stage 2: Helpful Response Fine-tuning (Gemma-7B-IT, QLoRA) — Colab Notebook

This notebook fine-tunes Google Gemma-7B-IT using QLoRA on the Anthropic Helpful-Harmless dataset, logs experiments to Weights & Biases, and optionally evaluates helpfulness and safety deltas using your Stage 1 safety classifier.

Notes:
- You need to accept the Gemma model license on Hugging Face Hub with your account before training.
- You will login to Hugging Face and W&B via Colab widgets (no plaintext secrets).
- If you have a Stage 1 package zip in Google Drive (safety_text_classifier_trained_*.zip), this notebook will auto-extract it for safety filtering and evaluation.


In [None]:
# Install dependencies (QLoRA stack; make safety/JAX optional)
!pip -q install -U "pyarrow<20.0.0"
!pip -q install -U bitsandbytes==0.43.1 accelerate datasets wandb evaluate pyyaml tqdm sentencepiece
!pip -q install -U git+https://github.com/huggingface/transformers.git
!pip -q install -U git+https://github.com/huggingface/peft.git
# JAX: avoid Colab plugin mismatch; only install if missing
import subprocess, sys
from importlib.metadata import version, PackageNotFoundError

def pip_quiet(args):
    return subprocess.run([sys.executable, '-m', 'pip', 'install', '-q'] + args).returncode == 0

# Remove potentially incompatible CUDA plugin (Colab ships newer jaxlib)
subprocess.run([sys.executable, '-m', 'pip', 'uninstall', '-y', 'jax-cuda12-plugin'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)

# Try to import existing JAX; if unavailable, install CPU-only JAX to avoid GPU plugin issues
try:
    import jax
    try:
        jax_ver = version('jax')
    except PackageNotFoundError:
        jax_ver = 'unknown'
    try:
        jaxlib_ver = version('jaxlib')
    except PackageNotFoundError:
        jaxlib_ver = 'unknown'
    print('JAX present:', jax_ver, 'jaxlib:', jaxlib_ver)
except Exception as e:
    print('JAX not present or broken, installing CPU-only JAX...', e)
    pip_quiet(["jax[cpu]==0.4.38"])
    pip_quiet(["flax>=0.8.4,<0.9.0", "optax>=0.2.2,<0.3.0"])
    import jax
    print('JAX version:', jax.__version__)

# If JAX imports, print devices (may be CPU)
try:
    import jax
    print('JAX devices:', jax.devices())
except Exception as e:
    print('JAX devices check failed:', e)


In [None]:
# GPU check & memory tweaks
import torch, os
print('GPU available:', torch.cuda.is_available())
if torch.cuda.is_available():
    print('GPU name:', torch.cuda.get_device_name(0))
    # Helpful memory settings on Colab
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True,max_split_size_mb:512'


## Repository setup
You have two options:
- A) Mount Google Drive if you already have your repo under Drive (recommended)
- B) Clone your GitHub repository (replace the placeholder URL)


In [None]:
# Option A: Mount Google Drive and work from there (if you keep your repo in Drive)
USE_DRIVE = True  # set to False to use GitHub clone instead

import os, glob
from pathlib import Path
repo_root = '/content/ml-learning'

if USE_DRIVE:
    from google.colab import drive
    try:
        drive.mount('/content/drive')
        print('Drive mounted. Expecting repo under /content/drive/MyDrive/ml-learning')
        if os.path.exists('/content/drive/MyDrive/ml-learning'):
            !rsync -a --delete /content/drive/MyDrive/ml-learning/ /content/ml-learning/
        else:
            print('Repo not found in Drive. Proceeding with placeholder GitHub clone.')
            USE_DRIVE = False
    except Exception as e:
        print('Drive not mounted. Using GitHub clone. Error:', e)
        USE_DRIVE = False

if not USE_DRIVE:
    # Option B: Clone from GitHub (replace with your repo URL)
    if not os.path.exists(repo_root):
        !git clone https://github.com/yourusername/ml-learning.git {repo_root}

%cd {repo_root}/helpful-finetuning
!pwd


In [None]:
# If a Stage 1 zip exists in Drive, auto-extract to expected path for safety filtering/eval
import os, glob
zip_candidates = glob.glob('/content/drive/MyDrive/safety_text_classifier_trained_*.zip')
if zip_candidates:
    os.makedirs('/content/ml-learning/safety-text-classifier', exist_ok=True)
    print('Found Stage 1 package:', zip_candidates[0])
    !unzip -o "{zip_candidates[0]}" -d /content/ml-learning/safety-text-classifier
else:
    print('No Stage 1 zip found. If checkpoints are in the repo path, safety filter will use them.
Otherwise safety filter defaults to safe to avoid blocking training.')


In [None]:
# Login to Hugging Face and Weights & Biases (widgets)
from huggingface_hub import login
import wandb
login()  # Enter HF token in widget (required for Gemma)
wandb.login()  # Enter W&B API key in widget


## Train: Gemma-7B-IT with QLoRA (Colab-optimized overrides)
- Base config: `configs/base_config.yaml`
- Overrides:   `configs/colab_config.yaml` (smaller batch/seq_len, GA)


In [None]:
# Start training
!python -m src.training.train_qlora --config configs/base_config.yaml --override configs/colab_config.yaml


## Evaluate (quick subset)
Computes a simple helpfulness heuristic vs base and safety deltas using Stage 1.


In [None]:
# Run evaluation (uses ./lora_adapters if present)
!python -m src.evaluation.evaluate_helpfulness --config configs/base_config.yaml


## (Optional) Interactive demo (Gradio)
Launch a lightweight UI with share link to compare base vs fine-tuned and see safety overlay.


In [None]:
#@title Launch demo (optional)
USE_DEMO = False  #@param {type:"boolean"}
if USE_DEMO:
    import gradio as gr
    from src.inference.generate import GemmaInference
    from src.utils.safety_integration import SafetyFilter
    base = GemmaInference('google/gemma-7b-it', adapter_path=None, load_in_4bit=True)
    ft   = GemmaInference('google/gemma-7b-it', adapter_path='./lora_adapters', load_in_4bit=True)
    safety = SafetyFilter(
        classifier_config_path='../safety-text-classifier/configs/base_config.yaml',
        checkpoint_dir='../safety-text-classifier/checkpoints/best_model',
    )
    def compare(prompt, temperature, top_p, max_length, safety_threshold):
        b = base.generate(prompt, max_length=max_length, temperature=temperature, top_p=top_p)
        f = ft.generate(prompt,   max_length=max_length, temperature=temperature, top_p=top_p)
        bs = safety.score_text(b)
        fs = safety.score_text(f)
        bf = '🟢' if bs >= safety_threshold else '🔴'
        ff = '🟢' if fs >= safety_threshold else '🔴'
        return b, f, f"{bf} Safety: {bs:.2f}", f"{ff} Safety: {fs:.2f}"
    with gr.Blocks() as app:
        gr.Markdown('## Stage 2: Base vs Fine-tuned (Gemma-7B-IT + QLoRA)')
        prompt = gr.Textbox(label='Prompt', lines=4)
        with gr.Row():
            temperature = gr.Slider(0.1, 1.0, value=0.7, step=0.05, label='Temperature')
            top_p       = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label='Top-p')
            max_length  = gr.Slider(64, 1024, value=512, step=16, label='Max length')
            safety_thr  = gr.Slider(0.0, 1.0, value=0.8, step=0.05, label='Safety threshold')
        go = gr.Button('Generate')
        base_out = gr.Textbox(label='Base response', lines=10)
        ft_out   = gr.Textbox(label='Fine-tuned response', lines=10)
        base_s   = gr.Label(label='Base safety')
        ft_s     = gr.Label(label='Fine-tuned safety')
        go.click(compare, inputs=[prompt, temperature, top_p, max_length, safety_thr], outputs=[base_out, ft_out, base_s, ft_s])
    app.launch(share=True)
