<a href="https://colab.research.google.com/github/MHosseinHashemi/Image_Similarity/blob/main/Image_Simmilarity_CenterLoss_TF_v4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import random
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, RandomFlip, RandomRotation, Dense, Dropout, Lambda

from tqdm import tqdm
from collections import defaultdict

In [2]:
(train_data, test_data, validation_data), info = tfds.load("oxford_flowers102", split=['train', 'validation', 'test'], as_supervised=True, with_info=True)

In [3]:
height = 128
width = 128

def preprocess_images(image, label, height, width):
    # image = tf.image.resize_with_crop_or_pad(image, target_height=height, target_width=width)
    image = tf.image.resize(image, [width, height])
    image = tf.cast(image, tf.float32) / 255.0
    return image, label


In [4]:
train_ds = train_data.map(lambda image, label: preprocess_images(image, label, height, width))

In [5]:
test_ds = test_data.map(lambda image, label: preprocess_images(image, label, height, width))

In [6]:
def data_loader(data):
  x = []
  y = []
  for img, label in tqdm(data.as_numpy_iterator()):
    x.append(img)
    y.append(label)

  return x, y

In [7]:
x_train, y_train = data_loader(train_ds)

1020it [00:01, 542.56it/s]


In [8]:
x_test, y_test = data_loader(test_ds)

1020it [00:02, 438.10it/s]


In [9]:
# Base Model
MODEL_URL = "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_s/feature_vector/2"

# model = tf.keras.Sequential([
#     tf.keras.layers.RandomFlip(),
#     tf.keras.layers.RandomRotation(0.3),
#     hub.KerasLayer(MODEL_URL, trainable=True),
#     tf.keras.layers.Dropout(0.25),
#     tf.keras.layers.Dense(128, activation=None),
#     tf.keras.layers.Dense(102, activation='softmax')
#     # tf.keras.layers.Lambda(lambda x: tf.math.l2_normalize(x, axis=1)) # L2 normalize embeddings
# ])

# model.build([None, height, height, 3])
# model.summary()

# With Functional API to prevent Further Err
input_layer = Input(shape=(height, width, 3))
x = RandomFlip()(input_layer)
x = RandomRotation(0.3)(x)
x = hub.KerasLayer(MODEL_URL, trainable=True)(x)
x = Dropout(0.25)(x)
x = Dense(128, activation=None)(x)
x = Lambda(lambda x: tf.math.l2_normalize(x, axis=1))(x)  # L2 normalize embeddings
output_layer = Dense(102, activation='softmax')(x)

model = tf.keras.Model(inputs=input_layer, outputs=output_layer)

model.summary()



Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 128, 128, 3)]     0         
                                                                 
 random_flip (RandomFlip)    (None, 128, 128, 3)       0         
                                                                 
 random_rotation (RandomRota  (None, 128, 128, 3)      0         
 tion)                                                           
                                                                 
 keras_layer (KerasLayer)    (None, 1280)              20331360  
                                                                 
 dropout (Dropout)           (None, 1280)              0         
                                                                 
 dense (Dense)               (None, 128)               163968    
                                                             

In [10]:
def batch_me(images, labels, batch_size, samples_per_class):
  temp_dict = defaultdict(list) # A Dic of Lists to save img, label pairs as one object
  # batch_count = 0  # just for debugging ...
  for img, label in zip(images, labels):
    temp_dict[label].append(img)

  while True:
    batch_x = []
    batch_y = []
    while len(batch_x) < batch_size:
      for category, examples in temp_dict.items():
        # Only feed as large as the "samples per class"
        # If the batch did not had enough space, feed as much as it has
        n_samples = min(samples_per_class, (batch_size - len(batch_x)))
        if n_samples == 0:
          break
        # Pick randomly from simmilar images of the same category
        samples = random.sample(examples, k=n_samples)
        # Add corresponding x, y values to the batch
        batch_x.extend(samples)
        batch_y.extend([category] * len(samples))
        # batch_count += 1  # Increment the counter


    # print(f"\nBatch count: {batch_count}") # just for debugging ...
    # It should be a continous operation
    yield np.array(batch_x), np.array(batch_y)


In [11]:
def center_loss(feature_vector, center):
    difference = feature_vector - center
    loss = tf.reduce_mean(tf.reduce_sum(difference**2, axis=1))

    return loss

