In [None]:
import math
import numpy as np

import cv2
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import plotly.graph_objects as go
import plotly.io as pio
pio.renderers.default='notebook'
%matplotlib inline

import tensorflow as tf
tf.keras.backend.clear_session()
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
tf.debugging.set_log_device_placement(True)

import logging
logger = tf.get_logger()
logger.setLevel(logging.ERROR)

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # FATAL
logging.getLogger('tensorflow').setLevel(logging.FATAL)

In [None]:
print(tf.__version__)

In [None]:
def load_horses_orig(path, image_size, data_shape):
    mask_path = path + 'masks/'
    image_path = path + 'images/'
    images = []
    masks = []
    test_images = []
    test_masks = []
    
    for i in range(328):
        
        orig_im = cv2.imread(image_path + 'image-{}.png'.format(i))
        orig_im= cv2.cvtColor(orig_im, cv2.COLOR_RGB2BGR)
        
        low_im = cv2.resize(orig_im, dsize=(image_size, image_size))

        orig_mask = cv2.imread(mask_path + 'mask-{}.png'.format(i))
        low_mask = cv2.resize(orig_mask, dsize=(image_size, image_size))
        low_mask = cv2.cvtColor(low_mask, cv2.COLOR_RGB2GRAY)
        bin_mask = (low_mask > 0) + 0
        
        images.append(low_im)
        masks.append(bin_mask)
    
    if data_shape == '3d':
        xtest = np.reshape(np.array(images[250:]), (-1,image_size, image_size, 3))
        ytest = np.reshape(np.array(masks[250:]), (-1, image_size, image_size, 1))
        xdata = np.reshape(np.array(images[:200]), (-1,image_size, image_size, 3))
        ydata = np.reshape(np.array(masks[:200]), (-1, image_size, image_size, 1))
        yval =  np.reshape(np.array(masks[200:250]), (-1, image_size, image_size, 1))
        xval = np.reshape(np.array(images[200:250]), (-1,image_size, image_size, 3))
    else:
        xtest = np.reshape(np.array(images[250:]), (-1, image_size * image_size * 3))
        ytest = np.reshape(np.array(masks[250:]), (-1, image_size * image_size))
        xdata = np.reshape(np.array(images[:200]), (-1, image_size * image_size * 3))
        ydata = np.reshape(np.array(masks[:200]), (-1, image_size * image_size))
        yval =  np.reshape(np.array(masks[200:250]), (-1, image_size * image_size))
        xval = np.reshape(np.array(images[200:250]), (-1, image_size * image_size * 3))
    return xdata, xval, xtest, ydata, yval, ytest

In [None]:
#change the path address 
path = './horses/'
image_size = 32
data_shape = '2d'
xdata, xval, xtest, ydata, yval, ytest = load_horses_orig(path, image_size, data_shape)

In [None]:
if data_shape == '3d':
    def draw(image, real_mask, fake_mask):
        fig, (ax1,ax2,ax3) = plt.subplots(1,3)
        ax1.axis('off')
        ax2.axis('off')
        ax3.axis('off')
        ax1.imshow(image)
        ax2.imshow(real_mask, cmap=plt.cm.gray)
        ax3.imshow(fake_mask, cmap=plt.cm.gray)
        plt.show()
else:
    def draw(image, mask, fake):
        fig, (ax1,ax2,ax3) = plt.subplots(1,3) 
        ax1.axis('off')
        ax2.axis('off')
        ax3.axis('off')
        ax1.imshow(np.reshape(image, (image_size,image_size,3)))
        ax2.imshow(np.reshape(mask, (image_size,image_size,1)), cmap=plt.cm.gray)
        ax3.imshow(np.reshape(fake, (image_size,image_size,1)), cmap=plt.cm.gray)
        plt.show()
        
draw(xdata[0], ydata[0], ydata[0])
print(xdata[0].shape, ydata[0].shape)

