# 📦 Packages and Basic Setup
---

In [None]:
%%capture
!pip install -U rich tf-models-official

import random
import os
import numpy as np
from rich import print
import tensorflow_models as tfm
import tensorflow as tf
import tensorflow_datasets as tfds
from rich.progress import track
from tensorflow.python.ops.numpy_ops import np_config

from typing import Callable, Tuple, Any, List

# Experimental options
options = tf.data.Options()
options.experimental_optimization.noop_elimination = True
options.experimental_optimization.apply_default_optimizations = True
options.experimental_deterministic = False
options.threading.max_intra_op_parallelism = 1
np_config.enable_numpy_behavior()

AUTOTUNE = tf.data.experimental.AUTOTUNE
strategy = tf.distribute.MirroredStrategy()

In [None]:
# @title ⚙ Configuration
GLOBAL_SEED = 42  # @param {type: "number"}
NUM_VIEWS = 2  # @param {type: "number"}
NUM_TRAINING_EPOCHS = 10  # @param {type: "number"}
NUM_EVAL_EPOCHS = 100  # @param {type: "number"}
TRAIN_BATCH_SIZE = 32  # @param {type: "number"}
EVAL_BATCH_SIZE = 256  # @param {type: "number"}
MLP_UNITS = 8192  # @param {type: "number"}
INVAR_COEFF = 25.0  # @param {type: "number"}
VAR_COEFF = 25.0  # @param {type: "number"}
COV_COEFF = 1.0  # @param {type: "number"}
DECAY_STEPS = 1000  # @param {type: "number"}
WEIGHT_DECAY = 1e-6  # @param {type: "number"}
BASE_LR = 0.2  # @param {type: "number"}
EVAL_LR = 0.02  # @param {type: "number"}


# ============ Random Seed ============
def seed_everything(seed=GLOBAL_SEED):
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    tf.experimental.numpy.random.seed(seed)
    # When running on the CuDNN backend, two further options must be set
    os.environ["TF_CUDNN_DETERMINISTIC"] = "1"
    os.environ["TF_DETERMINISTIC_OPS"] = "1"
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")


seed_everything()

# 🆘 Utility Classes and Functions
---

In [None]:
def off_diagonal(x: tf.Tensor) -> tf.Tensor:
    n, m = x.shape[0], x.shape[1]
    assert n == m, f"Not a square tensor, dimensions found: {n} and {m}"

    flattened_tensor = tf.reshape(x, [-1])[:-1]
    elements = tf.reshape(flattened_tensor, [n - 1, n + 1])[:, 1:]
    return tf.reshape(elements, [-1])

## 🖖 Utilites for Data Augmentation

In [None]:
GAUSSIAN_P = [1.0, 0.1]
SOLARIZE_P = [0.0, 0.2]

def shuffle_zipped_output(a: Any, b: Any) -> Tuple[Any]:
    """Shuffle the given inputs"""
    listify = [a,b]
    random.shuffle(listify)
    return listify[0], listify[1]

@tf.function
def scale_image(image: tf.Tensor, label: tf.Tensor) -> Tuple[tf.Tensor]:
    """Convert all images to float32"""
    image = tf.image.convert_image_dtype(image, tf.float32)
    return (image, label)

