# SR-010: Recovery LoRA Training

**Version:** 1.0.0 | **Last Updated:** 2026-01-07 PST

Train lightweight LoRA adapters on quantized V2 models to recover accuracy.

**Training Modes:**
- `recover`: CE loss on raw text (default)
- `sft`: Supervised fine-tuning on instruction/response pairs
- `kd`: Knowledge distillation from teacher model

**Prerequisites:**
- Trained V2 QAT checkpoint (e.g., from SR-008)
- CUDA GPU (recommended: T4, A100, or L4)

**Changelog:**
- v1.0.0: Initial release with recover/sft/kd modes

## 1. Setup

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Clone repository
%cd /content
!git clone https://github.com/Anemll/qwen3_apple_style_2bit_qat_lora.git repo
%cd repo
!git pull

In [None]:
# Install dependencies
!pip install -q transformers accelerate datasets sentencepiece protobuf wandb

# Login to W&B via Colab secret (optional)
# Go to: Colab menu -> Secrets (key icon) -> Add "WANDB_API_KEY"
try:
    from google.colab import userdata
    wandb_key = userdata.get('WANDB_API_KEY')
    if wandb_key:
        import wandb
        wandb.login(key=wandb_key)
        print("W&B: Logged in via secret")
    else:
        print("W&B: No API key found in secrets (optional)")
except Exception as e:
    print(f"W&B: Secret not configured (optional) - {e}")

In [None]:
# Check GPU
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 2. Configuration

In [None]:
# === CONFIGURATION ===

# Base model
MODEL_ID = "Qwen/Qwen3-0.6B"

# V2 QAT checkpoint path (from Google Drive)
V2_CHECKPOINT = "/content/drive/MyDrive/qwen3_runs/SR-008a-revive-hermes/final_v2_q4_r32_fp32_20260107_015153.pt"

# Training mode: "recover", "sft", or "kd"
LORA_MODE = "recover"

# Teacher model (only for kd mode)
TEACHER_MODEL = None  # e.g., "Qwen/Qwen3-4B-Instruct" for KD mode

# LoRA config
RECOVERY_R = 8        # LoRA rank (8-32 typical)
MLP_ONLY = False      # True = only MLP layers, False = MLP + attention

# Training config
MAX_STEPS = 500
SEQ_LEN = 1024
BATCH_SIZE = 4
LR = 3e-4
LOG_INTERVAL = 10
SAVE_STEPS = 100

# Data source (HuggingFace dataset)
TRAIN_DATA_HF = "Salesforce/wikitext"
HF_SUBSET = "wikitext-103-v1"
HF_SPLIT = "train"
HF_MAX_SAMPLES = 10000

# Output
OUTPUT_DIR = "runs/recovery_lora"
GDRIVE_OUTPUT = "/content/drive/MyDrive/qwen3_runs/SR-010-recovery-lora"

# W&B (optional)
USE_WANDB = False
WANDB_PROJECT = "qwen3-recovery-lora"
WANDB_RUN = "sr010-recover"

In [None]:
# Verify checkpoint exists and sync to local storage
import os
import shutil
from pathlib import Path

# Check if checkpoint exists on Google Drive
if not os.path.exists(V2_CHECKPOINT):
    print(f"ERROR: Checkpoint not found: {V2_CHECKPOINT}")
    print("\nAvailable checkpoints in qwen3_runs:")
    runs_dir = Path("/content/drive/MyDrive/qwen3_runs")
    if runs_dir.exists():
        for d in sorted(runs_dir.iterdir()):
            if d.is_dir():
                pts = list(d.glob("*.pt"))
                if pts:
                    print(f"  {d.name}/")
                    for pt in sorted(pts)[:5]:
                        size_mb = pt.stat().st_size / 1e6
                        print(f"    - {pt.name} ({size_mb:.1f} MB)")
