In [1]:
import numpy as np
import gc
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import pickle
import math
from tensorflow.keras import backend as K
import sys
import cv2
import time

In [2]:
class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon


In [3]:
class maskVAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.train_mask_loss = keras.metrics.Mean(name="train_mask_loss")
        self.val_mask_loss = keras.metrics.Mean(name="val_mask_loss")
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
        self.kl_weight = 0.0

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.train_mask_loss,
            self.kl_loss_tracker,
        ]
    
    def reconstruction_loss(self, true_masks, pred_masks, z_mean, z_logvar, z_latent):
    
        kl_loss = tf.reduce_mean(0.5 * tf.reduce_sum(tf.square(z_mean) + tf.square(tf.exp(z_logvar)) - 2*(z_logvar) - 1,
                                                     axis=1))
        mask_loss = tf.reduce_mean(tf.keras.backend.binary_crossentropy(true_masks,pred_masks))

        return mask_loss, kl_loss

    def train_step(self, data):
        (bx, cls, mask) = data[0]
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z_latent = self.encoder(data)
            pred_mask = self.decoder((bx, cls, z_latent))
            mask_loss, kl_loss = self.reconstruction_loss(mask, pred_mask, z_mean, z_log_var, z_latent)
            total_loss = mask_loss + kl_loss*self.kl_weight
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.train_mask_loss.update_state(mask_loss)
        self.kl_loss_tracker.update_state(kl_loss)
#         self.update_kl_weight()
        return {
            "loss": self.total_loss_tracker.result(),
            "mask_loss": self.train_mask_loss.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }
    
    def test_step(self, data):
        
        z_mean, z_log_var, z_latent = self.encoder(data)
        (bx, cls, mask) = data[0]
        pred_mask = self.decoder((bx, cls, z_latent))
        mask_loss, kl_loss = self.reconstruction_loss(mask, pred_mask, z_mean, z_log_var, z_latent)
        total_loss = mask_loss + kl_loss*self.kl_weight
        self.val_mask_loss.update_state(mask_loss)
        return {
            "mask_loss": self.val_mask_loss.result(),
        }
    
    def generate(self, bbx_cond, class_cond):
        
        batch = len(bbx_cond)
        dim = 64
        z_latent = tf.keras.backend.random_normal(shape=(batch, dim))
        pred_mask = self.decoder((bbx_cond, class_cond, z_latent))
        
        return pred_mask
    
    def reconstruct(self, data):
        
        z_mean, z_log_var, z_latent = self.encoder(data)
        (bx, cls, mask) = data[0]
        pred_mask = self.decoder((bx, cls, z_latent))
        
        return pred_mask
        
    
    @tf.function
    def update_kl_weight(self):
        if (self.kl_loss_tracker.result() > 10.0 
            and abs(self.train_mask_loss.result() - self.val_mask_loss.result())< 0.1 
            and self.kl_weight<0.5):
            self.kl_weight += 0.01 
        

In [4]:
def load_data(file_postfix):
    
    outfile = 'C:/GitHub/meronymnet/data_np_16/X_train'+file_postfix+'.np'
    with open(outfile, 'rb') as pickle_file:
        X_train = pickle.load(pickle_file)

    outfile = 'C:/GitHub/meronymnet/data_np_16/class_v'+file_postfix+'.np'
    with open(outfile, 'rb') as pickle_file:
        class_v = pickle.load(pickle_file)

    outfile = 'C:/GitHub/meronymnet/data_np_16/masks_train'+file_postfix+'.np'
    with open(outfile, 'rb') as pickle_file:
        masks = pickle.load(pickle_file)

    outfile = 'C:/GitHub/meronymnet/data_np_16/X_train_val'+file_postfix+'.np'
    with open(outfile, 'rb') as pickle_file:
        X_train_val = pickle.load(pickle_file)

    outfile = 'C:/GitHub/meronymnet/data_np_16/class_v_val'+file_postfix+'.np'
    with open(outfile, 'rb') as pickle_file:
        class_v_val = pickle.load(pickle_file)

    outfile = 'C:/GitHub/meronymnet/data_np_16/masks_val'+file_postfix+'.np'
    with open(outfile, 'rb') as pickle_file:
        masks_val = pickle.load(pickle_file)

     #outfile = 'C:/GitHub/meronymnet/data_np_16/X_test'+file_postfix+'.np'
     #with open(outfile, 'rb') as pickle_file:
     #    X_test = pickle.load(pickle_file)

     #outfile = 'C:/GitHub/meronymnet/data_np_16/X_test'+file_postfix+'.np'
     #with open(outfile, 'rb') as pickle_file:
     #    X_obj_test = pickle.load(pickle_file)
        
        
    return X_train, class_v, masks, X_train_val, class_v_val, masks_val

