In [1]:
import optax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import jax  # noqa: F401

from jax_src.train import train, make_optimizer
from jax_src.generator import save_net
from jax_src.sst import (
    sst_loss_fixed_wh,  # noqa: F401
    eval_net,  # noqa: F401
    load_sst_net,  # noqa: F401
    sst_chen_consecutive,  # noqa: F401
    wass2_errors_normal,  # noqa: F401
    make_sst_net,  # noqa: F401
)
from jax_src.discriminator import marginal_wass2_error

jax.config.update("jax_enable_x64", True)
dtype = jnp.float64


def plot_losses(_losses_list):
    all_losses = jnp.abs(jnp.concatenate(_losses_list, axis=0))
    # Get rid of spikes
    bound1 = jnp.min(all_losses) + 3 * jnp.std(all_losses)
    losses_pruned = all_losses[all_losses < bound1]
    bound2 = jnp.mean(losses_pruned) + 3 * jnp.std(losses_pruned)
    all_losses = jnp.clip(all_losses, 0.0, bound2)
    plt.plot(all_losses)
    plt.show()

## The Chen-trained Neural net achieves a Wasserstein error of 3.271e-5

In [3]:
# Either make a new net or load a pre-trained one
noise_size = 3
hidden_dim = 16
num_layers = 3
leaky_slope = 0.01
use_multlayer = True
# net = make_sst_net(
#     jr.key(1),
#     noise_size,
#     hidden_dim,
#     num_layers,
#     leaky_slope,
#     use_multlayer,
#     dtype=dtype,
#     use_batch_norm=False
# )
saving = True
net = load_sst_net(
    "/home/andy/PycharmProjects/Levy_CFGAN/numpy_nets/",
    noise_size,
    hidden_dim,
    num_layers,
    leaky_slope,
    use_multlayer,
    dtype,
    use_batch_norm=False,
    use_activation=True,
)
losses_list = []
GLOBAL_KEY = jr.key(7)

# net without training
net_best, _, _ = eval_net(net, jr.key(6), 100, 2**20, -1, True, saving)

Mean error: 0.006662, variance error: 0.0009496, avg var: 0.05002, wass error: 3.271e-05, score: 0.02005
Inital score: 0.02005


## The average Wasserstein error of a normal distribution with the right conditional mean and variance is 7.67e-4

In [4]:
print(wass2_errors_normal(jr.key(0), True))

Error for w=0.4302, hh=0.4289: 0.0004312
Error for w=-2.865, hh=-0.6711: 0.00107
Error for w=-4.189, hh=-0.5475: 0.001146
Error for w=-0.9949, hh=0.9425: 0.0005802
Error for w=-0.4197, hh=0.2238: 0.0004544
Error for w=-0.4521, hh=-0.0696: 0.000491
Error for w=-1.14, hh=0.006232: 0.0009128
Error for w=-1.985, hh=-0.1498: 0.001075
Error for w=-0.3316, hh=0.3142: 0.0003915
Error for w=-1.356, hh=-0.2088: 0.0009616
Error for w=-1.034, hh=0.422: 0.0007774
Error for w=-3.419, hh=0.0187: 0.001137
Error for w=1.29, hh=1.287: 0.0006244
Error for w=0.3729, hh=-0.5267: 0.000386
Error for w=-0.6617, hh=-0.04992: 0.000645
Error for w=-3.103, hh=1.266: 0.001033
Error for w=-0.9551, hh=-0.2237: 0.0008003
Error for w=1.119, hh=-1.58: 0.0005552
Error for w=-1.396, hh=-0.1825: 0.0009771
Error for w=-1.259, hh=0.6715: 0.0007793
0.0007614567421490231


## BB loss

In [3]:
num_reps = 100
num_steps = 2**13
# Set the learning rate schedule
# schedule = optax.cosine_decay_schedule(1e-3, num_steps * num_reps, 1e-4)
schedule = optax.constant_schedule(2e-4)
opt, opt_state = make_optimizer(net, None, schedule, beta1=0.95, beta2=0.995)

