# 📦 Packages and Basic Setup
---

In [None]:
%%capture
!pip install -U rich

import os
import random
import numpy as np
from rich import print
import tensorflow as tf
from itertools import groupby
from rich.progress import track
import tensorflow_datasets as tfds
from tensorflow.python.ops.numpy_ops import np_config

from typing import Callable, Optional, Tuple, Any, List

# Experimental options
options = tf.data.Options()
options.experimental_optimization.noop_elimination = True
options.experimental_optimization.apply_default_optimizations = True
options.experimental_deterministic = False
options.threading.max_intra_op_parallelism = 1
np_config.enable_numpy_behavior()

AUTOTUNE = tf.data.experimental.AUTOTUNE
strategy = tf.distribute.MirroredStrategy()

In [None]:
# @title ⚙ Configuration
GLOBAL_SEED = 42  # @param {type: "number"}
NUM_CROPS = [2, 3]
MIN_SCALE = [0.5, 0.14]
MAX_SCALE = [1.0, 0.5]
SIZE_CROPS = [224, 96]
CROPS_FOR_ASSIGN = [0, 1]
NUM_TRAINING_EPOCHS = 10  # @param {type: "number"}
NUM_EVAL_EPOCHS = 100  # @param {type: "number"}
TRAIN_BATCH_SIZE = 8  # @param {type: "number"}
EVAL_BATCH_SIZE = 64  # @param {type: "number"}
TEMPERATURE = 0.1  # @param {type: "number"}
DECAY_STEPS = 1000  # @param {type: "number"}
WEIGHT_DECAY = 1e-6  # @param {type: "number"}
BASE_LR = 0.2  # @param {type: "number"}
EVAL_LR = 0.02  # @param {type: "number"}


# ============ Random Seed ============
def seed_everything(seed=GLOBAL_SEED):
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    tf.experimental.numpy.random.seed(seed)
    # When running on the CuDNN backend, two further options must be set
    os.environ["TF_CUDNN_DETERMINISTIC"] = "1"
    os.environ["TF_DETERMINISTIC_OPS"] = "1"
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")


seed_everything()

In [None]:
## Limit GPU memory growth
gpus = tf.config.list_physical_devices("GPU")
if gpus:
    try:
        # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.list_logical_devices("GPU")
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)

# 🆘 Utility Classes and Functions
---

In [None]:
@tf.function
def random_apply(func: Callable, x: tf.Tensor, p: float) -> tf.Tensor:
    """Randomly apply the desired func to the input image"""
    return tf.cond(
        tf.less(
            tf.random.uniform([], minval=0, maxval=1, dtype=tf.float32),
            tf.cast(p, tf.float32),
        ),
        lambda: func(x),
        lambda: x,
    )


def shuffle_zipped_output(a: Any, b: Any, c: Any, d: Any, e: Any) -> Tuple[Any]:
    """Shuffle the given inputs"""
    listify = [a, b, c, d, e]
    random.shuffle(listify)

    return listify[0], listify[1], listify[2], listify[3], listify[4]


def sinkhorn(
    sample_prototype_batch: tf.Tensor, num_sinkhorn_iters: int = 3
) -> tf.Tensor:
    """
    Perform sinkhorn normalization on the sample prototype batch
    """
    Q = tf.transpose(tf.exp(sample_prototype_batch / 0.05))
    Q /= tf.keras.backend.sum(Q)
    K, B = tf.shape(Q)

    u = tf.zeros_like(K, dtype=tf.float32)
    r = tf.ones_like(K, dtype=tf.float32) / K
    c = tf.ones_like(B, dtype=tf.float32) / B

    for _ in range(num_sinkhorn_iters):
        u = tf.keras.backend.sum(Q, axis=1)
        Q *= tf.expand_dims((r / u), axis=1)
        Q *= tf.expand_dims(c / tf.keras.backend.sum(Q, axis=0), 0)

    final_quantity = Q / tf.keras.backend.sum(Q, axis=0, keepdims=True)
    final_quantity = tf.transpose(final_quantity)

    return final_quantity

## 🖖 Utilites for Data Augmentation