@tf.function
def gaussian_blur(image: tf.Tensor, kernel_size:int=23, padding: str='SAME') -> tf.Tensor:
    """
    Randomly apply Gaussian Blur to the input image
    
    Reference: https://github.com/google-research/simclr/blob/master/data_util.py
    """

    sigma = tf.random.uniform((1,))* 1.9 + 0.1
    radius = tf.cast(kernel_size / 2, tf.int32)
    kernel_size = radius * 2 + 1
    x = tf.cast(tf.range(-radius, radius + 1), tf.float32)
    blur_filter = tf.exp(
        -tf.pow(x, 2.0) / (2.0 * tf.pow(tf.cast(sigma, tf.float32), 2.0)))
    blur_filter /= tf.reduce_sum(blur_filter)

    # One vertical and one horizontal filter.
    blur_v = tf.reshape(blur_filter, [kernel_size, 1, 1, 1])
    blur_h = tf.reshape(blur_filter, [1, kernel_size, 1, 1])
    num_channels = tf.shape(image)[-1]
    blur_h = tf.tile(blur_h, [1, 1, num_channels, 1])
    blur_v = tf.tile(blur_v, [1, 1, num_channels, 1])
    expand_batch_dim = image.shape.ndims == 3
    if expand_batch_dim:
      image = tf.expand_dims(image, axis=0)
    blurred = tf.nn.depthwise_conv2d(
        image, blur_h, strides=[1, 1, 1, 1], padding=padding)
    blurred = tf.nn.depthwise_conv2d(
        blurred, blur_v, strides=[1, 1, 1, 1], padding=padding)
    if expand_batch_dim:
      blurred = tf.squeeze(blurred, axis=0)
    return blurred

@tf.function
def color_jitter(image: tf.Tensor, s: float = 0.5) -> tf.Tensor:
    """Randomly apply Color Jittering to the input image"""
    x = tf.image.random_brightness(image, max_delta=0.8*s)
    x = tf.image.random_contrast(x, lower=1-0.8*s, upper=1+0.8*s)
    x = tf.image.random_saturation(x, lower=1-0.8*s, upper=1+0.8*s)
    x = tf.image.random_hue(x, max_delta=0.2*s)
    x = tf.clip_by_value(x, 0, 1)
    return x

@tf.function
def solarize(image: tf.Tensor, threshold: int = 128) -> tf.Tensor:
    """Solarize the input image"""
    return tf.where(image < threshold, image, 255 - image)

@tf.function
def color_drop(image: tf.Tensor) -> tf.Tensor:
    """Randomly convert the input image to GrayScale"""
    image = tf.image.rgb_to_grayscale(image)
    image = tf.tile(image, [1, 1, 3])
    return image

@tf.function
def random_apply(func: Callable, x: tf.Tensor, p: float) -> tf.Tensor:
    """Randomly apply the desired func to the input image"""
    return tf.cond(
        tf.less(tf.random.uniform([], minval=0, maxval=1, dtype=tf.float32),
                tf.cast(p, tf.float32)),
        lambda: func(x),
        lambda: x)

@tf.function
def custom_augment_train(image: tf.Tensor, label: tf.Tensor, gaussian_p: float = 0.1, solarize_p: float = 0.0) -> Tuple[tf.Tensor]:       
    """Container function to apply all custom augmentations"""
    # Random flips
    image = random_apply(tf.image.flip_left_right, image, p=0.5)
    # Randomly apply transformation (color distortions) with probability p.
    image = random_apply(color_jitter, image, p=0.8)
    # Randomly apply grayscale
    image = random_apply(color_drop, image, p=0.2)
    # Randomly apply gausian blur
    image = random_apply(gaussian_blur, image, p=gaussian_p)
    # Randomly apply solarization
    image = random_apply(solarize, image, p=solarize_p)

    return (image, label)

@tf.function
def custom_augment_eval(image: tf.Tensor, label: tf.Tensor, crop_size:int = 224) -> Tuple[tf.Tensor]:
    """Randomly Resize and Augment Crops"""
    # image resizing
    image_shape = 260
    image = tf.image.resize(image, (image_shape, image_shape))
    # get the crop from the image
    crop = tf.image.random_crop(image, (crop_size,crop_size,3))
    resized_image = tf.image.resize(crop, (crop_size, crop_size))
    return resized_image, label

@tf.function
def train_augmentations(image: tf.Tensor, label: tf.Tensor, gaussian_p: float = 0.1, solarize_p: float = 0.0, crop_size:int = 224) -> Tuple[tf.Tensor]:
    """Randomly Resize and Augment Crops"""
    # scale the pixel values
    image, label = scale_image(image , label)
    # image resizing
    image_shape = 260
    image = tf.image.resize(image, (image_shape, image_shape))
    # get the crop from the image
    crop = tf.image.random_crop(image, (crop_size,crop_size,3))
    crop_resize = tf.image.resize(crop, (crop_size, crop_size))
    # color distortions
    distored_image, label = custom_augment_train(crop_resize, label, gaussian_p)
    return distored_image, label