In [5]:
def shuffle_latent(a, b, c, d=None):
    p = np.random.permutation(len(a))
    if d is None:
        return a[p], b[p], c[p]
    return a[p], b[p], c[p], d[p]

def sampling(z_mean, z_log_var):
    epsilon = tf.random_normal(tf.shape(z_log_var), name="epsilon")
    return z_mean + epsilon * tf.exp(z_log_var)

def frange_cycle_linear(n_iter, start=0.0, stop=1.0,  n_cycle=4, ratio=0.5):
    L = np.ones(n_iter) * stop
    period = n_iter/n_cycle
    step = (stop-start)/(period*ratio)

    for c in range(n_cycle):
        v, i = start, 0
        while v <= stop and (int(i+c*period) < n_iter):
            L[int(i+c*period)] = v
            v += step
            i += 1
    return L

In [6]:
def reconstruction_loss(true_masks, pred_masks, z_mean, z_logvar, z_latent):
    
    kl_loss = tf.reduce_mean(0.5 * tf.reduce_sum(tf.square(z_mean) + tf.square(tf.exp(z_logvar)) - 2*(z_logvar) - 1,
                                                 axis=1))  
    mask_loss = tf.reduce_mean(tf.keras.backend.binary_crossentropy(true_masks,pred_masks))

    return mask_loss, kl_loss
    

In [7]:
batch_size = 10
max_num_node = 16
latent_dims = 64

latent_inputs = keras.Input(shape=(latent_dims,))

true_maps = keras.Input(shape=([max_num_node, 64, 64, 1]), dtype=tf.float32)
true_masks = keras.Input(shape=([max_num_node, 64, 64, 1]), dtype=tf.float32)
true_edges = keras.Input(shape=([max_num_node, 64, 64, 1]), dtype=tf.float32)

true_bbxs = keras.Input(shape=([max_num_node, 4]), dtype=tf.float32)
cond_bbxs = keras.Input(shape=([max_num_node, 4]), dtype=tf.int32)

true_lbls = keras.Input(shape=([max_num_node, 1]), dtype=tf.float32)
cond_lbls = keras.Input(shape=([max_num_node, 1]), dtype=tf.float32)

true_classes = keras.Input(shape=([7]), dtype=tf.float32)
cond_classes = keras.Input(shape=([7]), dtype=tf.float32)


rnn_bbxs = layers.Bidirectional(layers.GRU(4, return_sequences=True))(true_bbxs)
concatenated_bbx_lbl = rnn_bbxs
dense_cond = layers.Dense(64, activation='tanh')(true_classes)
enc = layers.TimeDistributed(layers.Conv2D(8, kernel_size=3))(true_masks)
enc = layers.TimeDistributed(layers.BatchNormalization(trainable = False))(enc)
enc = layers.TimeDistributed(layers.Activation('relu'))(enc)

enc = layers.TimeDistributed(layers.Conv2D(16, kernel_size=3))(enc)
enc = layers.TimeDistributed(layers.BatchNormalization(trainable = False))(enc)
enc = layers.TimeDistributed(layers.Activation('relu'))(enc)

enc = layers.TimeDistributed(layers.MaxPooling2D(pool_size=(2, 2)))(enc)

enc = layers.TimeDistributed(layers.Conv2D(32, kernel_size=3, activation='relu'))(enc)
enc = layers.TimeDistributed(layers.BatchNormalization(trainable = False))(enc)
enc = layers.TimeDistributed(layers.Activation('relu'))(enc)
enc = layers.TimeDistributed(layers.Flatten())(enc)
TDD = layers.TimeDistributed(layers.Dense(64, activation='relu', name = 'encoded_bitmaps'))
dense_enc_maps = TDD(enc)

BGRU = layers.Bidirectional(layers.GRU(32, return_sequences=True))
rnn_maps = BGRU(dense_enc_maps)

D = layers.Dense(64, activation='tanh')
attention = D(concatenated_bbx_lbl)
sent_representation = layers.Multiply()([rnn_maps, attention])
sent_representation = layers.Multiply()([sent_representation, dense_cond])
images_with_attention = layers.Lambda(lambda xin: K.sum(xin, axis=-2),
                                            output_shape=(128,))(sent_representation)

z_mean = layers.Dense(64, activation='tanh')(images_with_attention)
z_log_var = layers.Dense(64, activation='tanh', name='z_logvar')(images_with_attention)

z_latent = z_mean #sampling(z_mean, z_log_var)
#z_latent = Sampling()([z_mean, z_log_var])
encoder = keras.Model(inputs=[true_bbxs, true_classes, true_masks],
                      outputs=[z_mean, z_log_var, z_latent], 
                      name='encoder')

# Decoder
cond_bbx = layers.Lambda(lambda xin: K.sum(xin, axis=-1), output_shape=(4,))(cond_bbxs)
cond_cat = cond_bbx

