In [1]:
!pip install plotly

Collecting plotly
  Using cached plotly-5.3.1-py2.py3-none-any.whl (23.9 MB)
Installing collected packages: plotly
Successfully installed plotly-5.3.1


In [1]:
import numpy as np
import pandas as pd

import plotly.express as px

import tensorflow as tf
import tensorflow.keras as keras

from sklearn.model_selection import train_test_split

In [2]:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

In [3]:
np.random.seed(2021)
tf.random.set_seed(2021)

In [4]:
num_classes=10

x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255

x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)

y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

x_train,x_val,y_train,y_val=train_test_split(x_train,y_train,test_size =0.2)

In [179]:
## Embedding ##
class Encoder(keras.Model):
    def __init__(self,n_layers=2,latent_features=2,hidden_featrues=64):
        super().__init__()
        self.latent_features=latent_features
        # First Layer
        self.conv_0=keras.layers.Conv2D(hidden_featrues, kernel_size=3, activation="relu",padding='same')
        self.dropout_0=keras.layers.Dropout(0.2)
        # Middle Layers
        self.convs=[keras.layers.Conv2D(hidden_featrues, kernel_size=3, activation="relu",padding='same') for _ in range(n_layers)]
        self.mxpools = [keras.layers.MaxPooling2D(pool_size=2) for _ in range(n_layers)]
        self.dropouts = [keras.layers.Dropout(0.2) for _ in range(n_layers)]
        # Integration Layer(Output Layer)
        WIDTH=2
        self.gap=keras.layers.AveragePooling2D(pool_size=WIDTH,)
        self.flatten=keras.layers.Flatten()
        self.dense=keras.layers.Dense(latent_features)
      
    def call(self,x):
        x=self.dropout_0(self.conv_0(x))
        for conv,mxp,drp in zip(self.convs,self.mxpools,self.dropouts):
            x=drp(mxp(conv(x)))
        x=self.gap(x)
        x=self.flatten(x)
        y=self.dense(x)
        return y

## Special Layers ##
def cos_sim(q,s):
    query=tf.math.l2_normalize(q, axis=-1)
    support=tf.math.l2_normalize(s, axis=-1)
    cosine=tf.matmul(query,support,transpose_b=True)
    return cosine
class CosineLayer(keras.layers.Layer):
    def __init__(self, out_features=10):
        super().__init__()
        self.out_features = out_features
    def build(self, input_shape):
        self.w = self.add_weight(shape=(self.out_features,input_shape[-1]),
                               initializer='glorot_uniform',
                               trainable=True)
    def call(self, inputs):
        return cos_sim(inputs,self.w)

## Main Model ##
class ClassModel(keras.Model):
    def __init__(self,encoder,out_features=10,mode="dense"):
        super().__init__()
        self.encoder=encoder
        self.mode=mode
        latent_features=encoder.latent_features
        
        
        if self.mode=="dense":
            self.output_layer=keras.layers.Dense(out_features,use_bias=False)
            self.out_func=keras.layers.Softmax()
        if self.mode=="cos":
            self.output_layer=CosineLayer(out_features)
        
    def call(self,x,training=False):
        latent=self.encoder(x)
        if self.mode=="dense":
            y=self.out_func(self.output_layer(latent))
            return y
        if self.mode=="cos":
            cosine=self.output_layer(latent)
            return cosine

## Loss function ##
class FocalLoss():
    def __init__(self, gamma=3, eps=1e-10):
        self.gamma = gamma
        self.eps = tf.constant(eps)
        self.softmax =keras.layers.Softmax()
        self.cce=keras.losses.CategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
        
    def __call__(self,y_true, y_pred): 
        logp = self.cce(y_true,self.softmax(y_pred+self.eps))
        p = tf.exp(-logp)
        loss = (tf.constant(1.) - p) ** self.gamma * logp
        return tf.reduce_mean(loss, axis=-1)
    
class AddMarginLoss():
    def __init__(self,s=5,m=0.3):
        self.s = s
        self.m = m
        self.lossfn=FocalLoss()
        self.__name__="loss"
    def __call__(self, y_true, y_pred):
        cosine=y_pred
        phi = cosine - self.m
        # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
        arc = (y_true * phi) + ((1.0 - y_true) * cosine)  # you can use torch.where if your torch.__version__ is 0.4
        arc *= self.s
        return self.lossfn(y_true,arc)
    
