<a href="https://colab.research.google.com/github/MHosseinHashemi/Image_Similarity/blob/main/Image_Simmilarity_CenterLoss_TF_v4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import random
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, RandomFlip, RandomRotation, Dense, Dropout, Lambda

from tqdm import tqdm
from collections import defaultdict

In [3]:
(train_data, test_data, validation_data), info = tfds.load("oxford_flowers102", split=['train', 'validation', 'test'], as_supervised=True, with_info=True)

In [4]:
height = 128
width = 128

def preprocess_images(image, label, height, width):
    # image = tf.image.resize_with_crop_or_pad(image, target_height=height, target_width=width)
    image = tf.image.resize(image, [width, height])
    image = tf.cast(image, tf.float32) / 255.0
    return image, label


In [5]:
train_ds = train_data.map(lambda image, label: preprocess_images(image, label, height, width))

In [6]:
test_ds = test_data.map(lambda image, label: preprocess_images(image, label, height, width))

In [7]:
def data_loader(data):
  x = []
  y = []
  for img, label in tqdm(data.as_numpy_iterator()):
    x.append(img)
    y.append(label)

  return x, y

In [8]:
x_train, y_train = data_loader(train_ds)

1020it [00:02, 451.41it/s]


In [9]:
x_test, y_test = data_loader(test_ds)

1020it [00:02, 458.73it/s]


In [10]:
# Base Model
MODEL_URL = "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_s/feature_vector/2"

# model = tf.keras.Sequential([
#     tf.keras.layers.RandomFlip(),
#     tf.keras.layers.RandomRotation(0.3),
#     hub.KerasLayer(MODEL_URL, trainable=True),
#     tf.keras.layers.Dropout(0.25),
#     tf.keras.layers.Dense(128, activation=None),
#     tf.keras.layers.Dense(102, activation='softmax')
#     # tf.keras.layers.Lambda(lambda x: tf.math.l2_normalize(x, axis=1)) # L2 normalize embeddings
# ])

# model.build([None, height, height, 3])
# model.summary()

# With Functional API to prevent Further Err
input_layer = Input(shape=(height, width, 3))
x = RandomFlip()(input_layer)
x = RandomRotation(0.3)(x)
x = hub.KerasLayer(MODEL_URL, trainable=True)(x)
x = Dropout(0.25)(x)
x = Dense(128, activation=None)(x)
x = Lambda(lambda x: tf.math.l2_normalize(x, axis=1))(x)  # L2 normalize embeddings
output_layer = Dense(102, activation='softmax')(x)

model = tf.keras.Model(inputs=input_layer, outputs=output_layer)

model.summary()



Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 128, 128, 3)]     0         
                                                                 
 random_flip (RandomFlip)    (None, 128, 128, 3)       0         
                                                                 
 random_rotation (RandomRota  (None, 128, 128, 3)      0         
 tion)                                                           
                                                                 
 keras_layer (KerasLayer)    (None, 1280)              20331360  
                                                                 
 dropout (Dropout)           (None, 1280)              0         
                                                                 
 dense (Dense)               (None, 128)               163968    
                                                             

In [11]:
def batch_me(images, labels, batch_size, samples_per_class):
  temp_dict = defaultdict(list) # A Dic of Lists to save img, label pairs as one object
  for img, label in zip(images, labels):
    temp_dict[label].append(img)

  while True:
    batch_x = []
    batch_y = []
    while len(batch_x) < batch_size:
      for category, examples in temp_dict.items():
        # Only feed as large as the "samples per class"
        # If the batch did not had enough space, feed as much as it has
        n_samples = min(samples_per_class, (batch_size - len(batch_x)))
        if n_samples == 0:
          break
        # Pick randomly from simmilar images of the same category
        samples = random.sample(examples, k=n_samples)
        # Add corresponding x, y values to the batch
        batch_x.extend(samples)
        batch_y.extend([category] * len(samples))

    # It should be a continous operation
    yield np.array(batch_x), np.array(batch_y)


In [12]:
def center_loss(feature_vector, center):
    difference = feature_vector - center
    loss = tf.reduce_mean(tf.reduce_sum(difference**2, axis=1))

    return loss

# Debug

In [13]:
# x_train[0].shape

In [14]:
# temp=np.expand_dims(x_train[0], axis=0)
# temp.shape

In [19]:
# feature_extraction_model = Model(inputs=model.input, outputs=model.layers[-2].output)

In [16]:
# feature_extraction_model.predict(temp).shape

In [17]:
# l_1 = [21, 12, 33, 46]
# l_2 = [15, 16, 17, 18]

# result_dict = {key: value for key, value in zip(l_1, l_2)}

# print(result_dict)

# End of Debug

In [20]:
feature_extraction_model = Model(inputs=model.input, outputs=model.layers[-2].output)

In [17]:
# # Extarct All feature vectors and their corresponding centers
# raw_features = {}
# for k in tqdm(range(len(x_train))):
#     if y_train[k] in raw_features.keys():
#         # Avergae the new center with the previous one and replace it
#         new_center = feature_extraction_model.predict(np.expand_dims(x_train[k], axis=0)).mean()
#         prev_center = raw_features[y_train[k]]
#         raw_features[y_train[k]] = [(prev_center + new_center )/2] * 128
#         del new_center, prev_center

#     else:
#         raw_features[y_train[k]].append([feature_extraction_model.predict(np.expand_dims(x_train[k], axis=0)).mean()] * 128)


# all_features = {key: value for key, value in raw_features.items()}

# del raw_features,
# all_features

In [None]:
# raw_features = {}

# for k, (x, y) in tqdm(enumerate(zip(x_train, y_train))):
#     if y in raw_features:
#         new_center = feature_extraction_model.predict(np.expand_dims(x, axis=0)).mean()
#         prev_center = raw_features[y]
#         averaged_center = [(prev + new_center) / 2 for prev in prev_center]
#         raw_features[y] = [averaged_center] * 128
#     else:
#         feature_vector = feature_extraction_model.predict(np.expand_dims(x, axis=0)).mean()
#         raw_features[y] = [feature_vector] * 128

# all_features = {key: value for key, value in raw_features.items()}

In [21]:
raw_features = {}
class_feature_vectors = {}  # Store pre-calculated feature vectors for each class

# Calculate and store class feature vectors
for x, y in tqdm(zip(x_train, y_train)):
    if y not in class_feature_vectors:
        feature_vector = feature_extraction_model.predict(np.expand_dims(x, axis=0)).mean()
        class_feature_vectors[y] = feature_vector

