In [1]:
!pip list | grep jax
!pip list | grep flax
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import time
import requests
import functools
import jax
from jax import config
import jax.numpy as jnp
import flax
from matplotlib import pyplot as plt
import numpy as onp
import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()

import diffusion_distillation
from flax.training import checkpoints as flaxcheckpoints



jax                          0.3.15
jaxlib                       0.3.15+cuda11.cudnn82
flax                         0.4.2


2024-04-19 12:20:01.773034: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-04-19 12:20:02.487663: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-12.1/lib64:/usr/local/cuda-12.1/lib64:
2024-04-19 12:20:02.487818: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-12.1/lib64:/usr/local/cuda-12.1/lib64:
  from .autonotebook import tqdm as notebook_tqdm


## Train a new diffusion model

In [2]:
# jax.devices()
# jax.config.FLAGS.jax_xla_backend = "gpu_driver"
print("JAX backend:", jax.lib.xla_bridge.get_backend().platform)
jax.devices()


JAX backend: gpu


[GpuDevice(id=0, process_index=0)]

In [3]:
# create model
config = diffusion_distillation.config.mnist_base.get_config()
model = diffusion_distillation.model.Model(config)

In [4]:
print(model.model)

UNet(
    # attributes
    num_classes = 1
    ch = 64
    emb_ch = 256
    out_ch = 1
    ch_mult = [1, 1, 1]
    num_res_blocks = 3
    attn_resolutions = [8, 16]
    num_heads = 1
    dropout = 0.2
    logsnr_input_type = 'inv_cos'
    logsnr_scale_range = (-10.0, 10.0)
    resblock_resample = True
    head_dim = None
)


In [5]:
print(type(model.make_optimizer_def()))
# print(type(model.make_optimizer_def().init(None)))

<class 'flax.optim.adam.Adam'>




In [6]:
# init params 
state = jax.device_get(model.make_init_state())
state = flax.jax_utils.replicate(state)

  jax.tree_map(lambda a: a.shape, init_params)))
  return sum([x.size for x in jax.tree_leaves(pytree)])
  param_states = jax.tree_map(self.init_param_state, params)
  return jax.tree_map(jnp.array, pytree)


In [7]:
type(model.make_optimizer_def())



flax.optim.adam.Adam

In [8]:
# type(state.optimizer[0]), type(state.optimizer[1])

In [9]:
model.make_init_state()



