In [1]:
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.layers import Input, Conv2D, Flatten, Dense, Lambda
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.manifold import TSNE
import numpy as np
from sklearn.impute import SimpleImputer


# Load CIFAR-10 data
(x_train, _), (x_test, _) = cifar10.load_data()

# Normalize the data
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

# Data augmentation
data_generator = ImageDataGenerator(
    rotation_range=2,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True,
    vertical_flip=False,
    fill_mode='nearest'
)

augmented_x_train = data_generator.flow(x_train, batch_size=len(x_train), shuffle=False).next()



In [2]:
# Encoder network
def get_encoder_model(input_shape):
    inputs = Input(shape=input_shape)
    x = Conv2D(64, (3, 3), activation='relu')(inputs)
    x = Conv2D(128, (3, 3), activation='relu')(x)
    x = Flatten()(x)
    embeddings = Dense(128)(x)  # 128-dimensional embedding
    encoder = Model(inputs, embeddings, name='encoder')
    return encoder

input_shape = (32, 32, 3)
encoder = get_encoder_model(input_shape)

normalize_layer = Lambda(lambda x: tf.math.l2_normalize(x, axis=1))

input_positive = Input(shape=input_shape, name='input_positive')
input_negative = Input(shape=input_shape, name='input_negative')

embedding_positive = encoder(input_positive)
embedding_negative = encoder(input_negative)

normalized_embedding_positive = normalize_layer(embedding_positive)
normalized_embedding_negative = normalize_layer(embedding_negative)

merged_output = tf.concat([normalized_embedding_positive, normalized_embedding_negative], axis=1)

siamese_model = Model(inputs=[input_positive, input_negative], outputs=merged_output, name='siamese_model')

# def contrastive_loss(y_true, y_pred):
#     margin = 1
#     y_true = tf.cast(y_true, y_pred.dtype)
    
#     positive_pair_loss = y_true * tf.math.square(tf.norm(y_pred[:, :128] - y_pred[:, 128:], axis=1))
#     negative_pair_loss = (1 - y_true) * tf.math.square(tf.maximum(margin - tf.norm(y_pred[:, :128] - y_pred[:, 128:], axis=1), 0))
    
#     return 0.5 * tf.reduce_mean(positive_pair_loss + negative_pair_loss)


# For debugging
def isnan(x):
    return x != x

# Change in contrastive_loss
def contrastive_loss(y_true, y_pred):
    margin = 1
    y_true = tf.cast(y_true, y_pred.dtype)
    
    distance = tf.norm(y_pred[:, :128] - y_pred[:, 128:], axis=1)
    positive_pair_loss = y_true * tf.math.square(distance)
    negative_pair_loss = (1 - y_true) * tf.math.square(tf.maximum(margin - distance, 0))
    
    loss = 0.5 * tf.reduce_mean(positive_pair_loss + negative_pair_loss)
    
    # Debugging NaN loss
    if isnan(loss):
        tf.print("NaN detected in loss")
    
    return loss


def sample_pairs(data, batch_size):
    num_samples = len(data)
    indices = np.arange(num_samples)
    
    positive_idx_1 = np.random.choice(indices, size=batch_size)
    positive_idx_2 = np.random.choice(indices, size=batch_size)
    positive_pairs = np.stack([data[positive_idx_1], data[positive_idx_2]], axis=1)

    negative_idx_1 = np.random.choice(indices, size=batch_size)
    negative_idx_2 = np.random.choice(indices, size=batch_size)
    while np.any(negative_idx_1 == negative_idx_2):
        negative_idx_2 = np.random.choice(indices, size=batch_size)
    negative_pairs = np.stack([data[negative_idx_1], data[negative_idx_2]], axis=1)

    return positive_pairs, negative_pairs



In [None]:

siamese_model.compile(optimizer='adam', loss=contrastive_loss)

epochs = 2
batch_size = 32
evaluate_interval = 1
train_losses = []

for epoch in range(epochs):
    for step in range(len(x_train) // batch_size):
        positive_pairs, negative_pairs = sample_pairs(x_train, batch_size)
        positive_labels = np.ones((batch_size,))
        negative_labels = np.zeros((batch_size,))
        
        combined_pairs = np.vstack([positive_pairs, negative_pairs])
        combined_labels = np.hstack([positive_labels, negative_labels])
        
        loss = siamese_model.train_on_batch([combined_pairs[:, 0], combined_pairs[:, 1]], combined_labels)
        
        train_losses.append(loss)
        
        print(f"Epoch {epoch+1}/{epochs}, Step {step+1}/{len(x_train)//batch_size}, Loss: {loss:.4f}", end='\r')

    if (epoch + 1) % evaluate_interval == 0:
        avg_train_loss = sum(train_losses) / len(train_losses)
        print(f"\nEpoch {epoch+1}/{epochs}, Average Training Loss: {avg_train_loss:.4f}")
        
        train_losses = []
        
        
        encoder = siamese_model.get_layer('encoder')

        # Predict on the test set to get embeddings
        embeddings = encoder.predict(x_test)
        
        # Check for NaN Values and Handle them
        if np.isnan(embeddings).any():
            print("NaN values detected in the embeddings. Replacing with zeros.")
            embeddings = np.nan_to_num(embeddings)


            # Option A: Remove Rows with NaN Values
            # embeddings = embeddings[~np.isnan(embeddings).any(axis=1)]

            # Option B: Impute NaN Values
            imputer = SimpleImputer(strategy='mean')
            embeddings = imputer.fit_transform(embeddings)

        tsne = TSNE(n_components=2, random_state=42)
        embeddings_2D = tsne.fit_transform(embeddings)

        # Visualization
        plt.figure(figsize=(8, 6))
        sns.scatterplot(x=embeddings_2D[:, 0], y=embeddings_2D[:, 1], legend='full')
        plt.show()



NaN detected in loss1562, Loss: 0.1295
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562,

NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Lo

NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Lo

NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detect

NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
N

NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
NaN detected in loss/1562, Loss: nan
N

In [None]:

# Plot the training loss curve
plt.figure(figsize=(10, 6))
plt.plot(range(1, len(train_losses) + 1), train_losses, label='Training Loss')
plt.xlabel('Training Step')
plt.ylabel('Loss')
plt.title('Training Loss Curve')
plt.legend()
plt.show()