# Feature Extraction

In [12]:
feature_extraction_model = Model(inputs=model.input, outputs=model.layers[-2].output)

In [13]:
# # Extarct All feature vectors and their corresponding centers
# raw_features = {}
# for k in tqdm(range(len(x_train))):
#     if y_train[k] in raw_features.keys():
#         # Avergae the new center with the previous one and replace it
#         new_center = feature_extraction_model.predict(np.expand_dims(x_train[k], axis=0)).mean()
#         prev_center = raw_features[y_train[k]]
#         raw_features[y_train[k]] = [(prev_center + new_center )/2] * 128
#         del new_center, prev_center

#     else:
#         raw_features[y_train[k]].append([feature_extraction_model.predict(np.expand_dims(x_train[k], axis=0)).mean()] * 128)


# all_features = {key: value for key, value in raw_features.items()}

# del raw_features,
# all_features

In [14]:
# raw_features = {}

# for k, (x, y) in tqdm(enumerate(zip(x_train, y_train))):
#     if y in raw_features:
#         new_center = feature_extraction_model.predict(np.expand_dims(x, axis=0)).mean()
#         prev_center = raw_features[y]
#         averaged_center = [(prev + new_center) / 2 for prev in prev_center]
#         raw_features[y] = [averaged_center] * 128
#     else:
#         feature_vector = feature_extraction_model.predict(np.expand_dims(x, axis=0)).mean()
#         raw_features[y] = [feature_vector] * 128

# all_features = {key: value for key, value in raw_features.items()}

In [15]:
raw_features = {}
class_feature_vectors = {}  # Store pre-calculated feature vectors for each class

# Calculate and store class feature vectors
for x, y in zip(x_train, y_train):
    if y not in class_feature_vectors:
        feature_vector = feature_extraction_model.predict(np.expand_dims(x, axis=0)).mean()
        class_feature_vectors[y] = feature_vector

# Calculate raw_features using class_feature_vectors
for x, y in zip(x_train, y_train):
    if y in raw_features:
        new_center = feature_extraction_model.predict(np.expand_dims(x, axis=0)).mean()
        raw_features[y] = [(prev + new_center) / 2 for prev in raw_features[y]]
    else:
        raw_features[y] = [class_feature_vectors[y]] * 128

all_features = {key: value for key, value in raw_features.items()}

# Clean up memory
del class_feature_vectors




# Training Loop

In [16]:
epochs = 10
alpha = 0.5
batch_size = 32
n_examples_per_class = 4
EMA_lr = 0.9
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0003)
num_steps_per_epoch = len(x_train) // batch_size

# Calculate initial centers
centers = tf.Variable(initial_value=tf.random.normal((102, 128), mean=0.0, stddev=0.5))

# Loop through all samples
for index in range(102):
    # category_features = all_features[y_train == index]
    category_center = all_features[y_train == index]
    # if len(category_features) > 0:
    if len(category_center) > 0 :
        # category_center = tf.reduce_mean(category_features, axis=0)
        centers[index].assign(category_center)

