In [1]:
# -*- coding: utf-8 -*-

import os
import uuid
import time
import random
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import LeakyReLU, ReLU
import numpy as np
from dataset import Dataset
from classifier import Classifier
from util import map_label
import tensorflow.keras.backend as K
import logging


tf.compat.v1.enable_eager_execution()

logging.basicConfig(level=logging.INFO)



In [2]:

class CE_GZSL:

    # optimizers

    generator_optimizer = None
    discriminator_optimizer = None

    # nets

    embedding_net = None
    comparator_net = None
    generator_net = None
    discriminator_net = None

    # params

    discriminator_iterations = None
    generator_noise = None
    gp_weight = None
    instance_weight = None
    class_weight = None
    instance_temperature = None
    class_temperature = None
    synthetic_number = None
    gzsl = True#None
    visual_size = None

    def __init__(self, generator_optimizer: keras.optimizers.Optimizer,
                 discriminator_optimizer: keras.optimizers.Optimizer,
                 args: dict, **kwargs):
        super(CE_GZSL, self).__init__(**kwargs)

        self.embedding(args["visual_size"], args["embedding_hidden"], args["embedding_size"])
        self.comparator(args["embedding_hidden"], args["attribute_size"], args["comparator_hidden"])
        self.generator(args["visual_size"], args["attribute_size"], args["generator_noise"], args["generator_hidden"])
        self.discriminator(args["visual_size"], args["attribute_size"], args["discriminator_hidden"])

        self.generator_optimizer = generator_optimizer
        self.discriminator_optimizer = discriminator_optimizer
        self.instance_weight = args["instance_weight"]
        self.class_weight = args["class_weight"]
        self.instance_temperature = args["instance_temperature"]
        self.class_temperature = args["class_temperature"]
        self.gp_weight = args["gp_weight"]
        self.synthetic_number = args["synthetic_number"]
        self.gzsl = args["gzsl"]
        self.visual_size = args["visual_size"]

        self.discriminator_iterations = args["discriminator_iterations"]
        self.generator_noise = args["generator_noise"]

    def summary(self):

        networks = [self.embedding_net, self.comparator_net, self.generator_net, self.discriminator_net]

        for net in networks:
            net.summary()

    def embedding(self, visual_size, hidden_units, embedded_size):

        inputs = keras.Input(shape=visual_size)
        x = keras.layers.Dense(hidden_units)(inputs)
        embed_h = ReLU(name="embed_h")(x)
        x = keras.layers.Dense(embedded_size)(embed_h)

        embed_z = keras.layers.Lambda(lambda x: K.l2_normalize(x, axis=1), name="embed_z")(x)

        self.embedding_net = keras.Model(inputs, [embed_h, embed_z], name="embedding")

    def comparator(self, embedding_size, attribute_size, hidden_units):

        inputs = keras.Input(shape=embedding_size + attribute_size)
        x = keras.layers.Dense(hidden_units)(inputs)
        x = LeakyReLU(0.2)(x)
        output = keras.layers.Dense(1, name="comp_out")(x)

        self.comparator_net = keras.Model(inputs, output, name="comparator")

    def generator(self, visual_size, attribute_size, noise, hidden_units):

        inputs = keras.Input(shape=attribute_size + noise)
        x = keras.layers.Dense(hidden_units)(inputs)
        x = LeakyReLU(0.2)(x)
        x = keras.layers.Dense(visual_size)(x)
        output = ReLU(name="gen_out")(x)
        self.generator_net = keras.Model(inputs, output, name="generator")

    def discriminator(self, visual_size, attribute_size, hidden_units):

        inputs = keras.Input(shape=visual_size + attribute_size)
        x = keras.layers.Dense(hidden_units)(inputs)
        x = LeakyReLU(0.2)(x)
        output = keras.layers.Dense(1, name="disc_out")(x)
        self.discriminator_net = keras.Model(inputs, output, name="discriminator")

    def d_loss_fn(self, real_logits, fake_logits):

        real_loss = tf.reduce_mean(real_logits)
        fake_loss = tf.reduce_mean(fake_logits)
        return fake_loss - real_loss

    def gradient_penalty(self, batch_size, real_images, fake_images, attribute_data):
        """ Calculates the gradient penalty.

        This loss is calculated on an interpolated image
        and added to the discriminator loss.
        """
        # Get the interpolated image
        alpha = tf.random.uniform(shape=(batch_size, 1))
        alpha = tf.tile(alpha, (1, real_images.shape[1]))
        interpolated = real_images * alpha + (1-alpha) * fake_images

        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated)
            # 1. Get the discriminator output for this interpolated image.
            pred = self.discriminator_net(tf.concat([interpolated, attribute_data], axis=1))

        # 2. Calculate the gradients w.r.t to this interpolated image.
        grads = gp_tape.gradient(pred, [interpolated])[0]
        # 3. Calculate the norm of the gradients.
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=1))
        gp = tf.reduce_mean((norm - 1.0) ** 2)
        return gp

    def contrastive_criterion(self, labels, feature_vectors):

        # Compute logits
        anchor_dot_contrast = tf.divide(
            tf.matmul(
                feature_vectors, tf.transpose(feature_vectors)
            ),
            self.instance_temperature,
        )

        logits_max = tf.reduce_max(anchor_dot_contrast, 1, keepdims=True)
        logits = anchor_dot_contrast - logits_max

        # Expand to [batch_size, 1]
        labels = tf.reshape(labels, (-1, 1))
        mask = tf.cast(tf.equal(labels, tf.transpose(labels)), dtype=tf.float32)

        # rosife: all except anchor
        logits_mask = tf.Variable(tf.ones_like(mask))
        indices = tf.reshape(tf.range(0, tf.shape(mask)[0]), (-1, 1))
        indices = tf.concat([indices, indices], axis=1)
        updates = tf.zeros((tf.shape(mask)[0]))
        logits_mask.scatter_nd_update(indices, updates)

        # rosife: positive except anchor
        mask = mask * logits_mask
        single_samples = tf.cast(tf.equal(tf.reduce_sum(mask, axis=1), 0), dtype=tf.float32)

        # compute log_prob
        masked_logits = tf.exp(logits) * logits_mask
        log_prob = logits - tf.math.log(tf.reduce_sum(masked_logits, 1, keepdims=True))

        # compute mean of log-likelihood over positive
        mean_log_prob_pos = tf.reduce_sum(mask * log_prob, 1) / (tf.reduce_sum(mask, 1)+single_samples)

        # loss
        loss = -mean_log_prob_pos * (1 - single_samples)
        loss = tf.reduce_sum(loss) / (tf.cast(tf.shape(loss)[0], dtype=tf.float32) - tf.reduce_sum(single_samples))

        return loss

    def class_scores_for_loop(self, embed, input_label, attribute_seen):

        n_class_seen = attribute_seen.shape[0]

        expand_embed = tf.reshape(tf.tile(tf.expand_dims(embed, 1), [1, n_class_seen, 1]), [embed.shape[0] * n_class_seen, -1])
        expand_att = tf.reshape(tf.tile(tf.expand_dims(attribute_seen, 0), [embed.shape[0], 1, 1]), [embed.shape[0] * n_class_seen, -1])
        all_scores = tf.reshape(tf.divide(self.comparator_net(tf.concat([expand_embed, expand_att], axis=1)), self.class_temperature), [embed.shape[0], n_class_seen])

        score_max = tf.reduce_max(all_scores, axis=1, keepdims=True)
        # normalize the scores for stable training
        scores_norm = all_scores - score_max

        exp_scores = tf.exp(scores_norm)

        mask = tf.one_hot(input_label, n_class_seen)

        log_scores = scores_norm - tf.math.log(tf.reduce_sum(exp_scores, axis=1, keepdims=True))

        cls_loss = -tf.reduce_mean(tf.reduce_sum(mask * log_scores, axis=1) / tf.reduce_sum(mask, axis=1))

        return cls_loss

    def generate_synthetic_features(self, classes, attribute_data):

        nclass = classes.shape[0]
        syn_feature = tf.Variable(tf.zeros((self.synthetic_number * nclass, self.visual_size)))
        syn_label = tf.Variable(tf.zeros((self.synthetic_number * nclass, 1), dtype=classes.dtype))

        for i in range(nclass):

            iclass = classes[i]
            iclass_att = attribute_data[iclass]

            syn_att = tf.repeat(tf.reshape(iclass_att, (1, -1)), self.synthetic_number, axis=0)

            syn_noise = tf.random.normal(shape=(self.synthetic_number, self.generator_noise))

            output = self.generator_net(tf.concat([syn_noise, syn_att], axis=1))

            syn_feature[i * self.synthetic_number:(i+1) * self.synthetic_number, :].assign(output)
            syn_label[i * self.synthetic_number:(i+1) * self.synthetic_number, :].assign(tf.fill((self.synthetic_number, 1), iclass))

        return syn_feature, syn_label

    def train_step(self, real_features, attribute_data, labels, attribute_seen):

        batch_size = tf.shape(real_features)[0]

        d_loss_tracker = keras.metrics.Mean()

        for i in range(self.discriminator_iterations):

            # Get the latent vector
            noise_data = tf.random.normal(shape=(batch_size, self.generator_noise))

            with tf.GradientTape() as tape:

                embed_real, z_real = self.embedding_net(real_features)

                real_ins_contras_loss = self.contrastive_criterion(labels, z_real)
                cls_loss_real = self.class_scores_for_loop(embed_real, labels, attribute_seen)

                # Generate fake images from the latent vector
                fake_features = self.generator_net(tf.concat([noise_data, attribute_data], axis=1))
                # Get the logits for the fake images
                fake_logits = self.discriminator_net(tf.concat([fake_features, attribute_data], axis=1))
                # Get the logits for the real images
                real_logits = self.discriminator_net(tf.concat([real_features, attribute_data], axis=1))

                # Calculate the discriminator loss using the fake and real image logits
                d_cost = self.d_loss_fn(real_logits, fake_logits)
                # Calculate the gradient penalty
                gp = self.gradient_penalty(batch_size, real_features, fake_features, attribute_data)
                # Add the gradient penalty to the original discriminator loss
                d_loss = d_cost + gp * self.gp_weight + real_ins_contras_loss + cls_loss_real

            trainable_variables = self.discriminator_net.trainable_variables + self.embedding_net.trainable_variables + self.comparator_net.trainable_variables

            # Get the gradients w.r.t the discriminator loss
            d_gradient = tape.gradient(d_loss, trainable_variables)

            # Update the weights of the discriminator using the discriminator optimizer
            self.discriminator_optimizer.apply_gradients(zip(d_gradient, trainable_variables))

            d_loss_tracker.update_state(d_loss)

        # Train the generator
        # Get the latent vector
        noise_data = tf.random.normal(shape=(batch_size, self.generator_noise))

        with tf.GradientTape() as tape:

            embed_real, z_real = self.embedding_net(real_features)

            embed_fake, z_fake = self.embedding_net(fake_features)

            fake_ins_contras_loss = self.contrastive_criterion(tf.concat([labels, labels], axis=0), tf.concat([z_fake, z_real], axis=0))
            cls_loss_fake = self.class_scores_for_loop(embed_fake, labels, attribute_seen)

            # Generate fake images using the generator
            fake_features = self.generator_net(tf.concat([noise_data, attribute_data], axis=1))
            # Get the discriminator logits for fake images
            fake_logits = self.discriminator_net(tf.concat([fake_features, attribute_data], axis=1))
            # Calculate the generator loss
            G_cost = -tf.reduce_mean(fake_logits)

            errG = G_cost + self.instance_weight * fake_ins_contras_loss + self.class_weight * cls_loss_fake

        # Get the gradients w.r.t the generator loss
        gen_gradient = tape.gradient(errG, self.generator_net.trainable_variables)
        # Update the weights of the generator using the generator optimizer
        self.generator_optimizer.apply_gradients(zip(gen_gradient, self.generator_net.trainable_variables))

        return d_loss_tracker.result(), errG, cls_loss_fake, cls_loss_real, fake_ins_contras_loss, real_ins_contras_loss

    def fit(self, dataset):

        train_features = dataset.train_features()
        train_attributes = dataset.train_attributes()
        train_labels = dataset.train_labels()

        unseen_classes = tf.constant(dataset.unseen_classes())
        seen_classes = tf.constant(dataset.seen_classes())

        attributes = tf.constant(dataset.attributes())

        train_feat_ds = tf.data.Dataset.from_tensor_slices(train_features)
        train_feat_ds = train_feat_ds.shuffle(buffer_size=train_features.shape[0], seed=seed).batch(batch_size)

        train_att_ds = tf.data.Dataset.from_tensor_slices(train_attributes)
        train_att_ds = train_att_ds.shuffle(buffer_size=train_attributes.shape[0], seed=seed).batch(batch_size)

        train_label_ds = tf.data.Dataset.from_tensor_slices(train_labels)
        train_label_ds = train_label_ds.shuffle(buffer_size=train_labels.shape[0], seed=seed).batch(batch_size)

        attribute_seen = tf.constant(ds.attribute_seen())

        for epoch in range(epochs):

            epoch_start = time.time()

            att_it = train_att_ds.__iter__()
            label_it = train_label_ds.__iter__()
            
            #att_it = tf.compat.v1.data.make_one_shot_iterator(train_att_ds)
            #label_it = tf.compat.v1.data.make_one_shot_iterator(train_label_ds)
            
            d_loss_tracker = keras.metrics.Mean()
            g_loss_tracker = keras.metrics.Mean()

            for step, train_feat in enumerate(train_feat_ds):

                train_att = att_it.next()
                train_label = label_it.next()

                train_label = map_label(train_label, seen_classes)

                d_loss, g_loss, cls_loss_fake, cls_loss_real, fake_ins_contras_loss, real_ins_contras_loss = ce_gzsl.train_step(train_feat, train_att, train_label, attribute_seen)

                d_loss_tracker.update_state(d_loss)
                g_loss_tracker.update_state(g_loss)

            logging.info("main epoch {} - d_loss {:.4f} - g_loss {:.4f} - time: {:.4f}".format(epoch, d_loss_tracker.result(), g_loss_tracker.result(), time.time() - epoch_start))

            # classification

            cls_start = time.time()

            if self.gzsl:

                syn_feature, syn_label = self.generate_synthetic_features(unseen_classes, attributes)

                train_x = tf.concat([train_features, syn_feature], axis=0)
                train_y = tf.concat([train_labels.reshape(-1, 1), syn_label], axis=0)
                num_classes = tf.size(unseen_classes) + tf.size(seen_classes)

                cls = Classifier(train_x, train_y, self.embedding_net, seed, num_classes, 25,
                                 self.synthetic_number, self.visual_size, cls_lr, beta1, beta2, dataset)

                acc_seen, acc_unseen, acc_h = cls.fit()

                logging.info('best acc: seen {:.4f} - unseen {:.4f} - H {:.4f} - time {:.4f}'.format(acc_seen, acc_unseen, acc_h, time.time() - cls_start))

            else:
                syn_feature, syn_label = self.generate_synthetic_features(unseen_classes, attributes)
                labels = map_label(syn_label, unseen_classes)
                num_classes = tf.size(unseen_classes)

                cls = Classifier(syn_feature, labels, self.embedding_net, seed, num_classes, 100,
                                 self.synthetic_number, self.visual_size, cls_lr, beta1, beta2, dataset)

                acc = cls.fit_zsl()

                logging.info('best acc: {:.4f} - time {:.4f}'.format(acc, time.time() - cls_start))

            if (epoch + 1) % checkpoint_epochs == 0:
                logging.info("saving checkpoint: {}".format(exp_path))
                np.save(os.path.join(exp_path, "syn_feature.npy"), syn_feature.numpy())
                np.save(os.path.join(exp_path, "syn_label.npy"), syn_label.numpy())
                self.generator_net.save(os.path.join(exp_path, "generator.h5"))
                self.discriminator_net.save(os.path.join(exp_path, "discriminator.h5"))
                self.comparator_net.save(os.path.join(exp_path, "comparator.h5"))
                self.embedding_net.save(os.path.join(exp_path, "embedding.h5"))


