Skip to content

aengusng8/DriftingModel

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Generative Modeling via Drifting

arXiv Twitter

If you find this code helpful, we'd appreciate a ⭐ — thank you!

Drifting Model Banner

This repository contains a PyTorch implementation of Drifting Models, a novel generative modeling paradigm that evolves the pushforward distribution during training time rather than inference time.

The implementation focuses on the MNIST dataset using a DiT (Diffusion Transformer) backbone, supporting accelerate for mixed-precision training.

🚀 Quick Start

We keep this repository clean with only 3 main files: dit.py, train.py, and train.sh.

Install dependencies:

pip install torch torchvision timm accelerate tqdm

Run training:

bash train.sh
# Or with custom arguments
bash train.sh --num-gpus 2 --batch-size 512 --epochs 200

🧠 Understanding Drifting Models

Unlike Diffusion or Flow Matching models, which solve differential equations (ODEs/SDEs) iteratively at inference time to generate data, Drifting Models perform One-Step Generation (1-NFE).

1. The Core Idea: Training-Time Evolution

In traditional Deep Learning, training is an iterative process of updating weights. Drifting Models leverage this iterative nature to evolve the distribution.

  • The Generator ($f_\theta$): Maps noise $\epsilon$ directly to data $x$.
  • The Pushforward: The generator creates a distribution $q = (f_\theta)# p\epsilon$.
  • Evolution: As training progresses (iteration $i \to i+1$), the generator updates $f_{\theta_{i+1}}$, implicitly moving the samples $x$.

2. The Drifting Field ($V$)

To guide this evolution, the paper defines a vector field $V_{p,q}(x)$ that tells a generated sample where to move. The goal is to reach an Equilibrium where $V=0$, which implies the generated distribution $q$ matches the data distribution $p$.

The field is inspired by Mean-Shift clustering and consists of two forces: $$V(x) = \underbrace{V_p^+(x)}{\text{Attraction}} - \underbrace{V_q^-(x)}{\text{Repulsion}}$$

  • Attraction: Pulls generated samples toward real data points ($y^+$).
  • Repulsion: Pushes generated samples away from other generated samples ($y^-$) to prevent mode collapse and spread the distribution.

3. Training Objective

The model is trained using a regression loss. For a generated point $x$, we calculate its drift vector $V(x)$, freeze this target, and force the network to step towards it.

$$\mathcal{L} = \mathbb{E}_{\epsilon} \left[ || f_\theta(\epsilon) - \text{stopgrad}(f_\theta(\epsilon) + V(f_\theta(\epsilon))) ||^2 \right]$$

This objective effectively minimizes the drift magnitude $||V||^2$, pushing the system toward the equilibrium $V=0$.

4. Implementation Details

  • Kernel: The influence of neighbors is weighted by a kernel $k(x, y) = \exp(-\frac{||x-y||}{\tau})$.
  • Doubly Stochastic Normalization: To ensure stability, the kernel weights are normalized using a geometric mean of row-wise and column-wise softmax operations (Algorithm 2).
  • Temperature Scaling: The temperature $\tau$ is scaled by $\sqrt{D}$ (where $D$ is feature dimension) to make the kernel robust to high-dimensional spaces.

🙏 Acknowledgement

This implementation is based on the paper:

Generative Modeling via Drifting
Mingyang Deng, He Li, Tianhong Li, Yilun Du, Kaiming He

About

PyTorch implementation of Drifting Models by Kaiming He et al.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published