@tf.function
def eval_augmentations(image: tf.Tensor, label: tf.Tensor) -> Tuple[tf.Tensor]:
    """Randomly Augment Images for Evaluation"""
    # Scale the pixel values
    image, label = scale_image(image , label)
    # random horizontal flip
    image = random_apply(tf.image.random_flip_left_right, image, p=0.5)
    # Random resized crops
    image, label = custom_augment_eval(image, label)

    return image, label

# 💿 The Dataset

---

For the purposes of this example, we use the TF Flowers dataset.

In [None]:
tfds.disable_progress_bar()

# Gather Flowers dataset
train_ds, validation_ds = tfds.load(
    "tf_flowers",
    split=["train[:85%]", "train[85%:]"],
    as_supervised=True
)

## 🖖 Data Augmentation Pipeline


In [None]:
# We create a Tuple because we have two loaders corresponding to each view
trainloaders = tuple()

for i in range(NUM_VIEWS):
  trainloader = (
      train_ds
      .shuffle(1024)
      .map(lambda x, y: train_augmentations(x, y, GAUSSIAN_P[i], SOLARIZE_P[i]), num_parallel_calls=AUTOTUNE)
  )
  trainloader = trainloader.with_options(options)
  trainloaders+=(trainloader,)

## ⚙️ Dataloader


In [None]:
# zip both the dataloaders together
trainloader = tf.data.Dataset.zip(trainloaders)

# final trainloader to be used for training
trainloader = (
    trainloader
    .batch(TRAIN_BATCH_SIZE * strategy.num_replicas_in_sync)
    .map(shuffle_zipped_output, num_parallel_calls=AUTOTUNE)
    .prefetch(AUTOTUNE)
)

# ✍️ Model Architecture & Training
---

