In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from typing import Any, Dict
import haiku as hk
import optax

# Set random seed for reproducibility
key = jax.random.PRNGKey(42)

# Natural frequency
omega_0 = 1.0

# Time points
t = jnp.linspace(0, 10, 100).reshape(-1, 1)

def generate_oscillator_data(beta: float, key: Any) -> Dict[str, Any]:
    """
    Generate time-series data for an underdamped harmonic oscillator.
    """
    # Initial conditions
    x0 = jax.random.uniform(key, minval=-1.0, maxval=1.0)
    v0 = jax.random.uniform(key, minval=-1.0, maxval=1.0)

    # Damped frequency
    omega_d = jnp.sqrt(omega_0**2 - beta**2)

    # Solution
    A = x0
    B = (v0 + beta * x0) / omega_d
    x = jnp.exp(-beta * t) * (A * jnp.cos(omega_d * t) + B * jnp.sin(omega_d * t))

    return {
        't': t,
        'x': x,
        'beta': beta,
        'x0': x0,
        'v0': v0
    }

# Define multiple environments with different damping coefficients
betas = [0.1, 0.2, 0.3]
num_envs = len(betas)

# Generate data for each environment
env_data = []
for beta in betas:
    key, subkey = jax.random.split(key)
    data = generate_oscillator_data(beta, subkey)
    env_data.append(data)

# Prepare data for training
def prepare_data(env_data):
    batch_data = []
    for data in env_data:
        x = data['t']
        y = data['x']
        beta = data['beta']
        x0 = data['x0']
        v0 = data['v0']
        batch_data.append({'x': x, 'y': y, 'beta': beta, 'x0': x0, 'v0': v0})
    return batch_data

batch_data = prepare_data(env_data)

# Define the model architecture
def feature_extractor_fn(x):
    phi = hk.Sequential([
        hk.Linear(64), jax.nn.tanh,
        hk.Linear(64), jax.nn.tanh
    ])
    return phi(x)

def predictor_fn(features):
    w = hk.Linear(1)
    return w(features)

# Transform the functions
feature_extractor = hk.transform(feature_extractor_fn)
predictor = hk.transform(predictor_fn)

# Initialize parameters
def init_model(key, x_sample):
    key_phi, key_w = jax.random.split(key)
    phi_params = feature_extractor.init(key_phi, x_sample)
    features = feature_extractor.apply(phi_params, None, x_sample)
    w_params = predictor.init(key_w, features)
    return {'phi': phi_params, 'w': w_params}

x_sample = batch_data[0]['x'][0:1]

key, subkey = jax.random.split(key)
params = init_model(subkey, x_sample)

# Optimizer
optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(params)

# Loss functions
def mse_loss(predictions, targets):
    return jnp.mean((predictions - targets) ** 2)

def compute_erm_loss(params, x, y):
    features = feature_extractor.apply(params['phi'], None, x)
    predictions = predictor.apply(params['w'], None, features)
    return mse_loss(predictions, y)

def compute_irm_penalty(params, x, y):
    def loss_fn(w_params):
        features = feature_extractor.apply(params['phi'], None, x)
        predictions = predictor.apply(w_params, None, features)
        return mse_loss(predictions, y)
    grads = jax.grad(loss_fn)(params['w'])
    grad_norm = sum([jnp.sum(jnp.square(g)) for g in jax.tree_leaves(grads)])
    return grad_norm

def compute_physics_loss(params, x, beta):
    """
    Compute the physics loss term for the underdamped harmonic oscillator.
    """
    def net_output_scalar(t_scalar):
        t_scalar = t_scalar.reshape(1, 1)
        features = feature_extractor.apply(params['phi'], None, t_scalar)
        x_pred = predictor.apply(params['w'], None, features)
        return x_pred.squeeze()

    # Vectorize net_output_scalar over x
    net_output_vec = jax.vmap(net_output_scalar)

    # Compute x_pred
    x_pred = net_output_vec(x.squeeze())

    # Compute first derivative dx/dt
    dxdt_fn = jax.grad(net_output_scalar)
    dxdt = jax.vmap(dxdt_fn)(x.squeeze())

    # Compute second derivative d2x/dt2
    d2xdt2_fn = jax.grad(dxdt_fn)
    d2xdt2 = jax.vmap(d2xdt2_fn)(x.squeeze())

    # Compute the residual of the differential equation
    residual = d2xdt2 + 2 * beta * dxdt + omega_0**2 * x_pred

    # Compute the mean squared residual (physics loss)
    physics_loss = jnp.mean(residual**2)

    return physics_loss

