In [1]:
import jax
import jax.numpy as jnp
from nanodl import ArrayDataset, DataLoader
from nanodl import DiffusionModel, DiffusionDataParallelTrainer

image_size = 32
block_depth = 2
batch_size = 8
widths = [32, 64, 128]
key = jax.random.PRNGKey(0)

In [2]:
# Use actual images
images = jnp.ones((101, image_size, image_size, 3))
dataset = ArrayDataset(images) 
dataloader = DataLoader(dataset, 
                        batch_size=batch_size, 
                        shuffle=True, 
                        drop_last=False) 

In [3]:
# Create diffusion model
diffusion_model = DiffusionModel(image_size, widths, block_depth)
params = diffusion_model.init(key, images)
pred_noises, pred_images = diffusion_model.apply(params, images)
print(pred_noises.shape, pred_images.shape)

(101, 32, 32, 3) (101, 32, 32, 3)


In [4]:
# Training on your data
# Note: saved params are often different from training weights, use the saved params for generation
trainer = DiffusionDataParallelTrainer(diffusion_model, 
                                       input_shape=images.shape, 
                                       weights_filename='params.pkl', 
                                       learning_rate=1e-4)
trainer.train(dataloader, 10, dataloader)
print(trainer.evaluate(dataloader))

Number of parameters: 1007395
Number of accelerators: 1
Epoch 1, Train Loss: 7.979574203491211
Epoch 1, Val Loss: 24.317951202392578
New best validation score achieved, saving model...
Epoch 2, Train Loss: 7.75728702545166
Epoch 2, Val Loss: 23.518024444580078
New best validation score achieved, saving model...
Epoch 3, Train Loss: 7.392527103424072
Epoch 3, Val Loss: 22.308382034301758
New best validation score achieved, saving model...
Epoch 4, Train Loss: 6.846263408660889
Epoch 4, Val Loss: 20.62131690979004
New best validation score achieved, saving model...
Epoch 5, Train Loss: 6.1358747482299805
Epoch 5, Val Loss: 18.36245346069336
New best validation score achieved, saving model...
Epoch 6, Train Loss: 5.278435230255127
Epoch 6, Val Loss: 15.812017440795898
New best validation score achieved, saving model...
Epoch 7, Train Loss: 4.328006267547607
Epoch 7, Val Loss: 13.123092651367188
New best validation score achieved, saving model...
Epoch 8, Train Loss: 3.3344056606292725
Epo

In [6]:
# Generate some samples
params = trainer.load_params('params.pkl')
generated_images = diffusion_model.apply({'params': params}, 
                                         num_images=5, 
                                         diffusion_steps=5, 
                                         method=diffusion_model.generate)
print(generated_images.shape)

(5, 32, 32, 3)
