This repository contains the official implementation of Learning with Preserving (LwP), a novel framework for continual learning that addresses catastrophic forgetting through dynamic weighted distance preservation.
LwP introduces a fundamentally different approach to continual learning by preserving the geometric structure of learned representations rather than focusing on task-specific outputs. The key innovation is the Dynamic Weighted Distance Preservation (DWDP) loss, which maintains pairwise distances between representations while dynamically weighting based on label similarity.
- No Replay Buffer Required: Privacy-preserving approach suitable for sensitive data
- Dynamic Weighting: Prevents conflicts between classification and preservation objectives
- Multiple Distance Metrics: Support for L2, RBF, Cosine, and RKD distance preservation
- Comprehensive Evaluation: Extensive benchmarking across multiple datasets and scenarios
- Distribution Shift Robustness: Superior performance under domain shifts
- Modular Architecture: Easy integration with various backbone networks
| Dataset | Type | Tasks | Description |
|---|---|---|---|
| CelebA | Image | 10 | Face attribute classification (200K+ images) |
| FairFace | Image | 7 | Demographic attribute classification |
| BDD100K | Image | 4 | Autonomous driving scene understanding |
| PhysiQ | Time-series | 3 | Physiotherapy exercise recognition |
- LwP (proposed): Learning with Preserving
- ER: Experience Replay
- DER: Dark Experience Replay
- DERPP: Dark Experience Replay++
- LwF: Learning without Forgetting
- EWC: Elastic Weight Consolidation
- SI: Synaptic Intelligence
- GSS: Gradient-based Sample Selection
- FDR: Feature Distillation Replay
- DVC: Dynamic VC-dimension
- OBC: Online Balanced Continual Learning
- MTL: Standard Multi-Task Learning
- PCGrad: Projection Conflict-based Gradient Descent
- IMTL: Impartial Multi-Task Learning
- NashMTL: Nash Multi-Task Learning
- Python 3.8+
- PyTorch 1.12+
- CUDA 11.0+ (for GPU acceleration)
# Clone the repository
git clone <repository-url>
cd lwp-framework
# Install dependencies
pip install torch torchvision pandas numpy matplotlib seaborn scipy scikit-learn datasets wandb tqdm
# Optional: Install additional dependencies for specific features
pip install timm # For Vision Transformer support# Continual learning with LwP on CelebA dataset
python main.py --job cl --model lwp --dataset celeba --num_seed 5
# Multi-task learning on PhysiQ dataset
python main.py --job mtl --model mtl --dataset physiq --num_seed 3
# Hybrid MTL-to-CL training
python main.py --job mtltocl --model lwp --dataset bdd100k --num_seed 5# Custom hyperparameters for LwP
python main.py --job cl --model lwp --dataset celeba \
--lam_dwdp 0.05 --lam_old 1.0 --dist_method rbf \
--architecture resnet50 --input_size 224 \
--epochs 30 --batch_size 128 --lr 0.001
# Distribution shift experiments on BDD100K
python main.py --job cl --model lwp --dataset bdd100k \
--continual_learning_mode weather_shift \
--filter_weather clear --filter_timeofday daytimeMost datasets are automatically downloaded when first used:
# CelebA and FairFace download automatically
python main.py --job cl --model lwp --dataset celeba
# PhysiQ requires manual setup (see below)- Visit Berkeley DeepDrive
- Create account and request access
- Download:
bdd100k_images_100k.zip(images)bdd100k_labels_release.zip(labels)
# Setup directory structure
mkdir -p data/BDD100k/bdd100k
# Extract files
cd data/BDD100k/bdd100k
unzip /path/to/bdd100k_images_100k.zip
unzip /path/to/bdd100k_labels_release.zip# Create directory
mkdir -p data/PHYSIQ
# Place CSV files in format: S{subject}_E{exercise}_R_{angle}_0.csv
# Example: S4_E1_R_30_0.csv, S4_E1_R_60_0.csv, etc.# View all results for a dataset
python eval_v2.py --dataset celeba --job cl
# Filter by specific model
python eval_v2.py --dataset physiq --job cl --model lwp# Compare different hyperparameter values
python eval_v2.py --dataset bdd100k --job cl --model lwp \
--compare_param lam_dwdp --save_plots
# Filter by other hyperparameters
python eval_v2.py --dataset bdd100k --job cl --model lwp \
--compare_param dist_method \
--filter_hparams lam_dwdp=0.05 lam_old=1.0 \
--save_plots# Weather shift experiments
python eval_v2.py --dataset bdd100k --job cl --model lwp \
--continual_learning_mode weather_shift
# Time-of-day shift experiments
python eval_v2.py --dataset bdd100k --job cl --model lwp \
--continual_learning_mode time_shiftThe LwP model consists of three main components:
- Encoder: Feature extraction backbone (ResNet, ViT, etc.)
- Task Predictors: Task-specific classification heads
- Preservation Mechanism: Frozen copies for DWDP loss computation
The core innovation of LwP is the DWDP loss:
# Basic distance preservation
distance_loss = ||D(z) - D(z')||²
# Dynamic weighting based on label similarity
weighted_loss = ||(D(z) - D(z')) ⊙ M(y)||²Where:
z, z': Past and current representationsD(·): Distance matrix computationM(y): Dynamic mask based on label similarity⊙: Element-wise multiplication
LwP supports multiple distance metrics:
- L2 (Euclidean):
||z_i - z_j||₂ - RBF Kernel:
exp(-||z_i - z_j||₂²) - Cosine Similarity:
⟨z_i, z_j⟩ / (||z_i||₂ ||z_j||₂) - RKD: Relational Knowledge Distillation
--job: Training paradigm (cl,mtl,mtltocl)--model: Model architecture--dataset: Dataset name--epochs: Training epochs per task--batch_size: Batch size--lr: Learning rate--num_seed: Number of random seeds
--architecture: Backbone architecture--z_dim: Representation dimension--pretrain: Use ImageNet pretrained weights--input_size: Input image size
--lam_dwdp: DWDP loss weight--lam_old: Old task loss weight--dist_method: Distance metric--disable_dynamic: Disable dynamic weighting
--continual_learning_mode: Shift type--filter_weather: Weather condition filter--filter_timeofday: Time-of-day filter--filter_scene: Scene type filter
Create custom configuration files for reproducible experiments:
# configs/celeba_lwp.py
config = {
'job': 'cl',
'model': 'lwp',
'dataset': 'celeba',
'architecture': 'resnet18',
'input_size': 64,
'epochs': 20,
'batch_size': 256,
'lr': 0.0001,
'lam_dwdp': 0.01,
'lam_old': 1.0,
'dist_method': 'orig',
'num_seed': 5
}- Accuracy: Task-specific classification accuracy
- Backward Transfer: Performance change on previous tasks
- Forward Transfer: Performance on future tasks
- F1 Score: Balanced precision and recall
- Calibration Error: Prediction confidence calibration
The evaluation framework provides comprehensive visualization:
# Generate confusion matrices
python eval_v2.py --dataset celeba --job cl --model lwp --save_plots
# Hyperparameter comparison plots
python eval_v2.py --dataset bdd100k --job cl --model lwp \
--compare_param lam_dwdp --save_plots- WandB Integration: Real-time experiment tracking
- Local Logging: Comprehensive result storage
- Metrics Export: CSV/JSON format for analysis
If you use this code in your research, please cite our paper:
@inproceedings{wang2026lwp,
title={Learning with Preserving for Continual Multitask Learning},
author={Wang, Hanchen David and Bae, Siwoo and Chen, Zirong and Ma, Meiyi},
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
year={2026},
organization={AAAI}
}We welcome contributions to improve the LwP framework:
- Fork the repository
- Create a feature branch
- Make your changes
- Add tests and documentation
- Submit a pull request
This project is licensed under the MIT License - see the LICENSE file for details.
- Dataset providers for making their data publicly available
- PyTorch team for the excellent deep learning framework
- The continual learning research community for inspiring this work
For questions, issues, or collaborations, please:
- Open an issue on GitHub
- Contact the corresponding author: Hanchen David Wang (hanchen.wang.1@vanderbilt.edu)
- Hanchen David Wang (Corresponding Author)* - Vanderbilt University
- Siwoo Bae* - Vanderbilt University
- Zirong Chen - Vanderbilt University
- Meiyi Ma - Vanderbilt University
*Equal contribution
Note: This implementation is designed for research purposes. For production use, additional testing and optimization may be required.