# 📦 Packages and Basic Setup
---

In [None]:
%%capture
!pip install -U rich flax
!pip install --upgrade git+https://github.com/n2cholas/jax-resnet.git

import os
import jax
import random
import numpy as np
from rich import print

import flax.linen as nn
import jax.numpy as jnp
from jax_resnet import pretrained_resnet, common

import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.python.ops.numpy_ops import np_config

from typing import Callable, Tuple, Any, List

# Experimental options
options = tf.data.Options()
options.experimental_optimization.noop_elimination = True
options.experimental_optimization.apply_default_optimizations = True
options.experimental_deterministic = False
options.threading.max_intra_op_parallelism = 1
np_config.enable_numpy_behavior()

AUTOTUNE = tf.data.experimental.AUTOTUNE
strategy = tf.distribute.MirroredStrategy()

In [None]:
# @title ⚙ Configuration
GLOBAL_SEED = 42  # @param {type: "number"}
NUM_VIEWS = 2  # @param {type: "number"}
NUM_TRAINING_EPOCHS = 10  # @param {type: "number"}
NUM_EVAL_EPOCHS = 100  # @param {type: "number"}
TRAIN_BATCH_SIZE = 32  # @param {type: "number"}
EVAL_BATCH_SIZE = 256  # @param {type: "number"}
MLP_UNITS = 8192  # @param {type: "number"}
INVAR_COEFF = 25.0  # @param {type: "number"}
VAR_COEFF = 25.0  # @param {type: "number"}
COV_COEFF = 1.0  # @param {type: "number"}
DECAY_STEPS = 1000  # @param {type: "number"}
WEIGHT_DECAY = 1e-6  # @param {type: "number"}
BASE_LR = 0.2  # @param {type: "number"}
EVAL_LR = 0.02  # @param {type: "number"}


# ============ Random Seed ============
def seed_everything(seed=GLOBAL_SEED):
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    tf.experimental.numpy.random.seed(seed)
    # When running on the CuDNN backend, two further options must be set
    os.environ["TF_CUDNN_DETERMINISTIC"] = "1"
    os.environ["TF_DETERMINISTIC_OPS"] = "1"
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")


seed_everything()

## Accelerator Configuration

In [None]:
# Reference: https://www.kaggle.com/code/odins0n/jax-flax-tf-data-vision-transformers-tutorial

if 'TPU_NAME' in os.environ:
    import requests
    if 'TPU_DRIVER_MODE' not in globals():
        url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
        resp = requests.post(url)
        TPU_DRIVER_MODE = 1
    from jax.config import config as jax_config
    jax_config.FLAGS.jax_xla_backend = "tpu_driver"
    jax_config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
    print("TPU DETECTED!")
    print('Registered TPU:', jax_config.FLAGS.jax_backend_target)
elif "COLAB_TPU_ADDR" in os.environ:
    import jax.tools.colab_tpu
    jax.tools.colab_tpu.setup_tpu()
else:
    print('No TPU detected.')

DEVICE_COUNT = len(jax.local_devices())
TPU = DEVICE_COUNT==8

if TPU:
    print("8 cores of TPU ( Local devices in Jax ):")
    print('\n'.join(map(str,jax.local_devices())))

# 🆘 Utility Classes and Functions
---

## 🖖 Utilites for Data Augmentation


In [None]:
GAUSSIAN_P = [1.0, 0.1]
SOLARIZE_P = [0.0, 0.2]

def shuffle_zipped_output(a: Any, b: Any) -> Tuple[Any]:
    """Shuffle the given inputs"""
    listify = [a,b]
    random.shuffle(listify)
    return listify[0], listify[1]

@tf.function
def scale_image(image: tf.Tensor, label: tf.Tensor) -> Tuple[tf.Tensor]:
    """Convert all images to float32"""
    image = tf.image.convert_image_dtype(image, tf.float32)
    return (image, label)

