# Stage 2: Helpful Response Fine-tuning (Gemma-7B-IT, QLoRA) — Colab Notebook

This notebook fine-tunes Google Gemma-7B-IT using QLoRA on the Anthropic Helpful-Harmless dataset, logs experiments to Weights & Biases, and optionally evaluates helpfulness and safety deltas using your Stage 1 safety classifier.

Notes:
- You need to accept the Gemma model license on Hugging Face Hub with your account before training.
- You will login to Hugging Face and W&B via Colab widgets (no plaintext secrets).
- If you have a Stage 1 package zip in Google Drive (safety_text_classifier_trained_*.zip), this notebook will auto-extract it for safety filtering and evaluation.


In [None]:
# Minimal setup for Colab: ensure GPU and install uv (we'll use repo-pinned deps)
import torch
assert torch.cuda.is_available(), 'CUDA not available. Please enable GPU in Runtime > Change runtime type > Hardware accelerator: GPU.'
!pip -q install -U uv


In [None]:
# GPU check & memory tweaks
import torch, os
print('GPU available:', torch.cuda.is_available())
if torch.cuda.is_available():
    print('GPU name:', torch.cuda.get_device_name(0))
    # Helpful memory settings on Colab
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True,max_split_size_mb:512'


## Repository setup
You have two options:
- A) Mount Google Drive if you already have your repo under Drive (recommended)
- B) Clone your GitHub repository (replace the placeholder URL)


In [None]:
# Clone repo from GitHub and (optionally) mount Drive for model assets
USE_DRIVE_FOR_ASSETS = True  # Mount Drive to fetch large checkpoints only

import os, glob
from pathlib import Path
repo_root = '/content/ml-learning'

# Always clone or pull latest code from GitHub
if not os.path.exists(repo_root):
    !git clone https://github.com/Jai-Dhiman/ml-learning {repo_root}
else:
    print('Repo path exists; pulling latest changes...')
%cd {repo_root}
!git pull --ff-only

# Mount Drive only for model artifacts (e.g., Stage 1 zip)
if USE_DRIVE_FOR_ASSETS:
    from google.colab import drive
    try:
        drive.mount('/content/drive')
        print('Drive mounted for model assets.')
    except Exception as e:
        print('Drive not mounted. Proceeding without Drive assets. Error:', e)

%cd {repo_root}/helpful-finetuning
!pwd

# Create and sync a project-local environment pinned to repo deps
!uv venv
!bash -lc 'source .venv/bin/activate && uv sync'

# Ensure bitsandbytes (GPU) and Triton are present in the env
!bash -lc 'source .venv/bin/activate && uv pip uninstall -y bitsandbytes || true'
!bash -lc 'source .venv/bin/activate && python - <<\'PY\'
import torch, re, os
v = torch.version.cuda or ""
digits = re.sub(r"\D", "", v) or "121"
print("Detected CUDA version:", v, "-> BNB_CUDA_VERSION:", digits)
open("/tmp/bnb_cuda_version", "w").write(digits)

!bash -lc 'export BNB_CUDA_VERSION=$(cat /tmp/bnb_cuda_version); source .venv/bin/activate && uv pip install -U triton==2.2.0 && uv pip install --pre -U --extra-index-url https://jllllll.github.io/bitsandbytes-wheels/cu${BNB_CUDA_VERSION}/ bitsandbytes==0.43.1'
!bash -lc 'source .venv/bin/activate && python - <<\'PY\'
import importlib, importlib.metadata as im
import torch
try:
    import bitsandbytes as bnb
    print("torch.cuda:", torch.version.cuda)
    print("bnb import ok:", bnb.__file__)
    print("bnb version:", getattr(bnb, "__version__", "n/a"))
    try:
        print("bnb dist metadata:", im.version("bitsandbytes"))
    except Exception as e:
        print("[WARN] bitsandbytes distribution metadata missing:", e)