def compute_initial_condition_loss(params, x0, v0, t0):
    """
    Compute the MSE loss for the initial conditions.
    """
    # Predict displacement at t0
    t0 = t0.reshape(1, 1)
    features = feature_extractor.apply(params['phi'], None, t0)
    x_pred = predictor.apply(params['w'], None, features)
    x_pred = x_pred.squeeze()

    # Compute dx/dt at t0
    def net_output_scalar(t_scalar):
        t_scalar = t_scalar.reshape(1, 1)
        features = feature_extractor.apply(params['phi'], None, t_scalar)
        x_pred = predictor.apply(params['w'], None, features)
        return x_pred.squeeze()

    dxdt_fn = jax.grad(net_output_scalar)
    dxdt_pred = dxdt_fn(t0.squeeze())

    # Compute MSE loss for displacement and velocity
    displacement_loss = (x_pred - x0) ** 2
    velocity_loss = (dxdt_pred - v0) ** 2

    ic_loss = displacement_loss + velocity_loss

    return ic_loss

def total_loss(params, batch_data, lambda_irm, lambda_physics, lambda_ic):
    erm_loss = 0.0
    irm_penalty = 0.0
    physics_loss = 0.0
    ic_loss = 0.0

    for data in batch_data:
        x = data['x']
        y = data['y']
        beta = data['beta']
        x0 = data['x0']
        v0 = data['v0']
        t0 = x[0]

        erm_loss += compute_erm_loss(params, x, y)
        irm_penalty += compute_irm_penalty(params, x, y)
        physics_loss += compute_physics_loss(params, x, beta)
        ic_loss += compute_initial_condition_loss(params, x0, v0, t0)

    erm_loss /= len(batch_data)
    irm_penalty /= len(batch_data)
    physics_loss /= len(batch_data)
    ic_loss /= len(batch_data)

    total_loss_value = (erm_loss +
                        lambda_irm * irm_penalty +
                        lambda_physics * physics_loss +
                        lambda_ic * ic_loss)

    return total_loss_value, (erm_loss, irm_penalty, physics_loss, ic_loss)

# Training loop
num_epochs = 50000
lambda_irm = 1.0
lambda_physics = 1.0
lambda_ic = 1.0

@jax.jit
def update(params, opt_state, batch_data):
    (loss_value, (erm_loss, irm_penalty, physics_loss, ic_loss)), grads = jax.value_and_grad(total_loss, has_aux=True)(
        params, batch_data, lambda_irm, lambda_physics, lambda_ic)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_value, erm_loss, irm_penalty, physics_loss, ic_loss

loss_history = []
erm_loss_history = []
irm_penalty_history = []
physics_loss_history = []
ic_loss_history = []

for epoch in range(num_epochs):
    params, opt_state, loss_value, erm_loss, irm_penalty, physics_loss, ic_loss = update(
        params, opt_state, batch_data)
    loss_history.append(loss_value)
    erm_loss_history.append(erm_loss)
    irm_penalty_history.append(irm_penalty)
    physics_loss_history.append(physics_loss)
    ic_loss_history.append(ic_loss)

    if (epoch + 1) % 50 == 0:
        print(f"Epoch {epoch+1}, Total Loss: {loss_value:.6f}, "
              f"ERM Loss: {erm_loss:.6f}, IRM Penalty: {irm_penalty:.6f}, "
              f"Physics Loss: {physics_loss:.6f}, IC Loss: {ic_loss:.6f}")

# Evaluate the model
def evaluate(params, env_data):
    predictions = []
    for data in env_data:
        x = data['t']
        features = feature_extractor.apply(params['phi'], None, x)
        y_pred = predictor.apply(params['w'], None, features)
        predictions.append(y_pred.flatten())
    return predictions

predictions = evaluate(params, env_data)

# Plot the results
for i, data in enumerate(env_data):
    plt.figure(figsize=(8, 4))
    plt.plot(data['t'], data['x'], label='True')
    plt.plot(data['t'], predictions[i], label='Predicted')
    plt.title(f"Environment {i+1} (beta={data['beta']})")
    plt.xlabel('Time')
    plt.ylabel('Displacement')
    plt.legend()
    plt.show()
