In [1]:
import tensorflow as tf
from models import utils as mutils
import jax
from jax import numpy as jnp
from configs.ve.disk_ssim import get_config
from models import super_simple

import losses
import sde_lib

import functools
import sampling
import datasets
from flax import jax_utils as flax_utils
from models import layers, layerspp

from flax import linen as nn
from matplotlib import pyplot as plt
from flax.training import checkpoints

default_init = layers.default_init
get_act = layers.get_act



In [2]:
def get_step_fn(sde, model, train, optimize_fn=None, reduce_mean=False, continuous=True, likelihood_weighting=False):
  """Create a one-step training/evaluation function.

  Args:
    sde: An `sde_lib.SDE` object that represents the forward SDE.
    model: A `flax.linen.Module` object that represents the architecture of the score-based model.
    train: `True` for training and `False` for evaluation.
    optimize_fn: An optimization function.
    reduce_mean: If `True`, average the loss across data dimensions. Otherwise sum the loss across data dimensions.
    continuous: `True` indicates that the model is defined to take continuous time steps.
    likelihood_weighting: If `True`, weight the mixture of score matching losses according to
      https://arxiv.org/abs/2101.09258; otherwise use the weighting recommended by our paper.

  Returns:
    A one-step function for training or evaluation.
  """
  if continuous:
    loss_fn = losses.get_sde_loss_fn(sde, model, train, reduce_mean=reduce_mean,
                              continuous=True, likelihood_weighting=likelihood_weighting)
  else:
    raise NotImplementedError()

  def step_fn(carry_state, batch):
    """Running one step of training or evaluation.

    This function will undergo `jax.lax.scan` so that multiple steps can be pmapped and jit-compiled together
    for faster execution.

    Args:
      carry_state: A tuple (JAX random state, `flax.struct.dataclass` containing the training state).
      batch: A mini-batch of training/evaluation data.

    Returns:
      new_carry_state: The updated tuple of `carry_state`.
      loss: The average loss value of this state.
    """

    (rng, state) = carry_state
    rng, step_rng = jax.random.split(rng)
    grad_fn = jax.value_and_grad(loss_fn, argnums=1, has_aux=True)
    if train:
      params = state.optimizer.target
      states = state.model_state
      (loss, new_model_state), grad = grad_fn(step_rng, params, states, batch)
      new_optimizer = optimize_fn(state, grad)
      new_params_ema = jax.tree_multimap(
        lambda p_ema, p: p_ema * state.ema_rate + p * (1. - state.ema_rate),
        state.params_ema, new_optimizer.target
      )
      step = state.step + 1
      new_state = state.replace(
        step=step,
        optimizer=new_optimizer,
        model_state=new_model_state,
        params_ema=new_params_ema
      )
    else:
      loss, _ = loss_fn(step_rng, state.params_ema, state.model_state, batch)
      new_state = state

    new_carry_state = (rng, new_state)
    return new_carry_state, loss

  return step_fn


In [3]:
spec = {'image': tf.TensorSpec(shape=(1, 1, 2), dtype=tf.float32, name=None),
            'label': tf.TensorSpec(shape=(), dtype=tf.int32, name=None)}
new_ds = tf.data.experimental.load('disk/data', spec)

In [4]:
step_rng = jax.random.PRNGKey(42)
config = get_config()

train_ds = new_ds
train_iter = iter(train_ds)
temp = next(train_iter)
ts = jax.random.uniform(jax.random.PRNGKey(42), shape=(1,)) * 999
x = temp['image']
rng = jax.random.PRNGKey(24)
labels = ts



In [5]:
score_model = super_simple.SSimple(config=config)
variables = score_model.init({'params': rng, 'dropout': rng}, x, ts)
# Variables is a `flax.FrozenDict`. It is immutable and respects functional programming
init_model_state, initial_params = variables.pop('params')
x, params = score_model.apply(variables, x, labels, train=True, mutable=list(init_model_state.keys()), rngs={'dropout': rng})
x.shape

(1, 1, 1, 2)

In [6]:
train_ds, eval_ds, _ = datasets.get_dataset(config,
                                              additional_dim=config.training.n_jitted_steps,
                                              uniform_dequantization=config.data.uniform_dequantization)

train_iter = iter(train_ds)
temp = next(train_iter)
batch = jax.tree_map(lambda x: x._numpy(), temp)

In [7]:
x = batch['image'][0, 0]
t = jax.random.uniform(rng, (64,))
out, params = score_model.apply(variables, x, t, train=True, mutable=list(init_model_state.keys()), rngs={'dropout': rng})
out.shape

(64, 1, 1, 2)

In [11]:
rng = jax.random.PRNGKey(41)
optimizer = losses.get_optimizer(config).create(initial_params)
state = mutils.State(step=0, optimizer=optimizer, lr=config.optim.lr,
                       model_state=init_model_state,
                       ema_rate=config.model.ema_rate,
                       params_ema=initial_params,
                       rng=rng)

In [12]:
params = state.optimizer.target
states = state.model_state
sde = sde_lib.VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales)
sampling_eps = 1e-5

optimize_fn = losses.optimization_manager(config)
continuous = config.training.continuous
reduce_mean = config.training.reduce_mean
likelihood_weighting = config.training.likelihood_weighting


n_jitted_steps = config.training.n_jitted_steps
# Must be divisible by the number of steps jitted together
assert config.training.log_freq % n_jitted_steps == 0 and \
        config.training.snapshot_freq_for_preemption % n_jitted_steps == 0 and \
        config.training.eval_freq % n_jitted_steps == 0 and \
        config.training.snapshot_freq % n_jitted_steps == 0, "Missing logs or checkpoints!"
train_ds, eval_ds, _ = datasets.get_dataset(config,
                                              additional_dim=config.training.n_jitted_steps,
                                              uniform_dequantization=config.data.uniform_dequantization)

train_iter = iter(train_ds)
temp = next(train_iter)
(temp['image'][0, 0, 0], temp['label'][0, 0, 0])
batch = jax.tree_map(lambda x: x._numpy(), temp)

loss_fn = losses.get_sde_loss_fn(sde, score_model, True, reduce_mean=reduce_mean,
                              continuous=True, likelihood_weighting=likelihood_weighting)

grad_fn = jax.value_and_grad(loss_fn, argnums=1, has_aux=True)

(loss, new_model_state), grad = grad_fn(step_rng, params, states, batch)

In [13]:
train_step_fn = get_step_fn(sde, score_model, train=True, optimize_fn=optimize_fn,
                                    reduce_mean=reduce_mean, continuous=continuous,
                                    likelihood_weighting=likelihood_weighting)
a, b = train_step_fn((rng, state), batch)

In [14]:
if config.training.sde.lower() == 'vpsde':
    sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
    sampling_eps = 1e-3
elif config.training.sde.lower() == 'subvpsde':
    sde = sde_lib.subVPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
    sampling_eps = 1e-3
elif config.training.sde.lower() == 'vesde':
    sde = sde_lib.VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales)
    sampling_eps = 1e-5
else:
    raise NotImplementedError(f"SDE {config.training.sde} unknown.")
inverse_scaler = datasets.get_data_inverse_scaler(config)


# Build one-step training and evaluation functions
optimize_fn = losses.optimization_manager(config)
continuous = config.training.continuous
reduce_mean = config.training.reduce_mean
likelihood_weighting = config.training.likelihood_weighting
train_step_fn = losses.get_step_fn(sde, score_model, train=True, optimize_fn=optimize_fn,
                                    reduce_mean=reduce_mean, continuous=continuous,
                                    likelihood_weighting=likelihood_weighting)
# Pmap (and jit-compile) multiple training steps together for faster running
p_train_step = jax.pmap(functools.partial(jax.lax.scan, train_step_fn), axis_name='batch', donate_argnums=1)
eval_step_fn = losses.get_step_fn(sde, score_model, train=False, optimize_fn=optimize_fn,
                                reduce_mean=reduce_mean, continuous=continuous,
                                likelihood_weighting=likelihood_weighting)
# Pmap (and jit-compile) multiple evaluation steps together for faster running
p_eval_step = jax.pmap(functools.partial(jax.lax.scan, eval_step_fn), axis_name='batch', donate_argnums=1)

# Building sampling functions
if config.training.snapshot_sampling:
    sampling_shape = (config.training.batch_size // jax.local_device_count(), config.data.image_size,
                    config.data.image_size, config.data.num_channels)
sampling_fn = sampling.get_sampling_fn(config, sde, score_model, sampling_shape, inverse_scaler, sampling_eps)

# Replicate the training state to run on multiple devices
pstate = flax_utils.replicate(state)
num_train_steps = config.training.n_iters

# In case there are multiple hosts (e.g., TPU pods), only log to host 0
rng = jax.random.fold_in(rng, jax.host_id())

# JIT multiple training steps together for faster training
n_jitted_steps = config.training.n_jitted_steps
# Must be divisible by the number of steps jitted together
assert config.training.log_freq % n_jitted_steps == 0 and \
        config.training.snapshot_freq_for_preemption % n_jitted_steps == 0 and \
        config.training.eval_freq % n_jitted_steps == 0 and \
        config.training.snapshot_freq % n_jitted_steps == 0, "Missing logs or checkpoints!"
train_ds, eval_ds, _ = datasets.get_dataset(config,
                                              additional_dim=config.training.n_jitted_steps,
                                              uniform_dequantization=config.data.uniform_dequantization)
config.training.batch_size
train_iter = iter(train_ds)
temp = next(train_iter)
(temp['image'][0, 0, 0], temp['label'][0, 0, 0])
batch = jax.tree_map(lambda x: x._numpy(), temp)
rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1)
next_rng = jnp.asarray(next_rng)
# Execute one training step
(_, pstate), ploss = p_train_step((next_rng, pstate), batch)