validation = False
preprocessing = False

exp_local_path = "<local_path>"
exp_remote_path = "E:/Sushree/Dataset/data/xlsa17/data"

exp_path = os.path.join(exp_remote_path, "cegzsl_experiments", str(uuid.uuid4()))
os.makedirs(exp_path)

data_local_path = "xlsa17"
data_remote_path = "E:/Sushree/Dataset/data/xlsa17/data"

ds = Dataset(data_remote_path)
ds.read("AWA2", preprocessing=preprocessing, validation=validation)

args = {"visual_size": ds.feature_size(),
        "attribute_size": ds.attribute_size(),
        "embedding_size": 512,
        "embedding_hidden": 2048,
        "comparator_hidden": 2048,
        "generator_hidden": 4096,
        "generator_noise": 1024,
        "discriminator_hidden": 4096,
        "discriminator_iterations": 5,
        "instance_weight": 0.001,
        "class_weight": 0.001,
        "instance_temperature": 0.1,
        "class_temperature": 0.1,
        "gp_weight": 10.0,
        "gzsl": True,
        "synthetic_number": 100}



INFO:root:This dataset does not support GZSL validation
INFO:root:preprocessing: False - validation: False
INFO:root:features: (37322, 2048) - attributes: (50, 85)
INFO:root:seen classes: 40 - unseen classes: 10
INFO:root:training: 23527 - test seen: 5882 - test unseen: 7913


In [3]:
# main training

epochs = 2000
batch_size = 4096
learning_rate = 0.0001
lr_decay = 0.99
lr_decay_epochs = 100
beta1 = 0.5
beta2 = 0.999
seed = 1985
checkpoint_epochs = 100

random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)

steps_per_epoch = int(np.ceil(ds.train_features().shape[0] / batch_size))
lr_decay_steps = lr_decay_epochs * steps_per_epoch

# classifier

cls_lr = 0.001

gen_lr_schedule = keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=learning_rate,
                                                              decay_steps=lr_decay_steps, decay_rate=lr_decay)

disc_lr_schedule = keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=learning_rate,
                                                               decay_steps=lr_decay_steps, decay_rate=lr_decay)

gen_opt = keras.optimizers.Adam(learning_rate=gen_lr_schedule, beta_1=beta1, beta_2=beta2)
disc_opt = keras.optimizers.Adam(learning_rate=disc_lr_schedule, beta_1=beta1, beta_2=beta2)

ce_gzsl = CE_GZSL(gen_opt, disc_opt, args)

ce_gzsl.summary()
ce_gzsl.fit(ds)


Model: "embedding"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 2048)]            0         
                                                                 
 dense (Dense)               (None, 2048)              4196352   
                                                                 
 embed_h (ReLU)              (None, 2048)              0         
                                                                 
 dense_1 (Dense)             (None, 512)               1049088   
                                                                 
 embed_z (Lambda)            (None, 512)               0         
                                                                 
Total params: 5,245,440
Trainable params: 5,245,440
Non-trainable params: 0
_________________________________________________________________
Model: "comparator"
___________________________

INFO:root:main epoch 0 - d_loss -2.1080 - g_loss -7.3806 - time: 27.6669
INFO:root:best acc: seen 0.9127 - unseen 0.0023 - H 0.0046 - time 67.0521
INFO:root:main epoch 1 - d_loss -0.5279 - g_loss -8.2227 - time: 27.9398
INFO:root:best acc: seen 0.9180 - unseen 0.0009 - H 0.0019 - time 67.3670
INFO:root:main epoch 2 - d_loss -2.9393 - g_loss -2.6960 - time: 27.8612
INFO:root:best acc: seen 0.9194 - unseen 0.0007 - H 0.0015 - time 67.6553
INFO:root:main epoch 3 - d_loss -8.6487 - g_loss -1.1355 - time: 28.0231


KeyboardInterrupt: 

In [4]:
train_features = ds.train_features()
print(train_features.shape)

train_attributes = ds.train_attributes()
print(train_attributes, train_attributes.shape)

train_labels = ds.train_labels()
print(train_labels, train_labels.shape)

unseen_classes = ds.unseen_classes()
print(unseen_classes, unseen_classes.shape)

seen_classes = ds.seen_classes()
print(seen_classes, seen_classes.shape)

(23527, 2048)
[[0.00575881 0.003829   0.         ... 0.03639079 0.13208508 0.01148699]
 [0.         0.00507555 0.         ... 0.15675628 0.09070911 0.01425057]
 [0.00575881 0.003829   0.         ... 0.03639079 0.13208508 0.01148699]
 ...
 [0.         0.06587863 0.         ... 0.02108505 0.10218637 0.0332632 ]
 [0.0084177  0.01262655 0.         ... 0.02104426 0.04138143 0.02316552]
 [0.03877321 0.15834625 0.         ... 0.1760296  0.07107783 0.30279845]] (23527, 85)
[42 21 42 ... 39 18 45] (23527,)
[ 6  8 22 23 29 30 33 40 46 49] (10,)
[ 0  1  2  3  4  5  7  9 10 11 12 13 14 15 16 17 18 19 20 21 24 25 26 27
 28 31 32 34 35 36 37 38 39 41 42 43 44 45 47 48] (40,)


In [5]:
seed = 1985
batch_size = 4096
train_feat_ds = tf.data.Dataset.from_tensor_slices(train_features)
train_feat_ds = train_feat_ds.shuffle(buffer_size=train_features.shape[0], seed=seed).batch(batch_size)

train_att_ds = tf.data.Dataset.from_tensor_slices(train_attributes)
train_att_ds = train_att_ds.shuffle(buffer_size=train_attributes.shape[0], seed=seed).batch(batch_size)

train_label_ds = tf.data.Dataset.from_tensor_slices(train_labels)
train_label_ds = train_label_ds.shuffle(buffer_size=train_labels.shape[0], seed=seed).batch(batch_size)

attribute_seen = tf.constant(ds.attribute_seen())
print(attribute_seen, attribute_seen.shape)

tf.Tensor(
[[-0.00375358 -0.00375358 -0.00375358 ...  0.00882092  0.03640974
   0.03145501]
 [ 0.12045617  0.00426584  0.         ...  0.17996307  0.0618086
   0.03495531]
 [ 0.26584458  0.20652363  0.         ...  0.05026821  0.04274552
   0.04915256]
 ...
 [ 0.03877321  0.15834625  0.         ...  0.1760296   0.07107783
   0.30279845]
 [ 0.22516498  0.15266022  0.         ...  0.12733492  0.10009693
   0.01771   ]
 [ 0.19613947  0.1966714   0.         ...  0.01787277  0.06698743
   0.258836  ]], shape=(40, 85), dtype=float32) (40, 85)


# Results


INFO:root:best acc: seen 0.8974 - unseen 0.1791 - H 0.2986 - time 74.9681
INFO:root:main epoch 1782 - d_loss -3.9678 - g_loss -45.5081 - time: 29.2714