In [1]:
# -*- coding: utf-8 -*-

import os
import uuid
import time
import random
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import LeakyReLU, ReLU
import numpy as np
from dataset import Dataset
from classifier import Classifier
from util import map_label
import tensorflow.keras.backend as K
import logging


tf.compat.v1.enable_eager_execution()

logging.basicConfig(level=logging.INFO)



In [2]:

class CE_GZSL:

    # optimizers

    generator_optimizer = None
    discriminator_optimizer = None

    # nets

    embedding_net = None
    comparator_net = None
    generator_net = None
    discriminator_net = None

    # params

    discriminator_iterations = None
    generator_noise = None
    gp_weight = None
    instance_weight = None
    class_weight = None
    instance_temperature = None
    class_temperature = None
    synthetic_number = None
    gzsl = True#None
    visual_size = None

    def __init__(self, generator_optimizer: keras.optimizers.Optimizer,
                 discriminator_optimizer: keras.optimizers.Optimizer,
                 args: dict, **kwargs):
        super(CE_GZSL, self).__init__(**kwargs)

        self.embedding(args["visual_size"], args["embedding_hidden"], args["embedding_size"])
        self.comparator(args["embedding_hidden"], args["attribute_size"], args["comparator_hidden"])
        self.generator(args["visual_size"], args["attribute_size"], args["generator_noise"], args["generator_hidden"])
        self.discriminator(args["visual_size"], args["attribute_size"], args["discriminator_hidden"])

        self.generator_optimizer = generator_optimizer
        self.discriminator_optimizer = discriminator_optimizer
        self.instance_weight = args["instance_weight"]
        self.class_weight = args["class_weight"]
        self.instance_temperature = args["instance_temperature"]
        self.class_temperature = args["class_temperature"]
        self.gp_weight = args["gp_weight"]
        self.synthetic_number = args["synthetic_number"]
        self.gzsl = args["gzsl"]
        self.visual_size = args["visual_size"]

        self.discriminator_iterations = args["discriminator_iterations"]
        self.generator_noise = args["generator_noise"]

    def summary(self):

        networks = [self.embedding_net, self.comparator_net, self.generator_net, self.discriminator_net]

        for net in networks:
            net.summary()

    def embedding(self, visual_size, hidden_units, embedded_size):

        inputs = keras.Input(shape=visual_size)
        x = keras.layers.Dense(hidden_units)(inputs)
        embed_h = ReLU(name="embed_h")(x)
        x = keras.layers.Dense(embedded_size)(embed_h)

        embed_z = keras.layers.Lambda(lambda x: K.l2_normalize(x, axis=1), name="embed_z")(x)

        self.embedding_net = keras.Model(inputs, [embed_h, embed_z], name="embedding")

    def comparator(self, embedding_size, attribute_size, hidden_units):

        inputs = keras.Input(shape=embedding_size + attribute_size)
        x = keras.layers.Dense(hidden_units)(inputs)
        x = LeakyReLU(0.2)(x)
        output = keras.layers.Dense(1, name="comp_out")(x)

        self.comparator_net = keras.Model(inputs, output, name="comparator")

    def generator(self, visual_size, attribute_size, noise, hidden_units):

        inputs = keras.Input(shape=attribute_size + noise)
        x = keras.layers.Dense(hidden_units)(inputs)
        x = LeakyReLU(0.2)(x)
        x = keras.layers.Dense(visual_size)(x)
        output = ReLU(name="gen_out")(x)
        self.generator_net = keras.Model(inputs, output, name="generator")

    def discriminator(self, visual_size, attribute_size, hidden_units):

        inputs = keras.Input(shape=visual_size + attribute_size)
        x = keras.layers.Dense(hidden_units)(inputs)
        x = LeakyReLU(0.2)(x)
        output = keras.layers.Dense(1, name="disc_out")(x)
        self.discriminator_net = keras.Model(inputs, output, name="discriminator")

    def d_loss_fn(self, real_logits, fake_logits):

        real_loss = tf.reduce_mean(real_logits)
        fake_loss = tf.reduce_mean(fake_logits)
        return fake_loss - real_loss

    def gradient_penalty(self, batch_size, real_images, fake_images, attribute_data):
        """ Calculates the gradient penalty.

        This loss is calculated on an interpolated image
        and added to the discriminator loss.
        """
        # Get the interpolated image
        alpha = tf.random.uniform(shape=(batch_size, 1))
        alpha = tf.tile(alpha, (1, real_images.shape[1]))
        interpolated = real_images * alpha + (1-alpha) * fake_images

        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated)
            # 1. Get the discriminator output for this interpolated image.
            pred = self.discriminator_net(tf.concat([interpolated, attribute_data], axis=1))

        # 2. Calculate the gradients w.r.t to this interpolated image.
        grads = gp_tape.gradient(pred, [interpolated])[0]
        # 3. Calculate the norm of the gradients.
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=1))
        gp = tf.reduce_mean((norm - 1.0) ** 2)
        return gp

    def contrastive_criterion(self, labels, feature_vectors):

        # Compute logits
        anchor_dot_contrast = tf.divide(
            tf.matmul(
                feature_vectors, tf.transpose(feature_vectors)
            ),
            self.instance_temperature,
        )

        logits_max = tf.reduce_max(anchor_dot_contrast, 1, keepdims=True)
        logits = anchor_dot_contrast - logits_max

        # Expand to [batch_size, 1]
        labels = tf.reshape(labels, (-1, 1))
        mask = tf.cast(tf.equal(labels, tf.transpose(labels)), dtype=tf.float32)

        # rosife: all except anchor
        logits_mask = tf.Variable(tf.ones_like(mask))
        indices = tf.reshape(tf.range(0, tf.shape(mask)[0]), (-1, 1))
        indices = tf.concat([indices, indices], axis=1)
        updates = tf.zeros((tf.shape(mask)[0]))
        logits_mask.scatter_nd_update(indices, updates)

        # rosife: positive except anchor
        mask = mask * logits_mask
        single_samples = tf.cast(tf.equal(tf.reduce_sum(mask, axis=1), 0), dtype=tf.float32)

        # compute log_prob
        masked_logits = tf.exp(logits) * logits_mask
        log_prob = logits - tf.math.log(tf.reduce_sum(masked_logits, 1, keepdims=True))

        # compute mean of log-likelihood over positive
        mean_log_prob_pos = tf.reduce_sum(mask * log_prob, 1) / (tf.reduce_sum(mask, 1)+single_samples)

        # loss
        loss = -mean_log_prob_pos * (1 - single_samples)
        loss = tf.reduce_sum(loss) / (tf.cast(tf.shape(loss)[0], dtype=tf.float32) - tf.reduce_sum(single_samples))

        return loss

    def class_scores_for_loop(self, embed, input_label, attribute_seen):

        n_class_seen = attribute_seen.shape[0]

        expand_embed = tf.reshape(tf.tile(tf.expand_dims(embed, 1), [1, n_class_seen, 1]), [embed.shape[0] * n_class_seen, -1])
        expand_att = tf.reshape(tf.tile(tf.expand_dims(attribute_seen, 0), [embed.shape[0], 1, 1]), [embed.shape[0] * n_class_seen, -1])
        all_scores = tf.reshape(tf.divide(self.comparator_net(tf.concat([expand_embed, expand_att], axis=1)), self.class_temperature), [embed.shape[0], n_class_seen])

        score_max = tf.reduce_max(all_scores, axis=1, keepdims=True)
        # normalize the scores for stable training
        scores_norm = all_scores - score_max

        exp_scores = tf.exp(scores_norm)

        mask = tf.one_hot(input_label, n_class_seen)

        log_scores = scores_norm - tf.math.log(tf.reduce_sum(exp_scores, axis=1, keepdims=True))

        cls_loss = -tf.reduce_mean(tf.reduce_sum(mask * log_scores, axis=1) / tf.reduce_sum(mask, axis=1))

        return cls_loss

    def generate_synthetic_features(self, classes, attribute_data):

        nclass = classes.shape[0]
        syn_feature = tf.Variable(tf.zeros((self.synthetic_number * nclass, self.visual_size)))
        syn_label = tf.Variable(tf.zeros((self.synthetic_number * nclass, 1), dtype=classes.dtype))

        for i in range(nclass):

            iclass = classes[i]
            iclass_att = attribute_data[iclass]

            syn_att = tf.repeat(tf.reshape(iclass_att, (1, -1)), self.synthetic_number, axis=0)

            syn_noise = tf.random.normal(shape=(self.synthetic_number, self.generator_noise))

            output = self.generator_net(tf.concat([syn_noise, syn_att], axis=1))

            syn_feature[i * self.synthetic_number:(i+1) * self.synthetic_number, :].assign(output)
            syn_label[i * self.synthetic_number:(i+1) * self.synthetic_number, :].assign(tf.fill((self.synthetic_number, 1), iclass))

        return syn_feature, syn_label

    def train_step(self, real_features, attribute_data, labels, attribute_seen):

        batch_size = tf.shape(real_features)[0]

        d_loss_tracker = keras.metrics.Mean()

        for i in range(self.discriminator_iterations):

            # Get the latent vector
            noise_data = tf.random.normal(shape=(batch_size, self.generator_noise))

            with tf.GradientTape() as tape:

                embed_real, z_real = self.embedding_net(real_features)

                real_ins_contras_loss = self.contrastive_criterion(labels, z_real)
                cls_loss_real = self.class_scores_for_loop(embed_real, labels, attribute_seen)

                # Generate fake images from the latent vector
                fake_features = self.generator_net(tf.concat([noise_data, attribute_data], axis=1))
                # Get the logits for the fake images
                fake_logits = self.discriminator_net(tf.concat([fake_features, attribute_data], axis=1))
                # Get the logits for the real images
                real_logits = self.discriminator_net(tf.concat([real_features, attribute_data], axis=1))

                # Calculate the discriminator loss using the fake and real image logits
                d_cost = self.d_loss_fn(real_logits, fake_logits)
                # Calculate the gradient penalty
                gp = self.gradient_penalty(batch_size, real_features, fake_features, attribute_data)
                # Add the gradient penalty to the original discriminator loss
                d_loss = d_cost + gp * self.gp_weight + real_ins_contras_loss + cls_loss_real

            trainable_variables = self.discriminator_net.trainable_variables + self.embedding_net.trainable_variables + self.comparator_net.trainable_variables

            # Get the gradients w.r.t the discriminator loss
            d_gradient = tape.gradient(d_loss, trainable_variables)

            # Update the weights of the discriminator using the discriminator optimizer
            self.discriminator_optimizer.apply_gradients(zip(d_gradient, trainable_variables))

            d_loss_tracker.update_state(d_loss)

        # Train the generator
        # Get the latent vector
        noise_data = tf.random.normal(shape=(batch_size, self.generator_noise))

        with tf.GradientTape() as tape:

            embed_real, z_real = self.embedding_net(real_features)

            embed_fake, z_fake = self.embedding_net(fake_features)

            fake_ins_contras_loss = self.contrastive_criterion(tf.concat([labels, labels], axis=0), tf.concat([z_fake, z_real], axis=0))
            cls_loss_fake = self.class_scores_for_loop(embed_fake, labels, attribute_seen)

            # Generate fake images using the generator
            fake_features = self.generator_net(tf.concat([noise_data, attribute_data], axis=1))
            # Get the discriminator logits for fake images
            fake_logits = self.discriminator_net(tf.concat([fake_features, attribute_data], axis=1))
            # Calculate the generator loss
            G_cost = -tf.reduce_mean(fake_logits)

            errG = G_cost + self.instance_weight * fake_ins_contras_loss + self.class_weight * cls_loss_fake

        # Get the gradients w.r.t the generator loss
        gen_gradient = tape.gradient(errG, self.generator_net.trainable_variables)
        # Update the weights of the generator using the generator optimizer
        self.generator_optimizer.apply_gradients(zip(gen_gradient, self.generator_net.trainable_variables))

        return d_loss_tracker.result(), errG, cls_loss_fake, cls_loss_real, fake_ins_contras_loss, real_ins_contras_loss

    def fit(self, dataset):

        train_features = dataset.train_features()
        train_attributes = dataset.train_attributes()
        train_labels = dataset.train_labels()

        unseen_classes = tf.constant(dataset.unseen_classes())
        seen_classes = tf.constant(dataset.seen_classes())

        attributes = tf.constant(dataset.attributes())

        train_feat_ds = tf.data.Dataset.from_tensor_slices(train_features)
        train_feat_ds = train_feat_ds.shuffle(buffer_size=train_features.shape[0], seed=seed).batch(batch_size)

        train_att_ds = tf.data.Dataset.from_tensor_slices(train_attributes)
        train_att_ds = train_att_ds.shuffle(buffer_size=train_attributes.shape[0], seed=seed).batch(batch_size)

        train_label_ds = tf.data.Dataset.from_tensor_slices(train_labels)
        train_label_ds = train_label_ds.shuffle(buffer_size=train_labels.shape[0], seed=seed).batch(batch_size)

        attribute_seen = tf.constant(ds.attribute_seen())

        for epoch in range(epochs):

            epoch_start = time.time()

            att_it = train_att_ds.__iter__()
            label_it = train_label_ds.__iter__()
            
            #att_it = tf.compat.v1.data.make_one_shot_iterator(train_att_ds)
            #label_it = tf.compat.v1.data.make_one_shot_iterator(train_label_ds)
            
            d_loss_tracker = keras.metrics.Mean()
            g_loss_tracker = keras.metrics.Mean()

            for step, train_feat in enumerate(train_feat_ds):

                train_att = att_it.next()
                train_label = label_it.next()

                train_label = map_label(train_label, seen_classes)

                d_loss, g_loss, cls_loss_fake, cls_loss_real, fake_ins_contras_loss, real_ins_contras_loss = ce_gzsl.train_step(train_feat, train_att, train_label, attribute_seen)

                d_loss_tracker.update_state(d_loss)
                g_loss_tracker.update_state(g_loss)

            logging.info("main epoch {} - d_loss {:.4f} - g_loss {:.4f} - time: {:.4f}".format(epoch, d_loss_tracker.result(), g_loss_tracker.result(), time.time() - epoch_start))

            # classification

            cls_start = time.time()

            if self.gzsl:

                syn_feature, syn_label = self.generate_synthetic_features(unseen_classes, attributes)

                train_x = tf.concat([train_features, syn_feature], axis=0)
                train_y = tf.concat([train_labels.reshape(-1, 1), syn_label], axis=0)
                num_classes = tf.size(unseen_classes) + tf.size(seen_classes)

                cls = Classifier(train_x, train_y, self.embedding_net, seed, num_classes, 25,
                                 self.synthetic_number, self.visual_size, cls_lr, beta1, beta2, dataset)

                acc_seen, acc_unseen, acc_h = cls.fit()

                logging.info('best acc: seen {:.4f} - unseen {:.4f} - H {:.4f} - time {:.4f}'.format(acc_seen, acc_unseen, acc_h, time.time() - cls_start))

            else:
                syn_feature, syn_label = self.generate_synthetic_features(unseen_classes, attributes)
                labels = map_label(syn_label, unseen_classes)
                num_classes = tf.size(unseen_classes)

                cls = Classifier(syn_feature, labels, self.embedding_net, seed, num_classes, 100,
                                 self.synthetic_number, self.visual_size, cls_lr, beta1, beta2, dataset)

                acc = cls.fit_zsl()

                logging.info('best acc: {:.4f} - time {:.4f}'.format(acc, time.time() - cls_start))

            if (epoch + 1) % checkpoint_epochs == 0:
                logging.info("saving checkpoint: {}".format(exp_path))
                np.save(os.path.join(exp_path, "syn_feature.npy"), syn_feature.numpy())
                np.save(os.path.join(exp_path, "syn_label.npy"), syn_label.numpy())
                self.generator_net.save(os.path.join(exp_path, "generator.h5"))
                self.discriminator_net.save(os.path.join(exp_path, "discriminator.h5"))
                self.comparator_net.save(os.path.join(exp_path, "comparator.h5"))
                self.embedding_net.save(os.path.join(exp_path, "embedding.h5"))


validation = False
preprocessing = False

exp_local_path = "<local_path>"
exp_remote_path = "E:/Sushree/Dataset/data/xlsa17/data"

exp_path = os.path.join(exp_remote_path, "cegzsl_experiments", str(uuid.uuid4()))
os.makedirs(exp_path)

data_local_path = "xlsa17"
data_remote_path = "E:/Sushree/Dataset/data/xlsa17/data"

ds = Dataset(data_remote_path)
ds.read("AWA2", preprocessing=preprocessing, validation=validation)

args = {"visual_size": ds.feature_size(),
        "attribute_size": ds.attribute_size(),
        "embedding_size": 512,
        "embedding_hidden": 2048,
        "comparator_hidden": 2048,
        "generator_hidden": 4096,
        "generator_noise": 1024,
        "discriminator_hidden": 4096,
        "discriminator_iterations": 5,
        "instance_weight": 0.001,
        "class_weight": 0.001,
        "instance_temperature": 0.1,
        "class_temperature": 0.1,
        "gp_weight": 10.0,
        "gzsl": True,
        "synthetic_number": 100}



INFO:root:This dataset does not support GZSL validation
INFO:root:preprocessing: False - validation: False
INFO:root:features: (37322, 2048) - attributes: (50, 85)
INFO:root:seen classes: 40 - unseen classes: 10
INFO:root:training: 23527 - test seen: 5882 - test unseen: 7913


In [3]:
# main training

epochs = 2000
batch_size = 4096
learning_rate = 0.0001
lr_decay = 0.99
lr_decay_epochs = 100
beta1 = 0.5
beta2 = 0.999
seed = 1985
checkpoint_epochs = 100

random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)

steps_per_epoch = int(np.ceil(ds.train_features().shape[0] / batch_size))
lr_decay_steps = lr_decay_epochs * steps_per_epoch

# classifier

cls_lr = 0.001

gen_lr_schedule = keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=learning_rate,
                                                              decay_steps=lr_decay_steps, decay_rate=lr_decay)

disc_lr_schedule = keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=learning_rate,
                                                               decay_steps=lr_decay_steps, decay_rate=lr_decay)

gen_opt = keras.optimizers.Adam(learning_rate=gen_lr_schedule, beta_1=beta1, beta_2=beta2)
disc_opt = keras.optimizers.Adam(learning_rate=disc_lr_schedule, beta_1=beta1, beta_2=beta2)

ce_gzsl = CE_GZSL(gen_opt, disc_opt, args)

ce_gzsl.summary()
ce_gzsl.fit(ds)


INFO:root:main epoch 0 - d_loss -2.1073 - g_loss -7.3806 - time: 28.3429
INFO:root:best acc: seen 0.9116 - unseen 0.0023 - H 0.0046 - time 67.8609
INFO:root:main epoch 1 - d_loss -0.5265 - g_loss -8.2228 - time: 28.0057
INFO:root:best acc: seen 0.9213 - unseen 0.0006 - H 0.0012 - time 68.0300
INFO:root:main epoch 2 - d_loss -2.9402 - g_loss -2.6959 - time: 28.0276
INFO:root:best acc: seen 0.9185 - unseen 0.0011 - H 0.0023 - time 67.8010
INFO:root:main epoch 3 - d_loss -8.6505 - g_loss -1.1355 - time: 27.8373
INFO:root:best acc: seen 0.9233 - unseen 0.0009 - H 0.0017 - time 66.9239
INFO:root:main epoch 4 - d_loss -12.4219 - g_loss -2.9061 - time: 27.8640
INFO:root:best acc: seen 0.9228 - unseen 0.0013 - H 0.0026 - time 67.2837
INFO:root:main epoch 5 - d_loss -14.6206 - g_loss -3.8656 - time: 27.8955
INFO:root:best acc: seen 0.9269 - unseen 0.0006 - H 0.0011 - time 66.9169
INFO:root:main epoch 6 - d_loss -15.7277 - g_loss -4.6157 - time: 27.7669
INFO:root:best acc: seen 0.9268 - unseen 0

INFO:root:best acc: seen 0.9205 - unseen 0.0546 - H 0.1031 - time 65.8755
INFO:root:main epoch 56 - d_loss -9.9698 - g_loss -5.1743 - time: 27.4286
INFO:root:best acc: seen 0.9208 - unseen 0.0433 - H 0.0828 - time 66.5309
INFO:root:main epoch 57 - d_loss -10.3137 - g_loss -5.1676 - time: 27.4428
INFO:root:best acc: seen 0.9209 - unseen 0.0344 - H 0.0663 - time 65.5390
INFO:root:main epoch 58 - d_loss -10.5940 - g_loss -5.1897 - time: 27.3924
INFO:root:best acc: seen 0.9221 - unseen 0.0391 - H 0.0751 - time 66.0690
INFO:root:main epoch 59 - d_loss -10.8650 - g_loss -5.1772 - time: 27.5389
INFO:root:best acc: seen 0.9234 - unseen 0.0385 - H 0.0739 - time 65.4979
INFO:root:main epoch 60 - d_loss -11.1004 - g_loss -5.3471 - time: 27.4559
INFO:root:best acc: seen 0.9239 - unseen 0.0500 - H 0.0948 - time 66.5765
INFO:root:main epoch 61 - d_loss -11.3185 - g_loss -5.4649 - time: 27.5032
INFO:root:best acc: seen 0.9223 - unseen 0.0547 - H 0.1033 - time 66.7466
INFO:root:main epoch 62 - d_loss 















INFO:root:main epoch 100 - d_loss -12.7260 - g_loss -14.7030 - time: 27.4285
INFO:root:best acc: seen 0.9196 - unseen 0.1048 - H 0.1882 - time 66.1652
INFO:root:main epoch 101 - d_loss -12.6855 - g_loss -14.7483 - time: 27.4516
INFO:root:best acc: seen 0.9205 - unseen 0.1148 - H 0.2041 - time 66.2366
INFO:root:main epoch 102 - d_loss -12.6311 - g_loss -15.3569 - time: 27.4279
INFO:root:best acc: seen 0.9196 - unseen 0.1214 - H 0.2145 - time 66.1177
INFO:root:main epoch 103 - d_loss -12.5600 - g_loss -15.6555 - time: 27.6158
INFO:root:best acc: seen 0.9183 - unseen 0.1110 - H 0.1980 - time 66.5249
INFO:root:main epoch 104 - d_loss -12.5255 - g_loss -15.8092 - time: 27.5521
INFO:root:best acc: seen 0.9207 - unseen 0.1374 - H 0.2392 - time 65.9133
INFO:root:main epoch 105 - d_loss -12.4841 - g_loss -16.4400 - time: 27.5875
INFO:root:best acc: seen 0.9162 - unseen 0.1250 - H 0.2200 - time 66.0771
INFO:root:main epoch 106 - d_loss -12.4250 - g_loss -16.2460 - time: 27.5381
INFO:root:best ac

INFO:root:best acc: seen 0.9181 - unseen 0.1399 - H 0.2429 - time 64.8953
INFO:root:main epoch 154 - d_loss -11.0981 - g_loss -28.3280 - time: 27.3971
INFO:root:best acc: seen 0.9196 - unseen 0.1573 - H 0.2687 - time 65.3535
INFO:root:main epoch 155 - d_loss -11.0750 - g_loss -28.3126 - time: 27.5474
INFO:root:best acc: seen 0.9198 - unseen 0.1597 - H 0.2722 - time 66.1517
INFO:root:main epoch 156 - d_loss -11.0554 - g_loss -28.8704 - time: 27.4306
INFO:root:best acc: seen 0.9220 - unseen 0.1575 - H 0.2690 - time 65.3449
INFO:root:main epoch 157 - d_loss -11.0213 - g_loss -28.8137 - time: 27.5845
INFO:root:best acc: seen 0.9209 - unseen 0.1414 - H 0.2452 - time 65.5206
INFO:root:main epoch 158 - d_loss -10.9842 - g_loss -28.9496 - time: 27.4906
INFO:root:best acc: seen 0.9192 - unseen 0.1715 - H 0.2890 - time 66.4344
INFO:root:main epoch 159 - d_loss -10.9526 - g_loss -29.7614 - time: 27.7576
INFO:root:best acc: seen 0.9173 - unseen 0.1585 - H 0.2703 - time 65.1002
INFO:root:main epoch















INFO:root:main epoch 200 - d_loss -9.8661 - g_loss -36.5648 - time: 27.5124
INFO:root:best acc: seen 0.9184 - unseen 0.1974 - H 0.3249 - time 66.3408
INFO:root:main epoch 201 - d_loss -9.8334 - g_loss -36.5484 - time: 27.5737
INFO:root:best acc: seen 0.9177 - unseen 0.1857 - H 0.3089 - time 66.2338
INFO:root:main epoch 202 - d_loss -9.8114 - g_loss -36.8771 - time: 27.3696
INFO:root:best acc: seen 0.9193 - unseen 0.1610 - H 0.2741 - time 66.0999
INFO:root:main epoch 203 - d_loss -9.8018 - g_loss -36.8246 - time: 27.5205
INFO:root:best acc: seen 0.9171 - unseen 0.1770 - H 0.2967 - time 66.3947
INFO:root:main epoch 204 - d_loss -9.7915 - g_loss -37.2311 - time: 27.5084
INFO:root:best acc: seen 0.9182 - unseen 0.1704 - H 0.2875 - time 66.4565
INFO:root:main epoch 205 - d_loss -9.7628 - g_loss -37.2692 - time: 27.6608
INFO:root:best acc: seen 0.9198 - unseen 0.1868 - H 0.3105 - time 66.2479
INFO:root:main epoch 206 - d_loss -9.7259 - g_loss -37.2245 - time: 27.6352
INFO:root:best acc: seen

INFO:root:best acc: seen 0.9155 - unseen 0.1753 - H 0.2942 - time 67.1521
INFO:root:main epoch 254 - d_loss -9.2155 - g_loss -43.0524 - time: 27.6137
INFO:root:best acc: seen 0.9132 - unseen 0.1897 - H 0.3141 - time 66.7871
INFO:root:main epoch 255 - d_loss -9.1847 - g_loss -42.9260 - time: 27.5365
INFO:root:best acc: seen 0.9152 - unseen 0.1665 - H 0.2817 - time 65.8344
INFO:root:main epoch 256 - d_loss -9.1916 - g_loss -43.2514 - time: 27.6389
INFO:root:best acc: seen 0.9163 - unseen 0.1633 - H 0.2771 - time 66.8436
INFO:root:main epoch 257 - d_loss -9.1856 - g_loss -43.0750 - time: 27.4595
INFO:root:best acc: seen 0.9169 - unseen 0.1757 - H 0.2950 - time 66.5293
INFO:root:main epoch 258 - d_loss -9.1753 - g_loss -43.4964 - time: 27.6463
INFO:root:best acc: seen 0.9136 - unseen 0.2046 - H 0.3343 - time 66.8505
INFO:root:main epoch 259 - d_loss -9.1763 - g_loss -43.4267 - time: 27.5373
INFO:root:best acc: seen 0.9126 - unseen 0.1894 - H 0.3137 - time 66.1163
INFO:root:main epoch 260 -















INFO:root:main epoch 300 - d_loss -8.7072 - g_loss -46.5930 - time: 27.5374
INFO:root:best acc: seen 0.9146 - unseen 0.1727 - H 0.2905 - time 66.7970
INFO:root:main epoch 301 - d_loss -8.6850 - g_loss -46.8308 - time: 27.6788
INFO:root:best acc: seen 0.9087 - unseen 0.1864 - H 0.3094 - time 66.1932
INFO:root:main epoch 302 - d_loss -8.6719 - g_loss -46.9030 - time: 27.4787
INFO:root:best acc: seen 0.9139 - unseen 0.1818 - H 0.3033 - time 66.7533
INFO:root:main epoch 303 - d_loss -8.6887 - g_loss -47.0405 - time: 27.7377
INFO:root:best acc: seen 0.9142 - unseen 0.1900 - H 0.3147 - time 67.7817
INFO:root:main epoch 304 - d_loss -8.6508 - g_loss -46.8448 - time: 27.5515
INFO:root:best acc: seen 0.9111 - unseen 0.2069 - H 0.3372 - time 66.3081
INFO:root:main epoch 305 - d_loss -8.6287 - g_loss -47.0630 - time: 27.6159
INFO:root:best acc: seen 0.9130 - unseen 0.1959 - H 0.3226 - time 67.0765
INFO:root:main epoch 306 - d_loss -8.6197 - g_loss -47.2498 - time: 27.4901
INFO:root:best acc: seen

INFO:root:best acc: seen 0.9071 - unseen 0.2087 - H 0.3394 - time 67.2662
INFO:root:main epoch 354 - d_loss -8.0321 - g_loss -49.3413 - time: 27.5801
INFO:root:best acc: seen 0.9082 - unseen 0.2141 - H 0.3465 - time 65.8272
INFO:root:main epoch 355 - d_loss -7.9992 - g_loss -49.5366 - time: 27.4709
INFO:root:best acc: seen 0.9107 - unseen 0.2185 - H 0.3524 - time 66.7066
INFO:root:main epoch 356 - d_loss -8.0228 - g_loss -48.8399 - time: 27.5217
INFO:root:best acc: seen 0.9117 - unseen 0.1982 - H 0.3257 - time 66.5280
INFO:root:main epoch 357 - d_loss -8.0156 - g_loss -49.2456 - time: 27.6056
INFO:root:best acc: seen 0.9068 - unseen 0.2187 - H 0.3524 - time 66.9652
INFO:root:main epoch 358 - d_loss -7.9704 - g_loss -49.4413 - time: 27.6310
INFO:root:best acc: seen 0.9093 - unseen 0.2087 - H 0.3395 - time 66.9506
INFO:root:main epoch 359 - d_loss -7.9643 - g_loss -49.2100 - time: 27.5217
INFO:root:best acc: seen 0.9088 - unseen 0.2046 - H 0.3340 - time 67.1576
INFO:root:main epoch 360 -















INFO:root:main epoch 400 - d_loss -7.5879 - g_loss -49.8707 - time: 27.6536
INFO:root:best acc: seen 0.9084 - unseen 0.2131 - H 0.3452 - time 66.5181
INFO:root:main epoch 401 - d_loss -7.6040 - g_loss -50.2647 - time: 27.6653
INFO:root:best acc: seen 0.9059 - unseen 0.2229 - H 0.3577 - time 67.7355
INFO:root:main epoch 402 - d_loss -7.5651 - g_loss -50.0506 - time: 27.4935
INFO:root:best acc: seen 0.9090 - unseen 0.2082 - H 0.3388 - time 65.5467
INFO:root:main epoch 403 - d_loss -7.5457 - g_loss -49.9315 - time: 27.6616
INFO:root:best acc: seen 0.9054 - unseen 0.1983 - H 0.3254 - time 66.5356
INFO:root:main epoch 404 - d_loss -7.5628 - g_loss -50.1587 - time: 27.6157
INFO:root:best acc: seen 0.9069 - unseen 0.2249 - H 0.3604 - time 65.9398
INFO:root:main epoch 405 - d_loss -7.5205 - g_loss -49.9715 - time: 27.7258
INFO:root:best acc: seen 0.9072 - unseen 0.2219 - H 0.3565 - time 66.5558
INFO:root:main epoch 406 - d_loss -7.5324 - g_loss -50.1633 - time: 27.4751
INFO:root:best acc: seen

INFO:root:best acc: seen 0.9078 - unseen 0.2269 - H 0.3630 - time 67.2029
INFO:root:main epoch 454 - d_loss -7.2046 - g_loss -50.3198 - time: 27.6326
INFO:root:best acc: seen 0.9082 - unseen 0.2139 - H 0.3462 - time 66.9218
INFO:root:main epoch 455 - d_loss -7.1854 - g_loss -50.7548 - time: 27.5694
INFO:root:best acc: seen 0.9082 - unseen 0.1850 - H 0.3074 - time 67.2043
INFO:root:main epoch 456 - d_loss -7.1930 - g_loss -50.2898 - time: 27.5500
INFO:root:best acc: seen 0.9072 - unseen 0.1961 - H 0.3225 - time 66.7326
INFO:root:main epoch 457 - d_loss -7.1819 - g_loss -50.2663 - time: 27.5731
INFO:root:best acc: seen 0.9075 - unseen 0.2141 - H 0.3465 - time 67.0845
INFO:root:main epoch 458 - d_loss -7.1715 - g_loss -50.4901 - time: 27.7298
INFO:root:best acc: seen 0.9041 - unseen 0.2040 - H 0.3329 - time 67.4694
INFO:root:main epoch 459 - d_loss -7.1916 - g_loss -50.3278 - time: 27.7098
INFO:root:best acc: seen 0.9052 - unseen 0.2147 - H 0.3471 - time 67.5166
INFO:root:main epoch 460 -















INFO:root:main epoch 500 - d_loss -6.9519 - g_loss -50.7318 - time: 27.6088
INFO:root:best acc: seen 0.9060 - unseen 0.2116 - H 0.3430 - time 66.7054
INFO:root:main epoch 501 - d_loss -6.9793 - g_loss -50.6914 - time: 27.5999
INFO:root:best acc: seen 0.9058 - unseen 0.2043 - H 0.3335 - time 66.4938
INFO:root:main epoch 502 - d_loss -6.9517 - g_loss -50.4901 - time: 27.4749
INFO:root:best acc: seen 0.9051 - unseen 0.2016 - H 0.3298 - time 67.3885
INFO:root:main epoch 503 - d_loss -6.9630 - g_loss -50.4769 - time: 27.7413
INFO:root:best acc: seen 0.9056 - unseen 0.2072 - H 0.3372 - time 66.1468
INFO:root:main epoch 504 - d_loss -6.9602 - g_loss -50.6614 - time: 27.4860
INFO:root:best acc: seen 0.9063 - unseen 0.2049 - H 0.3343 - time 67.9124
INFO:root:main epoch 505 - d_loss -6.9334 - g_loss -50.5919 - time: 27.5547
INFO:root:best acc: seen 0.9072 - unseen 0.2127 - H 0.3447 - time 65.9683
INFO:root:main epoch 506 - d_loss -6.9350 - g_loss -50.6287 - time: 27.4827
INFO:root:best acc: seen

INFO:root:best acc: seen 0.9066 - unseen 0.2000 - H 0.3277 - time 68.0046
INFO:root:main epoch 554 - d_loss -6.7486 - g_loss -50.4276 - time: 27.6102
INFO:root:best acc: seen 0.9053 - unseen 0.2175 - H 0.3508 - time 68.4684
INFO:root:main epoch 555 - d_loss -6.7531 - g_loss -50.7217 - time: 27.6974
INFO:root:best acc: seen 0.9024 - unseen 0.2045 - H 0.3334 - time 68.7337
INFO:root:main epoch 556 - d_loss -6.7252 - g_loss -50.2575 - time: 27.6022
INFO:root:best acc: seen 0.9048 - unseen 0.1990 - H 0.3263 - time 69.0951
INFO:root:main epoch 557 - d_loss -6.7314 - g_loss -50.6069 - time: 27.6806
INFO:root:best acc: seen 0.9069 - unseen 0.2069 - H 0.3369 - time 68.4920
INFO:root:main epoch 558 - d_loss -6.7187 - g_loss -50.4201 - time: 27.4959
INFO:root:best acc: seen 0.9037 - unseen 0.2070 - H 0.3368 - time 67.4017
INFO:root:main epoch 559 - d_loss -6.7078 - g_loss -50.7385 - time: 27.7180
INFO:root:best acc: seen 0.9053 - unseen 0.2051 - H 0.3344 - time 67.2125
INFO:root:main epoch 560 -















INFO:root:main epoch 600 - d_loss -6.5978 - g_loss -50.6673 - time: 27.5715
INFO:root:best acc: seen 0.9049 - unseen 0.2120 - H 0.3435 - time 67.8828
INFO:root:main epoch 601 - d_loss -6.5798 - g_loss -50.2496 - time: 27.7748
INFO:root:best acc: seen 0.9045 - unseen 0.2024 - H 0.3307 - time 68.3089
INFO:root:main epoch 602 - d_loss -6.5938 - g_loss -50.5014 - time: 27.5741
INFO:root:best acc: seen 0.9028 - unseen 0.2010 - H 0.3288 - time 68.9729
INFO:root:main epoch 603 - d_loss -6.5654 - g_loss -50.5667 - time: 27.7383
INFO:root:best acc: seen 0.9065 - unseen 0.2128 - H 0.3446 - time 68.1596
INFO:root:main epoch 604 - d_loss -6.5753 - g_loss -50.4563 - time: 27.5601
INFO:root:best acc: seen 0.9062 - unseen 0.1834 - H 0.3050 - time 67.7313
INFO:root:main epoch 605 - d_loss -6.5859 - g_loss -50.5101 - time: 27.7636
INFO:root:best acc: seen 0.9053 - unseen 0.1876 - H 0.3108 - time 68.0139
INFO:root:main epoch 606 - d_loss -6.5558 - g_loss -50.4626 - time: 27.6707
INFO:root:best acc: seen

INFO:root:best acc: seen 0.9024 - unseen 0.2196 - H 0.3532 - time 67.4384
INFO:root:main epoch 654 - d_loss -6.4206 - g_loss -50.2318 - time: 27.5703
INFO:root:best acc: seen 0.9026 - unseen 0.2171 - H 0.3501 - time 68.2032
INFO:root:main epoch 655 - d_loss -6.4017 - g_loss -50.1446 - time: 27.5712
INFO:root:best acc: seen 0.9048 - unseen 0.2041 - H 0.3331 - time 67.6597
INFO:root:main epoch 656 - d_loss -6.3941 - g_loss -50.0747 - time: 27.6050
INFO:root:best acc: seen 0.9036 - unseen 0.1946 - H 0.3203 - time 67.7749
INFO:root:main epoch 657 - d_loss -6.4107 - g_loss -50.5861 - time: 27.5924
INFO:root:best acc: seen 0.9055 - unseen 0.1869 - H 0.3098 - time 68.0499
INFO:root:main epoch 658 - d_loss -6.4011 - g_loss -50.1855 - time: 27.6196
INFO:root:best acc: seen 0.9025 - unseen 0.1863 - H 0.3089 - time 67.9829
INFO:root:main epoch 659 - d_loss -6.3944 - g_loss -49.8844 - time: 27.6416
INFO:root:best acc: seen 0.9032 - unseen 0.2010 - H 0.3288 - time 68.3573
INFO:root:main epoch 660 -















INFO:root:main epoch 700 - d_loss -6.2860 - g_loss -50.0512 - time: 27.6683
INFO:root:best acc: seen 0.9020 - unseen 0.1857 - H 0.3080 - time 69.5067
INFO:root:main epoch 701 - d_loss -6.2759 - g_loss -50.1275 - time: 27.6704
INFO:root:best acc: seen 0.9041 - unseen 0.2034 - H 0.3321 - time 70.5445
INFO:root:main epoch 702 - d_loss -6.2628 - g_loss -49.9234 - time: 27.8030
INFO:root:best acc: seen 0.9039 - unseen 0.1922 - H 0.3169 - time 69.5558
INFO:root:main epoch 703 - d_loss -6.2619 - g_loss -49.9720 - time: 27.6669
INFO:root:best acc: seen 0.9061 - unseen 0.1828 - H 0.3042 - time 69.1404
INFO:root:main epoch 704 - d_loss -6.2494 - g_loss -50.1832 - time: 27.7304
INFO:root:best acc: seen 0.9054 - unseen 0.2007 - H 0.3285 - time 69.5267
INFO:root:main epoch 705 - d_loss -6.2625 - g_loss -50.0831 - time: 27.5438
INFO:root:best acc: seen 0.9060 - unseen 0.1847 - H 0.3068 - time 69.9347
INFO:root:main epoch 706 - d_loss -6.2627 - g_loss -50.3218 - time: 27.6963
INFO:root:best acc: seen

INFO:root:best acc: seen 0.9057 - unseen 0.2023 - H 0.3308 - time 69.5502
INFO:root:main epoch 754 - d_loss -6.1171 - g_loss -49.6222 - time: 27.6892
INFO:root:best acc: seen 0.9037 - unseen 0.1943 - H 0.3198 - time 68.7960
INFO:root:main epoch 755 - d_loss -6.1184 - g_loss -50.0742 - time: 27.7490
INFO:root:best acc: seen 0.9032 - unseen 0.1943 - H 0.3198 - time 69.2136
INFO:root:main epoch 756 - d_loss -6.1165 - g_loss -49.9895 - time: 27.6969
INFO:root:best acc: seen 0.9037 - unseen 0.1990 - H 0.3261 - time 69.2032
INFO:root:main epoch 757 - d_loss -6.1059 - g_loss -49.9455 - time: 27.8078
INFO:root:best acc: seen 0.9059 - unseen 0.1908 - H 0.3152 - time 69.4182
INFO:root:main epoch 758 - d_loss -6.1074 - g_loss -49.8343 - time: 27.7244
INFO:root:best acc: seen 0.9055 - unseen 0.1907 - H 0.3150 - time 69.1052
INFO:root:main epoch 759 - d_loss -6.1084 - g_loss -49.9480 - time: 27.7794
INFO:root:best acc: seen 0.9049 - unseen 0.1861 - H 0.3088 - time 68.6138
INFO:root:main epoch 760 -















INFO:root:main epoch 800 - d_loss -5.9603 - g_loss -49.4872 - time: 27.8152
INFO:root:best acc: seen 0.9077 - unseen 0.1934 - H 0.3188 - time 68.1470
INFO:root:main epoch 801 - d_loss -5.9691 - g_loss -49.6884 - time: 27.8047
INFO:root:best acc: seen 0.9040 - unseen 0.2140 - H 0.3461 - time 69.2732
INFO:root:main epoch 802 - d_loss -5.9842 - g_loss -49.9333 - time: 27.6795
INFO:root:best acc: seen 0.9012 - unseen 0.1899 - H 0.3137 - time 69.5726
INFO:root:main epoch 803 - d_loss -5.9718 - g_loss -49.7997 - time: 27.8170
INFO:root:best acc: seen 0.9039 - unseen 0.2031 - H 0.3317 - time 68.6864
INFO:root:main epoch 804 - d_loss -5.9934 - g_loss -49.6739 - time: 27.6677
INFO:root:best acc: seen 0.9017 - unseen 0.1929 - H 0.3178 - time 69.5237
INFO:root:main epoch 805 - d_loss -5.9517 - g_loss -49.9351 - time: 27.6392
INFO:root:best acc: seen 0.9040 - unseen 0.1936 - H 0.3189 - time 68.7806
INFO:root:main epoch 806 - d_loss -5.9626 - g_loss -49.9054 - time: 27.8432
INFO:root:best acc: seen

INFO:root:best acc: seen 0.9021 - unseen 0.1939 - H 0.3192 - time 68.4606
INFO:root:main epoch 854 - d_loss -5.8289 - g_loss -49.7661 - time: 27.7853
INFO:root:best acc: seen 0.9032 - unseen 0.1805 - H 0.3009 - time 69.8293
INFO:root:main epoch 855 - d_loss -5.8087 - g_loss -49.5948 - time: 27.7329
INFO:root:best acc: seen 0.9026 - unseen 0.1896 - H 0.3133 - time 69.8983
INFO:root:main epoch 856 - d_loss -5.8205 - g_loss -49.4991 - time: 27.6459
INFO:root:best acc: seen 0.9016 - unseen 0.1895 - H 0.3131 - time 69.7337
INFO:root:main epoch 857 - d_loss -5.8067 - g_loss -49.2558 - time: 27.6674
INFO:root:best acc: seen 0.9017 - unseen 0.1918 - H 0.3163 - time 69.7867
INFO:root:main epoch 858 - d_loss -5.8019 - g_loss -49.5866 - time: 27.7438
INFO:root:best acc: seen 0.9028 - unseen 0.1929 - H 0.3178 - time 69.5175
INFO:root:main epoch 859 - d_loss -5.7958 - g_loss -49.5697 - time: 27.7920
INFO:root:best acc: seen 0.9031 - unseen 0.1870 - H 0.3098 - time 68.2459
INFO:root:main epoch 860 -















INFO:root:main epoch 900 - d_loss -5.6920 - g_loss -49.1454 - time: 27.6276
INFO:root:best acc: seen 0.9056 - unseen 0.1884 - H 0.3119 - time 70.4509
INFO:root:main epoch 901 - d_loss -5.7103 - g_loss -49.6094 - time: 27.8058
INFO:root:best acc: seen 0.9039 - unseen 0.1967 - H 0.3230 - time 70.2750
INFO:root:main epoch 902 - d_loss -5.7061 - g_loss -49.4285 - time: 27.7906
INFO:root:best acc: seen 0.9018 - unseen 0.1803 - H 0.3005 - time 70.3469
INFO:root:main epoch 903 - d_loss -5.6998 - g_loss -49.0806 - time: 27.7887
INFO:root:best acc: seen 0.9027 - unseen 0.1817 - H 0.3025 - time 69.9342
INFO:root:main epoch 904 - d_loss -5.6872 - g_loss -49.4045 - time: 27.7030
INFO:root:best acc: seen 0.9067 - unseen 0.1803 - H 0.3008 - time 69.9999
INFO:root:main epoch 905 - d_loss -5.6752 - g_loss -49.4395 - time: 27.9287
INFO:root:best acc: seen 0.9023 - unseen 0.1916 - H 0.3161 - time 69.6929
INFO:root:main epoch 906 - d_loss -5.6801 - g_loss -49.4879 - time: 27.7253
INFO:root:best acc: seen

INFO:root:best acc: seen 0.9026 - unseen 0.1866 - H 0.3093 - time 69.2290
INFO:root:main epoch 954 - d_loss -5.5466 - g_loss -49.1646 - time: 27.7881
INFO:root:best acc: seen 0.9021 - unseen 0.1847 - H 0.3066 - time 69.4325
INFO:root:main epoch 955 - d_loss -5.5666 - g_loss -49.2070 - time: 27.8961
INFO:root:best acc: seen 0.9034 - unseen 0.1808 - H 0.3013 - time 69.0492
INFO:root:main epoch 956 - d_loss -5.5258 - g_loss -49.0059 - time: 27.8884
INFO:root:best acc: seen 0.9045 - unseen 0.1832 - H 0.3047 - time 68.2696
INFO:root:main epoch 957 - d_loss -5.5385 - g_loss -49.3349 - time: 28.0101
INFO:root:best acc: seen 0.9046 - unseen 0.1852 - H 0.3075 - time 69.1197
INFO:root:main epoch 958 - d_loss -5.5527 - g_loss -48.9695 - time: 27.8640
INFO:root:best acc: seen 0.9025 - unseen 0.1979 - H 0.3246 - time 68.4048
INFO:root:main epoch 959 - d_loss -5.5356 - g_loss -49.1400 - time: 27.8977
INFO:root:best acc: seen 0.9027 - unseen 0.1850 - H 0.3071 - time 68.3006
INFO:root:main epoch 960 -















INFO:root:main epoch 1000 - d_loss -5.4430 - g_loss -48.9256 - time: 27.7327
INFO:root:best acc: seen 0.9046 - unseen 0.2009 - H 0.3288 - time 69.3243
INFO:root:main epoch 1001 - d_loss -5.4315 - g_loss -49.0010 - time: 27.9896
INFO:root:best acc: seen 0.9038 - unseen 0.1929 - H 0.3179 - time 69.9357
INFO:root:main epoch 1002 - d_loss -5.4300 - g_loss -48.8798 - time: 27.9951
INFO:root:best acc: seen 0.9032 - unseen 0.1951 - H 0.3209 - time 69.2359
INFO:root:main epoch 1003 - d_loss -5.4330 - g_loss -48.9174 - time: 27.9603
INFO:root:best acc: seen 0.9027 - unseen 0.1922 - H 0.3169 - time 69.2263
INFO:root:main epoch 1004 - d_loss -5.4289 - g_loss -48.8912 - time: 27.9616
INFO:root:best acc: seen 0.9053 - unseen 0.1923 - H 0.3172 - time 68.5799
INFO:root:main epoch 1005 - d_loss -5.4042 - g_loss -48.9191 - time: 27.9768
INFO:root:best acc: seen 0.9044 - unseen 0.1965 - H 0.3228 - time 71.1871
INFO:root:main epoch 1006 - d_loss -5.4227 - g_loss -48.9159 - time: 27.9675
INFO:root:best ac

INFO:root:best acc: seen 0.9032 - unseen 0.1977 - H 0.3243 - time 70.6236
INFO:root:main epoch 1054 - d_loss -5.3005 - g_loss -48.5373 - time: 28.0597
INFO:root:best acc: seen 0.9059 - unseen 0.1927 - H 0.3178 - time 68.7584
INFO:root:main epoch 1055 - d_loss -5.2971 - g_loss -48.5082 - time: 28.0705
INFO:root:best acc: seen 0.9063 - unseen 0.1906 - H 0.3149 - time 69.3689
INFO:root:main epoch 1056 - d_loss -5.3031 - g_loss -48.8981 - time: 28.1021
INFO:root:best acc: seen 0.9028 - unseen 0.2022 - H 0.3304 - time 69.8143
INFO:root:main epoch 1057 - d_loss -5.2871 - g_loss -48.4160 - time: 28.0854
INFO:root:best acc: seen 0.9068 - unseen 0.1782 - H 0.2978 - time 68.5643
INFO:root:main epoch 1058 - d_loss -5.2782 - g_loss -48.7200 - time: 28.1497
INFO:root:best acc: seen 0.9057 - unseen 0.1721 - H 0.2893 - time 69.2576
INFO:root:main epoch 1059 - d_loss -5.2752 - g_loss -48.6723 - time: 28.0056
INFO:root:best acc: seen 0.9050 - unseen 0.1910 - H 0.3154 - time 69.7789
INFO:root:main epoch















INFO:root:main epoch 1100 - d_loss -5.1794 - g_loss -48.3151 - time: 28.0867
INFO:root:best acc: seen 0.9029 - unseen 0.1890 - H 0.3125 - time 70.1836
INFO:root:main epoch 1101 - d_loss -5.1924 - g_loss -48.3592 - time: 28.1182
INFO:root:best acc: seen 0.9052 - unseen 0.1838 - H 0.3056 - time 69.2602
INFO:root:main epoch 1102 - d_loss -5.1822 - g_loss -48.5479 - time: 28.0666
INFO:root:best acc: seen 0.9039 - unseen 0.1964 - H 0.3226 - time 69.9793
INFO:root:main epoch 1103 - d_loss -5.1910 - g_loss -48.4933 - time: 27.9924
INFO:root:best acc: seen 0.9012 - unseen 0.1909 - H 0.3151 - time 71.0123
INFO:root:main epoch 1104 - d_loss -5.1894 - g_loss -48.3637 - time: 28.2354
INFO:root:best acc: seen 0.9035 - unseen 0.1862 - H 0.3088 - time 70.0833
INFO:root:main epoch 1105 - d_loss -5.1789 - g_loss -48.4480 - time: 27.9762
INFO:root:best acc: seen 0.9029 - unseen 0.1824 - H 0.3034 - time 70.4644
INFO:root:main epoch 1106 - d_loss -5.1883 - g_loss -48.3310 - time: 28.2429
INFO:root:best ac

INFO:root:best acc: seen 0.9031 - unseen 0.1958 - H 0.3219 - time 68.6082
INFO:root:main epoch 1154 - d_loss -5.0787 - g_loss -48.2979 - time: 28.1809
INFO:root:best acc: seen 0.9043 - unseen 0.1806 - H 0.3010 - time 71.0272
INFO:root:main epoch 1155 - d_loss -5.0750 - g_loss -48.3394 - time: 28.2119
INFO:root:best acc: seen 0.9040 - unseen 0.1960 - H 0.3222 - time 68.7116
INFO:root:main epoch 1156 - d_loss -5.0656 - g_loss -48.0420 - time: 28.3336
INFO:root:best acc: seen 0.9035 - unseen 0.1833 - H 0.3048 - time 70.5454
INFO:root:main epoch 1157 - d_loss -5.0701 - g_loss -48.3078 - time: 28.1333
INFO:root:best acc: seen 0.9055 - unseen 0.1956 - H 0.3217 - time 70.6167
INFO:root:main epoch 1158 - d_loss -5.0603 - g_loss -48.2260 - time: 28.2334
INFO:root:best acc: seen 0.9028 - unseen 0.1805 - H 0.3009 - time 70.0966
INFO:root:main epoch 1159 - d_loss -5.0470 - g_loss -48.1664 - time: 28.1642
INFO:root:best acc: seen 0.9035 - unseen 0.1796 - H 0.2997 - time 69.8432
INFO:root:main epoch















INFO:root:main epoch 1200 - d_loss -4.9708 - g_loss -47.9787 - time: 28.2595
INFO:root:best acc: seen 0.9032 - unseen 0.1855 - H 0.3078 - time 70.8956
INFO:root:main epoch 1201 - d_loss -4.9698 - g_loss -48.0259 - time: 28.3075
INFO:root:best acc: seen 0.9053 - unseen 0.1851 - H 0.3074 - time 69.5130
INFO:root:main epoch 1202 - d_loss -4.9497 - g_loss -48.0018 - time: 28.3295
INFO:root:best acc: seen 0.9061 - unseen 0.1860 - H 0.3087 - time 70.7731
INFO:root:main epoch 1203 - d_loss -4.9434 - g_loss -48.1356 - time: 28.3773
INFO:root:best acc: seen 0.9043 - unseen 0.1957 - H 0.3218 - time 71.3909
INFO:root:main epoch 1204 - d_loss -4.9574 - g_loss -48.0518 - time: 28.4954
INFO:root:best acc: seen 0.9055 - unseen 0.1829 - H 0.3043 - time 70.3682
INFO:root:main epoch 1205 - d_loss -4.9589 - g_loss -48.0802 - time: 28.1641
INFO:root:best acc: seen 0.9051 - unseen 0.1903 - H 0.3145 - time 70.6282
INFO:root:main epoch 1206 - d_loss -4.9433 - g_loss -48.1275 - time: 28.3929
INFO:root:best ac

INFO:root:best acc: seen 0.9053 - unseen 0.1974 - H 0.3241 - time 70.6950
INFO:root:main epoch 1254 - d_loss -4.8592 - g_loss -47.7977 - time: 28.2747
INFO:root:best acc: seen 0.9020 - unseen 0.1852 - H 0.3073 - time 69.7011
INFO:root:main epoch 1255 - d_loss -4.8520 - g_loss -47.9385 - time: 28.2636
INFO:root:best acc: seen 0.9041 - unseen 0.1949 - H 0.3206 - time 70.8897
INFO:root:main epoch 1256 - d_loss -4.8626 - g_loss -47.9100 - time: 28.3660
INFO:root:best acc: seen 0.9055 - unseen 0.1787 - H 0.2985 - time 70.7173
INFO:root:main epoch 1257 - d_loss -4.8624 - g_loss -47.9502 - time: 28.3386
INFO:root:best acc: seen 0.9062 - unseen 0.1762 - H 0.2950 - time 70.3819
INFO:root:main epoch 1258 - d_loss -4.8435 - g_loss -47.9676 - time: 28.3198
INFO:root:best acc: seen 0.9010 - unseen 0.1965 - H 0.3227 - time 71.8247
INFO:root:main epoch 1259 - d_loss -4.8461 - g_loss -47.9970 - time: 28.3226
INFO:root:best acc: seen 0.9056 - unseen 0.1857 - H 0.3082 - time 71.4806
INFO:root:main epoch















INFO:root:main epoch 1300 - d_loss -4.7619 - g_loss -47.5420 - time: 28.3480
INFO:root:best acc: seen 0.9040 - unseen 0.1739 - H 0.2917 - time 70.0011
INFO:root:main epoch 1301 - d_loss -4.7661 - g_loss -47.5918 - time: 28.3630
INFO:root:best acc: seen 0.9046 - unseen 0.1767 - H 0.2957 - time 71.8909
INFO:root:main epoch 1302 - d_loss -4.7526 - g_loss -47.7503 - time: 28.5059
INFO:root:best acc: seen 0.9029 - unseen 0.1810 - H 0.3015 - time 70.4019
INFO:root:main epoch 1303 - d_loss -4.7691 - g_loss -47.7170 - time: 28.3472
INFO:root:best acc: seen 0.9044 - unseen 0.1917 - H 0.3163 - time 71.0383
INFO:root:main epoch 1304 - d_loss -4.7646 - g_loss -47.7604 - time: 28.2991
INFO:root:best acc: seen 0.9037 - unseen 0.1692 - H 0.2850 - time 71.5940
INFO:root:main epoch 1305 - d_loss -4.7734 - g_loss -47.4563 - time: 28.4251
INFO:root:best acc: seen 0.9045 - unseen 0.1816 - H 0.3025 - time 70.2209
INFO:root:main epoch 1306 - d_loss -4.7503 - g_loss -47.6546 - time: 28.2676
INFO:root:best ac

INFO:root:best acc: seen 0.9032 - unseen 0.1900 - H 0.3140 - time 70.9823
INFO:root:main epoch 1354 - d_loss -4.6542 - g_loss -47.2887 - time: 28.3322
INFO:root:best acc: seen 0.9036 - unseen 0.1934 - H 0.3187 - time 71.9571
INFO:root:main epoch 1355 - d_loss -4.6732 - g_loss -47.4398 - time: 28.5163
INFO:root:best acc: seen 0.9067 - unseen 0.1799 - H 0.3002 - time 71.9803
INFO:root:main epoch 1356 - d_loss -4.6560 - g_loss -47.3754 - time: 28.3525
INFO:root:best acc: seen 0.9024 - unseen 0.1893 - H 0.3129 - time 70.2182
INFO:root:main epoch 1357 - d_loss -4.6519 - g_loss -47.2522 - time: 28.3264
INFO:root:best acc: seen 0.9028 - unseen 0.1840 - H 0.3057 - time 72.6111
INFO:root:main epoch 1358 - d_loss -4.6482 - g_loss -47.2955 - time: 28.2837
INFO:root:best acc: seen 0.9055 - unseen 0.1788 - H 0.2987 - time 72.4573
INFO:root:main epoch 1359 - d_loss -4.6338 - g_loss -47.3580 - time: 28.4154
INFO:root:best acc: seen 0.9040 - unseen 0.1779 - H 0.2973 - time 72.4345
INFO:root:main epoch















INFO:root:main epoch 1400 - d_loss -4.5734 - g_loss -47.3329 - time: 28.4601
INFO:root:best acc: seen 0.9038 - unseen 0.1881 - H 0.3114 - time 71.8081
INFO:root:main epoch 1401 - d_loss -4.5620 - g_loss -47.2106 - time: 28.3265
INFO:root:best acc: seen 0.9065 - unseen 0.1819 - H 0.3029 - time 72.0825
INFO:root:main epoch 1402 - d_loss -4.5806 - g_loss -47.0544 - time: 28.5223
INFO:root:best acc: seen 0.9048 - unseen 0.1667 - H 0.2815 - time 70.8259
INFO:root:main epoch 1403 - d_loss -4.5582 - g_loss -47.1957 - time: 28.3688
INFO:root:best acc: seen 0.9018 - unseen 0.1810 - H 0.3015 - time 71.1880
INFO:root:main epoch 1404 - d_loss -4.5667 - g_loss -47.2987 - time: 28.3199
INFO:root:best acc: seen 0.9017 - unseen 0.1806 - H 0.3009 - time 71.5512
INFO:root:main epoch 1405 - d_loss -4.5652 - g_loss -46.9652 - time: 28.5308
INFO:root:best acc: seen 0.8997 - unseen 0.1718 - H 0.2886 - time 71.8212
INFO:root:main epoch 1406 - d_loss -4.5566 - g_loss -47.1600 - time: 28.4650
INFO:root:best ac

INFO:root:best acc: seen 0.9029 - unseen 0.1804 - H 0.3006 - time 70.3418
INFO:root:main epoch 1454 - d_loss -4.4870 - g_loss -47.0519 - time: 28.5617
INFO:root:best acc: seen 0.9020 - unseen 0.1759 - H 0.2943 - time 71.8680
INFO:root:main epoch 1455 - d_loss -4.4750 - g_loss -46.8597 - time: 28.3657
INFO:root:best acc: seen 0.9050 - unseen 0.1854 - H 0.3077 - time 71.4770
INFO:root:main epoch 1456 - d_loss -4.4816 - g_loss -46.9191 - time: 28.6841
INFO:root:best acc: seen 0.9049 - unseen 0.1833 - H 0.3049 - time 74.2055
INFO:root:main epoch 1457 - d_loss -4.4677 - g_loss -46.9581 - time: 28.3657
INFO:root:best acc: seen 0.9028 - unseen 0.1794 - H 0.2993 - time 71.0574
INFO:root:main epoch 1458 - d_loss -4.4519 - g_loss -46.9929 - time: 28.6718
INFO:root:best acc: seen 0.9045 - unseen 0.1813 - H 0.3020 - time 74.2632
INFO:root:main epoch 1459 - d_loss -4.4592 - g_loss -46.8265 - time: 28.3481
INFO:root:best acc: seen 0.9067 - unseen 0.1783 - H 0.2979 - time 73.2148
INFO:root:main epoch















INFO:root:main epoch 1500 - d_loss -4.3983 - g_loss -46.9100 - time: 28.7020
INFO:root:best acc: seen 0.9034 - unseen 0.1712 - H 0.2878 - time 71.1719
INFO:root:main epoch 1501 - d_loss -4.3931 - g_loss -46.6300 - time: 28.4753
INFO:root:best acc: seen 0.9037 - unseen 0.1708 - H 0.2874 - time 73.9975
INFO:root:main epoch 1502 - d_loss -4.4012 - g_loss -46.5699 - time: 28.6214
INFO:root:best acc: seen 0.9022 - unseen 0.1788 - H 0.2984 - time 72.6485
INFO:root:main epoch 1503 - d_loss -4.4073 - g_loss -46.7783 - time: 28.5378
INFO:root:best acc: seen 0.9042 - unseen 0.1834 - H 0.3049 - time 71.7613
INFO:root:main epoch 1504 - d_loss -4.3785 - g_loss -46.5883 - time: 28.5328
INFO:root:best acc: seen 0.9046 - unseen 0.1877 - H 0.3108 - time 71.7303
INFO:root:main epoch 1505 - d_loss -4.3909 - g_loss -46.7366 - time: 28.5101
INFO:root:best acc: seen 0.9060 - unseen 0.1898 - H 0.3138 - time 71.9834
INFO:root:main epoch 1506 - d_loss -4.3950 - g_loss -46.7826 - time: 28.5622
INFO:root:best ac

INFO:root:best acc: seen 0.9034 - unseen 0.1978 - H 0.3245 - time 72.9374
INFO:root:main epoch 1554 - d_loss -4.3088 - g_loss -46.4366 - time: 28.7472
INFO:root:best acc: seen 0.9032 - unseen 0.1748 - H 0.2929 - time 72.1837
INFO:root:main epoch 1555 - d_loss -4.2977 - g_loss -46.3668 - time: 28.6860
INFO:root:best acc: seen 0.9021 - unseen 0.1862 - H 0.3086 - time 74.5550
INFO:root:main epoch 1556 - d_loss -4.3014 - g_loss -46.4067 - time: 28.6955
INFO:root:best acc: seen 0.8996 - unseen 0.1780 - H 0.2972 - time 74.1341
INFO:root:main epoch 1557 - d_loss -4.3072 - g_loss -46.5174 - time: 28.7062
INFO:root:best acc: seen 0.9028 - unseen 0.1715 - H 0.2883 - time 72.6188
INFO:root:main epoch 1558 - d_loss -4.3021 - g_loss -46.4699 - time: 28.6611
INFO:root:best acc: seen 0.8998 - unseen 0.1857 - H 0.3078 - time 73.4540
INFO:root:main epoch 1559 - d_loss -4.2697 - g_loss -46.3620 - time: 28.6361
INFO:root:best acc: seen 0.9003 - unseen 0.1820 - H 0.3028 - time 73.5960
INFO:root:main epoch















INFO:root:main epoch 1600 - d_loss -4.2376 - g_loss -46.2732 - time: 28.6273
INFO:root:best acc: seen 0.9030 - unseen 0.1758 - H 0.2944 - time 73.3556
INFO:root:main epoch 1601 - d_loss -4.2098 - g_loss -46.1595 - time: 28.5825
INFO:root:best acc: seen 0.8977 - unseen 0.1827 - H 0.3036 - time 74.4282
INFO:root:main epoch 1602 - d_loss -4.2244 - g_loss -46.2779 - time: 28.8580
INFO:root:best acc: seen 0.9028 - unseen 0.1721 - H 0.2892 - time 74.1909
INFO:root:main epoch 1603 - d_loss -4.2332 - g_loss -45.9219 - time: 28.7015
INFO:root:best acc: seen 0.8995 - unseen 0.1792 - H 0.2989 - time 74.5039
INFO:root:main epoch 1604 - d_loss -4.2245 - g_loss -46.3778 - time: 28.7279
INFO:root:best acc: seen 0.9020 - unseen 0.1685 - H 0.2840 - time 74.7047
INFO:root:main epoch 1605 - d_loss -4.2227 - g_loss -46.0138 - time: 28.6786
INFO:root:best acc: seen 0.9001 - unseen 0.1830 - H 0.3042 - time 73.6631
INFO:root:main epoch 1606 - d_loss -4.2304 - g_loss -46.3655 - time: 28.8355
INFO:root:best ac

INFO:root:best acc: seen 0.9009 - unseen 0.1667 - H 0.2814 - time 72.5942
INFO:root:main epoch 1654 - d_loss -4.1400 - g_loss -46.1340 - time: 28.9154
INFO:root:best acc: seen 0.9021 - unseen 0.1781 - H 0.2975 - time 72.5725
INFO:root:main epoch 1655 - d_loss -4.1373 - g_loss -46.0957 - time: 28.7799
INFO:root:best acc: seen 0.9003 - unseen 0.1889 - H 0.3122 - time 77.7209
INFO:root:main epoch 1656 - d_loss -4.1412 - g_loss -46.0189 - time: 28.8519
INFO:root:best acc: seen 0.9021 - unseen 0.1920 - H 0.3166 - time 76.4343
INFO:root:main epoch 1657 - d_loss -4.1428 - g_loss -45.8943 - time: 28.8873
INFO:root:best acc: seen 0.9013 - unseen 0.1936 - H 0.3187 - time 76.9659
INFO:root:main epoch 1658 - d_loss -4.1404 - g_loss -45.8156 - time: 29.0538
INFO:root:best acc: seen 0.9058 - unseen 0.1910 - H 0.3154 - time 82.0097
INFO:root:main epoch 1659 - d_loss -4.1264 - g_loss -45.9647 - time: 28.9965
INFO:root:best acc: seen 0.9037 - unseen 0.1796 - H 0.2996 - time 85.0462
INFO:root:main epoch















INFO:root:main epoch 1700 - d_loss -4.0701 - g_loss -45.6382 - time: 28.9880
INFO:root:best acc: seen 0.9041 - unseen 0.1819 - H 0.3028 - time 73.8033
INFO:root:main epoch 1701 - d_loss -4.0757 - g_loss -45.7112 - time: 28.9081
INFO:root:best acc: seen 0.8996 - unseen 0.1795 - H 0.2993 - time 74.0811
INFO:root:main epoch 1702 - d_loss -4.0645 - g_loss -45.6814 - time: 29.1383
INFO:root:best acc: seen 0.9048 - unseen 0.1794 - H 0.2994 - time 75.2214
INFO:root:main epoch 1703 - d_loss -4.0770 - g_loss -45.8042 - time: 28.8317
INFO:root:best acc: seen 0.9027 - unseen 0.1807 - H 0.3011 - time 73.6330
INFO:root:main epoch 1704 - d_loss -4.0768 - g_loss -45.7076 - time: 28.9557
INFO:root:best acc: seen 0.9025 - unseen 0.1895 - H 0.3132 - time 77.2209
INFO:root:main epoch 1705 - d_loss -4.0535 - g_loss -45.5922 - time: 28.8712
INFO:root:best acc: seen 0.8974 - unseen 0.1897 - H 0.3132 - time 75.0421
INFO:root:main epoch 1706 - d_loss -4.0667 - g_loss -45.7189 - time: 29.0605
INFO:root:best ac

INFO:root:best acc: seen 0.9023 - unseen 0.1784 - H 0.2979 - time 75.9963
INFO:root:main epoch 1754 - d_loss -3.9922 - g_loss -45.3638 - time: 29.1535
INFO:root:best acc: seen 0.9015 - unseen 0.1817 - H 0.3024 - time 76.2638
INFO:root:main epoch 1755 - d_loss -4.0123 - g_loss -45.6364 - time: 29.1585
INFO:root:best acc: seen 0.9013 - unseen 0.1788 - H 0.2985 - time 75.7345
INFO:root:main epoch 1756 - d_loss -3.9918 - g_loss -45.3669 - time: 29.0986
INFO:root:best acc: seen 0.9008 - unseen 0.1804 - H 0.3006 - time 73.9005
INFO:root:main epoch 1757 - d_loss -3.9858 - g_loss -45.4033 - time: 29.1966
INFO:root:best acc: seen 0.9008 - unseen 0.1816 - H 0.3023 - time 73.3223
INFO:root:main epoch 1758 - d_loss -3.9901 - g_loss -45.3018 - time: 29.1787
INFO:root:best acc: seen 0.9047 - unseen 0.1745 - H 0.2926 - time 74.2397
INFO:root:main epoch 1759 - d_loss -3.9929 - g_loss -45.6959 - time: 29.1981
INFO:root:best acc: seen 0.9014 - unseen 0.1867 - H 0.3094 - time 75.0563
INFO:root:main epoch

KeyboardInterrupt: 

In [4]:
train_features = ds.train_features()
print(train_features.shape)

train_attributes = ds.train_attributes()
print(train_attributes, train_attributes.shape)

train_labels = ds.train_labels()
print(train_labels, train_labels.shape)

unseen_classes = ds.unseen_classes()
print(unseen_classes, unseen_classes.shape)

seen_classes = ds.seen_classes()
print(seen_classes, seen_classes.shape)

(23527, 2048)
[[0.00575881 0.003829   0.         ... 0.03639079 0.13208508 0.01148699]
 [0.         0.00507555 0.         ... 0.15675628 0.09070911 0.01425057]
 [0.00575881 0.003829   0.         ... 0.03639079 0.13208508 0.01148699]
 ...
 [0.         0.06587863 0.         ... 0.02108505 0.10218637 0.0332632 ]
 [0.0084177  0.01262655 0.         ... 0.02104426 0.04138143 0.02316552]
 [0.03877321 0.15834625 0.         ... 0.1760296  0.07107783 0.30279845]] (23527, 85)
[42 21 42 ... 39 18 45] (23527,)
[ 6  8 22 23 29 30 33 40 46 49] (10,)
[ 0  1  2  3  4  5  7  9 10 11 12 13 14 15 16 17 18 19 20 21 24 25 26 27
 28 31 32 34 35 36 37 38 39 41 42 43 44 45 47 48] (40,)


In [5]:
seed = 1985
batch_size = 4096
train_feat_ds = tf.data.Dataset.from_tensor_slices(train_features)
train_feat_ds = train_feat_ds.shuffle(buffer_size=train_features.shape[0], seed=seed).batch(batch_size)

train_att_ds = tf.data.Dataset.from_tensor_slices(train_attributes)
train_att_ds = train_att_ds.shuffle(buffer_size=train_attributes.shape[0], seed=seed).batch(batch_size)

train_label_ds = tf.data.Dataset.from_tensor_slices(train_labels)
train_label_ds = train_label_ds.shuffle(buffer_size=train_labels.shape[0], seed=seed).batch(batch_size)

attribute_seen = tf.constant(ds.attribute_seen())
print(attribute_seen, attribute_seen.shape)

tf.Tensor(
[[-0.00375358 -0.00375358 -0.00375358 ...  0.00882092  0.03640974
   0.03145501]
 [ 0.12045617  0.00426584  0.         ...  0.17996307  0.0618086
   0.03495531]
 [ 0.26584458  0.20652363  0.         ...  0.05026821  0.04274552
   0.04915256]
 ...
 [ 0.03877321  0.15834625  0.         ...  0.1760296   0.07107783
   0.30279845]
 [ 0.22516498  0.15266022  0.         ...  0.12733492  0.10009693
   0.01771   ]
 [ 0.19613947  0.1966714   0.         ...  0.01787277  0.06698743
   0.258836  ]], shape=(40, 85), dtype=float32) (40, 85)


# Results


INFO:root:best acc: seen 0.8974 - unseen 0.1791 - H 0.2986 - time 74.9681
INFO:root:main epoch 1782 - d_loss -3.9678 - g_loss -45.5081 - time: 29.2714