<a href="https://colab.research.google.com/github/ayulockin/SwAV-TF/blob/master/Linear_Evaluation_on_Flower_Dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://username:password@github.com/ayulockin/SwAV-TF.git

Cloning into 'SwAV-TF'...
remote: Enumerating objects: 110, done.[K
remote: Counting objects: 100% (110/110), done.[K
remote: Compressing objects: 100% (102/102), done.[K
remote: Total 110 (delta 48), reused 22 (delta 8), pack-reused 0[K
Receiving objects: 100% (110/110), 10.81 MiB | 795.00 KiB/s, done.
Resolving deltas: 100% (48/48), done.


In [2]:
import sys
sys.path.append('SwAV-TF/utils')

import architecture

In [3]:
import tensorflow as tf
import tensorflow_datasets as tfds

from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential

import matplotlib.pyplot as plt
import numpy as np
import random
import time
import os

from tqdm import tqdm
from imutils import paths

tf.random.set_seed(666)
np.random.seed(666)

tfds.disable_progress_bar()

## Restoring model weights from GCS Bucket

In [4]:
from tensorflow.keras.utils import get_file

In [5]:
feature_backbone_urlpath = "https://storage.googleapis.com/swav-tf/feature_backbone_20_epochs.h5"
prototype_urlpath = "https://storage.googleapis.com/swav-tf/projection_prototype_20_epochs.h5"

In [6]:
feature_backbone_weights = get_file('swav_feature_weights', feature_backbone_urlpath)
prototype_weights = get_file('swav_prototype_projection_weights', prototype_urlpath)

Downloading data from https://storage.googleapis.com/swav-tf/feature_backbone_20_epochs.h5
Downloading data from https://storage.googleapis.com/swav-tf/projection_prototype_20_epochs.h5


## Dataset gathering and preparation

In [7]:
# Gather Flowers dataset
train_ds, _, validation_ds = tfds.load(
    "tf_flowers",
    split=["train[:10%]", "train[10%:85%]", "train[85%:]"], # notice 10% of the images was used for training our linear evaluater. 
    data_dir='tf_dataset',
    as_supervised=True
)

AUTO = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 32

@tf.function
def scale_resize_image(image, label):
    image = tf.image.resize(image, (224, 224)) # Resizing to high resolution used while training swav
    image = tf.image.convert_image_dtype(image, tf.float32)
    return (image, label)

train_ds = (
    train_ds
    .map(scale_resize_image, num_parallel_calls=AUTO)
    .batch(32)
    .prefetch(AUTO)
)

test_ds = (
    validation_ds
    .map(scale_resize_image, num_parallel_calls=AUTO)
    .batch(32)
    .prefetch(AUTO)
)

[1mDownloading and preparing dataset tf_flowers/3.0.0 (download: 218.21 MiB, generated: Unknown size, total: 218.21 MiB) to tf_dataset/tf_flowers/3.0.0...[0m


local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead set
data_dir=gs://tfds-data/datasets.



[1mDataset tf_flowers downloaded and prepared to tf_dataset/tf_flowers/3.0.0. Subsequent calls will reuse this data.[0m


## Get SwAV architecture and Build Linear Model

In [8]:
feature_backbone = architecture.get_resnet_backbone()
feature_backbone.summary()

Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, None, None, 3)]   0         
_________________________________________________________________
resnet50 (Functional)        (None, None, None, 2048)  23587712  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
Total params: 23,587,712
Trainable params: 23,534,592
Non-trainable params: 53,120
_________________________________________________________________


In [9]:
projection_prototype = architecture.get_projection_prototype(15)
projection_prototype.summary()

Model: "functional_3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            [(None, 2048)]       0                                            
__________________________________________________________________________________________________
dense (Dense)                   (None, 2048)         4196352     input_3[0][0]                    
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 2048)         8192        dense[0][0]                      
__________________________________________________________________________________________________
activation (Activation)         (None, 2048)         0           batch_normalization[0][0]        
_______________________________________________________________________________________

#### Load trained weights

In [10]:
feature_backbone.load_weights(feature_backbone_weights)
projection_prototype.load_weights(prototype_weights)

#### Linear Evaluater

In [11]:
def get_linear_model(features):
    linear_model = Sequential([Dense(5, input_shape=(features, ), activation="softmax")])
    return linear_model

## Evaluation

In [19]:
def get_image_representation(trainloader):
    # get embedding from feature backbone model
    embeddings = []
    labels = []
    for image, label in trainloader:
        labels.extend(label)

        embedding = feature_backbone.predict(image)
        embeddings.extend(embedding.tolist())
    embeddings = np.array(embeddings)

    # get projection from trained projection head
    projections, prototypes = projection_prototype(embeddings)

    return np.array(embeddings), np.array(projections), np.array(prototypes), np.array(labels)

In [21]:
train_embeddings, train_projections, _, train_labels = get_image_representation(train_ds)
test_embeddings, test_projections, _, test_labels = get_image_representation(test_ds)

print(train_embeddings.shape, train_projections.shape, train_labels.shape)
print(test_embeddings.shape, test_projections.shape, test_labels.shape)

(367, 2048) (367, 128) (367,)
(550, 2048) (550, 128) (550,)


In [22]:
# Early Stopping to prevent overfitting
early_stopper = tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=5, verbose=2, restore_best_weights=True)

In [26]:
tf.keras.backend.clear_session()
linear_model = get_linear_model(128)
linear_model.summary()

linear_model.compile(loss="sparse_categorical_crossentropy", metrics=["acc"],
                     optimizer="adam")

history = linear_model.fit(train_projections, train_labels,
                 validation_data=(test_projections, test_labels),
                 batch_size=32,
                 epochs=35,
                 callbacks=[early_stopper])

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense (Dense)                (None, 5)                 645       
Total params: 645
Trainable params: 645
Non-trainable params: 0
_________________________________________________________________
Epoch 1/35
Epoch 2/35
Epoch 3/35
Epoch 4/35
Epoch 5/35
Epoch 6/35
Epoch 7/35
 1/12 [=>............................] - ETA: 0s - loss: 1.5929 - acc: 0.2812Restoring model weights from the end of the best epoch.
Epoch 00007: early stopping