class ArcMarginLoss():
    def __init__(self,s=32,m=0.3):
        self.s = s
        self.m = m
        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m
        
        self.lossfn=FocalLoss()
        self.__name__="loss"
    def __call__(self, y_true, y_pred):
        cosine=y_pred
        
        sine = tf.sqrt(tf.constant(1.) - tf.pow(cosine, 2))
        
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = tf.where(cosine > 0, phi, cosine)
        else:
            phi = tf.where((cosine - self.th) > 0, phi, cosine - self.mm)
        
        # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
        arc = (y_true * phi) + ((1.0 - y_true) * cosine)  # you can use torch.where if your torch.__version__ is 0.4
        arc *= self.s
        return self.lossfn(y_true,arc)

In [180]:
enc=Encoder(n_layers=2,latent_features=2,hidden_featrues=64)
model=ClassModel(enc,out_features=10,mode="cos")

loss_fn=AddMarginLoss(s=5,m=0.3)
model.compile(loss=loss_fn, optimizer="adam", metrics=["accuracy"])

In [181]:
# for x,y in train_loader:
    
#     break

In [182]:
# with tf.GradientTape() as tape:
#     l,w=model(x)
#     loss=loss_fn(y,(l,w))
# gradient =tape.gradient(loss,model.trainable_weights)

In [183]:
## Recording Callbacks

class LogLatent(keras.callbacks.Callback):
    def __init__(self,encoder,test_x,test_y,skip_iters=100):
        self.df=self.df=pd.DataFrame()
        self.encoder=encoder
        self.plt_x=test_x
        self.plt_y=test_y
        self.itr=0
        self.skip_iters=skip_iters
    def on_batch_end(self,batch, logs={}):
        if self.itr%self.skip_iters==self.skip_iters-1:
            latent=self.encoder(self.plt_x)
            df_new=pd.DataFrame()
            df_new["f1"]=latent[:,0]
            df_new["f2"]=latent[:,1]
            df_new["arc"]=0.5
            df_new["iterations"]=np.repeat(self.itr,len(latent))

            df_new["label"]=self.plt_y
            self.df=self.df.append(df_new)
        self.itr+=1
    def normalize(self):
        # Normalize latent by maximum value
        self.mx_norm=np.linalg.norm(self.df[["f1","f2"]].values,axis=1).max()
        self.df[["f1",'f2']]=self.df[["f1",'f2']].apply(lambda x: x/self.mx_norm)

        # Project latents into arc
        df=self.df.copy()
        norm=np.linalg.norm(df[["f1","f2"]].values,axis=1,keepdims=True)
        df[["f1",'f2']]=df[["f1",'f2']].values/norm
        df["arc"]=df["arc"].map({0.5:0.9},na_action=None)
        
        self.df=self.df.append(df)
        


In [184]:
batch_size = 128
epochs = 10
plt_x,plt_y=x_test[:batch_size],y_test[:batch_size]
lg=LogLatent(enc,plt_x,plt_y.argmax(-1))

train_loader = tf.data.Dataset.from_tensor_slices((x_train,y_train)).cache().shuffle(len(x_train)//2).batch(batch_size)
val_loader = tf.data.Dataset.from_tensor_slices((x_val,y_val)).cache().shuffle(len(x_val)//2).batch(batch_size*2)


In [185]:
try:
    model.fit(train_loader, 
              epochs=epochs,
              validation_data=val_loader,
              callbacks=[lg])
except KeyboardInterrupt:
    print("KeyboardInterrupt")

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [186]:
model.evaluate(x_test, y_test, batch_size=batch_size)



[1.0719693899154663, 0.8503000140190125]

In [187]:
lg.normalize()

In [None]:
fig=px.scatter(lg.df, x="f1", y="f2", animation_frame="iterations", color="label",opacity=lg.df.arc,
                range_x=[-1.5,1.5],
              range_y=[-1.5,1.5])
fig.update_yaxes(
    scaleanchor = "x",
    scaleratio = 1,
)

In [189]:
fig.write_html("history/cos_tf.html")