Training framework for the PRX text-to-image diffusion models by Photoroom.
Read the full story on the Hugging Face blog.
PRX is a transformer-based latent diffusion model trained with flow matching. This repository contains everything needed to train and evaluate PRX models, including:
- A patchified transformer denoiser
- Support for multiple text encoders (T5, T5-Gemma2B, Qwen3) and VAEs (AutoencoderKL, DC-AE)
- Distributed training via MosaicML Composer with FSDP
- Training algorithms: EMA, REPA/iREPA, SPRINT, TREAD, contrastive flow matching, Perceptual losses (P-DINO, LPIPS), etc.
- Evaluation metrics: FID, CMMD, DINO-MMD
Pre-trained PRX models are available on Hugging Face and can be used directly with diffusers.
Requires Python 3.11+.
uv sync
# With optional dependencies
uv sync --extra streaming # MosaicML Streaming dataset support
uv sync --extra lpips # LPIPS perceptual loss
uv sync --all-extras # EverythingTraining is configured with Hydra YAML files. See configs/yamls/ for examples. The repository includes all the training configurations used in the benchmarks presented in the blog post.
composer -m prx.training.train --config-path=configs/yamls hydra/launcher=basicPRX trains on MosaicML Streaming (MDS) datasets, organized into aspect-ratio buckets. We provide a conversion script that takes WebDataset-style tar files and produces AR-bucketed MDS shards ready for training.
- Download the fine-t2i dataset:
HF_HUB_CACHE=/path/to/cache huggingface-cli download ma-xu/fine-t2i --repo-type dataset- Convert to AR-bucketed MDS (images resized to 1024-base AR buckets, 27 buckets, patch_size=32):
uv run scripts/fine-t2i-to-mds.py \
--input /path/to/cache/hub/datasets--ma-xu--fine-t2i/snapshots/<hash> \
--output /path/to/output/fine-t2i \
--workers 16This produces one MDS subdirectory per aspect ratio (e.g. 0.667/, 1.000/, 1.500/), each containing sharded MDS files with a merged index.json.
- Train by pointing a dataset config at the output directory. Create a dataset YAML (e.g.
configs/yamls/dataset/train_fine_t2i.yaml):
# @package dataset.train_dataset
_target_: prx.dataset.StreamingProcessedDataset
local:
- /path/to/output/fine-t2i
caption_keys:
- [prompt, 0.5]
- [enhanced_prompt, 0.5]
has_text_latents: false
text_tower: ${diffusion_text_tower.preset_name}
cache_limit: 8tb
drop_last: true
shuffle: true
batching_method: device_per_stream
num_workers: 8
persistent_workers: true
pin_memory: true
transforms:
- _target_: prx.dataset.transforms.ArAwareResize
default_image_size: ${image_size}
patch_size_pixels: ${patch_size_pixels}
transforms_targets:
- imageNote:
ArAwareResizeis still used at training time even though images were already resized during MDS export. The MDS conversion targets a fixed 1024-base resolution, but the training config may use a differentimage_size(e.g. 512 for early-stage training).ArAwareResizeensures images are resized to match the model's current resolution and patch grid, and also handles any JPEG decode size differences.
Then launch training referencing your dataset config:
composer -m prx.training.train \
--config-path=configs/yamls \
hydra/launcher=basic \
dataset/train_dataset=train_fine_t2iSee the JIT-benchmark configs for full training configuration examples.
Apache 2.0
PRX is built by the Photoroom machine learning team and large parts of this codebase were built on top of an existing private Photoroom codebase. The following previous team members made significant contributions to the foundations of this project and deserve credit, even though their work may not appear in the public git history:
PRX team: David Bertoin, Roman Frigg, Jon Almazán, Eliot Andres

