In [1]:
import jax
print(jax.__version__) 
import jax.numpy as jnp  # JAX NumPy
import optax  # Optimizers
import equinox as eqx  # Equinox

# NF Model
# from flowMC.resource.nf_model.flowjax_model import FlowJAXNormalizingFlow
from flowMC.resource.nf_model.flowjax_wrapper import NFModelFlowJAX

# Data and plotting
from sklearn.datasets import make_moons
import matplotlib.pyplot as plt
import jax.random as jr

# For evaluation
from scipy.special import kl_div

# Load data
data = jnp.array(make_moons(n_samples=100000, noise=0.05)[0])

# Model parameters
n_feature = 2
n_layers = 10
n_hidden = 100

key, subkey = jax.random.split(jax.random.key(0), 2)

model = NFModelFlowJAX(
    dim=2,
    key=subkey,
    n_layers=n_layers,
    flow_type="RationalQuadraticCoupling"
)

optimizer = optax.adam(1e-3)
state = optimizer.init(eqx.filter(model, eqx.is_array))

# Train model
final_key, trained_model, _, losses = model.train(
    key,
    data,
    optimizer,
    state,
    num_epochs=100,
    batch_size=128
)

# Sample from trained model
sample_key = jax.random.fold_in(key, 0)
samples = trained_model.sample(sample_key, 10000)  # Generate more samples for analysis

# Plotting results
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.scatter(data[:, 0], data[:, 1], s=1, alpha=0.5, label='True Data')
plt.title('True Data (Make Moons)')
plt.legend()

plt.subplot(1, 2, 2)
plt.scatter(samples[:, 0], samples[:, 1], s=1, alpha=0.5, label='Generated Samples')
plt.title('NF Generated Samples')
plt.legend()
plt.savefig("nf_moons_comparison.png")
plt.show()

# Compute KL divergence (approximate)
# Note: Since we don't have the true PDF of 'make_moons', this is a simplified comparison.
# Here we compare histograms as discrete distributions.

def compute_hist_kl(data1, data2, bins=50):
    hist1, edges = jnp.histogramdd(data1, bins=bins, density=True)
    hist2, _ = jnp.histogramdd(data2, bins=edges, density=True)
    hist1 += 1e-8  # Avoid log(0)
    hist2 += 1e-8
    return jnp.sum(kl_div(hist1, hist2))

kl_value = compute_hist_kl(data, samples)
print(f"Approximate KL Divergence between true data and generated samples: {kl_value:.4f}")

# Optional: Save model or samples
jnp.save("true_data.npy", data)
jnp.save("generated_samples.npy", samples)
eqx.tree_serialise_leaves("trained_nf_model.eqx", trained_model)

ImportError: DLL load failed while importing utils: 动态链接库(DLL)初始化例程失败。