for i in range(num_reps):
    GLOBAL_KEY, temp_key = jr.split(GLOBAL_KEY, 2)
    net, discr, opt_state, losses = train(
        net,
        None,
        temp_key,
        num_steps,
        opt,
        opt_state,
        sst_loss_fixed_wh,
        1.0,
        1,
    )
    avg_loss = jnp.mean(jnp.abs(losses))
    print(f"======== Finished rep {i+1}/{num_reps} ======== mean loss: {avg_loss:.4}")
    net_best, _, _ = eval_net(net, jr.key(6), 0, 2**20, net_best, True, saving)
    # plot_losses(losses_list)
    print("\n")

Mean error: 0.0, variance error: 0.0, avg var: 0.0, wass error: 0.0001992, score: 0.09959

Mean error: 0.0, variance error: 0.0, avg var: 0.0, wass error: 0.000109, score: 0.05449

Mean error: 0.0, variance error: 0.0, avg var: 0.0, wass error: 0.0001144, score: 0.05721

Mean error: 0.0, variance error: 0.0, avg var: 0.0, wass error: 9.377e-05, score: 0.04688
New best net with score 0.04688

Mean error: 0.0, variance error: 0.0, avg var: 0.0, wass error: 0.0001263, score: 0.06313

Mean error: 0.0, variance error: 0.0, avg var: 0.0, wass error: 5.947e-05, score: 0.02973
New best net with score 0.02973

Mean error: 0.0, variance error: 0.0, avg var: 0.0, wass error: 0.0001215, score: 0.06074

Mean error: 0.0, variance error: 0.0, avg var: 0.0, wass error: 0.000181, score: 0.09052

Mean error: 0.0, variance error: 0.0, avg var: 0.0, wass error: 8.531e-05, score: 0.04265

Mean error: 0.0, variance error: 0.0, avg var: 0.0, wass error: 0.0001045, score: 0.05226

Mean error: 0.0, variance er

KeyboardInterrupt: 

In [6]:
save_net(net, "/home/andy/PycharmProjects/Levy_CFGAN/numpy_nets/sst_")

## Check the SST Chen's relation is correct

In [2]:
import math

c = jnp.zeros((2**22,), dtype=dtype)
w = jr.normal(jr.PRNGKey(0), (2**22,), dtype=dtype)
hh = math.sqrt(1 / 12) * jr.normal(jr.PRNGKey(0), (2**22,), dtype=dtype)
with open("sst_saved_values/uncond/sst_unconditional.npy", "rb") as f:
    c_true = jnp.load(f)[: 2**20]

wass_dist = marginal_wass2_error(c, c_true)
print(f"num samples: 2^{math.log2(c.shape[0])}, wass dist: {wass_dist}")

for i in range(10):
    w, hh, c = sst_chen_consecutive(w, hh, c)

    wass_dist = marginal_wass2_error(c, c_true)
    print(
        f"num samples: 2^{math.log2(c.shape[0])}, wass dist: {wass_dist:.5}, mean: {jnp.mean(c):.4}, var: {jnp.var(c):.4}"
    )

num samples: 2^22.0, wass dist: 0.5846481323242188
num samples: 2^21.0, wass dist: 0.072193, mean: 0.2506, var: 0.2814
num samples: 2^20.0, wass dist: 0.016293, mean: 0.3757, var: 0.3622
num samples: 2^19.0, wass dist: 0.0045235, mean: 0.4382, var: 0.3652
num samples: 2^18.0, wass dist: 0.0012815, mean: 0.4692, var: 0.3554
num samples: 2^17.0, wass dist: 0.00043373, mean: 0.4852, var: 0.3489
num samples: 2^16.0, wass dist: 0.00012418, mean: 0.494, var: 0.3433
num samples: 2^15.0, wass dist: 0.0001847, mean: 0.4974, var: 0.3314
num samples: 2^14.0, wass dist: 0.00036022, mean: 0.4971, var: 0.3255
num samples: 2^13.0, wass dist: 0.00075478, mean: 0.4918, var: 0.3181
num samples: 2^12.0, wass dist: 0.0018489, mean: 0.4901, var: 0.3049
