In [22]:
import matplotlib.pyplot as plt
import jax
jax.config.update('jax_default_device',jax.devices()[-1])
import jax.numpy as jnp
import optax
import jax
import numpy as np
import ot
from stochastic_interpolant.neural_network import NeuralNetwork
from stochastic_interpolant.dataloaders import (
    DatasetSampler,GaussianReferenceSampler,
    IndependenceCouplingSampler,build_trainloader,testloader_factory)
from stochastic_interpolant.loss_functions import get_linear_interpolants,root_prod_gamma,root_prod_gammadot,get_loss_functions
from stochastic_interpolant.data_generators import inf_train_gen
# from stochastic_interpolant.model_training import train_model

In [8]:
# Target sample
num_target = 5000
target_samples = inf_train_gen('2spirals',rng = 50,batch_size = num_target)

In [13]:
# Reference Sample
reference_samples = jax.random.normal(jax.random.PRNGKey(13334),(num_target,2))

In [19]:
mini_batch_size = 100
data_key = jax.random.PRNGKey(42)
target_key, reference_key = jax.random.split(data_key,2)
target_batch = jax.random.choice(target_key,target_samples,(mini_batch_size,))
reference_batch = jax.random.choice(reference_key,reference_samples,(mini_batch_size,))

In [23]:
M = ot.dist(target_batch,reference_batch,p=2)

In [25]:
M.shape

(100, 100)

In [26]:
M@jnp.ones(mini_batch_size)

Array([ 369.0374 ,  549.6625 , 1005.4695 ,  391.70386,  344.48224,
        506.3252 ,  683.6517 , 1050.1572 ,  363.26294,  383.9378 ,
       1009.8624 ,  352.19568,  422.4872 ,  393.25183,  382.6565 ,
        764.4727 ,  360.61597, 1212.6416 ,  530.71106,  874.10406,
        583.2994 ,  836.624  , 1108.693  ,  769.63293, 1142.071  ,
        426.25732,  365.1642 ,  884.1506 , 1069.3484 ,  307.72604,
        535.5749 ,  620.7322 ,  358.62256, 1194.4216 ,  266.0911 ,
        935.0187 , 1018.1676 , 1189.8458 ,  492.63837,  349.62787,
        507.22318,  520.49475, 1121.4202 ,  420.5481 ,  992.4385 ,
        822.4968 ,  939.06885,  732.46716,  507.03778,  433.003  ,
        555.6853 ,  255.8516 ,  367.66223,  366.97217,  857.36566,
        785.0674 ,  327.9602 ,  487.1964 ,  452.13974,  762.114  ,
        479.14502,  654.1726 ,  767.2874 ,  392.8618 ,  641.95764,
        313.57074,  957.6579 ,  297.38544,  219.10406,  430.3553 ,
        343.68054,  848.2037 ,  914.9475 ,  968.77783,  786.85

In [27]:
M[0,:]

Array([4.87128019e+00, 1.84214318e+00, 5.76710272e+00, 3.44439697e+00,
       5.08630514e+00, 8.31590772e-01, 7.05637932e+00, 4.87010050e+00,
       7.37586141e-01, 1.14021790e+00, 7.94805110e-01, 6.41036034e-01,
       3.31614685e+00, 5.75101471e+00, 4.77346992e+00, 7.03291607e+00,
       3.62846589e+00, 7.05816507e-01, 2.72464848e+00, 2.80247879e+00,
       3.67493129e+00, 3.87888193e+00, 1.44077539e+00, 7.74126887e-01,
       6.16751254e-01, 2.24060225e+00, 1.89575529e+00, 7.65278530e+00,
       6.41045189e+00, 2.57034707e+00, 2.83603525e+00, 5.23129344e-01,
       3.12386918e+00, 1.60587370e-01, 8.51220798e+00, 3.84350598e-01,
       8.27554703e+00, 1.06159439e+01, 4.96510553e+00, 7.70366764e+00,
       2.76179171e+00, 2.78559756e+00, 3.54905152e+00, 1.04129314e-02,
       3.49837303e+00, 4.36041057e-01, 4.13526249e+00, 2.00629950e+00,
       1.27573907e+00, 2.93755436e+00, 3.93534994e+00, 4.86815786e+00,
       3.46180153e+00, 5.48256779e+00, 1.16407661e+01, 3.44702625e+00,
      