<a href="https://colab.research.google.com/github/ayulockin/SwAV-TF/blob/master/fine_tuning/Fine_Tuning_40_Epochs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports and Setups

In [None]:
!git clone https://github.com/ayulockin/SwAV-TF.git

Cloning into 'SwAV-TF'...
remote: Enumerating objects: 9, done.[K
remote: Counting objects:  11% (1/9)[Kremote: Counting objects:  22% (2/9)[Kremote: Counting objects:  33% (3/9)[Kremote: Counting objects:  44% (4/9)[Kremote: Counting objects:  55% (5/9)[Kremote: Counting objects:  66% (6/9)[Kremote: Counting objects:  77% (7/9)[Kremote: Counting objects:  88% (8/9)[Kremote: Counting objects: 100% (9/9)[Kremote: Counting objects: 100% (9/9), done.[K
remote: Compressing objects: 100% (9/9), done.[K
remote: Total 207 (delta 5), reused 0 (delta 0), pack-reused 198[K
Receiving objects: 100% (207/207), 15.83 MiB | 21.76 MiB/s, done.
Resolving deltas: 100% (95/95), done.


In [None]:
import sys
sys.path.append('SwAV-TF/utils')

import multicrop_dataset
import architecture

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds

from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.models import Sequential, Model

import matplotlib.pyplot as plt 
import numpy as np
import random
import time
import os

from tqdm import tqdm
from imutils import paths

tf.random.set_seed(666)
np.random.seed(666)

tfds.disable_progress_bar()

#### W&B - Experiment Tracking

In [None]:
%%capture
!pip install wandb

In [None]:
import wandb
from wandb.keras import WandbCallback

wandb.login()

## Restoring model weights from GCS Bucket

In [None]:
from tensorflow.keras.utils import get_file

In [None]:
feature_backbone_urlpath = "https://storage.googleapis.com/swav-tf/feature_backbone_40_epochs.h5"
prototype_urlpath = "https://storage.googleapis.com/swav-tf/projection_prototype_40_epochs.h5"

In [None]:
feature_backbone_weights = get_file('swav_feature_weights', feature_backbone_urlpath)
prototype_weights = get_file('swav_prototype_projection_weights', prototype_urlpath)

Downloading data from https://storage.googleapis.com/swav-tf/feature_backbone_40_epochs.h5
Downloading data from https://storage.googleapis.com/swav-tf/projection_prototype_40_epochs.h5


## Dataset gathering and preparation

We will be using only **10%** labeled training examples. 

In [None]:
# Gather Flowers dataset
train_ds, extra_train_ds, validation_ds = tfds.load(
    "tf_flowers",
    split=["train[:10%]", "train[10%:85%]", "train[85%:]"],
    as_supervised=True
)

AUTO = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 64

@tf.function
def scale_resize_image(image, label):
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize(image, (224, 224)) # Resizing to highest resolution used while training swav
    return (image, label)

training_ds = (
    train_ds
    .map(scale_resize_image, num_parallel_calls=AUTO)
    .batch(BATCH_SIZE)
    .prefetch(AUTO)
)

testing_ds = (
    validation_ds
    .map(scale_resize_image, num_parallel_calls=AUTO)
    .batch(BATCH_SIZE)
    .prefetch(AUTO)
)

[1mDownloading and preparing dataset tf_flowers/3.0.0 (download: 218.21 MiB, generated: Unknown size, total: 218.21 MiB) to /root/tensorflow_datasets/tf_flowers/3.0.0...[0m


local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead set
data_dir=gs://tfds-data/datasets.



[1mDataset tf_flowers downloaded and prepared to /root/tensorflow_datasets/tf_flowers/3.0.0. Subsequent calls will reuse this data.[0m


## Get SwAV architecture and Build Linear Model

In [None]:
def get_linear_classifier():
    # input placeholder
    inputs = Input(shape=(224, 224, 3))
    # get swav baseline model architecture
    feature_backbone = architecture.get_resnet_backbone()
    # load trained weights
    feature_backbone.load_weights(feature_backbone_weights)
    feature_backbone.trainable = False

    x = feature_backbone(inputs, training=False)
    outputs = Dense(5, activation="softmax")(x)
    linear_model = Model(inputs, outputs)

    return linear_model

In [None]:
tf.keras.backend.clear_session()
model = get_linear_classifier()
model.summary()

Model: "functional_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 224, 224, 3)]     0         
_________________________________________________________________
functional_1 (Functional)    (None, 2048)              23587712  
_________________________________________________________________
dense (Dense)                (None, 5)                 10245     
Total params: 23,597,957
Trainable params: 10,245
Non-trainable params: 23,587,712
_________________________________________________________________


## Callback

In [None]:
# Early Stopping to prevent overfitting
early_stopper = tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=2, verbose=2, restore_best_weights=True)

# Without Augmentation

#### Warm Up

In [None]:
# get model and compile
tf.keras.backend.clear_session()
model = get_linear_classifier()
model.summary()

model.compile(loss="sparse_categorical_crossentropy", metrics=["acc"],
                     optimizer='adam')

# initialize wandb run
wandb.init(entity='authors', project='swav-tf')

# train
history = model.fit(training_ds,
                 validation_data=(testing_ds),
                 epochs=35,
                 callbacks=[WandbCallback(),
                            early_stopper])

Model: "functional_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 224, 224, 3)]     0         
_________________________________________________________________
functional_1 (Functional)    (None, 2048)              23587712  
_________________________________________________________________
dense (Dense)                (None, 5)                 10245     
Total params: 23,597,957
Trainable params: 10,245
Non-trainable params: 23,587,712
_________________________________________________________________


Epoch 1/35
Epoch 2/35
Epoch 3/35
Epoch 4/35
Epoch 5/35
Epoch 6/35
Epoch 7/35
Epoch 8/35
Epoch 00008: early stopping


In [None]:
model.save('warmup.h5')

#### Fine tune CNN

In [None]:
 def get_classifier():
    # input placeholder
    inputs = Input(shape=(224, 224, 3))
    # get swav baseline model architecture
    feature_backbone = architecture.get_resnet_backbone()
    # load trained weights
    feature_backbone.load_weights(feature_backbone_weights)
    feature_backbone.trainable = True

    # load warmup model
    warmup_model = tf.keras.models.load_model('warmup.h5')
    # get trained output layer
    last_layer = warmup_model.get_layer('dense')

    
    x = feature_backbone(inputs, training=False)
    outputs = last_layer(x)
    linear_model = Model(inputs, outputs)

    return linear_model

In [None]:
# get model and compile
tf.keras.backend.clear_session()
full_trainable_model = get_classifier()
full_trainable_model.summary()

full_trainable_model.compile(loss="sparse_categorical_crossentropy", metrics=["acc"],
                     optimizer='adam')

# initialize wandb
wandb.init(entity='authors', project='swav-tf')

# train                                                        
history = full_trainable_model.fit(training_ds,
                 validation_data=(testing_ds),
                 epochs=35,
                 callbacks=[WandbCallback(),
                            early_stopper])

Model: "functional_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 224, 224, 3)]     0         
_________________________________________________________________
functional_1 (Functional)    (None, 2048)              23587712  
_________________________________________________________________
dense (Dense)                (None, 5)                 10245     
Total params: 23,597,957
Trainable params: 23,544,837
Non-trainable params: 53,120
_________________________________________________________________


Epoch 1/35




Epoch 2/35
Epoch 3/35
Epoch 4/35
Epoch 5/35
Epoch 00005: early stopping


#### Evaluation

In [None]:
loss, acc = full_trainable_model.evaluate(testing_ds)
wandb.log({'Test Accuracy': round(acc*100, 2)})



# Training with Augmentation


#### Augmentation

In [None]:
# Configs
CROP_SIZE = 224
MIN_SCALE = 0.5
MAX_SCALE = 1.

# Experimental options
options = tf.data.Options()
options.experimental_optimization.noop_elimination = True
options.experimental_optimization.map_vectorization.enabled = True
options.experimental_optimization.apply_default_optimizations = True
options.experimental_deterministic = False
options.experimental_threading.max_intra_op_parallelism = 1

In [None]:
@tf.function
def scale_image(image, label):
	image = tf.image.convert_image_dtype(image, tf.float32)
	return (image, label)

@tf.function
def random_apply(func, x, p):
	return tf.cond(
		tf.less(tf.random.uniform([], minval=0, maxval=1, dtype=tf.float32),
				tf.cast(p, tf.float32)),
		lambda: func(x),
		lambda: x)
 
@tf.function
def random_resize_crop(image, label):
  # Conditional resizing
  image = tf.image.resize(image, (260, 260))
  # Get the crop size for given min and max scale
  size = tf.random.uniform(shape=(1,), minval=MIN_SCALE*260,
		          maxval=MAX_SCALE*260, dtype=tf.float32)
  size = tf.cast(size, tf.int32)[0]
  # Get the crop from the image
  crop = tf.image.random_crop(image, (size,size,3))
  crop_resize = tf.image.resize(crop, (CROP_SIZE, CROP_SIZE))
  
  return crop_resize, label

@tf.function
def tie_together(image, label):
  # Scale the pixel values
  image, label = scale_image(image , label)
  # random horizontal flip
  image = random_apply(tf.image.random_flip_left_right, image, p=0.5)
  # Random resized crops
  image, label = random_resize_crop(image, label)
  
  return image, label

In [None]:
trainloader = (
	train_ds
	.shuffle(1024)
	.map(tie_together, num_parallel_calls=AUTO)
	.batch(BATCH_SIZE)
	.prefetch(AUTO)
)

trainloader = trainloader.with_options(options)

#### Warmup

In [None]:
# get model and compile
tf.keras.backend.clear_session()
model = get_linear_classifier()
model.summary()

model.compile(loss="sparse_categorical_crossentropy", metrics=["acc"],
                     optimizer='adam')

# initialize wandb run
wandb.init(entity='authors', project='swav-tf')

# train
history = model.fit(trainloader,
                 validation_data=(testing_ds),
                 epochs=35,
                 callbacks=[WandbCallback(),
                            early_stopper])

Model: "functional_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 224, 224, 3)]     0         
_________________________________________________________________
functional_1 (Functional)    (None, 2048)              23587712  
_________________________________________________________________
dense (Dense)                (None, 5)                 10245     
Total params: 23,597,957
Trainable params: 10,245
Non-trainable params: 23,587,712
_________________________________________________________________


Epoch 1/35
Epoch 2/35
Epoch 3/35
Epoch 4/35
Epoch 5/35
Epoch 6/35
Epoch 7/35
Epoch 00007: early stopping


In [None]:
model.save('warmup_augmentation.h5')

#### Fine tune CNN

In [None]:
 def get_classifier():
    # input placeholder
    inputs = Input(shape=(224, 224, 3))
    # get swav baseline model architecture
    feature_backbone = architecture.get_resnet_backbone()
    # load trained weights
    feature_backbone.load_weights(feature_backbone_weights)
    feature_backbone.trainable = True

    # load warmup model
    warmup_model = tf.keras.models.load_model('warmup_augmentation.h5')
    # get trained output layer
    last_layer = warmup_model.get_layer('dense')

    
    x = feature_backbone(inputs, training=False)
    outputs = last_layer(x)
    linear_model = Model(inputs, outputs)

    return linear_model

In [None]:
# get model and compile
tf.keras.backend.clear_session()
full_trainable_model = get_classifier()
full_trainable_model.summary()

full_trainable_model.compile(loss="sparse_categorical_crossentropy", metrics=["acc"],
                     optimizer='adam')

# initialize wandb
wandb.init(entity='authors', project='swav-tf')

# train                                                        
history = full_trainable_model.fit(trainloader,
                 validation_data=(testing_ds),
                 epochs=35,
                 callbacks=[WandbCallback(),
                            early_stopper])

Model: "functional_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 224, 224, 3)]     0         
_________________________________________________________________
functional_1 (Functional)    (None, 2048)              23587712  
_________________________________________________________________
dense (Dense)                (None, 5)                 10245     
Total params: 23,597,957
Trainable params: 23,544,837
Non-trainable params: 53,120
_________________________________________________________________


Epoch 1/35




Epoch 2/35
Epoch 3/35
Epoch 4/35
Epoch 5/35
Epoch 00005: early stopping


#### Evaluation

In [None]:
loss, acc = full_trainable_model.evaluate(testing_ds)
wandb.log({'Test Accuracy': round(acc*100, 2)})

