<a href="https://colab.research.google.com/github/FreddieNeverLeft/DP-RFAD/blob/main/DP-RFAD_JIT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# setup

## Imports

In [8]:
import functools

from jax.example_libraries import optimizers
import jax
import jax.config
from jax.config import config as jax_config
jax_config.update('jax_enable_x64', True) # for numerical stability, can disable if not an issue
from jax import numpy as jnp
from jax import scipy as sp
from jax import random
import numpy as np
import tensorflow as tf

tf.config.experimental.set_visible_devices([], 'GPU')
import tensorflow_datasets as tfds

import matplotlib.pyplot as plt
%matplotlib inline

# ## Flax (NN in JAX)
# try:
#     import flax
# except ModuleNotFoundError: # Install flax if missing
#     !pip install --quiet flax
#     import flax
# from flax import linen as nn
# from flax.training import train_state, checkpoints

# try: 
#   import neural_tangents as nt
# except ModuleNotFoundError: # Install neural_tangents if missing
#   !pip install -q git+https://www.github.com/google/neural-tangents
#   import neural_tangents as nt
# from neural_tangents import stax

# try:
#   import haiku as hk
# except ModuleNotFoundError: # Install neural_tangents if missing
#   !pip install git+https://github.com/deepmind/dm-haiku
#   import haiku as hk
!pip install torch_optimizer

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torch_optimizer
  Downloading torch_optimizer-0.3.0-py3-none-any.whl (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.9/61.9 KB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
Collecting pytorch-ranger>=0.1.1
  Downloading pytorch_ranger-0.1.1-py3-none-any.whl (14 kB)
Installing collected packages: pytorch-ranger, torch_optimizer
Successfully installed pytorch-ranger-0.1.1 torch_optimizer-0.3.0


In [None]:
!git clone https://github.com/yolky/RFAD

In [None]:
!python3 RFAD/run_distillation.py --dataset cifar10 --save_path path/to/directory/ --samples_per_class 10 --platt --learn_labels

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz
100% 170498071/170498071 [00:02<00:00, 78364157.70it/s]
Extracting ./data/cifar-10-python.tar.gz to ./data/
Files already downloaded and verified


## Define Parameters

In [None]:
# architecture params
ARCHITECTURE = 'FC' #@param ['FC', 'Conv', 'Myrtle']
# choice of neural network architecture yielding the corresponding NTK
DEPTH =  1#@param {'type': int}; depth of neural network
WIDTH = 100 #@param {'type': int}; width of finite width neural network; only used if parameterization is 'standard'
PARAMETERIZATION = 'standard' #@param ['ntk', 'standard']
# whether to use standard or NTK parameterization, see https://arxiv.org/abs/2001.07301

# dataset
DATASET = 'cifar10' #@param ['cifar10', 'cifar100', 'mnist', 'svhn_cropped']

# training params
LEARNING_RATE = 1e-2 #@param {'type': float};
SUPPORT_SIZE = 200  #@param {'type': int}; number of images to learn
TARGET_BATCH_SIZE =60000  #@param {'type': int}; number of target images to use in KRR for each step

## Load Data

In [None]:
def get_tfds_dataset(name):
  ds_train, ds_test = tfds.as_numpy(
      tfds.load(
          name,
          split=['train', 'test'],
          batch_size=-1,
          as_dataset_kwargs={'shuffle_files': False}))

  return ds_train['image'], ds_train['label'], ds_test['image'], ds_test['label']

def one_hot(x,
            num_classes,
            center=False,
            dtype=np.float32):
  assert len(x.shape) == 1
  one_hot_vectors = np.array(x[:, None] == np.arange(num_classes), dtype)
  if center:
    one_hot_vectors = one_hot_vectors - 1. / num_classes
  return one_hot_vectors

def get_normalization_data(arr):
  channel_means = np.mean(arr, axis=(0, 1, 2))
  channel_stds = np.std(arr, axis=(0, 1, 2))
  return channel_means, channel_stds

def normalize(arr, mean, std):
  return (arr - mean) / std

X_TRAIN_RAW, LABELS_TRAIN, _, _ = get_tfds_dataset(DATASET)
# X_TRAIN_RAW, LABELS_TRAIN, X_TEST_RAW, LABELS_TEST = get_tfds_dataset(DATASET)
Y_TRAIN = one_hot(LABELS_TRAIN, 10)
# channel_means, channel_stds = get_normalization_data(X_TRAIN_RAW)
# X_TRAIN, X_TEST = normalize(X_TRAIN_RAW, channel_means, channel_stds), normalize(X_TEST_RAW, channel_means, channel_stds)
# Y_TRAIN, Y_TEST = one_hot(LABELS_TRAIN, 10), one_hot(LABELS_TEST, 10) 

[1mDownloading and preparing dataset 162.17 MiB (download: 162.17 MiB, generated: 132.40 MiB, total: 294.58 MiB) to ~/tensorflow_datasets/cifar10/3.0.2...[0m


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Generating splits...:   0%|          | 0/2 [00:00<?, ? splits/s]

Generating train examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]

Shuffling ~/tensorflow_datasets/cifar10/3.0.2.incomplete3451FX/cifar10-train.tfrecord*...:   0%|          | 0/…

Generating test examples...:   0%|          | 0/10000 [00:00<?, ? examples/s]

Shuffling ~/tensorflow_datasets/cifar10/3.0.2.incomplete3451FX/cifar10-test.tfrecord*...:   0%|          | 0/1…

[1mDataset cifar10 downloaded and prepared to ~/tensorflow_datasets/cifar10/3.0.2. Subsequent calls will reuse this data.[0m


## Define Kernel

In [None]:
# define architectures

def FullyConnectedNetwork( 
    depth,
    width,
    W_std = np.sqrt(2), 
    b_std = 0.1,
    num_classes = 10,
    parameterization = 'ntk',
    activation = 'relu'):
  """Returns neural_tangents.stax fully connected network."""
  activation_fn = stax.Relu()
  dense = functools.partial(
      stax.Dense, W_std=W_std, b_std=b_std, parameterization=parameterization)

  layers = [stax.Flatten()]
  for _ in range(depth):
    layers += [dense(width), activation_fn]
  layers += [stax.Dense(num_classes, W_std=W_std, b_std=b_std, 
                        parameterization=parameterization)]

  return stax.serial(*layers)

def FullyConvolutionalNetwork( 
    depth,
    width,
    W_std = np.sqrt(2), 
    b_std = 0.1,
    num_classes = 10,
    parameterization = 'ntk',
    activation = 'relu'):
  """Returns neural_tangents.stax fully convolutional network."""
  activation_fn = stax.Relu()
  conv = functools.partial(
      stax.Conv,
      W_std=W_std,
      b_std=b_std,
      padding='SAME',
      parameterization=parameterization)
  
  layers = [stax.Flatten()]
  for _ in range(depth):
    layers += [conv(width, (3,3)), activation_fn]
  layers += [stax.Flatten(), stax.Dense(num_classes, W_std=W_std, b_std=b_std,
                                        parameterization=parameterization)]

  return stax.serial(*layers)

def MyrtleNetwork(  
    depth,
    width,
    W_std = np.sqrt(2), 
    b_std = 0.1,
    num_classes = 10,
    parameterization = 'ntk',
    activation = 'relu'):
  """Returns neural_tangents.stax Myrtle network."""
  layer_factor = {5: [1, 1, 1], 7: [1, 2, 2], 10: [2, 3, 3]}
  if depth not in layer_factor.keys():
    raise NotImplementedError(
        'Myrtle network withd depth %d is not implemented!' % depth)
  activation_fn = stax.Relu()
  layers = []
  conv = functools.partial(
      stax.Conv,
      W_std=W_std,
      b_std=b_std,
      padding='SAME',
      parameterization=parameterization)
  layers += [conv(width, (3, 3)), activation_fn]

  # generate blocks of convolutions followed by average pooling for each
  # layer of layer_factor except the last
  for block_depth in layer_factor[depth][:-1]:
    for _ in range(block_depth):
      layers += [conv(width, (3, 3)), activation_fn]
    layers += [stax.AvgPool((2, 2), strides=(2, 2))]

  # generate final blocks of convolution followed by global average pooling
  for _ in range(layer_factor[depth][-1]):
    layers += [conv(width, (3, 3)), activation_fn]
  layers += [stax.GlobalAvgPool()]

  layers += [
      stax.Dense(num_classes, W_std, b_std, parameterization=parameterization)
  ]

  return stax.serial(*layers)

def get_kernel_fn(architecture, depth, width, parameterization):
  if architecture == 'FC':
    return FullyConnectedNetwork(depth=depth, width=width, parameterization=parameterization)
  elif architecture == 'Conv':
    return FullyConvolutionalNetwork(depth=depth, width=width, parameterization=parameterization)
  elif architecture == 'Myrtle':
    return MyrtleNetwork(depth=depth, width=width, parameterization=parameterization)
  else:
    raise NotImplementedError(f'Unrecognized architecture {architecture}')

# Run KIP

In [None]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
print(X_TRAIN_RAW.shape)

gpu
(50000, 32, 32, 3)


In [None]:
_, _, k_f = get_kernel_fn(ARCHITECTURE, DEPTH, WIDTH, PARAMETERIZATION)
KERNEL_FN = jax.jit(functools.partial(k_f, get='ntk'))
KERNEL_FN = nt.batch(KERNEL_FN, device_count = -1, batch_size=5)

In [None]:
def get_class_balanced_sample(sample_size: int,
                            apply_fn,
                            n_classes:int,
                            key, d_code = 90):

  def class_balanced_sample(params, state):
    
    n_per_class, remainder = divmod(sample_size, n_classes)
    if remainder != 0:
      raise ValueError(
          f'Number of classes {n_classes} in labels must divide sample size {sample_size}.'
      )

    return_labels = False
    labels = jnp.repeat(jnp.arange(n_classes), n_per_class)
    code = jax.random.normal(key, (sample_size, d_code))
    gen_one_hots = jax.nn.one_hot(labels, n_classes)
    code = jnp.concatenate([code, gen_one_hots], axis=1)[:,:, None, None]
    gen_samples, state = apply_fn(params, state, code, is_training=True)

    return gen_samples, gen_one_hots, n_per_class, state

  return class_balanced_sample


def make_loss_acc_fn(kernel_fn, x_true, y_true, class_balanced_sample, n_classes = 10):
  x_true = x_true.reshape(*x_true.shape[:-3], -1)
  y_true = np.argmax(y_true, axis =1)
  sum_Ktt = np.zeros(n_classes)
  x_true_c = []
  # x_true_c = np.array([])
  for c in range(n_classes):
    x_true_c.append(x_true[y_true == c].astype(jnp.float32))
    print(x_true_c[c].shape)
    # print(kernel_fn(x_true_c[c], x_true_c[c]))
  x_true_c = np.array(x_true_c)
  print(x_true_c.shape)
  for c in range(n_classes):
    sum_Ktt[c] += np.sum(kernel_fn(x_true_c[c], x_true_c[c]))
    # print(sum_Ktt[c])

  @jax.jit
  def loss_acc_fn(params, state):
    x_syn, y_syn, n_per_class, state = class_balanced_sample(params, state)
    mmd_loss = 0
    y_syn = jnp.argmax(y_syn, axis =1)
    for c in range(n_classes):
      # x_true_c = x_true[y_true == c]
      x_syn_c = x_syn[c*n_per_class:(c+1)*n_per_class]
      x_syn_c = x_syn_c.reshape((n_per_class, 32,32,3))
      m = x_true_c[c].shape[0]
      n = x_syn_c.shape[0]
      # batch these ==> for now, not the bottleneck. can save(?)
      sum_Kss = jnp.sum(kernel_fn(x_syn_c, x_syn_c)).block_until_ready()
      sum_Kts = jnp.sum(kernel_fn(x_true_c[c], x_syn_c)).block_until_ready()
      # print(c, mmd_loss)

      mmd_loss += jnp.sqrt(sum_Ktt[c]/(m*m) - 2 * sum_Kts/(m*n)
                           + sum_Kss/(n*n))
    return mmd_loss, state

  return loss_acc_fn

# Taking loss_acc_fn and class_balanced_sample outside.
def get_update_functions(init_params, loss_acc_fn, lr):
  opt_init, opt_update, get_params = optimizers.adam(lr) 
  opt_state = opt_init(init_params)
  # loss_acc_fn = make_loss_acc_fn(kernel_fn)
  # class_balanced_sample = get_class_balanced_sample(sample_size, apply_fn, n_classes, key) 
  value_and_grad = jax.value_and_grad(lambda params, state:loss_acc_fn(params, state), has_aux=True)
  # value_and_grad = jax.pmap(value_and_grad)
  

  @jax.jit
  def update_fn(step, opt_state, params, state, x_true, y_true):
    (loss, state), grad = value_and_grad(params, state)
    return opt_update(step, grad, opt_state), (loss, state)

  return opt_state, get_params, update_fn

def train(num_train_steps, log_freq=1000, seed=1):

  d_code = 90
  n_classes = 10
  # key = random.PRNGKey(0)
  # key, split = jax.random.split(key)

  def get_gen(batch, is_training):
      model = hk.nets.ResNet18(
          num_classes=3072, resnet_v2 = True)
      return model(batch, is_training=is_training)

  forward_gen = hk.without_apply_rng(hk.transform_with_state(get_gen))

  rng = jax.random.PRNGKey(42)

  gen_input = jnp.ones((1, d_code + n_classes, 1, 1))
  gen_params_init, state = forward_gen.init(rng, gen_input, is_training=True)
  # logits, state = forward_gen.apply(params, state, image, is_training=True)

  class_balanced_sample = get_class_balanced_sample(SUPPORT_SIZE, forward_gen.apply, n_classes, rng)
  # x_syn, y_syn, state = class_balanced_sample(gen_params_init, state)

  loss_acc_fn = make_loss_acc_fn(KERNEL_FN, X_TRAIN_RAW, Y_TRAIN, class_balanced_sample)
  opt_state, get_params, update_fn = get_update_functions(gen_params_init, loss_acc_fn, LEARNING_RATE)
  params = get_params(opt_state) # getting the latest nn parmeters

  for i in range(num_train_steps):
    opt_state, aux = update_fn(i, opt_state, params, state, X_TRAIN_RAW, Y_TRAIN)
    train_loss, state = aux
    params = get_params(opt_state)
    if i % log_freq == 0:
      print(f'----step {i}:')
      print('train loss:', train_loss)

  class_balanced_sample = get_class_balanced_sample(5000, forward_gen.apply, n_classes, rng)
  x_syn, y_syn, n_per_class, state = class_balanced_sample(params, state)
  x_syn = x_syn.reshape(5000, 32 ,32, 3)
  y_syn = jnp.argmax(y_syn, axis =1)
  fig = plt.figure(figsize=(100,100))
  fig.suptitle('DP-MENT Generation', fontsize=20, y=1.02)
  class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
  x_syn_c0 = []
  for c in range(10):
    x_syn_c0.append(x_syn[y_syn == c][0:9])

  for i in range(10):
    for j, img in enumerate(x_syn_c0[i]):
      ax = plt.subplot(10, 10, j+i*10+1)
      if i == 0:
        ax.set_title(class_names[j])
      plt.imshow(np.squeeze(img.astype(jnp.uint8)))


  return x_syn, y_syn, gen_params_init, state
  
x_syn, y_syn, params, state = train(1)
# print(x_syn.dtype)

(5000, 3072)
(5000, 3072)
(5000, 3072)
(5000, 3072)
(5000, 3072)
(5000, 3072)
(5000, 3072)
(5000, 3072)
(5000, 3072)
(5000, 3072)
(10, 5000, 3072)


In [None]:
import torch as pt
x_syn = np.transpose(np.asarray(x_syn, dtype=np.float32), (0, 3, 1, 2))
x_syn = pt.from_numpy(x_syn).type(pt.float32)

# x_real = np.transpose(X_TRAIN_RAW, (0, 3, 1, 2))
# x_real = pt.from_numpy(x_real).type(pt.float32)

np.savez('./gen_data.npz', x=x_syn, y=y_syn)

In [None]:
!pip install pytorch-fid
from pytorch_fid.fid_score import calculate_frechet_distance
from pytorch_fid.inception import InceptionV3

!pip install cloud-tpu-client==0.10 torch==1.12.0 https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-1.12-cp37-cp37m-linux_x86_64.whl
# imports pytorch
import torch

# imports the torch_xla package
import torch_xla
import torch_xla.core.xla_model as xm

In [None]:
import torch as pt
from tqdm import tqdm

class SynthDataset(pt.utils.data.Dataset):
  def __init__(self, data, targets, to_tensor):
    self.labeled = targets is not None
    self.data = data
    self.targets = targets
    self.to_tensor = to_tensor

  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):
    d = pt.tensor(self.data[idx], dtype=pt.float32) if self.to_tensor else self.data[idx]
    if self.labeled:
      t = pt.tensor(self.targets[idx], dtype=pt.long) if self.to_tensor else self.targets[idx]
      return d, t
    else:
      return d

def load_synth_dataset(data, batch_size, subset_size=None, to_tensor=False, shuffle=True):
  # if data_file.endswith('.npz'):  # allow for labels
  #   data_dict = np.load(data_file)
  #   data = data_dict['x']
  #   if 'y' in data_dict.keys():
  #     targets = data_dict['y']
  #     if len(targets.shape) > 1:
  #       targets = np.squeeze(targets)
  #       assert len(targets.shape) == 1, f'need target vector. shape is {targets.shape}'
  #   else:
  #     targets = None

  #   if subset_size is not None:
  #     random_subset = np.random.permutation(data_dict['x'].shape[0])[:subset_size]
  #     data = data[random_subset]
  #     targets = targets[random_subset] if targets is not None else None
  synth_data = SynthDataset(data=data, targets=None, to_tensor=to_tensor)
  
  synth_dataloader = pt.utils.data.DataLoader(synth_data, batch_size=batch_size, shuffle=shuffle,
                                              drop_last=False, num_workers=1)
  return synth_dataloader

def stats_from_dataloader(dataloader, model, device= xm.xla_device()):
    """
  Returns:
  -- mu    : The mean over samples of the activations of the pool_3 layer of
             the inception model.
  -- sigma : The covariance matrix of the activations of the pool_3 layer of
             the inception model.
  """
    model.eval()

    pred_list = []

    start_idx = 0

    for batch in tqdm(dataloader):
        x = batch[0] if (isinstance(batch, tuple) or isinstance(batch, list)) else batch
        x = x.to(device)

        with pt.no_grad():
            pred = model(x)[0]

        # If model output is not scalar, apply global spatial average pooling.
        # This happens if you choose a dimensionality not equal 2048.
        if pred.size(2) != 1 or pred.size(3) != 1:
            pred = pt.nn.adaptive_avg_pool2d(pred, output_size=(1, 1))

        pred = pred.squeeze(3).squeeze(2).cpu().numpy()

        # pred_arr[start_idx:start_idx + pred.shape[0]] = pred
        pred_list.append(pred)

        start_idx = start_idx + pred.shape[0]

    pred_arr = np.concatenate(pred_list, axis=0)
    # return pred_arr
    mu = np.mean(pred_arr, axis=0)
    sigma = np.cov(pred_arr, rowvar=False)
    return mu, sigma

def get_fid_scores(x_syn, x_real):
    dims = 2048

    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]

    model = InceptionV3([block_idx])

    x_syn = np.transpose(np.asarray(x_syn, dtype=np.float32), (0, 3, 1, 2))
    x_syn = pt.from_numpy(x_syn).type(pt.float32)

    x_real = np.transpose(x_real, (0, 3, 1, 2))
    x_real = pt.from_numpy(x_real).type(pt.float32)

    dev = xm.xla_device()

    x_syn.to(dev)
    x_real.to(dev)


    # stats = np.load(real_data_stats_file)
    # mu_real, sig_real = stats['mu'], stats['sig']

    real_dl = load_synth_dataset(x_real, 50)

    mu_real, sig_real = stats_from_dataloader(real_dl, model)

    print(mu_real.shape, sig_real.shape)

    mu_syn, sig_syn = stats_from_dataloader(x_syn, model)

    fid = calculate_frechet_distance(mu_real, sig_real, mu_syn, sig_syn)
    return fid

RuntimeError: ignored

In [None]:
# x_true.reshape(*x_syn.shape[:-3], -1)
# mu_syn = np.mean(x_syn.reshape(*x_syn.shape[:-3], -1), axis=0)
# sig_syn = np.cov(x_syn.reshape(*x_syn.shape[:-3], -1), rowvar=False)

# mu_real = np.mean(X_TRAIN_RAW.reshape(*X_TRAIN_RAW.shape[:-3], -1), axis=0)
# sig_real = np.cov(X_TRAIN_RAW.reshape(*X_TRAIN_RAW.shape[:-3], -1), rowvar=False)

# # mu_syn = np.squeeze(np.asarray(mu_syn))
# # sig_syn = np.squeeze(np.asarray(sig_syn))
# print(mu_real.shape, sig_real.shape)
# print(mu_syn.shape, sig_syn.shape) 

fid = get_fid_scores(x_syn, X_TRAIN_RAW)
print(f'fid={fid}')

In [None]:
# init_fn, f, _ = get_kernel_fn(ARCHITECTURE, DEPTH, WIDTH, PARAMETERIZATION)
key1 = random.PRNGKey(1)
# _, params_kernel = init_fn(key1, jnp.ones((1, 32,32,3)).shape)
# KERNEL_FN_UNBATCHED = functools.partial(kernel_fn, get='ntk')
# kernel_fn = nt.empirical_kernel_fn(
#     f, trace_axes=(), vmap_axes=0, implementation=1)
# kernel_fn = functools.partial(kernel_fn, get='ntk', params = params_kernel)

# KERNEL_FN = jax.jit(nt.batch(kernel_fn, device_count = -1, batch_size=25))
# KERNEL_FN = functools.partial(kernel_fn_batched, get='ntk')
# TODO: try the empirical one next?
# KERNEL_FN = jax.jit(functools.partial(kernel_fn_batched, get='ntk'))
# INIT_FN = jax.jit(functools.partial(init_fn))

# init_fn, f, _ = stax.serial(
#     stax.Conv(32, (3, 3)),
#     stax.Relu(),
#     stax.Conv(32, (3, 3)),
#     stax.Relu(),
#     stax.Conv(32, (3, 3)),
#     stax.Flatten(),
#     stax.Dense(10)
# )

# _, params_kernel = init_fn(key1, jnp.ones((1, 32,32,3)).shape)
#
# Default setting: reducing over logits; pass `vmap_axes=0` because the
# network is iid along the batch axis, no BatchNorm. Use default
# `implementation=nt.NtkImplementation.JACOBIAN_CONTRACTION` (`1`).
# kernel_fn = nt.empirical_kernel_fn(
#     f, trace_axes=(-1,), vmap_axes=0,
#     implementation=nt.NtkImplementation.JACOBIAN_CONTRACTION)

# kernel_fn = functools.partial(kernel_fn, get='ntk', params = params_kernel)

_, _, kernel_fn = stax.serial(
    stax.Conv(32, (3, 3)),
    stax.Relu(),
    stax.Conv(32, (3, 3)),
    stax.Relu(),
    stax.Conv(32, (3, 3)),
    stax.Flatten(),
    stax.Dense(10)
)

_, _, kernel_fn = stax.serial(
    stax.Conv(32, (3, 3)),
    stax.Flatten(),
    stax.Dense(10)
)

kernel_fn = jax.jit(functools.partial(kernel_fn, get='ntk'))
KERNEL_FN = nt.batch(kernel_fn, device_count = 5, batch_size=40)