TrainState(step=0, optimizer=Optimizer(optimizer_def=<flax.optim.adam.Adam object at 0x7cc91c461520>, state=OptimizerState(step=DeviceArray(0, dtype=int32), param_states=FrozenDict({
    conv_in: {
        bias: _AdamParamState(grad_ema=DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                     0., 0., 0., 0.], dtype=float32), grad_sq_ema=DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                     0., 0., 0., 0.], dtype=float32)),
        ke

In [10]:
# JIT compile training step
train_step = functools.partial(model.step_fn, jax.random.PRNGKey(0), True)
train_step = functools.partial(jax.lax.scan, train_step)  # for substeps
train_step = jax.pmap(train_step, axis_name='batch', donate_argnums=(0,))

In [11]:
# build input pipeline
total_bs = config.train.batch_size
device_bs = total_bs // jax.device_count()
train_ds = model.dataset.get_shuffled_repeated_dataset(
    split='train',
    batch_shape=(
        jax.local_device_count(),  # for pmap
        config.train.substeps,  # for lax.scan over multiple substeps
        device_bs,  # batch size per device
    ),
    local_rng=jax.random.PRNGKey(0),
    augment=True)
train_iter = diffusion_distillation.utils.numpy_iter(train_ds)

(28, 28, 1)


In [12]:
config.model.mean_type

'x'

In [21]:
# run training
for step in range(1000):
  batch = next(train_iter)
  state, metrics = train_step(state, batch)
  print(flaxcheckpoints.save_checkpoint('/tmp/flax_ckpt/checkpoint/', state, step, overwrite=True))
  metrics = jax.device_get(flax.jax_utils.unreplicate(metrics))
  metrics = jax.tree_map(lambda x: float(x.mean(axis=0)), metrics)
  print(metrics)

/tmp/flax_ckpt/checkpoint/checkpoint_0
{'train/gnorm': 1715.4888916015625, 'train/loss': 0.9441455006599426}


  metrics = jax.tree_map(lambda x: float(x.mean(axis=0)), metrics)


/tmp/flax_ckpt/checkpoint/checkpoint_1
{'train/gnorm': 802.95361328125, 'train/loss': 1.059767246246338}
/tmp/flax_ckpt/checkpoint/checkpoint_2
{'train/gnorm': 1298.5948486328125, 'train/loss': 1.428767442703247}
/tmp/flax_ckpt/checkpoint/checkpoint_3
{'train/gnorm': 458.55010986328125, 'train/loss': 0.31230902671813965}
/tmp/flax_ckpt/checkpoint/checkpoint_4
{'train/gnorm': 46.45166778564453, 'train/loss': 0.15292471647262573}
/tmp/flax_ckpt/checkpoint/checkpoint_5
{'train/gnorm': 1037.1522216796875, 'train/loss': 0.9696954488754272}
/tmp/flax_ckpt/checkpoint/checkpoint_6
{'train/gnorm': 517.4755249023438, 'train/loss': 0.6376945972442627}
/tmp/flax_ckpt/checkpoint/checkpoint_7
{'train/gnorm': 170.34500122070312, 'train/loss': 0.19640035927295685}
/tmp/flax_ckpt/checkpoint/checkpoint_8
{'train/gnorm': 46.166900634765625, 'train/loss': 0.1274232268333435}
/tmp/flax_ckpt/checkpoint/checkpoint_9
{'train/gnorm': 1211.931884765625, 'train/loss': 1.065319299697876}
/tmp/flax_ckpt/checkpoint

: 

In [14]:
# get all attr
#print(dir(state))
#for k, v in dict(state.ema_params).items():
#    print(k, type(v))
#print()
print(state.num_sample_steps)
print(model.make_init_state().num_sample_steps)
#print(state.optimizer)
#print(state.replace)
#print(state.step)

[0]




0


# ❗ The following Code do not need to be run

## Distill a trained diffusion model

In [15]:
# create model
import diffusion_distillation
config = diffusion_distillation.config.mnist_distill.get_config()
model = diffusion_distillation.model.Model(config)

In [16]:
# from flax import serialization

# def restore_from_path(ckpt_path, target):
#   with open(ckpt_path, 'rb') as fp:
#     return serialization.from_bytes(target, fp.read())
# loaded = flaxcheckpoints.restore_checkpoint("/tmp/flax_ckpt/checkpoint/checkpoint_0", None)
# print(loaded)

In [17]:
config.model.mean_type

'x'

In [18]:
print(model.model)

UNet(
    # attributes
    num_classes = 1
    ch = 64
    emb_ch = 256
    out_ch = 1
    ch_mult = [1, 1, 1]
    num_res_blocks = 3
    attn_resolutions = [8, 16]
    num_heads = 1
    dropout = 0.2
    logsnr_input_type = 'inv_cos'
    logsnr_scale_range = (-10.0, 10.0)
    resblock_resample = True
    head_dim = None
)


In [19]:
# load the teacher params
model.load_teacher_state("/tmp/flax_ckpt/checkpoint/checkpoint_20")



TypeError: 'NoneType' object is not subscriptable

In [None]:
# init student state
init_params = diffusion_distillation.utils.copy_pytree(model.teacher_state.ema_params)
optim = model.make_optimizer_def().create(init_params)
state = diffusion_distillation.model.TrainState(
    step=model.teacher_state.step,
    optimizer=optim,
    ema_params=diffusion_distillation.utils.copy_pytree(init_params),
    num_sample_steps=model.teacher_state.num_sample_steps//2)



In [None]:
# build input pipeline
total_bs = config.train.batch_size
device_bs = total_bs // jax.device_count()
train_ds = model.dataset.get_shuffled_repeated_dataset(
    split='train',
    batch_shape=(
        jax.local_device_count(),  # for pmap
        config.train.substeps,  # for lax.scan over multiple substeps
        device_bs,  # batch size per device
    ),
    local_rng=jax.random.PRNGKey(0),
    augment=True)
train_iter = diffusion_distillation.utils.numpy_iter(train_ds)

(28, 28, 1)


In [None]:
steps_per_distill_iter = 10  # number of distillation steps per iteration of progressive distillation
end_num_steps = 4  # eventual number of sampling steps we want to use 
while state.num_sample_steps >= end_num_steps:

  # compile training step
  train_step = functools.partial(model.step_fn, jax.random.PRNGKey(0), True)
  train_step = functools.partial(jax.lax.scan, train_step)  # for substeps
  train_step = jax.pmap(train_step, axis_name='batch', donate_argnums=(0,))

  # train the student against the teacher model
  print('distilling teacher using %d sampling steps into student using %d steps'
        % (model.teacher_state.num_sample_steps, state.num_sample_steps))
  state = flax.jax_utils.replicate(state)
  for step in range(steps_per_distill_iter):
    batch = next(train_iter)
    state, metrics = train_step(state, batch)
    metrics = jax.device_get(flax.jax_utils.unreplicate(metrics))
    metrics = jax.tree_map(lambda x: float(x.mean(axis=0)), metrics)
    print(metrics)

  # student becomes new teacher for next distillation iteration
  model.teacher_state = jax.device_get(
      flax.jax_utils.unreplicate(state).replace(optimizer=None))

  # reset student optimizer for next distillation iteration
  init_params = diffusion_distillation.utils.copy_pytree(model.teacher_state.ema_params)
  optim = model.make_optimizer_def().create(init_params)
  state = diffusion_distillation.model.TrainState(
      step=model.teacher_state.step,
      optimizer=optim,
      ema_params=diffusion_distillation.utils.copy_pytree(init_params),
      num_sample_steps=model.teacher_state.num_sample_steps//2)

distilling teacher using 8192 sampling steps into student using 4096 steps