else:
    size_mb = os.path.getsize(V2_CHECKPOINT) / 1e6
    print(f"Checkpoint found: {V2_CHECKPOINT}")
    print(f"Size: {size_mb:.1f} MB")
    
    # Sync to local storage for faster training
    LOCAL_CHECKPOINT_DIR = "/content/checkpoints"
    os.makedirs(LOCAL_CHECKPOINT_DIR, exist_ok=True)
    
    local_ckpt = Path(LOCAL_CHECKPOINT_DIR) / Path(V2_CHECKPOINT).name
    
    if not local_ckpt.exists():
        print(f"\nSyncing checkpoint to local storage...")
        shutil.copy(V2_CHECKPOINT, local_ckpt)
        print(f"  Copied to: {local_ckpt}")
        
        # Also copy config.json if it exists
        config_src = Path(V2_CHECKPOINT).parent / "config.json"
        if config_src.exists():
            shutil.copy(config_src, Path(LOCAL_CHECKPOINT_DIR) / "config.json")
            print(f"  Copied config.json")
    else:
        print(f"\nLocal checkpoint already exists: {local_ckpt}")
    
    # Update V2_CHECKPOINT to use local path
    V2_CHECKPOINT_LOCAL = str(local_ckpt)
    print(f"\nUsing local checkpoint: {V2_CHECKPOINT_LOCAL}")

## 3. Train Recovery LoRA

In [None]:
# Build training command (uses local checkpoint for speed)
# V2_CHECKPOINT_LOCAL was set in the sync cell above

checkpoint_path = V2_CHECKPOINT_LOCAL if 'V2_CHECKPOINT_LOCAL' in dir() else V2_CHECKPOINT

cmd = f"""
python scripts/train_recovery_lora.py \
    --model {MODEL_ID} \
    --v2-checkpoint "{checkpoint_path}" \
    --train-data-hf {TRAIN_DATA_HF} \
    --hf-subset {HF_SUBSET} \
    --hf-split {HF_SPLIT} \
    --hf-max-samples {HF_MAX_SAMPLES} \
    --lora-mode {LORA_MODE} \
    --recovery-r {RECOVERY_R} \
    --lr {LR} \
    --max-steps {MAX_STEPS} \
    --seq-len {SEQ_LEN} \
    --batch-size {BATCH_SIZE} \
    --log-interval {LOG_INTERVAL} \
    --save-steps {SAVE_STEPS} \
    --output {OUTPUT_DIR}
"""

# Add optional flags
if MLP_ONLY:
    cmd += "    --mlp-only \\\n"

if LORA_MODE == "kd" and TEACHER_MODEL:
    cmd += f"    --teacher {TEACHER_MODEL} \\\n"

if USE_WANDB:
    cmd += f"    --wandb --wandb-project {WANDB_PROJECT} --wandb-run {WANDB_RUN} \\\n"

print("Training command:")
print(cmd)

In [None]:
# Run training
!{cmd}

## 4. Test Inference

In [None]:
# Find the best checkpoint
from pathlib import Path

output_path = Path(OUTPUT_DIR)
checkpoints = sorted(output_path.glob("recovery_*.pt"))

if checkpoints:
    latest_ckpt = checkpoints[-1]
    print(f"Found {len(checkpoints)} checkpoints")
    print(f"Latest: {latest_ckpt}")
else:
    print("No checkpoints found!")

In [None]:
# Test inference with LoRA
# Option A: Full checkpoint with embedded LoRA
!python scripts/test_inference.py "{latest_ckpt}" \
    --prompt "What is the capital of France?" \
    --max-tokens 256

In [None]:
# Interactive mode
# !python scripts/test_inference.py "{latest_ckpt}" --interactive

## 5. Compare: Base vs LoRA

In [None]:
# Test prompts for comparison
TEST_PROMPTS = [
    "What is the capital of France?",
    "Explain quantum mechanics briefly.",
    "Write a haiku about coding.",
]

In [None]:
# Test base model (without LoRA)
base_ckpt = V2_CHECKPOINT_LOCAL if 'V2_CHECKPOINT_LOCAL' in dir() else V2_CHECKPOINT

print("=" * 60)
print("BASE MODEL (no LoRA)")
print("=" * 60)
for prompt in TEST_PROMPTS:
    print(f"\nPrompt: {prompt}")
    !python scripts/test_inference.py "{base_ckpt}" --prompt "{prompt}" --max-tokens 128 2>/dev/null | tail -5

In [None]:
# Test with LoRA
print("=" * 60)
print("WITH RECOVERY LoRA")
print("=" * 60)
for prompt in TEST_PROMPTS:
    print(f"\nPrompt: {prompt}")
    !python scripts/test_inference.py "{latest_ckpt}" --prompt "{prompt}" --max-tokens 128 2>/dev/null | tail -5

