Skip to content

UCF-ML-Research/SecureRouter

Repository files navigation

MPC-BERT Pool: Cost-Aware Mixture of Experts with Gated Routing

CI

A PyTorch implementation of a cost-aware, gated BERT pool (MPC-Router) for efficient Natural Language Understanding. This project implements a Mixture of Experts (MoE) architecture with asymmetric backpropagation to balance accuracy, computational cost, and load distribution across experts.

📋 Overview

The MPC-BERT Pool architecture consists of:

  • Router Stem: Lightweight BERT encoder (e.g., bert-tiny) that produces CLS tokens for routing decisions
  • Expert Pool: Multiple heterogeneous BERT expert models (tiny/base/large) processing in parallel
  • MLP Router: 2-layer MLP that assigns gating probabilities to experts via Gumbel-Softmax
  • Composite Loss: Combines task loss, load balancing, and cost awareness

Key innovation: Asymmetric backpropagation where:

  • Task loss (L_task) updates all components
  • Auxiliary losses (L_balance, L_cost) only update the router

🚀 Quick Start

Prerequisites

  • Python >= 3.8
  • CUDA-compatible GPU (required for training)
  • conda (recommended)

Installation

# Clone the repository
git clone https://github.com/UCF-ML-Research/SecureRouter-Encrypted-Routing-for-Efficient-Secure-Inference.git
cd SecureRouter-Encrypted-Routing-for-Efficient-Secure-Inference

# Create conda environment (recommended)
conda create -n mpc_router python=3.11 -y
conda activate mpc_router

# Editable install — pulls in all runtime deps via pyproject.toml
pip install -e .

# Optional: dev tools (pytest, black, mypy, ...)
pip install -e ".[dev]"

This installs the package as mpc_router and registers four CLI entry points: mpc-router-train, mpc-router-tune, mpc-router-eval, mpc-router-demo. See docs/API.md for the full API reference.

Verify the install (no GPU needed)

python -c "from mpc_router import GatedBertPool, CompositeLoss, GLUEDataLoader; print('mpc_router OK')"

Python API — 30-second example

import torch
from mpc_router import GatedBertPool, CompositeLoss

model = GatedBertPool(
    router_model_name="prajjwal1/bert-tiny",
    expert_model_names=[
        "M-FAC/bert-tiny-finetuned-mrpc",
        "howey/bert-base-uncased-mrpc",
        "yoshitomo-matsubara/bert-large-uncased-mrpc",
    ],
    num_labels=2,
)
loss_fn = CompositeLoss(
    task_type="classification", num_labels=2, num_experts=3,
    expert_costs=torch.tensor([2.0, 7.0, 13.0]),
    alpha=0.01, beta=0.03,
)
# Soft gating (training): runs all experts, returns gradient signal to the router
predictions, gating_weights = model(input_ids, attention_mask)
loss, components = loss_fn(predictions, labels, gating_weights)

# Hard gating (inference): runs only the chosen expert per sample
predictions, chosen_expert = model.forward_inference(input_ids, attention_mask)

Training a Model (CLI)

# Train on MRPC with default config
mpc-router-train --config configs/task_mrpc.yaml

# Train with custom loss weights
mpc-router-train --config configs/task_sst2.yaml --alpha 0.5 --beta 0.001

# Legacy in-tree invocation (still works in dev mode)
PYTHONPATH=$(pwd) python src/train.py --config configs/task_mrpc.yaml

📁 Project Structure

MPC_Router_with_LLM_pool/
├── .github/                   # GitHub tooling
│   ├── workflows/ci.yml      # CI: auto-test on push/PR
│   ├── pull_request_template.md
│   └── ISSUE_TEMPLATE/       # Bug report & feature request templates
├── pyproject.toml             # Package metadata (PEP 621) — source of truth
├── setup.py                   # Stub for legacy tooling only
├── requirement.txt            # Dependency list (mirrored in pyproject.toml)
├── docs/
│   └── API.md                # Full API reference
├── configs/                   # Configuration files
│   ├── base_config.yaml      # Default parameters
│   ├── task_mrpc.yaml        # MRPC task config
│   ├── task_sst2.yaml        # SST-2 task config
│   └── task_*.yaml           # Other GLUE task configs
├── src/                       # Source code (installed as package `mpc_router`)
│   ├── __init__.py           # Public API re-exports
│   ├── model.py              # GatedBertPool, RouterStem, ExpertPool, MLPRouter, BertExpert
│   ├── loss.py               # TaskLoss, LoadBalancingLoss, CostAwareLoss, CompositeLoss
│   ├── train.py              # Training with asymmetric backprop  → mpc-router-train
│   ├── train_supervised.py   # Two-stage supervised training
│   ├── evaluate_inference.py # Hard-gating evaluation              → mpc-router-eval
│   ├── demo.py               # Interactive CLI demo                → mpc-router-demo
│   ├── data_loader.py        # GLUEDataLoader
│   ├── utils.py              # Logging, checkpoints, plotting, MetricTracker
│   └── hyperparameter_tuning.py  # Grid search                     → mpc-router-tune
├── scripts/                   # Executable scripts
│   ├── run_training.sh       # Training launcher
│   ├── run_reviewer_experiments.py  # DAC reviewer response runs
│   └── security/             # Side-channel defense analysis
├── tools/
│   └── benchmark/            # Microbench scripts (standalone, NOT part of the package)
├── tutorials/                 # Jupyter notebooks
└── README.md                  # This file

🏗️ Architecture Details

Shared Stem Module

  • Takes first N layers from pretrained BERT
  • Processes all inputs identically
  • Returns CLS token and sequence outputs

Expert Pool

  • K experts, each with remaining BERT layers
  • Each expert specializes in different input types
  • Processes inputs in parallel

