In [3]:
import math
import optax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import torch
import jax  # noqa: F401

from jax_src.train import train, la_loss, bb_loss, bb_loss_fixed_wh, make_optimizer  # noqa: F401
from jax_src.generator import Net, make_net, save_net, load_net  # noqa: F401
from jax_src.discriminator import (
    UCFDiscriminator,
    init_transform,
    WassersteinDiscriminator,
)  # noqa: F401
from jax_src.evaluation import evaluate_net, evaluate_fosters_method  # noqa: F401

jax.config.update("jax_enable_x64", True)
test_bm_dim = 6


def load_moments(bm_dim):
    file_path = f"/home/andy/PycharmProjects/Levy_CFGAN/moments/dim_{bm_dim}_moments.pt"
    moments_loaded = torch.load(file_path)
    return moments_loaded.numpy()


true_4moms = load_moments(test_bm_dim)


def eval_and_save(
    current_best, net, key, true_4moms, num_samples, bm_dim, saving=False
):
    eval_results = evaluate_net(net, key, true_4moms, num_samples, bm_dim)
    max_4mom_err = eval_results[0][1]
    wass2_err = eval_results[5]
    score = math.sqrt(max_4mom_err**2 + wass2_err**2)
    if score < current_best:
        print(f"New best score: {score:.4}")
        if saving:
            save_net(net, "/home/andy/PycharmProjects/Levy_CFGAN/numpy_nets/")
        return score
    elif current_best == -1:
        print(f"Initial score: {score:.4}")
        return score
    else:
        return current_best


def plot_losses(_losses_list):
    all_losses = jnp.abs(jnp.concatenate(_losses_list, axis=0))
    # Get rid of spikes
    bound1 = jnp.min(all_losses) + 3 * jnp.std(all_losses)
    losses_pruned = all_losses[all_losses < bound1]
    bound2 = jnp.mean(losses_pruned) + 3 * jnp.std(losses_pruned)
    all_losses = jnp.clip(all_losses, 0.0, bound2)

    plt.plot(all_losses)
    plt.show()

In [2]:
# Either make a new net or load a pre-trained one
noise_size = 4
hidden_dim = 16
num_layers = 3
leaky_slope = 0.01
use_multlayer = True
net = make_net(
    jr.key(7),
    noise_size,
    hidden_dim,
    num_layers,
    leaky_slope,
    use_multlayer,
    dtype=jnp.complex64,
    use_batch_norm=False,
    use_activation=True,
)
saving = True
# net = load_net(
#     "/home/andy/PycharmProjects/Levy_CFGAN/numpy_nets/",
#     noise_size,
#     hidden_dim,
#     num_layers,
#     leaky_slope,
#     use_multlayer,
#     jnp.complex64,
#     use_batch_norm=False,
#     use_activation=True,
# )
losses_list = []
GLOBAL_KEY = jr.key(3)
# net without training
net_best = eval_and_save(-1, net, jr.key(6), true_4moms, 2**20, test_bm_dim, saving)

Wass2: 0.084; MOMS: 4max: 0.0227, 4avg: 0.000886, 3max: 0.00231, 2max: 0.0415, 1max: 0.00097, 0max: 0.307
Initial score: 0.08699


## BB loss

In [3]:
bm_dim_train = 4
GLOBAL_KEY, temp_key = jr.split(GLOBAL_KEY, 2)


def get_discr(global_key, use_ucfd=True):
    if use_ucfd:
        global_key, temp_key = jr.split(global_key, 2)
        m = 64
        n = 3
        M = init_transform(
            temp_key, bm_dim_train, m, n, jnp.complex64
        )  # nxn Lie algebra, m different tensors
        discriminator = UCFDiscriminator(M, bm_dim_train)
    else:
        discriminator = WassersteinDiscriminator(bm_dim_train)
    return discriminator, global_key


wass_discr, GLOBAL_KEY = get_discr(GLOBAL_KEY, use_ucfd=False)
ucf_discr, GLOBAL_KEY = get_discr(GLOBAL_KEY, use_ucfd=True)

In [5]:
num_reps = 20
num_steps_wass = 2**13
num_steps_ucf = 128
total_steps = num_reps * (num_steps_wass + num_steps_ucf)
# Set the learning rate schedule
# schedule_wass = optax.cosine_decay_schedule(1e-4, total_steps, 1e-8)
schedule_wass = optax.constant_schedule(1e-5)
schedule_ucf = optax.cosine_decay_schedule(1e-5, total_steps, 1e-7)
# schedule = optax.constant_schedule(1e-4)
opt_wass, opt_state_wass = make_optimizer(
    net, wass_discr, schedule_wass, beta1=0.95, beta2=0.995
)
opt_ucf, opt_state_ucf = make_optimizer(
    net, ucf_discr, schedule_ucf, beta1=0.7, beta2=0.97
)

for i in range(num_reps):
    # Re-initialize the discriminator every few iterations
    # num_discr_iters = 3
    # lr_ratio = 5.0
    # if (i + 0) % 1000 == 0:
    #     print("Re-initializing the discriminator")
    #     ucf_discr, GLOBAL_KEY = get_discr(GLOBAL_KEY, use_ucfd=True)
    #     num_discr_iters += 7
    #     lr_ratio *= 2.0

    GLOBAL_KEY, temp_key = jr.split(GLOBAL_KEY, 2)
    net, wass_discr, opt_state_wass, losses = train(
        net,
        wass_discr,
        temp_key,
        num_steps_wass,
        opt_wass,
        opt_state_wass,
        bb_loss_fixed_wh,
        1.0,
        1,
    )
    avg_loss = float(jnp.mean(jnp.abs(losses)))
    print(f"======== Wasserstein rep {i+1}/{num_reps} ======== avg_loss: {avg_loss:.4}")
    losses_list.append(losses)
    net_best = eval_and_save(
        net_best, net, jr.key(6), true_4moms, 2**19, test_bm_dim, saving
    )

    # print(f"======== UCF rep {i+1}/{num_reps} ========")
    # GLOBAL_KEY, temp_key = jr.split(GLOBAL_KEY, 2)
    # net, ucf_discr, opt_state_ucf, losses = train(
    #     net,
    #     ucf_discr,
    #     temp_key,
    #     num_steps_ucf,
    #     opt_ucf,
    #     opt_state_ucf,
    #     bb_loss_fixed_wh,
    #     lr_ratio,
    #     num_discr_iters,
    # )
    # losses_list.append(losses)
    # net_best = eval_and_save(
    #     net_best, net, jr.key(6), true_4moms, 2**20, test_bm_dim, saving
    # )

    # plot_losses(losses_list)
    print("\n")

Wass2: 0.097; MOMS: 4max: 0.0279, 4avg: 0.000585, 3max: 0.00161, 2max: 0.0174, 1max: 0.00151, 0max: 0.343

Wass2: 0.0915; MOMS: 4max: 0.0362, 4avg: 0.000643, 3max: 0.00144, 2max: 0.0101, 1max: 0.00145, 0max: 0.351

Wass2: 0.0865; MOMS: 4max: 0.0421, 4avg: 0.000681, 3max: 0.00133, 2max: 0.00427, 1max: 0.00123, 0max: 0.357


KeyboardInterrupt: 

In [None]:
plot_losses(losses_list)

In [4]:
_ = evaluate_fosters_method(jr.key(6), true_4moms, 2**20, test_bm_dim)

Wass2: 0.00229; MOMS: 4max: 0.00486, 4avg: 0.000214, 3max: 0.00128, 2max: 0.00118, 1max: 0.00108, 0max: 0.00111
