This repository implements policy distillation techniques for transferring knowledge from a state-based teacher policy to a vision-based student policy in MuJoCo continuous control tasks (e.g., Pusher).
The student model can process both rendered image frames and internal robot state variables (e.g., joint positions, velocities), enabling multimodal learning from pixel observations combined with proprioceptive information.
Four distillation methods are implemented, each with different trade-offs:
1. Policy Proximal Distillation (PPD) [1]
An on-policy method that combines PPO's policy gradient objective with a distillation loss from the teacher. The student collects its own rollouts in the environment and optimizes a combined objective:
where PPD_coef) controls the balance between RL exploration and teacher imitation.
The teacher selects actions in the environment while the student learns to match them. The student is trained using PPO but with the teacher driving the data collection. This provides a stable training signal since the teacher generates high-quality trajectories.
The student selects actions in the environment and the teacher provides corrective supervision. The student optimizes a combination of RL rewards and distillation from the teacher's policy evaluated on the same states. This method encourages the student to explore while learning from the teacher.
A two-phase offline-to-online approach:
- BC: The student is trained purely offline on a pre-collected dataset of teacher demonstrations using behavioural cloning (supervised learning on state-action pairs).
- DAgger (optional): The student is deployed in the environment to collect new data, which is aggregated with the original dataset. The teacher labels the student-visited states with the correct actions, and training continues on the expanded dataset.
For BC phase, the repository includes a Prioritized Experience Replay buffer with importance sampling correction:
- Priority-based sampling: Samples with higher prediction error are drawn more frequently (controlled by
alpha). - Importance sampling weights: Corrects the bias introduced by non-uniform sampling (controlled by
beta, annealed from an initial value toward 1.0). - Alpha annealing modes: Supports
constant,linear,dynamic_mean, anddynamic_maxschedules to adapt priorities during training. - Balanced batch sampling: An optional sampler that draws 50% old data and 50% new data per batch, ensuring both are represented.
| PPD | StudentDistillation | TeacherDistillation |
|---|---|---|
PPD_lambda5.mp4 |
SD.mp4 |
TD.mp4 |
| BC-10k dataset | BC-50k dataset |
|---|---|
BC_10k.mp4 |
BC_50k.mp4 |
The following plot illustrates the reward convergence across different distillation methods.
- Sample Efficiency: Behavioural Cloning (BC) achieves high performance significantly faster than on-policy methods. By leveraging a pre-collected teacher dataset, BC bypasses the need for collecting millions of online environmental transitions during the initial learning phase.
- Rapid Growth: While RL-based methods (PPD, Student/Teacher Distillation) require extensive exploration to stabilize, BC converges almost instantly to a near-teacher level of competence, making it the most resource-efficient choice for visual policy initialization.
- Stability: The BC phase provides a consistent supervised signal, avoiding the high variance typically seen in early-stage Reinforcement Learning.
Three student model sizes are available, all based on the IMPALA CNN backbone:
| Model | CNN Filters | Description |
|---|---|---|
ImpaalaSmall |
(16, 32) | Lightweight model |
ImpaalaMid |
(16, 32, 32) | Medium model |
ImpaalaBig |
(32, 64, 64) | Full-capacity model |
Internal states are upscaled and then concatenated with the hidden representation of the visual input. All models use the same MLP actor head on top of the CNN backbone.
The IMPALA (Importance Weighted Actor-Learner Architecture) backbone is a residual CNN designed for efficient visual feature extraction. It consists of:
- ConvSequence blocks: Each block has a convolutional layer followed by max-pooling and two residual blocks.
- Residual blocks: Standard residual connections with two 3×3 convolutions and ReLU activations, enabling deeper networks without degradation.
- Average pooling: A spatial pooling layer at the end reduces the feature map to a fixed-size vector regardless of input dimensions.
This architecture strikes a good balance between computational efficiency and representational capacity, making it well-suited for RL tasks with visual observations.
├── Algorithm/ # Core training algorithms
│ ├── PPD.py # Policy Proximal Distillation
│ ├── PPO.py # Proximal Policy Optimization (base)
│ ├── StudentTeacher_distillation.py # Teacher/Student online distillation
│ ├── BehaviouralCloning.py # BC + DAgger phases
│ ├── CollectDataset.py # Teacher dataset collection
│ ├── Test_model.py # Model evaluation utilities
│ ├── Utils.py # DataLoader, Prioritized Buffer, helpers
│ └── seeds.py # Centralized seed management
├── Models/ # Neural network architectures
│ ├── Base_Model.py # Base model class (actor-critic)
│ ├── Impoola.py # IMPALA backbone + actor/critic builders
│ ├── Model_Pusher.py # Pusher-specific model configurations
│ └── Utils.py # Visual input manager, screen processing
├── Enviroment/ # Environment wrappers
│ ├── My_wrapper.py # Custom Gymnasium wrappers
│ └── Utils.py # Environment creation utilities
├── Distillation/ # Experiment entry points
│ ├── Pusher.py # Main script for Pusher distillation experiments
│ └── Utils.py # Experiment configuration helpers
├── Train_Teacher_Pusher/ # Teacher training (SAC via Stable-Baselines3)
│ ├── Mujoco_main.py # SAC training script
│ ├── callbacks.py # Custom evaluation callbacks
│ ├── SaveWeights.py # Weight extraction from SB3 to custom model
│ └── Test_Model.py # Teacher model testing
├── Results/ # Saved models and datasets
├── requirements.txt
└── README.md
pip install -r requirements.txtNote: MuJoCo environments require a working MuJoCo installation. The
gymnasium[mujoco]package handles this automatically.
cd Train_Teacher_Pusher
python Mujoco_main.py --device cuda:0 --total-timesteps 5000000python Distillation/Pusher.py --C_data --num_data 100000 --device cuda:0PPD:
python Distillation/Pusher.py --PPD --mode ImpaalaMid --PPD_parameter 5 --device cuda:0Teacher Distillation:
python Distillation/Pusher.py --Tdistillation --mode ImpaalaMid --device cuda:0Student Distillation:
python Distillation/Pusher.py --Sdistillation --mode ImpaalaMid --device cuda:0Behavioural Cloning (Phase 1):
python Distillation/Pusher.py --BC_phase --mode ImpaalaMid --dataset 100k --alpha 0.0 --device cuda:0DAgger (Phase 2):
python Distillation/Pusher.py --Dagger_phase --folder_Dagger ./Results/Pusher/OurAlgorithm1_NLL/Phase_One/<run_folder> --device cuda:0| Argument | Description | Default |
|---|---|---|
--mode |
Student architecture: State, ImpaalaSmall, ImpaalaMid, ImpaalaBig |
State |
--dataset |
Dataset size: 5k, 10k, 50k, 100k |
100k |
--alpha |
PER priority exponent (0.0 = uniform sampling) | 0.0 |
--loss_type |
BC loss: KL (KL divergence) or NLL (negative log-likelihood) |
NLL |
--PPD_parameter |
PPD distillation coefficient λ | 5 |
--run_index |
Run index for seed selection (0–5) | 0 |
All experiments are logged to Weights & Biases. Metrics include:
- Episode rewards (student and teacher)
- Action accuracy (student vs teacher)
- Policy entropy
- Training loss
- Buffer statistics (for prioritized replay)
- Evaluation videos
[1] Spigler, G. (2025). Proximal Policy Distillation. Transactions on Machine Learning Research. https://openreview.net/forum?id=WfVXe88oMh