except Exception as e:
    print("[ERROR] bitsandbytes import failed:", e)
PY'


In [None]:
# If a Stage 1 zip exists in Drive, auto-extract to expected path for safety filtering/eval
import os, glob
dst_dir = '/content/ml-learning/safety-text-classifier'
os.makedirs(dst_dir, exist_ok=True)

# Preferred exact path (provided by user)
exact_zip = '/content/drive/MyDrive/safety-text-classifier/safety_text_classifier_trained_20250916_0632.zip'
candidates = []
if os.path.exists(exact_zip):
    candidates = [exact_zip]
else:
    # Fallback patterns
    pats = [
        '/content/drive/MyDrive/safety_text_classifier_trained_*.zip',
        '/content/drive/MyDrive/safety-text-classifier/safety_text_classifier_trained_*.zip',
    ]
    for p in pats:
        candidates.extend(glob.glob(p))

if candidates:
    candidates.sort(reverse=True)
    print('Found Stage 1 package:', candidates[0])
    !unzip -o "{candidates[0]}" -d {dst_dir}
else:
    print('No Stage 1 zip found on Drive. If checkpoints are in the repo path, safety filter will use them.')
    print('Otherwise safety filter defaults to safe to avoid blocking training.')


In [None]:
# Login to Hugging Face (required for Gemma model access)
# Secure login without storing/printing your token.
# If getpass has issues in Colab, this cell will fall back to the interactive widget provided by huggingface_hub.login().
import os
os.environ.pop("HF_TOKEN", None)
os.environ.pop("HUGGINGFACEHUB_API_TOKEN", None)
from huggingface_hub import login, HfApi
try:
    import getpass as gp
    raw = gp.getpass("Paste your Hugging Face token (input hidden): ")
    token = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
    if not isinstance(token, str):
        raise TypeError(f"Unexpected token type: {type(token).__name__}")
    token = token.strip()
    if not token:
        raise ValueError("Empty token provided")
    login(token=token, add_to_git_credential=False)
    who = HfApi().whoami(token=token)
    print(f"Logged in as: {who.get('name') or who.get('email') or 'OK'}")
except Exception as e:
    print(f"[HF Login] getpass flow failed: {e}")
    print("Falling back to interactive login widget...")
    login()
    try:
        who = HfApi().whoami()
        print(f"Logged in as: {who.get('name') or who.get('email') or 'OK'}")
    except Exception as e2:
        print(f"[HF Login] Verification skipped: {e2}")


In [None]:
# Login to Weights & Biases for experiment tracking
import wandb
wandb.login()  # Enter W&B API key in widget


In [None]:
# Preflight in project venv: verify Anthropic/hh-rlhf subset and splits
!bash -lc 'source .venv/bin/activate && python - <<\'PY\'
import yaml
from datasets import load_dataset, get_dataset_config_names, get_dataset_split_names
base = yaml.safe_load(open("configs/base_config.yaml"))
try:
    override = yaml.safe_load(open("configs/colab_config.yaml"))
except Exception:
    override = {}
cfg = dict(base)
if isinstance(override, dict):
    for k, v in override.items():
        if isinstance(v, dict) and k in cfg and isinstance(cfg[k], dict):
            cfg[k] = {**cfg[k], **v}
        else:
            cfg[k] = v
dcfg = cfg.get("dataset", {})
name = dcfg.get("name")
subset = dcfg.get("subset")
train_split = dcfg.get("train_split")
eval_split = dcfg.get("eval_split")
print("Selected dataset config:", dcfg)
assert name == "Anthropic/hh-rlhf", f"Stage 2 requires Anthropic/hh-rlhf, got: {name}"
assert subset, "dataset.subset is required"
configs = get_dataset_config_names(name)
print("Available subsets (reported):", configs)
# Authoritative tiny load using project venv packages
_tiny_ok = False
_tiny_err = None
try:
    _ = load_dataset(name, subset, split="test[:1]")
    _tiny_ok = True
