Copyright 2021 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

     https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# Finite ConvNet Training with Conv-KIP distilled images

This notebook demonstrate a simple example of finite neural network transfer using Conv-KIP distilled image. 

Training code is based off of [Lee et al., Finite Versus Infinite Neural Networks: an Empirical Study, NeurIPS 2020](https://arxiv.org/abs/2007.15801), as adapted in [Nguyen et al., Dataset Distillation with Infinitely Wide Convolutional Networks](https://arxiv.org/abs/2107.13034).

## Imports

In [None]:
# Install ml_collections and neural_tangents
!pip install -q git+https://www.github.com/google/ml_collections
!pip install -q git+https://www.github.com/google/neural-tangents

In [None]:
from absl import app
from absl import logging

import functools
import time
import operator as op

import jax
from jax.example_libraries import optimizers
from jax.example_libraries import stax as ostax
import jax.numpy as jnp
from jax.tree_util import tree_map
from jax.tree_util import tree_reduce

import ml_collections
import neural_tangents as nt
from neural_tangents import stax

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

## Experiment Configs

In [None]:
def get_config():
  # Note that max_lr_factor and l2_regularization is found through grid search.
  config = ml_collections.ConfigDict()
  config.seed = 0

  # Dataset
  config.dataset = ml_collections.ConfigDict()
  config.dataset.name = 'cifar10'  # ['cifar10', 'cifar100', 'mnist', 'fashion_mnist']

  config.preprocess_type = 'zca_normalize'  # ['zca_normalize', 'standard']
  config.zca_reg = 0.1

  # Optimization
  config.optimizer = 'momentum' # Note: In this notebook `run_trainer` optimizer is hard-coded to be `momentum`.
  config.momentum = 0.9
  config.batch_size = 100
  config.eval_batch_size = 100
  config.train_steps = 5_000

  config.empirical_ntk_num_inputs = 50  # Number of samples to estimate max LR.

  config.max_lr_factor = 1.0
  config.l2_regularization = 0.001

  # Network Architecture
  config.width = 1024
  config.depth = 3
  config.use_gap = False
  config.W_std = 2.0**0.5
  config.b_std = 0.1
  config.activation = 'relu'  # ['relu', 'identity', 'gelu'', 'identity']
  config.prepend_conv_layer = True  # enable for 'ConvNet' datasets; disable for 'ConvNet3' datasets

  config.loss_type = 'mse'  # ['mse', 'xent']
  config.parameterization = 'standard'  # ['standard', 'ntk']

  config.kip = ml_collections.ConfigDict()
  # Put any KIP / Label Solve checkpoint path here to use as training data."
  config.kip.data_ckpt_path = (
      'gs://kip-datasets/kip/cifar10/ConvNet_ssize100_zca_nol_noaug_ckpt1000.npz')

  return config

## Define Training Utilities

In [None]:
def load_kip_data(config):
  """Load and preprocess dataset with TFDS and KIP ckpt path."""
  data = _load_dataset(config)
  data = apply_preprocess(config, data)

  logging.info('valid data: %s, %s', 
               data['valid']['images'].shape, data['valid']['labels'].shape)
  logging.info('test data: %s, %s', 
               data['test']['images'].shape, data['test']['labels'].shape)

  # Override training data
  ckpt_name = config.kip.data_ckpt_path
  with tf.io.gfile.GFile(ckpt_name, 'rb') as f:
    loaded_data = jnp.load(f)
  train_images = loaded_data['images']
  train_labels = loaded_data['labels']

  if config.loss_type == 'xent':
    # Recover one-hot label even for the learned labels.
    n_classes = train_labels.shape[-1]
    train_labels = jnp.array(
        jnp.argmax(train_labels, axis=-1)[:, None] == jnp.arange(n_classes),
        dtype=train_labels.dtype)

  # Override training
  data['train'] = {'images': train_images, 'labels': train_labels}

  logging.info('Overrriding train with ckpt %s, size: (%s, %s)', ckpt_name,
               data['train']['images'].shape, data['train']['labels'].shape)

  return data


def _load_dataset(config):
  """Get per channel normalized / one hot encoded data from TFDS."""
  VALID_SIZE = 5000

  dataset_name = config.dataset.name
  ds_builder = tfds.builder(dataset_name)
  ds_train, ds_test = tfds.as_numpy(
      tfds.load(
          dataset_name,
          split=['train', 'test'],
          batch_size=-1,
          as_dataset_kwargs={'shuffle_files': False}))

  train_images, train_labels, test_images, test_labels = (ds_train['image'],
                                                          ds_train['label'],
                                                          ds_test['image'],
                                                          ds_test['label'])
  height, width, num_channels = ds_builder.info.features['image'].shape
  num_classes = ds_builder.info.features['label'].num_classes
  with config.dataset.unlocked():
    config.dataset.height = height
    config.dataset.width = width
    config.dataset.num_channels = num_channels
    config.dataset.num_classes = num_classes

  # One hot encode
  train_labels = jax.nn.one_hot(train_labels, num_classes)
  test_labels = jax.nn.one_hot(test_labels, num_classes)

  if config.get('loss_type', 'mse') == 'mse':
    shift = (1. / num_classes if num_classes > 1 else 0.5)
    train_labels -= shift
    test_labels -= shift

  # Normalize by precomputed per channel mean/std from training images
  train_xs = (train_images - np.mean(train_images, axis=(0, 1, 2))) / np.std(
      train_images, axis=(0, 1, 2))
  test_xs = (test_images - np.mean(train_images, axis=(0, 1, 2))) / np.std(
      train_images, axis=(0, 1, 2))

  test_ys = test_labels
  train_xs, valid_xs = train_xs[:-VALID_SIZE], train_xs[-VALID_SIZE:]
  train_ys, valid_ys = train_labels[:-VALID_SIZE], train_labels[-VALID_SIZE:]

  train = (train_xs, train_ys)
  valid = (valid_xs, valid_ys)
  test = (test_xs, test_ys)
  
  data = {'train': {'images': train_xs, 'labels': train_ys},
          'valid': {'images': valid_xs, 'labels': valid_ys},
          'test': {'images': test_xs, 'labels': test_ys}
  }

  return data

In [None]:
#@title Preprocess utilties
def apply_preprocess(config, data):
  """Apply ZCA preprocessing on the standard normalized data."""
  x_train, y_train = data['train']['images'], data['train']['labels']
  x_valid, y_valid = data['valid']['images'], data['valid']['labels']
  x_test, y_test = data['test']['images'], data['test']['labels']

  preprocess_type = config.get('preprocess_type', 'standard')
  if preprocess_type == 'standard':
    # Normalization is already done.
    pass
  else:
    zca_reg = config.get('zca_reg', 0.0)
    if preprocess_type == 'zca_normalize':
        preprocess_op = _get_preprocess_op(
          x_train,
          layer_norm=True,
          zca_reg=zca_reg,
          zca_reg_absolute_scale=config.get('zca_reg_absolute_scale', False))
        x_train = preprocess_op(x_train)
        x_valid = preprocess_op(x_valid)
        x_test = preprocess_op(x_test)
    else:
      NotImplementedError('Preprocess type %s is not implemented' %
                          preprocess_type)

  return {'train': {'images': x_train, 'labels': y_train},
          'valid': {'images': x_valid, 'labels': y_valid},
          'test': {'images': x_test, 'labels': y_test}}


def _get_preprocess_op(x_train,
                      layer_norm=True,
                      zca_reg=1e-5,
                      zca_reg_absolute_scale=False,
                      on_cpu=False):
  """ZCA preprocessing function."""
  whitening_transform = _get_whitening_transform(x_train, layer_norm, zca_reg,
                                                zca_reg_absolute_scale,
                                                on_cpu)

  def _preprocess_op(images):
    orig_shape = images.shape
    images = images.reshape(orig_shape[0], -1)
    if layer_norm:
      # Zero mean every feature
      images = images - jnp.mean(images, axis=1)[:, jnp.newaxis]
      # Normalize
      image_norms = jnp.linalg.norm(images, axis=1)
      # Make features unit norm
      images = images / image_norms[:, jnp.newaxis]

    images = (images).dot(whitening_transform)
    images = images.reshape(orig_shape)
    return images

  return _preprocess_op


def _get_whitening_transform(x_train,
                             layer_norm=True,
                             zca_reg=1e-5,
                             zca_reg_absolute_scale=False,
                             on_cpu=False):
  """Returns 2D matrix that performs whitening transform.

  Whitening transform is a (d,d) matrix (d = number of features) which acts on
  the right of a (n, d) batch of flattened data.
  """
  orig_train_shape = x_train.shape
  x_train = x_train.reshape(orig_train_shape[0], -1).astype('float64')
  if on_cpu:
    x_train = jax.device_put(x_train, jax.devices('cpu')[0])

  n_train = x_train.shape[0]
  if layer_norm:
    logging.info('Performing layer norm preprocessing.')
    # Zero mean every feature
    x_train = x_train - jnp.mean(x_train, axis=1)[:, jnp.newaxis]
    # Normalize
    train_norms = jnp.linalg.norm(x_train, axis=1)
    # Make features unit norm
    x_train = x_train / train_norms[:, jnp.newaxis]

  logging.info('Performing zca whitening preprocessing with reg: %.2e', zca_reg)
  cov = 1.0 / n_train * x_train.T.dot(x_train)
  if zca_reg_absolute_scale:
    reg_amount = zca_reg
  else:
    reg_amount = zca_reg * jnp.trace(cov) / cov.shape[0]
  logging.info('Raw zca regularization strength: %f', reg_amount)

  u, s, _ = jnp.linalg.svd(cov + reg_amount * jnp.eye(cov.shape[0]))
  inv_sqrt_zca_eigs = s**(-1 / 2)

  # rank control
  if n_train < x_train.shape[1]:
    inv_sqrt_zca_eigs = inv_sqrt_zca_eigs.at[n_train:].set(
        jnp.ones(inv_sqrt_zca_eigs[n_train:].shape[0]))
  whitening_transform = jnp.einsum(
      'ij,j,kj->ik', u, inv_sqrt_zca_eigs, u, optimize=True)
  return whitening_transform


In [None]:
# Loss Definition.
cross_entropy = lambda y, y_hat: -np.mean(np.sum(y * y_hat, axis=1))
mse_loss = lambda y, y_hat: 0.5 * jnp.mean((y - y_hat)**2)

_l2_norm = lambda params: tree_map(lambda x: jnp.sum(x ** 2), params)
l2_regularization = lambda params: tree_reduce(op.add, _l2_norm(params))


def cosine_schedule(initial_learning_rate, training_steps):
  def _cosine_schedule(t):
    return initial_learning_rate * 0.5 * (
        1 + jnp.cos(t / training_steps * jnp.pi))
  return _cosine_schedule

def _epoch_from_step(step, train_size, batch_size):
  if train_size == batch_size:
    return step
  else:
    return float(step / train_size * batch_size)  

## Define Networks

In [None]:
def _get_activation_fn(config):
  if config.activation.lower() == 'relu':
    activation_fn = stax.Relu()
  elif config.activation.lower() == 'erf':
    activation_fn = stax.Erf()
  elif config.activation.lower() == 'identity':
    activation_fn = stax.Identity()
  elif config.activation.lower() == 'gelu':
    activation_fn = stax.Gelu()
  else:
    raise ValueError('activation function %s not implemented' %
                     config.activation)
  return activation_fn


def _get_norm_layer(normalization):
  normalization = normalization.lower()
  if 'layer' in normalization:
    norm_layer = stax.LayerNorm(axis=(1, 2, 3))
  elif 'instance' in normalization:
    norm_layer = stax.LayerNorm(axis=(1, 2))
  elif normalization == '':
    norm_layer = stax.Identity()
  else:
    raise ValueError('normalization %s not implemented' % normalization)
  return norm_layer
 

def _ConvNet(config): 
  return ConvNet(
      depth=config.depth,
      width=config.width,
      prepend_conv_layer=config.prepend_conv_layer,
      use_gap=config.get('use_gap', False),
      W_std=config.W_std,
      b_std=config.b_std,
      num_classes=config.dataset.num_classes,
      parameterization=config.parameterization,
      activation_fn=_get_activation_fn(config),
      norm_layer=_get_norm_layer(config.get('normalization', '')),
      image_format=config.get('image_format', 'NHWC'))


def ConvNet( 
    depth: int,
    width: int,
    prepend_conv_layer: bool = True,
    use_gap: bool = False,
    W_std=2**0.5,
    b_std=0.1,
    num_classes: int = 10,
    parameterization: str = 'ntk',
    activation_fn=stax.Relu(),
    norm_layer=stax.Identity(),
    image_format: str = 'NHWC'):
  """Adaptation of ConvNet baseline of Dataset Condensation.

  Original architecture is based on (Gidaris & Komodakis, 2018)
  and here we adapt version of Zhao et al., Dataset Condensation with Gradient
  Matching, https://openreview.net/pdf?id=mSAKhLYLSsl

  Implements depth-many blocks of convolution, activation, 2x2 avg pooling.
  Normalization layer of corresponding finite-width neural network is omitted.

  For the 'ConvNet' settings of Nguyen et al., Dataset Distillation with 
  Infinitely Wide Convolutional Networks, set depth=3 (width is immaterial for
  'ntk' parameterization) and prepend_conv_layer=True. For 'ConvNet3' settings
  closer to Zhao et al., set prepend_conv_layer=False. 

  Args:
    depth: depth of network
    width: width of network
    prepend_conv_layer: if True, add an additional conv and relu layer before
      main set of blocks
    use_gap: if True, use global average pooling for preclassifier layer
    W_std: standard deviation of weight matrix initialization
    b_std: standard deviation of bias initialization
    num_classes: number of classes for output layer
    parameterization: 'ntk' or 'standard' for initializing network and NTK
    activation_fn: NT activation function of network
    norm_layer: NT normalization layer, default is Identity.
    image_format: Image format 'NHWC', 'NCHW' etc.

  Returns:
    Corresponding neural_tangents stax model.
  """
  layers = []
  conv = functools.partial(
      stax.Conv,
      W_std=W_std,
      b_std=b_std,
      padding='SAME',
      parameterization=parameterization)
  if prepend_conv_layer:
    layers += [
        conv(width, (3, 3),
             dimension_numbers=(image_format, 'HWIO', 'NHWC')),
        activation_fn
    ]

  # generate blocks of convolutions followed by average pooling
  for _ in range(depth):
    layers += [conv(width, (3, 3)), norm_layer,
               activation_fn, stax.AvgPool((2, 2), strides=(2, 2))]
  if use_gap:
    layers.append(stax.GlobalAvgPool())
  else:
    layers.append(stax.Flatten())
  layers.append(stax.Dense(num_classes, W_std, b_std,
                           parameterization=parameterization))

  return stax.serial(*layers)

## Define Trainer

In [None]:
def run_trainer(data, config):
  """Train a neural network."""

  # Experiment Parameters.
  batch_size = config.batch_size
  eval_batch_size = config.get('eval_batch_size', config.batch_size)
  train_size = data['train']['images'].shape[0]
  steps_per_epoch = int(np.ceil(train_size / batch_size))
  train_steps = int(config.train_steps)
  train_epochs = int(np.ceil(config.train_steps / steps_per_epoch))

  l2_lambda = config.l2_regularization

  key = jax.random.PRNGKey(config.seed)

  # Construct tf.data
  train_ds = tf.data.Dataset.from_tensor_slices({
      'images': data['train']['images'],
      'labels': data['train']['labels'],
  }).repeat().shuffle(
      data['train']['images'].shape[0], seed=0).batch(batch_size).as_numpy_iterator()

  # This is used for computing training metrics. 
  train_eval_ds = tf.data.Dataset.from_tensor_slices({
      'images': data['train']['images'],
      'labels': data['train']['labels'],
  }).batch(eval_batch_size)

  valid_ds = tf.data.Dataset.from_tensor_slices({
      'images': data['valid']['images'][:1000], # Smaller validation set size for notebook
      'labels': data['valid']['labels'][:1000], 
  }).batch(eval_batch_size)

  test_ds = tf.data.Dataset.from_tensor_slices({
      'images': data['test']['images'],
      'labels': data['test']['labels'],
  }).batch(eval_batch_size)

  # Initialize Network.
  network = _ConvNet(config)
  init_f, f, _ = network

  key, split = jax.random.split(key)
  _, init_params = init_f(split, (-1,) + data['train']['images'].shape[1:])

  # Estimate maximum learning rate
  def logit_reduced_f(params, x):
    out = f(params, x)
    return jnp.sum(out, axis=-1) / out.shape[-1]**(1 / 2)

  input_sample = data['valid']['images'][:config.empirical_ntk_num_inputs]

  empirical_kernel_fn = lambda x1, x2, params: nt.empirical_ntk_fn(
      logit_reduced_f, trace_axes=(), vmap_axes=0, implementation=1)(x1, x2,
                                                                     params)
  empirical_kernel_fn = nt.batch(empirical_kernel_fn, batch_size=10)

  logging.info('input_sample shape: %s', input_sample.shape)
  max_lr_estimate_start = time.time()
  kernel = empirical_kernel_fn(input_sample, None, init_params)
  logging.info('kernel shape: %s', kernel.shape)

  y_train_size = kernel.shape[0] * config.dataset.num_classes
  assert y_train_size == data['valid']['labels'][:config.empirical_ntk_num_inputs].size
  max_lr = nt.predict.max_learning_rate(
      ntk_train_train=kernel, y_train_size=y_train_size, eps=1e-12)
  print('Max LR estimate took: %.2fs'%(time.time() - max_lr_estimate_start))

  learning_rate = float(max_lr * config.max_lr_factor)
  print('max LR: %f, current LR: %f'%(max_lr, learning_rate))
  
  # Define Raw loss, Accuracy, and Optimizer.
  @jax.jit
  def raw_loss(params, batch):
    """Loss without weight decay."""
    images, labels = batch['images'], batch['labels']
    loss_type = config.get('loss_type', 'xent')
    if loss_type == 'xent':
      return cross_entropy(ostax.logsoftmax(f(params, images)), labels)
    elif loss_type == 'mse':
      return mse_loss(f(params, images), labels)
    else:
      raise NotImplementedError('Loss type %s not implemented:' % loss_type)

  @jax.jit
  def loss(params, batch):
    l2_loss = 0.5 * l2_lambda * l2_regularization(params)
    return raw_loss(params, batch) + l2_loss

  grad_loss = jax.jit(jax.grad(loss))

  @jax.jit
  def accuracy(params, batch):
    images, labels = batch['images'], batch['labels']
    return jnp.mean(
        jnp.array(
            jnp.argmax(f(params, images), axis=1) == jnp.argmax(labels, axis=1),
            dtype=np.float32))

  learning_rate_fn = cosine_schedule(learning_rate, config.train_steps)
  print('Using momentum optimizer.')
  opt_init_fn, opt_apply_fn, get_params = optimizers.momentum(
      learning_rate_fn, config.momentum)

  opt_apply_fn = jax.jit(opt_apply_fn)
  state = opt_init_fn(init_params)
  del init_params  # parameters obatined from optimizer state

  # Define Update and Evaluate Function.
  @jax.jit
  def update(step, state, batch):
    """Training updates."""
    params = get_params(state)
    new_step = step
    dparams = grad_loss(params, batch)
    return new_step + 1, opt_apply_fn(step, dparams, state)

  def dataset_evaluate(state, dataset):
    """Compute loss and accuracy metrics over entire dataset."""
    params = get_params(state)
    tot_metrics ={'raw_loss':0., 'loss': 0., 'correct': 0, 'count': 0}
    for eval_batch in dataset.as_numpy_iterator():
      eval_size = eval_batch['images'].shape[0]
      tot_metrics['raw_loss'] += raw_loss(params, eval_batch) * eval_size
      tot_metrics['loss'] += loss(params, eval_batch) * eval_size
      tot_metrics['correct'] += accuracy(params, eval_batch) * eval_size
      tot_metrics['count']  += eval_size
    metric ={'raw_loss': tot_metrics['raw_loss'] / tot_metrics['count'], 
             'loss': tot_metrics['loss'] / tot_metrics['count'],
             'accuracy': tot_metrics['correct']  / tot_metrics['count'] }

    return metric

  measurements = []
  # Define logging steps.
  log_max_steps = np.log10(train_steps)
  log_steps = [0] + sorted(
      list(set([int(10**t) for t in np.linspace(0.0, log_max_steps, 10)])))
  start_time = time.time()
  global_step = 0
  step_time = 0

  hparams_json = config.to_json_best_effort(indent=2)
  print('hparams: %s', hparams_json)

  print('Total training steps %d, steps_per_epoch %d' %
        (train_steps, steps_per_epoch))
  print('Step (Epoch)\tLearning Rate\tTrain Loss\tValid Loss\t'
        'Train Acc\tValid Acc\tTime Elapsed\tEval Time')
  while global_step <= train_steps:
    i = int(global_step)
    epoch = _epoch_from_step(i, train_size, batch_size)

    if i in log_steps or i % 250 == 0 or i == train_steps:
      eval_start_time = time.time()
      train_metric = dataset_evaluate(state, train_eval_ds)
      if not jnp.isfinite(train_metric['raw_loss']):
        msg = 'NaN during Training! Terminating current trial.'
        raise ValueError(msg)
      valid_metric = dataset_evaluate(state, valid_ds)
      
      eval_time = time.time() - eval_start_time
      elapsed_time = float(time.time() - start_time)
      lr = float(learning_rate_fn(i))
      measurements.append([
          i, epoch, lr,
          train_metric['loss'],
          valid_metric['loss'],
          train_metric['accuracy'],
          valid_metric['accuracy'], elapsed_time
      ])
      print(
          ('{:06d}\t({:06.1f})\t' + ('{:.6e}\t' * 3) + ('{:.6f}\t' * 4)).format(
              i, epoch, lr, train_metric['loss'], valid_metric['loss'],
              train_metric['accuracy'], valid_metric['accuracy'],
              elapsed_time, eval_time))
    global_step, state = update(global_step, state, next(train_ds))
  
  print('Training finished')
  test_metric = dataset_evaluate(state, test_ds)
  print('Step\tEpoch\tLearning Rate\tTrain Loss\tValid Loss\tTest Loss\t'
        'Train Acc\tValid Acc\tTest Acc\tTime Elapsed')
  print(('{:06d}\t({:06.1f})\t' + ('{:.6e}\t' * 4) + ('{:.6f}\t' * 4)).format(
      i, epoch, lr, train_metric['loss'], valid_metric['loss'], 
      test_metric['loss'], train_metric['accuracy'], 
      valid_metric['accuracy'], test_metric['accuracy'],
      time.time() - start_time))
  return measurements

## Run an Experiment with Trainer

In [None]:
tf.config.experimental.set_visible_devices([], 'GPU')

config = get_config()
data = load_kip_data(config)
measurements = run_trainer(data, config)

Max LR estimate took: 32.74s
max LR: 0.006775, current LR: 0.006775
Using momentum optimizer.
hparams: %s {
  "seed": 0,
  "dataset": {
    "name": "cifar10",
    "height": 32,
    "width": 32,
    "num_channels": 3,
    "num_classes": 10
  },
  "preprocess_type": "zca_normalize",
  "zca_reg": 0.1,
  "optimizer": "momentum",
  "momentum": 0.9,
  "batch_size": 100,
  "eval_batch_size": 100,
  "train_steps": 5000,
  "empirical_ntk_num_inputs": 50,
  "max_lr_factor": 1.0,
  "l2_regularization": 0.001,
  "width": 1024,
  "depth": 3,
  "use_gap": false,
  "W_std": 1.4142135623730951,
  "b_std": 0.1,
  "activation": "relu",
  "loss_type": "mse",
  "parameterization": "standard",
  "kip": {
    "data_ckpt_path": "gs://kip-datasets/kip/cifar10/ssize100_zca_nol_noaug_ckpt1000.npz"
  }
}
Total training steps 5000, steps_per_epoch 1
Step (Epoch)	Learning Rate	Train Loss	Valid Loss	Train Acc	Valid Acc	Time Elapsed	Eval Time
000000	(0000.0)	6.775185e-03	4.233754e+00	4.299506e+00	0.100000	0.117000	8