# Training
for epoch in tqdm(range(epochs)):
    total_loss = 0.0
    num_batches = 0

    # Create batch generator for the current epoch
    batch_generator = batch_me(images=x_train, labels=y_train, batch_size=batch_size, samples_per_class=n_examples_per_class)

    for batch_idx in range(num_steps_per_epoch):
        # Get the next batch
        batch_x, batch_y = next(batch_generator)
        # Capture Gradients
        with tf.GradientTape() as tape:
            # Extract Features per batch
            predictions = model(batch_x, training=False)
            # initialize batch centers
            batch_centers = centers.numpy()[batch_y]

            # Calculate Batch Centers
            for index in range(batch_size):
                instance_feature = feature_extraction_model(np.expand_dims(batch_x[index], axis=0))
                instance_mean = tf.reduce_mean(instance_feature[0], axis=0)
                batch_centers[index] = [instance_mean] * 128

            # Center-Loss calculation
            c_loss = center_loss(instance_feature, batch_centers)
            # Combine it with CategoricalCrossEntropyLoss
            cls_loss = tf.keras.losses.CategoricalCrossentropy()(tf.one_hot(batch_y, 102), predictions)
            # Total Loss
            loss = (c_loss * alpha) + cls_loss

        # Calculate Gradients
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        # Update training loss
        total_loss += loss
        num_batches += 1
        print(f"\nEpoch: {epoch} - step: {num_batches} - running loss: {loss:.3f}")

    # Calculate training Loss
    training_loss = total_loss / num_batches


    """Validation Loop"""
    val_batch_generator = batch_me(images=x_test, labels=y_test, batch_size=batch_size, samples_per_class=n_examples_per_class)

    for step in range(num_steps_per_epoch):
        val_batch_x, val_batch_y = next(val_batch_generator)
        # Make Predictions and Init Centers
        val_predictions = model(val_batch_x, training=False)
        val_batch_centers = centers.numpy()[val_batch_y]

    # Center Calcualtion
    for idx in range(batch_size):
        val_instance_features = feature_extraction_model(np.expand_dims(val_batch_x[idx], axis=0))
        val_instance_mean = tf.reduce_mean(val_instance_features[0], axis=0)
        val_batch_centers[idx] = val_instance_mean * 128

    # Loss Calculation
    val_c_loss = center_loss(val_instance_features, val_batch_centers)
    val_cls_loss = tf.keras.losses.CategoricalCrossentropy()(tf.one_hot(val_batch_y, 102), val_predictions)
    val_loss = (val_c_loss * alpha) + val_cls_loss
    val_loss = val_loss / num_batches

    print(f"Epoch {epoch + 1}/{epochs} - Training Loss: {training_loss:.4f} - Validation Loss: {val_loss.numpy():.4f}")
    print("="*100)

    # Centers Update Frequency
    for index in range(102):
        category_features = all_features[y_train == index]
        if len(category_features)>0:
            category_center = tf.reduce_mean(category_features, axis=0)
            # centers[index].assign((1.0 - EMA_lr) * centers[index]) + (EMA_lr * category_center)
            centers[index].assign((1.0 - EMA_lr) * tf.cast(centers[index], tf.float32) + (EMA_lr * tf.cast(category_center, tf.float32)))
    print(f"Centers updated - Step : {epoch}\n")





Epoch: 0 - step: 1 - running loss: 5.126

Epoch: 0 - step: 2 - running loss: 4.820

Epoch: 0 - step: 3 - running loss: 4.666

Epoch: 0 - step: 4 - running loss: 4.534

Epoch: 0 - step: 5 - running loss: 4.467

Epoch: 0 - step: 6 - running loss: 4.345

Epoch: 0 - step: 7 - running loss: 4.299

Epoch: 0 - step: 8 - running loss: 4.196

Epoch: 0 - step: 9 - running loss: 4.132

Epoch: 0 - step: 10 - running loss: 4.128

Epoch: 0 - step: 11 - running loss: 4.144

Epoch: 0 - step: 12 - running loss: 4.129

Epoch: 0 - step: 13 - running loss: 4.082

Epoch: 0 - step: 14 - running loss: 4.060

Epoch: 0 - step: 15 - running loss: 4.023

Epoch: 0 - step: 16 - running loss: 4.020

Epoch: 0 - step: 17 - running loss: 3.965

Epoch: 0 - step: 18 - running loss: 3.946

Epoch: 0 - step: 19 - running loss: 3.982

Epoch: 0 - step: 20 - running loss: 3.973

Epoch: 0 - step: 21 - running loss: 3.921

Epoch: 0 - step: 22 - running loss: 3.888

Epoch: 0 - step: 23 - running loss: 3.810

Epoch: 0 - step: 24

 10%|█         | 1/10 [03:01<27:13, 181.54s/it]

Centers updated - Step : 0


Epoch: 1 - step: 1 - running loss: 3.824

Epoch: 1 - step: 2 - running loss: 3.833

Epoch: 1 - step: 3 - running loss: 3.854

Epoch: 1 - step: 4 - running loss: 3.855

Epoch: 1 - step: 5 - running loss: 3.861

Epoch: 1 - step: 6 - running loss: 3.828

Epoch: 1 - step: 7 - running loss: 3.824

Epoch: 1 - step: 8 - running loss: 3.803

Epoch: 1 - step: 9 - running loss: 3.808

Epoch: 1 - step: 10 - running loss: 3.825

Epoch: 1 - step: 11 - running loss: 3.849