# Calculate raw_features using class_feature_vectors
for x, y in tqdm(zip(x_train, y_train)):
    if y in raw_features:
        new_center = feature_extraction_model.predict(np.expand_dims(x, axis=0)).mean()
        raw_features[y] = [(prev + new_center) / 2 for prev in raw_features[y]]
    else:
        raw_features[y] = [class_feature_vectors[y]] * 128

all_features = {key: value for key, value in raw_features.items()}

# Clean up memory by releasing unnecessary variables
del class_feature_vectors


0it [00:00, ?it/s]



1it [00:02,  2.93s/it]



2it [00:03,  1.30s/it]



3it [00:03,  1.27it/s]



4it [00:03,  1.87it/s]



5it [00:03,  2.36it/s]



6it [00:03,  3.02it/s]



7it [00:03,  3.65it/s]



8it [00:04,  4.21it/s]



9it [00:04,  4.77it/s]



10it [00:04,  5.18it/s]



12it [00:04,  7.18it/s]



13it [00:04,  7.18it/s]



14it [00:04,  7.15it/s]



15it [00:05,  6.89it/s]



16it [00:05,  6.82it/s]



17it [00:05,  6.79it/s]



19it [00:05,  8.67it/s]



22it [00:05, 11.93it/s]



24it [00:05,  9.75it/s]



26it [00:06,  8.45it/s]



28it [00:06,  9.53it/s]



30it [00:06, 10.57it/s]



32it [00:06, 11.26it/s]



34it [00:06,  9.64it/s]



36it [00:07, 10.25it/s]



38it [00:07,  8.25it/s]



39it [00:07,  7.75it/s]



40it [00:07,  7.57it/s]



42it [00:07,  9.08it/s]



43it [00:08,  8.51it/s]



45it [00:08,  9.54it/s]



46it [00:08,  8.80it/s]



48it [00:08,  9.08it/s]



49it [00:08,  8.53it/s]



53it [00:08, 13.82it/s]



55it [00:09, 10.41it/s]



60it [00:09, 15.98it/s]



62it [00:09, 12.32it/s]



64it [00:09, 10.41it/s]



66it [00:10, 10.81it/s]



69it [00:10, 13.03it/s]



71it [00:10, 12.89it/s]



73it [00:10, 10.28it/s]



75it [00:10, 10.91it/s]



77it [00:11,  9.36it/s]



80it [00:11, 11.80it/s]



82it [00:11, 10.05it/s]



84it [00:11, 10.53it/s]



89it [00:11, 15.98it/s]



91it [00:11, 15.40it/s]



93it [00:12, 14.85it/s]



95it [00:12, 14.42it/s]



97it [00:12, 14.09it/s]



107it [00:12, 26.95it/s]



110it [00:12, 21.78it/s]



113it [00:13, 13.65it/s]



121it [00:13, 19.15it/s]



124it [00:13, 17.45it/s]



126it [00:14, 14.93it/s]



130it [00:14, 15.65it/s]



133it [00:14, 16.39it/s]



138it [00:14, 20.32it/s]



141it [00:14, 16.16it/s]



143it [00:15, 15.34it/s]



146it [00:15,  9.07it/s]



148it [00:15,  9.60it/s]



150it [00:16, 10.15it/s]



157it [00:16, 17.67it/s]



162it [00:16, 21.00it/s]



167it [00:16, 24.04it/s]



170it [00:16, 18.38it/s]



173it [00:16, 18.03it/s]



177it [00:17, 19.83it/s]



182it [00:17, 23.27it/s]



185it [00:17, 17.90it/s]



204it [00:17, 43.01it/s]



210it [00:18, 28.19it/s]



224it [00:18, 40.04it/s]



230it [00:18, 39.47it/s]



235it [00:18, 38.26it/s]



268it [00:18, 84.04it/s]



290it [00:18, 97.43it/s]



301it [00:19, 91.26it/s]



318it [00:19, 97.40it/s]



329it [00:19, 90.43it/s]



1020it [00:19, 52.17it/s]
0it [00:00, ?it/s]



11it [00:00, 66.20it/s]



18it [00:00, 54.84it/s]



24it [00:00, 31.81it/s]



28it [00:00, 30.57it/s]



32it [00:01, 22.80it/s]



36it [00:01, 24.24it/s]



41it [00:01, 26.84it/s]



44it [00:01, 23.44it/s]



47it [00:01, 22.25it/s]



50it [00:01, 21.62it/s]



53it [00:02, 16.22it/s]



56it [00:02, 17.34it/s]



58it [00:02, 12.61it/s]



60it [00:02, 12.62it/s]



66it [00:02, 18.88it/s]



69it [00:03, 15.69it/s]



71it [00:03, 15.06it/s]



74it [00:03, 16.17it/s]



78it [00:03, 18.90it/s]



81it [00:03, 18.94it/s]



83it [00:03, 17.37it/s]



85it [00:04, 16.04it/s]



87it [00:04, 11.81it/s]



89it [00:04, 11.92it/s]



91it [00:04, 11.79it/s]



93it [00:05, 10.60it/s]



95it [00:05, 10.06it/s]



97it [00:05,  9.62it/s]



99it [00:05,  7.15it/s]



100it [00:06,  6.40it/s]



101it [00:06,  5.82it/s]



102it [00:06,  5.49it/s]



103it [00:06,  5.75it/s]



104it [00:06,  6.01it/s]



105it [00:07,  6.20it/s]



106it [00:07,  6.24it/s]



108it [00:07,  8.06it/s]



110it [00:07,  9.50it/s]



113it [00:07, 12.06it/s]



115it [00:07,  9.86it/s]



117it [00:08,  8.71it/s]



118it [00:08,  8.30it/s]



119it [00:08,  7.95it/s]



120it [00:08,  7.46it/s]



122it [00:08,  8.94it/s]



123it [00:09,  8.37it/s]



125it [00:09,  9.99it/s]



127it [00:09, 10.70it/s]



129it [00:09,  8.94it/s]



131it [00:09,  9.88it/s]



133it [00:09, 10.64it/s]



135it [00:10,  8.76it/s]



136it [00:10,  8.27it/s]



137it [00:10,  7.83it/s]



141it [00:10, 12.42it/s]



143it [00:10, 12.51it/s]



145it [00:11, 10.09it/s]



147it [00:11, 10.92it/s]



150it [00:11, 12.73it/s]



152it [00:11,  9.95it/s]



154it [00:12,  8.74it/s]



155it [00:12,  8.27it/s]



156it [00:12,  7.82it/s]



158it [00:12,  8.35it/s]



159it [00:12,  7.81it/s]



160it [00:12,  7.28it/s]



