A PyTorch implementation of the paper:
Deep Bilateral Learning for Real-Time Image Enhancement
MichaΓ«l Gharbi, Jiawen Chen, Jonathan T. Barron, Samuel W. Hasinoff, FrΓ©do Durand
ACM Transactions on Graphics (SIGGRAPH), 2017
π Paper | π Project Page
- Real-time processing: Process high-resolution images in milliseconds
- Bilateral learning: Learn locally-affine color transformations in bilateral space
- OpenCV + PyTorch: Efficient image loading and GPU-accelerated inference
- Flexible training: Support for any paired input/output image dataset
- TensorBoard logging: Monitor training with visualizations
HDRNet learns to enhance images by:
- Low-res stream: Process a downsampled version to predict a 3D bilateral grid of affine coefficients
- Guide network: Learn a guidance map from full-resolution input
- Bilateral slicing: Use the guide to slice the grid and apply transformations at full resolution
Full-res Input βββββββββββββββββββββββββββββββββββββββ¬ββββββββββββββ> Output
β β β
β β β
βΌ βΌ β
Low-res Input ββ> Spatial Features ββ> Bilateral Grid ββ> Bilateral Slice & Apply
β β
βΌ β
Global Features ββββββββββββ
- Python 3.8+
- PyTorch 2.0+
- OpenCV 4.8+
- CUDA (recommended for training)
Organize your data in paired input/output folders:
data/
βββ train/
β βββ input/
β β βββ image001.jpg
β β βββ image002.jpg
β β βββ ...
β βββ output/
β βββ image001.jpg (enhanced version)
β βββ image002.jpg
β βββ ...
βββ val/
βββ input/
βββ output/
Note: Input and output images must have matching filenames.
- MIT-Adobe FiveK: 5000 RAW photos with 5 expert retouches
- Custom pairs: Any before/after image pairs (e.g., Lightroom edits)
- Synthetic: Generated input/output pairs for specific enhancements
python train.py \
--data_root data/train \
--val_root data/val \
--epochs 100 \
--batch_size 4 \
--lr 1e-4python train.py \
--data_root data/train \
--val_root data/val \
--input_dir input \
--output_dir output \
--low_res_size 256 \
--full_res_size 512 \
--base_features 8 \
--grid_depth 8 \
--batch_size 4 \
--epochs 100 \
--lr 1e-4 \
--l1_weight 1.0 \
--l2_weight 0.0 \
--checkpoint_dir checkpoints \
--log_dir logs \
--device cudapython train.py \
--data_root data/train \
--resume checkpoints/checkpoint_epoch50.pth \
--epochs 100tensorboard --logdir logspython inference.py \
--input path/to/image.jpg \
--output path/to/enhanced.jpg \
--checkpoint checkpoints/best_model.pthpython inference.py \
--input path/to/images/ \
--output path/to/enhanced/ \
--checkpoint checkpoints/best_model.pthpython inference.py \
--input image.jpg \
--output enhanced.jpg \
--checkpoint model.pth \
--visualize \
--show_guide \
--show_gridfrom hdrnet import create_hdrnet
from inference import enhance_image, create_enhancer
# Option 1: One-time enhancement
enhanced = enhance_image(
'input.jpg',
checkpoint_path='checkpoints/best_model.pth',
device='cuda'
)
# Option 2: Create reusable enhancer
enhance = create_enhancer('checkpoints/best_model.pth', device='cuda')
enhanced1 = enhance('image1.jpg')
enhanced2 = enhance('image2.jpg')
enhanced3 = enhance(numpy_array) # Also accepts numpy arraysfrom hdrnet import create_hdrnet
config = {
'low_res_size': 256,
'base_features': 8,
'grid_depth': 8,
'use_guide_nn': True
}
model = create_hdrnet(config)| Metric | Value |
|---|---|
| PSNR | ~28-32 dB (dataset dependent) |
| SSIM | ~0.92-0.96 |
| Speed | ~15ms @ 1080p (RTX 3090) |
| Parameters | ~482K |
hdrnet-pytorch/
βββ hdrnet/
β βββ __init__.py # Package exports
β βββ model.py # HDRNet architecture
β βββ layers.py # Bilateral grid operations
β βββ dataset.py # Data loading with OpenCV
β βββ utils.py # Utilities and metrics
βββ train.py # Training script
βββ inference.py # Inference script
βββ requirements.txt # Dependencies
βββ README.md
The training script supports combined losses:
python train.py \
--l1_weight 1.0 \
--l2_weight 0.5 \
--perceptual_weight 0.1 # Uses VGG features# Standard HDRNet
from hdrnet import HDRNet
model = HDRNet(low_res_size=256, grid_depth=8)
# Lightweight curve-based variant
from hdrnet import HDRNetCurves
model = HDRNetCurves(num_control_points=16)If you use this code, please cite the original paper:
@article{gharbi2017deep,
title={Deep Bilateral Learning for Real-Time Image Enhancement},
author={Gharbi, Micha{\"e}l and Chen, Jiawen and Barron, Jonathan T and Hasinoff, Samuel W and Durand, Fr{\'e}do},
journal={ACM Transactions on Graphics (TOG)},
volume={36},
number={4},
pages={1--12},
year={2017},
publisher={ACM}
}This project is licensed under the MIT License - see the LICENSE file for details.
- Original HDRNet paper and authors
- PyTorch team
- OpenCV library