In [None]:
# Do not change this cell
def iou(ytrue, yprediction, data_shape):
    yp = yprediction
    yt = ytrue
    yp = yp > 0.5 + 0
    if data_shape == '3d':
        intersect = np.sum(np.minimum(yp, yt))
        union = np.sum(np.maximum(yp, yt))
    else:
        intersect = np.sum(np.minimum(yp, yt),1)
        union = np.sum(np.maximum(yp, yt),1)
    
    return np.average(intersect / (union+0.0))

assert iou(ydata, ydata, data_shape) == 1.0

In [None]:
class Discriminator(tf.keras.Model):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.concat_layer = tf.keras.layers.Concatenate(axis=3)
        
        self.conv1 = tf.keras.layers.Conv2D(1, 3, 2, activation = tf.nn.softplus)
        self.conv2 = tf.keras.layers.Conv2D(32, 3, 2, activation = tf.nn.softplus)
        self.conv3 = tf.keras.layers.Conv2D(64, 3, 2, activation = tf.nn.softplus)
        self.conv4 = tf.keras.layers.Conv2D(128, 3, 2, activation = tf.nn.softplus)
        
        self.flatten = tf.keras.layers.Flatten()
        self.fc_score = tf.keras.layers.Dense(1)

    def call(self, x, yt, y):
        x = tf.cast(x, tf.float32)
        yt = tf.cast(yt, tf.float32)
        y = tf.cast(y, tf.float32)

        x = self.concat_layer([x, yt, y])
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.flatten(x)
        x = self.fc_score(x)
        return x

class Generator(tf.keras.Model):
    def __init__(self):
        super(Generator, self).__init__()

        self.concat_layer = tf.keras.layers.Concatenate(axis=3)
        self.conv1 = tf.keras.layers.Conv2DTranspose(128, 3, 2, activation = tf.nn.softplus, padding = 'same')
        self.conv2 = tf.keras.layers.Conv2DTranspose(64, 3, 2, activation = tf.nn.softplus, padding = 'same')
        self.conv3 = tf.keras.layers.Conv2DTranspose(32, 3, 2, activation = tf.nn.softplus, padding = 'same')
        self.conv4 = tf.keras.layers.Conv2DTranspose(1, 3, 2, activation = tf.nn.softplus, padding = 'same')
        
        self.flatten = tf.keras.layers.Flatten()
        self.fc_energy = tf.keras.layers.Dense(1)
        
    def call(self, x, z):
        x = tf.cast(x, tf.float32)
        x = self.concat_layer([x, z])
        
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.flatten(x)
        x = self.fc_energy(x)
        return x

In [None]:
# def rank_based_training(self):
#     # what are value_h and l ? how are they calculated..
    
#     self.margin_weight_ph = tf.placeholder(tf.float32, shape=[], name="Margin")
#     self.value_h = tf.placeholder(tf.float32, shape=[None])
#     self.value_l = tf.placeholder(tf.float32, shape=[None])
#     self.yp_h_ind = tf.placeholder(tf.float32,
#                           shape=[None, self.config.output_num * self.config.dimension],
#                           name="YP_H")

#     self.yp_l_ind = tf.placeholder(tf.float32,
#                           shape=[None, self.config.output_num * self.config.dimension],
#                           name="YP_L")

#     self.energy_yh = self.get_energy(xinput=self.x, yinput=self.yp_h_ind, embedding=self.embedding,
#                                      reuse=self.config.pretrain)
#     self.energy_yl = self.get_energy(xinput=self.x, yinput=self.yp_l_ind, embedding=self.embedding,
#                                      reuse=True)


#     self.energy_yp = self.energy_yh
#     self.yp = self.yp_h_ind

#     self.energy_ygradient = tf.gradients(self.energy_yp, self.yp)[0]

#     vloss = 0
#     for v in self.spen_variables():
#         vloss = vloss + tf.nn.l2_loss(v)