161it [00:13,  6.47it/s]



163it [00:13,  8.15it/s]



164it [00:13,  7.76it/s]



165it [00:13,  7.46it/s]



166it [00:13,  6.78it/s]



169it [00:13, 10.33it/s]



171it [00:14, 10.93it/s]



173it [00:14, 11.67it/s]



175it [00:14,  9.32it/s]



177it [00:14, 10.10it/s]



179it [00:15,  8.71it/s]



180it [00:15,  8.23it/s]



181it [00:15,  7.87it/s]



183it [00:15,  9.34it/s]



186it [00:15, 12.02it/s]



188it [00:15,  9.71it/s]



190it [00:16,  8.60it/s]



191it [00:16,  8.21it/s]



192it [00:16,  7.78it/s]



193it [00:16,  6.92it/s]



194it [00:17,  5.87it/s]



195it [00:17,  5.28it/s]



196it [00:17,  5.00it/s]



197it [00:17,  4.83it/s]



198it [00:17,  4.62it/s]



199it [00:18,  4.59it/s]



200it [00:18,  4.57it/s]



201it [00:18,  4.92it/s]



202it [00:18,  5.26it/s]



203it [00:18,  5.73it/s]



207it [00:19, 10.96it/s]



209it [00:19, 11.50it/s]



211it [00:19,  9.54it/s]



213it [00:19,  8.39it/s]



214it [00:19,  8.01it/s]



215it [00:20,  7.80it/s]



216it [00:20,  7.48it/s]



217it [00:20,  7.15it/s]



218it [00:20,  6.92it/s]



219it [00:20,  6.77it/s]



220it [00:20,  6.59it/s]



221it [00:21,  6.45it/s]



222it [00:21,  6.41it/s]



223it [00:21,  6.47it/s]



225it [00:21,  8.10it/s]



226it [00:21,  7.48it/s]



227it [00:21,  7.23it/s]



229it [00:21,  8.77it/s]



230it [00:22,  8.25it/s]



232it [00:22,  9.70it/s]



233it [00:22,  8.61it/s]



234it [00:22,  8.04it/s]



235it [00:22,  7.50it/s]



236it [00:22,  7.27it/s]



237it [00:23,  6.92it/s]



238it [00:23,  6.79it/s]



239it [00:23,  6.69it/s]



240it [00:23,  6.60it/s]



241it [00:23,  6.51it/s]



242it [00:23,  6.46it/s]



243it [00:24,  6.48it/s]



244it [00:24,  6.38it/s]



245it [00:24,  6.36it/s]



246it [00:24,  6.29it/s]



247it [00:24,  6.36it/s]



248it [00:24,  6.35it/s]



249it [00:24,  6.43it/s]



250it [00:25,  6.36it/s]



251it [00:25,  6.28it/s]



252it [00:25,  6.23it/s]



253it [00:25,  6.23it/s]



254it [00:25,  6.30it/s]



255it [00:25,  6.28it/s]



256it [00:26,  6.24it/s]



257it [00:26,  6.24it/s]



258it [00:26,  6.36it/s]



259it [00:26,  6.31it/s]



260it [00:26,  6.40it/s]



261it [00:26,  6.32it/s]



262it [00:27,  6.32it/s]



263it [00:27,  6.31it/s]



264it [00:27,  6.30it/s]



265it [00:27,  6.30it/s]



266it [00:27,  6.22it/s]



267it [00:27,  6.28it/s]



269it [00:27,  8.35it/s]



270it [00:28,  7.70it/s]



271it [00:28,  7.40it/s]



272it [00:28,  7.01it/s]



273it [00:28,  5.95it/s]



274it [00:28,  5.28it/s]



275it [00:29,  4.93it/s]



276it [00:29,  4.65it/s]



277it [00:29,  4.59it/s]



278it [00:29,  4.57it/s]



279it [00:30,  4.45it/s]



280it [00:30,  4.40it/s]



281it [00:30,  4.62it/s]



282it [00:30,  5.01it/s]



283it [00:30,  5.29it/s]



284it [00:30,  5.62it/s]



285it [00:31,  5.83it/s]



286it [00:31,  5.92it/s]



287it [00:31,  5.94it/s]



288it [00:31,  6.17it/s]



289it [00:31,  6.34it/s]



291it [00:31,  8.28it/s]



292it [00:32,  7.56it/s]



293it [00:32,  7.30it/s]



294it [00:32,  6.73it/s]



295it [00:32,  6.67it/s]



296it [00:32,  6.62it/s]



297it [00:32,  6.59it/s]



298it [00:33,  6.49it/s]



299it [00:33,  6.54it/s]



300it [00:33,  6.37it/s]



302it [00:33,  8.10it/s]



303it [00:33,  7.58it/s]



304it [00:33,  7.04it/s]



305it [00:33,  6.95it/s]



306it [00:34,  6.89it/s]



307it [00:34,  6.83it/s]



308it [00:34,  6.55it/s]



309it [00:34,  6.61it/s]



310it [00:34,  6.60it/s]



311it [00:34,  6.47it/s]



312it [00:35,  6.45it/s]



313it [00:35,  6.52it/s]



314it [00:35,  6.37it/s]



315it [00:35,  6.44it/s]



316it [00:35,  6.36it/s]



317it [00:35,  6.23it/s]



319it [00:36,  8.35it/s]



320it [00:36,  7.90it/s]



321it [00:36,  7.49it/s]



322it [00:36,  7.04it/s]



323it [00:36,  6.67it/s]



324it [00:36,  6.72it/s]



325it [00:36,  6.66it/s]



327it [00:37,  8.67it/s]



328it [00:37,  8.10it/s]



329it [00:37,  7.58it/s]



330it [00:37,  7.14it/s]



331it [00:37,  6.83it/s]



332it [00:37,  6.83it/s]



333it [00:38,  6.83it/s]



334it [00:38,  6.63it/s]



335it [00:38,  6.63it/s]



336it [00:38,  6.53it/s]



337it [00:38,  6.38it/s]



338it [00:38,  6.44it/s]



339it [00:38,  6.14it/s]



340it [00:39,  6.16it/s]



341it [00:39,  6.25it/s]



342it [00:39,  6.20it/s]



343it [00:39,  6.26it/s]



344it [00:39,  6.21it/s]



345it [00:39,  6.36it/s]



347it [00:40,  8.34it/s]



348it [00:40,  7.64it/s]



349it [00:40,  7.30it/s]



350it [00:40,  5.78it/s]



351it [00:40,  5.18it/s]



352it [00:41,  4.99it/s]



