## Installation

In [None]:
!pip install git+https://github.com/google-research/swirl-dynamics.git@main

In [None]:
!pip install tensorflow-datasets

## Imports

In [None]:
import functools

from clu import metric_writers
import jax
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import optax
import orbax.checkpoint as ocp
import tensorflow as tf
import tensorflow_datasets as tfds

from swirl_dynamics import templates
from swirl_dynamics.lib import diffusion as dfn_lib
from swirl_dynamics.lib import solvers as solver_lib
from swirl_dynamics.projects import probabilistic_diffusion as dfn

## Example I - Unconditional diffusion model with guidance

### Dataset

First we need a dataset containing samples whose distribution is to be modeled by the diffusion model. For demonstration purpose, we use the MNIST dataset provided by TensorFlow Datasets.

Our code setup accepts any Python iterable objects to be used as dataloaders. The expectation is that they should continuously yield a dictionary with a field named `x` whose corresponding value is a numpy array with shape `(batch, *spatial_dims, channels)`.

In [None]:
def get_mnist_dataset(split: str, batch_size: int):
  ds = tfds.load("mnist", split=split)
  ds = ds.map(
      # Change field name from "image" to "x" (required by `DenoisingModel`)
      # and normalize the value to [0, 1].
      lambda x: {"x": tf.cast(x["image"], tf.float32) / 255.0}
  )
  ds = ds.repeat()
  ds = ds.batch(batch_size)
  ds = ds.prefetch(tf.data.AUTOTUNE)
  ds = ds.as_numpy_iterator()
  return ds

# The standard deviation of the normalized dataset.
# This is useful for determining the diffusion scheme and preconditioning
# of the neural network parametrization.
DATA_STD = 0.31

### Architecture

Next let's define the U-Net backbone. The "Preconditioning" is to ensure that the inputs and outputs of the network are roughly standardized (for more details, see Appendix B.6. in [this paper](https://arxiv.org/abs/2206.00364)).

In [None]:
denoiser_model = dfn_lib.PreconditionedDenoiser(
    out_channels=1,
    num_channels=(64, 128),
    downsample_ratio=(2, 2),
    num_blocks=4,
    noise_embed_dim=128,
    padding="SAME",
    use_attention=True,
    use_position_encoding=True,
    num_heads=8,
    sigma_data=DATA_STD,
)

### Training

For diffusion model training, the above-defined U-Net backbone serves as a denoiser, which takes as input a batch of (isotropic Gaussian noise) corrupted samples and outputs its best guess for what the uncorrupted image would be.

