In [1]:
import sys
sys.path.append("../..")
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.keras import layers

from transport_nets.bijectors import BananaMap, BananaFlow
tfd = tfp.distributions
tfb = tfp.bijectors

In [2]:
def make_disc_model(nn_list):
    model = tf.keras.Sequential()
    model.add(layers.Dense(nn_list[0]))
    model.add(layers.LeakyReLU())
    model.add(layers.BatchNormalization())
    model.add(layers.Dense(nn_list[1]))
    model.add(layers.LeakyReLU())
    model.add(layers.BatchNormalization())
    model.add(layers.Dense(nn_list[2]))
    model.add(layers.LeakyReLU())
    model.add(layers.BatchNormalization())
    model.add(layers.Dense(1,activation='sigmoid'))
    
    return model

def make_gen_model(nn_list,output_dim):
    model = tf.keras.Sequential()
    model.add(layers.Dense(nn_list[0]))
    model.add(layers.LeakyReLU())
    model.add(layers.BatchNormalization())
    model.add(layers.Dense(nn_list[1]))
    model.add(layers.LeakyReLU())
    model.add(layers.BatchNormalization())
    model.add(layers.Dense(nn_list[2]))
    model.add(layers.LeakyReLU())
    model.add(layers.BatchNormalization())
    model.add(layers.Dense(output_dim))

    return model

class T_gen(tf.keras.Model):
    
    def __init__(self,n,m,K_nn_list,F_nn_list):
        super(T_gen,self).__init__(name='T')
        self.n = n
        self.m = m
        self.K = make_gen_model(K_nn_list,n)
        self.F = make_gen_model(F_nn_list,m)
        
    def call(self, inputs):
        x = inputs[...,:self.n]
        y = inputs[...,self.n:]
        T1 = self.K(x)
        T2 = self.F(tf.concat([T1,y],axis=-1))
        
        return tf.concat([T1,T2],axis = -1)

In [3]:
# Using Adam Optimizer for minimizing the Generator and Discriminator loss
lr = 1e-5
BATCH_SIZE = 100
BUFFER_SIZE = 10000
n = 1
m = 1
latent_dim = n+m
lamda = 0.01 # monoticity lagrange multiplier 
disc_nn_list = [200,500,100]
K_nn_list = [100,200,100]
F_nn_list = [200,500,100]
N = 5000
N_epochs = 500


params = (0.5,0.1,0.05,0.0) #(a1,a2,a3,theta)
bMap = BananaMap(params)
bFlow = BananaFlow(bMap)

XT = bFlow.sample(N)
x = XT[...,1:]
y = XT[...,:1]
train_dataset_tensor = tf.concat([x,y],axis=-1)
train_dataset = tf.data.Dataset.from_tensor_slices(train_dataset_tensor).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

f = make_disc_model(disc_nn_list)
T = T_gen(n,m,K_nn_list,F_nn_list)

gen_opt = tf.keras.optimizers.Adam(learning_rate=lr, epsilon=1e-8)
disc_opt = tf.keras.optimizers.Adam(learning_rate=lr, epsilon=1e-8)

def gen_loss():
    z1 = tf.random.normal([BATCH_SIZE,latent_dim])
    z2 = tf.random.normal([BATCH_SIZE,latent_dim])
    T1 = T(z1)
    T2 = T(z2)
    g_loss_fake = tf.reduce_mean(tf.math.log(f(T1)),axis=0)
    m_loss = lamda*tf.reduce_mean(tf.reduce_sum((T1-T2)*(z1-z2),axis=1))
    g_loss = g_loss_fake + m_loss
    return -g_loss

def disc_loss(x):
    z = tf.random.normal([BATCH_SIZE,latent_dim])
    d_loss_real = tf.reduce_mean(tf.math.log(f(x)),axis=0)
    d_loss_fake = tf.reduce_mean(tf.math.log(1.0-f(T(z))),axis=0)
    d_loss = d_loss_real+d_loss_fake
    return -d_loss 
    

# input x is minibatch of data points
@tf.function
def train_step(x): 
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        g_loss = gen_loss()
        d_loss = disc_loss(x)
        
    gen_grad = gen_tape.gradient(g_loss, T.trainable_variables)
    disc_grad = disc_tape.gradient(d_loss, f.trainable_variables)
    
    gen_opt.apply_gradients(zip(gen_grad,T.trainable_variables))
    disc_opt.apply_gradients(zip(disc_grad,f.trainable_variables))
    
    return d_loss,g_loss

