# MaxViT Tutotial

**Note**: Please connect to a GPU runtime instance. Make sure tensorflow can be imported.


<table align="left"><td>
  <a target="_blank"  href="https://github.com/google-research/maxvit/blob/main/MaxViT_tutorial.ipynb">
    <img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on github
  </a>
</td><td>
  <a target="_blank"  href="https://colab.research.google.com/github/google-research/maxvit/blob/master/MaxViT_tutorial.ipynb">
    <img width=32px src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
</td></table>

## 0. Import libraries

In [None]:
#@title Import libs
import time
from IPython import display
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import tensorflow_datasets as tfds
!git clone https://github.com/google-research/maxvit
%cd /content/maxvit
# set up module
!python setup.py install

# imports
import maxvit.models.hparams as hparams
import maxvit.models.maxvit as layers

# Checkpoints location
CKPTS_DIRS = {
    'MaxViTTiny_i1k_224': 'gs://gresearch/maxvit/ckpts/maxvittiny/i1k/224',
    'MaxViTTiny_i1k_384': 'gs://gresearch/maxvit/ckpts/maxvittiny/i1k/384',
    'MaxViTTiny_i1k_512': 'gs://gresearch/maxvit/ckpts/maxvittiny/i1k/512',
    'MaxViTSmall_i1k_224': 'gs://gresearch/maxvit/ckpts/maxvitsmall/i1k/224',
    'MaxViTSmall_i1k_384': 'gs://gresearch/maxvit/ckpts/maxvitsmall/i1k/384',
    'MaxViTSmall_i1k_512': 'gs://gresearch/maxvit/ckpts/maxvitsmall/i1k/512',
    'MaxViTBase_i1k_224': 'gs://gresearch/maxvit/ckpts/maxvitbase/i1k/224',
    'MaxViTBase_i1k_384': 'gs://gresearch/maxvit/ckpts/maxvitbase/i1k/384',
    'MaxViTBase_i1k_512': 'gs://gresearch/maxvit/ckpts/maxvitbase/i1k/512',
    'MaxViTBase_i21k_i1k_224': None,
    'MaxViTBase_i21k_i1k_384': 'gs://gresearch/maxvit/ckpts/maxvitbase/i21k_i1k/384',
    'MaxViTBase_i21k_i1k_512': 'gs://gresearch/maxvit/ckpts/maxvitbase/i21k_i1k/512',
    'MaxViTLarge_i1k_224': 'gs://gresearch/maxvit/ckpts/maxvitlarge/i1k/224',
    'MaxViTLarge_i1k_384': 'gs://gresearch/maxvit/ckpts/maxvitlarge/i1k/384',
    'MaxViTLarge_i1k_512': 'gs://gresearch/maxvit/ckpts/maxvitlarge/i1k/512',
    'MaxViTLarge_i21k_i1k_224': None,
    'MaxViTLarge_i21k_i1k_384': 'gs://gresearch/maxvit/ckpts/maxvitlarge/i21k_i1k/384',
    'MaxViTLarge_i21k_i1k_512': 'gs://gresearch/maxvit/ckpts/maxvitlarge/i21k_i1k/512',
    'MaxViTXLarge_i21k_i1k_224': None,
    'MaxViTXLarge_i21k_i1k_384': 'gs://gresearch/maxvit/ckpts/maxvitxlarge/i21k_i1k/384',
    'MaxViTXLarge_i21k_i1k_512': 'gs://gresearch/maxvit/ckpts/maxvitxlarge/i21k_i1k/512',
}

DATASET_MAP = {
    'ImageNet-1K': 'i1k', 
    'ImageNet-21K': 'i21k_i1k',
}

## 1. Benchmark inference time

In [None]:
#@title Set model and params

MODEL_NAME = "MaxViTTiny" #@param ["MaxViTTiny", "MaxViTSmall", "MaxViTBase", "MaxViTLarge"] {type:"string"}
IMAGE_SIZE = "224" #@param [224, 384, 512] {type:"string"}
BATCH_SIZE = 16 #@param {type:"integer"}
MIXED_PRECISION = True #@param {type:"boolean"}

IMAGE_SIZE = int(IMAGE_SIZE)


