In [1]:
import os
import pathlib
import random
import functools
import copy
import sys
import pickle
import tarfile
import operator
import math
import requests
import importlib

import tqdm

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

import sklearn
from sklearn import preprocessing
import scipy

import tensorflow as tf
import tensorflow_datasets as tfds

import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)

import flax

import neural_tangents as nt

import vit_jax

In [3]:
# Helper functions for images.
labelnames = dict(
    # https://www.cs.toronto.edu/~kriz/cifar.html
    cifar10=('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'),
    # https://www.cs.toronto.edu/~kriz/cifar.html
    cifar100=('apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'computer_keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm')
)
def make_label_getter(dataset):
    """Returns a function converting label indices to names."""
    def getter(label):
        if dataset in labelnames:
            return labelnames[dataset][label]
        return f'label={label}'
    return getter

def show_img(img, ax=None, title=None):
  """Shows a single image."""
  if ax is None:
    ax = plt.gca()
  ax.imshow(img[...])
  ax.set_xticks([])
  ax.set_yticks([])
  if title:
    ax.set_title(title)

def show_img_grid(imgs, titles):
  """Shows a grid of images."""
  n = int(np.ceil(len(imgs)**.5))
  _, axs = plt.subplots(n, n, figsize=(3 * n, 3 * n))
  for i, (img, title) in enumerate(zip(imgs, titles)):
    img = (img + 1) / 2  # Denormalize
    show_img(img, axs[i // n][i % n], title)

dataset = 'cifar10'
batch_size = 256  # Reduce to 256 if running on a single GPU.

# Note the datasets are configured in input_pipeline.DATASET_PRESETS
# Have a look in the editor at the right.
num_classes = vit_jax.input_pipeline.get_dataset_info(dataset, 'train')['num_classes']
# tf.data.Datset for training, infinite repeats.
ds_train = vit_jax.input_pipeline.get_data(
    dataset=dataset, mode='train', repeats=None, batch_size=batch_size,
    tfds_data_dir=".data/"
)
# tf.data.Datset for evaluation, single repeat.
ds_test = vit_jax.input_pipeline.get_data(
    dataset=dataset, mode='test', repeats=None, batch_size=batch_size,
    tfds_data_dir=".data/"
)

# Fetch a batch of test images for illustration purposes.
batch = next(iter(ds_test.as_numpy_iterator()))
# Note the shape : [num_local_devices, local_batch_size, h, w, c]
print(batch['image'].shape)
print(batch["label"].shape)

(1, 256, 32, 32, 3)
(1, 256, 10)


In [3]:
logger = vit_jax.logging.setup_logger('.logs/')

# Load model definition & initialize random parameters.
VisionTransformer = vit_jax.models.KNOWN_MODELS["ViT-B_16"].partial(num_classes=10)
_, params = VisionTransformer.init_by_shape(
    jax.random.PRNGKey(0),
    # Discard the "num_local_devices" dimension of the batch for initialization.
    [((batch_size, 32, 32, 3), 'float32')])

# Load and convert pretrained checkpoint.
# This involves loading the actual pre-trained model results, but then also also
# modifying the parameters a bit, e.g. changing the final layers, and resizing
# the positional embeddings.
# For details, refer to the code and to the methods of the paper.
params = vit_jax.checkpoint.load_pretrained(
    pretrained_path='.models/ViT-B_16.npz',
    init_params=params,
    model_config=vit_jax.models.CONFIGS["ViT-B_16"],
    logger=logger
)

# So far, all our data is in the host memory. Let's now replicate the arrays
# into the devices.
# This will make every array in the pytree params become a ShardedDeviceArray
# that has the same data replicated across all local devices.
# For TPU it replicates the params in every core.
# For a single GPU this simply moves the data onto the device.
# For CPU it simply creates a copy.
params_repl = flax.jax_utils.replicate(params)
print('params.cls:', type(params['cls']).__name__, params['cls'].shape)
print('params_repl.cls:', type(params_repl['cls']).__name__, params_repl['cls'].shape)

# Then map the call to our model's forward pass onto all available devices.
vit_apply_repl = jax.pmap(VisionTransformer.call)

def get_accuracy(params_repl):
  """Returns accuracy evaluated on the test set."""
  good = total = 0
  steps = vit_jax.input_pipeline.get_dataset_info(dataset, 'test')['num_examples'] // batch_size
  for _, batch in zip(tqdm.notebook.trange(steps), ds_test.as_numpy_iterator()):
    predicted = vit_apply_repl(params_repl, batch['image'])
    is_same = predicted.argmax(axis=-1) == batch['label'].argmax(axis=-1)
    good += is_same.sum()
    total += len(is_same.flatten())
  return good / total

# Random performance without fine-tuning.
get_accuracy(params_repl)

{'pre_logits'}
2021-01-26 23:41:27,576 [INFO] vit_jax.logging: Inspect extra keys:
{'pre_logits/bias', 'pre_logits/kernel'}
2021-01-26 23:41:27,578 [INFO] vit_jax.logging: load_pretrained: drop-head variant
2021-01-26 23:41:27,580 [INFO] vit_jax.logging: load_pretrained: resized variant: (1, 197, 768) to (1, 5, 768)
2021-01-26 23:41:27,582 [INFO] vit_jax.logging: load_pretrained: grid-size from 14 to 2
2021-01-26 23:41:28,132 [INFO] absl: Load pre-computed DatasetInfo (eg: splits, num examples,...) from GCS: cifar10/3.0.2


params.cls: ndarray (1, 1, 768)
params_repl.cls: ShardedDeviceArray (1, 1, 1, 768)


2021-01-26 23:41:28,849 [INFO] absl: Load dataset info from /tmp/tmp3oeqy8b7tfds
2021-01-26 23:41:28,851 [INFO] absl: Field info.citation from disk and from code do not match. Keeping the one from code.
2021-01-26 23:41:28,852 [INFO] absl: Field info.splits from disk and from code do not match. Keeping the one from code.
2021-01-26 23:41:28,853 [INFO] absl: Field info.module_name from disk and from code do not match. Keeping the one from code.


  0%|          | 0/39 [00:00<?, ?it/s]

DeviceArray(0.09995994, dtype=float64)

In [4]:
# 100 Steps take approximately 15 minutes in the TPU runtime.
total_steps = 100
warmup_steps = 5
decay_type = 'cosine'
grad_norm_clip = 1
# This controls in how many forward passes the batch is split. 8 works well with
# a TPU runtime that has 8 devices. 64 should work on a GPU. You can of course
# also adjust the batch_size above, but that would require you to adjust the
# learning rate accordingly.
accum_steps = 64
base_lr = 0.03

# Check out train.make_update_fn in the editor on the right side for details.
update_fn_repl = vit_jax.train.make_update_fn(VisionTransformer.call, accum_steps)
# We use a momentum optimizer that uses half precision for state to save
# memory. It als implements the gradient clipping.
opt = vit_jax.momentum_clip.Optimizer(grad_norm_clip=grad_norm_clip).create(params)
opt_repl = flax.jax_utils.replicate(opt)

lr_fn = vit_jax.hyper.create_learning_rate_schedule(total_steps, base_lr, decay_type, warmup_steps)
# Prefetch entire learning rate schedule onto devices. Otherwise we would have
# a slow transfer from host to devices in every step.
lr_iter = vit_jax.hyper.lr_prefetch_iter(lr_fn, 0, total_steps)
# Initialize PRNGs for dropout.
update_rngs = jax.random.split(jax.random.PRNGKey(0), jax.local_device_count())

# The world's simplest training loop.
# Completes in ~20 min on the TPU runtime.
for step, batch, lr_repl in zip(
    tqdm.notebook.trange(1, total_steps + 1),
    ds_train.as_numpy_iterator(),
    lr_iter
):
    opt_repl, loss_repl, update_rngs = update_fn_repl(
        opt_repl, lr_repl, batch, update_rngs)

# Should be ~97.2% for CIFAR10
# Should be ~71.2% for CIFAR100
get_accuracy(opt_repl.target)

  0%|          | 0/100 [00:00<?, ?it/s]

2021-01-27 00:05:36,010 [INFO] absl: Load pre-computed DatasetInfo (eg: splits, num examples,...) from GCS: cifar10/3.0.2
2021-01-27 00:05:37,518 [INFO] absl: Load dataset info from /tmp/tmp6fe13xkwtfds
2021-01-27 00:05:37,568 [INFO] absl: Field info.citation from disk and from code do not match. Keeping the one from code.
2021-01-27 00:05:37,569 [INFO] absl: Field info.splits from disk and from code do not match. Keeping the one from code.
2021-01-27 00:05:37,571 [INFO] absl: Field info.module_name from disk and from code do not match. Keeping the one from code.


  0%|          | 0/39 [00:00<?, ?it/s]

DeviceArray(0.44280849, dtype=float64)