cond_fully_cat = layers.Dense(64, activation='relu')(cond_cat)
cond_class_ = layers.Dense(64, activation='relu')(cond_classes)
conditioned_z = layers.concatenate([cond_fully_cat, latent_inputs], axis=-1, name='conditioned_z_1')
conditioned_z = layers.concatenate([conditioned_z, cond_class_], axis=-1, name='conditioned_z_2')
decoded = layers.RepeatVector(max_num_node)(conditioned_z)
decoded = layers.Bidirectional(layers.GRU(32, return_sequences=True))(decoded)
dec_dense = layers.TimeDistributed(layers.Dense(25088, activation='relu',  name = 'encoding'))(decoded)
dec_conv = layers.TimeDistributed(layers.Reshape((28, 28, 32)))(dec_dense)

dec = layers.TimeDistributed(layers.Conv2DTranspose(32, kernel_size=3, padding='same'))(dec_conv)
dec = layers.TimeDistributed(layers.BatchNormalization(trainable = False))(dec)
dec = layers.TimeDistributed(layers.Activation('relu'))(dec)

dec = layers.TimeDistributed(layers.Conv2DTranspose(16, kernel_size=3))(dec)
dec = layers.TimeDistributed(layers.BatchNormalization(trainable = False))(dec)
dec = layers.TimeDistributed(layers.Activation('relu'))(dec)

dec = layers.TimeDistributed(layers.UpSampling2D(size=(2, 2)))(dec)

dec = layers.TimeDistributed(layers.Conv2DTranspose(8, kernel_size=3))(dec)
dec = layers.TimeDistributed(layers.BatchNormalization(trainable = False))(dec)
dec = layers.TimeDistributed(layers.Activation('relu'))(dec)

decoder_bitmaps = layers.TimeDistributed(layers.Conv2DTranspose(1, kernel_size=3,
                                                                            activation='sigmoid', 
                                                                            name = 'decoded_mask'))(dec)
decoder = keras.Model(inputs=[cond_bbxs, cond_classes, latent_inputs],
                      outputs=decoder_bitmaps, 
                      name='decoder')

In [8]:
lr = 0.00002
file_postfix = '_combined_mask_data'

mask_vae_model = maskVAE(encoder, decoder)
X_train, class_v, masks, X_train_val, class_v_val, masks_val = load_data(file_postfix)
gc.collect()
# train_dataset = tf.data.Dataset.from_tensor_slices((X_train[:,:,1:], class_v, masks))
# val_dataset = tf.data.Dataset.from_tensor_slices((X_train_val[:,:,1:], class_v_val, masks_val))

def scheduler(epoch, lr):
    if epoch < 10:
        return lr
    elif epoch%10==0:
        return lr * tf.math.exp(-0.1)
    else:
        return lr
    
callbacks = [
    tf.keras.callbacks.ModelCheckpoint(
        filepath="D:/meronym_data/runs/mask_generation_model_tf2_reconstruction/lr00002/maskvae.ckpt",
        save_freq=10),
    tf.keras.callbacks.LearningRateScheduler(scheduler)]

mask_vae_model.compile(optimizer=keras.optimizers.Adam(lr))
mask_vae_model.fit(((X_train[:,:,1:], class_v, masks)),
                   epochs=200, batch_size=16, 
                   validation_data=((X_train_val[:,:,1:], class_v_val, masks_val),),
                   callbacks=callbacks, 
                   shuffle=True)

Epoch 1/200


ValueError: Model <__main__.maskVAE object at 0x000002218B0ECD00> cannot be saved either because the input shape is not available or because the forward pass of the model is not defined.To define a forward pass, please override `Model.call()`. To specify an input shape, either call `build(input_shape)` directly, or call the model on actual data using `Model()`, `Model.fit()`, or `Model.predict()`. If you have a custom training step, please make sure to invoke the forward pass in train step through `Model.__call__`, i.e. `model(inputs)`, as opposed to `model.call()`.

In [None]:
tf.__version__

In [None]:
checkpoint_path = "C:/MeronymNet-PyTorch/src/mask_generation_model_tf2_reconstruction/lr00002/maskvae.ckpt"
mask_vae_model.load_weights(checkpoint_path)


In [None]:
generated_masks = []
batch_size = 100
for i in range(len(X_train_val)//batch_size):
    generated_masks.append(
        mask_vae_model.reconstruct(
            (
                np.float32(X_train_val[i:batch_size*(i+1),:,1:]),
                np.float32(class_v_val[i:batch_size*(i+1)])
                np.float32(masks_val[i:batch_size*(i+1)])
            )
        )
    )
    gc.collect()


In [None]:
outfile = 'D:/meronym_data/generated_masks.npy'
with open(outfile, 'wb') as pickle_file:
    generated_masks = np.concatenate(generated_masks)
    pickle.dump(generated_masks, pickle_file)