In [None]:
@tf.function
def scale_image(image: tf.Tensor) -> tf.Tensor:
    """Convert all images to float32"""
    image = tf.image.convert_image_dtype(image, tf.float32)
    return image


@tf.function
def scale_image_with_label(image: tf.Tensor, label: tf.Tensor) -> Tuple[tf.Tensor]:
    """Convert all images to float32"""
    image = tf.image.convert_image_dtype(image, tf.float32)
    return (image, label)


@tf.function
def gaussian_blur(
    image: tf.Tensor, kernel_size: int = 23, padding: str = "SAME"
) -> tf.Tensor:
    """
    Randomly apply Gaussian Blur to the input image

    Reference: https://github.com/google-research/simclr/blob/master/data_util.py
    """

    sigma = tf.random.uniform((1,)) * 1.9 + 0.1
    radius = tf.cast(kernel_size / 2, tf.int32)
    kernel_size = radius * 2 + 1
    x = tf.cast(range(-radius, radius + 1), tf.float32)
    blur_filter = tf.exp(
        -tf.pow(x, 2.0) / (2.0 * tf.pow(tf.cast(sigma, tf.float32), 2.0))
    )
    blur_filter /= tf.reduce_sum(blur_filter)

    # One vertical and one horizontal filter.
    blur_v = tf.reshape(blur_filter, [kernel_size, 1, 1, 1])
    blur_h = tf.reshape(blur_filter, [1, kernel_size, 1, 1])
    num_channels = tf.shape(image)[-1]
    blur_h = tf.tile(blur_h, [1, 1, num_channels, 1])
    blur_v = tf.tile(blur_v, [1, 1, num_channels, 1])
    expand_batch_dim = image.shape.ndims == 3
    if expand_batch_dim:
        image = tf.expand_dims(image, axis=0)
    blurred = tf.nn.depthwise_conv2d(
        image, blur_h, strides=[1, 1, 1, 1], padding=padding
    )
    blurred = tf.nn.depthwise_conv2d(
        blurred, blur_v, strides=[1, 1, 1, 1], padding=padding
    )
    if expand_batch_dim:
        blurred = tf.squeeze(blurred, axis=0)
    return blurred


@tf.function
def color_jitter(image: tf.Tensor, s: float = 0.5) -> tf.Tensor:
    """Randomly apply Color Jittering to the input image"""
    x = tf.image.random_brightness(image, max_delta=0.8 * s)
    x = tf.image.random_contrast(x, lower=1 - 0.8 * s, upper=1 + 0.8 * s)
    x = tf.image.random_saturation(x, lower=1 - 0.8 * s, upper=1 + 0.8 * s)
    x = tf.image.random_hue(x, max_delta=0.2 * s)
    x = tf.clip_by_value(x, 0, 1)
    return x


@tf.function
def color_drop(image: tf.Tensor) -> tf.Tensor:
    """Randomly convert the input image to GrayScale"""
    image = tf.image.rgb_to_grayscale(image)
    image = tf.tile(image, [1, 1, 3])
    return image


@tf.function
def random_resize_crop(
    image: tf.Tensor,
    min_scale: float,
    max_scale: float,
    crop_size: int,
    label: Optional[tf.Tensor],
) -> tf.Tensor:
    """Randomly resize and crop the input image"""
    if crop_size == 224:
        image_shape = 260
        image = tf.image.resize(image, (image_shape, image_shape))
    else:
        image_shape = 160
        image = tf.image.resize(image, (image_shape, image_shape))

    # Get the crop size for given min and max scale
    size = tf.random.uniform(
        shape=(1,),
        minval=min_scale * image_shape,
        maxval=max_scale * image_shape,
        dtype=tf.float32,
    )
    size = tf.cast(size, tf.int32)[0]

    # Get the crop from the image
    crop = tf.image.random_crop(image, (size, size, 3))
    crop_resize = tf.image.resize(crop, (crop_size, crop_size))

    return (crop_resize, label) if label is not None else crop_resize


