In [1]:
import jax
import jax.numpy as jnp
from jax import lax
import matplotlib.pyplot as plt
import numpy as np

In [None]:
class NN:
    def __call__(self, params, x):
        hidden = jax.nn.tanh(jnp.dot(x, params['w1']) + params['b1'])
        output = jnp.dot(hidden, params['w2']) + params['b2']
        return output.squeeze()
    
    def euler_integrate(self, params, y0, t0, t1, num_steps=20):
        h = (t1 - t0) / num_steps
        y = y0
        t = t0
        ys = []
        for _ in range(num_steps):
            y = y + h * self(params, jnp.array([t, y]))
            t = t + h
            ys.append(y)
        return ys

class LossFunction:
    def __init__(self, nn, real_f, t_max):
        self.nn = nn
        self.real_f = real_f
        self.t_max = t_max
        
    def __call__(self, y0, params):
        loss = 0.0
        t0 = 0.0
        ts = jnp.linspace(0.0, self.t_max, num=20)
        y_preds = self.nn.euler_integrate(params, y0, t0, self.t_max)
        y_real = real_f(ts, y0)
        
        return jnp.sum((jnp.array(y_preds) - y_real) ** 2)


def real_f(t, y):
    return jnp.sin(t) * y

hidden_dim = 3
key = jax.random.PRNGKey(0)
key1, key2, key3 = jax.random.split(key, 3)
params = {
    'w1': jax.random.normal(key1, (2, hidden_dim)) * 0.1,
    'b1': jax.random.normal(key2, (hidden_dim,)) * 0.1,
    'w2': jax.random.normal(key3, (hidden_dim, 1)) * 0.1,
    'b2': jnp.zeros((1,))
}

nn = NN()

def gradient_descent_step(loss_fn, y0, params, learning_rate=1e-3):
    grads = jax.grad(loss_fn, argnums=1)(y0, params)
    new_params = jax.tree_util.tree_map(lambda p, g: p - learning_rate * g, params, grads)
    return new_params

epochs = 100
loss_fn = LossFunction(nn, real_f, t_max=5.0)
y0 = jnp.array(1.0)
loss_history = []
for epoch in range(epochs):
    params = gradient_descent_step(loss_fn, y0, params)
    
    if epoch % 1 == 0:
        current_loss = loss_fn(y0, params)
        print(f"Epoch {epoch}, Loss: {current_loss}")
        loss_history.append(current_loss)


Epoch 0, Loss: 8.31982421875
Epoch 1, Loss: 5.291850566864014
Epoch 2, Loss: 4.3660173416137695
Epoch 3, Loss: 4.084941387176514


In [None]:
# plot real vs predicted solution
t_values = jnp.linspace(0.0, 5.0, num=100)
y_real_n = real_f(t_values, y0)
y_pred_n = nn.euler_integrate(params, y0, 0.0, 5.0, num_steps=100)
y_pred_n = jnp.array(y_pred_n)


plt.plot(t_values, y_real_n, label='Real Solution')
plt.plot(t_values, y_pred_n, label='Predicted Solution', linestyle='dashed')
plt.xlabel('t')
plt.ylabel('y(t)')
plt.title('Real vs Predicted Solution')
plt.legend()
plt.show()

TypeError: NN.euler_integrate() got an unexpected keyword argument 'num_steps'