@tf.function
def gaussian_blur(image: tf.Tensor, kernel_size:int=23, padding: str='SAME') -> tf.Tensor:
    """
    Randomly apply Gaussian Blur to the input image
    
    Reference: https://github.com/google-research/simclr/blob/master/data_util.py
    """

    sigma = tf.random.uniform((1,))* 1.9 + 0.1
    radius = tf.cast(kernel_size / 2, tf.int32)
    kernel_size = radius * 2 + 1
    x = tf.cast(tf.range(-radius, radius + 1), tf.float32)
    blur_filter = tf.exp(
        -tf.pow(x, 2.0) / (2.0 * tf.pow(tf.cast(sigma, tf.float32), 2.0)))
    blur_filter /= tf.reduce_sum(blur_filter)

    # One vertical and one horizontal filter.
    blur_v = tf.reshape(blur_filter, [kernel_size, 1, 1, 1])
    blur_h = tf.reshape(blur_filter, [1, kernel_size, 1, 1])
    num_channels = tf.shape(image)[-1]
    blur_h = tf.tile(blur_h, [1, 1, num_channels, 1])
    blur_v = tf.tile(blur_v, [1, 1, num_channels, 1])
    expand_batch_dim = image.shape.ndims == 3
    if expand_batch_dim:
      image = tf.expand_dims(image, axis=0)
    blurred = tf.nn.depthwise_conv2d(
        image, blur_h, strides=[1, 1, 1, 1], padding=padding)
    blurred = tf.nn.depthwise_conv2d(
        blurred, blur_v, strides=[1, 1, 1, 1], padding=padding)
    if expand_batch_dim:
      blurred = tf.squeeze(blurred, axis=0)
    return blurred

@tf.function
def color_jitter(image: tf.Tensor, s: float = 0.5) -> tf.Tensor:
    """Randomly apply Color Jittering to the input image"""
    x = tf.image.random_brightness(image, max_delta=0.8*s)
    x = tf.image.random_contrast(x, lower=1-0.8*s, upper=1+0.8*s)
    x = tf.image.random_saturation(x, lower=1-0.8*s, upper=1+0.8*s)
    x = tf.image.random_hue(x, max_delta=0.2*s)
    x = tf.clip_by_value(x, 0, 1)
    return x

@tf.function
def solarize(image: tf.Tensor, threshold: int = 128) -> tf.Tensor:
    """Solarize the input image"""
    return tf.where(image < threshold, image, 255 - image)

@tf.function
def color_drop(image: tf.Tensor) -> tf.Tensor:
    """Randomly convert the input image to GrayScale"""
    image = tf.image.rgb_to_grayscale(image)
    image = tf.tile(image, [1, 1, 3])
    return image

@tf.function
def random_apply(func: Callable, x: tf.Tensor, p: float) -> tf.Tensor:
    """Randomly apply the desired func to the input image"""
    return tf.cond(
        tf.less(tf.random.uniform([], minval=0, maxval=1, dtype=tf.float32),
                tf.cast(p, tf.float32)),
        lambda: func(x),
        lambda: x)

@tf.function
def custom_augment_train(image: tf.Tensor, label: tf.Tensor, gaussian_p: float = 0.1, solarize_p: float = 0.0) -> Tuple[tf.Tensor]:       
    """Container function to apply all custom augmentations"""
    # Random flips
    image = random_apply(tf.image.flip_left_right, image, p=0.5)
    # Randomly apply transformation (color distortions) with probability p.
    image = random_apply(color_jitter, image, p=0.8)
    # Randomly apply grayscale
    image = random_apply(color_drop, image, p=0.2)
    # Randomly apply gausian blur
    image = random_apply(gaussian_blur, image, p=gaussian_p)
    # Randomly apply solarization
    image = random_apply(solarize, image, p=solarize_p)

    return (image, label)

@tf.function
def custom_augment_eval(image: tf.Tensor, label: tf.Tensor, crop_size:int = 224) -> Tuple[tf.Tensor]:
    """Randomly Resize and Augment Crops"""
    # image resizing
    image_shape = 260
    image = tf.image.resize(image, (image_shape, image_shape))
    # get the crop from the image
    crop = tf.image.random_crop(image, (crop_size,crop_size,3))
    resized_image = tf.image.resize(crop, (crop_size, crop_size))
    return resized_image, label

@tf.function
def train_augmentations(image: tf.Tensor, label: tf.Tensor, gaussian_p: float = 0.1, solarize_p: float = 0.0, crop_size:int = 224) -> Tuple[tf.Tensor]:
    """Randomly Resize and Augment Crops"""
    # scale the pixel values
    image, label = scale_image(image , label)
    # image resizing
    image_shape = 260
    image = tf.image.resize(image, (image_shape, image_shape))
    # get the crop from the image
    crop = tf.image.random_crop(image, (crop_size,crop_size,3))
    crop_resize = tf.image.resize(crop, (crop_size, crop_size))
    # color distortions
    distored_image, label = custom_augment_train(crop_resize, label, gaussian_p)
    return distored_image, label