#     obj1 = tf.reduce_sum( tf.maximum( (self.value_h - self.value_l)*self.margin_weight_ph - self.energy_yh + self.energy_yl, 0.0))
#     self.vh_sum = tf.reduce_sum (self.value_h)
#     self.vl_sum = tf.reduce_sum (self.value_l)
#     self.eh_sum = tf.reduce_sum(self.energy_yh)
#     self.el_sum = tf.reduce_sum(self.energy_yl)
#     self.objective = obj1 +  self.config.l2_penalty * vloss #+ obj2
#     self.num_update = tf.reduce_sum(tf.cast( (self.value_h - self.value_l)*self.margin_weight_ph  >= (self.energy_yh - self.energy_yl), tf.float32))
#     self.train_step = self.optimizer.minimize(self.objective, var_list=self.spen_variables())
#     return self

# def search_better_y_fast(self, xtest, yprev):
#     final_best = np.zeros((xtest.shape[0], self.config.output_num))
#     for iter in range(np.shape(xtest)[0]):
#       random_proposal = yprev[iter,:]
#       score_first = self.evaluate(np.expand_dims(xtest[iter], 0), np.expand_dims(random_proposal, 0))
#       start = score_first
#       labelset = set(np.arange(self.config.dimension))
#       found = False
#       for l in range(self.config.output_num):
#         for label in (labelset - set([yprev[iter,l]])): #set([random_proposal[l]]):
#           random_proposal_new = random_proposal[:]
#           random_proposal_new[l] = label
#           score = self.evaluate(np.expand_dims(xtest[iter], 0),
#                                 np.expand_dims(random_proposal_new, 0))
#           if score > score_first:
#             score_first = score
#             best_l = l
#             best_label = label
#             found = True
#             #random_proposal[l] = random_proposal_new[l]
#             #changed = True
#             #break
#       if self.config.loglevel > 4:
#         print ("iter:", iter, "found:", found, "score first: ", start, "new score", score_first)
#       final_best[iter, :] = yprev[iter, :]
#       if found:
#         final_best[iter, best_l] = best_label
#     return final_best

In [1]:
# use y_true only at reward not at gradient computation
def reward(yprediction, ytrue):
    intersect = tf.math.reduce_sum(tf.math.minimum(yprediction, ytrue))
    union = tf.math.reduce_sum(tf.math.maximum(yprediction, ytrue))
    return tf.math.reduce_mean(tf.math.divide(intersect, union))

def search_better_y_fast(xtest, yprev):
    # yprev = sample from langevin inf?
    final_best = np.zeros((xtest.shape[0], image_size * image_size))
    for i in range(np.shape(xtest)[0]):
        random_proposal = yprev[i,:]
        score_first = reward(random_proposal, xtest[i])
        start = score_first
        labelset = set(np.arange(1)) #output dim = 1?
        found = False
        for l in range(image_size * image_size):
            for label in (labelset - set([yprev[i,l]])):
                random_proposal_new = random_proposal[:]
                random_proposal_new[l] = label
                score = reward(random_proposal_new, xtest[i])
                
            if score > score_first:
                score_first = score
                best_l = l
                best_label = label
                found = True
        
        print ("iter:", i, "found:", found, "score first: ", start, "new score", score_first)
        final_best[i, :] = yprev[i, :]
        if found:
            final_best[i, best_l] = best_label
    return final_best