Epoch: 1 - step: 12 - running loss: 3.834

Epoch: 1 - step: 13 - running loss: 3.834

Epoch: 1 - step: 14 - running loss: 3.827

Epoch: 1 - step: 15 - running loss: 3.807

Epoch: 1 - step: 16 - running loss: 3.938

Epoch: 1 - step: 17 - running loss: 3.693

Epoch: 1 - step: 18 - running loss: 3.678

Epoch: 1 - step: 19 - running loss: 3.744

Epoch: 1 - step: 20 - running loss: 3.816

Epoch: 1 - step: 21 - running loss: 3.889

Epoch: 1 - step: 22 - running loss: 3.916

Epoch: 1 - step: 23 - running loss

 20%|██        | 2/10 [04:33<17:10, 128.83s/it]

Epoch 2/10 - Training Loss: 3.8162 - Validation Loss: 42.0866
Centers updated - Step : 1


Epoch: 2 - step: 1 - running loss: 3.804

Epoch: 2 - step: 2 - running loss: 3.857

Epoch: 2 - step: 3 - running loss: 3.878

Epoch: 2 - step: 4 - running loss: 3.726

Epoch: 2 - step: 5 - running loss: 3.667

Epoch: 2 - step: 6 - running loss: 3.658

Epoch: 2 - step: 7 - running loss: 3.636

Epoch: 2 - step: 8 - running loss: 3.662

Epoch: 2 - step: 9 - running loss: 3.702

Epoch: 2 - step: 10 - running loss: 3.671

Epoch: 2 - step: 11 - running loss: 3.692

Epoch: 2 - step: 12 - running loss: 3.697

Epoch: 2 - step: 13 - running loss: 3.699

Epoch: 2 - step: 14 - running loss: 3.651

Epoch: 2 - step: 15 - running loss: 3.678

Epoch: 2 - step: 16 - running loss: 3.646

Epoch: 2 - step: 17 - running loss: 3.599

Epoch: 2 - step: 18 - running loss: 3.635

Epoch: 2 - step: 19 - running loss: 3.623

Epoch: 2 - step: 20 - running loss: 3.581

Epoch: 2 - step: 21 - running loss: 3.601

Epoch: 2 - step

 30%|███       | 3/10 [06:04<13:00, 111.55s/it]

Epoch 3/10 - Training Loss: 3.6558 - Validation Loss: 96.9625
Centers updated - Step : 2


Epoch: 3 - step: 1 - running loss: 3.560

Epoch: 3 - step: 2 - running loss: 3.537

Epoch: 3 - step: 3 - running loss: 3.560

Epoch: 3 - step: 4 - running loss: 3.537

Epoch: 3 - step: 5 - running loss: 3.547

Epoch: 3 - step: 6 - running loss: 3.519

Epoch: 3 - step: 7 - running loss: 3.485

Epoch: 3 - step: 8 - running loss: 3.471

Epoch: 3 - step: 9 - running loss: 3.498

Epoch: 3 - step: 10 - running loss: 3.457

Epoch: 3 - step: 11 - running loss: 3.471

Epoch: 3 - step: 12 - running loss: 3.480

Epoch: 3 - step: 13 - running loss: 3.447

Epoch: 3 - step: 14 - running loss: 3.448

Epoch: 3 - step: 15 - running loss: 3.423

Epoch: 3 - step: 16 - running loss: 3.439

Epoch: 3 - step: 17 - running loss: 3.436

Epoch: 3 - step: 18 - running loss: 3.447

Epoch: 3 - step: 19 - running loss: 3.426

Epoch: 3 - step: 20 - running loss: 3.420

Epoch: 3 - step: 21 - running loss: 3.432

Epoch: 3 - step

 40%|████      | 4/10 [07:37<10:25, 104.18s/it]

Epoch 4/10 - Training Loss: 3.4503 - Validation Loss: 123.1928
Centers updated - Step : 3


Epoch: 4 - step: 1 - running loss: 3.379

Epoch: 4 - step: 2 - running loss: 3.397

Epoch: 4 - step: 3 - running loss: 3.412

Epoch: 4 - step: 4 - running loss: 3.366

Epoch: 4 - step: 5 - running loss: 3.350

Epoch: 4 - step: 6 - running loss: 3.330