@tf.function
def augmentation_pipeline(image: tf.Tensor) -> tf.Tensor:
    # Random flips
    image = random_apply(tf.image.flip_left_right, image, p=0.5)
    # Randomly apply gausian blur
    image = random_apply(gaussian_blur, image, p=0.5)
    # Randomly apply transformation (color distortions) with probability p.
    image = random_apply(color_jitter, image, p=0.8)
    # Randomly apply grayscale
    image = random_apply(color_drop, image, p=0.2)

    return image


@tf.function
def apply_augmentation(
    image: tf.Tensor, min_scale: float, max_scale: float, crop_size: int
) -> tf.Tensor:
    # Retrieve the image features
    image = image["image"]
    # Scale the pixel values
    image = scale_image(image)
    # Random resized crops
    image = random_resize_crop(image, min_scale, max_scale, crop_size, label=None)
    # Color distortions & Gaussian blur
    image = augmentation_pipeline(image)

    return image


@tf.function
def eval_augmentation(image: tf.Tensor, label: tf.Tensor) -> tf.Tensor:
    # Scale the pixel values
    image, label = scale_image_with_label(image, label)
    # random horizontal flip
    image = random_apply(tf.image.flip_left_right, image, p=0.5)
    # Random resized crops
    image, label = random_resize_crop(
        image,
        min_scale=MIN_SCALE[0],
        max_scale=MAX_SCALE[0],
        crop_size=SIZE_CROPS[0],
        label=label,
    )

    return image, label

# 💿 The Dataset
---
For the purposes of this example, we use the TF Flowers dataset.



In [None]:
tfds.disable_progress_bar()

# Gather Flowers dataset
train_ds, validation_ds = tfds.load(
    "tf_flowers",
    split=["train[:85%]", "train[85%:]"],
)

## 🖖 Data Augmentation Pipeline

In [None]:
# We create a Tuple because we have two loaders corresponding to each view
trainloaders = tuple()

for i, num_crop in enumerate(NUM_CROPS):
    for _ in range(num_crop):
        trainloader = train_ds.shuffle(1024).map(
            lambda x: apply_augmentation(x, MIN_SCALE[i], MAX_SCALE[i], SIZE_CROPS[i]),
            num_parallel_calls=AUTOTUNE,
        )
        trainloader = trainloader.with_options(options)
        trainloaders += (trainloader,)

## ⚙️ Dataloader

In [None]:
# zip both the dataloaders together
trainloader = tf.data.Dataset.zip(trainloaders)

# final trainloader to be used for training
trainloader = trainloader.batch(
    TRAIN_BATCH_SIZE * strategy.num_replicas_in_sync
).prefetch(AUTOTUNE)

# ✍️ Model Architecture & Training
---


## 🏠 Building the network