## 6. Save to Google Drive

In [None]:
# Create output directory on Google Drive
import os
import shutil
from pathlib import Path

os.makedirs(GDRIVE_OUTPUT, exist_ok=True)

# Copy checkpoints
output_path = Path(OUTPUT_DIR)
for f in output_path.glob("*.pt"):
    dest = Path(GDRIVE_OUTPUT) / f.name
    print(f"Copying {f.name} -> {dest}")
    shutil.copy(f, dest)

# Copy config.json if exists
config_path = output_path / "config.json"
if config_path.exists():
    shutil.copy(config_path, Path(GDRIVE_OUTPUT) / "config.json")
    print("Copied config.json")

# Copy training log if exists
log_path = output_path / "training_log.csv"
if log_path.exists():
    shutil.copy(log_path, Path(GDRIVE_OUTPUT) / "training_log.csv")
    print("Copied training_log.csv")

print(f"\nSaved to: {GDRIVE_OUTPUT}")

## 7. Plot Training Loss

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

log_file = Path(OUTPUT_DIR) / "training_log.csv"
if log_file.exists():
    df = pd.read_csv(log_file)
    
    plt.figure(figsize=(10, 4))
    plt.plot(df['step'], df['train_loss'], label='Train Loss')
    if 'eval_loss' in df.columns:
        plt.plot(df['step'], df['eval_loss'], label='Eval Loss', linestyle='--')
    plt.xlabel('Step')
    plt.ylabel('Loss')
    plt.title('Recovery LoRA Training')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()
    
    print(f"\nFinal loss: {df['train_loss'].iloc[-1]:.4f}")
    print(f"Best loss: {df['train_loss'].min():.4f}")
else:
    print("No training log found")

---

## Alternative: Knowledge Distillation Mode

Use a larger teacher model to guide the LoRA training.

In [None]:
# KD mode training (requires more VRAM)
kd_ckpt = V2_CHECKPOINT_LOCAL if 'V2_CHECKPOINT_LOCAL' in dir() else V2_CHECKPOINT

KD_CMD = f"""
python scripts/train_recovery_lora.py \
    --model {MODEL_ID} \
    --v2-checkpoint "{kd_ckpt}" \
    --train-data-hf {TRAIN_DATA_HF} \
    --hf-subset {HF_SUBSET} \
    --lora-mode kd \
    --teacher Qwen/Qwen3-4B-Instruct \
    --kd-temperature 2.0 \
    --kd-alpha 0.5 \
    --recovery-r 8 \
    --lr 3e-4 \
    --max-steps 500 \
    --seq-len 512 \
    --batch-size 2 \
    --output runs/recovery_kd
"""

print("KD mode command (requires A100 or similar):")
print(KD_CMD)

# Uncomment to run:
# !{KD_CMD}

---

## Alternative: SFT Mode with Alpaca

Supervised fine-tuning on instruction data.

In [None]:
# SFT mode with Alpaca dataset
sft_ckpt = V2_CHECKPOINT_LOCAL if 'V2_CHECKPOINT_LOCAL' in dir() else V2_CHECKPOINT

SFT_CMD = f"""
python scripts/train_recovery_lora.py \
    --model {MODEL_ID} \
    --v2-checkpoint "{sft_ckpt}" \
    --train-data-hf tatsu-lab/alpaca \
    --dataset-format alpaca \
    --template-mode think \
    --lora-mode sft \
    --recovery-r 8 \
    --lr 3e-4 \
    --max-steps 1000 \
    --seq-len 1024 \
    --batch-size 4 \
    --output runs/recovery_sft
"""

print("SFT mode command:")
print(SFT_CMD)

# Uncomment to run:
# !{SFT_CMD}

---

## Notes

**Memory Requirements:**
- `recover` mode: ~8-10 GB VRAM (T4 OK)
- `sft` mode: ~8-10 GB VRAM (T4 OK)
- `kd` mode: ~16-20 GB VRAM (A100/L4 recommended)

**Recommended Settings:**
- Start with `recovery-r 8`, increase to 16/32 if needed
- Use `--mlp-only` first for faster iteration
- Sequence length 1024-2048 for general recovery

**Checkpoint Types:**
- Full checkpoint: Contains base + LoRA weights (~5GB)
- LoRA-only (`--save-lora-only`): Just LoRA weights (~17MB)