In [1]:
import seaborn as sns
import jax
import ott
import diffrax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

from entot.data.data import MixtureNormalSampler
from entot.models.model import OTFlowMatching
from entot.nets.nets import MLP_vector_field, Bridge_MLP_mean, MLP_marginal
from entot.plotting.plots import plot_1D_unbalanced

- the fwd and bwd functions take an extra `perturbed` argument, which     indicates which primals actually need a gradient. You can use this     to skip computing the gradient for any unperturbed value. (You can     also safely just ignore this if you wish.)
- `None` was previously passed to indicate a symbolic zero gradient for     all objects that weren't inexact arrays, but all inexact arrays     always had an array-valued gradient. Now, `None` may also be passed     to indicate that an inexact array has a symbolic zero gradient.
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
source = MixtureNormalSampler(jax.random.PRNGKey(0), [0], 1,  0.5, batch_size=1024)
target = MixtureNormalSampler(jax.random.PRNGKey(1), [-1.0,1.0, 5.0], 1,  0.1, batch_size=1024)

2023-08-11 17:56:12.235295: W external/xla/xla/service/gpu/nvptx_compiler.cc:698] The NVIDIA driver's CUDA version is 12.1 which is older than the ptxas CUDA version (12.2.128). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [3]:
epsilon = 1e-1
tau_a = 0.98
tau_b = 0.98

In [4]:
neural_net = MLP_vector_field(1, 128, 128, 128, n_frequencies=10)
bridge_net = Bridge_MLP_mean(1, 64, 64)
ot_solver = ott.solvers.linear.sinkhorn.Sinkhorn()
solver_latent_to_data = ott.solvers.linear.sinkhorn.Sinkhorn()

mlp_eta = MLP_marginal(1, 128)
mlp_xi = MLP_marginal(1, 128)

otfm = OTFlowMatching(neural_net, bridge_net=bridge_net, ot_solver=ot_solver, epsilon=epsilon, mlp_eta=mlp_eta, mlp_xi=mlp_xi, tau_a=tau_a, tau_b=tau_b, input_dim=1, output_dim=1, iterations=20_000, k_noise_per_x=10, solver_latent_to_data=solver_latent_to_data)

In [5]:
otfm.iterations=2_000
otfm(source, target, 1024, 1024)

  0%|          | 0/2000 [00:00<?, ?it/s]

a shape 0 is 1024 
(1024,)
sum of k samples per x 10
shape of k samples per x (1024,)
shape of tmat augmented (10240, 1024)
shape of a argument (1024,)


  0%|          | 0/2000 [00:10<?, ?it/s]

target batch shape (10240, 1, 1)
source batch shape (10240, 1)
noise batch shape (1024, 10, 1)





ValueError: vmap got inconsistent sizes for array axes to be mapped:
  * most axes (2 of them) had size 10240, e.g. axis 0 of argument key of type uint32[10240,2];
  * one axis had size 1024: axis 0 of argument x of type float32[1024,10,1]

In [None]:
source.batch_size = 1024
source_batch = next(source)
res, _ , _= otfm.transport(source_batch, seed=0)

In [None]:
sns.kdeplot(res[0,...])

In [None]:
kwargs = {}
a = (-2.5, 2.5)
b = (0.0, 2.0)
c = (0.0, 4.0)
d = (-2.5, 7.0)
kwargs["00_xlim"] = a
kwargs["01_xlim"] = a
kwargs["02_xlim"] = a
kwargs["03_xlim"] = b
kwargs["04_xlim"] = b
kwargs["10_xlim"] = a
kwargs["11_xlim"] = a
kwargs["12_xlim"] = b
kwargs["13_xlim"] = b
kwargs["14_xlim"] = b

kwargs["00_ylim"] = b
kwargs["01_ylim"] = b
kwargs["02_ylim"] = b
kwargs["03_ylim"] = d
kwargs["04_ylim"] = d
kwargs["10_ylim"] = d
kwargs["11_ylim"] = d
kwargs["12_ylim"] = d
kwargs["13_ylim"] = d
kwargs["14_ylim"] = d

source.batch_size = 1024
source_batch = next(source)

target.batch_size = 1024
target_batch = next(target)

rescale_source = otfm.state_eta.apply_fn({"params": otfm.state_eta.params}, source_batch)
rescale_target = otfm.state_xi.apply_fn({"params": otfm.state_xi.params}, target_batch)

res, _, _ = otfm.transport(source_batch, seed=1, diffeqsolve_kwargs={"max_steps": 1_000_000})

_  = plot_1D_unbalanced(source_batch, target_batch, res[0,...], rescale_source, rescale_target, epsilon=epsilon, tau_a=tau_a, tau_b=tau_b, seed=1, **kwargs)

In [None]:


def sample_conditional_indices_from_tmap(
    key: jax.random.PRNGKeyArray, tmat: jnp.ndarray, k_samples_per_x: Optional[int]
) -> Tuple[jnp.array, jnp.array]:
    indices_per_row = jax.vmap(
        lambda tmat: jax.random.choice(key=key, a=jnp.arange(len(tmat)), p=tmat, shape=(k_samples_per_x,)),
        in_axes=0,
        out_axes=0,
    )(tmat)
    return jnp.repeat(jnp.arange(tmat.shape[0]), k_samples_per_x), indices_per_row % tmat.shape[1]