This repository is the official implementation of the AHPA (Adaptive Hierarchical Representation Alignment) method, based on SiT and REPA. It demonstrates how dynamic routing over intermediate representations can enhance representation alignment.
cd AHPA
# Create environment (Python 3.10+)
conda create -n ahpa python=3.10 -y && conda activate ahpa
# Install dependencies (same as iREPA)
pip install -r requirements.txtYou will need the pre-cached ImageNet latents and VAE/Encoder pretrained models.
Download the full pre-processed dataset and weights from the iREPA/REPA-E collections to the current directory:
hf download REPA-E/iREPA-collections --include "data/**" --local-dir "."
hf download REPA-E/iREPA-collections --include "pretrained_models/**" --local-dir "."
hf download REPA-E/iREPA-collections --include "VIRTUAL_imagenet256_labeled.npz" --local-dir "."Ensure your workspace looks like this:
AHPA/
├── data/
│ ├── imagenet-latents-images/
│ └── imagenet-latents-sdvae-ft-mse-f8d4/
├── pretrained_models/
│ ├── sdvae-ft-mse-f8d4.pt
│ └── sdvae-ft-mse-f8d4-latents-stats.pt
├── VIRTUAL_imagenet256_labeled.npz
├── ldm
└── guided-diffusion/
└── evaluations/evaluator_batch.py
To significantly speed up training, we recommend pre-extracting the necessary VAE features. We provide a script to do this in distributed fashion.
cd ldm
torchrun --nproc_per_node=8 extract_vae_features.py \
--data-dir ../data \
--output-dir ../data/vae_layer_features \
--vae-ckpt ../pretrained_models/sdvae-ft-mse-f8d4.pt \
--batch-size 64Once extracted, you can append --cached-vae-feature-dir ../data/vae_layer_features to your training commands to skip online extraction.
This repository provides instructions for Latent Diffusion (SiT) experiments. All code lies within the ldm directory.
We provide a convenient bash script (run_ahpa.sh) that automates the entire pipeline: Training -> Generation -> Evaluation chronologically.
cd ldm
bash run_ahpa.shYou can configure your desired experiments and models within the run_ahpa.sh.
To train the models with different loss modes (AHPA, REPA, SRA2):
AHPA
cd ldm
accelerate launch train.py --config configs/ahpa.yaml \
--model "SiT-XL/2" \
--encoder-depth 7 \
--data-dir ../data \
--exp-name "ahpa-sit-xl" \
--max-train-steps 400000 \
--learning-rate 2e-4 \
--max-grad-norm 2.0 \
--batch-size 256baseline REPA
cd ldm
accelerate launch train.py --config configs/repa.yaml \
--model "SiT-XL/2" \
--encoder-depth 7 \
--data-dir ../data \
--exp-name "repa-sit-xl" \
--max-train-steps 400000 \
--learning-rate 2e-4 \
--max-grad-norm 2.0 \
--batch-size 256SRA2
cd ldm
accelerate launch train.py --config configs/sra2.yaml \
--model "SiT-XL/2" \
--encoder-depth 7 \
--data-dir ../data \
--exp-name "sra2-sit-xl" \
--max-train-steps 400000 \
--learning-rate 2e-4 \
--max-grad-norm 2.0 \
--batch-size 256To generate samples using a trained checkpoint:
cd ldm
python generate_all.py \
--exp-name "ahpa-sit-xl" \
--model "SiT-XL/2" \
--steps 0400000 \
--sample-dir samples \
--num-samples 50000 \
--nproc 8 \
--encoder-depth 7 \
--sample-list-out "ahpa-sit-xl_samples.txt"We use the ADM evaluation suite to compute ImageNet 256x256 metrics. Make sure VIRTUAL_imagenet256_labeled.npz is in the root directory.
python ../guided-diffusion/evaluations/evaluator_batch.py \
--ref_batch ../VIRTUAL_imagenet256_labeled.npz \
--sample_list ahpa-sit-xl_samples.txt \
--log eval_results_ahpa-sit-xl.log