### Model 
- Training
    backbone(CLIP) + Dropout + Dense(units = 256) + Arcface + Softmax (classes = 17691)
- Inference
    backbone(CLIP) + Dropout + Dense(units = 256) + AdaptiveAveragePooling(n=64)

In [None]:
from transformers import CLIPProcessor, TFCLIPVisionModel, CLIPFeatureExtractor

import re
import os 
import glob
import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt
import random 
import math 
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers 
from sklearn import metrics 
from sklearn.model_selection import KFold, train_test_split, StratifiedKFold
from tensorflow.keras import backend as K
import tensorflow_addons as tfa
from tqdm.auto import tqdm
from sklearn.preprocessing import normalize

import pickle
import json
import tensorflow_hub as tfhub
from datetime import datetime
import gc
import requests
from mpl_toolkits import axes_grid1

#### Device

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print("Running on TPU", tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.TPUStrategy(tpu)
else: 
    # Default TF works on CPU and GPU
    strategy = tf.distribute.get_strategy()
    
AUTO = tf.data.experimental.AUTOTUNE
print("Replicas: ", strategy.num_replicas_in_sync)

In [None]:
# If GPU instance it makes mixed precision enable 
if strategy.num_replicas_in_sync == 1:
    from tensorflow.keras.mixed_precision import experimnetal as mix_precision
    policy = mixed_precition.Policy('mixed_float16')
    mixed_precision.set_policy(policy)
    


In [None]:
class config:
    VERSION = 3
    SUBV = "Clip_ViT_Train"
    
    SEED = 42
    
    # pretrained Model
    RESUME = True
    RESUME_EPOCH = 0
    RESUME_WEIGHT = "../input/guei-v6-clip-vit-large-arcface-train-projection/clip-vit-large-patch14_224pix-emb256_arcface_entire.h5"
    
    #backbone model 
    model_type = "clip-vit-large-patch14"
    EFF_SIZE = 0
    EFF2_TYPE = ""
    IMAGE_SIZE = 224
    
    # projection layer 
    N_CLASSES = 17691
    EMB_DIM = 256
    
    # training 
    TRAIN = False
    BATCH_SIZE = 200 * strategy.num_replicas_in_sync
    EPOCHS = 100
    LR = 0.001
    save_dir = "./"
    
    DEBUG = False
    
# Function to seed everything 
def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    tf.random.set_seed(seed)

# model name 
MODEL_NAME = None
if config.model_type == "effnetv1":
    MODEL_NAME = f"effnet_v1_b{config.EFF_SIZE}"
elif config.model_type == "effnetv2":
    MODEL_NAME = f"effnetv2_{config.EFF2_TYPE}"
elif "swin" in config.model_type:
    MODEL_NAME = config.model_type
elif "conv" in config.model_tpye:
    MODEL_NAME = config.model_type
else:
    MODEL_NAME = config.model_type
    
config.MODEL_NAME = MODEL_NAME
print(MODEL_NAME)


### TFRecords

In [None]:
from kaggle_datasets import KaggleDatasets

In [None]:
train_shard_suffix = '*-train-*.tfrec'

ROOT_DIRS = [
    "guie-glr2021mini-tfrecords-label-10691-17690",
    "guie-imagenet1k-mini1-tfrecords-label-0-999",
    "guie-products10k-tfrecords-label-1000-10690",
]

train_set_path = []
valid_set_path = []
for ROOT_DIR in ROOT_DIRS:
    GCS_DS_PATH = KaggleDatasets().get_gcs_path(ROOT_DIR)
    
    print(f"\"{ROOT_DIR}\" : \"{GCS_DS_PATH}\", ")
    files = sorted(tf.io.gfile.glob(GCS_DS_PATH + f"/{train_shared_suffix}"))
    # split data 
    train_set_path += random.samples(files, int(len(files) * 0.9 ))
    valid_set_path += [ file for file in files if not file in train_set_path ]
    print(ROOT_DIR, ", number of tfrecords", len(files))

train_set_path = sorted( train_set_path )
valid_set_path = sorted( valid_set_path )
print("# of tfrecords for training: ", len(train_set_path))
print("# of tfrecords for training: ", len(valid_set_path))

if config.DEBUG:
    train_set_path = random.sample( train_set_path, 4)
    print("Debug: reduce training data. num = ", len(trian_set_path))
    
    valid_set_path = train_set_path
    print("debug: reduce validation data. num = ", len(valid_set_path))

In [None]:
def get_num_of_image(file):
    return int(file.split("/")[-1].split(".")[0].split("-")[-1])

train_set_len = sum( [ get_num_of_images(file) for file in train_set_path ] )
valid_set_len = sum( [ get_num_of_images(file) for file in valid_set_path ] )

train_set_len, valid_set_len

### Dataset Pipeline 

In [None]:
def deserialization_fn(serialized_example):
    parsed_example = tf.io.parse_single_example(serialized_example,
                                               features = {
                                                   'image/encoded': tf.io.FixedLenFeature([],tf.string),
                                                   'image/class/label' : tf.io.FixedLenFeature([], tf.int64),
                                               })
    image = tf.image.decode_jpeg(parsed_example['image/encoded'], channels = 3)
    image = tf.image.resize(image, size=(config.IMAGE_SIZE, config.IMAGE_SIZE))
    label = tf.cast(parsed_example['image/class/label'], tf.int64)
    image = tf.cast(image, tf.float32) / 255.0
    return image, label

In [None]:
def arcface_format(image, label_group):
    return {'inp1': image, 'inp2' : label_group}, label_group

def rescale_image(image, label_group):
    image = tf.cast(image, tf.float32) * 255.0
    return image, label_group

# Data augmentation 
def data_augment(image, label_group):
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_hue(image, 0.01)
    image = tf.image.random_saturation(image, 0.70, 1.30)
    image = tf.image.random_contrast(image, 0.80, 1.20)
    image = tf.image.random_brightness(image, 0.10)
    return image, label_group

# Dataset to obtain backbone's inference 
def get_backbone_inference_dataset(tfrecord_paths, cache = False,
                                  repeat = False, shuffle = False, augment = False):
    dataset = tf.data.Dataset.from_tensor_slices(tfrecord_paths)
    data_len = sum( [ get_num_of_images(file) for file in tfrecord_paths ] )
    dataset = dataset.shuffle( data_len//10 ) if shuffle else dataset
    dataset = dataset.flat_map(tf.data.TFRecordDataset)
    dataset = dataset.map(deserialization_fn, num_parallel_calls = AUTO)
    
    if augment:
        dataset = dataset.map(data_augment, num_parallel_calls = AUTO)
        dataset = dataset.map(rescale_image, num_parallel_calls = AUTO)
        dataset = dataset.map(arcface_format, num_parallel_calls = AUTO)
        
        if repeat:
            dataset = dataset.repeat()
        dataset = dataset.batch(config.BATCH_SIZE)
        dataset = dataset.prefetch(AUTO)
        return dataset

### Viz tfrecord images

In [None]:
backbone_infer_dataset_encode = get_backbone_inference_dataset(train_set_path, shuffle = True, augment = True)

num_cols = 3
num_rows = 5
backbone_infer_dataset_encode = backbone_infer_dataset_encode.unbatch().batch(num_cols * num_rows)
X,y = next(iter(backbone_infer_dataset_encode))
print(x['inp1'].shape)

fig = plt.figure(figsize=(15,15))
grid = axes_grid1.ImageGrid(fig, 111, nrows_ncols = (num_cols, num_rows), axes_pad=0.1)

for i, ax in enumerate(grid):
    ax.imshow(x['inp1'][i]/255)
    ax.axis("off")
    
del backbone_infer_dataset_encode


### Model

In [None]:
# Arcmarginproduct class keras layer
class ArcMarginProduct(tf.keras.layers.Layer):
    """
    Implements large margin arc distance.
    
    Referece:
        https://arxiv.org/pdf/1801.07698.pdf
        https://github.com/lyakaap/Landmark2019-1st-and-3rd-Place-Solution/
            blob/master/src/modeling/metric_learning.py
    """
    
    def __init__(self, n_classes, s = 30, m=0.50, easy_margin = False,
                ls_eps = 0.0, **kwargs):
        super(ArcMarginProduct, self).__init__(**kwargs)
        
        self.n_classes = n_classes
        self.s = s
        self.m = m 
        self.ls_eps = ls_eps
        self.easy_margin = easy_margin
        self.cos_m = tf.math.cos(m)
        self.sin_m = tf.math.sin(m)
        self.th = tf.math.cos(math.pi - m)
        self.mm = tf.math.sin(math.pi - m) * m
        
    def get_config(self):
        
        config = super().get_config().copy()
        config.update({
            'n_classes' : self.n_classes,
            's' : self.s,
            'm' : self.m,
            'ls_eps' : self.ls_eps,
            'easy_margin' : self.easy_margin
        })
        return config
    
    def build(self, input_shape):
        super(ArcMArginProduct, self).build(input_shape[0])
        
        self.W = self.add_weight(
            name = 'W',
            shape = (int(input_shape[0][-1]), self.n_classes),
            initializer = 'glorot_uniform',
            dtype = 'float32',
            trainable = True,
            regularizer = None
        )
        
    def call(self, inputs):
        X,y = inputs 
        y = tf.cast(y, dtype = tf.int32)
        cosine = tf.matmul(
            tf.math.12_normalize(X, axis = 1),
            tf.math.12_normalize(self.W, axis = 0)
        )
        sine = tf.math.sqrt(1.0 - tf.math.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = tf.where(cosine > 0, phi, cosine)
        else:
            phi = tf.where(cosine > self.th, phi, cosine - self.activity_regularizermm)
        
        one_hot = tf.cast(
            tf.one_hot(y, depth = self.n_classes),
            dtype = cosine.dtype
        )
        
        if self.ls_eps > 0:
            one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.n_classes
            
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s
        return output
        

In [None]:
def get_scale_layer(rescale = "tf"):
    
    if isinstance(rescale_mode, (list, tuple)):
        mean, std = rescale_mode
    elif rescale_mode == "torch":
        mean = np.array([0.485, 0.456, 0.406]) * 255.0
        std = np.array([0.229, 0.224, 0.225]) * 255.0
    elif rescale_mode == "tf":
        mean, std = 127.5, 127.5
    elif rescale_mode == "tf128":
        mean, std = 128.0, 128.0
    elif rescale_mode == "raw01":
        mean, std = 0, 255.0
    else:
        mean, std = 0, 1
    scaling_layer = keras.layers.Lambda(lambda x: ( tf.cast(x, tf.float32) - mean) / std)
    
    return scaling_layer

def get_clip_model():
    inp = tf.keras.layers.Input(shape = [3, 224, 224])
    backbone = TFCLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14")
    output = backbone({'pixel_value' : inp}).pooler_output
    return tf.keras.Model(inputs=[inp], outputs = [output])

def get_embedding_model():
    
    inp = tf.keras.layers.Input(shape = [None, None, 3], name = "inp1")
    label = tf.keras.layers.Input(shape = {}, name = 'inp2')
    
    # Definition of layers
    layer_resize = tf.keras.layers.Lambda(lambda x: tf.image.resize(x, [config.IMAGE_SIZE, config.IMAGE_SIZE]), name = 'resize')
    layer_scaling = get_scale_layer(rescale_mode = 'torch')
    layer_permute = tf.keras.layers.Permute((3,1,2))
    layer_backbone = get_clip_model()
    layer_dropout = tf.keras.layers.Dropout(0.2)
    layer_dense_before_arcface = tf.keras.layers.Dense(config.EMB_DIM)
    layer_margin = ArcMarginProduct(
        n_classes = config.N_CLASSES,
        s = 30,
        m = 0.3,
        name = f"head/arcface",
        dtype = 'float32'
    )
    layer_softmax = tf.keras.layers.Softmax(dtype = 'float32')
    layer_12 = tf.keras.layers.Lambda(lambda x: tf.math.12_normalize(x, axis = 1, name = 'embdedding_norm'))
    
    if config.EMB_DIM != 64:
        layers_adaptive_pooling = tfa.layers.AdaptiveAveragePooling1D(64)
    else:
        layer_adaptive_pooling = tf.keras.layers.Lambda(lambda x: x)
        
    image = layer_scaling(inp)
    image = layer_resize(image)
    image = layer_permute(image)
    backbone_output = layer_backbone(image)
    embed = layer_dropout(backbone_output)
    embed = layer_dense_before_arcface(embed)
    x = layer_margin([embed, label])
    output = layer_softmax(x)
    model = tf.keras.models.Model(inputs = [inp, label], outputs = [output])
    
    model.layers[-6].trainable = False
    opt = tf.keras.optimizers.Adam(learning_rate = config.LR)
    model.compile(
        optimizer = opt,
        loss = [tf.keras.losses.SparseCategoricalCrossentropy()],
        metrics = [tf.keras.metrics.SpaceCategoricalAccuracy(), tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5)]
    )
    
    # Definition of Embedding model 
    embed_model = keras.Sequential([
        keras.layers.InputLayer(input_shape = (None, None, 3), dtype = 'uint8'),
        layer_scaling,
        layer_resize,
        layer_permute,
        layer_backbone,
        layer_dropout,
        layer_dense_before_arcface,
        layer_adaptive_pooling,
        layer_12,
        
    ])
    
    return model, embed_model

In [None]:
with strategy.scope():
    model, emb_model = get_embedding_model()
    
if config.RESUME:
    print(f"load {config.RESUME_WEIGHT}")
    model.load_weights(config.RESUME_WEIGHT)

In [None]:
model.summary()

In [None]:
emb_model.summary()

### Scheduler

In [None]:
def get_lr_callback(plot = False):
    lr_start = 0.000001
    lr_max = 0.000005 * config.BATCH_SIZE
    lr_min = 0.000001
    lr_ramp_ep = 4
    lr_sus_ep = 0
    lr_decay = 0.95
    
    def lrfn(epoch):
        if config.RESUME:
            epoch = epoch + config.RESUME_EPOCH
        if epoch < lr_ramp_ep:
            lr = (lr_max - lr_start) / lr_ramp_ep * epoch + lr_start
        
        elif epoch < lr_ramp_ep + lr_sus_ep:
            lr = lr_max
        else:
            lr = (lr_max - lr_min) * lr_decay**(epoch - lr_ramp_ep - lr_sus_ep) + lr_min
            
        return lr
    
    if plot:
        epochs = list(range(config.EPOCHS))
        learning_rates = [lrfn(x) for x in epochs]
        plt.scatter(epochs, learning_rates)
        plt.show()
        
    lr_callback = tf.keras.callbacks.LearningRateScheduler(lrfn, verbose = False)
    return lr_callback

get_lr_callback(plot = True)        