def rank_based_training(real_image, inf_iter, inf_rate, l2_penalty):
    batch_size = real_image.get_shape().as_list()[0]
    z = tf.random.uniform([batch_size, image_size, image_size, 1], 0.0, 1.0)
    energy_yh = gen(real_image, z)
    fake_label = unrolled_inf(real_image, z, inf_iter, inf_rate)
    
    z = tf.random.uniform([batch_size, image_size, image_size, 1], 0.0, 1.0)
    energy_yl = gen(real_image, z)

    energy_yp = energy_yh
    
    # sort the values based on energy
    # margin > 0 always
    # y_h = y_n in paper (change sign)
    # 

    gradients = compute_gradient(real_image, fake_label)

    vloss = 0
    for v in gen.trainable_variables:
        vloss = vloss + tf.nn.l2_loss(v)

    obj1 = tf.reduce_sum( tf.maximum( (value_h - value_l) * margin_weight_ph - energy_yh + energy_yl, 0.0))
    vh_sum = tf.reduce_sum (value_h)
    vl_sum = tf.reduce_sum (value_l)
    eh_sum = tf.reduce_sum(energy_yh)
    el_sum = tf.reduce_sum(energy_yl)
    objective = obj1 +  l2_penalty * vloss #+ obj2
    num_update = tf.reduce_sum(tf.cast( (value_h - value_l)*margin_weight_ph  >= (energy_yh - energy_yl), tf.float32))
    train_step = optimizer.minimize(objective, var_list=gen.trainable_variables)
    return self

In [None]:
def compute_gradient(x, y):
    y = tf.Variable(y)
    with tf.GradientTape() as t:
        t.watch(y)
        energy = gen(x, y)
    return t.gradient(energy, y)

def unrolled_inf(images, fake_label, inf_iter, inf_rate):
    yp_ind = fake_label; current_yp_ind = yp_ind
    
    yp_ar = []
    for i in range(inf_iter):
        gradients = compute_gradient(images, yp_ind)
        
        next_yp_ind = tf.math.add(current_yp_ind, tf.math.multiply(inf_rate, gradients))
        
        # Langevin dynamics
        temp = tf.cast(inf_rate/2, tf.float32)
        temp = temp * gradients + tf.random.normal(gradients.get_shape().as_list(), mean=0, stddev=inf_rate)
        next_yp_ind = next_yp_ind + temp
        current_yp_ind = next_yp_ind
        
        yp_ind = current_yp_ind
        yp_ar.append(yp_ind)
        
    yp = yp_ar[-1]
    yp = tf.nn.sigmoid(yp)
    return yp

def test_time_inf(images, fake_label, inf_iter, inf_rate):
    yp_ind = fake_label;   current_yp_ind = yp_ind
    
    yp_ar = []
    for i in range(inf_iter):
        gradients = compute_gradient(images, yp_ind)
        
        next_yp_ind = tf.math.add(current_yp_ind, tf.math.multiply(inf_rate, gradients))
        
        # Langevin dynamics
        temp = tf.cast(inf_rate/2, tf.float32)
        temp = temp * gradients
        next_yp_ind = next_yp_ind + temp
        current_yp_ind = next_yp_ind
        
        yp_ind = current_yp_ind
        yp_ar.append(yp_ind)
        
    yp = yp_ar[-1]
    yp = tf.nn.sigmoid(yp)
    return yp

In [None]:
def plot_figure(loss, arange, brange, start, end, step, plot_name):
    axis_length = len(np.linspace(start, end, step))
    loss = loss.numpy().reshape(-1, axis_length)

    # plot the surface plot with plotly's Surface
    fig = go.Figure(data=go.Surface(z=loss,
                                    x=arange,
                                    y=brange))

    # add a countour plot
    fig.update_traces(contours_z=dict(show=True, usecolormap=True,
                                      highlightcolor="limegreen", project_z=True))

    # annotate the plot
    fig.update_layout(title=plot_name,
                      scene=dict(
                        xaxis_title='Pred. Label Axis-0',
                        yaxis_title='Pred. Label Axis-1',
                        zaxis_title=plot_name),
                      width=700, height=700)

    fig.show()
    