def train(dataset, epochs):
    for epoch in range(epochs):
        for data_batch in dataset:
            d_loss,g_loss = train_step(data_batch)
        if epoch % 10 == 0: 
            print('it:',epoch,"D loss:",d_loss.numpy(),"G loss:",g_loss.numpy())
            
train(train_dataset,N_epochs)

it: 0 D loss: [1.3806748] G loss: [0.6815028]
it: 10 D loss: [1.3881774] G loss: [0.68130857]
it: 20 D loss: [1.3768718] G loss: [0.7296488]
it: 30 D loss: [1.3941169] G loss: [0.6778769]
it: 40 D loss: [1.3665285] G loss: [0.6577923]
it: 50 D loss: [1.3927486] G loss: [0.6815985]
it: 60 D loss: [1.3810556] G loss: [0.71837795]
it: 70 D loss: [1.3692669] G loss: [0.7322038]
it: 80 D loss: [1.3926544] G loss: [0.69731385]
it: 90 D loss: [1.3827499] G loss: [0.69271773]
it: 100 D loss: [1.3835373] G loss: [0.6985604]
it: 110 D loss: [1.3901974] G loss: [0.68363625]
it: 120 D loss: [1.38146] G loss: [0.69118667]
it: 130 D loss: [1.3858037] G loss: [0.696331]
it: 140 D loss: [1.3853552] G loss: [0.68967885]
it: 150 D loss: [1.384937] G loss: [0.68377745]
it: 160 D loss: [1.3874123] G loss: [0.679668]
it: 170 D loss: [1.3865869] G loss: [0.6736726]
it: 180 D loss: [1.3860171] G loss: [0.6796594]
it: 190 D loss: [1.3813926] G loss: [0.69246066]
it: 200 D loss: [1.373996] G loss: [0.7000928]


In [None]:
import seaborn as sns
def plot_density(data,axis):
    x, y = np.squeeze(np.split(data, 2, axis=1))
    return sns.kdeplot(x, y, cmap="viridis", shade=True, 
                     shade_lowest=True, ax=axis)
xa,xb,ya,yb = (-1.5,1.5,-0.2,0.6)
fig,ax = plt.subplots(1,2,figsize=(10,4))
X = bFlow.sample(5000)
X_p = T(tf.random.normal([5000,2]))
flip = tfb.Permute([1,0])
X_flipped = flip.forward(X_p)
l1 = plot_density(X,axis=ax[0])
l2 = plot_density(X_flipped,axis=ax[1])
ax[0].set(xlim=(xa,xb),ylim=(ya,yb))
ax[1].set(xlim=(xa,xb),ylim=(ya,yb))

ax[0].set_title('true map')
ax[1].set_title('MGAN map')
plt.savefig('MGAN_kde.png')

In [None]:
fig, ax = plt.subplots(2,2,figsize=(8,6))
x_obs_list = [0.5,0.2,0.0]
ax[0,0].scatter(X[:,0],X[:,1],alpha=0.2,label='true data')
ax[0,0].scatter(X_flipped[:,0],X_flipped[:,1],alpha=0.2,label='MGAN samples')
ax[0,0].plot(np.linspace(-2,2,100),x_obs_list[0]*np.ones(100),'r--')
ax[0,0].plot(np.linspace(-2,2,100),x_obs_list[1]*np.ones(100),'g--')
ax[0,0].plot(np.linspace(-2,2,100),x_obs_list[2]*np.ones(100),'--',c='purple')
ax[0,0].set(xlabel='y',ylabel='x')
ax[0,0].legend()

Ns = 2000
u = tf.random.normal([Ns,1])

x_obs1 = x_obs_list[0]*tf.ones([Ns,1])
x_obs2 = x_obs_list[1]*tf.ones([Ns,1])
x_obs3 = x_obs_list[2]*tf.ones([Ns,1])
y_sample1 = T.F(tf.concat([x_obs1,u],axis=-1))[...,0]
y_sample2 = T.F(tf.concat([x_obs2,u],axis=-1))[...,0]
y_sample3 = T.F(tf.concat([x_obs3,u],axis=-1))[...,0]
ax[0,1].hist(y_sample1,40,color='r',density=True)
ax[1,0].hist(y_sample2,40,color='g',density=True)
ax[1,1].hist(y_sample3,40,color='purple',density=True)
plt.tight_layout()
plt.savefig('MGAN_conditional_samples.png')

In [6]:
T.summary()

Model: "T"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
sequential_1 (Sequential)    multiple                  42201     
_________________________________________________________________
sequential_2 (Sequential)    multiple                  154501    
Total params: 196,702
Trainable params: 194,302
Non-trainable params: 2,400
_________________________________________________________________