class MaxViTModel(tf.keras.Model):
  """class to build MaxViT family model."""
  def __init__(self,
               model_name='',
               model_input_size=224,
               input_specs=tf.keras.layers.InputSpec(
                   shape=[None, None, None, 3]),
               training=True):
    """VisionTransformer initialization function."""
    inputs = tf.keras.Input(shape=input_specs.shape[1:])
    config = hparams.lookup(model_name)

    if model_input_size == 224:
      config.model.window_size = 7
      config.model.grid_size = 7
      config.model.scale_ratio = None
    elif model_input_size == 384:
      config.model.window_size = 12
      config.model.grid_size = 12
      config.model.scale_ratio = '384/224'
    elif model_input_size == 512:
      config.model.window_size = 16
      config.model.grid_size = 16
      config.model.scale_ratio = '512/224'

    model = layers.MaxViT(config.model)
    out = model(inputs, training=training)

    super(MaxViTModel, self).__init__(inputs=inputs, outputs=out)


def build_tf2_model():
  """Build the tf2 model."""
  if MIXED_PRECISION:
    # Use 'mixed_float16' if running on GPUs.
    policy = tf.keras.mixed_precision.Policy('mixed_float16')
    tf.keras.mixed_precision.set_global_policy(policy)
  model = MaxViTModel(model_name=MODEL_NAME,
                      model_input_size=IMAGE_SIZE,
                      input_specs=tf.keras.layers.InputSpec(
                      shape=[BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3]),
                      training=False)
  return model


def run_tf_benchmark():
  """Run benchmark."""
  model = build_tf2_model()
  imgs = tf.ones((BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3), dtype=tf.float16)

  @tf.function
  def f(x):
    return model(x, training=False)

  print('starting warmup.')
  for _ in range(10):  # warmup runs.
    f(imgs)

  print('start benchmark.')
  start = time.perf_counter()
  for _ in range(10):
    f(imgs)
  end = time.perf_counter()
  inference_time = (end - start) / 10

  print('Per batch inference time: ', inference_time)
  print('FPS: ', BATCH_SIZE / inference_time)

run_tf_benchmark()

starting warmup.
start benchmark.
Per batch inference time:  0.1268322808000022
FPS:  126.15084976063699


## 2. Inference on images 

In [None]:
#@title Set model and params (ImageNet-1K models)

MODEL_NAME = "MaxViTBase" #@param ["MaxViTTiny", "MaxViTSmall", "MaxViTBase", "MaxViTLarge"] {type:"string"}
TRAIN_SET = "ImageNet-1K" #@param ["ImageNet-1K"] {type:"string"}
TRAIN_IMAGE_SIZE = "224" #@param [224, 384, 512] {type:"string"}
MIXED_PRECISION = False #@param {type:"boolean"}

CKPT_DIR = CKPTS_DIRS[f'{MODEL_NAME}_{DATASET_MAP[TRAIN_SET]}_{TRAIN_IMAGE_SIZE}']

In [None]:
#@title Inference on example image
import maxvit.models.eval_ckpt as eval_ckpt

#@markdown ### Enter a file path:
file_path = "https://upload.wikimedia.org/wikipedia/commons/f/fe/Giant_Panda_in_Beijing_Zoo_1.JPG" #@param {type:"string"}
INFER_IMAGE_SIZE = "224" #@param [224, 384, 448, 512, 672, 768, 896, 1024] {type:"string"}

# Download label map file and image
labels_map_file = 'gs://cloud-tpu-checkpoints/efficientnet/eval_data/labels_map.json'
image_file = 'panda.jpg'

!wget {file_path} -O {image_file}

image_files = [image_file]

eval_driver = eval_ckpt.MaxViTDriver(
    model_name=MODEL_NAME,
    model_input_size=TRAIN_IMAGE_SIZE,
    batch_size=1,
    image_size=int(INFER_IMAGE_SIZE),
    include_background_label=False,
    advprop_preprocessing=False,)

print(f"Input image:")
display.display(display.Image(image_file, width=INFER_IMAGE_SIZE))

print(f"MaxViT prediction:")
pred_idx, pred_prob = eval_driver.eval_example_images(
    CKPT_DIR, image_files, labels_map_file)

--2022-10-13 05:18:01--  https://upload.wikimedia.org/wikipedia/commons/f/fe/Giant_Panda_in_Beijing_Zoo_1.JPG
Resolving upload.wikimedia.org (upload.wikimedia.org)... 208.80.154.240, 2620:0:861:ed1a::2:b
Connecting to upload.wikimedia.org (upload.wikimedia.org)|208.80.154.240|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 116068 (113K) [image/jpeg]
Saving to: ‘panda.jpg’


2022-10-13 05:18:01 (3.82 MB/s) - ‘panda.jpg’ saved [116068/116068]

Input image:


<IPython.core.display.Image object>

MaxViT prediction:


Instructions for updating:
This is a deprecated API that should only be used in TF 1 graph mode and legacy TF 2 graph mode available through `tf.compat.v1`. In all other situations -- namely, eager mode and inside `tf.function` -- you can consume dataset elements using `for elem in dataset: ...` or by explicitly creating iterator via `iterator = iter(dataset)` and fetching its elements via `values = next(iterator)`. Furthermore, this API is not available in TF 2. During the transition from TF 1 to TF 2 you can use `tf.compat.v1.data.make_one_shot_iterator(dataset)` to create a TF 1 graph mode style iterator for a dataset created through TF 2 APIs. Note that this should be a transient state of your code base as there are in general no guarantees about the interoperability of TF 1 and TF 2 code.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.


predicted class for image panda.jpg: 
  -> top_0 (90.62%): giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca  
  -> top_1 (0.21%): lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens  
  -> top_2 (0.07%): earthstar  
  -> top_3 (0.05%): soccer ball  
  -> top_4 (0.04%): sloth bear, Melursus ursinus, Ursus ursinus  


In [None]:
#@title Using ImageNet-21K pre-trained models

MODEL_NAME = "MaxViTBase" #@param ["MaxViTBase", "MaxViTLarge", "MaxViTXLarge"] {type:"string"}
TRAIN_SET = "ImageNet-21K" #@param ["ImageNet-21K"] {type:"string"}
TRAIN_IMAGE_SIZE = "384" #@param [384, 512] {type:"string"}
MIXED_PRECISION = False #@param {type:"boolean"}

CKPT_DIR = CKPTS_DIRS[f'{MODEL_NAME}_{DATASET_MAP[TRAIN_SET]}_{TRAIN_IMAGE_SIZE}']

In [None]:
#@title Inference on example image
import maxvit.models.eval_ckpt as eval_ckpt

#@markdown ### Enter a file path:
file_path = "https://upload.wikimedia.org/wikipedia/commons/f/fe/Giant_Panda_in_Beijing_Zoo_1.JPG" #@param {type:"string"}
INFER_IMAGE_SIZE = "384" #@param [224, 384, 448, 512, 672, 768, 896, 1024] {type:"string"}

# Download label map file and image
labels_map_file = 'gs://cloud-tpu-checkpoints/efficientnet/eval_data/labels_map.json'
image_file = 'panda.jpg'

!wget {file_path} -O {image_file}

image_files = [image_file]

eval_driver = eval_ckpt.MaxViTDriver(
    model_name=MODEL_NAME,
    model_input_size=TRAIN_IMAGE_SIZE,
    batch_size=1,
    image_size=int(INFER_IMAGE_SIZE),
    include_background_label=False,
    legacy_preprocessing=False,)

print(f"Input image:")
display.display(display.Image(image_file, width=INFER_IMAGE_SIZE))

print(f"MaxViT prediction:")
pred_idx, pred_prob = eval_driver.eval_example_images(
    CKPT_DIR, image_files, labels_map_file)

--2022-10-13 05:20:09--  https://upload.wikimedia.org/wikipedia/commons/f/fe/Giant_Panda_in_Beijing_Zoo_1.JPG
Resolving upload.wikimedia.org (upload.wikimedia.org)... 208.80.154.240, 2620:0:861:ed1a::2:b
Connecting to upload.wikimedia.org (upload.wikimedia.org)|208.80.154.240|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 116068 (113K) [image/jpeg]
Saving to: ‘panda.jpg’


2022-10-13 05:20:09 (3.78 MB/s) - ‘panda.jpg’ saved [116068/116068]

Input image:


<IPython.core.display.Image object>

MaxViT prediction:
predicted class for image panda.jpg: 
  -> top_0 (91.65%): giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca  
  -> top_1 (0.17%): lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens  
  -> top_2 (0.06%): brown bear, bruin, Ursus arctos  
  -> top_3 (0.05%): soccer ball  
  -> top_4 (0.05%): American black bear, black bear, Ursus americanus, Euarctos americanus  


## 3. Directly inference on higher resolution.

Note INFER_IMAGE_SIZE needs to be multipliers of TRAIN_IMAGE_SIZE

In [None]:
#@title Set model and params

MODEL_NAME = "MaxViTTiny" #@param ["MaxViTTiny", "MaxViTSmall", "MaxViTBase", "MaxViTLarge"] {type:"string"}
TRAIN_SET = "ImageNet-1K" #@param ["ImageNet-1K"] {type:"string"}
TRAIN_IMAGE_SIZE = "224" #@param [224, 384, 512] {type:"string"}
MIXED_PRECISION = False #@param {type:"boolean"}

CKPT_DIR = CKPTS_DIRS[f'{MODEL_NAME}_{DATASET_MAP[TRAIN_SET]}_{TRAIN_IMAGE_SIZE}']

