In [119]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [120]:
from library import datasets, models
from flax import linen as nn
from jax import random, numpy as jnp
import optax
import pandas as pd
import plotly.express as px
import plotly.io as pio
pio.renderers.default = 'notebook_connected'

In [121]:
LATENT_DIM = 1
AMBIENT_DIM = 2

In [122]:
taylor_pair_dis = datasets.point_dataset.PointPairDistribution.random_taylor(
    latent_dim=LATENT_DIM,
    dis_A_dim=AMBIENT_DIM,
    dis_B_dim=AMBIENT_DIM,
    latent_range=1,
    max_order=2,
    coeff_range=1,
    noise_std_A=0.05,
    noise_std_B=0,
    key=random.PRNGKey(2))

dataset = taylor_pair_dis.generate_dataset(1000)

In [123]:
A, _ = dataset.get_all_point_pairs()

In [124]:
discriminator = nn.Sequential([
    nn.Dense(8),
    nn.relu,
    nn.Dense(4),
    nn.relu,
    nn.Dense(1),
])

generator = nn.Sequential([
    nn.Dense(8),
    nn.relu,
    nn.Dense(8),
    nn.relu,
    nn.Dense(4),
    nn.relu,
    nn.Dense(AMBIENT_DIM),
])

model = models.vanilla_gan.VanillaGAN(
    generator=generator,
    discriminator=discriminator,
    latent_shape=(LATENT_DIM,),
    ambient_shape=(AMBIENT_DIM,)
)

print(model)
model.initialize(loss_fn=optax.sigmoid_binary_cross_entropy)

generator:
	Model Structure:
		Sequential(
		    # attributes
		    layers = [Dense(
		        # attributes
		        features = 8
		        use_bias = True
		        dtype = None
		        param_dtype = float32
		        precision = None
		        kernel_init = init
		        bias_init = zeros
		    ), <jax._src.custom_derivatives.custom_jvp object at 0x7f3099d79300>, Dense(
		        # attributes
		        features = 8
		        use_bias = True
		        dtype = None
		        param_dtype = float32
		        precision = None
		        kernel_init = init
		        bias_init = zeros
		    ), <jax._src.custom_derivatives.custom_jvp object at 0x7f3099d79300>, Dense(
		        # attributes
		        features = 4
		        use_bias = True
		        dtype = None
		        param_dtype = float32
		        precision = None
		        kernel_init = init
		        bias_init = zeros
		    ), <jax._src.custom_derivatives.custom_jvp object at 0x7f3099d79300>, Dense(
		        # attributes
		        

In [125]:
model.train(A,optax.sgd(learning_rate=1e-2), print_every=5, batch_size=64, num_epochs=10)

iteration 14; gen_loss: 0.015899652615189552; dis_loss: 0.7523893117904663: : 15it [00:02,  5.58it/s]
iteration 14; gen_loss: 0.007629755884408951; dis_loss: 0.7268475294113159: : 15it [00:02,  6.68it/s]
iteration 14; gen_loss: 0.008827363140881062; dis_loss: 0.7318089008331299: : 15it [00:02,  6.58it/s]
iteration 14; gen_loss: -9.351741755381227e-05; dis_loss: 0.7347143888473511: : 15it [00:02,  6.61it/s]
iteration 14; gen_loss: -0.007829639129340649; dis_loss: 0.734379768371582: : 15it [00:02,  6.74it/s]
iteration 14; gen_loss: -0.011103011667728424; dis_loss: 0.7301446199417114: : 15it [00:02,  6.87it/s]
iteration 14; gen_loss: -0.015775766223669052; dis_loss: 0.7165679931640625: : 15it [00:02,  6.99it/s]
iteration 14; gen_loss: -0.017182042822241783; dis_loss: 0.7139186859130859: : 15it [00:02,  6.87it/s]
iteration 14; gen_loss: -0.019007131457328796; dis_loss: 0.7167685031890869: : 15it [00:02,  6.76it/s]
iteration 14; gen_loss: -0.020778566598892212; dis_loss: 0.7118358016014099:

In [126]:
gan_distribution = model.create_distribution()

In [127]:
true_samples = A

In [128]:
fake_samples = gan_distribution.draw_samples(random.normal(random.PRNGKey(0), (1000, LATENT_DIM)))

In [129]:
print(true_samples.shape, fake_samples.shape)

(1000, 2) (1000, 2)


In [130]:
df = {'x': jnp.concatenate((true_samples[:, 0], fake_samples[:, 0]), axis=0),
      'y': jnp.concatenate((true_samples[:, 1], fake_samples[:, 1]), axis=0),
      'labels': jnp.concatenate((jnp.ones((true_samples.shape[0],)), jnp.zeros((fake_samples.shape[0],))), axis=0)}

In [131]:
px.scatter(df, x='x', y='y', color='labels')