353it [00:41,  4.88it/s]



354it [00:41,  4.54it/s]



355it [00:41,  4.52it/s]



356it [00:42,  4.48it/s]



357it [00:42,  4.52it/s]



358it [00:42,  4.55it/s]



359it [00:42,  4.90it/s]



360it [00:42,  5.39it/s]



361it [00:42,  5.69it/s]



362it [00:43,  5.96it/s]



363it [00:43,  6.16it/s]



364it [00:43,  6.32it/s]



365it [00:43,  6.25it/s]



366it [00:43,  6.18it/s]



367it [00:43,  6.22it/s]



368it [00:44,  6.23it/s]



369it [00:44,  6.44it/s]



370it [00:44,  6.53it/s]



371it [00:44,  6.35it/s]



372it [00:44,  6.32it/s]



373it [00:44,  6.41it/s]



374it [00:44,  6.41it/s]



375it [00:45,  6.43it/s]



376it [00:45,  6.33it/s]



377it [00:45,  6.39it/s]



378it [00:45,  6.40it/s]



379it [00:45,  6.34it/s]



380it [00:45,  6.44it/s]



381it [00:46,  6.27it/s]



382it [00:46,  6.35it/s]



383it [00:46,  6.39it/s]



384it [00:46,  6.41it/s]



385it [00:46,  6.07it/s]



386it [00:46,  6.14it/s]



387it [00:47,  6.21it/s]



388it [00:47,  6.32it/s]



389it [00:47,  6.48it/s]



390it [00:47,  6.42it/s]



391it [00:47,  6.34it/s]



392it [00:47,  6.42it/s]



393it [00:47,  6.39it/s]



394it [00:48,  6.20it/s]



395it [00:48,  6.14it/s]



396it [00:48,  6.28it/s]



397it [00:48,  6.34it/s]



398it [00:48,  6.34it/s]



399it [00:48,  6.27it/s]



400it [00:49,  6.35it/s]



401it [00:49,  6.45it/s]



402it [00:49,  5.84it/s]



403it [00:49,  5.91it/s]



404it [00:49,  6.01it/s]



405it [00:49,  5.98it/s]



406it [00:50,  6.06it/s]



407it [00:50,  6.17it/s]



408it [00:50,  6.29it/s]



409it [00:50,  6.28it/s]



410it [00:50,  6.26it/s]



411it [00:50,  6.38it/s]



412it [00:51,  6.22it/s]



413it [00:51,  6.32it/s]



414it [00:51,  6.36it/s]



415it [00:51,  6.30it/s]



416it [00:51,  6.34it/s]



417it [00:51,  6.12it/s]



418it [00:52,  6.20it/s]



419it [00:52,  6.32it/s]



420it [00:52,  6.29it/s]



421it [00:52,  5.86it/s]



422it [00:52,  5.35it/s]



423it [00:53,  4.76it/s]



424it [00:53,  4.67it/s]



425it [00:53,  4.61it/s]



426it [00:53,  4.63it/s]



427it [00:53,  4.60it/s]



428it [00:54,  4.46it/s]



429it [00:54,  4.42it/s]



430it [00:54,  4.52it/s]



431it [00:54,  4.94it/s]



432it [00:54,  5.30it/s]



433it [00:55,  5.57it/s]



434it [00:55,  5.81it/s]



435it [00:55,  5.89it/s]



436it [00:55,  6.11it/s]



437it [00:55,  6.27it/s]



438it [00:55,  6.20it/s]



439it [00:56,  6.06it/s]



440it [00:56,  6.16it/s]



441it [00:56,  6.30it/s]



442it [00:56,  6.45it/s]



443it [00:56,  6.47it/s]



444it [00:56,  6.58it/s]



445it [00:56,  6.45it/s]



446it [00:57,  6.36it/s]



447it [00:57,  6.42it/s]



448it [00:57,  6.28it/s]



449it [00:57,  6.25it/s]



450it [00:57,  6.31it/s]



451it [00:57,  6.46it/s]



452it [00:58,  6.30it/s]



453it [00:58,  6.39it/s]



454it [00:58,  6.53it/s]



455it [00:58,  6.49it/s]



456it [00:58,  6.34it/s]



457it [00:58,  6.37it/s]



458it [00:58,  6.35it/s]



459it [00:59,  5.72it/s]



460it [00:59,  6.02it/s]



461it [00:59,  6.04it/s]



462it [00:59,  6.10it/s]



463it [00:59,  6.27it/s]



464it [00:59,  6.23it/s]



465it [01:00,  6.12it/s]



466it [01:00,  6.10it/s]



467it [01:00,  6.08it/s]



468it [01:00,  6.32it/s]



469it [01:00,  6.27it/s]



470it [01:00,  6.39it/s]



471it [01:01,  6.32it/s]



472it [01:01,  6.40it/s]



473it [01:01,  6.21it/s]



474it [01:01,  6.11it/s]



475it [01:01,  6.16it/s]



476it [01:01,  6.16it/s]



477it [01:02,  6.15it/s]



478it [01:02,  6.31it/s]



479it [01:02,  6.40it/s]



480it [01:02,  6.39it/s]



481it [01:02,  6.49it/s]



482it [01:02,  6.47it/s]



483it [01:02,  6.41it/s]



484it [01:03,  6.33it/s]



485it [01:03,  6.31it/s]



486it [01:03,  6.22it/s]



487it [01:03,  6.34it/s]



488it [01:03,  6.45it/s]



489it [01:03,  6.48it/s]



490it [01:04,  6.38it/s]



491it [01:04,  6.09it/s]



492it [01:04,  6.04it/s]



493it [01:04,  5.74it/s]



494it [01:04,  5.27it/s]



495it [01:05,  4.86it/s]



496it [01:05,  4.62it/s]



497it [01:05,  4.54it/s]



498it [01:05,  4.44it/s]



499it [01:06,  4.43it/s]



500it [01:06,  4.22it/s]



501it [01:06,  4.12it/s]



502it [01:06,  4.57it/s]



503it [01:06,  4.96it/s]



504it [01:07,  5.26it/s]



505it [01:07,  5.53it/s]



506it [01:07,  5.72it/s]



507it [01:07,  5.83it/s]



508it [01:07,  6.03it/s]



509it [01:07,  6.08it/s]



510it [01:07,  6.22it/s]



511it [01:08,  6.23it/s]



512it [01:08,  6.22it/s]



513it [01:08,  6.30it/s]



514it [01:08,  6.35it/s]



515it [01:08,  6.36it/s]



516it [01:08,  6.26it/s]



517it [01:09,  6.34it/s]