Epoch: 4 - step: 7 - running loss: 3.351

Epoch: 4 - step: 8 - running loss: 3.336

Epoch: 4 - step: 9 - running loss: 3.329

Epoch: 4 - step: 10 - running loss: 3.348

Epoch: 4 - step: 11 - running loss: 3.317

Epoch: 4 - step: 12 - running loss: 3.328

Epoch: 4 - step: 13 - running loss: 3.341

Epoch: 4 - step: 14 - running loss: 3.311

Epoch: 4 - step: 15 - running loss: 3.358

Epoch: 4 - step: 16 - running loss: 3.308

Epoch: 4 - step: 17 - running loss: 3.295

Epoch: 4 - step: 18 - running loss: 3.287

Epoch: 4 - step: 19 - running loss: 3.283

Epoch: 4 - step: 20 - running loss: 3.322

Epoch: 4 - step: 21 - running loss: 3.319

Epoch: 4 - ste

 50%|█████     | 5/10 [09:08<08:18, 99.60s/it] 

Centers updated - Step : 4


Epoch: 5 - step: 1 - running loss: 3.236

Epoch: 5 - step: 3 - running loss: 3.209

Epoch: 5 - step: 4 - running loss: 3.196

Epoch: 5 - step: 5 - running loss: 3.213

Epoch: 5 - step: 6 - running loss: 3.202

Epoch: 5 - step: 7 - running loss: 3.192

Epoch: 5 - step: 8 - running loss: 3.205

Epoch: 5 - step: 9 - running loss: 3.202

Epoch: 5 - step: 10 - running loss: 3.177

Epoch: 5 - step: 11 - running loss: 3.181

Epoch: 5 - step: 12 - running loss: 3.174

Epoch: 5 - step: 13 - running loss: 3.164

Epoch: 5 - step: 14 - running loss: 3.172

Epoch: 5 - step: 15 - running loss: 3.137

Epoch: 5 - step: 16 - running loss: 3.145

Epoch: 5 - step: 17 - running loss: 3.174

Epoch: 5 - step: 18 - running loss: 3.146

Epoch: 5 - step: 19 - running loss: 3.147

Epoch: 5 - step: 20 - running loss: 3.141

Epoch: 5 - step: 21 - running loss: 3.132

Epoch: 5 - step: 22 - running loss: 3.153

Epoch: 5 - step: 23 - running loss: 3.136

Epoch: 5 - step: 24 - running los

 60%|██████    | 6/10 [10:40<06:27, 96.91s/it]

Epoch 6/10 - Training Loss: 3.1592 - Validation Loss: 84.5804
Centers updated - Step : 5


Epoch: 6 - step: 1 - running loss: 3.069

Epoch: 6 - step: 2 - running loss: 3.088

Epoch: 6 - step: 3 - running loss: 3.072

Epoch: 6 - step: 4 - running loss: 3.058

Epoch: 6 - step: 5 - running loss: 3.085

Epoch: 6 - step: 6 - running loss: 3.072

Epoch: 6 - step: 7 - running loss: 3.056

Epoch: 6 - step: 8 - running loss: 3.048

Epoch: 6 - step: 9 - running loss: 3.045

Epoch: 6 - step: 10 - running loss: 3.039

Epoch: 6 - step: 11 - running loss: 3.013

Epoch: 6 - step: 12 - running loss: 3.020

Epoch: 6 - step: 13 - running loss: 3.040

Epoch: 6 - step: 14 - running loss: 3.034

Epoch: 6 - step: 15 - running loss: 3.014

Epoch: 6 - step: 16 - running loss: 2.984

Epoch: 6 - step: 17 - running loss: 3.024

Epoch: 6 - step: 18 - running loss: 3.008

Epoch: 6 - step: 19 - running loss: 2.995

Epoch: 6 - step: 20 - running loss: 3.000

Epoch: 6 - step: 21 - running loss: 2.994

Epoch: 6 - step

 70%|███████   | 7/10 [12:09<04:43, 94.48s/it]

Epoch 7/10 - Training Loss: 3.0130 - Validation Loss: 103.2355
Centers updated - Step : 6


Epoch: 7 - step: 1 - running loss: 2.907

Epoch: 7 - step: 2 - running loss: 2.942

Epoch: 7 - step: 3 - running loss: 2.897