Besides the backbone architecture, we also need to specify how to sample the noise levels (i.e. standard deviations) used to corrupt the samples and the weighting for each noise level in the loss function (for available options and configurations, see [`swirl_dynamics.lib.diffusion.diffusion`](https://github.com/google-research/swirl-dynamics/blob/main/swirl_dynamics/lib/diffusion/diffusion.py)):

In [None]:
diffusion_scheme = dfn_lib.Diffusion.create_variance_exploding(
    sigma=dfn_lib.tangent_noise_schedule(),
    data_std=DATA_STD,
)

model = dfn.DenoisingModel(
    # `input_shape` must agree with the expected sample shape (without the batch
    # dimension), which in this case is simply the dimensions of a single MNIST
    # sample.
    input_shape=(28, 28, 1),
    denoiser=denoiser_model,
    noise_sampling=dfn_lib.log_uniform_sampling(
        diffusion_scheme, clip_min=1e-4, uniform_grid=True,
    ),
    noise_weighting=dfn_lib.edm_weighting(data_std=DATA_STD),
)

We are now ready to define the learning parameters.

In [None]:
# !rm -R -f $workdir  # optional: clear the working directory

In [None]:
num_train_steps = 100_000  #@param
workdir = "/tmp/diffusion_demo_mnist"  #@param
train_batch_size = 32  #@param
eval_batch_size = 32  #@param
initial_lr = 0.0  #@param
peak_lr = 1e-4  #@param
warmup_steps = 1000  #@param
end_lr = 1e-6  #@param
ema_decay = 0.999  #@param
ckpt_interval = 1000  #@param
max_ckpt_to_keep = 5  #@param

To start training, we first need to initialize the trainer.

In [None]:
# NOTE: use `trainers.DistributedDenoisingTrainer` for multi-device
# training with data parallelism.
trainer = dfn.DenoisingTrainer(
    model=model,
    rng=jax.random.PRNGKey(888),
    optimizer=optax.adam(
        learning_rate=optax.warmup_cosine_decay_schedule(
            init_value=initial_lr,
            peak_value=peak_lr,
            warmup_steps=warmup_steps,
            decay_steps=num_train_steps,
            end_value=end_lr,
        ),
    ),
    # We keep track of an exponential moving average of the model parameters
    # over training steps. This alleviates the "color-shift" problems known to
    # exist in the diffusion models.
    ema_decay=ema_decay,
)

Now we are ready to kick start training. A couple of "callbacks" are passed to assist with monitoring and checkpointing.

The first step will be a little slow as Jax needs to JIT compile the step function (the same goes for the first step where evaluation is performed). Fortunately, steps after that should continue much faster.

In [None]:
templates.run_train(
    train_dataloader=get_mnist_dataset(
        split="train[:75%]", batch_size=train_batch_size
    ),
    trainer=trainer,
    workdir=workdir,
    total_train_steps=num_train_steps,
    metric_writer=metric_writers.create_default_writer(
        workdir, asynchronous=False
    ),
    metric_aggregation_steps=100,
    eval_dataloader=get_mnist_dataset(
        split="train[75%:]", batch_size=eval_batch_size
    ),
    eval_every_steps = 1000,
    num_batches_per_eval = 2,
    callbacks=(
        # This callback displays the training progress in a tqdm bar
        templates.TqdmProgressBar(
            total_train_steps=num_train_steps,
            train_monitors=("train_loss",),
        ),
        # This callback saves model checkpoint periodically
        templates.TrainStateCheckpoint(
            base_dir=workdir,
            options=ocp.CheckpointManagerOptions(
                save_interval_steps=ckpt_interval, max_to_keep=max_ckpt_to_keep
            ),
        ),
    ),
)

### Inference

#### Unconditional generation

After training is complete, the trained denoiser may be used to generate unconditional samples.

First, let's restore the model from checkpoint.

In [None]:
# Restore train state from checkpoint. By default, the move recently saved
# checkpoint is restored. Alternatively, one can directly use
# `trainer.train_state` if continuing from the training section above.
trained_state = dfn.DenoisingModelTrainState.restore_from_orbax_ckpt(
    f"{workdir}/checkpoints", step=None
)
# Construct the inference function
denoise_fn = dfn.DenoisingTrainer.inference_fn_from_state_dict(
    trained_state, use_ema=True, denoiser=denoiser_model
)

Diffusion samples are generated by plugging the trained denoising function in a stochastic differential equation (parametrized by the diffusion scheme) and solving it backwards in time.

In [None]:
sampler = dfn_lib.SdeSampler(
    input_shape=(28, 28, 1),
    integrator=solver_lib.EulerMaruyama(),
    tspan=dfn_lib.edm_noise_decay(
        diffusion_scheme, rho=7, num_steps=256, end_sigma=1e-3,
    ),
    scheme=diffusion_scheme,
    denoise_fn=denoise_fn,
    guidance_transforms=(),
    apply_denoise_at_end=True,
    return_full_paths=False,  # Set to `True` if the full sampling paths are needed
)

The sampler may be run by calling its `.generate()` function. Optionally, we may JIT compile this function so that it runs faster if repeatedly called.

In [None]:
generate = jax.jit(sampler.generate, static_argnames=('num_samples',))

In [None]:
samples = generate(
    rng=jax.random.PRNGKey(8888), num_samples=4
)

Visualize the generated samples:

In [None]:
# Plot generated samples
fig, ax = plt.subplots(1, 4, figsize=(8, 2))
for i in range(4):
  im = ax[i].imshow(samples[i, :, :, 0] * 255, cmap="gray", vmin=0, vmax=255)

plt.tight_layout()
plt.show()

#### Guided generation

To achieve 'guided' generation, we can modify a trained denoising function and tailor it to produce samples with specific desired characteristics. For instance, in an out-filling task where the goal is to generate full images from a given patch, we can guide the denoiser to create samples whose crops at certain positions precisely align with the provided patch.

In [None]:
guidance_fn = dfn_lib.InfillFromSlices(
    # This specifies location of the guide input using python slices.
    # Here it implies that the guide input corresponds the 7x7 patch in the
    # center of the image.
    slices=(slice(None), slice(11, 18), slice(11, 18)),

    # This is a parameter that controls how "hard" the denoiser pushes for
    # the conditioning to be satisfied. It is a tradeoff between strictness of
    # constraint satisfication and diversity in the generated samples.
    guide_strength=0.1,
)

This transform function is passed through the `guidance_transforms` arg of the sampler.

In [None]:
guided_sampler = dfn_lib.SdeSampler(
    input_shape=(28, 28, 1),
    integrator=solver_lib.EulerMaruyama(),
    tspan=dfn_lib.edm_noise_decay(
        diffusion_scheme, rho=7, num_steps=256, end_sigma=1e-3,
    ),
    scheme=diffusion_scheme,
    denoise_fn=denoise_fn,
    guidance_transforms=(guidance_fn,),
    apply_denoise_at_end=True,
    return_full_paths=False,
)

guided_generate = jax.jit(guided_sampler.generate, static_argnames=('num_samples',))

We construct an example guidance input from a real sample and use it to guide the sampling:

In [None]:
test_ds = get_mnist_dataset(split="test", batch_size=1)
test_example = next(iter(test_ds))["x"]
example_guidance_inputs = {'observed_slices': test_example[:, 11:18, 11:18]}

In [None]:
guided_samples = guided_generate(
    rng=jax.random.PRNGKey(66),
    num_samples=4,
    # Note that the shape of the guidance input must be compatible with
    # `sample[guidance_fn.slices]`
    guidance_inputs=example_guidance_inputs,
)

Visualize guided samples:

In [None]:
# Plot guide patch.
fig, ax = plt.subplots(1, 1, figsize=(2, 2))
im = ax.imshow(
    test_example[0, 11:18, 11:18, 0] * 255, cmap="gray", vmin=0, vmax=255
)
ax.axis("off")
ax.set_title("Guide patch")
plt.tight_layout()
plt.show()

# Plot generated samples.
fig, ax = plt.subplots(1, 4, figsize=(8, 2))
for i in range(4):
  im = ax[i].imshow(
      guided_samples[i, :, :, 0] * 255, cmap="gray", vmin=0, vmax=255
  )
  # Mark out the patch where guidance is enabled.
  square = patches.Rectangle(
      xy=(11, 11), width=7, height=7, fill=False, edgecolor='red'
  )
  ax[i].add_patch(square)
  ax[i].axis("off")
  ax[i].set_title(f"Sample #{i}")

plt.tight_layout()
plt.show()

## Example II - Conditional diffusion model

In the above example, we trained an *unconditional* diffusion model and applied conditioning at inference time. This is not always easy to do, depending on how the conditioning input relates to the samples.

Alternatively, we can directly *train a conditional model*, where the conditional signal is provided at training time as an additional input to the denoising neural network, which may then use it to compute the denoised target.

Below we show an example of how to accomplish this. We again generate samples of handwritten digits, using the MNIST dataset for training. We will condition the generation on the `x[11:18, 11:18]` patch.

### Dataset

Besides the sample in `x`, the dataset for training conditional models require a `cond` key which contains the condition signals.

In [None]:
def preproc_example(example: dict[str, tf.Tensor]):
  processed = {}
  processed["x"] = tf.cast(example["image"], tf.float32) / 255.0

  # The "channel:" prefix indicate that the conditioning signal is to be
  # incorporated by resizing and concatenating along the channel dimension.
  # This is implemented at the backbone level.
  processed["cond"] = {"channel:low_res": processed["x"][11:18, 11:18]}
  return processed


def get_cond_mnist_dataset(split: str, batch_size: int):
  ds = tfds.load("mnist", split=split)
  ds = ds.map(preproc_example)
  ds = ds.repeat()
  ds = ds.batch(batch_size)
  ds = ds.prefetch(tf.data.AUTOTUNE)
  ds = ds.as_numpy_iterator()
  return ds

DATA_STD = 0.31

### Architecture

The architecture is similar to the unconditional case. We provide additional args that specify how to resize the conditioning signal (in order to be compatible with the noisy sample for channel-wise concatenation).

In [None]:
cond_denoiser_model = dfn_lib.PreconditionedDenoiser(
    out_channels=1,
    num_channels=(64, 128),
    downsample_ratio=(2, 2),
    num_blocks=4,
    noise_embed_dim=128,
    padding="SAME",
    use_attention=True,
    use_position_encoding=True,
    num_heads=8,
    sigma_data=DATA_STD,
    cond_resize_method="cubic",
    cond_embed_dim=128,
)

### Training

The `DenoisingModel` is again similar to the unconditional case. We additionally provide the shape information of the `cond` input.

In [None]:
diffusion_scheme = dfn_lib.Diffusion.create_variance_exploding(
    sigma=dfn_lib.tangent_noise_schedule(),
    data_std=DATA_STD,
)

cond_model = dfn.DenoisingModel(
    input_shape=(28, 28, 1),
    # `cond_shape` must agree with the expected structure and shape
    # (without the batch dimension) of the `cond` input.
    cond_shape={"channel:low_res": (7, 7, 1)},
    denoiser=cond_denoiser_model,
    noise_sampling=dfn_lib.log_uniform_sampling(
        diffusion_scheme, clip_min=1e-4, uniform_grid=True,
    ),
    noise_weighting=dfn_lib.edm_weighting(data_std=DATA_STD),
)

The rest mostly repeats the unconditional training example, replacing the datasets and model with their conditional counterparts.

In [None]:
# !rm -R -f $cond_workdir  # optional: clear the working directory

In [None]:
num_train_steps = 100_000  #@param
cond_workdir = "/tmp/cond_diffusion_demo_mnist"  #@param
train_batch_size = 32  #@param
eval_batch_size = 32  #@param
initial_lr = 0.0  #@param
peak_lr = 1e-4  #@param
warmup_steps = 1000  #@param
end_lr = 1e-6  #@param
ema_decay = 0.999  #@param
ckpt_interval = 1000  #@param
max_ckpt_to_keep = 5  #@param

In [None]:
cond_trainer = dfn.DenoisingTrainer(
    model=cond_model,
    rng=jax.random.PRNGKey(888),
    optimizer=optax.adam(
        learning_rate=optax.warmup_cosine_decay_schedule(
            init_value=initial_lr,
            peak_value=peak_lr,
            warmup_steps=warmup_steps,
            decay_steps=num_train_steps,
            end_value=end_lr,
        ),
    ),
    ema_decay=ema_decay,
)

templates.run_train(
    train_dataloader=get_cond_mnist_dataset(
        split="train[:75%]", batch_size=train_batch_size
    ),
    trainer=cond_trainer,
    workdir=cond_workdir,
    total_train_steps=num_train_steps,
    metric_writer=metric_writers.create_default_writer(
        cond_workdir, asynchronous=False
    ),
    metric_aggregation_steps=100,
    eval_dataloader=get_cond_mnist_dataset(
        split="train[75%:]", batch_size=eval_batch_size
    ),
    eval_every_steps = 1000,
    num_batches_per_eval = 2,
    callbacks=(
        templates.TqdmProgressBar(
            total_train_steps=num_train_steps,
            train_monitors=("train_loss",),
        ),
        templates.TrainStateCheckpoint(
            base_dir=cond_workdir,
            options=ocp.CheckpointManagerOptions(
                save_interval_steps=ckpt_interval, max_to_keep=max_ckpt_to_keep
            ),
        ),
    ),
)

### Inference

To perform inference/sampling, let's load back the trained conditional model checkpoint:

In [None]:
trained_state = dfn.DenoisingModelTrainState.restore_from_orbax_ckpt(
    f"{cond_workdir}/checkpoints", step=None
)
# Construct the inference function
cond_denoise_fn = dfn.DenoisingTrainer.inference_fn_from_state_dict(
    trained_state, use_ema=True, denoiser=cond_denoiser_model
)

The conditional sampler again follows the previous example, with the only exception being that the conditional model replaces the unconditional one.

Below we do not apply any guidance, but one can be easily added in the same way as in the unconditional example above.

In [None]:
cond_sampler = dfn_lib.SdeSampler(
    input_shape=(28, 28, 1),
    integrator=solver_lib.EulerMaruyama(),
    tspan=dfn_lib.edm_noise_decay(
        diffusion_scheme, rho=7, num_steps=256, end_sigma=1e-3,
    ),
    scheme=diffusion_scheme,
    denoise_fn=cond_denoise_fn,
    guidance_transforms=(),
    apply_denoise_at_end=True,
    return_full_paths=False,
)

We again JIT the generate function for the sake of faster repeated sampling calls. Here we employ `functools.partial` to specify `num_samples=5`, making it easier to vectorize across the batch dimension with `jax.vmap`.

In [None]:
num_samples_per_cond = 5

generate = jax.jit(
    functools.partial(sampler.generate, num_samples_per_cond)
)

Loading a test batch of conditions with 4 elements:

In [None]:
batch_size = 4
test_ds = get_cond_mnist_dataset(split="test", batch_size=4)
test_batch_cond = next(iter(test_ds))["cond"]

The vectorized generate function is applied to the loaded batch. The vectorization occurs for the leading dimensions of both the random seed and the condition (for those unfamiliarized with vectorized operations in jax, think of a more efficient `for` loop that iterates over the random seeds and batch conditions zipped together).

In [None]:
cond_samples = jax.vmap(generate, in_axes=(0, 0, None))(
    jax.random.split(jax.random.PRNGKey(8888), batch_size),
    test_batch_cond,
    None,  # Guidance inputs = None since no guidance transforms involved
)

The result `cond_samples` has shape `(batch_size, num_samples_per_cond, *input_shape)`.

In [None]:
print(cond_samples.shape)

Visualize generated examples alongside their low-res conditioning:

In [None]:
for i in range(batch_size):
  fig, ax = plt.subplots(1, 1, figsize=(2, 2))
  im = ax.imshow(
      test_batch_cond["channel:low_res"][i, :, :, 0] * 255,
      cmap="gray", vmin=0, vmax=255
  )
  ax.axis("off")
  ax.set_title(f"Low-res condition: #{i + 1}")


  # Plot generated samples.
  fig, ax = plt.subplots(
      1, num_samples_per_cond, figsize=(num_samples_per_cond * 2, 2)
  )
  for j in range(num_samples_per_cond):
    im = ax[j].imshow(
        cond_samples[i, j, :, :, 0] * 255, cmap="gray", vmin=0, vmax=255
    )
    square = patches.Rectangle(
        xy=(11, 11), width=7, height=7, fill=False, edgecolor='red'
    )
    ax[j].add_patch(square)
    ax[j].set_title(f"conditional sample: #{j + 1}")
    ax[j].axis("off")

  plt.tight_layout()

plt.show()