In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import jax
jax.config.update('jax_platform_name', 'cpu')

In [None]:
from entot.data.data import create_gaussians, create_gaussian_split
from entot.plotting.plots import plot_1D
from entot.models.models import NoiseOutsourcingModel
from entot.models.utils import MixtureNormalSampler
import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp

In [None]:
sampler_source = MixtureNormalSampler([0], 1, 100, 0.5)
sampler_target = MixtureNormalSampler([-1.0,1.0], 1, 100, 0.1)

In [None]:
source = sampler_source(jax.random.PRNGKey(0))

In [None]:
source.shape

In [None]:
target = sampler_target(jax.random.PRNGKey(0))

In [None]:
T_xz = jnp.expand_dims(jnp.transpose(jnp.expand_dims(target, 0)), axis=-1) + 0.01 * jax.random.normal(jax.random.PRNGKey(0), shape=(100,1,10))

# Dataset 1

In [None]:
source, target = create_gaussians(100, 100, var_source=0.2, var_target=0.2)

In [None]:
plt.scatter(source[:,0], source[:,1], color="blue")
plt.scatter(target[:, 0], target[:, 1], color="red")

In [None]:
sm = NoiseOutsourcingModel(0.01, 512, 512, iterations=500, inner_iterations=10, input_dim=2, noise_dim=2)


In [None]:
sm(source, target)

In [None]:
sm.metrics

In [None]:
transported = sm.transport(source)
plt.scatter(source[:,0], source[:,1], color="blue")
plt.scatter(target[:, 0], target[:, 1], color="red")
plt.scatter(transported[:, 0], transported[:, 1], color="green", marker="P")

In [None]:
sampled = sm.sample(source[0,None])
plt.scatter(source[:,0], source[:,1], color="blue")
plt.scatter(target[:, 0], target[:, 1], color="red")
plt.scatter(sampled[:, 0], sampled[:, 1], color="green", marker="P")

In [None]:
plt.plot(np.arange(len(sm.metrics["t_obj"])), sm.metrics["t_obj"])

In [None]:
plt.plot(np.arange(len(sm.metrics["phi_obj"])), sm.metrics["phi_obj"])

In [None]:
a = jnp.array([[1,2, 3], [3,4, 5]])

In [None]:
a.shape

In [None]:
plt.plot(np.arange(len(sm.metrics["phi_obj"])), sm.metrics["phi_obj"])