except Exception as te:
    _tiny_err = te
if not _tiny_ok:
    raise SystemExit(f"FAIL: tiny load failed: {_tiny_err}")
# Validate splits
def _base_split(s):
    return s.split("[")[0].split(":")[0].strip() if s else s
splits = get_dataset_split_names(name, subset)
print(f"Available splits for {name}/{subset}:", splits)
if train_split and _base_split(train_split) not in splits:
    raise SystemExit(f"Invalid train_split {train_split}. Available: {splits}")
if eval_split and _base_split(eval_split) not in splits:
    raise SystemExit(f"Invalid eval_split {eval_split}. Available: {splits}")
print("Dataset preflight OK (subset confirmed by tiny load).")
PY'


## Train: Gemma-7B-IT with QLoRA (Colab-optimized overrides)
- Base config: `configs/base_config.yaml`
- Overrides:   `configs/colab_config.yaml` (smaller batch/seq_len, GA)


In [None]:
# Start training
!bash -lc 'source .venv/bin/activate && python -m src.training.train_qlora --config configs/base_config.yaml --override configs/colab_config.yaml'


## Evaluate (quick subset)
Computes a simple helpfulness heuristic vs base and safety deltas using Stage 1.


In [None]:
# Run evaluation (uses ./lora_adapters if present)
!bash -lc 'source .venv/bin/activate && python -m src.evaluation.evaluate_helpfulness --config configs/base_config.yaml'


## (Optional) Interactive demo (Gradio)
Launch a lightweight UI with share link to compare base vs fine-tuned and see safety overlay.


In [None]:
#@title Launch demo (optional)
USE_DEMO = False  #@param {type:"boolean"}
if USE_DEMO:
    import gradio as gr
    from src.inference.generate import GemmaInference
    from src.utils.safety_integration import SafetyFilter
    base = GemmaInference('google/gemma-7b-it', adapter_path=None, load_in_4bit=True)
    ft   = GemmaInference('google/gemma-7b-it', adapter_path='./lora_adapters', load_in_4bit=True)
    safety = SafetyFilter(
        classifier_config_path='../safety-text-classifier/configs/base_config.yaml',
        checkpoint_dir='../safety-text-classifier/checkpoints/best_model',
    )
    def compare(prompt, temperature, top_p, max_length, safety_threshold):
        b = base.generate(prompt, max_length=max_length, temperature=temperature, top_p=top_p)
        f = ft.generate(prompt,   max_length=max_length, temperature=temperature, top_p=top_p)
        bs = safety.score_text(b)
        fs = safety.score_text(f)
        bf = '🟢' if bs >= safety_threshold else '🔴'
        ff = '🟢' if fs >= safety_threshold else '🔴'
        return b, f, f"{bf} Safety: {bs:.2f}", f"{ff} Safety: {fs:.2f}"
    with gr.Blocks() as app:
        gr.Markdown('## Stage 2: Base vs Fine-tuned (Gemma-7B-IT + QLoRA)')
        prompt = gr.Textbox(label='Prompt', lines=4)
        with gr.Row():
            temperature = gr.Slider(0.1, 1.0, value=0.7, step=0.05, label='Temperature')
            top_p       = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label='Top-p')
            max_length  = gr.Slider(64, 1024, value=512, step=16, label='Max length')
            safety_thr  = gr.Slider(0.0, 1.0, value=0.8, step=0.05, label='Safety threshold')
        go = gr.Button('Generate')
        base_out = gr.Textbox(label='Base response', lines=10)
        ft_out   = gr.Textbox(label='Fine-tuned response', lines=10)
        base_s   = gr.Label(label='Base safety')
        ft_s     = gr.Label(label='Fine-tuned safety')
        go.click(compare, inputs=[prompt, temperature, top_p, max_length, safety_thr], outputs=[base_out, ft_out, base_s, ft_s])
    app.launch(share=True)
