In [1]:
import jax.numpy as jnp
import jax
import matplotlib.pyplot as plt
import equinox as eqx

In [2]:
num_target = 1000
key = jax.random.PRNGKey(10)
target_samples = jax.random.uniform(key,shape = (10 * num_target,2),minval = -1,maxval = 1)
target_samples = target_samples[jnp.sum(jnp.abs(target_samples),1)>0.5][:num_target]

num_reference = 1000
key = jax.random.PRNGKey(15)
reference_samples = jax.random.normal(key,shape = (num_reference,2))


In [3]:
jax.random.shuffle?

[0;31mSignature:[0m [0mjax[0m[0;34m.[0m[0mrandom[0m[0;34m.[0m[0mshuffle[0m[0;34m([0m[0mkey[0m[0;34m:[0m [0;34m'KeyArrayLike'[0m[0;34m,[0m [0mx[0m[0;34m:[0m [0;34m'ArrayLike'[0m[0;34m,[0m [0maxis[0m[0;34m:[0m [0;34m'int'[0m [0;34m=[0m [0;36m0[0m[0;34m)[0m [0;34m->[0m [0;34m'Array'[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Shuffle the elements of an array uniformly at random along an axis.

Args:
  key: a PRNG key used as the random key.
  x: the array to be shuffled.
  axis: optional, an int axis along which to shuffle (default 0).

Returns:
  A shuffled version of x.
[0;31mFile:[0m      ~/opt/anaconda3/envs/stoch_interp/lib/python3.12/site-packages/jax/_src/random.py
[0;31mType:[0m      function

In [9]:
def get_trainloader():
    for i in range(10**6):
        trainloader_key = jax.random.PRNGKey(i)
        shuffle_ref_key,shuffle_target_key,normal_key,t_key = jax.random.split(trainloader_key,4)
        ref_batch = jax.random.permutation(shuffle_ref_key,reference_samples)
        target_batch = jax.random.permutation(shuffle_target_key,target_samples)
        t_vals = jax.random.uniform(t_key,(len(ref_batch),1))
        z = jax.random.uniform(t_key,ref_batch.shape)
        yield t_vals,ref_batch,target_batch,z

In [10]:
a = get_trainloader()

In [11]:
t,x,y,z = next(a)

In [15]:
z

Array([[0.7956687 , 0.05179608],
       [0.00287819, 0.42221785],
       [0.74910176, 0.4567448 ],
       ...,
       [0.85822594, 0.7483438 ],
       [0.72045755, 0.03744674],
       [0.6263758 , 0.8568218 ]], dtype=float32)

In [40]:
class NeuralNetwork(eqx.Module):
    layers: list
    extra_bias: jax.Array

    def __init__(self, key):
        key1, key2, key3 = jax.random.split(key, 3)
        # These contain trainable parameters.
        self.layers = [eqx.nn.Linear(3, 32, key=key1),
                       eqx.nn.Linear(32, 32, key=key2),
                       eqx.nn.Linear(32, 2, key=key3)]
        # This is also a trainable parameter.
        self.extra_bias = jax.numpy.ones(2)

    def __call__(self, x):
        for layer in self.layers[:-1]:
            x = jax.nn.relu(layer(x))
        return self.layers[-1](x) + self.extra_bias

def I(t,x,y):
    return (1-t)*x + t * y

def It(t,x,y):
    return y-x

def gamma(t):
    return jnp.sqrt(2*t*(1-t))

gammadot = jax.vmap(jax.vmap(jax.grad(gamma)))


@jax.jit  # compile this function to make it run fast.
@jax.grad  # differentiate all floating-point arrays in `model`.
def loss(model,t, x, y,z):
    tx = jnp.hstack([t,x])
    bhat = jax.vmap(model)(tx)  # vectorise the model over a batch of data
    dot_term = jnp.sum((It(t,x,y)+gammadot(t)*z) * bhat,axis = 1)
    return jnp.mean(jnp.sum(bhat**2,axis=1)) - jnp.mean(dot_term)  # L2 loss


In [41]:
t = jax.random.uniform(jax.random.PRNGKey(3),(1000,1))
z = jax.random.normal(jax.random.PRNGKey(5),(1000,2))
model = NeuralNetwork(jax.random.PRNGKey(0))


In [42]:
loss(model,t,reference_samples,target_samples,z)

NeuralNetwork(
  layers=[
    Linear(
      weight=f32[32,3],
      bias=f32[32],
      in_features=3,
      out_features=32,
      use_bias=True
    ),
    Linear(
      weight=f32[32,32],
      bias=f32[32],
      in_features=32,
      out_features=32,
      use_bias=True
    ),
    Linear(
      weight=f32[2,32],
      bias=f32[2],
      in_features=32,
      out_features=2,
      use_bias=True
    )
  ],
  extra_bias=f32[2]
)

In [45]:


loss(model,)

In [None]:
x_key, y_key, model_key = jax.random.split(jax.random.PRNGKey(0), 3)
model = NeuralNetwork(model_key)

# Example data
x = jax.random.normal(x_key, (100, 2))
y = jax.random.normal(y_key, (100, 2))