## 🏠 Building the network
![](https://github.com/facebookresearch/vicreg/blob/main/.github/vicreg_archi_full.jpg?raw=true)

In [None]:
class VICReg(tf.keras.Model):
    def __init__(
        self,
        num_units: int,
        invar_coeff: float,
        var_coeff: float,
        cov_coeff: float,
        **kwargs
    ) -> None:
        super().__init__(**kwargs)

        self.num_units = num_units
        self.invar_coeff = invar_coeff
        self.var_coeff = var_coeff
        self.cov_coeff = cov_coeff

        self.encoder = self.build_encoder()
        self.expander = self.build_expander(self.num_units)

        self.loss_tracker = tf.keras.metrics.Mean(name="vicreg_loss")
        self.invarloss_tracker = tf.keras.metrics.Mean(name="invariance_loss")
        self.varloss_tracker = tf.keras.metrics.Mean(name="variance_loss")
        self.covloss_tracker = tf.keras.metrics.Mean(name="covariance_loss")

    def get_config(self):
        return {
            "invar_coeff": self.invar_coeff,
            "var_coeff": self.var_coeff,
            "cov_coeff": self.cov_coeff,
            "num_units": self.num_units,
        }

    @classmethod
    def from_config(cls, config):
        return cls(**config)

    @property
    def metrics(self):
        return [
            self.loss_tracker,
            self.invarloss_tracker,
            self.varloss_tracker,
            self.covloss_tracker,
        ]

    def build_encoder(self):
        encoder_input = tf.keras.layers.Input((None, None, 3))
        base_model = tf.keras.applications.ResNet50(
            include_top=False, weights=None, input_shape=(None, None, 3)
        )
        base_model.trainable = True
        representations = base_model(encoder_input, training=True)
        encoder_output = tf.keras.layers.GlobalAveragePooling2D()(representations)
        encoder = tf.keras.models.Model(
            inputs=encoder_input, outputs=encoder_output, name="encoder"
        )
        return encoder

    def build_expander(self, num_units: int):
        expander_input = tf.keras.layers.Input((2048,))

        projection_1 = tf.keras.layers.Dense(num_units)(expander_input)
        projection_1 = tf.keras.layers.BatchNormalization()(projection_1)
        projection_1 = tf.keras.layers.Activation("relu")(projection_1)

        projection_2 = tf.keras.layers.Dense(num_units)(projection_1)
        projection_2 = tf.keras.layers.BatchNormalization()(projection_2)
        projection_2 = tf.keras.layers.Activation("relu")(projection_2)

        expander_output = tf.keras.layers.Dense(num_units)(projection_2)

        expander = tf.keras.models.Model(
            inputs=expander_input, outputs=expander_output, name="expander"
        )

        return expander

    def save_weights(self):
        self.encoder.save_weights("encoder.h5")
        self.expander.save_weights("expander.h5")

    def train_step(self, images):
        x, x_prime = images[0][0], images[1][0]
        inputs = [x, x_prime]
        batch_size = inputs[0][0].shape[0]

        with tf.GradientTape() as tape:
            # Get Representations (through encoder)
            y = self.encoder(x)
            y_prime = self.encoder(x_prime)

            # Get Embeddings (through expander)
            z = self.expander(y)
            z_prime = self.expander(y_prime)

            # Calculate the Representation (Invariance) Loss
            invar_loss = tf.keras.metrics.mean_squared_error(z, z_prime)

            # Calculate var. and std. dev. of embeddings
            z = z - tf.reduce_mean(z, axis=0)
            z_prime = z_prime - tf.reduce_mean(z_prime, axis=0)
            std_z = tf.sqrt(tf.math.reduce_variance(z, axis=0) + 0.0001)
            std_z_prime = tf.sqrt(tf.math.reduce_variance(z_prime, axis=0) + 0.0001)

            # Calculate the Variance Loss (Hinge Function)
            var_loss = (
                tf.reduce_mean(tf.nn.relu(1 - std_z)) / 2
                + tf.reduce_mean(tf.nn.relu(1 - std_z_prime)) / 2
            )

            # Get Covariance Matrix
            cov_z = (z.T @ z) / (batch_size - 1)
            cov_z_prime = (z_prime.T @ z_prime) / (batch_size - 1)

            # Calculate the Covariance Loss
            cov_loss_z = tf.divide(tf.reduce_sum(tf.pow(off_diagonal(cov_z), 2)), 8192)
            cov_loss_z_prime = tf.divide(
                tf.reduce_sum(tf.pow(off_diagonal(cov_z_prime), 2)), 8192
            )
            cov_loss = cov_loss_z + cov_loss_z_prime

            # Weighted Avg. of Invariance, Variance and Covariance Loss
            loss = (
                self.invar_coeff * invar_loss
                + self.var_coeff * var_loss
                + self.cov_coeff * cov_loss
            )

        # Compute gradients
        variables = self.encoder.trainable_variables + self.expander.trainable_variables
        gradients = tape.gradient(loss, variables)
        # Update weights
        self.optimizer.apply_gradients(zip(gradients, variables))
        # Compute our own metrics
        self.loss_tracker.update_state(loss)
        self.invarloss_tracker.update_state(invar_loss)
        self.varloss_tracker.update_state(var_loss)
        self.covloss_tracker.update_state(cov_loss)
        # Return a dict mapping metric names to current value
        return {
            "loss": self.loss_tracker.result(),
            "invariance_loss": self.invarloss_tracker.result(),
            "variance_loss": self.varloss_tracker.result(),
            "covariance_loss": self.covloss_tracker.result(),
        }

## 🏃 Train !!

In [None]:
# The training protocol for VICReg follows those of BYOL and Barlow Twins,
# i.e. the use of LARS which is adaptive algorithm meant for large batch training
lr_decayed_fn = tf.keras.optimizers.schedules.CosineDecay(
    initial_learning_rate=BASE_LR, decay_steps=DECAY_STEPS
)
opt = tfm.optimization.lars_optimizer.LARS(
    learning_rate=lr_decayed_fn, weight_decay_rate=WEIGHT_DECAY
)

with strategy.scope():
    model = VICReg(
        num_units=MLP_UNITS,
        invar_coeff=INVAR_COEFF,
        var_coeff=VAR_COEFF,
        cov_coeff=COV_COEFF,
    )
    model.compile(optimizer=opt)
model.fit(trainloader, epochs=NUM_TRAINING_EPOCHS)
model.save_weights()

# 👨🏻‍⚖️ Linear Evaluation
---
We use a linear evaluation protocol i.e., we train a linear classifier on top of the frozen representations of the ResNet-50 backbone pretrained with VICReg. 

## ⚙️ Dataloader for Linear Evaluation

As detailed in Appendix C.2 Imagenet Evaluation, the training data augmentation pipeline is composed of random cropping
and resize of ratio 0.2 to 1.0 with size 224 × 224, and random horizontal flips.  During evaluation the
validation images are simply center cropped and resized to 224 × 224.

In [None]:
tf.keras.backend.clear_session()

# Gather Flowers dataset
train_ds, validation_ds, test_ds = tfds.load(
    "tf_flowers",
    split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],
    as_supervised=True
)

