A PyTorch implementation of a cost-aware, gated BERT pool (MPC-Router) for efficient Natural Language Understanding. This project implements a Mixture of Experts (MoE) architecture with asymmetric backpropagation to balance accuracy, computational cost, and load distribution across experts.
The MPC-BERT Pool architecture consists of:
- Router Stem: Lightweight BERT encoder (e.g., bert-tiny) that produces CLS tokens for routing decisions
- Expert Pool: Multiple heterogeneous BERT expert models (tiny/base/large) processing in parallel
- MLP Router: 2-layer MLP that assigns gating probabilities to experts via Gumbel-Softmax
- Composite Loss: Combines task loss, load balancing, and cost awareness
Key innovation: Asymmetric backpropagation where:
- Task loss (L_task) updates all components
- Auxiliary losses (L_balance, L_cost) only update the router
- Python >= 3.8
- CUDA-compatible GPU (required for training)
- conda (recommended)
# Clone the repository
git clone https://github.com/UCF-ML-Research/SecureRouter-Encrypted-Routing-for-Efficient-Secure-Inference.git
cd SecureRouter-Encrypted-Routing-for-Efficient-Secure-Inference
# Create conda environment (recommended)
conda create -n mpc_router python=3.11 -y
conda activate mpc_router
# Editable install — pulls in all runtime deps via pyproject.toml
pip install -e .
# Optional: dev tools (pytest, black, mypy, ...)
pip install -e ".[dev]"This installs the package as mpc_router and registers four CLI entry
points: mpc-router-train, mpc-router-tune, mpc-router-eval,
mpc-router-demo. See docs/API.md for the full API
reference.
python -c "from mpc_router import GatedBertPool, CompositeLoss, GLUEDataLoader; print('mpc_router OK')"import torch
from mpc_router import GatedBertPool, CompositeLoss
model = GatedBertPool(
router_model_name="prajjwal1/bert-tiny",
expert_model_names=[
"M-FAC/bert-tiny-finetuned-mrpc",
"howey/bert-base-uncased-mrpc",
"yoshitomo-matsubara/bert-large-uncased-mrpc",
],
num_labels=2,
)
loss_fn = CompositeLoss(
task_type="classification", num_labels=2, num_experts=3,
expert_costs=torch.tensor([2.0, 7.0, 13.0]),
alpha=0.01, beta=0.03,
)
# Soft gating (training): runs all experts, returns gradient signal to the router
predictions, gating_weights = model(input_ids, attention_mask)
loss, components = loss_fn(predictions, labels, gating_weights)
# Hard gating (inference): runs only the chosen expert per sample
predictions, chosen_expert = model.forward_inference(input_ids, attention_mask)# Train on MRPC with default config
mpc-router-train --config configs/task_mrpc.yaml
# Train with custom loss weights
mpc-router-train --config configs/task_sst2.yaml --alpha 0.5 --beta 0.001
# Legacy in-tree invocation (still works in dev mode)
PYTHONPATH=$(pwd) python src/train.py --config configs/task_mrpc.yamlMPC_Router_with_LLM_pool/
├── .github/ # GitHub tooling
│ ├── workflows/ci.yml # CI: auto-test on push/PR
│ ├── pull_request_template.md
│ └── ISSUE_TEMPLATE/ # Bug report & feature request templates
├── pyproject.toml # Package metadata (PEP 621) — source of truth
├── setup.py # Stub for legacy tooling only
├── requirement.txt # Dependency list (mirrored in pyproject.toml)
├── docs/
│ └── API.md # Full API reference
├── configs/ # Configuration files
│ ├── base_config.yaml # Default parameters
│ ├── task_mrpc.yaml # MRPC task config
│ ├── task_sst2.yaml # SST-2 task config
│ └── task_*.yaml # Other GLUE task configs
├── src/ # Source code (installed as package `mpc_router`)
│ ├── __init__.py # Public API re-exports
│ ├── model.py # GatedBertPool, RouterStem, ExpertPool, MLPRouter, BertExpert
│ ├── loss.py # TaskLoss, LoadBalancingLoss, CostAwareLoss, CompositeLoss
│ ├── train.py # Training with asymmetric backprop → mpc-router-train
│ ├── train_supervised.py # Two-stage supervised training
│ ├── evaluate_inference.py # Hard-gating evaluation → mpc-router-eval
│ ├── demo.py # Interactive CLI demo → mpc-router-demo
│ ├── data_loader.py # GLUEDataLoader
│ ├── utils.py # Logging, checkpoints, plotting, MetricTracker
│ └── hyperparameter_tuning.py # Grid search → mpc-router-tune
├── scripts/ # Executable scripts
│ ├── run_training.sh # Training launcher
│ ├── run_reviewer_experiments.py # DAC reviewer response runs
│ └── security/ # Side-channel defense analysis
├── tools/
│ └── benchmark/ # Microbench scripts (standalone, NOT part of the package)
├── tutorials/ # Jupyter notebooks
└── README.md # This file
- Takes first N layers from pretrained BERT
- Processes all inputs identically
- Returns CLS token and sequence outputs
- K experts, each with remaining BERT layers
- Each expert specializes in different input types
- Processes inputs in parallel
- Simple 2-layer MLP
- Takes CLS token from stem
- Outputs gating probabilities for experts
Total Loss: L_total = L_task + α·L_balance + β·L_cost
- Task Loss (L_task): Standard cross-entropy or MSE
- Load Balancing Loss (L_balance): Squared coefficient of variation
- Cost-Aware Loss (L_cost): Expected computational cost
Task configs inherit from configs/base_config.yaml and override as needed. Example from configs/task_mrpc.yaml:
model:
router_model_name: "prajjwal1/bert-tiny"
router_stem_layers: 2
expert_model_names:
- "M-FAC/bert-tiny-finetuned-mrpc"
- "howey/bert-base-uncased-mrpc"
- "yoshitomo-matsubara/bert-large-uncased-mrpc"
router_hidden_size: 256
router_dropout: 0.1
tokenizer_model_name: "bert-base-uncased"
expert_costs:
- 2.0 # Relative cost for bert-tiny
- 7.0 # Relative cost for bert-base
- 13.0 # Relative cost for bert-large
loss_weights:
alpha: 0.01 # Weight for load balancing loss
beta: 0.03 # Weight for cost-aware lossThe training implements custom gradient control:
# Step 1: L_task gradients on ALL parameters
l_task.backward(retain_graph=True)
# Step 2: L_aux gradients on ROUTER ONLY
l_aux = alpha * l_balance + beta * l_cost
aux_grads = torch.autograd.grad(l_aux, router_params)
# Step 3: Add auxiliary gradients to router
for param, grad in zip(router_params, aux_grads):
param.grad += gradThe implementation supports all GLUE benchmark tasks:
| Task | Type | Input Format | Metric |
|---|---|---|---|
| CoLA | Classification | Single sentence | Matthews Corr |
| SST-2 | Classification | Single sentence | Accuracy |
| MRPC | Classification | Sentence pair | F1/Accuracy |
| STS-B | Regression | Sentence pair | Pearson/Spearman |
| QQP | Classification | Sentence pair | F1/Accuracy |
| MNLI | Classification | Sentence pair | Accuracy |
| QNLI | Classification | Sentence pair | Accuracy |
| RTE | Classification | Sentence pair | Accuracy |
- Baseline (α=0, β=0): Maximum accuracy, no constraints
- Grid Search:
- α ∈ [0.01, 0.1, 1.0]
- β ∈ [0.001, 0.01, 0.1]
- Analyze Pareto Frontier: Plot accuracy vs. cost trade-off
for alpha in 0.01 0.1 1.0; do
for beta in 0.001 0.01 0.1; do
./scripts/run_training.sh mrpc $alpha $beta
done
doneTraining logs are saved to TensorBoard:
# Start TensorBoard
tensorboard --logdir outputs/logs/
# Key metrics to monitor:
# - loss/task: Task performance
# - loss/balance: Load distribution
# - loss/cost: Computational cost
# - eval_avg_gate_prob_expert_*: Expert utilizationTry the model with your own sentences:
export PYTHONPATH=$(pwd)
# SST-2 sentiment analysis demo
python src/demo.py --checkpoint outputs/sst2_best_model.pt
# MRPC paraphrase detection demo
python src/demo.py --checkpoint outputs/mrpc_best_model.ptExample output:
>>> This movie is absolutely amazing!
Sentiment: Positive
Confidence: 99.6%
Expert: [2] BERT-large
Router Scores:
[0] BERT-tiny 10.0% ##
[1] BERT-base 40.0% ############
[2] BERT-large 50.0% ############### <--
# Weighted average of all experts
predictions, gating_weights = model(input_ids, attention_mask)# Single expert selection for efficiency
predictions, chosen_expert_indices = model.forward_inference(input_ids, attention_mask)- Dynamic Padding: Sequences padded per-batch, not globally
- ModuleList Usage: Critical for proper parameter registration
- Gradient Retention:
retain_graph=Truefor multi-loss backprop - Cost Definition: Expert costs are predefined constants
- Symptom: One expert gets 100% probability
- Solution: Increase α (balance weight)
- Symptom: Performance drops vs baseline
- Solution: Decrease α and β weights
- Symptom: OOM errors
- Solution: Reduce batch size or num_experts
Based on the paper "Building an MPC Router / BERT Pool" implementing:
- Mixture of Experts architecture
- Cost-aware routing
- Load balancing mechanisms
- Asymmetric gradient control
Contributions are welcome! This repo includes PR and Issue templates to help you get started.
- Fork the repository
- Create a feature branch (
git checkout -b feature/my-feature) - Commit your changes (
git commit -m "Add my feature") - Push to the branch (
git push origin feature/my-feature) - Open a Pull Request — the PR template will guide you
This project is released under the MIT License.
- Hugging Face Transformers for BERT implementation
- PyTorch team for the deep learning framework
- GLUE benchmark creators for evaluation datasets
Note: This is a research implementation focusing on the novel asymmetric backpropagation technique and cost-aware routing mechanisms. Production deployment would require additional optimizations.