In [11]:
from dynax.solvers import odeint 

import jax.numpy as jnp
from jax import vmap
import numpy.random as npr
from jax import jit, grad
import matplotlib
import matplotlib.pyplot as plt 

def mlp(params, inputs):
  # A multi-layer perceptron, i.e. a fully-connected neural network.
  for w, b in params:
    outputs = jnp.dot(inputs, w) + b  # Linear transform
    inputs = jnp.tanh(outputs)        # Nonlinearity
  return outputs

def nn_dynamics(state, time, params):
  state_and_time = jnp.hstack([state, jnp.array(time)])
  return mlp(params, state_and_time)

def odenet(params, input):
  start_and_end_times = jnp.array([0.0, 1.0])
  init_state, final_state = odeint(nn_dynamics, input, start_and_end_times, params)

  return final_state



batched_odenet = vmap(odenet, in_axes=(None, 0))

# We need to change the input dimension to 2, to allow time-dependent dynamics.
odenet_layer_sizes = [2, 20, 1]

def odenet_loss(params, inputs, targets):
  preds = batched_odenet(params, inputs)
  return jnp.mean(jnp.sum((preds - targets)**2, axis=1))

def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
  return [(scale * rng.randn(m, n), scale * rng.randn(n))
          for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])]

@jit
def odenet_update(params, inputs, targets):
  grads = grad(odenet_loss)(params, inputs, targets)
  return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

# Toy 1D dataset.
inputs = jnp.reshape(jnp.linspace(-2.0, 2.0, 10), (10, 1))
targets = inputs**3 + 0.1 * inputs

# Hyperparameters.
layer_sizes = [1, 20, 1]
param_scale = 1.0
step_size = 0.01
train_iters = 1000



# Initialize and train ODE-Net.
odenet_params = init_random_params(param_scale, odenet_layer_sizes)

preds = batched_odenet(odenet_params, inputs)
print(preds)
print(sss)

for i in range(train_iters):
  odenet_params = odenet_update(odenet_params, inputs, targets)

# Plot resulting model.
fig = plt.figure(figsize=(6, 4), dpi=150)
ax = fig.gca()
ax.scatter(inputs, targets, lw=0.5, color='green')
fine_inputs = jnp.reshape(jnp.linspace(-3.0, 3.0, 100), (100, 1))
#ax.plot(fine_inputs, resnet(resnet_params, fine_inputs, resnet_depth), lw=0.5, color='blue')
ax.plot(fine_inputs, batched_odenet(odenet_params, fine_inputs), lw=0.5, color='red')
ax.set_xlabel('input')
ax.set_ylabel('output')
plt.legend(('Target', 'ODE Net predictions'))

Traced<ShapedArray(float32[1])>with<BatchTrace(level=1/0)> with
  val = Array([[-2.        ],
       [-1.5555556 ],
       [-1.1111112 ],
       [-0.6666667 ],
       [-0.22222227],
       [ 0.22222227],
       [ 0.66666675],
       [ 1.1111112 ],
       [ 1.5555556 ],
       [ 2.        ]], dtype=float32)
  batch_dim = 0 Traced<ShapedArray(float32[1])>with<BatchTrace(level=1/0)> with
  val = Array([[-0.1876402 ],
       [-0.15630104],
       [-0.12310665],
       [-0.07324953],
       [ 0.02656857],
       [ 0.1655008 ],
       [ 0.28269923],
       [ 0.36686522],
       [ 0.4351038 ],
       [ 0.4980237 ]], dtype=float32)
  batch_dim = 0
[[-0.1876402 ]
 [-0.15630104]
 [-0.12310665]
 [-0.07324953]
 [ 0.02656857]
 [ 0.1655008 ]
 [ 0.28269923]
 [ 0.36686522]
 [ 0.4351038 ]
 [ 0.4980237 ]]


NameError: name 'sss' is not defined