def generate_vectors(real_image, real_label, fake_label, start, end, step):
    arange = np.linspace(start, end, step)
    brange = np.linspace(start, end, step)

    alen = arange.shape[0]
    blen = brange.shape[0]

    r1 = np.random.uniform(0, 1, (fake_label.shape[0], image_size * image_size))
    r1[:,0] = 1 - r1[:,1]
    r2 = np.random.uniform(0, 1, (fake_label.shape[0], image_size * image_size))
    r2[:,0] = 1 - r2[:,1]
    
    y_img = np.zeros((alen*blen, fake_label.shape[0], image_size * image_size * 3))
    y_pred = np.zeros((alen*blen, fake_label.shape[0], image_size * image_size))
    y_true = np.zeros((alen*blen, fake_label.shape[0], image_size * image_size))
    for b in range(0,blen):
        for a in range(0,alen):
            k = arange[a]
            l = brange[b]
            y_pred[b*alen+a] = tf.clip_by_value(fake_label[0, :] + l*r1[0, :] + k*r2[0, :], 0.0, 1.0) # just one horse
            y_true[b*alen+a] = real_label[0, :]
            y_img[b*alen+a] = real_image[0, :]
    return arange, brange, y_img, y_pred, y_true

def plot_contour_ce(real_image, real_label, inf_iter, inf_rate, start, end, step):
    batch_size = real_label.shape[0]
    z = tf.random.uniform([batch_size, image_size, image_size, 1], 0.0, 1.0)
    fake_label = test_time_inf(real_image, z, inf_iter, inf_rate)
    
    fake_label = np.array(fake_label).reshape(-1, image_size * image_size)
    real_label = np.array(real_label).reshape(-1, image_size * image_size)
    real_image = np.array(real_image).reshape(-1, image_size * image_size * 3)

    arange, brange, y_img, y_pred, y_true = generate_vectors(real_image, real_label, fake_label, start, end, step)

    # CE Loss
    loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
    plot_figure(loss, arange, brange, start, end, step, 'Cross Entropy')
    
    
def plot_contour_adv(real_image, real_label, inf_iter, inf_rate, start, end, step):
    batch_size = real_label.shape[0]
    z = tf.random.uniform([batch_size, image_size, image_size, 1], 0.0, 1.0)
    fake_label = test_time_inf(real_image, z, inf_iter, inf_rate)
    
    fake_label = np.array(fake_label).reshape(-1, image_size * image_size)
    real_label = np.array(real_label).reshape(-1, image_size * image_size)
    real_image = np.array(real_image).reshape(-1, image_size * image_size * 3)
    
    # Adv loss
    arange, brange, y_img, y_pred, y_true = generate_vectors(real_image, real_label, fake_label, start, end, step)
    
    real_image = y_img.reshape(-1, image_size, image_size, 3)
    real_label = y_true.reshape(-1, image_size, image_size, 1)
    fake_label = y_pred.reshape(-1, image_size, image_size, 1)
    
    # disc signals
    real_score = dis(real_image, real_label, real_label)
    fake_score = dis(real_image, real_label, fake_label)

    # interpolate
    alpha_ = tf.random.uniform([fake_label.shape[0], 1, 1, 1], 0.0, 1.0)
    inter_sample = fake_label * alpha_ + real_label * (1 - alpha_)
    with tf.GradientTape() as tape_gp:
        tape_gp.watch(inter_sample)
        inter_score = dis(real_image, real_label, inter_sample)
    gp_gradients = tape_gp.gradient(inter_score, inter_sample)
    
    # gradient penalty
    gp_gradients_norm = tf.sqrt(tf.reduce_sum(tf.square(gp_gradients), axis = [1, 2, 3]))
    gp = tf.reduce_mean((gp_gradients_norm - 1.0) ** 2)

    # loss
    loss = fake_score - real_score + (gp * 10)
    plot_figure(loss, arange, brange, start, end, step, 'Adversarial loss')
    
    
def plot_hist(fake, real):
    fake = np.array(fake).reshape(-1, 1)
    real = np.array(real).reshape(-1, 1)
    bins = len(fake)
    plt.figure(figsize=(15,5))
    plt.hist(fake, bins, alpha=0.5, color='blue', label='Fake')
    plt.hist(real, bins, alpha=0.5, color='orange', label='Real')
    plt.legend(loc='upper right')
    plt.show()
    
