# EnStack: Advanced Stacking Ensemble for Vulnerability Detection

This notebook provides a professional, fully optimized pipeline for reproducing the results of the EnStack paper on Google Colab.

### ‚ö° Optimized Features:
1.  **High-Speed Training:** Automatic Mixed Precision (AMP) and Dynamic Padding (+5-8x speed).
2.  **Memory Efficient:** Lazy Loading and Gradient Checkpointing (Run large LLMs on T4 GPU).
3.  **Algorithmic Correctness:** K-Fold Out-of-Fold (OOF) stacking to prevent data leakage.
4.  **Advanced Visualization:** Confusion matrices, ROC curves, and Feature Importance plots.
5.  **Production Ready:** Export models to ONNX and TorchScript.
6.  **Robust Checkpoint System:** Atomic saves, resume capability, crash recovery.

---

## 1. Environment Setup

In [None]:
import os

from google.colab import drive

# 1. Mount Drive
print("üìÇ Connecting to Google Drive...")
drive.mount('/content/drive')

# 2. Clone Repository
REPO_NAME = "EnStack-paper" # @param {type:"string"}
GITHUB_USER = "TCTri205" # @param {type:"string"}

%cd /content
if not os.path.exists(REPO_NAME):
    print(f"‚¨áÔ∏è Cloning {REPO_NAME}...")
    !git clone https://github.com/{GITHUB_USER}/{REPO_NAME}.git
else:
    print("üîÑ Repository exists. Pulling latest optimized version...")
    !cd {REPO_NAME} && git pull

%cd /content/{REPO_NAME}

# 3. Install Dependencies
print("üì¶ Installing high-performance dependencies...")
!pip install -r requirements.txt -q
!pip install transformers[torch] datasets pyarrow xgboost tensorboard seaborn matplotlib -q

print("\n‚úÖ Setup complete. Ready to train.")

## 2. Check Hardware Acceleration

In [None]:
import psutil
import torch

print("üîç Hardware Check:")
if torch.cuda.is_available():
    print(f"‚úÖ GPU: {torch.cuda.get_device_name(0)}")
    print(f"   VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("‚ùå GPU NOT FOUND. Please go to: Runtime -> Change runtime type -> T4 GPU")

print(f"‚úÖ System RAM: {psutil.virtual_memory().total / 1e9:.2f} GB")

## 3. Workflow Configuration

In [None]:
import os
import shutil

# @markdown ### ‚öôÔ∏è Execution Mode
# @markdown Select **"Fresh Start"** to wipe old data/models and start over.<br>
# @markdown Select **"Resume Training"** to continue from the last checkpoint.
EXECUTION_MODE = "Resume Training" # @param ["Fresh Start", "Resume Training"]

DATA_DIR = "/content/drive/MyDrive/EnStack_Data"
CHECKPOINT_DIR = f"{DATA_DIR}/checkpoints"

if EXECUTION_MODE == "Fresh Start":
    print(f"‚ö†Ô∏è Fresh Start selected. Cleaning up {DATA_DIR}...")

    # Delete checkpoints
    if os.path.exists(CHECKPOINT_DIR):
        shutil.rmtree(CHECKPOINT_DIR)
        print(f"   - Deleted checkpoints: {CHECKPOINT_DIR}")

    # Delete cache files
    if os.path.exists(DATA_DIR):
        for f in os.listdir(DATA_DIR):
            if f.startswith(".cache_") or f.endswith("_processed.pkl"):
                file_path = os.path.join(DATA_DIR, f)
                if os.path.isfile(file_path):
                    os.remove(file_path)
                    print(f"   - Deleted: {f}")

    RESUME_TRAINING = False
    print("‚úÖ Cleanup complete. Ready for fresh training.")

else:
    print("üîÑ Resume Mode selected.")
    if os.path.exists(CHECKPOINT_DIR):
        # Check for last_checkpoint, best_model, or any intermediate checkpoints
        all_models = ["codebert", "graphcodebert", "unixcoder"]
        for model_name in all_models:
            model_dir = os.path.join(CHECKPOINT_DIR, model_name)
            if os.path.exists(model_dir):
                last_ckpt = os.path.join(model_dir, "last_checkpoint")
                best_model = os.path.join(model_dir, "best_model")
                
                if os.path.exists(last_ckpt):
                    print(f"   ‚úÖ {model_name}: Found last_checkpoint")
                elif os.path.exists(best_model):
                    print(f"   ‚úÖ {model_name}: Found best_model")
                else:
                    # Check for intermediate checkpoints
                    import glob
                    intermediates = glob.glob(os.path.join(model_dir, "checkpoint_epoch*_step*"))
                    if intermediates:
                        latest = sorted(intermediates)[-1]
                        print(f"   ‚ö†Ô∏è  {model_name}: Found intermediate checkpoint: {os.path.basename(latest)}")
                    else:
                        print(f"   ‚ùå {model_name}: No valid checkpoints found")

        RESUME_TRAINING = True
    else:
        print("   - No checkpoints found. Will start training from scratch.")
        RESUME_TRAINING = False

## 3.5. (Optional) Validate Checkpoint State

Run this cell to check the current checkpoint state before resuming.

In [None]:
# @markdown ### üîç Checkpoint Validation (Optional)
# @markdown Run this to verify checkpoint state before training
RUN_VALIDATION = False # @param {type:"boolean"}

if RUN_VALIDATION and os.path.exists(CHECKPOINT_DIR):
    print("üîç Validating checkpoint state...\n")

    # Import validation function
    import sys
    sys.path.append('/content/EnStack-paper')
    from scripts.train import find_latest_checkpoint
    from pathlib import Path
    
    all_models = ["codebert", "graphcodebert", "unixcoder"]
    for model_name in all_models:
        model_dir = os.path.join(CHECKPOINT_DIR, model_name)
        if os.path.exists(model_dir):
            # Use the same logic as training to find latest checkpoint
            latest_ckpt = find_latest_checkpoint(Path(model_dir))
            if latest_ckpt:
                print(f"\n{'='*70}")
                print(f"Validating {model_name}: {os.path.basename(latest_ckpt)}")
                print('='*70)
                !python scripts/validate_checkpoint.py --checkpoint_path {latest_ckpt}
            else:
                print(f"\n‚ö†Ô∏è  {model_name}: No checkpoint to validate")
else:
    print("‚è≠Ô∏è  Skipping validation (not enabled or no checkpoints)")

## 4. Data Preparation
Choose to use the **Full Draper VDISC** dataset (paper reproduction) or **Dummy Data** (quick code test).

In [None]:
# @markdown ### Data Source Configuration
DATA_MODE = "Draper VDISC" # @param ["Draper VDISC", "Dummy Data"]
SAMPLE_SIZE = 5000 # @param {type:"integer"}

if DATA_MODE == "Draper VDISC":
    print("üöÄ Downloading and processing Draper VDISC (~1GB)...")
    !chmod +x scripts/setup_draper.sh
    !./scripts/setup_draper.sh
else:
    print(f"üîÑ Generating synthetic dummy data ({SAMPLE_SIZE} samples)...")
    !python scripts/prepare_data.py --output_dir /content/drive/MyDrive/EnStack_Data --mode synthetic --sample {SAMPLE_SIZE}

print("\n‚úÖ Data is ready on Google Drive.")

## 5. Model Selection

Choose which base models to train. You can train all models or select specific ones.

In [None]:
# @markdown ### ü§ñ Model Selection
# @markdown Select which models to train (you can choose one or multiple)
TRAIN_CODEBERT = True # @param {type:"boolean"}
TRAIN_GRAPHCODEBERT = True # @param {type:"boolean"}
TRAIN_UNIXCODER = True # @param {type:"boolean"}

# Build model list based on selection
SELECTED_MODELS = []
if TRAIN_CODEBERT:
    SELECTED_MODELS.append('codebert')
if TRAIN_GRAPHCODEBERT:
    SELECTED_MODELS.append('graphcodebert')
if TRAIN_UNIXCODER:
    SELECTED_MODELS.append('unixcoder')

if not SELECTED_MODELS:
    print("‚ö†Ô∏è  WARNING: No models selected! Please select at least one model.")
else:
    print(f"‚úÖ Selected models for training: {', '.join(SELECTED_MODELS)}")
    print(f"   Total: {len(SELECTED_MODELS)} model(s)")

### üí° Training Tips:

**Training Individual Models:**
- Select only ONE model checkbox above to train it individually
- This is useful for:
  - Re-training a specific model that failed
  - Testing with different hyperparameters
  - Saving time when you only need certain models

**Training All Models:**
- Check all three checkboxes to train the complete ensemble
- Required for final meta-classifier training and evaluation

**Resume Training:**
- The system automatically detects existing checkpoints
- Will resume from the last saved state for selected models
- Unselected models will be skipped even if they have checkpoints

## 6. Training Configuration

Configure training hyperparameters and checkpoint strategy.

In [None]:
import yaml

# @markdown ### üéõÔ∏è Training Hyperparameters
EPOCHS = 10 # @param {type:"integer"}
BATCH_SIZE = 16 # @param {type:"integer"}
ACCUMULATION_STEPS = 1 # @param {type:"integer"}
# @markdown **SWA (Stochastic Weight Averaging) - Recommended for best results**
# @markdown Adds ~1-2 min per epoch but improves F1 by 0.5-1.0%
USE_SWA = True # @param {type:"boolean"}
SWA_START_EPOCH = 6 # @param {type:"integer"}

# @markdown ### üíæ Checkpoint Strategy
# @markdown - `save_steps=0`: Only save at end of epoch (fastest, risky if crash)
# @markdown - `save_steps=500`: Save every 500 batches (recommended for Colab)
# @markdown - `save_steps=1000`: Less frequent saves (faster, but more wasted work if crash)
SAVE_STEPS = 500 # @param {type:"integer"}

# Update config.yaml with notebook parameters
with open('configs/config.yaml', 'r') as f:
    config = yaml.safe_load(f)

config['training']['epochs'] = EPOCHS
config['training']['batch_size'] = BATCH_SIZE
config['training']['gradient_accumulation_steps'] = ACCUMULATION_STEPS
config['training']['use_swa'] = USE_SWA
config['training']['swa_start'] = SWA_START_EPOCH
config['training']['save_steps'] = SAVE_STEPS

# Update base_models in config with selected models
config['model']['base_models'] = SELECTED_MODELS

with open('configs/config.yaml', 'w') as f:
    yaml.dump(config, f)

print("‚úÖ Configuration updated:")
print(f"   - Selected Models: {', '.join(SELECTED_MODELS)}")
print(f"   - Epochs: {EPOCHS}")
print(f"   - Batch Size: {BATCH_SIZE}")
print(f"   - SWA (Stochastic Weight Averaging): {USE_SWA}")
if USE_SWA:
    print(f"   - SWA Start Epoch: {SWA_START_EPOCH} (will average epochs {SWA_START_EPOCH}-{EPOCHS})")
print(f"   - Checkpoint Strategy: save every {SAVE_STEPS} steps" if SAVE_STEPS > 0 else "   - Checkpoint Strategy: only at end of epoch")
print(f"   - Resume: {RESUME_TRAINING}")

## 7. Run Optimized Training Pipeline

This cell executes training for your selected base models.

**Note:** Training will automatically:
- Train only the models you selected above
- Resume from last checkpoint if `RESUME_TRAINING=True`
- Save checkpoints according to `SAVE_STEPS` strategy
- Log progress with detailed checkpoint information

In [None]:
# Verify models are selected
if not SELECTED_MODELS:
    print("‚ùå ERROR: No models selected for training!")
    print("   Please go back to 'Model Selection' cell and select at least one model.")
else:
    print("üöÄ Starting Training Pipeline...")
    print("=" * 70)
    print(f"Training models: {', '.join(SELECTED_MODELS)}")
    print("=" * 70)

    # Use RESUME_TRAINING variable from Workflow Configuration step
    !python scripts/train.py --config configs/config.yaml {'--resume' if RESUME_TRAINING else ''}

    print("\n" + "=" * 70)
    print("‚úÖ Training pipeline completed!")

In [None]:
# @markdown ### üõ°Ô∏è Safety Verification
# @markdown Checks if all selected models were successfully trained before proceeding.
import os
import yaml

print("üîç Verifying training completeness...")

# Load current configuration
with open('configs/config.yaml', 'r') as f:
    config = yaml.safe_load(f)

required_models = config['model']['base_models']
output_dir = config['training']['output_dir']
missing_models = []

for model in required_models:
    # Check for valid checkpoint artifacts
    model_dir = os.path.join(output_dir, model)
    is_valid = False
    
    if os.path.exists(model_dir):
        # Check for standard HuggingFace weights or our training state
        if os.path.exists(os.path.join(model_dir, "pytorch_model.bin")) or \
           os.path.exists(os.path.join(model_dir, "training_state.pth")):
            is_valid = True
            
    if not is_valid:
        missing_models.append(model)

if missing_models:
    error_msg = (
        f"\n‚ùå BLOCKING EXECUTION: Missing trained checkpoints for: {', '.join(missing_models)}\n"
        f"   The subsequent evaluation steps require these models to be available.\n"
        f"   Please check the training logs in Cell 7 for errors."
    )
    raise RuntimeError(error_msg)

print(f"‚úÖ Verification Passed: All {len(required_models)} models are ready for stacking.")

## 7.5. (Optional) Cleanup Old Checkpoints

Free up disk space by removing old mid-epoch checkpoints.

In [None]:
# @markdown ### üßπ Checkpoint Cleanup
# @markdown Remove old mid-epoch checkpoints to save Google Drive space
RUN_CLEANUP = False # @param {type:"boolean"}
KEEP_LAST_N = 0 # @param {type:"integer"}

if RUN_CLEANUP:
    print("üßπ Cleaning up old checkpoints...\n")

    all_models = ["codebert", "graphcodebert", "unixcoder"]
    for model_name in all_models:
        ckpt_dir = os.path.join(CHECKPOINT_DIR, model_name)
        if os.path.exists(ckpt_dir):
            print(f"\nCleaning {model_name}:")
            !python scripts/cleanup_checkpoints.py \
                --checkpoint_dir {ckpt_dir} \
                --keep-last {KEEP_LAST_N} \
                --auto
else:
    print("‚è≠Ô∏è  Skipping cleanup (not enabled)")

## 8. Meta-Classifier Comparison (Table III Reproduction)
Evaluate different meta-classifiers (SVM, Logistic Regression, XGBoost) on the same optimized features.

In [None]:
from pathlib import Path

import pandas as pd
import torch
import yaml
from IPython.display import display

from scripts.train import extract_all_features, load_labels_from_file, train_base_models
from src.stacking import (
    evaluate_meta_classifier,
    prepare_meta_features,
    train_meta_classifier,
)
from src.utils import get_device
from src.visualization import plot_meta_feature_importance


def reproduce_table_iii():
    print("üìä Comparing Meta-Classifiers (LR vs RF vs SVM vs XGBoost)...")

    with open("configs/config.yaml", 'r') as f:
        config = yaml.safe_load(f)

    device = get_device()
    root_dir = Path(config['data']['root_dir'])

    # 1. Load models and pre-created dataloaders
    trainers, dataloaders = train_base_models(config, config['model']['base_models'],
                                             num_epochs=0, device=device, resume=True)

    # 2. Extract Optimized Features (with caching)
    features_dict = extract_all_features(config, trainers, dataloaders, mode="logits", use_cache=True)

    # 3. Load Labels
    train_labels = load_labels_from_file(root_dir / config['data']['train_file'])
    test_labels = load_labels_from_file(root_dir / config['data']['test_file'])

    # 4. Prepare Meta-features with Scaling/PCA
    train_meta, _, pca, scaler = prepare_meta_features(features_dict['train'], train_labels, use_pca=True, use_scaling=True)
    test_meta, _, _, _ = prepare_meta_features(features_dict['test'], pca_model=pca, scaler=scaler, use_pca=True, use_scaling=True)

    # 5. Iterative Evaluation
    results = []
    for m_type in ["lr", "rf", "svm", "xgboost"]:
        print(f"  > Training {m_type.upper()}...")
        params = config['model']['meta_classifier_params'].get(m_type, {})
        clf = train_meta_classifier(train_meta, train_labels, classifier_type=m_type, **params)
        metrics = evaluate_meta_classifier(clf, test_meta, test_labels)
        
        if m_type == 'xgboost':
            feature_names = []
            for model_name in config['model']['base_models']:
                num_classes = config['model'].get('num_labels', 5)
                for c in range(num_classes):
                    feature_names.append(f"{model_name}_prob_{c}")
            
            plot_meta_feature_importance(clf, feature_names, save_path=f"{config['training']['output_dir']}/feature_importance.png")
        
        results.append({"Classifier": m_type.upper(), "Acc": metrics['accuracy']*100, "F1": metrics['f1']*100, "AUC": metrics['auc']*100})

    return pd.DataFrame(results)

comparison_df = reproduce_table_iii()
display(comparison_df)

## 9. Advanced Visualization

In [None]:
import glob

from IPython.display import Image

print("üìà Training Curves:")
hist_plots = glob.glob(f"{config['training']['output_dir']}/**/training_history.png", recursive=True)
for p in hist_plots:
    print(f"Source: {p}")
    display(Image(filename=p))

print("\nüéØ Final Confusion Matrix:")
display(Image(filename=f"{config['training']['output_dir']}/confusion_matrix.png"))

print("\n‚≠ê Feature Importance (Base Model Impact):")
display(Image(filename=f"{config['training']['output_dir']}/feature_importance.png"))

## 10. Model Export for Deployment

In [None]:
# Export the primary model to ONNX for 3x faster CPU inference
from src.models import create_model

print("üöÄ Exporting model for production...")
model_name = config['model']['base_models'][0]
model, _ = create_model(model_name, config, pretrained=False)

# Find the best checkpoint to export
import sys
sys.path.append('/content/EnStack-paper')
from scripts.train import find_latest_checkpoint
from pathlib import Path

model_dir = Path(config['training']['output_dir']) / model_name
checkpoint_path = find_latest_checkpoint(model_dir)

if checkpoint_path:
    print(f"üì¶ Using checkpoint: {os.path.basename(checkpoint_path)}")
    # Load weights from your best run
    model.load_state_dict(torch.load(f"{checkpoint_path}/pytorch_model.bin", map_location='cpu'), strict=False)

    onnx_path = f"{config['training']['output_dir']}/optimized_model.onnx"
    model.export_onnx(onnx_path)
    print(f"‚úÖ Successfully exported to: {onnx_path}")
else:
    print("‚ùå No checkpoint found to export.")

## 11. Troubleshooting & Documentation

Access helper documentation and tools.

In [None]:
print("üìö Available Documentation:")
print("\n1. Checkpoint System:")
print("   - CHECKPOINT_ANALYSIS.md - Root cause analysis")
print("   - CHECKPOINT_CORRECTNESS.md - Semantic correctness proof")
print("   - CHECKPOINT_VISUAL_GUIDE.md - Visual diagrams")
print("   - CHECKPOINT_STRATEGY.md - Configuration guide")
print("   - FINAL_VALIDATION.md - Validation summary")

print("\n2. Validation Tools:")
print("   !python scripts/validate_checkpoint.py --checkpoint_path <path>")
print("   !python scripts/debug_checkpoint.py --checkpoint_path <path>")
print("   !python scripts/demo_checkpoint_crash.py")

print("\n3. Cleanup Tools:")
print("   !python scripts/cleanup_checkpoints.py --checkpoint_dir <path> --keep-last 0")

print("\n4. Fix Tools:")
print("   !python scripts/fix_checkpoint_epoch.py --checkpoint_path <path> --epoch <n>")

print("\nüìñ For detailed guides, check the markdown files in the repository root.")