If you find this code helpful, we'd appreciate a ⭐ — thank you!
This repository contains a PyTorch implementation of Drifting Models, a novel generative modeling paradigm that evolves the pushforward distribution during training time rather than inference time.
The implementation focuses on the MNIST dataset using a DiT (Diffusion Transformer) backbone, supporting accelerate for mixed-precision training.
We keep this repository clean with only 3 main files: dit.py, train.py, and train.sh.
Install dependencies:
pip install torch torchvision timm accelerate tqdmRun training:
bash train.sh
# Or with custom arguments
bash train.sh --num-gpus 2 --batch-size 512 --epochs 200Unlike Diffusion or Flow Matching models, which solve differential equations (ODEs/SDEs) iteratively at inference time to generate data, Drifting Models perform One-Step Generation (1-NFE).
In traditional Deep Learning, training is an iterative process of updating weights. Drifting Models leverage this iterative nature to evolve the distribution.
-
The Generator (
$f_\theta$ ): Maps noise$\epsilon$ directly to data$x$ . - The Pushforward: The generator creates a distribution $q = (f_\theta)# p\epsilon$.
-
Evolution: As training progresses (iteration
$i \to i+1$ ), the generator updates$f_{\theta_{i+1}}$ , implicitly moving the samples$x$ .
To guide this evolution, the paper defines a vector field
The field is inspired by Mean-Shift clustering and consists of two forces: $$V(x) = \underbrace{V_p^+(x)}{\text{Attraction}} - \underbrace{V_q^-(x)}{\text{Repulsion}}$$
-
Attraction: Pulls generated samples toward real data points (
$y^+$ ). -
Repulsion: Pushes generated samples away from other generated samples (
$y^-$ ) to prevent mode collapse and spread the distribution.
The model is trained using a regression loss. For a generated point
This objective effectively minimizes the drift magnitude
-
Kernel: The influence of neighbors is weighted by a kernel
$k(x, y) = \exp(-\frac{||x-y||}{\tau})$ . - Doubly Stochastic Normalization: To ensure stability, the kernel weights are normalized using a geometric mean of row-wise and column-wise softmax operations (Algorithm 2).
-
Temperature Scaling: The temperature
$\tau$ is scaled by$\sqrt{D}$ (where$D$ is feature dimension) to make the kernel robust to high-dimensional spaces.
This implementation is based on the paper:
Generative Modeling via Drifting
Mingyang Deng, He Li, Tianhong Li, Yilun Du, Kaiming He
