# StreamGuard ML Training - Complete Notebook

**Version:** 1.7 (Safety Features Available - See instructions at end)  
**Last Updated:** 2025-11-01  
**Platform:** Google Colab (Free/Pro/Pro+)  
**GPU:** T4/V100/A100 (Adaptive Configuration)  
**Duration:** 11-24 hours (depends on GPU & config)  

This notebook trains all three StreamGuard models with **adaptive configuration** that automatically optimizes for your GPU.

## 🎯 Training Phases
1. **Enhanced SQL Intent Transformer** (2-8 hours depending on GPU)
2. **Enhanced Taint-Flow GNN** (4-12 hours depending on GPU)
3. **Fusion Layer** (2-10 hours depending on GPU)

## ✨ What's New in v1.7 (Safety Features)

**NEW: Optional Safety Features Available**
- ✅ **LR Finder with Safety Validation** (auto-detects optimal learning rate, 5e-4 cap, smart fallback)
- ✅ **LR Caching** (skip 5-10 min LR Finder on reruns, 168-hour cache)
- ✅ **Triple Weighting Auto-Adjustment** (prevents overcorrection when using sampler + weights + focal)
- ✅ **Enhanced Checkpoint Metadata** (includes seed, git commit, LR analysis)
- ✅ **Unit Tests** (14 tests verify all safety features)

**See instructions at the END of this notebook for how to use these features.**

**Backward Compatible:** All existing cells work exactly as before. New features are opt-in via CLI flags.

## ✨ What's New in v1.6 (Issue #11 - Training Collapse Fix)

### **CRITICAL: Training Collapse Fixed (Issue #11)**
- ✅ **Class-balanced loss with inverse frequency weights** (fixes model predicting only safe class)
- ✅ **LR scaling for large batches** (square-root rule: batch 64 gets 2x base LR)
- ✅ **Per-step scheduler** (moved inside train_epoch, was per-epoch before)
- ✅ **Gradient clipping** (max_norm=1.0 prevents exploding gradients)
- ✅ **Prediction distribution monitoring** (detects collapse early)
- ✅ **Enhanced collapse detection** (stops training if model predicts only one class)
- ✅ **Conservative label smoothing** (0.05 instead of 0.1)
- ✅ **Simplified loss calculation** (removed unnecessary sample-level weighting)

**Root Cause (Issue #11):** Model collapsed from F1=0.4337 (epoch 1) to F1=0.0000 (epoch 3+) due to:
1. No class balancing (54.2% safe vs 45.8% vulnerable)
2. LR designed for batch=16 but using batch=64
3. Scheduler stepping per-epoch instead of per-step
4. No gradient clipping
5. No early collapse detection

**The Fix:** All 8 critical fixes implemented in train_transformer.py (see `docs/ISSUE_11_TRAINING_COLLAPSE_COMPLETE_FIX.md`)

### **Previous Fixes (v1.5 - Issue #10)**
- ✅ **Max seq length configuration fixed** (512 for all GPUs, not 1024/768)
- ✅ **Automatic validation** to prevent exceeding CodeBERT's 512-token limit
- ✅ **Tensor size mismatch error prevented**
- ✅ **Updated PyTorch AMP API** (torch.amp instead of torch.cuda.amp)

### **Previous Fixes (v1.4 - Issue #9)**
- ✅ **Fixed CrossEntropyLoss tensor-to-scalar error**
- ✅ **Fixed sample weights handling**
- ✅ **Updated deprecated autocast/GradScaler**
- ✅ **Added Cell 1.5** (robust GPU detection with fallback)

### **Previous Fixes (v1.3 - Issue #8)**
- ✅ **Fixed NumPy binary incompatibility** (numpy==1.26.4 enforced)
- ✅ **Fixed tokenizers/transformers conflict** (tokenizers 0.14.1)
- ✅ **Fixed PyG circular import errors**

### **Adaptive GPU Configuration (Colab Pro)**
- 🔍 **Auto-detects GPU type** (T4/V100/A100) via Cell 1.5
- ⚙️  **Selects optimal hyperparameters** automatically
- 📊 **Three configuration tiers**:
  - **OPTIMIZED** (T4): 10/150/30 epochs, batch 32/64, seq 512, ~13-17h
  - **ENHANCED** (V100): 15/200/50 epochs, batch 48/96, seq 512, ~18-22h (2-3x faster)
  - **AGGRESSIVE** (A100): 20/300/100 epochs, batch 64/128, seq 512, ~20-24h (5-7x faster)

**Note:** All configurations use `max_seq_len = 512` (CodeBERT/RoBERTa model limit). Better GPUs benefit from larger batch sizes and more epochs.

### **Colab Pro Benefits**
- ✅ 24-hour runtime (vs 12h free)
- ✅ Better GPU access (V100, A100)
- ✅ Background execution
- ✅ **Larger batches → better gradient estimates**

**Recommended:** V100 on Colab Pro ($10/mo) for best balance of speed and availability.

## 🔧 All Critical Fixes Applied (v1.1 → v1.7)

### **v1.7 Fixes (Safety Features) - NEW**
- ✅ LR Finder with safety validation (5e-4 cap, 1e-5 fallback)
- ✅ LR caching (168-hour default, dataset fingerprint-based)
- ✅ Triple weighting auto-adjustment (20% reduction when all enabled)
- ✅ Enhanced checkpoint metadata (seed, git, LR analysis)
- ✅ Unit tests (14 tests for all safety features)

### **v1.6 Fixes (Issue #11)**
- ✅ Class-balanced loss with inverse frequency weights
- ✅ LR scaling for large batches (square-root rule)
- ✅ Warmup ratio adjustment (proportional, capped at 20%)
- ✅ Per-step scheduler (moved inside train_epoch)
- ✅ Gradient clipping (max_norm=1.0)
- ✅ Prediction distribution monitoring
- ✅ Enhanced collapse detection
- ✅ Conservative label smoothing (0.05)
- ✅ Drive-based data workflow (automatic copy to local storage)
- ✅ Pre-training validation tests

### **v1.5 Fixes (Issue #10)**
- ✅ Max seq length configuration fixed
- ✅ Automatic validation added
- ✅ Tensor size mismatch prevented
- ✅ PyTorch AMP API updated

### **v1.4 Fixes (Issue #9)**
- ✅ CrossEntropyLoss tensor-to-scalar error fixed
- ✅ Sample weights handling validated
- ✅ Deprecated API updated
- ✅ GPU detection robustness improved

### **v1.3 Fixes (Issue #8)**
- ✅ NumPy binary compatibility fixed
- ✅ tokenizers/transformers conflict resolved
- ✅ PyG circular import fixed

### **v1.1-v1.2 Fixes (Issues #1-#7)**
- ✅ Runtime-aware PyTorch Geometric installation
- ✅ Robust tree-sitter build with fallback
- ✅ Version compatibility validation
- ✅ Enhanced dependency conflict detection
- ✅ Optimized OOF fusion

## 📋 Before Starting

### **Colab Configuration:**
1. Enable GPU: **Runtime → Change runtime type → GPU**
2. **Recommended:** Subscribe to Colab Pro ($10/mo) for:
   - 24-hour runtime (required for full training)
   - Access to V100/A100 GPUs (2-7x faster than T4)
   - Background execution

### **Data Requirements - IMPORTANT:**

**You MUST upload preprocessed data files to Google Drive:**

```
My Drive/streamguard/data/processed/codexglue/
├── train.jsonl (504 MB, 21,854 samples)
├── valid.jsonl (63 MB, 2,732 samples)
├── test.jsonl (63 MB, 2,732 samples)
└── preprocessing_metadata.json (1.6 KB)
```

**Total size:** ~630 MB

**Why Google Drive?**
- Data files are too large for GitHub (exceeds 100 MB limit)
- They are in `.gitignore` and won't be cloned from the repository
- **Cell 6** will automatically mount Drive and copy data to Colab local storage
- Local storage provides faster I/O during training (vs reading from Drive each time)

**How to upload:**
1. Open Google Drive: https://drive.google.com/
2. Create folder structure: `My Drive/streamguard/data/processed/codexglue/`
3. Upload the 4 data files to this folder
4. Run notebook Cell 6 - it will copy files to Colab automatically

## 📊 Expected Results by Configuration

| Config | GPU | Time | Batch Sizes (T/G) | Seq Len | Speed vs T4 |
|--------|-----|------|-------------------|---------|-------------|
| **OPTIMIZED** | T4 | 13-17h | 32 / 64 | 512 | 1.0x |
| **ENHANCED** | V100 | 18-22h | 48 / 96 | 512 | 2-3x faster |
| **AGGRESSIVE** | A100 | 20-24h | 64 / 128 | 512 | 5-7x faster |

*Note: All configs use max_seq_len=512 (CodeBERT limit). Better GPUs use larger batches/epochs for quality.*

## 🚀 Quick Start

1. **Upload data to Drive** (see Data Requirements above)
2. Run **Cell 1**: Verify GPU is enabled
3. Run **Cell 1.5**: Auto-detect GPU and select configuration  
4. Run **Cell 2**: Install dependencies with compatibility fixes
5. Run **Cell 2.5**: Validate compatibility
6. Run **Cell 3**: Clone repository from GitHub
7. Run **Cell 4**: Setup tree-sitter
8. Run **Cell 6**: Mount Drive and copy data to local storage ⭐
9. **Run TEST CELLS 6.5 & 6.6**: Verify Issue #11 fixes (5-15 min total)
10. Run **Cells 7, 9, 11**: Full training with adaptive configuration
11. Monitor progress (can close browser with Colab Pro)

**IMPORTANT:** Run the test cells (6.5 & 6.6) before full training to verify all fixes are working!

**NEW:** For v1.7 safety features, see instructions at the END of this notebook.

## 🔗 Documentation

- **Training Collapse Fix:** See [docs/ISSUE_11_TRAINING_COLLAPSE_COMPLETE_FIX.md](https://github.com/VimalSajanGeorge/streamguard/blob/master/docs/ISSUE_11_TRAINING_COLLAPSE_COMPLETE_FIX.md)
- **Final Recommendations:** See [docs/ISSUE_11_FINAL_CAUTIONS_AND_RECOMMENDATIONS.md](https://github.com/VimalSajanGeorge/streamguard/blob/master/docs/ISSUE_11_FINAL_CAUTIONS_AND_RECOMMENDATIONS.md)
- **Max Seq Length Fix:** See [docs/ISSUE_10_MAX_SEQ_LEN_FIX.md](https://github.com/VimalSajanGeorge/streamguard/blob/master/docs/ISSUE_10_MAX_SEQ_LEN_FIX.md)
- **Critical Fixes Details:** See [docs/COLAB_CRITICAL_FIXES.md](https://github.com/VimalSajanGeorge/streamguard/blob/master/docs/COLAB_CRITICAL_FIXES.md)
- **Troubleshooting:** Check Issue #8, #9, #10, and #11 documentation for common errors

---
## Part 1: Environment Setup
Run these cells once at the beginning

In [None]:
# Cell 1: Verify GPU
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f"CUDA Version: {torch.version.cuda}")
else:
    print("⚠️  WARNING: GPU not available! Enable GPU in Runtime → Change runtime type")

In [None]:
# Cell 1.5: GPU Detection & Adaptive Configuration (Colab Pro Optimization)
import subprocess
import json
import torch
import re

def get_gpu_info():
    """Detect GPU type and memory with robust fallback."""
    try:
        # Try nvidia-smi first (most reliable)
        result = subprocess.run(
            ['nvidia-smi', '--query-gpu=name,memory.total', '--format=csv,noheader'],
            capture_output=True, text=True, timeout=5
        )
        if result.returncode == 0:
            lines = result.stdout.strip().split('\n')
            # Use first GPU if multiple
            gpu_line = lines[0].split(',')
            gpu_name = gpu_line[0].strip()

            # Parse memory (handle "15360 MiB" or "15.36 GB")
            mem_str = gpu_line[1].strip()
            if 'MiB' in mem_str:
                gpu_memory = float(re.findall(r'\d+', mem_str)[0]) / 1024  # MiB to GB
            else:
                gpu_memory = float(re.findall(r'[\d.]+', mem_str)[0])

            return gpu_name, gpu_memory
    except (subprocess.TimeoutExpired, FileNotFoundError, IndexError, ValueError):
        pass

    # Fallback to PyTorch
    if torch.cuda.is_available():
        gpu_name = torch.cuda.get_device_name(0)
        gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)  # Bytes to GB
        return gpu_name, gpu_memory

    # No GPU available
    return "CPU", 0.0

gpu_name, gpu_memory_gb = get_gpu_info()
gpu_name_lower = gpu_name.lower()

# Determine configuration tier (case-insensitive matching)
# CRITICAL FIX (Issue #9): CodeBERT max_seq_len is 512 (514 with special tokens) - RoBERTa limitation
# Using max_seq_len > 512 causes: RuntimeError: The expanded size of the tensor (1024) must match the existing size (514)
if 'a100' in gpu_name_lower:
    config_tier = 'AGGRESSIVE'
    config = {
        'transformer': {'epochs': 20, 'batch_size': 64, 'max_seq_len': 512, 'patience': 5},
        'gnn': {'epochs': 300, 'batch_size': 128, 'hidden_dim': 512, 'num_layers': 5, 'patience': 15},
        'fusion': {'n_folds': 10, 'epochs': 100}
    }
    note = "Maximum configuration - larger batches and more epochs for best training quality"
elif 'v100' in gpu_name_lower:
    config_tier = 'ENHANCED'
    config = {
        'transformer': {'epochs': 15, 'batch_size': 48, 'max_seq_len': 512, 'patience': 3},
        'gnn': {'epochs': 200, 'batch_size': 96, 'hidden_dim': 384, 'num_layers': 5, 'patience': 12},
        'fusion': {'n_folds': 5, 'epochs': 50}
    }
    note = "Enhanced configuration - 2-3x faster than T4, larger batches for better gradient estimates"
else:  # T4 or other
    config_tier = 'OPTIMIZED'
    config = {
        'transformer': {'epochs': 10, 'batch_size': 32, 'max_seq_len': 512, 'patience': 2},
        'gnn': {'epochs': 150, 'batch_size': 64, 'hidden_dim': 256, 'num_layers': 4, 'patience': 10},
        'fusion': {'n_folds': 5, 'epochs': 30}
    }
    note = "Optimized for T4 - reliable and cost-effective"

# Save config for training cells
config_data = {'tier': config_tier, 'gpu': gpu_name, 'config': config}
with open('/tmp/gpu_training_config.json', 'w') as f:
    json.dump(config_data, f)

print("="*70)
print("ADAPTIVE GPU CONFIGURATION")
print("="*70)
print(f"Detected GPU: {gpu_name}")
print(f"GPU Memory: {gpu_memory_gb:.2f} GB")
print(f"\nConfiguration Tier: {config_tier}")
print(f"Note: {note}")
print("\nHyperparameters:")
print(f"  Transformer: {config['transformer']['epochs']} epochs, batch {config['transformer']['batch_size']}, seq {config['transformer']['max_seq_len']}")
print(f"  GNN: {config['gnn']['epochs']} epochs, batch {config['gnn']['batch_size']}, hidden {config['gnn']['hidden_dim']}")
print(f"  Fusion: {config['fusion']['n_folds']} folds, {config['fusion']['epochs']} epochs")
print("\n💡 Note: max_seq_len is 512 for all configs (CodeBERT/RoBERTa model limit)")
print("="*70)

In [None]:
# Cell 2: Install dependencies with runtime detection and compatibility fixes
# ⚠️ CRITICAL: Includes NumPy compatibility fix, correct tokenizers version, and PyG error handling

import subprocess
import sys
import importlib

def run_cmd(cmd):
    """Run shell command and return success status."""
    print(f"Running: {cmd}")
    result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
    if result.returncode != 0:
        print(f"Error: {result.stderr}")
        return False
    return True

print("="*70)
print("INSTALLING DEPENDENCIES WITH COMPATIBILITY FIXES")
print("="*70)

# [1/9] CRITICAL: Fix NumPy version FIRST (before any torch imports)
print("\n[1/9] Ensuring NumPy compatibility...")
try:
    import numpy
    numpy_ver = numpy.__version__
    numpy_major = int(numpy_ver.split('.')[0])

    if numpy_major >= 2:
        print(f"⚠️  Detected NumPy {numpy_ver} (v2.x)")
        print("   PyTorch wheels may have binary incompatibility")
        print("   Downgrading to NumPy 1.26.4...")
        subprocess.run([sys.executable, "-m", "pip", "install", "-q", "numpy==1.26.4", "--force-reinstall"], check=True)
        print("✓ NumPy downgraded to 1.26.4")
        # Reload numpy
        importlib.reload(numpy)
        print(f"✓ NumPy {numpy.__version__} loaded (binary compatible)")
    else:
        print(f"✓ NumPy {numpy_ver} (v1.x - already compatible)")
except ImportError:
    print("NumPy not installed, installing 1.26.4...")
    subprocess.run([sys.executable, "-m", "pip", "install", "-q", "numpy==1.26.4"], check=True)
    import numpy
    print(f"✓ NumPy {numpy.__version__} installed")

# [2/9] Detect PyTorch and CUDA versions (now safe with correct numpy)
print("\n[2/9] Detecting PyTorch and CUDA versions...")
import torch

torch_version = torch.__version__.split('+')[0]  # e.g., '2.8.0'
cuda_version = torch.version.cuda  # e.g., '12.6'
cuda_tag = f"cu{cuda_version.replace('.', '')}" if cuda_version else 'cpu'  # e.g., 'cu126'

print(f"✓ Detected PyTorch {torch_version}")
print(f"✓ Detected CUDA {cuda_version if cuda_version else 'N/A (CPU only)'}")
print(f"✓ Using wheel tag: {cuda_tag}")

# [3/9] Install PyTorch Geometric with enhanced error handling
print("\n[3/9] Installing PyTorch Geometric (runtime-aware with fallback)...")
pyg_wheel_url = f"https://data.pyg.org/whl/torch-{torch_version}+{cuda_tag}.html"
print(f"Wheel URL: {pyg_wheel_url}")

pyg_packages = ['torch-scatter', 'torch-sparse', 'torch-cluster', 'torch-spline-conv']
pyg_install_success = True

for pkg in pyg_packages:
    print(f"  Installing {pkg}...")
    if not run_cmd(f"pip install -q {pkg} -f {pyg_wheel_url}"):
        print(f"    ⚠️  Wheel install failed, trying source build...")
        if not run_cmd(f"pip install -q {pkg} --no-binary {pkg}"):
            print(f"    ❌ Failed to install {pkg}")
            pyg_install_success = False
        else:
            print(f"    ✓ {pkg} installed from source (slower)")
    else:
        print(f"    ✓ {pkg} installed from wheel")

if pyg_install_success:
    run_cmd("pip install -q torch-geometric==2.4.0")
    print("✅ PyTorch Geometric installed successfully")
else:
    print("⚠️  Some PyG packages failed - GNN training may have issues")

# [4/9] Install Transformers with COMPATIBLE tokenizers version
print("\n[4/9] Installing Transformers with compatible tokenizers...")
print("⚠️  Note: Using tokenizers 0.14.1 (compatible with transformers 4.35.0)")

# Install transformers first, then pin tokenizers to compatible version
if not run_cmd("pip install -q transformers==4.35.0"):
    print("❌ Transformers installation failed")
else:
    # Now pin tokenizers to compatible version
    if not run_cmd("pip install -q tokenizers==0.14.1"):
        print("⚠️  Could not pin tokenizers to 0.14.1, using auto-resolved version")
    else:
        print("✓ Tokenizers 0.14.1 installed (compatible)")

# Install accelerate
run_cmd("pip install -q accelerate==0.24.1")

# [5/9] Install tree-sitter
print("\n[5/9] Installing tree-sitter...")
run_cmd("pip install -q tree-sitter==0.20.4")

# [6/9] Install additional packages
print("\n[6/9] Installing additional packages...")
run_cmd("pip install -q scikit-learn==1.3.2 scipy==1.11.4 tqdm")

# [7/9] Verify installations with enhanced checks
print("\n[7/9] Verifying installations...")
try:
    # Check NumPy first (critical)
    import numpy
    numpy_ver = numpy.__version__
    numpy_major = int(numpy_ver.split('.')[0])
    if numpy_major >= 2:
        print(f"⚠️  WARNING: NumPy {numpy_ver} detected (should be 1.x)")
        print("   Binary compatibility issues may occur")
    else:
        print(f"✓ NumPy: {numpy_ver} (binary compatible)")

    # Check other packages
    import torch
    import torch_geometric
    import transformers
    import tree_sitter
    import sklearn

    print(f"✓ PyTorch: {torch.__version__}")
    print(f"✓ PyTorch Geometric: {torch_geometric.__version__}")
    print(f"✓ Transformers: {transformers.__version__}")

    # Check tokenizers compatibility
    import tokenizers
    tokenizers_ver = tokenizers.__version__
    print(f"✓ Tokenizers: {tokenizers_ver}")

    if tokenizers_ver.startswith("0.15"):
        print(f"  ⚠️  WARNING: tokenizers {tokenizers_ver} may conflict with transformers 4.35.0")
    elif tokenizers_ver.startswith("0.14"):
        print(f"  ✓ Tokenizers version compatible")

    print(f"✓ tree-sitter: {tree_sitter.__version__}")
    print(f"✓ scikit-learn: {sklearn.__version__}")

except Exception as e:
    print(f"❌ Verification failed: {e}")
    print("   Please restart runtime and try again")
    print("   If issue persists, check:")
    print("   1. NumPy version (should be 1.26.4)")
    print("   2. Tokenizers version (should be 0.14.1)")

# [8/9] Test PyTorch Geometric installation
print("\n[8/9] Testing PyTorch Geometric...")
try:
    from torch_geometric.data import Data
    test_data = Data(x=torch.randn(5, 3), edge_index=torch.tensor([[0, 1], [1, 0]]))
    print("✓ PyTorch Geometric working correctly")
    print(f"✓ Test data created: {test_data}")
except Exception as e:
    print(f"⚠️  PyTorch Geometric test failed: {e}")
    print("   GNN training may have issues")
    print("   Possible causes:")
    print("   1. NumPy binary incompatibility")
    print("   2. PyG wheel installation failed")
    print("   3. CUDA version mismatch")

# [9/9] Display final summary
print("\n[9/9] Installation Summary:")
print("="*70)

success_indicators = {
    'numpy_compatible': numpy_major < 2 if 'numpy_major' in locals() else False,
    'pyg_installed': pyg_install_success,
    'transformers_installed': True,  # Assume success if we got here
    'tokenizers_compatible': tokenizers_ver.startswith("0.14") if 'tokenizers_ver' in locals() else False
}

all_success = all(success_indicators.values())

if all_success:
    print("✅ ALL INSTALLATIONS SUCCESSFUL")
    print("✓ NumPy 1.x (binary compatible)")
    print("✓ PyTorch Geometric with correct wheels")
    print("✓ Transformers with compatible tokenizers")
    print("✓ All packages verified")
else:
    print("⚠️  INSTALLATION COMPLETED WITH WARNINGS:")
    if not success_indicators['numpy_compatible']:
        print("  • NumPy version may cause binary incompatibility")
    if not success_indicators['pyg_installed']:
        print("  • PyG packages had installation issues")
    if not success_indicators['tokenizers_compatible']:
        print("  • Tokenizers version may conflict with transformers")
    print("\n  Training may still work, but monitor for errors")
print("="*70)

In [None]:
# Cell 2.5: Enhanced Version & Dependency Compatibility Check (v1.1)
# Validates versions, checks for dependency conflicts, validates PyG wheels

import torch
import torch_geometric
import transformers
import importlib
import sys

print("="*70)
print("ENHANCED DEPENDENCY & VERSION COMPATIBILITY CHECK")
print("="*70)

# [1/4] Check core versions
torch_ver = torch.__version__
pyg_ver = torch_geometric.__version__
transformers_ver = transformers.__version__
cuda_ver = torch.version.cuda if torch.cuda.is_available() else "N/A"

print(f"\n[1/4] Installed Core Versions:")
print(f"  PyTorch: {torch_ver}")
print(f"  PyTorch Geometric: {pyg_ver}")
print(f"  Transformers: {transformers_ver}")
print(f"  CUDA: {cuda_ver}")

# [2/4] Check for problematic optional dependencies (CRITICAL FIX #4)
print(f"\n[2/4] Checking Optional Dependencies:")
optional_deps = {
    'sentence_transformers': None,
    'datasets': None,
    'fsspec': None,
    'gcsfs': None
}

for pkg_name in optional_deps.keys():
    try:
        pkg = importlib.import_module(pkg_name)
        version = getattr(pkg, '__version__', 'unknown')
        optional_deps[pkg_name] = version
        print(f"  ⚠️  {pkg_name}: {version} (not needed for training)")
    except ImportError:
        print(f"  ✓ {pkg_name}: not installed (correct)")

# Check for version conflicts
has_conflicts = False
if optional_deps.get('sentence_transformers'):
    print("\n  ⚠️  WARNING: sentence-transformers detected")
    print("     May conflict with transformers==4.35.0")
    print("     If errors occur, uninstall: !pip uninstall -y sentence-transformers")
    has_conflicts = True

if optional_deps.get('datasets'):
    print("\n  ⚠️  WARNING: datasets library detected")
    print("     May pull incompatible transformers/tokenizers versions")
    has_conflicts = True

# [3/4] Validate PyG wheel URL (CRITICAL FIX #4)
print(f"\n[3/4] Validating PyTorch Geometric Installation:")
torch_version = torch_ver.split('+')[0]
cuda_tag = f"cu{cuda_ver.replace('.', '')}" if cuda_ver != "N/A" else 'cpu'
pyg_wheel_url = f"https://data.pyg.org/whl/torch-{torch_version}+{cuda_tag}.html"

print(f"  Expected wheel URL: {pyg_wheel_url}")

# Quick test PyG installation
try:
    from torch_geometric.data import Data
    test_data = Data(x=torch.randn(5, 3), edge_index=torch.tensor([[0, 1], [1, 0]]))
    print(f"  ✓ PyTorch Geometric working correctly")
    print(f"  ✓ Wheels matched PyTorch {torch_version} + {cuda_tag}")
except Exception as e:
    print(f"  ❌ PyTorch Geometric test failed: {e}")
    print(f"  ⚠️  Wheel URL may be incorrect - check {pyg_wheel_url}")

# [4/4] Core compatibility checks
print(f"\n[4/4] Core Compatibility Checks:")
warnings = []
errors = []

# Check PyTorch version
torch_major = int(torch_ver.split('.')[0])
if torch_major < 2:
    warnings.append("⚠️  PyTorch 2.x+ recommended (you have {torch_ver})")

# Check CUDA availability (CRITICAL)
if not torch.cuda.is_available():
    errors.append("❌ CUDA not available - training will be EXTREMELY slow")
    errors.append("   Enable GPU: Runtime → Change runtime type → GPU")

# Check PyG compatibility
pyg_major = int(pyg_ver.split('.')[0])
if pyg_major < 2:
    warnings.append("⚠️  PyTorch Geometric 2.x+ recommended")

# Check GPU memory
if torch.cuda.is_available():
    gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
    if gpu_mem_gb < 12:
        warnings.append(f"⚠️  GPU has only {gpu_mem_gb:.1f} GB RAM (16GB+ recommended)")
        warnings.append("   Consider reducing batch sizes if OOM errors occur")

# Display results
print("\n" + "="*70)
if errors:
    print("🔴 CRITICAL ERRORS:")
    for e in errors:
        print(f"  {e}")
    print("\n❌ CANNOT PROCEED - Fix errors above")
    print("="*70)
    raise RuntimeError("Environment validation failed")
elif warnings or has_conflicts:
    if warnings:
        print("⚠️  Compatibility Warnings:")
        for w in warnings:
            print(f"  {w}")
    if has_conflicts:
        print("\n⚠️  Dependency Conflicts Detected:")
        print("  Monitor for errors during training")
        print("  If issues occur, restart runtime and reinstall dependencies")
    print("\n✓ You can proceed but may need adjustments")
else:
    print("✅ ALL CHECKS PASSED - Ready for production training!")

print("="*70)

In [None]:
# Cell 3: Clone/Update repository from GitHub
import os
from pathlib import Path

# Clone or update StreamGuard repository
if not Path('streamguard').exists():
    print("Cloning StreamGuard repository...")
    !git clone https://github.com/VimalSajanGeorge/streamguard.git
    print("✓ Repository cloned")
else:
    print("✓ Repository already exists")
    print("Pulling latest changes...")
    os.chdir('streamguard')
    !git pull origin master
    print("✓ Repository updated")
    os.chdir('..')

os.chdir('streamguard')
print(f"\nWorking directory: {os.getcwd()}")
print("\n💡 All code changes from GitHub are now available!")
print("   No need to manually upload files to Google Drive")

In [None]:
# Cell 4: Setup tree-sitter with robust error handling
# ⚠️ CRITICAL: Includes fallback if build fails

from pathlib import Path
from tree_sitter import Language

print("="*70)
print("TREE-SITTER SETUP (with fallback support)")
print("="*70)

# Clone tree-sitter-c
vendor_dir = Path('vendor')
vendor_dir.mkdir(exist_ok=True)

if not (vendor_dir / 'tree-sitter-c').exists():
    print("\n[1/3] Cloning tree-sitter-c...")
    !cd vendor && git clone --depth 1 https://github.com/tree-sitter/tree-sitter-c.git
    print("✓ tree-sitter-c cloned")
else:
    print("\n[1/3] ✓ tree-sitter-c already exists")

# Build library with error handling
build_dir = Path('build')
build_dir.mkdir(exist_ok=True)
lib_path = build_dir / 'my-languages.so'

build_success = False

if not lib_path.exists():
    print("\n[2/3] Building tree-sitter library...")
    try:
        Language.build_library(
            str(lib_path),
            [str(vendor_dir / 'tree-sitter-c')]
        )
        print("✓ Build completed")

        # Verify build
        if lib_path.exists():
            print("\n[3/3] Verifying build...")
            try:
                test_lang = Language(str(lib_path), 'c')
                print("✓ tree-sitter library verified successfully")
                build_success = True
            except Exception as e:
                print(f"⚠️  Verification failed: {e}")
        else:
            print("⚠️  Build completed but library file not found")

    except Exception as e:
        print(f"⚠️  Build failed: {e}")
        print("   Common causes: missing compiler, permission issues")
else:
    print("\n[2/3] ✓ tree-sitter library already exists")
    print("\n[3/3] Verifying existing build...")
    try:
        test_lang = Language(str(lib_path), 'c')
        print("✓ Existing library verified")
        build_success = True
    except Exception as e:
        print(f"⚠️  Existing library invalid: {e}")

# Display final status
print("\n" + "="*70)
if build_success:
    print("✅ AST PARSING ENABLED (optimal)")
    print("   Preprocessing will use full AST structure")
else:
    print("⚠️  AST PARSING WILL USE FALLBACK MODE")
    print("   Preprocessing will use token-sequence graphs")
    print("   ✓ Training will still work correctly")
    print("   ✓ Performance impact: minimal (<5%)")
print("="*70)

### Platform Notes: tree-sitter on Windows/Linux

**Google Colab (Linux):**
- ✅ Works out-of-the-box with `.so` libraries
- ✅ GCC compiler available by default

**Windows (Local Development):**
- ⚠️  Requires Microsoft Visual C++ 14.0+ (MSVC)
- ⚠️  May fail with "compiler not found" errors
- **Solution 1:** Use WSL (Windows Subsystem for Linux) for preprocessing
- **Solution 2:** Use Colab for all preprocessing tasks
- **Solution 3:** Install Visual Studio Build Tools (large download)
- ✓ **Fallback:** Token-sequence graphs work fine (<5% performance impact)

**Recommendation:** For Windows users, use Colab for data preprocessing and training. Download preprocessed data to Windows only for inference/deployment.

---
## Part 1.5: Pre-Training Validation Tests (Issue #11 Fix Verification)

**IMPORTANT:** Run these test cells BEFORE full training to verify all Issue #11 fixes are working correctly.

These tests verify:
1. ✅ Class-balanced loss is working (model doesn't collapse to one class)
2. ✅ LR scaling and warmup are correct
3. ✅ Scheduler steps properly (per-step, not per-epoch)
4. ✅ Gradient clipping prevents exploding gradients
5. ✅ Prediction distribution monitoring detects collapse
6. ✅ Checkpoint saving/loading works with PyTorch 2.6+

**Expected Results:**
- **Test 1 (Tiny Overfitting Test):** Loss should decrease to near 0, F1 should reach 0.9+
- **Test 2 (Short Full-Data Test):** F1 should increase each epoch, prediction distribution should be balanced
- If tests pass, proceed to full training with confidence!

**Duration:** 5-10 minutes total

In [None]:
# Cell 6: Setup data from Google Drive
import os
import shutil
from pathlib import Path
import json

print("="*70)
print("SETTING UP DATA FROM GOOGLE DRIVE")
print("="*70)

# Ensure we're in the streamguard directory
os.chdir('/content/streamguard')
print(f"Working directory: {os.getcwd()}")

# Step 1: Mount Google Drive
print(f"\n[1/5] Mounting Google Drive...")
from google.colab import drive
drive.mount('/content/drive', force_remount=False)
print("✓ Google Drive mounted")

# Step 2: Check if data exists in Drive
drive_data_path = Path('/content/drive/MyDrive/streamguard/data/processed/codexglue')
print(f"\n[2/5] Checking for data in Google Drive...")
print(f"   Looking in: {drive_data_path}")

if not drive_data_path.exists():
    print(f"❌ ERROR: Data not found in Google Drive!")
    print(f"\n💡 Please upload the preprocessed data to Google Drive:")
    print(f"   1. Create folder: My Drive/streamguard/data/processed/codexglue/")
    print(f"   2. Upload these files:")
    print(f"      • train.jsonl (504 MB)")
    print(f"      • valid.jsonl (63 MB)")
    print(f"      • test.jsonl (63 MB)")
    print(f"      • preprocessing_metadata.json (1.6 KB)")
    print(f"\n   Total: ~630 MB")
    raise FileNotFoundError(f"Data not found in Drive: {drive_data_path}")

print(f"✓ Data found in Google Drive")

# Step 3: Check all required files
print(f"\n[3/5] Verifying data files in Drive...")
required_files = ['train.jsonl', 'valid.jsonl', 'test.jsonl', 'preprocessing_metadata.json']
missing_files = []

drive_sizes = {}
for file in required_files:
    file_path = drive_data_path / file
    if file_path.exists():
        size_mb = file_path.stat().st_size / (1024 * 1024)
        drive_sizes[file] = size_mb
        print(f"  ✓ {file:<30} ({size_mb:>8.2f} MB)")
    else:
        print(f"  ❌ {file:<30} MISSING")
        missing_files.append(file)

if missing_files:
    print(f"\n❌ ERROR: Missing {len(missing_files)} required file(s) in Drive")
    print(f"   Missing: {', '.join(missing_files)}")
    raise FileNotFoundError(f"Missing data files in Drive: {missing_files}")

total_size = sum(drive_sizes.values())
print(f"\n📦 Total data size in Drive: {total_size:.2f} MB")

# Step 4: Create local data directory and copy files
local_data_path = Path('/content/streamguard/data/processed/codexglue')
local_data_path.mkdir(parents=True, exist_ok=True)

print(f"\n[4/5] Copying data from Drive to Colab local storage...")
print(f"   Source: {drive_data_path}")
print(f"   Destination: {local_data_path}")
print(f"   (This provides faster I/O during training)\n")

for file in required_files:
    src = drive_data_path / file
    dst = local_data_path / file

    if dst.exists():
        # Check if sizes match (skip if already copied)
        src_size = src.stat().st_size
        dst_size = dst.stat().st_size
        if src_size == dst_size:
            print(f"  ✓ {file:<30} (already copied, skipping)")
            continue

    print(f"  📋 Copying {file:<30} ({drive_sizes[file]:.2f} MB)...", end='', flush=True)
    shutil.copy2(src, dst)
    print(" ✓")

print(f"\n✅ All data files copied to local storage!")

# Step 5: Load and display metadata
print(f"\n[5/5] Loading dataset statistics...")
metadata_path = local_data_path / 'preprocessing_metadata.json'
if metadata_path.exists():
    with open(metadata_path, 'r') as f:
        metadata = json.load(f)

    print(f"\n📊 Dataset Statistics:")
    total_samples = 0
    for split in ['train', 'validation', 'test']:
        if split in metadata:
            count = metadata[split].get('total_samples', 0)
            total_samples += count
            print(f"  {split.capitalize():<12}: {count:>6} samples")

    print(f"\n💡 Total samples: {total_samples:,}")

    # Show class distribution if available
    if 'train' in metadata and 'label_distribution' in metadata['train']:
        dist = metadata['train']['label_distribution']
        print(f"\n📊 Class Distribution (Training Set):")
        for label, count in dist.items():
            percentage = (count / metadata['train']['total_samples']) * 100
            print(f"  {label:<15}: {count:>6} ({percentage:>5.1f}%)")
else:
    print(f"  ⚠️  Metadata file not found")

print("\n" + "="*70)
print("✅ DATA SETUP COMPLETE - Ready for training!")
print("="*70)
print(f"\n💡 Training scripts will read from:")
print(f"   • {local_data_path / 'train.jsonl'}")
print(f"   • {local_data_path / 'valid.jsonl'}")
print(f"   • {local_data_path / 'test.jsonl'}")
print(f"\n💾 Data is now in Colab local storage (faster I/O than Drive)")
print("="*70)

---

## Part 2: Production Training Pipeline (v1.7)

**Multi-seed training with optimized LR Finder and safety features**

### What's New in v1.7:
- ✅ LR Finder runs ONCE (not 3 times) - saves 10-20 minutes
- ✅ Fixed data path bug (valid.jsonl not val.jsonl)
- ✅ Better error handling and progress tracking
- ✅ Automatic F1 score extraction from logs
- ✅ Graph data validation before training

### Training Pipeline:
1. **Data Validation** - Verify all required data exists
2. **Transformer Training** - 3 seeds: [42, 2025, 7] (~40-60 min)
3. **GNN Training** - 3 seeds: [42, 2025, 7] (~45-70 min)
4. **Results Summary** - Aggregated metrics across seeds

---


In [None]:
# Cell 12: Data Validation Pre-Flight Check
# Verify all required data exists before starting production training

import os
from pathlib import Path
import json

print('='*80)
print('DATA VALIDATION PRE-FLIGHT CHECK')
print('='*80 + '\n')

# Check CodeXGLUE data
data_files = {
    'Training data': 'data/processed/codexglue/train.jsonl',
    'Validation data': 'data/processed/codexglue/valid.jsonl',  # FIXED: not val.jsonl
    'Test data': 'data/processed/codexglue/test.jsonl'
}

all_files_exist = True
for name, path in data_files.items():
    file_path = Path(path)
    exists = file_path.exists()
    status = '✅' if exists else '❌'
    
    if exists:
        # Count lines
        with open(file_path, 'r') as f:
            count = sum(1 for _ in f)
        print(f'{status} {name:20s}: {path:50s} ({count:,} samples)')
    else:
        print(f'{status} {name:20s}: {path:50s} (NOT FOUND)')
        all_files_exist = False

# Check graph data (optional for GNN)
print('\n' + '-'*80)
print('GRAPH DATA (Required for GNN Training)')
print('-'*80 + '\n')

graph_train = Path('data/processed/graphs/train')
graph_val = Path('data/processed/graphs/val')

if graph_train.exists() and graph_val.exists():
    print('✅ Graph data found:')
    print(f'   Train: {graph_train}')
    print(f'   Val:   {graph_val}')
else:
    print('⚠️  Graph data not found. GNN training will fail.')
    print('\n   To create graph data, run:')
    print('   !python training/preprocessing/create_simple_graph_data.py \\')
    print('     --input data/processed/codexglue/train.jsonl \\')
    print('     --output data/processed/graphs/train')
    print('\n   (This will be automated in Cell 16 if needed)')

# Summary
print('\n' + '='*80)
if all_files_exist:
    print('✅ PRE-FLIGHT CHECK PASSED - Ready for production training!')
else:
    print('❌ PRE-FLIGHT CHECK FAILED - Missing required data files')
    print('\n   Please ensure data preprocessing completed successfully.')
print('='*80)


---

### Step 1: Transformer v1.7 Production Training

**3-seed training with LR Finder optimization**

**Duration:** ~40-60 minutes total
- LR Finder: 2-3 min (runs once)
- Seed 42: ~12-18 min
- Seed 2025: ~12-18 min
- Seed 7: ~12-18 min

**Key Features:**
- Multi-seed training for statistical validity
- Optimized LR Finder (runs once, cached for all seeds)
- Mixed precision training (faster on A100/V100)
- Weighted sampling for class balance
- Early stopping on F1 score

---


In [None]:

# Cell 13: Transformer Training Preset Catalog (Stability-first)

from dataclasses import dataclass, field
from typing import Dict, List, Optional
from pathlib import Path

_PROJECT_ROOT_HINTS = [
    Path('/content/streamguard'),
    Path.cwd(),
    Path.cwd() / 'streamguard'
]

if 'locate_project_root' not in globals():
    def locate_project_root():
        for candidate in _PROJECT_ROOT_HINTS:
            candidate = candidate.resolve()
            if (candidate / 'training' / 'train_transformer.py').exists():
                return candidate
        raise FileNotFoundError('Could not locate StreamGuard project root containing training/train_transformer.py')

PROJECT_ROOT = locate_project_root() if 'PROJECT_ROOT' not in globals() else Path(PROJECT_ROOT).resolve()
TRAIN_DATA = 'data/processed/codexglue/train.jsonl'
VAL_DATA = 'data/processed/codexglue/valid.jsonl'


@dataclass
class TrainingPreset:
    description: str
    epochs: int
    batch_size: int
    lr: float
    weight_multiplier: float
    dropout: float
    warmup_ratio: float
    max_seq_len: int
    accumulation_steps: int = 1
    early_stopping: int = 3
    weight_decay: float = 0.01
    mixed_precision: bool = True
    weighted_sampler: bool = True
    code_features: bool = False
    quick_test: bool = False
    focal_loss: bool = False
    lr_override: Optional[float] = None
    extra_args: List[str] = field(default_factory=list)
    notes: str = ''

    def cli_flags(self) -> List[str]:
        flags = [
            f'--epochs={self.epochs}',
            f'--batch-size={self.batch_size}',
            f'--lr={self.lr}',
            f'--weight-multiplier={self.weight_multiplier}',
            f'--dropout={self.dropout}',
            f'--warmup-ratio={self.warmup_ratio}',
            f'--max-seq-len={self.max_seq_len}',
            f'--accumulation-steps={self.accumulation_steps}',
            f'--early-stopping-patience={self.early_stopping}',
            f'--weight-decay={self.weight_decay}'
        ]
        if self.lr_override is not None:
            flags.append(f'--lr-override={self.lr_override}')
        if self.mixed_precision:
            flags.append('--mixed-precision')
        if self.weighted_sampler:
            flags.append('--use-weighted-sampler')
        if self.code_features:
            flags.append('--use-code-features')
        if self.quick_test:
            flags.append('--quick-test')
        if self.focal_loss:
            flags.append('--focal-loss')
        flags.extend(self.extra_args)
        return flags


TRAINING_PRESETS: Dict[str, TrainingPreset] = {
    'sanity_fast': TrainingPreset(
        description='2-epoch smoke test on 500 samples to confirm the pipeline runs end-to-end.',
        epochs=2,
        batch_size=16,
        lr=1.2e-5,
        weight_multiplier=1.3,
        dropout=0.15,
        warmup_ratio=0.05,
        max_seq_len=320,
        accumulation_steps=1,
        early_stopping=2,
        quick_test=True,
        notes='Finishes in <10 minutes on a T4 while validating the data/loader stack.'
    ),
    'balanced_medium': TrainingPreset(
        description='Stable full-dataset baseline for 22 GB GPUs (keeps memory <90%).',
        epochs=8,
        batch_size=24,
        lr=2.0e-5,
        weight_multiplier=1.5,
        dropout=0.20,
        warmup_ratio=0.10,
        max_seq_len=448,
        accumulation_steps=2,
        early_stopping=3,
        notes='Gradient accumulation (x2) keeps the effective batch ≈48 without triggering OOM.'
    ),
    'high_recall': TrainingPreset(
        description='High-recall preset with stronger weighting + code features.',
        epochs=12,
        batch_size=20,
        lr=1.8e-5,
        weight_multiplier=1.8,
        dropout=0.25,
        warmup_ratio=0.08,
        max_seq_len=448,
        accumulation_steps=2,
        early_stopping=4,
        code_features=True,
        focal_loss=True,
        notes='Prioritize vulnerable recall when auditing critical releases.'
    )
}

DEFAULT_PRESET = 'balanced_medium'

print('=' * 80)
print('TRANSFORMER TRAINING PRESET CATALOG')
print('=' * 80)
for name, preset in TRAINING_PRESETS.items():
    print()
    print(f"Preset: {name}")
    print(f"  Description: {preset.description}")
    print(
        f"  Key hyperparameters: batch={preset.batch_size}, epochs={preset.epochs}, "
        f"lr={preset.lr:.2e}, accumulation={preset.accumulation_steps}, "
        f"max_seq_len={preset.max_seq_len}, weight_mult={preset.weight_multiplier}"
    )
    if preset.notes:
        print(f"  Notes: {preset.notes}")

print()
print(f"Default preset: {DEFAULT_PRESET}")


In [None]:

# Cell 14: Transformer Preset Training Runner (Single/Multi Seed)

import os
import sys
import json
import subprocess
from pathlib import Path
import re
from datetime import datetime
from statistics import mean, pstdev

if 'TRAINING_PRESETS' not in globals():
    raise RuntimeError('Preset catalog not initialized. Run Cell 13 first.')

PRESET_NAME = DEFAULT_PRESET
SEEDS = [42]
ENABLE_LR_FINDER = False
FORCE_LR_FINDER = False
EXTRA_FLAGS = []  # e.g. ['--use-code-features']

project_root = Path(PROJECT_ROOT).resolve() if 'PROJECT_ROOT' in globals() else locate_project_root()
PROJECT_ROOT = project_root
if Path.cwd().resolve() != project_root:
    os.chdir(project_root)
    print(f"[setup] Working directory set to {project_root}")

for label, rel_path in [('train', TRAIN_DATA), ('validation', VAL_DATA)]:
    data_path = project_root / rel_path
    if not data_path.exists():
        raise FileNotFoundError(f"Missing {label} data file: {data_path}")

OUTPUT_BASE = project_root / 'training' / 'outputs' / 'transformer_presets' / PRESET_NAME
OUTPUT_BASE.mkdir(parents=True, exist_ok=True)
LOG_DIR = OUTPUT_BASE / '_logs'
LOG_DIR.mkdir(parents=True, exist_ok=True)

if 'stream_subprocess' not in globals():
    def stream_subprocess(name, cmd, log_path):
        print('=' * 80)
        print(f"[run] {name}")
        print('=' * 80)
        print(' '.join(map(str, cmd)))
        print('-' * 80)
        log_path.parent.mkdir(parents=True, exist_ok=True)
        captured_lines = []
        with log_path.open('w', encoding='utf-8') as log_file:
            process = subprocess.Popen(
                cmd,
                stdout=subprocess.PIPE,
                stderr=subprocess.STDOUT,
                text=True,
                cwd=project_root
            )
            for line in process.stdout:
                print(line, end='')
                log_file.write(line)
                log_file.flush()
                captured_lines.append(line)
            return_code = process.wait()
        print('-' * 80)
        print(f"[log] Output saved to {log_path}")
        print()
        output_text = ''.join(captured_lines)
        if return_code != 0:
            raise subprocess.CalledProcessError(return_code, cmd, output=output_text)
        return output_text

preset = TRAINING_PRESETS[PRESET_NAME]
print('=' * 80)
print('TRANSFORMER PRESET TRAINING')
print('=' * 80)
print(f"Preset: {PRESET_NAME} -> {preset.description}")
print(
    f"Preset HP: batch={preset.batch_size}, epochs={preset.epochs}, lr={preset.lr:.2e}, "
    f"accumulation={preset.accumulation_steps}, max_seq_len={preset.max_seq_len}, weight_mult={preset.weight_multiplier}"
)
print(f"Seeds: {SEEDS}")
print(f"LR Finder: {'ON' if ENABLE_LR_FINDER else 'OFF'}")
print(f"Output Base: {OUTPUT_BASE}")
print('=' * 80)
print()

if not SEEDS:
    raise ValueError('SEEDS list must contain at least one value.')

results = []
f1_pattern = re.compile(
    r"(?:New best model! F1: ([0-9.]+))|(?:Best validation F1 \(vulnerable\): ([0-9.]+))"
)

for idx, seed in enumerate(SEEDS, start=1):
    print('=' * 80)
    print(f"TRAINING WITH SEED: {seed} ({idx}/{len(SEEDS)})")
    print('=' * 80)
    print()

    seed_output_dir = OUTPUT_BASE / f'seed_{seed}'
    seed_output_dir.mkdir(parents=True, exist_ok=True)

    cmd = [
        sys.executable,
        'training/train_transformer.py',
        f'--train-data={TRAIN_DATA}',
        f'--val-data={VAL_DATA}',
        f'--output-dir={seed_output_dir.as_posix()}',
        f'--seed={seed}'
    ]
    cmd.extend(preset.cli_flags())
    if EXTRA_FLAGS:
        cmd.extend(EXTRA_FLAGS)
    if ENABLE_LR_FINDER:
        cmd.append('--find-lr')
        if FORCE_LR_FINDER:
            cmd.append('--force-find-lr')

    train_log = LOG_DIR / f'seed_{seed}.log'
    try:
        output_text = stream_subprocess(f'Transformer training (preset {PRESET_NAME}, seed {seed})', cmd, train_log)

        metrics_path = seed_output_dir / 'metrics.json'
        best_f1 = None
        if metrics_path.exists():
            try:
                with metrics_path.open('r', encoding='utf-8') as metrics_file:
                    metrics_payload = json.load(metrics_file)
                if 'best_f1_vulnerable' in metrics_payload:
                    best_f1 = float(metrics_payload['best_f1_vulnerable'])
            except Exception as metrics_err:
                print(f"[warn] Failed to read {metrics_path}: {metrics_err}")

        if best_f1 is None:
            matches = f1_pattern.findall(output_text)
            if matches:
                flattened = [value for pair in matches for value in pair if value]
                if flattened:
                    best_f1 = float(flattened[-1])

        if best_f1 is None:
            best_f1 = 0.0

        results.append({'seed': seed, 'best_f1': best_f1})
        print(f"[ok] Seed {seed} complete. Best F1: {best_f1:.4f}")
    except subprocess.CalledProcessError as exc:
        tail = '\n'.join(exc.output.strip().splitlines()[-25:]) if exc.output else 'No stderr captured.'
        print(f"[error] Training failed for seed {seed} (exit {exc.returncode}).")
        print(f'        See log: {train_log}')
        print('        Last log lines:')
        print(tail)
        results.append({'seed': seed, 'error': f'exit {exc.returncode}'})

print('-' * 80)
print('PRESET RUN SUMMARY')
print('-' * 80)

completed = [r for r in results if 'error' not in r]
if completed:
    f1_scores = [r['best_f1'] for r in completed]
    avg = float(mean(f1_scores))
    std = float(pstdev(f1_scores)) if len(f1_scores) > 1 else 0.0
    for r in completed:
        print(f"  Seed {r['seed']:4d}: F1 = {r['best_f1']:.4f}")
    print()
    print(f"  Mean F1: {avg:.4f} +/- {std:.4f}")
else:
    print('[!] All preset runs failed. Check logs for details.')

summary = {
    'preset': PRESET_NAME,
    'description': preset.description,
    'seeds': SEEDS,
    'results': results,
    'timestamp': datetime.now().isoformat()
}
if completed:
    summary['mean_f1'] = avg
    summary['std_f1'] = std

summary_path = OUTPUT_BASE / 'preset_summary.json'
with summary_path.open('w', encoding='utf-8') as f:
    json.dump(summary, f, indent=2)
print()
print(f"[ok] Preset summary saved to {summary_path}")


In [None]:

# Cell 15: Transformer v1.7 Production Training
# Multi-seed training with stabilized hyperparameters (LR Finder optional)

import os
import subprocess
import sys
import json
from pathlib import Path
import numpy as np
from datetime import datetime
import re

SEEDS = [42, 2025, 7]
TRAIN_DATA = 'data/processed/codexglue/train.jsonl'
VAL_DATA = 'data/processed/codexglue/valid.jsonl'

EPOCHS = 10
BATCH_SIZE = 32
ACCUMULATION_STEPS = 2
BASE_LR = 2.0e-5
WEIGHT_MULTIPLIER = 1.6
DROPOUT = 0.20
WARMUP_RATIO = 0.10
MAX_SEQ_LEN = 448
EARLY_STOPPING = 3
USE_WEIGHTED_SAMPLER = True
USE_CODE_FEATURES = False
ENABLE_LR_FINDER = False
FORCE_LR_FINDER = False
LR_FINDER_ITERATIONS = 100

if 'locate_project_root' not in globals():
    def locate_project_root():
        candidates = [
            Path('/content/streamguard'),
            Path.cwd(),
            Path.cwd() / 'streamguard'
        ]
        for candidate in candidates:
            candidate = candidate.resolve()
            if (candidate / 'training' / 'train_transformer.py').exists():
                return candidate
        raise FileNotFoundError('Could not locate StreamGuard project root containing training/train_transformer.py')

if 'PROJECT_ROOT' in globals():
    PROJECT_ROOT = Path(PROJECT_ROOT).resolve()
else:
    PROJECT_ROOT = locate_project_root()

if Path.cwd().resolve() != PROJECT_ROOT:
    os.chdir(PROJECT_ROOT)
    print(f"[setup] Working directory set to {PROJECT_ROOT}")

OUTPUT_DIR = PROJECT_ROOT / 'training' / 'outputs' / 'transformer_v17'
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
LOG_DIR = OUTPUT_DIR / '_logs'
LOG_DIR.mkdir(parents=True, exist_ok=True)

TRAIN_DATA_PATH = PROJECT_ROOT / TRAIN_DATA
VAL_DATA_PATH = PROJECT_ROOT / VAL_DATA
for label, file_path in [('train', TRAIN_DATA_PATH), ('validation', VAL_DATA_PATH)]:
    if not file_path.exists():
        raise FileNotFoundError(f"Missing {label} data file: {file_path}")

if 'stream_subprocess' not in globals():
    def stream_subprocess(name, cmd, log_path):
        print('=' * 80)
        print(f"[run] {name}")
        print('=' * 80)
        print(' '.join(map(str, cmd)))
        print('-' * 80)
        log_path.parent.mkdir(parents=True, exist_ok=True)
        captured_lines = []
        with log_path.open('w', encoding='utf-8') as log_file:
            process = subprocess.Popen(
                cmd,
                stdout=subprocess.PIPE,
                stderr=subprocess.STDOUT,
                text=True,
                cwd=PROJECT_ROOT
            )
            for line in process.stdout:
                print(line, end='')
                log_file.write(line)
                log_file.flush()
                captured_lines.append(line)
            return_code = process.wait()
        print('-' * 80)
        print(f"[log] Output saved to {log_path}")
        print()
        output_text = ''.join(captured_lines)
        if return_code != 0:
            raise subprocess.CalledProcessError(return_code, cmd, output=output_text)
        return output_text

print('=' * 80)
print('TRANSFORMER v1.7 PRODUCTION TRAINING')
print('=' * 80)
print(f"Seeds: {SEEDS}")
print(f"Output: {OUTPUT_DIR}")
print(f"Config: epochs={EPOCHS}, batch={BATCH_SIZE}, accumulation={ACCUMULATION_STEPS}, max_seq_len={MAX_SEQ_LEN}")
print(f"LR finder: {'ENABLED' if ENABLE_LR_FINDER else 'DISABLED'}")
print('=' * 80)
print()

if ENABLE_LR_FINDER:
    print('[1/2] Running LR Finder (once for all seeds)...')
    print()

    lr_finder_cmd = [
        sys.executable,
        'training/train_transformer.py',
        f'--train-data={TRAIN_DATA}',
        f'--val-data={VAL_DATA}',
        '--output-dir=training/outputs/.lr_finder_temp',
        '--quick-test',
        '--find-lr',
        f'--lr-finder-iterations={LR_FINDER_ITERATIONS}',
        '--epochs=5',
        '--batch-size=16',
        '--seed=42'
    ]
    if FORCE_LR_FINDER:
        lr_finder_cmd.append('--force-find-lr')

    lr_finder_log = LOG_DIR / 'lr_finder.log'
    try:
        stream_subprocess('LR Finder (quick run)', lr_finder_cmd, lr_finder_log)
        print('[ok] LR Finder complete. Cache ready for all seeds.')
        print()
    except subprocess.CalledProcessError as exc:
        print(f"[warn] LR Finder failed with exit code {exc.returncode}. See {lr_finder_log} for details.")
        print('       Continuing with BASE_LR fallback.')
        print()

    training_step_label = '[2/2]'
else:
    print('[1/1] LR Finder disabled — starting training with preset hyperparameters.')
    print()
    training_step_label = '[1/1]'

print(f"{training_step_label} Training all seeds with stabilized config...")
print()

results_all_seeds = []
f1_pattern = re.compile(r'Best F1: ([0-9.]+)')

for seed_idx, seed in enumerate(SEEDS, start=1):
    print('=' * 80)
    print(f"TRAINING WITH SEED: {seed} ({seed_idx}/{len(SEEDS)})")
    print('=' * 80)
    print()

    seed_output_dir = OUTPUT_DIR / f'seed_{seed}'
    seed_output_dir.mkdir(parents=True, exist_ok=True)

    cmd = [
        sys.executable,
        'training/train_transformer.py',
        f'--train-data={TRAIN_DATA}',
        f'--val-data={VAL_DATA}',
        f'--output-dir={seed_output_dir.as_posix()}',
        f'--seed={seed}',
        f'--epochs={EPOCHS}',
        f'--batch-size={BATCH_SIZE}',
        f'--accumulation-steps={ACCUMULATION_STEPS}',
        f'--lr={BASE_LR}',
        f'--weight-multiplier={WEIGHT_MULTIPLIER}',
        f'--dropout={DROPOUT}',
        f'--warmup-ratio={WARMUP_RATIO}',
        f'--max-seq-len={MAX_SEQ_LEN}',
        f'--early-stopping-patience={EARLY_STOPPING}'
    ]

    cmd.append('--mixed-precision')
    if USE_WEIGHTED_SAMPLER:
        cmd.append('--use-weighted-sampler')
    if USE_CODE_FEATURES:
        cmd.append('--use-code-features')
    if ENABLE_LR_FINDER:
        cmd.append('--find-lr')
        cmd.append(f'--lr-finder-iterations={LR_FINDER_ITERATIONS}')
        if FORCE_LR_FINDER:
            cmd.append('--force-find-lr')

    train_log = LOG_DIR / f'seed_{seed}.log'
    try:
        output_text = stream_subprocess(f'Transformer training (seed {seed})', cmd, train_log)
        matches = f1_pattern.findall(output_text)
        best_f1 = float(matches[-1]) if matches else 0.0
        results_all_seeds.append({'seed': seed, 'best_f1': best_f1})
        print(f"[ok] Seed {seed} complete. Best F1: {best_f1:.4f}")
    except subprocess.CalledProcessError as exc:
        tail = '\n'.join(exc.output.strip().splitlines()[-25:]) if exc.output else 'No stderr captured.'
        print(f"[error] Training failed for seed {seed} (exit {exc.returncode}).")
        print(f'        See log: {train_log}')
        print('        Last log lines:')
        print(tail)
        results_all_seeds.append({'seed': seed, 'error': f'exit {exc.returncode}'})

print('=' * 80)
print('TRANSFORMER TRAINING COMPLETE')
print('=' * 80)
print()

valid_results = [r for r in results_all_seeds if 'error' not in r]
if valid_results:
    f1_scores = [r['best_f1'] for r in valid_results]
    mean_f1 = float(np.mean(f1_scores))
    std_f1 = float(np.std(f1_scores))

    print(f"Results across {len(valid_results)} completed seeds:")
    for r in valid_results:
        print(f"  Seed {r['seed']:4d}: F1 = {r['best_f1']:.4f}")
    print()
    print(f"Mean F1: {mean_f1:.4f} +/- {std_f1:.4f}")
    print()

    summary = {
        'model': 'transformer_v17',
        'timestamp': datetime.now().isoformat(),
        'seeds': SEEDS,
        'results': results_all_seeds,
        'mean_f1': mean_f1,
        'std_f1': std_f1,
        'config': {
            'epochs': EPOCHS,
            'batch_size': BATCH_SIZE,
            'accumulation_steps': ACCUMULATION_STEPS,
            'base_lr': BASE_LR,
            'weight_multiplier': WEIGHT_MULTIPLIER,
            'dropout': DROPOUT,
            'warmup_ratio': WARMUP_RATIO,
            'max_seq_len': MAX_SEQ_LEN,
            'use_weighted_sampler': USE_WEIGHTED_SAMPLER,
            'use_code_features': USE_CODE_FEATURES,
            'enable_lr_finder': ENABLE_LR_FINDER
        }
    }

    summary_path = OUTPUT_DIR / 'production_summary.json'
    with summary_path.open('w', encoding='utf-8') as f:
        json.dump(summary, f, indent=2)
    print(f"[ok] Summary saved to {summary_path}")
else:
    print(f"[error] All seeds failed. Review logs under {LOG_DIR}")


---

### Step 2: GNN v1.7 Production Training

**Graph-based vulnerability detection with multi-seed training**

**Duration:** ~45-70 minutes total
- Graph preprocessing: ~5-10 min (if needed)
- LR Finder: 2-3 min (runs once)
- Seed 42: ~13-20 min
- Seed 2025: ~13-20 min
- Seed 7: ~13-20 min

**Key Features:**
- Graph Neural Network architecture
- Focal loss for hard negatives
- Weighted sampling + focal loss combination
- 15 epochs per seed (vs 10 for Transformer)

---


In [None]:
# Cell 16: Graph Data Preprocessing (Conditional)
# Check if graph data exists, if not run preprocessing

import subprocess
import sys
from pathlib import Path

print('='*80)
print('GRAPH DATA PREPROCESSING CHECK')
print('='*80 + '\n')

graph_train = Path('data/processed/graphs/train')
graph_val = Path('data/processed/graphs/val')

if graph_train.exists() and graph_val.exists():
    print('✅ Graph data already exists. Skipping preprocessing.\n')
    print(f'   Train: {graph_train}')
    print(f'   Val:   {graph_val}')
else:
    print('⚠️  Graph data not found. Running preprocessing...\n')
    
    # Create graph data for training set
    print('[1/2] Creating graph data for training set...')
    cmd_train = [
        sys.executable,
        'training/preprocessing/create_simple_graph_data.py',
        '--input', 'data/processed/codexglue/train.jsonl',
        '--output', 'data/processed/graphs/train'
    ]
    
    try:
        subprocess.run(cmd_train, check=True)
        print('✅ Training graph data created\n')
    except subprocess.CalledProcessError as e:
        print(f'❌ Failed to create training graph data: {e}\n')
        raise
    
    # Create graph data for validation set
    print('[2/2] Creating graph data for validation set...')
    cmd_val = [
        sys.executable,
        'training/preprocessing/create_simple_graph_data.py',
        '--input', 'data/processed/codexglue/valid.jsonl',
        '--output', 'data/processed/graphs/val'
    ]
    
    try:
        subprocess.run(cmd_val, check=True)
        print('✅ Validation graph data created\n')
    except subprocess.CalledProcessError as e:
        print(f'❌ Failed to create validation graph data: {e}\n')
        raise
    
    print('\n✅ Graph preprocessing complete!')

print('='*80)
print('READY FOR GNN TRAINING')
print('='*80)


In [None]:
# Cell 17: GNN v1.7 Production Training
# Multi-seed GNN training with optimized LR Finder

import subprocess
import sys
import json
from pathlib import Path
import numpy as np
from datetime import datetime
import re

# Configuration
SEEDS = [42, 2025, 7]
OUTPUT_DIR = Path('training/outputs/gnn_v17')
TRAIN_DATA = 'data/processed/graphs/train'
VAL_DATA = 'data/processed/graphs/val'

# Verify graph data exists
if not Path(TRAIN_DATA).exists():
    print(f'❌ Graph training data not found: {TRAIN_DATA}')
    print('Run Cell 16 first to create graph data.')
    raise FileNotFoundError(f'Graph data not found: {TRAIN_DATA}')

print('='*80)
print('GNN v1.7 PRODUCTION TRAINING')
print('='*80)
print(f'Seeds: {SEEDS}')
print(f'Output: {OUTPUT_DIR}')
print('='*80 + '\n')

OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# ============================================================
# STEP 1: Run LR Finder ONCE (quick mode, ~2-3 min)
# ============================================================
print('[1/2] Running LR Finder (once for all seeds)...\n')

lr_finder_cmd = [
    sys.executable,
    'training/train_gnn.py',
    f'--train-data={TRAIN_DATA}',
    f'--val-data={VAL_DATA}',
    '--output-dir=training/outputs/.lr_finder_temp_gnn',
    '--quick-test',
    '--find-lr',
    '--force-find-lr',
    '--epochs=5',
    '--batch-size=16',
    '--seed=42'
]

try:
    print('⏳ Running LR Finder on 64 samples...')
    result = subprocess.run(lr_finder_cmd, check=True, capture_output=True, text=True)
    print('✅ LR Finder complete! LR cached for all seeds.\n')
except subprocess.CalledProcessError as e:
    print(f'⚠️  LR Finder failed: {e}')
    print('Continuing with default LR (1e-3)...\n')

# ============================================================
# STEP 2: Train all seeds (uses cached LR)
# ============================================================
print('[2/2] Training all seeds with cached LR...\n')

results_all_seeds = []

for seed_idx, seed in enumerate(SEEDS):
    print(f'\n{"="*80}')
    print(f'TRAINING WITH SEED: {seed} ({seed_idx + 1}/{len(SEEDS)})')
    print(f'{"="*80}\n')
    
    seed_output_dir = OUTPUT_DIR / f'seed_{seed}'
    seed_output_dir.mkdir(parents=True, exist_ok=True)
    
    cmd = [
        sys.executable,
        'training/train_gnn.py',
        f'--train-data={TRAIN_DATA}',
        f'--val-data={VAL_DATA}',
        f'--output-dir={seed_output_dir}',
        f'--seed={seed}',
        '--epochs=15',
        '--batch-size=64',
        '--mixed-precision',
        '--find-lr',  # Uses cache (instant)
        '--use-weighted-sampler',
        '--focal-loss'
    ]
    
    try:
        result = subprocess.run(cmd, check=True, capture_output=True, text=True)
        
        # Extract F1 from logs
        best_f1 = 0.0
        f1_pattern = r'Best F1: ([0-9.]+)'
        matches = re.findall(f1_pattern, result.stdout)
        if matches:
            best_f1 = float(matches[-1])
        
        results_all_seeds.append({'seed': seed, 'best_f1': best_f1})
        print(f'\n✅ Seed {seed} complete. Best F1: {best_f1:.4f}')
        
    except subprocess.CalledProcessError as e:
        print(f'\n❌ Training failed for seed {seed}: {e}')
        results_all_seeds.append({'seed': seed, 'error': str(e)})

# ============================================================
# Aggregate
# ============================================================
print(f'\n{"="*80}')
print('GNN TRAINING COMPLETE')
print(f'{"="*80}\n')

valid_results = [r for r in results_all_seeds if 'error' not in r]
if valid_results:
    f1_scores = [r['best_f1'] for r in valid_results]
    mean_f1 = np.mean(f1_scores)
    std_f1 = np.std(f1_scores)
    
    print(f'📊 Results across {len(valid_results)} seeds:')
    for r in valid_results:
        print(f"  Seed {r['seed']:4d}: F1 = {r['best_f1']:.4f}")
    print(f'\n📈 Mean F1: {mean_f1:.4f} ± {std_f1:.4f}')
    
    summary = {
        'model': 'gnn_v17',
        'timestamp': datetime.now().isoformat(),
        'seeds': SEEDS,
        'results': results_all_seeds,
        'mean_f1': float(mean_f1),
        'std_f1': float(std_f1)
    }
    
    summary_path = OUTPUT_DIR / 'production_summary.json'
    with open(summary_path, 'w') as f:
        json.dump(summary, f, indent=2)
    print(f'\n💾 Summary saved: {summary_path}')
else:
    print('❌ All seeds failed. Check errors above.')


---

## 🎉 Production Training Complete!

### Results Summary:

**Transformer v1.7:**
- Results: `training/outputs/transformer_v17/production_summary.json`
- Checkpoints: `training/outputs/transformer_v17/seed_*/best_model.pt`

**GNN v1.7:**
- Results: `training/outputs/gnn_v17/production_summary.json`
- Checkpoints: `training/outputs/gnn_v17/seed_*/best_model.pt`

### Next Steps:

1. **View Results:** Run Cell 20 to see aggregated metrics
2. **Fusion Training (Optional):** See Cell 23 for instructions
3. **Model Deployment:** Use best checkpoints for inference

---


---

## Part 3: Optional Advanced Features

The following cells are optional:
- **Cell 20:** View detailed training results
- **Cell 21:** Run LR Finder safety validation test
- **Cell 22-23:** Fusion training instructions (requires Transformer + GNN checkpoints)

---


In [None]:
# Cell 20: View Detailed Training Results
# Load and display production summaries

import json
from pathlib import Path

print('='*80)
print('PRODUCTION TRAINING RESULTS')
print('='*80 + '\n')

# Transformer results
transformer_summary = Path('training/outputs/transformer_v17/production_summary.json')
if transformer_summary.exists():
    with open(transformer_summary, 'r') as f:
        t_data = json.load(f)
    
    print('📊 TRANSFORMER v1.7')
    print('-'*80)
    print(f"Seeds trained: {t_data['seeds']}")
    print(f"Mean F1: {t_data['mean_f1']:.4f} ± {t_data['std_f1']:.4f}")
    print('\nPer-seed results:')
    for r in t_data['results']:
        if 'error' not in r:
            print(f"  Seed {r['seed']:4d}: F1 = {r['best_f1']:.4f}")
        else:
            print(f"  Seed {r['seed']:4d}: FAILED")
else:
    print('⚠️  Transformer results not found. Run Cell 14 first.')

print('\n' + '='*80 + '\n')

# GNN results
gnn_summary = Path('training/outputs/gnn_v17/production_summary.json')
if gnn_summary.exists():
    with open(gnn_summary, 'r') as f:
        g_data = json.load(f)
    
    print('📊 GNN v1.7')
    print('-'*80)
    print(f"Seeds trained: {g_data['seeds']}")
    print(f"Mean F1: {g_data['mean_f1']:.4f} ± {g_data['std_f1']:.4f}")
    print('\nPer-seed results:')
    for r in g_data['results']:
        if 'error' not in r:
            print(f"  Seed {r['seed']:4d}: F1 = {r['best_f1']:.4f}")
        else:
            print(f"  Seed {r['seed']:4d}: FAILED")
else:
    print('⚠️  GNN results not found. Run Cell 17 first.')

print('\n' + '='*80)


In [None]:
# Cell 21: LR Finder Safety Validation Test (Optional)
# Quick test of LR Finder with safety validation on small subset (2-3 min)

import os
os.chdir('/content/streamguard')

print('='*70)
print('LR FINDER SAFETY VALIDATION TEST')
print('='*70)
print('Testing LR Finder with safety validation on 64 samples')
print('Duration: ~2-3 minutes')
print('='*70)

!python training/train_transformer.py \
  --train-data data/processed/codexglue/train.jsonl \
  --val-data data/processed/codexglue/valid.jsonl \
  --quick-test \
  --find-lr \
  --epochs 5 \
  --batch-size 16 \
  --seed 42

print('\n' + '='*70)
print('✅ LR Finder test complete!')
print('='*70)
print('\n📋 Check the output above for:')
print('  • LR Finder curve analysis (confidence: high/medium/low)')
print('  • Safety validation (cap applied? fallback used?)')
print('  • Suggested LR and final used LR')
print('  • Cache saved for future runs')
print('='*70)


---

### Optional: Fusion Training

**Fusion training combines Transformer + GNN models**

**Prerequisites:**
- Completed Transformer training (Cell 14)
- Completed GNN training (Cell 17)

**Note:** Fusion training is currently a manual process. Cell 23 will check prerequisites and provide the command to run.

---


In [None]:
# Cell 23: Fusion Training Instructions (Optional)
# Check prerequisites and provide manual fusion training command

from pathlib import Path

print('='*80)
print('FUSION v1.7 TRAINING PREREQUISITES')
print('='*80 + '\n')

transformer_checkpoint = Path('training/outputs/transformer_v17/seed_42/best_model.pt')
gnn_checkpoint = Path('training/outputs/gnn_v17/seed_42/best_model.pt')

transformer_ok = transformer_checkpoint.exists()
gnn_ok = gnn_checkpoint.exists()

print(f'{'✅' if transformer_ok else '❌'} Transformer checkpoint: {transformer_checkpoint}')
print(f'{'✅' if gnn_ok else '❌'} GNN checkpoint: {gnn_checkpoint}')

if not transformer_ok:
    print('\n⚠️  Transformer checkpoint not found. Run Cell 14 first.')
elif not gnn_ok:
    print('\n⚠️  GNN checkpoint not found. Run Cell 17 first.')
else:
    print('\n✅ Prerequisites met! Ready for fusion training.')
    print('\n' + '='*80)
    print('FUSION TRAINING COMMAND')
    print('='*80 + '\n')
    print('Run the following command to train fusion model:')
    print('\n```bash')
    print('!python training/train_fusion.py \\')
    print('  --train-data data/processed/codexglue/train.jsonl \\')
    print('  --val-data data/processed/codexglue/valid.jsonl \\')
    print(f'  --transformer-checkpoint {transformer_checkpoint} \\')
    print(f'  --gnn-checkpoint {gnn_checkpoint} \\')
    print('  --output-dir training/outputs/fusion_v17/seed_42 \\')
    print('  --epochs 12 \\')
    print('  --batch-size 32')
    print('```')
    print('\n' + '='*80)
    print('Note: Fusion training is optional. Transformer and GNN models')
    print('can be used independently for inference.')
    print('='*80)