MLP Router

  • Simple 2-layer MLP
  • Takes CLS token from stem
  • Outputs gating probabilities for experts

Loss Functions

Total Loss: L_total = L_task + α·L_balance + β·L_cost

  1. Task Loss (L_task): Standard cross-entropy or MSE
  2. Load Balancing Loss (L_balance): Squared coefficient of variation
  3. Cost-Aware Loss (L_cost): Expected computational cost

⚙️ Configuration

Task configs inherit from configs/base_config.yaml and override as needed. Example from configs/task_mrpc.yaml:

model:
  router_model_name: "prajjwal1/bert-tiny"
  router_stem_layers: 2
  expert_model_names:
    - "M-FAC/bert-tiny-finetuned-mrpc"
    - "howey/bert-base-uncased-mrpc"
    - "yoshitomo-matsubara/bert-large-uncased-mrpc"
  router_hidden_size: 256
  router_dropout: 0.1
  tokenizer_model_name: "bert-base-uncased"

expert_costs:
  - 2.0   # Relative cost for bert-tiny
  - 7.0   # Relative cost for bert-base
  - 13.0  # Relative cost for bert-large

loss_weights:
  alpha: 0.01  # Weight for load balancing loss
  beta: 0.03   # Weight for cost-aware loss

🔬 Asymmetric Backpropagation

The training implements custom gradient control:

# Step 1: L_task gradients on ALL parameters
l_task.backward(retain_graph=True)

# Step 2: L_aux gradients on ROUTER ONLY
l_aux = alpha * l_balance + beta * l_cost
aux_grads = torch.autograd.grad(l_aux, router_params)

# Step 3: Add auxiliary gradients to router
for param, grad in zip(router_params, aux_grads):
    param.grad += grad

📊 Supported GLUE Tasks

The implementation supports all GLUE benchmark tasks:

Task Type Input Format Metric
CoLA Classification Single sentence Matthews Corr
SST-2 Classification Single sentence Accuracy
MRPC Classification Sentence pair F1/Accuracy
STS-B Regression Sentence pair Pearson/Spearman
QQP Classification Sentence pair F1/Accuracy
MNLI Classification Sentence pair Accuracy
QNLI Classification Sentence pair Accuracy
RTE Classification Sentence pair Accuracy

🎯 Hyperparameter Tuning

Recommended Strategy

  1. Baseline (α=0, β=0): Maximum accuracy, no constraints
  2. Grid Search:
    • α ∈ [0.01, 0.1, 1.0]
    • β ∈ [0.001, 0.01, 0.1]
  3. Analyze Pareto Frontier: Plot accuracy vs. cost trade-off

Example Grid Search

for alpha in 0.01 0.1 1.0; do
    for beta in 0.001 0.01 0.1; do
        ./scripts/run_training.sh mrpc $alpha $beta
    done
done

📈 Monitoring Training

Training logs are saved to TensorBoard:

# Start TensorBoard
tensorboard --logdir outputs/logs/

# Key metrics to monitor:
# - loss/task: Task performance
# - loss/balance: Load distribution
# - loss/cost: Computational cost
# - eval_avg_gate_prob_expert_*: Expert utilization

🎮 Interactive Demo

Try the model with your own sentences:

export PYTHONPATH=$(pwd)

# SST-2 sentiment analysis demo
python src/demo.py --checkpoint outputs/sst2_best_model.pt

# MRPC paraphrase detection demo
python src/demo.py --checkpoint outputs/mrpc_best_model.pt

Example output:

>>> This movie is absolutely amazing!

  Sentiment:  Positive
  Confidence: 99.6%
  Expert:     [2] BERT-large

  Router Scores:
    [0] BERT-tiny      10.0%  ##
    [1] BERT-base      40.0%  ############
    [2] BERT-large     50.0%  ############### <--

🚀 Inference

Soft Gating (Training Mode)

# Weighted average of all experts
predictions, gating_weights = model(input_ids, attention_mask)

Hard Gating (Inference Mode)

# Single expert selection for efficiency
predictions, chosen_expert_indices = model.forward_inference(input_ids, attention_mask)

📝 Key Implementation Notes

  1. Dynamic Padding: Sequences padded per-batch, not globally
  2. ModuleList Usage: Critical for proper parameter registration
  3. Gradient Retention: retain_graph=True for multi-loss backprop
  4. Cost Definition: Expert costs are predefined constants

🔍 Troubleshooting

Router Collapse

  • Symptom: One expert gets 100% probability
  • Solution: Increase α (balance weight)

Poor Accuracy

  • Symptom: Performance drops vs baseline
  • Solution: Decrease α and β weights

High Memory Usage

  • Symptom: OOM errors
  • Solution: Reduce batch size or num_experts

📚 References

Based on the paper "Building an MPC Router / BERT Pool" implementing:

  • Mixture of Experts architecture
  • Cost-aware routing
  • Load balancing mechanisms
  • Asymmetric gradient control

🤝 Contributing

Contributions are welcome! This repo includes PR and Issue templates to help you get started.

  1. Fork the repository
  2. Create a feature branch (git checkout -b feature/my-feature)
  3. Commit your changes (git commit -m "Add my feature")
  4. Push to the branch (git push origin feature/my-feature)
  5. Open a Pull Request — the PR template will guide you

📄 License

This project is released under the MIT License.

🙏 Acknowledgments

  • Hugging Face Transformers for BERT implementation
  • PyTorch team for the deep learning framework
  • GLUE benchmark creators for evaluation datasets

Note: This is a research implementation focusing on the novel asymmetric backpropagation technique and cost-aware routing mechanisms. Production deployment would require additional optimizations.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors