# Diffusion probabilistic models - Score matching

## Score matching

In [2]:
%matplotlib inline
import functools
import math
import string
from typing import Any, Tuple, Optional, Sequence
import flax.linen as nn
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.nn.initializers as init
import matplotlib as mpl
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import flax
Path("outputs").mkdir(exist_ok=True)

In [5]:
"""Layers for defining NCSN++.
"""
# Function ported from StyleGAN2
def get_weight(module,
               shape,
               weight_var='weight',
               kernel_init=None):
  """Get/create weight tensor for a convolution or fully-connected layer."""

  return module.param(weight_var, kernel_init, shape)


class Conv2d(nn.Module):
  """Conv2d layer with optimal upsampling and downsampling (StyleGAN2)."""
  fmaps: int
  kernel: int
  up: bool = False
  down: bool = False
  resample_kernel: Tuple[int] = (1, 3, 3, 1)
  use_bias: bool = True
  weight_var: str = 'weight'
  kernel_init: Optional[Any] = None

  @nn.compact
  def __call__(self, x):
    assert not (self.up and self.down)
    assert self.kernel >= 1 and self.kernel % 2 == 1
    w = get_weight(self, (self.kernel, self.kernel, x.shape[-1], self.fmaps),
                   weight_var=self.weight_var,
                   kernel_init=self.kernel_init)
    if self.up:
      x = upsample_conv_2d(x, w, data_format='NHWC', k=self.resample_kernel)
    elif self.down:
      x = conv_downsample_2d(x, w, data_format='NHWC', k=self.resample_kernel)
    else:
      x = jax.lax.conv_general_dilated(
        x,
        w,
        window_strides=(1, 1),
        padding='SAME',
        dimension_numbers=('NHWC', 'HWIO', 'NHWC'))

    if self.use_bias:
      b = self.param('bias', jnn.initializers.zeros, (x.shape[-1],))
      x = x + b.reshape((1, 1, 1, -1))
    return x


def naive_upsample_2d(x, factor=2):
  _N, H, W, C = x.shape
  x = jnp.reshape(x, [-1, H, 1, W, 1, C])
  x = jnp.tile(x, [1, 1, factor, 1, factor, 1])
  return jnp.reshape(x, [-1, H * factor, W * factor, C])


def naive_downsample_2d(x, factor=2):
  _N, H, W, C = x.shape
  x = jnp.reshape(x, [-1, H // factor, factor, W // factor, factor, C])
  return jnp.mean(x, axis=[2, 4])


def upsample_conv_2d(x, w, k=None, factor=2, gain=1, data_format='NHWC'):
  """Fused `upsample_2d()` followed by `tf.nn.conv2d()`.
     Padding is performed only once at the beginning, not between the
     operations.
     The fused op is considerably more efficient than performing the same
     calculation
     using standard TensorFlow ops. It supports gradients of arbitrary order.
     Args:
       x:            Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
         C]`.
       w:            Weight tensor of the shape `[filterH, filterW, inChannels,
         outChannels]`. Grouped convolution can be performed by `inChannels =
         x.shape[0] // numGroups`.
       k:            FIR filter of the shape `[firH, firW]` or `[firN]`
         (separable). The default is `[1] * factor`, which corresponds to
         nearest-neighbor upsampling.
       factor:       Integer upsampling factor (default: 2).
       gain:         Scaling factor for signal magnitude (default: 1.0).
       data_format:  `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
     Returns:
       Tensor of the shape `[N, C, H * factor, W * factor]` or
       `[N, H * factor, W * factor, C]`, and same datatype as `x`.
  """

  assert isinstance(factor, int) and factor >= 1

  # Check weight shape.
  assert len(w.shape) == 4
  convH = w.shape[0]
  convW = w.shape[1]
  inC = w.shape[2]
  outC = w.shape[3]
  assert convW == convH

  # Setup filter kernel.
  if k is None:
    k = [1] * factor
  k = _setup_kernel(k) * (gain * (factor ** 2))
  p = (k.shape[0] - factor) - (convW - 1)

  stride = [factor, factor]
  # Determine data dimensions.
  if data_format == 'NCHW':
    num_groups = _shape(x, 1) // inC
  else:
    num_groups = _shape(x, 3) // inC

  # Transpose weights.
  w = jnp.reshape(w, [convH, convW, inC, num_groups, -1])
  w = jnp.transpose(w[::-1, ::-1], [0, 1, 4, 3, 2])
  w = jnp.reshape(w, [convH, convW, -1, num_groups * inC])

  ## Original TF code.
  # x = tf.nn.conv2d_transpose(
  #     x,
  #     w,
  #     output_shape=output_shape,
  #     strides=stride,
  #     padding='VALID',
  #     data_format=data_format)
  ## JAX equivalent
  x = jax.lax.conv_transpose(
    x,
    w,
    strides=stride,
    padding='VALID',
    transpose_kernel=True,
    dimension_numbers=(data_format, 'HWIO', data_format))

  return _simple_upfirdn_2d(
    x,
    k,
    pad0=(p + 1) // 2 + factor - 1,
    pad1=p // 2 + 1,
    data_format=data_format)


def conv_downsample_2d(x, w, k=None, factor=2, gain=1, data_format='NHWC'):
  """Fused `tf.nn.conv2d()` followed by `downsample_2d()`.
    Padding is performed only once at the beginning, not between the operations.
    The fused op is considerably more efficient than performing the same
    calculation
    using standard TensorFlow ops. It supports gradients of arbitrary order.
    Args:
        x:            Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
          C]`.
        w:            Weight tensor of the shape `[filterH, filterW, inChannels,
          outChannels]`. Grouped convolution can be performed by `inChannels =
          x.shape[0] // numGroups`.
        k:            FIR filter of the shape `[firH, firW]` or `[firN]`
          (separable). The default is `[1] * factor`, which corresponds to
          average pooling.
        factor:       Integer downsampling factor (default: 2).
        gain:         Scaling factor for signal magnitude (default: 1.0).
        data_format:  `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
    Returns:
        Tensor of the shape `[N, C, H // factor, W // factor]` or
        `[N, H // factor, W // factor, C]`, and same datatype as `x`.
  """

  assert isinstance(factor, int) and factor >= 1
  convH, convW, _inC, _outC = w.shape
  assert convW == convH
  if k is None:
    k = [1] * factor
  k = _setup_kernel(k) * gain
  p = (k.shape[0] - factor) + (convW - 1)
  s = [factor, factor]
  x = _simple_upfirdn_2d(x, k, pad0=(p + 1) // 2,
                         pad1=p // 2, data_format=data_format)

  return jax.lax.conv_general_dilated(
    x,
    w,
    window_strides=s,
    padding='VALID',
    dimension_numbers=(data_format, 'HWIO', data_format))


def upfirdn_2d(x, k, upx, upy, downx, downy, padx0, padx1, pady0, pady1):
  """Pad, upsample, FIR filter, and downsample a batch of 2D images.
    Accepts a batch of 2D images of the shape `[majorDim, inH, inW, minorDim]`
    and performs the following operations for each image, batched across
    `majorDim` and `minorDim`:
    1. Pad the image with zeros by the specified number of pixels on each side
       (`padx0`, `padx1`, `pady0`, `pady1`). Specifying a negative value
       corresponds to cropping the image.
    2. Upsample the image by inserting the zeros after each pixel (`upx`,
    `upy`).
    3. Convolve the image with the specified 2D FIR filter (`k`), shrinking the
       image so that the footprint of all output pixels lies within the input
       image.
    4. Downsample the image by throwing away pixels (`downx`, `downy`).
    This sequence of operations bears close resemblance to
    scipy.signal.upfirdn().
    The fused op is considerably more efficient than performing the same
    calculation
    using standard TensorFlow ops. It supports gradients of arbitrary order.
    Args:
        x:      Input tensor of the shape `[majorDim, inH, inW, minorDim]`.
        k:      2D FIR filter of the shape `[firH, firW]`.
        upx:    Integer upsampling factor along the X-axis (default: 1).
        upy:    Integer upsampling factor along the Y-axis (default: 1).
        downx:  Integer downsampling factor along the X-axis (default: 1).
        downy:  Integer downsampling factor along the Y-axis (default: 1).
        padx0:  Number of pixels to pad on the left side (default: 0).
        padx1:  Number of pixels to pad on the right side (default: 0).
        pady0:  Number of pixels to pad on the top side (default: 0).
        pady1:  Number of pixels to pad on the bottom side (default: 0).
        impl:   Name of the implementation to use. Can be `"ref"` or `"cuda"`
          (default).
    Returns:
        Tensor of the shape `[majorDim, outH, outW, minorDim]`, and same
        datatype as `x`.
  """
  k = jnp.asarray(k, dtype=np.float32)
  assert len(x.shape) == 4
  inH = x.shape[1]
  inW = x.shape[2]
  minorDim = x.shape[3]
  kernelH, kernelW = k.shape
  assert inW >= 1 and inH >= 1
  assert kernelW >= 1 and kernelH >= 1
  assert isinstance(upx, int) and isinstance(upy, int)
  assert isinstance(downx, int) and isinstance(downy, int)
  assert isinstance(padx0, int) and isinstance(padx1, int)
  assert isinstance(pady0, int) and isinstance(pady1, int)

  # Upsample (insert zeros).
  x = jnp.reshape(x, (-1, inH, 1, inW, 1, minorDim))
  x = jnp.pad(x, [[0, 0], [0, 0], [0, upy - 1], [0, 0], [0, upx - 1], [0, 0]])
  x = jnp.reshape(x, [-1, inH * upy, inW * upx, minorDim])

  # Pad (crop if negative).
  x = jnp.pad(x, [[0, 0], [max(pady0, 0), max(pady1, 0)],
                  [max(padx0, 0), max(padx1, 0)], [0, 0]])
  x = x[:,
      max(-pady0, 0):x.shape[1] - max(-pady1, 0),
      max(-padx0, 0):x.shape[2] - max(-padx1, 0), :]

  # Convolve with filter.
  x = jnp.transpose(x, [0, 3, 1, 2])
  x = jnp.reshape(x,
                  [-1, 1, inH * upy + pady0 + pady1, inW * upx + padx0 + padx1])
  w = jnp.array(k[::-1, ::-1, None, None], dtype=x.dtype)
  x = jax.lax.conv_general_dilated(
    x,
    w,
    window_strides=(1, 1),
    padding='VALID',
    dimension_numbers=('NCHW', 'HWIO', 'NCHW'))

  x = jnp.reshape(x, [
    -1, minorDim, inH * upy + pady0 + pady1 - kernelH + 1,
                  inW * upx + padx0 + padx1 - kernelW + 1
  ])
  x = jnp.transpose(x, [0, 2, 3, 1])

  # Downsample (throw away pixels).
  return x[:, ::downy, ::downx, :]


def _simple_upfirdn_2d(x, k, up=1, down=1, pad0=0, pad1=0, data_format='NCHW'):
  assert data_format in ['NCHW', 'NHWC']
  assert len(x.shape) == 4
  y = x
  if data_format == 'NCHW':
    y = jnp.reshape(y, [-1, y.shape[2], y.shape[3], 1])
  y = upfirdn_2d(
    y,
    k,
    upx=up,
    upy=up,
    downx=down,
    downy=down,
    padx0=pad0,
    padx1=pad1,
    pady0=pad0,
    pady1=pad1)
  if data_format == 'NCHW':
    y = jnp.reshape(y, [-1, x.shape[1], y.shape[1], y.shape[2]])
  return y


def _setup_kernel(k):
  k = np.asarray(k, dtype=np.float32)
  if k.ndim == 1:
    k = np.outer(k, k)
  k /= np.sum(k)
  assert k.ndim == 2
  assert k.shape[0] == k.shape[1]
  return k


def _shape(x, dim):
  return x.shape[dim]


def upsample_2d(x, k=None, factor=2, gain=1, data_format='NHWC'):
  r"""Upsample a batch of 2D images with the given filter.
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
    and upsamples each image with the given filter. The filter is normalized so
    that
    if the input pixels are constant, they will be scaled by the specified
    `gain`.
    Pixels outside the image are assumed to be zero, and the filter is padded
    with
    zeros so that its shape is a multiple of the upsampling factor.
    Args:
        x:            Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
          C]`.
        k:            FIR filter of the shape `[firH, firW]` or `[firN]`
          (separable). The default is `[1] * factor`, which corresponds to
          nearest-neighbor upsampling.
        factor:       Integer upsampling factor (default: 2).
        gain:         Scaling factor for signal magnitude (default: 1.0).
        data_format:  `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
    Returns:
        Tensor of the shape `[N, C, H * factor, W * factor]` or
        `[N, H * factor, W * factor, C]`, and same datatype as `x`.
  """
  assert isinstance(factor, int) and factor >= 1
  if k is None:
    k = [1] * factor
  k = _setup_kernel(k) * (gain * (factor ** 2))
  p = k.shape[0] - factor
  return _simple_upfirdn_2d(
    x,
    k,
    up=factor,
    pad0=(p + 1) // 2 + factor - 1,
    pad1=p // 2,
    data_format=data_format)


def downsample_2d(x, k=None, factor=2, gain=1, data_format='NHWC'):
  r"""Downsample a batch of 2D images with the given filter.
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
    and downsamples each image with the given filter. The filter is normalized
    so that
    if the input pixels are constant, they will be scaled by the specified
    `gain`.
    Pixels outside the image are assumed to be zero, and the filter is padded
    with
    zeros so that its shape is a multiple of the downsampling factor.
    Args:
        x:            Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
          C]`.
        k:            FIR filter of the shape `[firH, firW]` or `[firN]`
          (separable). The default is `[1] * factor`, which corresponds to
          average pooling.
        factor:       Integer downsampling factor (default: 2).
        gain:         Scaling factor for signal magnitude (default: 1.0).
        data_format:  `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
        impl:         Name of the implementation to use. Can be `"ref"` or
          `"cuda"` (default).
    Returns:
        Tensor of the shape `[N, C, H // factor, W // factor]` or
        `[N, H // factor, W // factor, C]`, and same datatype as `x`.
  """

  assert isinstance(factor, int) and factor >= 1
  if k is None:
    k = [1] * factor
  k = _setup_kernel(k) * gain
  p = k.shape[0] - factor
  return _simple_upfirdn_2d(
    x,
    k,
    down=factor,
    pad0=(p + 1) // 2,
    pad1=p // 2,
    data_format=data_format)

def ddpm_conv1x1(x, out_planes, stride=1, bias=True, dilation=1, init_scale=1.):
  """1x1 convolution with DDPM initialization."""
  bias_init = jnn.initializers.zeros
  output = nn.Conv(out_planes, kernel_size=(1, 1),
                   strides=(stride, stride), padding='SAME', use_bias=bias,
                   kernel_dilation=(dilation, dilation),
                   kernel_init=default_init(init_scale),
                   bias_init=bias_init)(x)
  return output


def ddpm_conv3x3(x, out_planes, stride=1, bias=True, dilation=1, init_scale=1.):
  """3x3 convolution with DDPM initialization."""
  bias_init = jnn.initializers.zeros
  output = nn.Conv(
    out_planes,
    kernel_size=(3, 3),
    strides=(stride, stride),
    padding='SAME',
    use_bias=bias,
    kernel_dilation=(dilation, dilation),
    kernel_init=default_init(init_scale),
    bias_init=bias_init)(x)
  return output

def default_init(scale=1.):
  """The same initialization used in DDPM."""
  scale = 1e-10 if scale == 0 else scale
  return jnn.initializers.variance_scaling(scale, 'fan_avg', 'uniform')

class NIN(nn.Module):
  num_units: int
  init_scale: float = 0.1

  @nn.compact
  def __call__(self, x):
    in_dim = int(x.shape[-1])
    W = self.param('W', default_init(scale=self.init_scale), (in_dim, self.num_units))
    b = self.param('b', jnn.initializers.zeros, (self.num_units,))
    y = contract_inner(x, W) + b
    assert y.shape == x.shape[:-1] + (self.num_units,)
    return y


conv1x1 = ddpm_conv1x1
conv3x3 = ddpm_conv3x3

class GaussianFourierProjection(nn.Module):
  """Gaussian Fourier embeddings for noise levels."""
  embedding_size: int = 256
  scale: float = 1.0

  @nn.compact
  def __call__(self, x):
    W = self.param('W', jax.nn.initializers.normal(stddev=self.scale), (self.embedding_size,))
    W = jax.lax.stop_gradient(W)
    x_proj = x[:, None] * W[None, :] * 2 * jnp.pi
    return jnp.concatenate([jnp.sin(x_proj), jnp.cos(x_proj)], axis=-1)


class Combine(nn.Module):
  """Combine information from skip connections."""
  method: str = 'cat'

  @nn.compact
  def __call__(self, x, y):
    h = conv1x1(x, y.shape[-1])
    if self.method == 'cat':
      return jnp.concatenate([h, y], axis=-1)
    elif self.method == 'sum':
      return h + y
    else:
      raise ValueError(f'Method {self.method} not recognized.')


class AttnBlockpp(nn.Module):
  """Channel-wise self-attention block. Modified from DDPM."""
  skip_rescale: bool = False
  init_scale: float = 0.

  @nn.compact
  def __call__(self, x):
    B, H, W, C = x.shape
    h = nn.GroupNorm(num_groups=min(x.shape[-1] // 4, 32))(x)
    q = NIN(C)(h)
    k = NIN(C)(h)
    v = NIN(C)(h)

    w = jnp.einsum('bhwc,bHWc->bhwHW', q, k) * (int(C) ** (-0.5))
    w = jnp.reshape(w, (B, H, W, H * W))
    w = jax.nn.softmax(w, axis=-1)
    w = jnp.reshape(w, (B, H, W, H, W))
    h = jnp.einsum('bhwHW,bHWc->bhwc', w, v)
    h = NIN(C, init_scale=self.init_scale)(h)
    if not self.skip_rescale:
      return x + h
    else:
      return (x + h) / np.sqrt(2.)


class Upsample(nn.Module):
  out_ch: Optional[int] = None
  with_conv: bool = False
  fir: bool = False
  fir_kernel: Tuple[int] = (1, 3, 3, 1)

  @nn.compact
  def __call__(self, x):
    B, H, W, C = x.shape
    out_ch = self.out_ch if self.out_ch else C
    if not self.fir:
      h = jax.image.resize(x, (x.shape[0], H * 2, W * 2, C), 'nearest')
      if self.with_conv:
        h = conv3x3(h, out_ch)
    else:
      if not self.with_conv:
        h = upsample_2d(x, self.fir_kernel, factor=2)
      else:
        h = Conv2d(out_ch,
                                       kernel=3,
                                       up=True,
                                       resample_kernel=self.fir_kernel,
                                       use_bias=True,
                                       kernel_init=default_init())(x)

    assert h.shape == (B, 2 * H, 2 * W, out_ch)
    return h


class Downsample(nn.Module):
  out_ch: Optional[int] = None
  with_conv: bool = False
  fir: bool = False
  fir_kernel: Tuple[int] = (1, 3, 3, 1)

  @nn.compact
  def __call__(self, x):
    B, H, W, C = x.shape
    out_ch = self.out_ch if self.out_ch else C
    if not self.fir:
      if self.with_conv:
        x = conv3x3(x, out_ch, stride=2)
      else:
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2), padding='SAME')
    else:
      if not self.with_conv:
        x = downsample_2d(x, self.fir_kernel, factor=2)
      else:
        x = Conv2d(
          out_ch,
          kernel=3,
          down=True,
          resample_kernel=self.fir_kernel,
          use_bias=True,
          kernel_init=default_init())(x)

    assert x.shape == (B, H // 2, W // 2, out_ch)
    return x


class ResnetBlockDDPMpp(nn.Module):
  """ResBlock adapted from DDPM."""
  act: Any
  out_ch: Optional[int] = None
  conv_shortcut: bool = False
  dropout: float = 0.1
  skip_rescale: bool = False
  init_scale: float = 0.

  @nn.compact
  def __call__(self, x, temb=None, train=True):
    B, H, W, C = x.shape
    out_ch = self.out_ch if self.out_ch else C
    h = self.act(nn.GroupNorm(num_groups=min(x.shape[-1] // 4, 32))(x))
    h = conv3x3(h, out_ch)
    # Add bias to each feature map conditioned on the time embedding
    if temb is not None:
      h += nn.Dense(out_ch, kernel_init=default_init())(self.act(temb))[:, None, None, :]

    h = self.act(nn.GroupNorm(num_groups=min(h.shape[-1] // 4, 32))(h))
    h = nn.Dropout(self.dropout)(h, deterministic=not train)
    h = conv3x3(h, out_ch, init_scale=self.init_scale)
    if C != out_ch:
      if self.conv_shortcut:
        x = conv3x3(x, out_ch)
      else:
        x = NIN(out_ch)(x)

    if not self.skip_rescale:
      return x + h
    else:
      return (x + h) / np.sqrt(2.)


class ResnetBlockBigGANpp(nn.Module):
  """ResBlock adapted from BigGAN."""
  act: Any
  up: bool = False
  down: bool = False
  out_ch: Optional[int] = None
  dropout: float = 0.1
  fir: bool = False
  fir_kernel: Tuple[int] = (1, 3, 3, 1)
  skip_rescale: bool = True
  init_scale: float = 0.

  @nn.compact
  def __call__(self, x, temb=None, train=True):
    B, H, W, C = x.shape
    out_ch = self.out_ch if self.out_ch else C
    h = self.act(nn.GroupNorm(num_groups=min(x.shape[-1] // 4, 32))(x))

    if self.up:
      if self.fir:
        h = upsample_2d(h, self.fir_kernel, factor=2)
        x = upsample_2d(x, self.fir_kernel, factor=2)
      else:
        h = naive_upsample_2d(h, factor=2)
        x = naive_upsample_2d(x, factor=2)
    elif self.down:
      if self.fir:
        h = downsample_2d(h, self.fir_kernel, factor=2)
        x = downsample_2d(x, self.fir_kernel, factor=2)
      else:
        h = naive_downsample_2d(h, factor=2)
        x = naive_downsample_2d(x, factor=2)

    h = conv3x3(h, out_ch)
    # Add bias to each feature map conditioned on the time embedding
    if temb is not None:
      h += nn.Dense(out_ch, kernel_init=default_init())(self.act(temb))[:, None, None, :]

    h = self.act(nn.GroupNorm(num_groups=min(h.shape[-1] // 4, 32))(h))
    h = nn.Dropout(self.dropout)(h, deterministic=not train)
    h = conv3x3(h, out_ch, init_scale=self.init_scale)
    if C != out_ch or self.up or self.down:
      x = conv1x1(x, out_ch)

    if not self.skip_rescale:
      return x + h
    else:
      return (x + h) / np.sqrt(2.)

In [None]:
"""

"""

ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp
ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp
Combine = layerspp.Combine
conv3x3 = layerspp.conv3x3
conv1x1 = layerspp.conv1x1
get_act = layers.get_act
get_normalization = normalization.get_normalization
default_initializer = layers.default_init


#@utils.register_model(name='ncsnpp')
class NCSNpp(nn.Module):
  """NCSN++ model"""
  #config: ml_collections.ConfigDict

  @nn.compact
  def __call__(self, x, time_cond, train=True):
    # config parsing
    #config = self.config
    act = get_act(config)
    sigmas = utils.get_sigmas(config)

    nf = config.model.nf
    ch_mult = config.model.ch_mult
    num_res_blocks = config.model.num_res_blocks
    attn_resolutions = config.model.attn_resolutions
    dropout = config.model.dropout
    resamp_with_conv = config.model.resamp_with_conv
    num_resolutions = len(ch_mult)

    conditional = config.model.conditional  # noise-conditional
    fir = config.model.fir
    fir_kernel = config.model.fir_kernel
    skip_rescale = config.model.skip_rescale
    resblock_type = config.model.resblock_type.lower()
    progressive = config.model.progressive.lower()
    progressive_input = config.model.progressive_input.lower()
    embedding_type = config.model.embedding_type.lower()
    init_scale = config.model.init_scale
    assert progressive in ['none', 'output_skip', 'residual']
    assert progressive_input in ['none', 'input_skip', 'residual']
    assert embedding_type in ['fourier', 'positional']
    combine_method = config.model.progressive_combine.lower()
    combiner = functools.partial(Combine, method=combine_method)

    # timestep/noise_level embedding; only for continuous training
    if embedding_type == 'fourier':
      # Gaussian Fourier features embeddings.
      assert config.training.continuous, "Fourier features are only used for continuous training."
      used_sigmas = time_cond
      temb = layerspp.GaussianFourierProjection(
        embedding_size=nf,
        scale=config.model.fourier_scale)(jnp.log(used_sigmas))

    elif embedding_type == 'positional':
      # Sinusoidal positional embeddings.
      timesteps = time_cond
      used_sigmas = sigmas[time_cond.astype(jnp.int32)]
      temb = layers.get_timestep_embedding(timesteps, nf)
    else:
      raise ValueError(f'embedding type {embedding_type} unknown.')

    if conditional:
      temb = nn.Dense(nf * 4, kernel_init=default_initializer())(temb)
      temb = nn.Dense(nf * 4, kernel_init=default_initializer())(act(temb))
    else:
      temb = None

    AttnBlock = functools.partial(layerspp.AttnBlockpp,
                                  init_scale=init_scale,
                                  skip_rescale=skip_rescale)

    Upsample = functools.partial(layerspp.Upsample,
                                 with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)

    if progressive == 'output_skip':
      pyramid_upsample = functools.partial(layerspp.Upsample,
                                           fir=fir, fir_kernel=fir_kernel, with_conv=False)
    elif progressive == 'residual':
      pyramid_upsample = functools.partial(layerspp.Upsample,
                                           fir=fir, fir_kernel=fir_kernel, with_conv=True)

    Downsample = functools.partial(layerspp.Downsample,
                                   with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)

    if progressive_input == 'input_skip':
      pyramid_downsample = functools.partial(layerspp.Downsample,
                                             fir=fir, fir_kernel=fir_kernel, with_conv=False)
    elif progressive_input == 'residual':
      pyramid_downsample = functools.partial(layerspp.Downsample,
                                             fir=fir, fir_kernel=fir_kernel, with_conv=True)

    if resblock_type == 'ddpm':
      ResnetBlock = functools.partial(ResnetBlockDDPM,
                                      act=act,
                                      dropout=dropout,
                                      init_scale=init_scale,
                                      skip_rescale=skip_rescale)

    elif resblock_type == 'biggan':
      ResnetBlock = functools.partial(ResnetBlockBigGAN,
                                      act=act,
                                      dropout=dropout,
                                      fir=fir,
                                      fir_kernel=fir_kernel,
                                      init_scale=init_scale,
                                      skip_rescale=skip_rescale)

    else:
      raise ValueError(f'resblock type {resblock_type} unrecognized.')

    if not config.data.centered:
      # If input data is in [0, 1]
      x = 2 * x - 1.

    # Downsampling block

    input_pyramid = None
    if progressive_input != 'none':
      input_pyramid = x

    hs = [conv3x3(x, nf)]
    for i_level in range(num_resolutions):
      # Residual blocks for this resolution
      for i_block in range(num_res_blocks):
        h = ResnetBlock(out_ch=nf * ch_mult[i_level])(hs[-1], temb, train)
        if h.shape[1] in attn_resolutions:
          h = AttnBlock()(h)
        hs.append(h)

      if i_level != num_resolutions - 1:
        if resblock_type == 'ddpm':
          h = Downsample()(hs[-1])
        else:
          h = ResnetBlock(down=True)(hs[-1], temb, train)

        if progressive_input == 'input_skip':
          input_pyramid = pyramid_downsample()(input_pyramid)
          h = combiner()(input_pyramid, h)

        elif progressive_input == 'residual':
          input_pyramid = pyramid_downsample(out_ch=h.shape[-1])(input_pyramid)
          if skip_rescale:
            input_pyramid = (input_pyramid + h) / np.sqrt(2.)
          else:
            input_pyramid = input_pyramid + h
          h = input_pyramid

        hs.append(h)

    h = hs[-1]
    h = ResnetBlock()(h, temb, train)
    h = AttnBlock()(h)
    h = ResnetBlock()(h, temb, train)

    pyramid = None

    # Upsampling block
    for i_level in reversed(range(num_resolutions)):
      for i_block in range(num_res_blocks + 1):
        h = ResnetBlock(out_ch=nf * ch_mult[i_level])(jnp.concatenate([h, hs.pop()], axis=-1),
                                                      temb,
                                                      train)

      if h.shape[1] in attn_resolutions:
        h = AttnBlock()(h)

      if progressive != 'none':
        if i_level == num_resolutions - 1:
          if progressive == 'output_skip':
            pyramid = conv3x3(
              act(nn.GroupNorm(num_groups=min(h.shape[-1] // 4, 32))(h)),
              x.shape[-1],
              bias=True,
              init_scale=init_scale)
          elif progressive == 'residual':
            pyramid = conv3x3(
              act(nn.GroupNorm(num_groups=min(h.shape[-1] // 4, 32))(h)),
              h.shape[-1],
              bias=True)
          else:
            raise ValueError(f'{progressive} is not a valid name.')
        else:
          if progressive == 'output_skip':
            pyramid = pyramid_upsample()(pyramid)
            pyramid = pyramid + conv3x3(
              act(nn.GroupNorm(num_groups=min(h.shape[-1] // 4, 32))(h)),
              x.shape[-1],
              bias=True,
              init_scale=init_scale)
          elif progressive == 'residual':
            pyramid = pyramid_upsample(out_ch=h.shape[-1])(pyramid)
            if skip_rescale:
              pyramid = (pyramid + h) / np.sqrt(2.)
            else:
              pyramid = pyramid + h
            h = pyramid
          else:
            raise ValueError(f'{progressive} is not a valid name')

      if i_level != 0:
        if resblock_type == 'ddpm':
          h = Upsample()(h)
        else:
          h = ResnetBlock(up=True)(h, temb, train)

    assert not hs

    if progressive == 'output_skip':
      h = pyramid
    else:
      h = act(nn.GroupNorm(num_groups=min(h.shape[-1] // 4, 32))(h))
      h = conv3x3(h, x.shape[-1], init_scale=init_scale)

    if config.model.scale_by_sigma:
      used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:]))))
      h = h / used_sigmas

    return h

In [6]:
"""
The loss function for a noise dependent score model from Song+2020
"""
def anneal_dsm_score_estimation(model, samples, labels, sigmas, key, anneal_power=2.):
    sigmas = sigmas[..., None]
    noise = jax.random.normal(key, samples.shape)
    perturbed_samples = samples + noise * sigmas
    target = -noise / sigmas
    scores = model(perturbed_samples, labels)
    loss = 1 / 2. * ((scores - target) ** 2).sum(axis=-1) * sigmas.squeeze() ** anneal_power
    return loss.mean(axis=0)


Now we train the model

In [5]:
""" 
The training of the NCSNv2 model. Here define the training
parameters and initialise the model. Train on a small scale 
for testing before moving to the full scale on GPU HPC.
"""
# ----------- #
# model setup #
# ----------- #

# load in data  
box_size = 31
dataname = 'sources_box' + str(box_size) + '.npy'     
dataset = np.load(dataname)
#plt.imshow(dataset[2], cmap='gray')
#plt.show()

# perform zero-padding of the data to get desired dimensions
data_padded_31 = []
#dataset = np.resize(dataset,(1989,96,96))
for i in range(len(dataset)):
    data_padded_tmp = np.pad(dataset[i], ((0,1),(0,1)), 'constant')
    data_padded_31.append(data_padded_tmp)
dataset = np.array( data_padded_31 )

# define noise levels 
sigma_begin = 1
sigma_end   = 0.01
num_scales = 10
sigmas      = np.exp(np.linspace(np.log(sigma_begin), np.log(sigma_end), num_scales))

# score model params
n_epochs    = 50                                    # number of epochs
steps       = 1_000                                 # number of steps per epoch
batch_size  = 32                                    # batch size
lr          = 1e-4                                  # learning rate
rng         = jax.random.PRNGKey(1992)              # random seed
input_shape = (jax.local_device_count(), 32, 32, 3) # size 32 by 32 one channel
#input_shape = (32, 32, 1) # size 32 by 32 one channel
label_shape = input_shape[:1]
fake_input  = jnp.zeros(input_shape)
fake_label  = jnp.zeros(label_shape, dtype=jnp.int32)
params_rng, dropout_rng = jax.random.split(rng)
model = NCSNv2()
#model = model_def()
variables = model.init({'params': params_rng}, fake_input, fake_label)
# Variables is a `flax.FrozenDict`. It is immutable and respects functional programming
init_model_state, initial_params = variables.pop('params')
optimizer = flax.optim.Adam(learning_rate=lr,
                            beta1 = 0.9,
                            eps = 1e-8).create(initial_params)  # create optimizer

# ------------- #
# training loop #
# ------------- #
@jax.jit
def train_step(model, optimizer, rng, samples, labels, sigmas):
    rng   = jax.random.PRNGKey(rng) # random number random seed
    grads = jax.grad(anneal_dsm_score_estimation)(model, samples, labels, sigmas, rng)
    model = optimizer.update(grads, model)
    return model, optimizer, rng

key_seq = jax.random.PRNGKey(0)
for t in tqdm(range(steps + 1)):

    idx = np.random.randint(0, len(dataset))
    labels = np.random.randint(0, len(sigmas), size=len(dataset[idx])) # size for 32 by 32
    model, optimizer, key_seq = train_step(model, optimizer, 
                                key_seq, dataset[idx], labels, sigmas[labels])

    if ((t % (steps // 5)) == 0):
        labels = np.random.randint(0, len(sigmas), size=len(dataset[0]))
        print(anneal_dsm_score_estimation(model, dataset[0], labels, sigmas[labels], rng))
# -------------------- #
# end of training loop #       
# -------------------- #

size of h: (1, 32, 32, 3)


TypeError: <lambda>() takes 2 positional arguments but 3 were given

In [None]:
# ------------------------ #
# testing score estimation #
# ------------------------ #
gaussian_noise = jax.random.normal(rng, shape=(32,32))
galaxy = dataset[1992]
labels = np.random.randint(0, len(sigmas), (gaussian_noise.shape[0],))
scores = model(gaussian_noise, labels)
scores2 = model(galaxy, labels)
fig , ax = plt.subplots(1,2,figsize=(16, 5.5), facecolor='black',dpi = 70)
plt.subplots_adjust(wspace=0.01)
plt.subplot(1,2,1)
plt.imshow(scores, cmap='plasma')
#plt.colorbar()
plt.title('Gaussian Noise',fontsize=28,pad=15)
plt.subplot(1,2,2)
plt.imshow(scores2, cmap='plasma')
cbar = plt.colorbar()
cbar.set_label(r'$\nabla_x log \ p(\mathbf{\tilde{x}})$', rotation=270, fontsize = 20,labelpad= 25)
plt.title('Galaxy',fontsize=28,pad=15)

CallCompactUnboundModuleError: Can't call compact methods on unbound modules (https://flax.readthedocs.io/en/latest/flax.errors.html#flax.errors.CallCompactUnboundModuleError)

In [None]:
# ----------------------------------------- #
# port the pytorch langevin dynamics to jax #
# ----------------------------------------- #
def anneal_Langevin_dynamics(x_mod, scorenet, sigmas, n_steps_each=100, step_lr=0.000008,
                             final_only=False, verbose=False, denoise=True):
    images = []
    scores  = []

    for c, sigma in enumerate(sigmas):
        labels = torch.ones(x_mod.shape[0], device=x_mod.device) * c
        labels = labels.long()
        step_size = step_lr * (sigma / sigmas[-1]) ** 2
        step_size_cpu = step_size.to('cpu') 
        for s in range(n_steps_each):
            grad = scorenet(x_mod, labels)
            scores.append(grad.to('cpu'))
            noise = torch.randn_like(x_mod)
            grad_norm = torch.norm(grad.view(grad.shape[0], -1), dim=-1).mean()
            noise_norm = torch.norm(noise.view(noise.shape[0], -1), dim=-1).mean()
            x_mod = x_mod + step_size_cpu * grad + noise * np.sqrt(step_size_cpu * 2)

            if not final_only:
                images.append(x_mod.to('cpu'))

    if denoise:
        last_noise = (len(sigmas) - 1) * torch.ones(x_mod.shape[0], device=x_mod.device)
        last_noise = last_noise.long()
        x_mod = x_mod + sigmas[-1] ** 2 * scorenet(x_mod, last_noise)
        images.append(x_mod.to('cpu'))

    return images, scores