518it [01:09,  6.34it/s]



519it [01:09,  6.12it/s]



520it [01:09,  6.16it/s]



521it [01:09,  6.12it/s]



522it [01:09,  6.17it/s]



523it [01:10,  5.85it/s]



524it [01:10,  5.79it/s]



525it [01:10,  5.87it/s]



526it [01:10,  5.92it/s]



527it [01:10,  6.08it/s]



528it [01:10,  6.12it/s]



529it [01:11,  6.19it/s]



530it [01:11,  6.27it/s]



531it [01:11,  6.21it/s]



532it [01:11,  6.26it/s]



533it [01:11,  6.05it/s]



534it [01:11,  6.13it/s]



535it [01:12,  6.06it/s]



536it [01:12,  6.05it/s]



537it [01:12,  6.05it/s]



538it [01:12,  6.03it/s]



539it [01:12,  6.15it/s]



540it [01:12,  6.24it/s]



541it [01:13,  6.33it/s]



542it [01:13,  6.12it/s]



543it [01:13,  6.06it/s]



544it [01:13,  5.91it/s]



545it [01:13,  5.89it/s]



546it [01:13,  5.96it/s]



547it [01:14,  6.01it/s]



548it [01:14,  6.11it/s]



549it [01:14,  6.09it/s]



550it [01:14,  5.83it/s]



551it [01:14,  5.96it/s]



552it [01:14,  6.09it/s]



553it [01:15,  6.03it/s]



554it [01:15,  5.97it/s]



555it [01:15,  6.02it/s]



556it [01:15,  5.72it/s]



557it [01:15,  5.82it/s]



558it [01:15,  5.97it/s]



559it [01:16,  6.14it/s]



560it [01:16,  6.11it/s]



561it [01:16,  6.10it/s]



562it [01:16,  5.83it/s]



563it [01:16,  5.27it/s]



564it [01:17,  4.83it/s]



565it [01:17,  4.73it/s]



566it [01:17,  4.33it/s]



567it [01:17,  4.21it/s]



568it [01:18,  4.20it/s]



569it [01:18,  4.15it/s]



570it [01:18,  4.14it/s]



571it [01:18,  4.46it/s]



572it [01:18,  4.82it/s]



573it [01:19,  5.19it/s]



574it [01:19,  5.44it/s]



575it [01:19,  5.56it/s]



576it [01:19,  5.58it/s]



577it [01:19,  5.76it/s]



578it [01:19,  5.96it/s]



579it [01:20,  6.07it/s]



580it [01:20,  6.07it/s]



581it [01:20,  6.21it/s]



582it [01:20,  6.19it/s]



583it [01:20,  6.12it/s]



584it [01:20,  6.19it/s]



585it [01:20,  6.28it/s]



586it [01:21,  6.28it/s]



587it [01:21,  6.24it/s]



588it [01:21,  6.24it/s]



589it [01:21,  6.05it/s]



590it [01:21,  6.02it/s]



591it [01:21,  6.05it/s]



592it [01:22,  6.08it/s]



593it [01:22,  6.13it/s]



594it [01:22,  6.09it/s]



595it [01:22,  6.16it/s]



596it [01:22,  6.05it/s]



597it [01:22,  5.99it/s]



598it [01:23,  6.18it/s]



599it [01:23,  6.08it/s]



600it [01:23,  6.20it/s]



601it [01:23,  6.22it/s]



602it [01:23,  6.17it/s]



603it [01:23,  6.20it/s]



604it [01:24,  6.29it/s]



605it [01:24,  6.15it/s]



606it [01:24,  6.16it/s]



607it [01:24,  6.16it/s]



608it [01:24,  6.19it/s]



609it [01:24,  6.13it/s]



610it [01:25,  6.23it/s]



611it [01:25,  6.32it/s]



612it [01:25,  6.32it/s]



613it [01:25,  6.36it/s]



614it [01:25,  6.35it/s]



615it [01:25,  6.08it/s]



616it [01:26,  6.17it/s]



617it [01:26,  6.11it/s]



618it [01:26,  6.31it/s]



619it [01:26,  6.25it/s]



620it [01:26,  6.31it/s]



621it [01:26,  6.29it/s]



622it [01:26,  6.35it/s]



623it [01:27,  6.28it/s]



624it [01:27,  6.24it/s]



625it [01:27,  6.23it/s]



626it [01:27,  6.25it/s]



627it [01:27,  6.35it/s]



628it [01:27,  6.19it/s]



629it [01:28,  6.27it/s]



630it [01:28,  6.28it/s]



631it [01:28,  6.29it/s]



632it [01:28,  5.83it/s]



633it [01:28,  5.28it/s]



634it [01:29,  4.63it/s]



635it [01:29,  4.43it/s]



636it [01:29,  4.31it/s]



637it [01:29,  4.43it/s]



638it [01:30,  4.21it/s]



639it [01:30,  4.11it/s]



640it [01:30,  4.06it/s]



641it [01:30,  4.09it/s]



642it [01:30,  4.55it/s]



643it [01:31,  4.94it/s]



644it [01:31,  5.35it/s]



645it [01:31,  5.63it/s]



646it [01:31,  5.65it/s]



647it [01:31,  5.94it/s]



648it [01:31,  5.99it/s]



649it [01:32,  6.10it/s]



650it [01:32,  5.99it/s]



651it [01:32,  6.05it/s]



652it [01:32,  6.03it/s]



653it [01:32,  6.23it/s]



654it [01:32,  6.20it/s]



655it [01:33,  6.09it/s]



656it [01:33,  6.18it/s]



657it [01:33,  6.21it/s]



658it [01:33,  5.97it/s]



659it [01:33,  6.10it/s]



660it [01:33,  6.19it/s]



661it [01:34,  6.08it/s]



662it [01:34,  6.14it/s]



663it [01:34,  6.08it/s]



664it [01:34,  6.00it/s]



665it [01:34,  6.13it/s]



666it [01:34,  6.20it/s]



667it [01:35,  5.99it/s]



668it [01:35,  5.95it/s]



669it [01:35,  6.05it/s]



670it [01:35,  6.12it/s]



671it [01:35,  5.98it/s]



672it [01:35,  5.99it/s]



673it [01:36,  6.00it/s]



674it [01:36,  6.04it/s]



675it [01:36,  6.01it/s]



676it [01:36,  5.86it/s]



677it [01:36,  5.97it/s]



678it [01:36,  6.06it/s]



679it [01:37,  6.05it/s]



680it [01:37,  6.12it/s]



681it [01:37,  5.91it/s]



682it [01:37,  5.93it/s]