@tf.function
def eval_augmentations(image: tf.Tensor, label: tf.Tensor) -> Tuple[tf.Tensor]:
    """Randomly Augment Images for Evaluation"""
    # Scale the pixel values
    image, label = scale_image(image , label)
    # random horizontal flip
    image = random_apply(tf.image.random_flip_left_right, image, p=0.5)
    # Random resized crops
    image, label = custom_augment_eval(image, label)

    return image, label

# 💿 The Dataset
---

For the purposes of this example, we use the TF Flowers dataset.

In [None]:
%%time
tfds.disable_progress_bar()

# Gather Flowers dataset
train_ds, validation_ds = tfds.load(
    "tf_flowers",
    split=["train[:85%]", "train[85%:]"],
    as_supervised=True
)

## 🖖 Data Augmentation Pipeline


In [None]:
# We create a Tuple because we have two loaders corresponding to each view
trainloaders = tuple()

for i in range(NUM_VIEWS):
  trainloader = (
      train_ds
      .shuffle(1024)
      .map(lambda x, y: train_augmentations(x, y, GAUSSIAN_P[i], SOLARIZE_P[i]), num_parallel_calls=AUTOTUNE)
  )
  trainloader = trainloader.with_options(options)
  trainloaders+=(trainloader,)

## ⚙️ Dataloader


In [None]:
# zip both the dataloaders together
trainloader = tf.data.Dataset.zip(trainloaders)

# final trainloader to be used for training
trainloader = (
    trainloader
    .batch(TRAIN_BATCH_SIZE * strategy.num_replicas_in_sync)
    .map(shuffle_zipped_output, num_parallel_calls=AUTOTUNE)
    .prefetch(AUTOTUNE)
)

# ✍️ Model Architecture & Training
---

## 🏠 Building the network

![](https://camo.githubusercontent.com/9cffe6a81978d546ca3c54c02e634d432c1be29ace3b2560d3f4a19710aa6654/68747470733a2f2f6769746875622e636f6d2f66616365626f6f6b72657365617263682f7669637265672f626c6f622f6d61696e2f2e6769746875622f7669637265675f61726368695f66756c6c2e6a70673f7261773d74727565)

In [None]:
def get_encoder():
  base_model, base_model_variables = pretrained_resnet(size=50)
  backbone = nn.Sequential(base_model().layers[:-1])
  backbone_variables = common.slice_variables(base_model_variables, end = -1)

  return backbone, backbone_variables

In [None]:
%%time
encoder, encoder_variables = get_encoder()
encoder_output = encoder.apply(encoder_variables,
                  jnp.ones((32, 224, 224, 3)),  # ImageNet sized inputs.
                  mutable=False)
assert encoder_output.shape == (32, 2048)
del encoder, encoder_variables, encoder_output

In [None]:
class Expander(nn.Module):
  num_units: int

  @nn.compact
  def __call__(self, inputs, train: bool):
    projection_1 = nn.Dense(features = self.num_units)(inputs)
    projection_1 = nn.BatchNorm(use_running_average=not train)(projection_1)
    projection_1 = nn.relu(projection_1)

    projection_2 = nn.Dense(features = self.num_units)(projection_1)
    projection_2 = nn.BatchNorm(use_running_average=not train)(projection_2)
    projection_2 = nn.relu(projection_2)

    return nn.Dense(features = self.num_units)(projection_2)

expander = Expander(num_units = 8192)

In [None]:
%%time
variables = expander.init(jax.random.PRNGKey(0), jnp.ones((2048, 8192)), train=False)
params, batch_stats = variables['params'], variables['batch_stats']
expander_output , _ = expander.apply(
  {'params': params, 'batch_stats': batch_stats},
  jnp.ones((2048, 8192)),
  train=True, mutable=['batch_stats']
)
assert expander_output.shape == (2048, 8192)

del expander, variables, params, batch_stats, expander_output

In [None]:
class VICReg(nn.Module):
  encoder: nn.Sequential
  expander: nn.Module

## 🎬 Initializing the Module


## 🏋️‍♂️ Train Step
