In [9]:
import tensorflow as tf
from tensorflow import keras
from keras import layers
import numpy as np
import random

# 1) Load the CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
y_train = y_train.reshape(-1)
y_test = y_test.reshape(-1)
num_classes = 10

In [10]:
def create_discriminator_model(embedding_dim=64):
    inputs = layers.Input(shape=(32, 32, 3))

    x = layers.Conv2D(32, 3, activation='relu')(inputs)
    x = layers.MaxPooling2D(pool_size=(2,2))(x)

    x = layers.Conv2D(64, 3, activation='relu')(x)
    x = layers.MaxPooling2D(pool_size=(2,2))(x)

    x = layers.Conv2D(64, 3, activation='relu')(x)
    x = layers.MaxPooling2D(pool_size=(2,2))(x)

    x = layers.Flatten()(x)

    outputs = layers.Dense(embedding_dim)(x)

    return keras.Model(inputs, outputs)

class TripletLoss(keras.losses.Loss):
    def __init__(self, margin=1.0, **kwargs):
        super().__init__(**kwargs)
        self.margin = margin
    def call(self, y_true, y_pred):
        d = y_pred.shape[1] // 3
        anchor = y_pred[:, 0:d]
        positive = y_pred[:, d:2*d]
        negative = y_pred[:, 2*d:3*d]
        pos_dist = tf.reduce_sum(tf.square(anchor - positive), axis=1)
        neg_dist = tf.reduce_sum(tf.square(anchor - negative), axis=1)
        loss = tf.maximum(pos_dist - neg_dist + self.margin, 0.0)
        return tf.reduce_mean(loss)

def make_triplets(x, y, batch_size=64):
    while True:
        anchors, positives, negatives = [], [], []
        for _ in range(batch_size):
            idx_anchor = random.randint(0, len(x) - 1)
            anchor_img = x[idx_anchor]
            anchor_label = y[idx_anchor]
            same_class_indices = np.where(y == anchor_label)[0]
            diff_class_indices = np.where(y != anchor_label)[0]
            idx_positive = random.choice(same_class_indices)
            idx_negative = random.choice(diff_class_indices)
            anchors.append(anchor_img)
            positives.append(x[idx_positive])
            negatives.append(x[idx_negative])
        yield (
            (
                np.array(anchors, dtype=np.float32),
                np.array(positives, dtype=np.float32),
                np.array(negatives, dtype=np.float32),
            ),
            np.zeros((batch_size,), dtype=np.float32),
        )

def get_triplet_dataset(x, y, batch_size=64):
    def gen():
        return make_triplets(x, y, batch_size)
    output_types = (
        (tf.float32, tf.float32, tf.float32),
        tf.float32
    )
    output_shapes = (
        ((None, 32, 32, 3),
         (None, 32, 32, 3),
         (None, 32, 32, 3)),
        (None,)
    )
    ds = tf.data.Dataset.from_generator(gen, output_types=output_types, output_shapes=output_shapes)
    return ds

def create_triplet_model(base_model):
    anchor_input = layers.Input(shape=(32, 32, 3))
    positive_input = layers.Input(shape=(32, 32, 3))
    negative_input = layers.Input(shape=(32, 32, 3))
    anchor_emb = base_model(anchor_input)
    positive_emb = base_model(positive_input)
    negative_emb = base_model(negative_input)
    concatenated = layers.Concatenate(axis=1)([anchor_emb, positive_emb, negative_emb])
    return keras.Model([anchor_input, positive_input, negative_input], concatenated)

def create_classifier(discriminator_model, freeze=False, num_classes=10):
    discriminator_model.trainable = not freeze
    inputs = keras.Input(shape=(32, 32, 3))
    x = discriminator_model(inputs)
    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dense(128, activation='relu')(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    return keras.Model(inputs, outputs)

In [None]:
# 2) Implement a CNN model (e.g. 5 conv layers) with <=100 features on output
# 3) Train the discriminator using triplet loss
# 4) Save the weights of trained discriminator model
# 5) Build a model for classification of CIFAR-10, starting with the
#    layers from discriminator and one or two fully-connected layers after that
#

embedding_dim = 100
discriminator = create_discriminator_model(embedding_dim)

triplet_model = create_triplet_model(discriminator)
triplet_model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=TripletLoss(margin=0.0)
)

train_triplet_ds = get_triplet_dataset(x_train, y_train, batch_size=64).repeat()
triplet_model.fit(
    train_triplet_ds,
    steps_per_epoch=256,
    epochs=10
)

discriminator.save_weights("discriminator_triplet.weights.h5")

Epoch 1/10
[1m256/256[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 19ms/step - accuracy: 1.4555e-04 - loss: 0.0020
Epoch 2/10
[1m256/256[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 19ms/step - accuracy: 0.0000e+00 - loss: 6.3026e-08
Epoch 3/10
[1m256/256[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 19ms/step - accuracy: 0.0000e+00 - loss: 9.0775e-09
Epoch 4/10
[1m181/256[0m [32m━━━━━━━━━━━━━━[0m[37m━━━━━━[0m [1m1s[0m 19ms/step - accuracy: 0.0000e+00 - loss: 9.7893e-10

KeyboardInterrupt: 

In [17]:
# 6a) Randomly initialized model
#

random_init_disc = create_discriminator_model(embedding_dim)

classifier_a = create_classifier(random_init_disc, freeze=False)
classifier_a.compile(
    optimizer=keras.optimizers.Adam(),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

classifier_a.fit(
    x_train,
    y_train,
    validation_data=(x_test, y_test),
    epochs=3,
    batch_size=64
)

loss_a, acc_a = classifier_a.evaluate(x_test, y_test)


Epoch 1/3
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 11ms/step - accuracy: 0.3072 - loss: 1.8340 - val_accuracy: 0.5166 - val_loss: 1.3336
Epoch 2/3
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 11ms/step - accuracy: 0.5395 - loss: 1.2771 - val_accuracy: 0.5747 - val_loss: 1.1825
Epoch 3/3
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 11ms/step - accuracy: 0.6025 - loss: 1.1023 - val_accuracy: 0.6286 - val_loss: 1.0457
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.6278 - loss: 1.0320


In [18]:
# 6b) Model with weights for convolutional layers loaded from trained discriminator,
#     fully-connected layers initialized randomly

pretrained_disc_b = create_discriminator_model(embedding_dim)
pretrained_disc_b.load_weights("discriminator_triplet.weights.h5")
classifier_b = create_classifier(pretrained_disc_b, freeze=False)

classifier_b.compile(
    optimizer=keras.optimizers.Adam(),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

classifier_b.fit(
    x_train, y_train,
    validation_data=(x_test, y_test),
    epochs=3,
    batch_size=64
)

loss_b, acc_b = classifier_b.evaluate(x_test, y_test)


Epoch 1/3
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 11ms/step - accuracy: 0.1235 - loss: 2.2474 - val_accuracy: 0.4173 - val_loss: 1.5664
Epoch 2/3
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 11ms/step - accuracy: 0.4618 - loss: 1.4581 - val_accuracy: 0.5351 - val_loss: 1.2951
Epoch 3/3
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 11ms/step - accuracy: 0.5450 - loss: 1.2650 - val_accuracy: 0.5664 - val_loss: 1.2068
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step - accuracy: 0.5828 - loss: 1.1855


In [19]:
# 6c) Model with weights for convolutional layers loaded from trained discriminator and
#     freezed, fully-connected layers initialized randomly

pretrained_disc_c = create_discriminator_model(embedding_dim)
pretrained_disc_c.load_weights("discriminator_triplet.weights.h5")
classifier_c = create_classifier(pretrained_disc_c, freeze=True)

classifier_c.compile(
    optimizer=keras.optimizers.Adam(),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

classifier_c.fit(
    x_train, y_train,
    validation_data=(x_test, y_test),
    epochs=3,
    batch_size=64
)

loss_c, acc_c = classifier_c.evaluate(x_test, y_test)


Epoch 1/3
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 5ms/step - accuracy: 0.1001 - loss: 2.3027 - val_accuracy: 0.1000 - val_loss: 2.3026
Epoch 2/3
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 5ms/step - accuracy: 0.1018 - loss: 2.3026 - val_accuracy: 0.1000 - val_loss: 2.3026
Epoch 3/3
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 5ms/step - accuracy: 0.0957 - loss: 2.3027 - val_accuracy: 0.1000 - val_loss: 2.3026
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step - accuracy: 0.1027 - loss: 2.3026


In [22]:
print(f"(a) Random Init Accuracy: {acc_a:.4f}")
print(f"(b) Pretrained Trainable Accuracy: {acc_b:.4f}")
print(f"(c) Pretrained Frozen Accuracy: {acc_c:.4f}")


(a) Random Init Accuracy: 0.6286
(b) Pretrained Trainable Accuracy: 0.5664
(c) Pretrained Frozen Accuracy: 0.1000