def calc_iou(images, labels, inf_iter, inf_rate):
    batch_size = labels.shape[0]
    z = tf.random.uniform([batch_size, image_size, image_size, 1], 0.0, 1.0)
    fake_label = test_time_inf(images, z, inf_iter, inf_rate)
    return iou(labels, fake_label)

def draw_test(images, labels, inf_iter, inf_rate, name):
    batch_size = labels.shape[0]
    z = tf.random.uniform([batch_size, image_size, image_size, 1], 0.0, 1.0)
    fake_label = test_time_inf(images, z, inf_iter, inf_rate)
    iou_score = iou(labels, fake_label)
    print('\n' + name + ' iou_score %:', round(iou_score*100, 2))
    
    for i in range(10):
        draw_all(images[i], labels[i], fake_label[i])
        
def plot_iou(epoch, train_iou, val_iou, test_iou):
    plt.plot(np.arange(epoch+1), np.array(train_iou), label = "train")
    plt.plot(np.arange(epoch+1), np.array(val_iou), label = "test")
    plt.plot(np.arange(epoch+1), np.array(test_iou), label = "valid")
    plt.xlabel('Epoch')
    plt.ylabel('IOU Score')
    plt.legend()
    plt.show()

In [None]:
def train_step_gen(real_image, real_label, inf_iter, inf_rate, pt_loss, alpha):
    batch_size = real_image.get_shape().as_list()[0]
    z = tf.random.uniform([batch_size, image_size, image_size, 1], 0.0, 1.0)
    real_label = tf.cast(real_label, dtype=tf.float32)
    with tf.GradientTape() as tape:
        fake_label = unrolled_inf(real_image, z, inf_iter, inf_rate)
        fake_score = dis(real_image, real_label, fake_label)
        loss = tf.reduce_mean(fake_score) + (pt_loss * alpha)
    gradients = tape.gradient(loss, gen.trainable_variables)
    gen_opt.apply_gradients(zip(gradients, gen.trainable_variables))
    gen_loss(loss)

def train_step_dis(real_image, real_label, inf_iter, inf_rate):
    batch_size = real_image.get_shape().as_list()[0]
    real_label = tf.cast(real_label, dtype=tf.float32)
    z = tf.random.uniform([batch_size, image_size, image_size, 1], 0.0, 1.0)
        
    with tf.GradientTape() as tape:
        fake_label = unrolled_inf(real_image, z, inf_iter, inf_rate)
        real_score = dis(real_image, real_label, real_label)
        fake_score = dis(real_image, real_label, fake_label)
        
        alpha_ = tf.random.uniform([batch_size, 1, 1, 1], 0.0, 1.0)
        inter_sample = fake_label * alpha_ + real_label * (1 - alpha_)
        with tf.GradientTape() as tape_gp:
            tape_gp.watch(inter_sample)
            inter_score = dis(real_image, real_label, inter_sample)
        gp_gradients = tape_gp.gradient(inter_score, inter_sample)
        
        # gradient penalty
        gp_gradients_norm = tf.sqrt(tf.reduce_sum(tf.square(gp_gradients), axis = [1, 2, 3]))
        gp = tf.reduce_mean((gp_gradients_norm - 1.0) ** 2)
        
        # adv loss
        loss = tf.reduce_mean(fake_score) - tf.reduce_mean(real_score) + (gp * 10)
    
    gradients = tape.gradient(loss, dis.trainable_variables)
    dis_opt.apply_gradients(zip(gradients, dis.trainable_variables))
    
    dis_loss(loss)
    adv_loss(loss - gp * 10)
    return fake_score, real_score
    
def pretrain_gen(real_image, real_label, inf_iter, inf_rate):
    batch_size = real_image.get_shape().as_list()[0]
    real_label = tf.cast(real_label, dtype=tf.float32)
    z = tf.random.uniform([batch_size, image_size, image_size, 1], 0.0, 1.0)
    with tf.GradientTape() as tape:
        fake_label = unrolled_inf(real_image, z, inf_iter, inf_rate)
        loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(real_label, fake_label))
    gradients = tape.gradient(loss, gen.trainable_variables)
    gen_opt_pt.apply_gradients(zip(gradients, gen.trainable_variables))
    gen_pretrain_loss(loss)
    return loss

