In [1]:
import jax
import jax.numpy as jnp
import numpy as np
import optax
import equinox as eqx
from mlp_prob1 import MLP

with open("data.npy", 'rb') as f:
    s = jnp.load(f)
    phi = jnp.load(f)
    
#s[n] = (θ,  ̇θ, h) 
s.shape

(10000, 3)

In [2]:
#φ[n] = [φ_1 φ_2 φ_3z φ_3x] //one row
phi.shape

(10000, 4)

In [None]:
def shuffle_and_split(x, y, train_pcnt = 0.8, seed = 42):
    
    #shuffle the dataset and split into test/train
    n = len(s)
    indexes = np.arange(n)
    np.random.seed(seed)
    np.random.shuffle(indexes)
    
    split = int (train_pcnt * n)
    
    train_idx = indexes[0:split]
    test_idx = indexes[split:]
    
    x_train, x_test = x[train_idx], x[test_idx]
    y_train, y_test = y[train_idx], y[test_idx]
    return x_train, y_train, x_test, y_test

s_train, phi_train, s_test, phi_test = shuffle_and_split(s, phi)

    #def batch_generator(x, y):
        #n = len(x)
        
        #for start_idx in range(0, n, batch_size):
            #end_idx = start_idx + batch_size
            #yield x[start_idx:end_idx], y[start_idx:end_idx]

    #return batch_generator(s_train, phi_train), batch_generator(s_test, phi_test)
    

    




        

    
    


In [4]:
print(s_train.shape)
print(s_test.shape)


print(phi_train.shape)
print(phi_test.shape)

(8000, 3)
(2000, 3)
(8000, 4)
(2000, 4)


In [None]:
# Loss function and accuracy metric

def compute_loss(model, x, y):
    pred_y = jax.vmap(model)(x)  
    return jnp.mean((pred_y - y) ** 2)

#RMSE, since this is a regression model
def compute_accuracy(model, x, y): 

    def eval_single(model, x, y):
        pred_y = model(x)
        residual_sq = jnp.mean((y - pred_y)**2)
        return residual_sq
    res_all = jax.vmap(eval_single, in_axes=(None, 0, 0))(model, x, y)
    return jnp.sqrt(jnp.sum(res_all) / x.shape[0])
    

In [6]:
# Training step (JIT compiled)
@eqx.filter_jit
def train_step(model, x, y, opt_state, optimizer):
    def loss_fn(model):
        return compute_loss(model, x, y)
    loss, grads = eqx.filter_value_and_grad(loss_fn)(model)
    updates, opt_state = optimizer.update(grads, opt_state, eqx.filter(model, eqx.is_array))
    model = eqx.apply_updates(model, updates)
    return model, opt_state

In [7]:

# Initialize model and optimizer
key = jax.random.PRNGKey(0)
model = MLP(key) # Instantiate your MLP model here
optimizer = optax.adamw(learning_rate=1e-3, weight_decay=1e-4) # Choose your optimizer and learning_rate.
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))

In [8]:
phi[0:128, [3,2]].shape

(128, 2)

In [None]:

# Training loop
num_epochs = 500
num_train = len(s_train)
batch_size = 64
x_test = s_test[:,[0,2]]
y_test = phi_test[:,[3,2]]
for epoch in range(num_epochs):

    for batch_start in range(0, num_train, batch_size):
        batch_end = batch_start + batch_size
        x_batch = s_train[batch_start:batch_end,[0,2]] #train only on theta and h since assuming theta_dot = 0
        y_batch = phi_train[batch_start:batch_end, [3,2]] #model ouput should be f_3x and f_3z only
        model, opt_state = train_step(model, x_batch, y_batch, opt_state, optimizer)

    if (epoch + 1) % 50 == 0 or epoch == 0:
        loss = compute_loss(model, x_batch, y_batch)
        acc = compute_accuracy(model, x_batch, y_batch)
        test_loss = compute_loss(model, x_test, y_test)
        test_acc = compute_accuracy(model, x_test, y_test)
        print(f"Epoch {epoch+1}: Loss={loss:.4f}, RMSE={acc:.4f}  Test set: Loss={test_loss:.4f}, RMSE={test_acc:.4f}")

Epoch 1: Loss=0.0109, RMSE=0.1045  Test set: Loss=0.0105, RMSE=0.1027
Epoch 50: Loss=0.0118, RMSE=0.1087  Test set: Loss=0.0116, RMSE=0.1076
Epoch 100: Loss=0.0187, RMSE=0.1367  Test set: Loss=0.0189, RMSE=0.1376
Epoch 150: Loss=0.0178, RMSE=0.1333  Test set: Loss=0.0184, RMSE=0.1358
Epoch 200: Loss=0.0142, RMSE=0.1192  Test set: Loss=0.0148, RMSE=0.1218
Epoch 250: Loss=0.0168, RMSE=0.1297  Test set: Loss=0.0186, RMSE=0.1364
Epoch 300: Loss=0.0154, RMSE=0.1240  Test set: Loss=0.0181, RMSE=0.1346
Epoch 350: Loss=0.0189, RMSE=0.1375  Test set: Loss=0.0222, RMSE=0.1491
Epoch 400: Loss=0.0073, RMSE=0.0853  Test set: Loss=0.0095, RMSE=0.0975
Epoch 450: Loss=0.0126, RMSE=0.1123  Test set: Loss=0.0152, RMSE=0.1234
Epoch 500: Loss=0.0057, RMSE=0.0752  Test set: Loss=0.0076, RMSE=0.0871
