Skip to content

AICPS-Lab/lwp

Repository files navigation

Learning with Preserving (LwP) - Continual Multitask Learning (CMTL) Framework

This repository contains the official implementation of Learning with Preserving (LwP), a novel framework for continual learning that addresses catastrophic forgetting through dynamic weighted distance preservation.

Overview

LwP introduces a fundamentally different approach to continual learning by preserving the geometric structure of learned representations rather than focusing on task-specific outputs. The key innovation is the Dynamic Weighted Distance Preservation (DWDP) loss, which maintains pairwise distances between representations while dynamically weighting based on label similarity.

Key Features

  • No Replay Buffer Required: Privacy-preserving approach suitable for sensitive data
  • Dynamic Weighting: Prevents conflicts between classification and preservation objectives
  • Multiple Distance Metrics: Support for L2, RBF, Cosine, and RKD distance preservation
  • Comprehensive Evaluation: Extensive benchmarking across multiple datasets and scenarios
  • Distribution Shift Robustness: Superior performance under domain shifts
  • Modular Architecture: Easy integration with various backbone networks

Supported Datasets

Dataset Type Tasks Description
CelebA Image 10 Face attribute classification (200K+ images)
FairFace Image 7 Demographic attribute classification
BDD100K Image 4 Autonomous driving scene understanding
PhysiQ Time-series 3 Physiotherapy exercise recognition

Supported Models

Continual Learning Methods

  • LwP (proposed): Learning with Preserving
  • ER: Experience Replay
  • DER: Dark Experience Replay
  • DERPP: Dark Experience Replay++
  • LwF: Learning without Forgetting
  • EWC: Elastic Weight Consolidation
  • SI: Synaptic Intelligence
  • GSS: Gradient-based Sample Selection
  • FDR: Feature Distillation Replay
  • DVC: Dynamic VC-dimension
  • OBC: Online Balanced Continual Learning

Multi-Task Learning Methods

  • MTL: Standard Multi-Task Learning
  • PCGrad: Projection Conflict-based Gradient Descent
  • IMTL: Impartial Multi-Task Learning
  • NashMTL: Nash Multi-Task Learning

Installation

Prerequisites

  • Python 3.8+
  • PyTorch 1.12+
  • CUDA 11.0+ (for GPU acceleration)

Setup

# Clone the repository
git clone <repository-url>
cd lwp-framework

# Install dependencies
pip install torch torchvision pandas numpy matplotlib seaborn scipy scikit-learn datasets wandb tqdm

# Optional: Install additional dependencies for specific features
pip install timm  # For Vision Transformer support

Quick Start

Basic Training

# Continual learning with LwP on CelebA dataset
python main.py --job cl --model lwp --dataset celeba --num_seed 5

# Multi-task learning on PhysiQ dataset
python main.py --job mtl --model mtl --dataset physiq --num_seed 3

# Hybrid MTL-to-CL training
python main.py --job mtltocl --model lwp --dataset bdd100k --num_seed 5

Advanced Configuration

# Custom hyperparameters for LwP
python main.py --job cl --model lwp --dataset celeba \
    --lam_dwdp 0.05 --lam_old 1.0 --dist_method rbf \
    --architecture resnet50 --input_size 224 \
    --epochs 30 --batch_size 128 --lr 0.001

# Distribution shift experiments on BDD100K
python main.py --job cl --model lwp --dataset bdd100k \
    --continual_learning_mode weather_shift \
    --filter_weather clear --filter_timeofday daytime

Dataset Setup

Automatic Download (Recommended)

Most datasets are automatically downloaded when first used:

# CelebA and FairFace download automatically
python main.py --job cl --model lwp --dataset celeba

# PhysiQ requires manual setup (see below)

Manual Setup

BDD100K Dataset

  1. Visit Berkeley DeepDrive
  2. Create account and request access
  3. Download:
    • bdd100k_images_100k.zip (images)
    • bdd100k_labels_release.zip (labels)
# Setup directory structure
mkdir -p data/BDD100k/bdd100k

# Extract files
cd data/BDD100k/bdd100k
unzip /path/to/bdd100k_images_100k.zip
unzip /path/to/bdd100k_labels_release.zip

PhysiQ Dataset

# Create directory
mkdir -p data/PHYSIQ

# Place CSV files in format: S{subject}_E{exercise}_R_{angle}_0.csv
# Example: S4_E1_R_30_0.csv, S4_E1_R_60_0.csv, etc.

Evaluation and Analysis

Basic Evaluation

# View all results for a dataset
python eval_v2.py --dataset celeba --job cl

# Filter by specific model
python eval_v2.py --dataset physiq --job cl --model lwp

Hyperparameter Analysis

# Compare different hyperparameter values
python eval_v2.py --dataset bdd100k --job cl --model lwp \
    --compare_param lam_dwdp --save_plots

