<a href="https://colab.research.google.com/github/ayulockin/SwAV-TF/blob/master/linear_evaluation/Linear_Evaluation_40_Epochs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports and Setups

In [None]:
!git clone https://github.com/ayulockin/SwAV-TF.git

Cloning into 'SwAV-TF'...
remote: Enumerating objects: 198, done.[K
Receiving objects:   0% (1/198)   Receiving objects:   1% (2/198)   Receiving objects:   2% (4/198)   Receiving objects:   3% (6/198)   Receiving objects:   4% (8/198)   Receiving objects:   5% (10/198)   Receiving objects:   6% (12/198)   Receiving objects:   7% (14/198)   Receiving objects:   8% (16/198)   Receiving objects:   9% (18/198)   Receiving objects:  10% (20/198)   Receiving objects:  11% (22/198)   Receiving objects:  12% (24/198)   Receiving objects:  13% (26/198)   Receiving objects:  14% (28/198)   Receiving objects:  15% (30/198)   Receiving objects:  16% (32/198)   Receiving objects:  17% (34/198)   Receiving objects:  18% (36/198)   Receiving objects:  19% (38/198)   Receiving objects:  20% (40/198)   Receiving objects:  21% (42/198)   Receiving objects:  22% (44/198)   Receiving objects:  23% (46/198)   Receiving objects:  24% (48/198)   Receiving objects:  25% (50/198)   

In [None]:
import sys
sys.path.append('SwAV-TF/utils')

import multicrop_dataset
import architecture

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds

from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.models import Sequential, Model

import matplotlib.pyplot as plt
import numpy as np
import random
import time
import os

from tqdm import tqdm
from imutils import paths

tf.random.set_seed(666)
np.random.seed(666)

tfds.disable_progress_bar()

#### W&B - Experiment Tracking

In [None]:
%%capture
!pip install wandb

In [None]:
import wandb
from wandb.keras import WandbCallback

wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: [32m[41mERROR[0m Not authenticated.  Copy a key from https://app.wandb.ai/authorize


API Key: ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

#### Restoring model weights from GCS Bucket - Model trained for 10 epochs

In [None]:
from tensorflow.keras.utils import get_file

In [None]:
feature_backbone_urlpath = "https://storage.googleapis.com/swav-tf/feature_backbone_40_epochs.h5"
prototype_urlpath = "https://storage.googleapis.com/swav-tf/projection_prototype_40_epochs.h5"

In [None]:
feature_backbone_weights = get_file('swav_feature_weights', feature_backbone_urlpath)
prototype_weights = get_file('swav_prototype_projection_weights', prototype_urlpath)

Downloading data from https://storage.googleapis.com/swav-tf/feature_backbone_40_epochs.h5
Downloading data from https://storage.googleapis.com/swav-tf/projection_prototype_40_epochs.h5


#### Dataset gathering and preparation

In [None]:
# Gather Flowers dataset
train_ds, validation_ds = tfds.load(
    "tf_flowers",
    split=["train[:85%]", "train[85%:]"],
    as_supervised=True
)

AUTO = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 256

@tf.function
def scale_resize_image(image, label):
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize(image, (224, 224)) # Resizing to highest resolution used while training swav
    return (image, label)

training_ds = (
    train_ds
    .map(scale_resize_image, num_parallel_calls=AUTO)
    .batch(BATCH_SIZE)
    .prefetch(AUTO)
)

testing_ds = (
    validation_ds
    .map(scale_resize_image, num_parallel_calls=AUTO)
    .batch(BATCH_SIZE)
    .prefetch(AUTO)
)

local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead set
data_dir=gs://tfds-data/datasets.



[1mDownloading and preparing dataset tf_flowers/3.0.0 (download: 218.21 MiB, generated: Unknown size, total: 218.21 MiB) to /root/tensorflow_datasets/tf_flowers/3.0.0...[0m
[1mDataset tf_flowers downloaded and prepared to /root/tensorflow_datasets/tf_flowers/3.0.0. Subsequent calls will reuse this data.[0m


#### Get SwAV architecture and Build Linear Model

In [None]:
def get_linear_classifier(alpha=1e-6):
    # input placeholder
    inputs = Input(shape=(224, 224, 3))
    # get swav baseline model architecture
    feature_backbone = architecture.get_resnet_backbone()
    # load trained weights
    feature_backbone.load_weights(feature_backbone_weights)
    feature_backbone.trainable = False

    x = feature_backbone(inputs, training=False)
    outputs = Dense(5, activation="softmax", 
                    kernel_regularizer=tf.keras.regularizers.L2(alpha))(x)
    linear_model = Model(inputs, outputs)

    return linear_model

In [None]:
tf.keras.backend.clear_session()
model = get_linear_classifier()
model.summary()

Model: "functional_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 224, 224, 3)]     0         
_________________________________________________________________
functional_1 (Functional)    (None, 2048)              23587712  
_________________________________________________________________
dense (Dense)                (None, 5)                 10245     
Total params: 23,597,957
Trainable params: 10,245
Non-trainable params: 23,587,712
_________________________________________________________________


#### Callback

In [None]:
# Early Stopping to prevent overfitting
early_stopper = tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=2, verbose=2, restore_best_weights=True)

# Training Linear Classifier - Without Augmentation

#### Training

In [None]:
# get model and compile
tf.keras.backend.clear_session()
model = get_linear_classifier(alpha=1e-6)
model.summary()

model.compile(loss="sparse_categorical_crossentropy", metrics=["acc"],
                     optimizer="adam")

# initialize wandb run
wandb.init(entity='authors', project='swav-tf')

# train 
history = model.fit(training_ds,
                 validation_data=(testing_ds),
                 epochs=100,
                 callbacks=[WandbCallback(),
                            early_stopper])

Model: "functional_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 224, 224, 3)]     0         
_________________________________________________________________
functional_1 (Functional)    (None, 2048)              23587712  
_________________________________________________________________
dense (Dense)                (None, 5)                 10245     
Total params: 23,597,957
Trainable params: 10,245
Non-trainable params: 23,587,712
_________________________________________________________________


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78

#### Evaluation

In [None]:
loss, acc = model.evaluate(testing_ds)
wandb.log({'Test Accuracy': round(acc*100, 2)})



# Training Linear Classifier with Augmentation

#### Augmentation

In [None]:
# Configs
CROP_SIZE = 224
MIN_SCALE = 0.5
MAX_SCALE = 1.

# Experimental options
options = tf.data.Options()
options.experimental_optimization.noop_elimination = True
options.experimental_optimization.map_vectorization.enabled = True
options.experimental_optimization.apply_default_optimizations = True
options.experimental_deterministic = False
options.experimental_threading.max_intra_op_parallelism = 1

In [None]:
@tf.function
def scale_image(image, label):
	image = tf.image.convert_image_dtype(image, tf.float32)
	return (image, label)

@tf.function
def random_apply(func, x, p):
	return tf.cond(
		tf.less(tf.random.uniform([], minval=0, maxval=1, dtype=tf.float32),
				tf.cast(p, tf.float32)),
		lambda: func(x),
		lambda: x)
 
@tf.function
def random_resize_crop(image, label):
  # Conditional resizing
  image = tf.image.resize(image, (260, 260))
  # Get the crop size for given min and max scale
  size = tf.random.uniform(shape=(1,), minval=MIN_SCALE*260,
		          maxval=MAX_SCALE*260, dtype=tf.float32)
  size = tf.cast(size, tf.int32)[0]
  # Get the crop from the image
  crop = tf.image.random_crop(image, (size,size,3))
  crop_resize = tf.image.resize(crop, (CROP_SIZE, CROP_SIZE))
  
  return crop_resize, label

@tf.function
def tie_together(image, label):
  # Scale the pixel values
  image, label = scale_image(image , label)
  # random horizontal flip
  image = random_apply(tf.image.random_flip_left_right, image, p=0.5)
  # Random resized crops
  image, label = random_resize_crop(image, label)
  
  return image, label

In [None]:
trainloader = (
	train_ds
	.shuffle(1024)
	.map(tie_together, num_parallel_calls=AUTO)
	.batch(BATCH_SIZE)
	.prefetch(AUTO)
)

trainloader = trainloader.with_options(options)

#### Training

In [None]:
# get model and compile
tf.keras.backend.clear_session()
model = get_linear_classifier(alpha=1e-6)
model.summary()

model.compile(loss="sparse_categorical_crossentropy", metrics=["acc"],
                     optimizer="adam")

# initialize wandb run
wandb.init(entity='authors', project='swav-tf')

# train 
history = model.fit(training_ds,
                 validation_data=(testing_ds),
                 epochs=100,
                 callbacks=[WandbCallback(),
                            early_stopper])

Model: "functional_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 224, 224, 3)]     0         
_________________________________________________________________
functional_1 (Functional)    (None, 2048)              23587712  
_________________________________________________________________
dense (Dense)                (None, 5)                 10245     
Total params: 23,597,957
Trainable params: 10,245
Non-trainable params: 23,587,712
_________________________________________________________________


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78

#### Evaluation

In [None]:
loss, acc = model.evaluate(testing_ds)
wandb.log({'Test Accuracy': round(acc*100, 2)})