eval_trainloader = (
	train_ds
	.shuffle(1024)
	.map(eval_augmentations, num_parallel_calls=AUTOTUNE)
	.batch(EVAL_BATCH_SIZE)
	.prefetch(AUTOTUNE)
)

eval_valdataloader = (
  validation_ds
  .shuffle(1024)
  .map(scale_image, num_parallel_calls=AUTOTUNE)
  .map(custom_augment_eval, num_parallel_calls=AUTOTUNE)
	.batch(EVAL_BATCH_SIZE)
	.prefetch(AUTOTUNE)
)

eval_testdataloader = (
  test_ds
  .shuffle(1024)
  .map(scale_image, num_parallel_calls=AUTOTUNE)
  .map(custom_augment_eval, num_parallel_calls=AUTOTUNE)
	.batch(EVAL_BATCH_SIZE)
	.prefetch(AUTOTUNE)
)

## 🏠 Building the network

In [None]:
def get_linear_classifier() -> tf.keras.Model:
    # input placeholder
    inputs = tf.keras.layers.Input(shape=(224, 224, 3))
    # get vicreg model architecture
    base_model = VICReg(
        num_units=MLP_UNITS,
        invar_coeff=INVAR_COEFF,
        var_coeff=VAR_COEFF,
        cov_coeff=COV_COEFF,
    )
    feature_backbone = base_model.build_encoder()
    # load trained weights
    feature_backbone.load_weights("encoder.h5")
    feature_backbone.trainable = False
    x = feature_backbone(inputs, training=False)
    outputs = tf.keras.layers.Dense(10, activation="softmax")(x)
    linear_model = tf.keras.Model(inputs, outputs)

    return linear_model

## 🏃 Evaluation !!

In [None]:
lr_decayed_fn = tf.keras.optimizers.schedules.CosineDecay(
    initial_learning_rate=EVAL_LR, decay_steps=DECAY_STEPS
)

with strategy.scope():
  evaluation_model = get_linear_classifier()
  evaluation_model.compile(loss="sparse_categorical_crossentropy", 
                           metrics=["acc"],
                           optimizer=tf.keras.optimizers.SGD(learning_rate=lr_decayed_fn, weight_decay = WEIGHT_DECAY))

evaluation_model.summary()

In [None]:
# callback for early stopping
early_stopper = tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=3, verbose=2, restore_best_weights=True)

# train the model meant for evaluation
evaluation_model.fit(eval_trainloader,
                 validation_data=eval_valdataloader,
                 epochs=NUM_EVAL_EPOCHS,
                 callbacks=[early_stopper])

In [None]:
loss, acc = evaluation_model.evaluate(eval_testdataloader)

In [None]:
evaluation_model.save("linear_eval")