# <tt>tf_gan</tt>: A short implementation of Generative Adversarial Neural Networks in TensorFlow

Import

In [None]:
import tensorflow as tf
from tf_gan import GAN

Prepare dataset and networks.

<tt>generate_latent_variables</tt> should be a function of <tt>int: batch_size</tt>

As a convention, <tt>disc_network</tt> should have ouput of <tt>tf.keras.layers.Dense(1)</tt> (do not use sigmoid).

As optional arguments: <tt>gen_opt</tt> and <tt>disc_opt</tt> for the optimizer of the generator and of the discriminator respectively.

In [None]:
dim = 5
latent_dim = 10
dataset = tf.data.Dataset.from_tensor_slices(tf.random.normal((20000, dim))).batch(256).prefetch(tf.data.AUTOTUNE)

gen_network = tf.keras.Sequential([
    tf.keras.layers.Dense(units=32, activation='relu'),
    tf.keras.layers.Dense(units=64, activation='relu'),
    tf.keras.layers.Dense(units=dim)
])

disc_network = tf.keras.Sequential([
    tf.keras.layers.Dense(units=32, activation='relu'),
    tf.keras.layers.Dense(units=16, activation='relu'),
    tf.keras.layers.Dense(units=1)
])

def generate_latent_variables(batch_size):
    return tf.random.normal((batch_size, latent_dim))


gan = GAN(
    gen_network = gen_network,
    disc_network = disc_network,
    generate_latent_variables = generate_latent_variables,
)

Training with the <tt>dataset</tt>

In [None]:
gan.train(dataset=dataset, epochs=20)

Generate samples

In [None]:
gan.generate_samples(batch_size=10)