In [1]:
import jax
import jax.numpy as jnp
import numpy as np
import optax
import equinox as eqx
from mlp_prob2 import MLP

with open("data.npy", 'rb') as f:
    s = jnp.load(f)
    phi = jnp.load(f) # phi for prob 1
    phi_2 = jnp.load(f) #phi for prob 2
    phi_3 = jnp.load(f) #phi for prob 3
    
#s[n] = (θ,  ̇θ, h) 
s.shape

(10000, 3)

In [None]:
#φ[n] = [φ_1 φ_2 φ_3z φ_3x] //one row
phi.shape

(10000, 4)

In [None]:
def shuffle_and_split(x, y, train_pcnt = 0.8, seed = 42):
    
    #shuffle the dataset and split into test/train
    n = len(s)
    indexes = np.arange(n)
    np.random.seed(seed)
    np.random.shuffle(indexes)
    
    split = int (train_pcnt * n)
    
    train_idx = indexes[0:split]
    test_idx = indexes[split:]
    
    x_train, x_test = x[train_idx], x[test_idx]
    y_train, y_test = y[train_idx], y[test_idx]
    return x_train, y_train, x_test, y_test

s_train, phi_train, s_test, phi_test = shuffle_and_split(s, phi_3)


In [4]:
print(s_train.shape)
print(s_test.shape)


print(phi_train.shape)
print(phi_test.shape)

(8000, 3)
(2000, 3)
(8000, 4)
(2000, 4)


In [5]:
# Loss function and accuracy metric

def compute_loss(model, x, y):
    pred_y = jax.vmap(model)(x)  
    return jnp.mean((pred_y - y) ** 2)


def compute_accuracy(model, x, y):

    def eval_single(model, x, y):
        pred_y = model(x)
        residual_sq = jnp.mean((y - pred_y)**2)
        return residual_sq
    res_all = jax.vmap(eval_single, in_axes=(None, 0, 0))(model, x, y)
    return jnp.sqrt(jnp.sum(res_all) / x.shape[0])
    

In [6]:
# Training step (JIT compiled)
@eqx.filter_jit
def train_step(model, x, y, opt_state, optimizer):
    def loss_fn(model):
        return compute_loss(model, x, y)
    loss, grads = eqx.filter_value_and_grad(loss_fn)(model)
    updates, opt_state = optimizer.update(grads, opt_state, eqx.filter(model, eqx.is_array))
    model = eqx.apply_updates(model, updates)
    return model, opt_state

In [7]:

# Initialize model and optimizer
key = jax.random.PRNGKey(0)
model = MLP(key) # Instantiate your MLP model here
optimizer = optax.adamw(learning_rate=1e-3, weight_decay=1e-4) # Choose your optimizer and learning_rate.
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))

In [8]:
phi[0:128, [3,2]].shape
s[0:128].shape

(128, 3)

In [None]:

# Training loop
num_epochs = 1000# Set the number of epochs
num_train = len(s_train)
batch_size = 64
x_test = s_test[:,:3]
y_test = phi_test[:,[3,2]]
for epoch in range(num_epochs):

    for batch_start in range(0, num_train, batch_size):
        batch_end = batch_start + batch_size
        x_batch = s_train[batch_start:batch_end, :3] #train on theta, theta_dot, and h
        y_batch = phi_train[batch_start:batch_end, [3,2]] #model ouput should be f_3x and f_3z only
        model, opt_state = train_step(model, x_batch, y_batch, opt_state, optimizer)

    if (epoch + 1) % 50 == 0 or epoch == 0:
        loss = compute_loss(model, x_batch, y_batch)
        acc = compute_accuracy(model, x_batch, y_batch)
        test_loss = compute_loss(model, x_test, y_test)
        test_acc = compute_accuracy(model, x_test, y_test)
        print(f"Epoch {epoch+1}: Loss={loss:.4f}, RMSE={acc:.4f}  Test set: Loss={test_loss:.4f}, RMSE={test_acc:.4f}")

Epoch 1: Loss=0.0303, RMSE=0.1740  Test set: Loss=0.0191, RMSE=0.1383
Epoch 50: Loss=0.0310, RMSE=0.1761  Test set: Loss=0.0180, RMSE=0.1343
Epoch 100: Loss=0.0329, RMSE=0.1813  Test set: Loss=0.0172, RMSE=0.1312
Epoch 150: Loss=0.0320, RMSE=0.1789  Test set: Loss=0.0166, RMSE=0.1289
Epoch 200: Loss=0.0320, RMSE=0.1789  Test set: Loss=0.0159, RMSE=0.1262
Epoch 250: Loss=0.0318, RMSE=0.1783  Test set: Loss=0.0157, RMSE=0.1251
Epoch 300: Loss=0.0316, RMSE=0.1779  Test set: Loss=0.0153, RMSE=0.1236
Epoch 350: Loss=0.0316, RMSE=0.1778  Test set: Loss=0.0149, RMSE=0.1220
Epoch 400: Loss=0.0311, RMSE=0.1765  Test set: Loss=0.0144, RMSE=0.1201
Epoch 450: Loss=0.0307, RMSE=0.1751  Test set: Loss=0.0140, RMSE=0.1181
Epoch 500: Loss=0.0298, RMSE=0.1726  Test set: Loss=0.0137, RMSE=0.1171
Epoch 550: Loss=0.0288, RMSE=0.1697  Test set: Loss=0.0136, RMSE=0.1168
Epoch 600: Loss=0.0283, RMSE=0.1682  Test set: Loss=0.0143, RMSE=0.1197
Epoch 650: Loss=0.0269, RMSE=0.1639  Test set: Loss=0.0137, RMSE=0.