Epoch: 7 - step: 4 - running loss: 2.910

Epoch: 7 - step: 5 - running loss: 2.924

Epoch: 7 - step: 6 - running loss: 2.933

Epoch: 7 - step: 7 - running loss: 2.909

Epoch: 7 - step: 8 - running loss: 2.909

Epoch: 7 - step: 9 - running loss: 2.897

Epoch: 7 - step: 10 - running loss: 2.857

Epoch: 7 - step: 11 - running loss: 2.878

Epoch: 7 - step: 12 - running loss: 2.872

Epoch: 7 - step: 13 - running loss: 2.867

Epoch: 7 - step: 14 - running loss: 2.892

Epoch: 7 - step: 15 - running loss: 2.862

Epoch: 7 - step: 16 - running loss: 2.874

Epoch: 7 - step: 17 - running loss: 2.853

Epoch: 7 - step: 18 - running loss: 2.866

Epoch: 7 - step: 19 - running loss: 2.854

Epoch: 7 - step: 20 - running loss: 2.831

Epoch: 7 - step: 21 - running loss: 2.831

Epoch: 7 - ste

 80%|████████  | 8/10 [13:38<03:05, 92.65s/it]

Centers updated - Step : 7


Epoch: 8 - step: 1 - running loss: 2.793

Epoch: 8 - step: 2 - running loss: 2.781

Epoch: 8 - step: 3 - running loss: 2.783

Epoch: 8 - step: 4 - running loss: 2.779

Epoch: 8 - step: 5 - running loss: 2.727

Epoch: 8 - step: 6 - running loss: 2.771

Epoch: 8 - step: 7 - running loss: 2.740

Epoch: 8 - step: 8 - running loss: 2.759

Epoch: 8 - step: 9 - running loss: 2.723

Epoch: 8 - step: 10 - running loss: 2.722

Epoch: 8 - step: 11 - running loss: 2.733

Epoch: 8 - step: 12 - running loss: 2.754

Epoch: 8 - step: 13 - running loss: 2.716

Epoch: 8 - step: 14 - running loss: 2.700

Epoch: 8 - step: 15 - running loss: 2.722

Epoch: 8 - step: 16 - running loss: 2.719

Epoch: 8 - step: 17 - running loss: 2.702

Epoch: 8 - step: 18 - running loss: 2.683

Epoch: 8 - step: 19 - running loss: 2.695

Epoch: 8 - step: 20 - running loss: 2.699

Epoch: 8 - step: 21 - running loss: 2.670

Epoch: 8 - step: 22 - running loss: 2.685

Epoch: 8 - step: 23 - running loss

 90%|█████████ | 9/10 [15:08<01:31, 91.62s/it]

Centers updated - Step : 8


Epoch: 9 - step: 1 - running loss: 2.631

Epoch: 9 - step: 2 - running loss: 2.615

Epoch: 9 - step: 3 - running loss: 2.622

Epoch: 9 - step: 4 - running loss: 2.606

Epoch: 9 - step: 5 - running loss: 2.606

Epoch: 9 - step: 6 - running loss: 2.631

Epoch: 9 - step: 7 - running loss: 2.786

Epoch: 9 - step: 8 - running loss: 2.592

Epoch: 9 - step: 9 - running loss: 2.703

Epoch: 9 - step: 10 - running loss: 2.715

Epoch: 9 - step: 11 - running loss: 2.734

Epoch: 9 - step: 12 - running loss: 2.753

Epoch: 9 - step: 13 - running loss: 2.713

Epoch: 9 - step: 14 - running loss: 2.707

Epoch: 9 - step: 15 - running loss: 2.661

Epoch: 9 - step: 16 - running loss: 2.651

Epoch: 9 - step: 17 - running loss: 2.648

Epoch: 9 - step: 18 - running loss: 2.605

Epoch: 9 - step: 19 - running loss: 2.602

Epoch: 9 - step: 20 - running loss: 2.537

Epoch: 9 - step: 21 - running loss: 2.491

Epoch: 9 - step: 22 - running loss: 2.484

Epoch: 9 - step: 23 - running loss

100%|██████████| 10/10 [16:36<00:00, 99.64s/it]

Epoch 10/10 - Training Loss: 2.6091 - Validation Loss: 57.6644
Centers updated - Step : 9