In [None]:
#@title Inference on example image
import maxvit.models.eval_ckpt as eval_ckpt

#@markdown ### Enter a file path:
file_path = "https://upload.wikimedia.org/wikipedia/commons/f/fe/Giant_Panda_in_Beijing_Zoo_1.JPG" #@param {type:"string"}
INFER_IMAGE_SIZE = "672" #@param [224, 384, 448, 512, 672, 768, 896, 1024] {type:"string"}

# Download label map file and image
labels_map_file = 'gs://cloud-tpu-checkpoints/efficientnet/eval_data/labels_map.json'
image_file = 'panda.jpg'

!wget {file_path} -O {image_file}

image_files = [image_file]

eval_driver = eval_ckpt.MaxViTDriver(
    model_name=MODEL_NAME,
    model_input_size=TRAIN_IMAGE_SIZE,
    batch_size=1,
    image_size=int(INFER_IMAGE_SIZE),
    include_background_label=False,
    advprop_preprocessing=False,)

print(f"Input image:")
display.display(display.Image(image_file, width=INFER_IMAGE_SIZE))

print(f"MaxViT prediction:")
pred_idx, pred_prob = eval_driver.eval_example_images(
    CKPT_DIR, image_files, labels_map_file)

--2022-10-13 05:21:53--  https://upload.wikimedia.org/wikipedia/commons/f/fe/Giant_Panda_in_Beijing_Zoo_1.JPG
Resolving upload.wikimedia.org (upload.wikimedia.org)... 208.80.153.240, 2620:0:861:ed1a::2:b
Connecting to upload.wikimedia.org (upload.wikimedia.org)|208.80.153.240|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 116068 (113K) [image/jpeg]
Saving to: ‘panda.jpg’


2022-10-13 05:21:54 (1.06 MB/s) - ‘panda.jpg’ saved [116068/116068]

Input image:


<IPython.core.display.Image object>

MaxViT prediction:
predicted class for image panda.jpg: 
  -> top_0 (90.04%): giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca  
  -> top_1 (0.10%): lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens  
  -> top_2 (0.08%): Arctic fox, white fox, Alopex lagopus  
  -> top_3 (0.07%): American black bear, black bear, Ursus americanus, Euarctos americanus  
  -> top_4 (0.07%): sloth bear, Melursus ursinus, Ursus ursinus  


## 4. Inference local uploads.

Note some image formats are not supported by Tensorflow IO.

In [None]:
#@title Set model and params

MODEL_NAME = "MaxViTTiny" #@param ["MaxViTTiny", "MaxViTSmall", "MaxViTBase", "MaxViTLarge"] {type:"string"}
TRAIN_SET = "ImageNet-1K" #@param ["ImageNet-1K"] {type:"string"}
TRAIN_IMAGE_SIZE = "224" #@param [224, 384, 512] {type:"string"}
MIXED_PRECISION = False #@param {type:"boolean"}

CKPT_DIR = CKPTS_DIRS[f'{MODEL_NAME}_{DATASET_MAP[TRAIN_SET]}_{TRAIN_IMAGE_SIZE}']

In [None]:
#@title Inference on uploaded image

from google.colab import files
uploaded = files.upload()

INFER_IMAGE_SIZE = "224" #@param [224, 384, 448, 512, 672, 768, 896, 1024] {type:"string"}


save_path = './' + list(uploaded.keys())[0]
with open(save_path, "wb") as f:
  f.write(list(uploaded.values())[0])

image_files = [save_path]

eval_driver = eval_ckpt.MaxViTDriver(
    model_name=MODEL_NAME,
    model_input_size=TRAIN_IMAGE_SIZE,
    batch_size=1,
    image_size=int(INFER_IMAGE_SIZE),
    include_background_label=False,
    advprop_preprocessing=False,)

print(f"Input image:")
display.display(display.Image(save_path, width=INFER_IMAGE_SIZE))

print(f"MaxViT prediction:")
pred_idx, pred_prob = eval_driver.eval_example_images(
    CKPT_DIR, image_files, labels_map_file)

Saving red_panda.jpeg to red_panda.jpeg
Input image:


<IPython.core.display.Image object>

MaxViT prediction:
predicted class for image ./red_panda.jpeg: 
  -> top_0 (82.86%): lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens  
  -> top_1 (0.23%): polecat, fitch, foulmart, foumart, Mustela putorius  
  -> top_2 (0.08%): giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca  
  -> top_3 (0.06%): mink  
  -> top_4 (0.06%): badger  
