In [None]:
import jax
from nanodl import DiffusionModel, DiffusionDataParallelTrainer

key = jax.random.PRNGKey(0)
image_size = 32
widths = [32, 64, 128]
block_depth = 2
input_shape = (3, image_size, image_size, 3)
images = jax.random.normal(key, input_shape)

diffusion_model = DiffusionModel(image_size, widths, block_depth)
params = diffusion_model.init(key, images)
pred_noises, pred_images = diffusion_model.apply(params, images)
print(pred_noises.shape, pred_images.shape)

generated_images = diffusion_model.apply(params, 
                                         num_images=5, 
                                         diffusion_steps=5, 
                                         method=diffusion_model.generate)
print(generated_images.shape)

# Training on your data
# Note: saved params are often different from training weights, use the saved params for generation
dataloader = [(images)] * 10
trainer = DiffusionDataParallelTrainer(diffusion_model, images.shape, 'params.pkl')
trainer.train(dataloader, 10, dataloader)
print(trainer.evaluate(dataloader))