![](https://i.ibb.co/TtSW4Fd/figure-3.png)

In [None]:
class SwAV(tf.keras.Model):
    """SwAV model class"""

    def __init__(
        self,
        units: Tuple[int, int] = (1024, 96),
        projection_dim=10,
        num_sinkhorn_iters: int = 3,
        CROPS_FOR_ASSIGN: Tuple[int, int] = (0, 1),
        NUM_CROPS: Tuple[int, int] = (2, 3),
        TEMPERATURE: float = 0.1,
        **kwargs
    ) -> None:
        super().__init__(**kwargs)
        self.units = units
        self.projection_dim = projection_dim
        self.CROPS_FOR_ASSIGN = CROPS_FOR_ASSIGN
        self.NUM_CROPS = NUM_CROPS
        self.TEMPERATURE = TEMPERATURE
        self.num_sinkhorn_iters = num_sinkhorn_iters

        self.encoder = self.build_encoder()
        self.projection = self.build_projection(self.units, self.projection_dim)

        self.loss_tracker = tf.keras.metrics.Mean(name="swav_loss")

    def get_config(self) -> dict:
        return {
            "units": self.units,
            "projection_dim": self.projection_dim,
            "num_sinkhorn_iters": self.num_sinkhorn_iters,
            "CROPS_FOR_ASSIGN": self.CROPS_FOR_ASSIGN,
            "NUM_CROPS": self.NUM_CROPS,
            "TEMPERATURE": self.TEMPERATURE,
        }

    @classmethod
    def from_config(cls, config, custom_objects=None) -> "SwAV":
        return cls(**config)

    @property
    def metrics(self) -> list:
        return [
            self.loss_tracker,
        ]

    def save_weights(
        self,
        filepath: str = "artifacts/swav/",
        overwrite=True,
        save_format="h5",
        options=None,
    ) -> None:
        self.encoder.save_weights("encoder.h5")
        self.projection.save_weights("projection.h5")

    def build_encoder(self) -> tf.keras.Model:
        encoder_input = tf.keras.layers.Input((None, None, 3))
        base_model = tf.keras.applications.ResNet50(
            include_top=False, weights=None, input_shape=(None, None, 3)
        )
        base_model.trainable = True
        representations = base_model(encoder_input, training=True)
        encoder_output = tf.keras.layers.GlobalAveragePooling2D()(representations)
        encoder = tf.keras.models.Model(
            inputs=encoder_input, outputs=encoder_output, name="encoder"
        )
        return encoder

    def build_projection(self, units, projection_dim) -> tf.keras.Model:
        inputs = tf.keras.layers.Input((2048,))
        projection_1 = tf.keras.layers.Dense(units[0])(inputs)
        projection_1 = tf.keras.layers.BatchNormalization()(projection_1)
        projection_1 = tf.keras.layers.Activation("relu")(projection_1)

        projection_2 = tf.keras.layers.Dense(units[1])(projection_1)
        projection_2_normalize = tf.math.l2_normalize(
            projection_2, axis=1, name="projection"
        )

        prototype = tf.keras.layers.Dense(
            projection_dim, use_bias=False, name="prototype"
        )(projection_2_normalize)

        return tf.keras.models.Model(
            inputs=inputs, outputs=[projection_2_normalize, prototype]
        )

    def train_step(self, images: tf.Tensor) -> dict:
        """
        References:

        * https://github.com/facebookresearch/swav/blob/master/main_swav.py
        * https://github.com/facebookresearch/swav/issues/19
        * https://github.com/ayulockin/SwAV-TF
        """
        im1, im2, im3, im4, im5 = images
        inputs = [im1, im2, im3, im4, im5]
        batch_size = inputs[0].shape[0]

        # ============ create crop entries with same shape ... ============
        crop_sizes = [inp.shape[1] for inp in inputs]  # list of crop size of views
        unique_consecutive_count = [
            len([elem for elem in g]) for _, g in groupby(crop_sizes)
        ]  # equivalent to torch.unique_consecutive
        idx_crops = tf.cumsum(unique_consecutive_count)

        # ============ multi-res forward passes ... ============
        start_idx = 0
        with tf.GradientTape() as tape:
            for end_idx in idx_crops:
                concat_input = tf.stop_gradient(
                    tf.concat(inputs[start_idx:end_idx], axis=0)
                )
                _embedding = self.encoder(
                    concat_input
                )  # get embedding of same dim views together
                if start_idx == 0:
                    embeddings = _embedding  # for first iter
                else:
                    embeddings = tf.concat(
                        (embeddings, _embedding), axis=0
                    )  # concat all the embeddings from all the views
                start_idx = end_idx

            projection, prototype = self.projection(
                embeddings
            )  # get normalized projection and prototype
            projection = tf.stop_gradient(projection)

            # ============ swav loss ... ============
            loss = 0
            for i, crop_id in enumerate(self.CROPS_FOR_ASSIGN):
                with tape.stop_recording():
                    out = prototype[batch_size * crop_id : batch_size * (crop_id + 1)]

                    # get assignments
                    q = sinkhorn(
                        out, self.num_sinkhorn_iters
                    )  # sinkhorn is used for cluster assignment

                # cluster assignment prediction
                subloss = 0
                for v in np.delete(
                    np.arange(np.sum(self.NUM_CROPS)), crop_id
                ):  # (for rest of the portions compute p and take cross entropy with q)
                    p = tf.nn.softmax(
                        prototype[batch_size * v : batch_size * (v + 1)]
                        / self.TEMPERATURE
                    )
                    subloss -= tf.math.reduce_mean(
                        tf.math.reduce_sum(q * tf.math.log(p), axis=1)
                    )
                loss += subloss / tf.cast(
                    (tf.reduce_sum(self.NUM_CROPS) - 1), tf.float32
                )

            loss /= len(self.CROPS_FOR_ASSIGN)  # type: ignore

        # ============ backprop ... ============
        variables = (
            self.encoder.trainable_variables + self.projection.trainable_variables
        )
        gradients = tape.gradient(loss, variables)
        self.optimizer.apply_gradients(zip(gradients, variables))

        # Compute our own metrics
        self.loss_tracker.update_state(loss)

        # Return a dict mapping metric names to current value
        return {"loss": self.loss_tracker.result()}

## 🏃 Train !!

In [None]:
lr_decayed_fn = tf.keras.optimizers.schedules.CosineDecay(
    initial_learning_rate=BASE_LR, decay_steps=DECAY_STEPS
)
opt = tf.keras.optimizers.experimental.SGD(learning_rate=lr_decayed_fn)

with strategy.scope():
    model = SwAV()
    model.compile(optimizer=opt, run_eagerly=True)

model.fit(
    trainloader,
    epochs=NUM_TRAINING_EPOCHS,
    callbacks=[
        tf.keras.callbacks.BackupAndRestore(
            "artifacts/swav/checkpoints/", save_freq="epoch", delete_checkpoint=False
        )
    ],
)

# 👨🏻‍⚖️ Linear Evaluation
---
We use a linear evaluation protocol i.e., we train a linear classifier on top of the frozen representations of the ResNet-50 backbone pretrained with SwAV.

In [None]:
feature_backbone_urlpath = "https://github.com/ayulockin/SwAV-TF/releases/download/v0.1.0/feature_backbone_10_epochs.h5"
feature_backbone_weights = tf.keras.utils.get_file(
    "swav_feature_weights", feature_backbone_urlpath
)

## ⚙️ Dataloader for Linear Evaluation

In [None]:
tf.keras.backend.clear_session()

# Gather Flowers dataset
train_ds, validation_ds = tfds.load(
    "tf_flowers", split=["train[:85%]", "train[85%:]"], as_supervised=True
)

eval_trainloader = (
    train_ds.shuffle(1024)
    .map(eval_augmentation, num_parallel_calls=AUTOTUNE)
    .batch(EVAL_BATCH_SIZE)
    .prefetch(AUTOTUNE)
)

eval_testloader = (
    validation_ds.map(scale_image_with_label, num_parallel_calls=AUTOTUNE)
    .batch(EVAL_BATCH_SIZE)
    .prefetch(AUTOTUNE)
)

eval_trainloader = eval_trainloader.with_options(options)

## 🏠 Building the network

In [None]:
def get_linear_classifier() -> tf.keras.Model:
    # input placeholder
    inputs = tf.keras.layers.Input(shape=(224, 224, 3))
    # get swav model architecture
    base_model = SwAV()
    feature_backbone = base_model.build_encoder()
    # load trained weights
    feature_backbone.load_weights(feature_backbone_weights)
    feature_backbone.trainable = False
    x = feature_backbone(inputs, training=False)
    outputs = tf.keras.layers.Dense(5, activation="softmax")(x)
    linear_model = tf.keras.Model(inputs, outputs)

    return linear_model

## 🏃 Evaluation !!

In [None]:
with strategy.scope():
    evaluation_model = get_linear_classifier()
    evaluation_model.compile(
        loss="sparse_categorical_crossentropy",
        metrics=["acc"],
        optimizer=tf.keras.optimizers.experimental.SGD(learning_rate=lr_decayed_fn),
    )

evaluation_model.summary()

In [None]:
# callback for early stopping
early_stopper = tf.keras.callbacks.EarlyStopping(
    monitor="val_loss", patience=3, verbose=2, restore_best_weights=True
)

# train the model meant for evaluation
evaluation_model.fit(
    eval_trainloader, epochs=NUM_EVAL_EPOCHS, callbacks=[early_stopper]
)

In [None]:
loss, acc = evaluation_model.evaluate(eval_testloader)