def train(data, n_epoch, n_update_dis, inf_iter, inf_rate, pt_gen, alpha):
    train_iou = []; val_iou = []; test_iou = [];
    for epoch in range(n_epoch):
        print('\nepoch:', epoch)
        
        if epoch == 0:
            pt_loss = 0
        
        fake_store = [];   real_store = [];
        for images, labels in train_dataset:
            
            # train gen
            if epoch > 0:
                if pt_gen:
                    pt_loss = pretrain_gen(images, labels, inf_iter, inf_rate)
                train_step_gen(images, labels, inf_iter, inf_rate, pt_loss, alpha)
            
            # train disc
            for i in range(n_update_dis):
                fake_sc, real_sc = train_step_dis(images, labels, inf_iter, inf_rate)
            fake_store.append(fake_sc.numpy())
            real_store.append(real_sc.numpy())
            
        # store iou progression
        train_iou.append(calc_iou(xdata, ydata, inf_iter, inf_rate))
        val_iou.append(calc_iou(xtest, ytest, inf_iter, inf_rate))
        test_iou.append(calc_iou(xval, yval, inf_iter, inf_rate))
        
        if epoch % 20 == 0:
            plot_hist(fake_store, real_store)
            
            start, end, step = 0.0, 20.0, 10
            plot_contour_ce(xtest, ytest, inf_iter, inf_rate, start, end, step)
            
            start, end, step = -200.0, 100.0, 10
            plot_contour_adv(xtest, ytest, inf_iter, inf_rate, start, end, step)

            draw_test(xdata, ydata, inf_iter, inf_rate, 'train')
            draw_test(xtest, ytest, inf_iter, inf_rate, 'test')
            draw_test(xval, yval, inf_iter, inf_rate, 'valid')
            
            plot_iou(epoch, train_iou, val_iou, test_iou)
                
        template = 'PT Gen Loss: {}, Gen Loss: {}, Dis Loss: {}, Adv Loss: {}'
        print (template.format(gen_pretrain_loss.result(), gen_loss.result(), dis_loss.result(), adv_loss.result()))
        dis_loss.reset_states()
        adv_loss.reset_states()
        gen_loss.reset_states()
        gen_pretrain_loss.reset_states()

In [None]:
# general
batch_size = 8
n_epoch = 300
n_figs = 10

# disc
n_update_dis = 5
d_learning_rate = 0.0005

# pretrain gen
alpha = 0.01
pt_gen = True
pt_g_learning_rate = 0.001

# ebm gen
inf_iter = 10 #alpha
inf_rate = 2 #delta
g_learning_rate = 0.0005

##############################################################################################################

# create tf dataset generator object
N = len(xdata)
train_dataset = tf.data.Dataset.from_tensor_slices((xdata, ydata))
train_dataset = train_dataset.shuffle(buffer_size=N)
train_dataset = train_dataset.batch(batch_size=batch_size, drop_remainder=False)

# Initialize Networks
gen = Generator()
dis = Discriminator()

# Initialize Optimizer
gen_opt = tf.keras.optimizers.Adam(g_learning_rate)
gen_opt_pt = tf.keras.optimizers.Adam(pt_g_learning_rate)
dis_opt = tf.keras.optimizers.Adam(d_learning_rate)

# Initialize Metrics
adv_loss = tf.keras.metrics.Mean(name = 'Adversarial_Loss')
dis_loss = tf.keras.metrics.Mean(name = 'Discriminator_Loss')
gen_loss = tf.keras.metrics.Mean(name = 'Generator_Loss')
gen_pretrain_loss = tf.keras.metrics.Mean(name = 'Generator_Pretrain_Loss')

train(train_dataset, n_epoch, n_update_dis, inf_iter, inf_rate, pt_gen, alpha)