In [None]:
import jax
import jax.random as jrandom
import jax.numpy as jnp
import flax.linen as nn
import optax
from typing import Sequence, Callable, Optional, Dict, Any
import numpy as np

from mmnn_jax import SinActivation,SinTUActivation, MMNNLayer, MMNNModel, Train_jax_model



import matplotlib.pyplot as plt
plt.style.use('figures/images_style.mplstyle')



In [None]:
y = lambda x: np.cos(20*np.pi*np.abs(x)**(1.4))+0.5*np.cos(12*np.pi*np.abs(x)**(1.6))

# Input data

x = jnp.linspace(-1, 1, 1000)
x = x.reshape(-1, 1)  # Reshape to 2D array

# Target data

y_data = y(x)

# Plot data

plt.figure(figsize=(10, 5))
plt.plot(x, y_data, label='Target Function', color='blue')
plt.title('Target Function')
plt.xlabel('x')
plt.ylabel('y')
plt.grid()
plt.legend()


In [None]:
# Configuration using ranks and widths
ranks = [1]+ [18]*2 + [ 1]     # 
widths = [356]*3      # 

print("Network architecture:")
print(f"Ranks (dimensions): {ranks}")
print(f"Widths (hidden layer sizes): {widths}")

# Create model
model = MMNNModel(
    ranks=ranks,
    widths=widths,
    activation=SinTUActivation(),
    seed=42
)

scheduler = optax.exponential_decay(
    init_value=0.001,  # Initial learning rate
    transition_steps=100,  # Number of steps before decay
    decay_rate=0.96,  # Decay rate
    staircase=True  # Use staircase decay
)

train_model = Train_jax_model(
    model=model,
    input_data=x,
    target_data=y_data,
    optimizer='adam',
    loss_fn='mse',
    learning_rate=scheduler,
    num_epochs=5000,
    batch_size=100,
    random_seed=42
)

In [None]:
params,epochs_dict = train_model.training_loop(print_every=500)

In [None]:
x = jnp.linspace(-1, 1, 2000).reshape(-1, 1)  # Reshape to 2D array for prediction
y_data = y(x)
y_pred = model.apply(params, x)

pred_error = jnp.mean((y_data - y_pred) ** 2)

print(f"MSE: {pred_error:.8f}")

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(x, y_data, label='Target Function', color='blue')
plt.plot(x, y_pred, label='Model Prediction', color='red', marker='*', alpha=0.5, markersize=1)
plt.title('Model Prediction vs Target Function')
plt.xlabel('x')
plt.ylabel('y')
plt.grid()
plt.legend()
plt.show()