683it [01:37,  6.00it/s]



684it [01:37,  5.93it/s]



685it [01:38,  5.91it/s]



686it [01:38,  5.83it/s]



687it [01:38,  5.92it/s]



688it [01:38,  5.87it/s]



689it [01:38,  5.93it/s]



690it [01:38,  6.11it/s]



691it [01:39,  6.08it/s]



692it [01:39,  6.06it/s]



693it [01:39,  6.02it/s]



694it [01:39,  5.95it/s]



695it [01:39,  6.01it/s]



696it [01:39,  6.09it/s]



697it [01:40,  6.06it/s]



698it [01:40,  6.01it/s]



699it [01:40,  6.03it/s]



700it [01:40,  5.99it/s]



701it [01:40,  5.62it/s]



702it [01:41,  5.09it/s]



703it [01:41,  4.56it/s]



704it [01:41,  4.37it/s]



705it [01:41,  4.27it/s]



706it [01:42,  4.27it/s]



707it [01:42,  4.41it/s]



708it [01:42,  4.30it/s]



709it [01:42,  4.39it/s]



710it [01:42,  4.88it/s]



711it [01:43,  5.02it/s]



712it [01:43,  5.33it/s]



713it [01:43,  5.45it/s]



714it [01:43,  5.62it/s]



715it [01:43,  5.78it/s]



716it [01:43,  6.03it/s]



717it [01:44,  6.03it/s]



718it [01:44,  6.01it/s]



719it [01:44,  6.09it/s]



720it [01:44,  5.85it/s]



721it [01:44,  5.69it/s]



722it [01:44,  6.00it/s]



723it [01:45,  6.02it/s]



724it [01:45,  6.17it/s]



725it [01:45,  6.13it/s]



726it [01:45,  6.03it/s]



727it [01:45,  6.19it/s]



728it [01:45,  6.11it/s]



729it [01:46,  5.91it/s]



730it [01:46,  5.92it/s]



731it [01:46,  5.93it/s]



732it [01:46,  5.93it/s]



733it [01:46,  6.01it/s]



734it [01:46,  6.19it/s]



735it [01:46,  6.22it/s]



736it [01:47,  6.21it/s]



737it [01:47,  6.37it/s]



738it [01:47,  5.92it/s]



739it [01:47,  6.00it/s]



740it [01:47,  6.11it/s]



741it [01:47,  6.15it/s]



742it [01:48,  6.16it/s]



743it [01:48,  6.29it/s]



744it [01:48,  6.21it/s]



745it [01:48,  6.08it/s]



746it [01:48,  6.00it/s]



747it [01:48,  5.99it/s]



748it [01:49,  5.89it/s]



749it [01:49,  6.04it/s]



750it [01:49,  5.99it/s]



751it [01:49,  5.87it/s]



752it [01:49,  6.07it/s]



753it [01:49,  5.98it/s]



754it [01:50,  6.04it/s]



755it [01:50,  6.16it/s]



756it [01:50,  6.00it/s]



757it [01:50,  5.85it/s]



758it [01:50,  5.98it/s]



759it [01:50,  6.07it/s]



760it [01:51,  6.04it/s]



761it [01:51,  5.97it/s]



762it [01:51,  5.88it/s]



763it [01:51,  5.92it/s]



764it [01:51,  5.92it/s]



765it [01:51,  5.92it/s]



766it [01:52,  5.87it/s]



767it [01:52,  6.02it/s]



768it [01:52,  5.97it/s]



769it [01:52,  5.91it/s]



770it [01:52,  5.17it/s]



771it [01:53,  4.78it/s]



772it [01:53,  4.40it/s]



773it [01:53,  4.22it/s]



774it [01:53,  4.29it/s]



775it [01:54,  4.24it/s]



776it [01:54,  4.19it/s]



777it [01:54,  4.24it/s]



778it [01:54,  4.50it/s]



779it [01:54,  4.94it/s]



780it [01:55,  5.28it/s]



781it [01:55,  5.61it/s]



782it [01:55,  5.63it/s]



783it [01:55,  5.69it/s]



784it [01:55,  5.65it/s]



785it [01:55,  5.74it/s]



786it [01:56,  5.91it/s]



787it [01:56,  5.98it/s]



788it [01:56,  5.97it/s]



789it [01:56,  6.03it/s]



790it [01:56,  6.02it/s]



791it [01:56,  6.04it/s]



792it [01:57,  5.98it/s]



793it [01:57,  6.04it/s]



794it [01:57,  6.09it/s]



795it [01:57,  6.05it/s]



796it [01:57,  5.97it/s]



797it [01:57,  6.06it/s]



798it [01:58,  6.10it/s]



799it [01:58,  6.03it/s]



800it [01:58,  5.88it/s]



801it [01:58,  5.75it/s]



802it [01:58,  5.87it/s]



803it [01:58,  6.00it/s]



804it [01:59,  6.17it/s]



805it [01:59,  6.17it/s]



806it [01:59,  6.22it/s]



807it [01:59,  6.20it/s]



808it [01:59,  6.08it/s]



809it [01:59,  5.88it/s]



810it [02:00,  5.90it/s]



811it [02:00,  5.90it/s]



812it [02:00,  5.95it/s]



813it [02:00,  5.79it/s]



814it [02:00,  5.91it/s]



815it [02:00,  5.84it/s]



816it [02:01,  5.93it/s]



817it [02:01,  5.87it/s]



818it [02:01,  6.05it/s]



819it [02:01,  5.99it/s]



820it [02:01,  6.00it/s]



821it [02:01,  5.97it/s]



822it [02:02,  5.93it/s]



823it [02:02,  5.85it/s]



824it [02:02,  6.04it/s]



825it [02:02,  5.81it/s]



826it [02:02,  5.79it/s]



827it [02:02,  5.84it/s]



828it [02:03,  5.84it/s]



829it [02:03,  5.77it/s]



830it [02:03,  5.86it/s]



831it [02:03,  5.90it/s]



832it [02:03,  5.91it/s]



833it [02:04,  5.84it/s]



834it [02:04,  5.67it/s]



835it [02:04,  5.73it/s]



836it [02:04,  5.78it/s]



837it [02:04,  5.55it/s]



838it [02:04,  4.95it/s]



839it [02:05,  4.61it/s]



840it [02:05,  4.26it/s]



841it [02:05,  4.23it/s]



842it [02:06,  4.19it/s]



843it [02:06,  4.09it/s]



844it [02:06,  4.24it/s]



845it [02:06,  4.01it/s]



846it [02:06,  4.21it/s]



847it [02:07,  4.61it/s]