# Filter by other hyperparameters
python eval_v2.py --dataset bdd100k --job cl --model lwp \
    --compare_param dist_method \
    --filter_hparams lam_dwdp=0.05 lam_old=1.0 \
    --save_plots

Distribution Shift Analysis

# Weather shift experiments
python eval_v2.py --dataset bdd100k --job cl --model lwp \
    --continual_learning_mode weather_shift

# Time-of-day shift experiments
python eval_v2.py --dataset bdd100k --job cl --model lwp \
    --continual_learning_mode time_shift

Model Architecture

LwP Framework

The LwP model consists of three main components:

  1. Encoder: Feature extraction backbone (ResNet, ViT, etc.)
  2. Task Predictors: Task-specific classification heads
  3. Preservation Mechanism: Frozen copies for DWDP loss computation

Dynamic Weighted Distance Preservation (DWDP)

The core innovation of LwP is the DWDP loss:

# Basic distance preservation
distance_loss = ||D(z) - D(z')||²

# Dynamic weighting based on label similarity
weighted_loss = ||(D(z) - D(z')) ⊙ M(y)||²

Where:

  • z, z': Past and current representations
  • D(·): Distance matrix computation
  • M(y): Dynamic mask based on label similarity
  • : Element-wise multiplication

Distance Metrics

LwP supports multiple distance metrics:

  • L2 (Euclidean): ||z_i - z_j||₂
  • RBF Kernel: exp(-||z_i - z_j||₂²)
  • Cosine Similarity: ⟨z_i, z_j⟩ / (||z_i||₂ ||z_j||₂)
  • RKD: Relational Knowledge Distillation

Configuration

Command Line Arguments

Core Training

  • --job: Training paradigm (cl, mtl, mtltocl)
  • --model: Model architecture
  • --dataset: Dataset name
  • --epochs: Training epochs per task
  • --batch_size: Batch size
  • --lr: Learning rate
  • --num_seed: Number of random seeds

Model Configuration

  • --architecture: Backbone architecture
  • --z_dim: Representation dimension
  • --pretrain: Use ImageNet pretrained weights
  • --input_size: Input image size

LwP-Specific

  • --lam_dwdp: DWDP loss weight
  • --lam_old: Old task loss weight
  • --dist_method: Distance metric
  • --disable_dynamic: Disable dynamic weighting

Distribution Shifts (BDD100K)

  • --continual_learning_mode: Shift type
  • --filter_weather: Weather condition filter
  • --filter_timeofday: Time-of-day filter
  • --filter_scene: Scene type filter

Configuration Files

Create custom configuration files for reproducible experiments:

# configs/celeba_lwp.py
config = {
    'job': 'cl',
    'model': 'lwp',
    'dataset': 'celeba',
    'architecture': 'resnet18',
    'input_size': 64,
    'epochs': 20,
    'batch_size': 256,
    'lr': 0.0001,
    'lam_dwdp': 0.01,
    'lam_old': 1.0,
    'dist_method': 'orig',
    'num_seed': 5
}

Results and Analysis

Performance Metrics

  • Accuracy: Task-specific classification accuracy
  • Backward Transfer: Performance change on previous tasks
  • Forward Transfer: Performance on future tasks
  • F1 Score: Balanced precision and recall
  • Calibration Error: Prediction confidence calibration

Visualization

The evaluation framework provides comprehensive visualization:

# Generate confusion matrices
python eval_v2.py --dataset celeba --job cl --model lwp --save_plots

# Hyperparameter comparison plots
python eval_v2.py --dataset bdd100k --job cl --model lwp \
    --compare_param lam_dwdp --save_plots

Logging

  • WandB Integration: Real-time experiment tracking
  • Local Logging: Comprehensive result storage
  • Metrics Export: CSV/JSON format for analysis

Citation

If you use this code in your research, please cite our paper:

@inproceedings{wang2026lwp,
  title={Learning with Preserving for Continual Multitask Learning},
  author={Wang, Hanchen David and Bae, Siwoo and Chen, Zirong and Ma, Meiyi},
  booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
  year={2026},
  organization={AAAI}
}

Contributing

We welcome contributions to improve the LwP framework:

  1. Fork the repository
  2. Create a feature branch
  3. Make your changes
  4. Add tests and documentation
  5. Submit a pull request

License

This project is licensed under the MIT License - see the LICENSE file for details.

Acknowledgments

  • Dataset providers for making their data publicly available
  • PyTorch team for the excellent deep learning framework
  • The continual learning research community for inspiring this work

Contact

For questions, issues, or collaborations, please:

Authors

  • Hanchen David Wang (Corresponding Author)* - Vanderbilt University
  • Siwoo Bae* - Vanderbilt University
  • Zirong Chen - Vanderbilt University
  • Meiyi Ma - Vanderbilt University

*Equal contribution


Note: This implementation is designed for research purposes. For production use, additional testing and optimization may be required.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages