## IMPORTS + look for TPU

In [1]:
#@title IMPORTS+ nvidia-smi
from tensorflow.keras import preprocessing
import os
import json
from io import BytesIO
import urllib.request
import numpy as np
import matplotlib.pyplot as plt
import sklearn
import cv2
import tensorflow as tf
from PIL import Image
from tensorflow import keras
from tensorflow.keras import backend as K
from tensorflow.python.framework.ops import disable_eager_execution,enable_eager_execution
from tensorflow.image import rgb_to_grayscale

USE_TPU=False
if 'COLAB_TPU_ADDR' not in os.environ:
  print('ERROR: Not connected to a TPU runtime; please see the first cell in this notebook for instructions!')
  !nvidia-smi
else:
  resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
#tf.config.experimental_connect_to_cluster(resolver)
# This is the TPU initialization code that has to be at the beginning.
#tf.tpu.experimental.initialize_tpu_system(resolver)
print("All devices: ", tf.config.list_logical_devices('TPU'))
    

print(tf.__version__)

ERROR: Not connected to a TPU runtime; please see the first cell in this notebook for instructions!
Fri Nov 11 13:35:00 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   52C    P8    10W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                    

In [2]:
#@title Mount Drive
from google.colab import drive
if USE_TPU==True:
  DRIVE_DIR = '/content/drive/MyDrive'
else:
  DRIVE_DIR = '/content/drive/My Drive'
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
#if dataset_type == "kbmg" or dataset_type == "shoes" or dataset_type == "wedding" :
BG_PATHS = np.array([os.path.join(dp, f) for dp, dn, fn in os.walk("BGs") for f in fn])
print("Found", len(BG_PATHS),"backgrounds.")

Found 1026 backgrounds.


In [7]:
cats_dict = [
        {'id': 0, 'name': 'dresses',   'nn_output_id': 0, 'feed_cats': [331, 332, 333, 334, 335, 336, 330]},
        {'id': 1, 'name': 'tops',      'nn_output_id': 1, 'feed_cats': [241, 242, 243, 240, 291, 290, 341, 342, 343, 344, 340, 351, 352, 353, 350]},
        {'id': 2, 'name': 'skirts',    'nn_output_id': 2, 'feed_cats': [321, 322, 323, 324, 325, 326, 327, 328, 320]},
        {'id': 3, 'name': 'outerwear', 'nn_output_id': 3, 'feed_cats': [271, 272, 273, 274, 275, 276, 277, 278, 270, 281, 282, 283, 284, 280, 301, 302, 300]},
        {'id': 4, 'name': 'leggings',  'nn_output_id': 4, 'feed_cats': [311, 312, 313, 314, 315, 316, 310]},
        {'id': 5, 'name': 'footwear',  'nn_output_id': 0, 'feed_cats': [251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 250]},
        {'id': 6, 'name': 'bags',      'nn_output_id': 0, 'feed_cats': [371, 372, 373, 374, 375, 376, 370]},
]
if dataset_type == "kbmg":
  NOBG_TEST_FRAC =0.10
  UNPROC_TEST_FRAC =0.05
  N_CATEGORIES = 5
elif dataset_type=="wedding":
  NOBG_TEST_FRAC =0.15
  UNPROC_TEST_FRAC =0.15
  N_CATEGORIES = 5
elif dataset_type in ["shoes", "bags", "hats"]:
  NOBG_TEST_FRAC =0.15
  UNPROC_TEST_FRAC =0.04
  N_CATEGORIES = 1


## MODEL Hyperparams

In [31]:
BATCH_SIZE = 32
N_ENCODING = 256
N_NOISE = 256
N_CAT_EMBEDDING = 16
IMG_SHAPE=(300,300)
INPUT_SHAPE=(300,300,3)
BG_PROBABILITY = 0.0
if dataset_type == "kbmg" or dataset_type=="wedding":
  N_CATEGORIES = 5
elif dataset_type=="shoes" or dataset_type=="bags":
  N_CATEGORIES=1

## Define discriminator, generator & WGAN model

In [33]:
#@title DISCRIMINATOR

from tensorflow.keras.layers import Input, Concatenate, Reshape, Dense,\
                                    Embedding,GlobalAveragePooling2D,\
                                    BatchNormalization,Flatten,Dropout,\
                                    Conv2DTranspose, LeakyReLU, Conv2D
from tensorflow.keras.models import Model   
from tensorflow.keras import layers
def conv_block(x, filters, activation, kernel_size=(3, 3),strides=(1, 1),
              padding="same",use_bias=True,use_bn=False,use_dropout=False,drop_value=0.5):
    x = layers.Conv2D(filters, kernel_size, strides=(1,1), padding=padding, 
                      use_bias=use_bias)(x)
    x = activation(x)
    x = layers.Conv2D(filters, kernel_size, strides=strides, padding=padding, 
                      use_bias=use_bias)(x)
    if use_bn:
        x = layers.BatchNormalization()(x)
    x = activation(x)
    if use_dropout:
        x = layers.Dropout(drop_value)(x)
    return x
def get_discriminator(n_used_cats=0):
    img_input = Input(INPUT_SHAPE)
    cat_input = Input((1,),dtype=tf.dtypes.int32)
    enc_input = Input((N_ENCODING,))    
    embedded_category_input = Embedding(input_dim = n_used_cats,
                                        output_dim = N_CAT_EMBEDDING,
                                        input_length = 1)(cat_input)
    embedded_category_input = Reshape((N_CAT_EMBEDDING,))(embedded_category_input)
    embedded_category_input = Dense(38*38*2)(embedded_category_input)
    embedded_category_input = Reshape((38,38,2))(embedded_category_input)
    
    embedded_encoding_input = Dense(19*19*2)(enc_input)
    embedded_encoding_input = Reshape((19,19,2))(embedded_encoding_input)
    

    #x = Concatenate(axis=1)([backbone_output,enc_input,embedded_category_input])  
    x = img_input
    x = conv_block(x,64,kernel_size=(5, 5),strides=(2, 2),use_bn=False,
                   use_bias=True,activation=layers.LeakyReLU(0.2),
                   use_dropout=False,drop_value=0.3)  
    x = conv_block(x,128,kernel_size=(5, 5),strides=(2, 2),use_bn=False,
                   use_bias=True,activation=layers.LeakyReLU(0.2),
                   use_dropout=True,drop_value=0.3)  
    x = conv_block(x,128,kernel_size=(3, 3),strides=(2, 2),use_bn=False,
                   use_bias=True,activation=layers.LeakyReLU(0.2),
                   use_dropout=True,drop_value=0.3)  
    x = Concatenate(axis=-1)([x,embedded_category_input])  

    x = conv_block(x,256,kernel_size=(3, 3),strides=(2, 2),use_bn=False,
                   use_bias=True,activation=layers.LeakyReLU(0.2),
                   use_dropout=True,drop_value=0.3)  
    x = Concatenate(axis=-1)([x,embedded_encoding_input])  
    x = conv_block(x,256,kernel_size=(3, 3),strides=(2, 2),use_bn=False,
                   use_bias=True,activation=layers.LeakyReLU(0.2),
                   use_dropout=True,drop_value=0.3)  

    x = layers.Flatten()(x)
    x = layers.Dropout(0.2)(x)
    x = layers.Dense(1)(x)

    model = Model(inputs =[img_input,cat_input,enc_input] ,
                  outputs=x, name="Discriminator")
    return model
d_model = get_discriminator(N_CATEGORIES)
d_model.summary()


Model: "Discriminator"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 300, 300, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 300, 300, 64  4864        ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 leaky_re_lu (LeakyReLU)        multiple             0           ['conv2d[0][0]',                 
                                                                  'conv2d_1[0][0]']   

In [34]:
#@title GENERATOR
def upsample_block(x,up_dim,
    filters,activation,kernel_size=(3, 3),strides=(1, 1),
    padding="same",use_bn=False,use_bias=True,use_dropout=False,drop_value=0.3):
    x = layers.UpSampling2D(up_dim)(x)
    x = layers.Conv2D(filters, kernel_size, strides=strides, padding=padding, 
                      use_bias=use_bias)(x)
    if activation:
        x = activation(x)
    x = layers.Conv2D(filters, kernel_size, strides=strides, padding=padding, 
                      use_bias=use_bias)(x)
    if use_bn:
        x = layers.BatchNormalization()(x)
    if activation:
        x = activation(x)
    if use_dropout:
        x = layers.Dropout(drop_value)(x)
    return x
def get_generator(n_used_cats=0):
  noise_input = Input((N_NOISE,))
  cat_input = Input((1,),dtype=tf.dtypes.int32)
  enc_input = Input((N_ENCODING,))    
  embedded_category_input = layers.Embedding(input_dim = n_used_cats,
                                                    output_dim = N_CAT_EMBEDDING,
                                                    input_length = 1)(cat_input)
  embedded_category_input = Reshape((N_CAT_EMBEDDING,))(embedded_category_input)

  side_size = 5
  n_noise_filters = 32
  n_cat_filters = 16
  n_enc_filters = 32
  n_units =  side_size*side_size

  processed_noise = Dense(n_units*n_noise_filters)(noise_input)
  processed_noise = BatchNormalization()(processed_noise)
  processed_noise = LeakyReLU(0.5)(processed_noise)
  processed_noise = Reshape((side_size,side_size,n_noise_filters))(processed_noise)

  processed_enc = Dense(n_units*n_enc_filters)(enc_input)
  processed_enc = BatchNormalization()(processed_enc)
  processed_enc = LeakyReLU(0.5)(processed_enc)
  processed_enc = Reshape((side_size,side_size,n_enc_filters))(processed_enc)

  processed_cat = Dense(n_units*n_cat_filters)(embedded_category_input)
  processed_cat = BatchNormalization()(processed_cat)
  processed_cat = LeakyReLU(0.5)(processed_cat)
  processed_cat = Reshape((side_size,side_size,n_cat_filters))(processed_cat)

  x = Concatenate(axis=-1)([processed_cat,processed_noise,processed_enc])

  x = upsample_block(x,(2,2),256,layers.LeakyReLU(0.2),kernel_size=(3,3),strides=(1, 1),
        use_bias=False,use_bn=True,padding="same",use_dropout=False,)
  x = upsample_block(x,(3,3),256,layers.LeakyReLU(0.2),kernel_size=(3,3),strides=(1, 1),
      use_bias=False,use_bn=True,padding="same",use_dropout=False)
  x = upsample_block(x,(5,5),128,layers.LeakyReLU(0.2),kernel_size=(5,5),strides=(1, 1),
      use_bias=False,use_bn=True,padding="same",use_dropout=False)
  x = upsample_block(x,(2,2), 3, layers.Activation("tanh"), strides=(1, 1), 
      use_bias=False, use_bn=True)

  model = Model(inputs =[noise_input,cat_input,enc_input] ,
                outputs=x, name="Generator")
  return model
g_model = get_generator(N_CATEGORIES)
g_model.summary()

Model: "Generator"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_5 (InputLayer)           [(None, 1)]          0           []                               
                                                                                                  
 embedding_1 (Embedding)        (None, 1, 16)        80          ['input_5[0][0]']                
                                                                                                  
 reshape_3 (Reshape)            (None, 16)           0           ['embedding_1[0][0]']            
                                                                                                  
 input_4 (InputLayer)           [(None, 256)]        0           []                               
                                                                                          

In [35]:
#@title wgan model

class WGAN(keras.Model):
    def __init__(
        self,
        discriminator,
        generator,
        latent_dim,
        discriminator_extra_steps=3,
        gp_weight=10.0,
    ):
        super(WGAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.d_steps = discriminator_extra_steps
        self.gp_weight = gp_weight

    def compile(self, d_optimizer, g_optimizer):#, d_loss_fn, g_loss_fn):
        super(WGAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        #self.d_loss_fn = d_loss_fn
        #self.g_loss_fn = g_loss_fn
    #"""
    def gradient_penalty(self, batch_size, real_images, fake_images,cats,encodings):
        # Get the interpolated image
        alpha = tf.random.normal([batch_size, 1, 1, 1], 0.0, 1.0)
        diff = fake_images - real_images
        interpolated = real_images + alpha * diff

        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated)
            # 1. Get the discriminator output for this interpolated image.
            pred = self.discriminator([interpolated,cats,encodings
                                       
                                       ], training=True)

        # 2. Calculate the gradients w.r.t to this interpolated image.
        grads = gp_tape.gradient(pred, [interpolated])[0]
        # 3. Calculate the norm of the gradients.
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
        gp = tf.reduce_mean((norm - 1.0) ** 2)
        return gp
    #"""
    def train_step(self, x):
        
        real_images = x[0]
        cats = x[1]
        encodings = x[2]
        rolled_cats = tf.roll(cats,1,axis=0)
        rolled_encodings = tf.roll(encodings,1,axis=0)

        # Get the batch size
        batch_size = tf.shape(real_images)[0]

        # For each batch, we are going to perform the
        # following steps as laid out in the original paper:
        # 1. Train the generator and get the generator loss
        # 2. Train the discriminator and get the discriminator loss
        # 3. Calculate the gradient penalty
        # 4. Multiply this gradient penalty with a constant weight factor
        # 5. Add the gradient penalty to the discriminator loss
        # 6. Return the generator and discriminator losses as a loss dictionary

        # Train the discriminator first. The original paper recommends training
        # the discriminator for `x` more steps (typically 5) as compared to
        # one step of the generator. Here we will train it for 3 extra steps
        # as compared to 5 to reduce the training time.
        for i in range(self.d_steps):
            # Get the latent vector
            random_latent_vectors = tf.random.normal(
                shape=(batch_size, self.latent_dim),mean=0,stddev=1)
            with tf.GradientTape() as tape:
                # Generate fake images from the latent vector
                fake_images = self.generator([random_latent_vectors,cats,encodings], training=True)
                # Get the logits for the fake images
                fake_logits = self.discriminator([fake_images,cats,encodings], training=True)
                # Get the logits for the real images
                real_logits = self.discriminator([real_images,cats,encodings], training=True)
                # Get the logits for the real images, wrong encoding and category
                rolled_logits = self.discriminator([real_images,rolled_cats,rolled_encodings], training=True)

                untrue_logits = tf.concat([fake_logits,rolled_logits],axis=0)

                # Calculate the gradient penalty
                gp = self.gradient_penalty(batch_size, real_images, fake_images,cats,encodings)
                
                d_loss = tf.reduce_mean(fake_logits) + tf.reduce_mean(untrue_logits) \
                        - 2*tf.reduce_mean(real_logits) + gp * self.gp_weight

            # Get the gradients w.r.t the discriminator loss
            d_gradient = tape.gradient(d_loss, self.discriminator.trainable_variables)
            # Update the weights of the discriminator using the discriminator optimizer
            self.d_optimizer.apply_gradients(
                zip(d_gradient, self.discriminator.trainable_variables)
            )
            if i==0:
              saved_metrics = {"fake_logits":fake_logits,"real_logits":real_logits,
                               "rolled_logits":rolled_logits,"grad_pen":gp}

        # Train the generator
        # Get the latent vector
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        with tf.GradientTape() as tape:
            # Generate fake images using the generator
            generated_images = self.generator([random_latent_vectors,cats,encodings], training=True)
            # Get the discriminator logits for fake images
            gen_img_logits = self.discriminator([generated_images,cats,encodings], training=True)
            # Calculate the generator loss
            g_loss = -tf.reduce_mean(gen_img_logits)

        # Get the gradients w.r.t the generator loss
        gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
        # Update the weights of the generator using the generator optimizer
        self.g_optimizer.apply_gradients(
            zip(gen_gradient, self.generator.trainable_variables)
        )
        saved_metrics.update({"d_loss": d_loss, "g_loss": g_loss})
        return saved_metrics
# Instantiate the optimizer for both networks
# (learning_rate=0.0002, beta_1=0.5 are recommended)
generator_optimizer = keras.optimizers.Adam(
    learning_rate=0.0002, beta_1=0.5, beta_2=0.9
)
discriminator_optimizer = keras.optimizers.Adam(
    learning_rate=0.0002, beta_1=0.5, beta_2=0.9
)

# Define the loss functions for the discriminator,
# which should be (fake_loss - real_loss).
# We will add the gradient penalty later to this loss function.
def discriminator_loss(real_img, fake_img):
    real_loss = tf.reduce_mean(real_img)
    fake_loss = tf.reduce_mean(fake_img)
    return fake_loss - real_loss


# Define the loss functions for the generator.
def generator_loss(fake_img):
    return -tf.reduce_mean(fake_img)

wgan = WGAN(
    discriminator=d_model,
    generator=g_model,
    latent_dim=N_NOISE,
    discriminator_extra_steps=3,
)
generator_optimizer = keras.optimizers.Adam(
    learning_rate=0.0002, beta_1=0.5, beta_2=0.9
)
discriminator_optimizer = keras.optimizers.Adam(
    learning_rate=0.0002, beta_1=0.5, beta_2=0.9
)
wgan.compile(
    d_optimizer=discriminator_optimizer,
    g_optimizer=generator_optimizer)


## Load previously trained triplet similarity model to provide encodings for WGAN conditioning

In [36]:
#@title Get model fn
from tensorflow.keras import layers
from tensorflow.keras.layers import Input, Dense, Reshape, BatchNormalization,\
                                    Flatten, GlobalAveragePooling2D, Dropout

tf_variable_VAE_loss_weight = tf.Variable(initial_value=0.0,trainable=False)

def sample_z(args):
  mu, sigma = args
  batch     = K.shape(mu)[0]
  dim       = K.int_shape(mu)[1]
  eps       = K.random_normal(shape=(batch, dim))
  return mu + K.exp(0.5*sigma) * eps

def get_online_triplet_model(n_used_cats=0,version=0):
    if(BACKBONE=='EfficientNetB1'):
        from tensorflow.keras.applications.efficientnet import EfficientNetB1,preprocess_input
        model_backbone = EfficientNetB1(weights='imagenet',include_top=False, input_shape=(244,244,3))
    elif(BACKBONE=='VGG'):
        from tensorflow.keras.applications.vgg16 import VGG16,preprocess_input
        model_backbone = VGG16(weights='imagenet',include_top=False, input_shape=(244,244,3))
    elif(BACKBONE=='EfficientNetB3'):
        from tensorflow.keras.applications.efficientnet import EfficientNetB3,preprocess_input
        model_backbone = EfficientNetB3(weights='imagenet',include_top=False, input_shape=INPUT_SHAPE)
    elif(BACKBONE=='EfficientNetB4'):
        from tensorflow.keras.applications.efficientnet import EfficientNetB4,preprocess_input
        model_backbone = EfficientNetB4(weights='imagenet',include_top=False, input_shape=INPUT_SHAPE)
    elif(BACKBONE=='xception'):
        from tensorflow.keras.applications import Xception
        from tensorflow.keras.applications.xception import preprocess_input
        model_backbone = Xception(weights='imagenet',include_top=False, input_shape=INPUT_SHAPE)
    model_backbone.trainable = False
    x = GlobalAveragePooling2D(name="avg_pool")(model_backbone.outputs[0])
    x = BatchNormalization()(x)    
    x = Flatten()(x)
    x = Dropout(0.3)(x)
    if version == 9:
      z_vae_encoding = Dense(1536,name="VAE_enc")(x)
      z_mean,z_logvar =tf.split(z_vae_encoding, num_or_size_splits=2, axis=1)            
      x = layers.Lambda(sample_z, output_shape=(1536//2, ), name='sampled_enc')([z_mean, z_logvar])      
    else:
      x = Dense(1024,activation='relu')(x)
      x = Dropout(0.3)(x)
      inside_encodings=[]
      inside_attention=[]
      for i in range(n_used_cats):        
          x_encoding = Dense(N_ENCODING)(x)
          x_encoding = K.l2_normalize(x_encoding, axis=-1)
          inside_encodings.append(x_encoding)
      x_encoding = tf.stack(inside_encodings,axis=1)

    preprocessing_fn = preprocess_input
    inp = Input(INPUT_SHAPE)

    one_model = Model(inputs = model_backbone.inputs,outputs=x_encoding, name="EncodingMod")    
    encodings = one_model(inp)
    
    inp_cats = Input((1,),dtype=tf.dtypes.int32)
    inputs = [inp, inp_cats]
    out_encodings = tf.gather_nd(encodings,indices=inp_cats,batch_dims=1,name="OutEncodings")

    model = Model(inputs = inputs, outputs = out_encodings)

    model.summary()
    
    return model, preprocessing_fn


## Data generator for WGAN 

In [37]:
#@title wgan_data_gen
from skimage.color import deltaE_ciede2000
import gc
                            
class DataGenerator(keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self,preprocess_fn,batch_size, img_aug, shuffle=True):
        'Initialization' 
        self.random_state= np.random.RandomState(seed=0)       
        self.shuffle=shuffle
        self.preprocess_fn=preprocess_fn
        self.reset_state = True
        self.batch_size = batch_size
        self.on_epoch_end()
        self.img_aug = img_aug
        self.cache_used = False
        print("Total length:",len(self))

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.ceil((N_PHOTOS) / self.batch_size))

    def __getitem__(self, index):        
        'Generate one batch of data'        
        gc.collect()
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        bs = len(indexes)
        X_img = np.empty((bs,)+INPUT_SHAPE)
        X_cats = np.empty((bs,)+(1,))
        X_encs = self.encodings[indexes]
        if self.cache_used is True:
          X_img = self.preprocess_fn(self.img_cache[indexes])
          X_cats = np.array(merged_data[indexes,1],dtype='float')
        else:
          for i, idx in enumerate(indexes):          
            X_img[i],X_cats[i] = self.get_img(idx)

          if self.img_aug is True:
            X_img = next(self.image_aug.flow(X_img,shuffle=False,batch_size=bs))
          X_img = self.preprocess_fn(X_img)
        #print(X_img.min(),X_img.max(),X_img.mean())
        return [X_img, X_cats, X_encs]

    def calculate_encodings(self,checkpoint,version,n_used_cats):
      triplet_model, preprocessing_fn = get_online_triplet_model(n_used_cats=n_used_cats,version=version)
      triplet_model.get_layer('EncodingMod').load_weights(checkpoint)
      #triplet_model.load_weights(checkpoint)
      def data_gen():
        all_indexes = np.arange(len(self.indexes))
        for i in range(self.__len__()):
          start_idx = i*self.batch_size
          end_idx = start_idx+self.batch_size
          chosen_indexes = all_indexes[start_idx:end_idx]#self.indexes[start_idx:end_idx]
          bs = len(chosen_indexes)
          X_img = np.empty((bs,)+INPUT_SHAPE)
          X_cats = np.empty((bs,)+(1,))
          for i, idx in enumerate(chosen_indexes):          
            X_img[i],X_cats[i] = self.get_img(idx)
          X_img = preprocessing_fn(X_img)
          yield ([X_img,X_cats],None)
      #self.encodings = np.empty((len(self.indexes),N_ENCODING))
      self.encodings = triplet_model.predict(x=data_gen(),steps=self.__len__(),verbose=1)
      del triplet_model
      K.clear_session()
      return

    def load_images(self):
      print("Predicted size:",len(self.indexes)*INPUT_SHAPE[0]*INPUT_SHAPE[1]*3/((2**10)**3))
      assert len(self.indexes)*INPUT_SHAPE[0]*INPUT_SHAPE[1]*3<15*((2**10)**3) #smaller than 15 GB
      self.cache_used = True
      self.img_cache = np.empty((len(self.indexes),INPUT_SHAPE[0],INPUT_SHAPE[1],3),dtype=np.uint8)   
      print(self.img_cache.nbytes)   
      for idx in np.arange(len(self.indexes)):   
        if idx%100==0:
          gc.collect() 
        img,_ = self.get_img(idx)    
        self.img_cache[idx] = img
      return        

    def get_img(self,img_idx):
      img_mode = "rgba" if merged_data[img_idx,3]==1 else "rgb"
      if os.path.isfile(merged_data[img_idx,0])==False:
        raise ValueError("Image not present at path: "+str(merged_data[img_idx,0]),img_idx)
      img = load_img(merged_data[img_idx,0],color_mode=img_mode)
      img = img_to_array(img)
      if merged_data[img_idx,3]==1 and rs.binomial(1,BG_PROBABILITY)==1:
        bg = img_to_array(load_img(BG_PATHS[rs.randint(0,len(BG_PATHS),1)[0]]))
        img = overlay_random(bg, img ,rs)      
      else:
        img = img[:,:,:3]        
      return img, merged_data[img_idx,1]

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(N_PHOTOS)         
        if(self.shuffle==True)      :
          perm = self.random_state.permutation(len(self.indexes))
          self.indexes = self.indexes[perm]   

    def perform_visual_check(self,check_idx,checkpoint,version,n_used_cats):      
      
      triplet_model, preprocessing_fn = get_online_triplet_model(n_used_cats=n_used_cats,version=version)
      triplet_model.load_weights(checkpoint)
      img,cat = self.get_img(check_idx)
      preprocessed_img = preprocessing_fn(img[None,...])
      enc_pred = triplet_model.predict([preprocessed_img,np.array(cat,dtype='float')[None,...]])
      if self.cache_used is True:
        plt.subplot(1,2,1)
        plt.imshow(self.img_cache[check_idx])          
        plt.subplot(1,2,2)
      plt.imshow(img/255.0)        
      plt.show()
      print("Cat:",cat)
      print("Predicted enc:",enc_pred[0,:6])
      print("True enc:",self.encodings[check_idx][:6])    
      del triplet_model
      K.clear_session()

def discr_preprocess_fn(x):
  x = x/255.0
  return x+tf.random.normal(x.shape,mean=0,stddev=2.0/255.0)

## Precalculate encodings

In [38]:
BACKBONE='EfficientNetB3'
data_gen = DataGenerator(preprocess_fn=discr_preprocess_fn,batch_size=BATCH_SIZE, img_aug=False, shuffle=True)
data_gen.calculate_encodings(os.path.join(DRIVE_DIR,"checkpoints","s2s_v6_EfficientNetB3_300_256_03_KbmgV3C.h5"),6,n_used_cats=N_CATEGORIES)
assert len(data_gen.encodings) == len(merged_data)

Total length: 672
Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_8 (InputLayer)           [(None, 300, 300, 3  0           []                               
                                )]                                                                
                                                                                                  
 EncodingMod (Functional)       (None, 5, 256)       13675567    ['input_8[0][0]']                
                                                                                                  
 input_9 (InputLayer)           [(None, 1)]          0           []                               
                                                                                                  
 tf.compat.v1.gather_nd (TFOpLa  (None, 256)         0           ['EncodingM

## Perform visual check to check validity of encodings using TSNE

In [None]:
data_gen.load_images()
data_gen.perform_visual_check(10000,os.path.join(DRIVE_DIR,"checkpoints","s2s_v6_EfficientNetB3_300_256_03_KbmgBothAugReviewedUnfrzH.ckpt"),6,n_used_cats=N_CATEGORIES)

In [44]:
from sklearn.manifold import TSNE
no_bg_idx = [i for i in range(len(data_gen.encodings)) if  merged_data[i,1]=='0'] #merged_data[i,3]=='1' and merged_data[i,1]=='1']
X_embedded = TSNE(n_components=2, learning_rate='auto',
                  init='random').fit_transform(data_gen.encodings[no_bg_idx])

In [45]:
import matplotlib.pyplot as plt
from matplotlib.offsetbox import OffsetImage, AnnotationBbox

fig, ax = plt.subplots(figsize=(100,100))
ax.scatter(X_embedded[:,0], X_embedded[:,1],s=0) 
random_indexes = np.random.choice(np.arange(len(X_embedded)),size=3000,replace=False)

def getImage(path):
    return OffsetImage(plt.imread(path))

for x0, y0, path in zip(X_embedded[random_indexes,0], X_embedded[random_indexes,1],(merged_data[no_bg_idx,0])[random_indexes]):
    ab = AnnotationBbox(getImage(path), (x0, y0), frameon=False)
    ax.add_artist(ab)
plt.axis("off")
plt.savefig(os.path.join(DRIVE_DIR,"dresses.png"))
plt.plot()


[]

## Prepare wGAN callbacks and fit

In [None]:
#@title GAN plotter to check results
class GANMonitor(keras.callbacks.Callback):
    def __init__(self, num_img, data_gen):
        self.num_img = num_img
        self.img_indexes = np.random.choice(N_PHOTOS,size=num_img,replace=False)        
        self.latent_vec = tf.random.normal(shape=(self.num_img, N_NOISE))
        self.encodings = data_gen.encodings[self.img_indexes]
        self.img_path = os.path.join(DRIVE_DIR,"wgan_images")
        base_imgs = []
        cats=[]
        for i in range(num_img):
          img,cat = data_gen.get_img(self.img_indexes[i])
          base_imgs.append(img)
          cats.append(cat)
        self.cats = np.array(cats,dtype='int')
        tiled_img = np.concatenate(base_imgs,axis=1)
        tiled_img = keras.preprocessing.image.array_to_img(tiled_img)
        tiled_img.save(os.path.join(self.img_path,"generated_base_img.png"))
        


    def on_epoch_end(self, epoch, logs=None):

        generated_images = self.model.generator.predict([self.latent_vec,self.cats,self.encodings])
        generated_images = (generated_images * 127.5) + 127.5
        #generated_images = generated_images.numpy()
        tiled_img = np.concatenate(generated_images,axis=1)
        tiled_img = keras.preprocessing.image.array_to_img(tiled_img)
        tiled_img.save(os.path.join(self.img_path,"generated_{epoch}_img.png").format(epoch=epoch+1))

monitor_cb = GANMonitor(10,data_gen)


In [None]:
data_gen.batch_size=BATCH_SIZE
data_gen.on_epoch_end()
monitor_cb.set_model(wgan)
monitor_cb.on_epoch_end(-1)

In [None]:
import datetime
logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
tensorboard_cb = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)

strategy = tf.distribute.TPUStrategy(resolver)

wgan.fit(data_gen,callbacks=[tensorboard_cb,monitor_cb],steps_per_epoch=100)