848it [02:07,  5.03it/s]



849it [02:07,  5.21it/s]



850it [02:07,  5.53it/s]



851it [02:07,  5.68it/s]



852it [02:07,  5.68it/s]



853it [02:08,  5.62it/s]



854it [02:08,  5.77it/s]



855it [02:08,  5.81it/s]



856it [02:08,  5.82it/s]



857it [02:08,  5.69it/s]



858it [02:09,  5.81it/s]



859it [02:09,  5.88it/s]



860it [02:09,  5.86it/s]



861it [02:09,  5.79it/s]



862it [02:09,  5.89it/s]



863it [02:09,  5.88it/s]



864it [02:10,  5.92it/s]



865it [02:10,  5.85it/s]



866it [02:10,  5.72it/s]



867it [02:10,  5.85it/s]



868it [02:10,  5.93it/s]



869it [02:10,  6.01it/s]



870it [02:11,  5.89it/s]



871it [02:11,  5.88it/s]



872it [02:11,  5.77it/s]



873it [02:11,  5.85it/s]



874it [02:11,  5.98it/s]



875it [02:11,  5.95it/s]



876it [02:12,  5.92it/s]



877it [02:12,  5.94it/s]



878it [02:12,  5.62it/s]



879it [02:12,  5.71it/s]



880it [02:12,  5.74it/s]



881it [02:12,  5.76it/s]



882it [02:13,  5.65it/s]



883it [02:13,  5.82it/s]



884it [02:13,  5.81it/s]



885it [02:13,  5.86it/s]



886it [02:13,  5.91it/s]



887it [02:13,  5.83it/s]



888it [02:14,  5.93it/s]



889it [02:14,  6.02it/s]



890it [02:14,  5.81it/s]



891it [02:14,  5.82it/s]



892it [02:14,  5.91it/s]



893it [02:14,  5.97it/s]



894it [02:15,  5.99it/s]



895it [02:15,  6.00it/s]



896it [02:15,  5.85it/s]



897it [02:15,  5.87it/s]



898it [02:15,  5.86it/s]



899it [02:15,  5.92it/s]



900it [02:16,  6.04it/s]



901it [02:16,  6.12it/s]



902it [02:16,  6.03it/s]



903it [02:16,  6.13it/s]



904it [02:16,  6.06it/s]



905it [02:17,  5.64it/s]



906it [02:17,  4.79it/s]



907it [02:17,  4.31it/s]



908it [02:17,  3.96it/s]



909it [02:18,  3.91it/s]



910it [02:18,  3.91it/s]



911it [02:18,  3.92it/s]



912it [02:18,  4.09it/s]



913it [02:19,  4.56it/s]



914it [02:19,  4.96it/s]



915it [02:19,  5.25it/s]



916it [02:19,  5.42it/s]



917it [02:19,  5.39it/s]



918it [02:19,  5.66it/s]



919it [02:20,  5.80it/s]



920it [02:20,  5.93it/s]



921it [02:20,  6.04it/s]



922it [02:20,  5.97it/s]



923it [02:20,  5.80it/s]



924it [02:20,  5.93it/s]



925it [02:21,  5.97it/s]



926it [02:21,  5.87it/s]



927it [02:21,  5.67it/s]



928it [02:21,  5.76it/s]



929it [02:21,  5.78it/s]



930it [02:21,  5.86it/s]



931it [02:22,  5.82it/s]



932it [02:22,  5.78it/s]



933it [02:22,  5.74it/s]



934it [02:22,  5.68it/s]



935it [02:22,  5.57it/s]



936it [02:22,  5.69it/s]



937it [02:23,  5.81it/s]



938it [02:23,  5.84it/s]



939it [02:23,  5.88it/s]



940it [02:23,  5.94it/s]



941it [02:23,  5.77it/s]



942it [02:23,  5.86it/s]



943it [02:24,  5.87it/s]



944it [02:24,  5.88it/s]



945it [02:24,  5.85it/s]



946it [02:24,  5.92it/s]



947it [02:24,  5.84it/s]



948it [02:25,  5.83it/s]



949it [02:25,  5.92it/s]



950it [02:25,  5.89it/s]



951it [02:25,  5.90it/s]



952it [02:25,  5.88it/s]



953it [02:25,  5.82it/s]



954it [02:26,  5.95it/s]



955it [02:26,  6.07it/s]



956it [02:26,  6.13it/s]



957it [02:26,  6.05it/s]



958it [02:26,  5.93it/s]



959it [02:26,  5.94it/s]



960it [02:27,  6.03it/s]



961it [02:27,  6.02it/s]



962it [02:27,  6.16it/s]



963it [02:27,  6.13it/s]



964it [02:27,  6.08it/s]



965it [02:27,  5.48it/s]



966it [02:28,  5.71it/s]



967it [02:28,  5.77it/s]



968it [02:28,  5.72it/s]



969it [02:28,  5.88it/s]



970it [02:28,  5.99it/s]



971it [02:28,  5.76it/s]



972it [02:29,  5.18it/s]



973it [02:29,  4.75it/s]



974it [02:29,  4.47it/s]



975it [02:29,  4.30it/s]



976it [02:30,  4.10it/s]



977it [02:30,  4.11it/s]



978it [02:30,  4.01it/s]



979it [02:30,  4.01it/s]



980it [02:31,  4.18it/s]



981it [02:31,  4.57it/s]



982it [02:31,  4.94it/s]



983it [02:31,  5.21it/s]



984it [02:31,  5.38it/s]



985it [02:31,  5.37it/s]



986it [02:32,  5.46it/s]



987it [02:32,  5.71it/s]



988it [02:32,  5.81it/s]



989it [02:32,  5.87it/s]



990it [02:32,  6.00it/s]



991it [02:32,  5.92it/s]



992it [02:33,  5.96it/s]



993it [02:33,  5.92it/s]



994it [02:33,  5.90it/s]



995it [02:33,  5.91it/s]



996it [02:33,  5.98it/s]



997it [02:34,  5.97it/s]



998it [02:34,  6.04it/s]



999it [02:34,  5.92it/s]



1000it [02:34,  5.89it/s]



1001it [02:34,  6.00it/s]



1002it [02:34,  5.43it/s]



1003it [02:35,  5.64it/s]



1004it [02:35,  5.82it/s]



1005it [02:35,  5.85it/s]



1006it [02:35,  5.95it/s]



1007it [02:35,  6.00it/s]



1008it [02:35,  5.94it/s]



1009it [02:36,  5.82it/s]



1010it [02:36,  5.97it/s]



1011it [02:36,  5.93it/s]



