In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from library import datasets, models
from flax import linen as nn
from jax import random, numpy as jnp
import optax
import pandas as pd
import plotly.express as px
import plotly.io as pio
pio.renderers.default = 'notebook_connected'

In [3]:
LATENT_DIM = 1
AMBIENT_DIM = 2

In [4]:
taylor_pair_dis = datasets.point_dataset.PointPairDistribution.random_taylor(
    latent_dim=LATENT_DIM,
    dis_A_dim=AMBIENT_DIM,
    dis_B_dim=AMBIENT_DIM,
    latent_range=1,
    max_order=2,
    coeff_range=1,
    noise_std_A=0.05,
    noise_std_B=0,
    key=random.PRNGKey(2))

dataset = taylor_pair_dis.generate_dataset(1000)

In [5]:
A, _ = dataset.get_all_point_pairs()

In [70]:
discriminator = nn.Sequential([
    nn.Dense(8),
    nn.relu,
    nn.Dense(16),
    nn.relu,
    nn.Dense(16),
    nn.relu,
    nn.Dense(8),
    nn.relu,
    nn.Dense(1),
])

generator = nn.Sequential([
    nn.Dense(8),
    nn.relu,
    nn.Dense(8),
    nn.relu,
    nn.Dense(4),
    nn.relu,
    nn.Dense(AMBIENT_DIM),
])

model = models.vanilla_gan.VanillaGAN(
    generator=generator,
    discriminator=discriminator,
    latent_shape=(LATENT_DIM,),
    ambient_shape=(AMBIENT_DIM,)
)

print(model)
model.initialize(loss_fn=optax.sigmoid_binary_cross_entropy)

generator:
	Model Structure:
		Sequential(
		    # attributes
		    layers = [Dense(
		        # attributes
		        features = 8
		        use_bias = True
		        dtype = None
		        param_dtype = float32
		        precision = None
		        kernel_init = init
		        bias_init = zeros
		    ), <jax._src.custom_derivatives.custom_jvp object at 0x7f4baae31480>, Dense(
		        # attributes
		        features = 8
		        use_bias = True
		        dtype = None
		        param_dtype = float32
		        precision = None
		        kernel_init = init
		        bias_init = zeros
		    ), <jax._src.custom_derivatives.custom_jvp object at 0x7f4baae31480>, Dense(
		        # attributes
		        features = 4
		        use_bias = True
		        dtype = None
		        param_dtype = float32
		        precision = None
		        kernel_init = init
		        bias_init = zeros
		    ), <jax._src.custom_derivatives.custom_jvp object at 0x7f4baae31480>, Dense(
		        # attributes
		        

In [86]:
model.train(A,optax.adam(learning_rate=1e-2), print_every=5, batch_size=1000, num_epochs=250)

iteration 0; gen_loss:  1.16e+00; dis_loss:  3.69e-01; gen_grads_magnitude:  1.71e+01; dis_grads_magnitude:  1.24e+00: : 1it [00:00,  2.23it/s]
iteration 0; gen_loss:  5.04e+00; dis_loss:  2.66e-01; gen_grads_magnitude:  2.39e+01; dis_grads_magnitude:  8.16e-02: : 1it [00:00,  7.37it/s]
iteration 0; gen_loss:  8.63e+00; dis_loss:  3.37e-01; gen_grads_magnitude:  1.59e+01; dis_grads_magnitude:  9.27e-02: : 1it [00:00,  7.89it/s]
iteration 0; gen_loss:  1.01e+01; dis_loss:  3.83e-01; gen_grads_magnitude:  1.55e+01; dis_grads_magnitude:  8.29e-02: : 1it [00:00,  8.13it/s]
iteration 0; gen_loss:  1.09e+01; dis_loss:  4.14e-01; gen_grads_magnitude:  1.52e+01; dis_grads_magnitude:  8.70e-02: : 1it [00:00,  7.77it/s]
iteration 0; gen_loss:  1.11e+01; dis_loss:  4.38e-01; gen_grads_magnitude:  1.50e+01; dis_grads_magnitude:  9.74e-02: : 1it [00:00,  7.00it/s]
iteration 0; gen_loss:  1.06e+01; dis_loss:  4.60e-01; gen_grads_magnitude:  1.59e+01; dis_grads_magnitude:  1.10e-01: : 1it [00:00,  8.

In [87]:
gan_distribution = model.create_distribution()

In [88]:
true_samples = A

In [89]:
fake_samples = gan_distribution.draw_samples(random.uniform(random.PRNGKey(0), (1000, LATENT_DIM), minval=-1, maxval=1))
# fake_samples = gan_distribution.draw_samples(random.normal(random.PRNGKey(0), (1000, LATENT_DIM)))

In [90]:
print(true_samples.shape, fake_samples.shape)

(1000, 2) (1000, 2)


In [91]:
df = {'x': jnp.concatenate((true_samples[:, 0], fake_samples[:, 0]), axis=0),
      'y': jnp.concatenate((true_samples[:, 1], fake_samples[:, 1]), axis=0),
      'labels': jnp.concatenate((jnp.ones((true_samples.shape[0],)), jnp.zeros((fake_samples.shape[0],))), axis=0)}

In [92]:
px.scatter(df, x='x', y='y', color='labels')