1012it [02:36,  5.83it/s]



1013it [02:36,  5.85it/s]



1014it [02:36,  6.01it/s]



1015it [02:37,  5.85it/s]



1016it [02:37,  5.94it/s]



1017it [02:37,  6.05it/s]



1018it [02:37,  6.20it/s]



1019it [02:37,  6.06it/s]



1020it [02:37,  6.46it/s]


In [22]:
all_features

{72: [-0.005025452923291596,
  -0.005025452923291596,
  -0.005025452923291596,
  -0.005025452923291596,
  -0.005025452923291596,
  -0.005025452923291596,
  -0.005025452923291596,
  -0.005025452923291596,
  -0.005025452923291596,
  -0.005025452923291596,
  -0.005025452923291596,
  -0.005025452923291596,
  -0.005025452923291596,
  -0.005025452923291596,
  -0.005025452923291596,
  -0.005025452923291596,
  -0.005025452923291596,
  -0.005025452923291596,
  -0.005025452923291596,
  -0.005025452923291596,
  -0.005025452923291596,
  -0.005025452923291596,
  -0.005025452923291596,
  -0.005025452923291596,
  -0.005025452923291596,
  -0.005025452923291596,
  -0.005025452923291596,
  -0.005025452923291596,
  -0.005025452923291596,
  -0.005025452923291596,
  -0.005025452923291596,
  -0.005025452923291596,
  -0.005025452923291596,
  -0.005025452923291596,
  -0.005025452923291596,
  -0.005025452923291596,
  -0.005025452923291596,
  -0.005025452923291596,
  -0.005025452923291596,
  -0.0050254529232915

In [23]:
epochs = 50
alpha = 0.5
batch_size = 16
n_examples_per_class = 4
EMA_lr = 0.9
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0003)


# Calculate initial centers
centers = tf.Variable(initial_value=tf.random.normal((102, 128), mean=0.0, stddev=0.5))

# Loop through all samples

for index in range(102):
    # category_features = all_features[y_train == index]
    category_center = all_features[y_train == index]
    # if len(category_features) > 0:
    if len(category_center) > 0 :
        # category_center = tf.reduce_mean(category_features, axis=0)
        centers[index].assign(category_center)

# Training
for epoch in tqdm(range(epochs)):
    total_loss = 0.0
    num_batches = 0
    for batch_x, batch_y in batch_me(images=x_train, labels=y_train, batch_size=batch_size, samples_per_class=n_examples_per_class):
        # Capture Gradients
        with tf.GradientTape() as tape:
            # Extract Features per batch
            temps = model(batch_x, training=False)
            # Seprate labels and their features
            predictions = temps[0]
            features = temps[1]
            # initialize batch centers
            batch_centers = centers.numpy()[batch_y]
            # Calculate Batch Centers
            for index in range(batch_size):
                instance_features = features[index]
                instance_mean = tf.reduce_mean(instance_features, axis=0)
                batch_centers[index] = instance_mean

            # Center-Loss calculation
            c_loss = center_loss(features, batch_centers)
            # Combine it with CategoricalCrossEntropyLoss
            cls_loss = tf.keras.losses.CategoricalCrossentropy()(batch_y, predictions)
            # Total Loss
            loss = (c_loss * alpha) + cls_loss

        # Calculate Gradients
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        # Update training loss
        total_loss += loss
        num_batches += 1

    # Calculate training Loss
    training_loss = total_loss / num_batches


    """Validation"""
    val_temps = model(x_test)
    # Extract Validation Features and Labels
    val_predictions = val_temps[0]
    val_features = val_temps[1]

    # Center Calcualtion
    val_batch_centers = centers.numpy()[y_test]
    for idx in range(len(y_test)):
        val_instance_features = val_features[idx]
        val_instance_mean = tf.reduce_mean(val_instance_features, axis=0)
        val_batch_centers[idx] = val_instance_mean

    # Loss Calculation
    val_c_loss = center_loss(val_features, val_batch_centers)
    val_cls_loss = tf.keras.losses.CategoricalCrossentropy()(y_test, val_predictions)
    val_loss = (val_c_loss * alpha) + val_cls_loss

    print(f"Epoch {epoch + 1}/{epochs} - Training Loss: {training_loss:.4f} - Validation Loss: {val_loss.numpy():.4f}")

    # Centers Update Frequency
    for index in range(102):
        category_features = all_features[y_train == index]
        if len(category_features)>0:
            category_center = tf.reduce_mean(category_features, axis=0)
            centers[index].assign((1.0 - EMA_lr) * centers[index]) + (EMA_lr * category_center)
            print(f"Centers updated - Step : {epoch+1}\n")


  0%|          | 0/50 [00:14<?, ?it/s]


InvalidArgumentError: ignored

# **TODO**

*The provided code seems to be an attempt to train a model for image similarity using both center loss and categorical cross-entropy loss on the Oxford 102 flowers dataset. Overall, the code structure and approach are reasonable, but there are a few areas that could be improved or clarified*:


- **Center Update**: The code updates the center values using the centers variable, but it initializes this variable with a random normal distribution. It might be more effective to initialize the centers with zeros, as this provides a better starting point.

- **Center Calculation**: The code calculates the batch centers for each batch using the mean of the feature vectors. However, in this implementation, the same center values are updated for each category within the batch. It's more common to calculate centers per category over the entire training set and then update them as the training progresses.

- **Learning Rate Schedule**: Consider using a learning rate schedule, such as the tf.keras.optimizers.schedules module, to adjust the learning rate during training. This can help with convergence and stability.

- **Batch Size and Center Calculation**: The current code calculates centers based on batch_y, which means it uses the current batch for calculating centers. This might lead to unstable center values. It's better to accumulate the feature vectors for each category over the entire training set and then calculate centers after an epoch.

- **Batch Normalization**: The model uses a pretrained EfficientNet model, which is a good choice. However, consider using batch normalization layers after the hub layer to improve training stability.

- **Variable Names**: Some variable names could be more descriptive. For example, temps and val_temps might be better named as outputs or something similar.

- **Validation Loop**: The validation loop calculates the loss for the entire validation set but doesn't update any values. It's common to calculate metrics such as accuracy as well.

- **Center Loss Term**: The alpha parameter is used to weight the center loss term. You might need to experiment with this value to achieve a good balance between the two loss components.

- **Normalization**: While the preprocessing of images includes scaling them to the range [0, 1], you may also consider standardizing the images (subtracting mean and dividing by standard deviation) to help convergence.

- **Model Checkpoints**: Consider using model checkpoints to save the best model during training based on validation performance.