# Approximate by Polynomials

The main program is shown below.

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

jax.config.update("jax_enable_x64", True)

# The Runge Function
def f(x):
    return 1/(1+25*x**2)

# Hyperparameters
POLYNOMIAL_DEGREE = 12
learning_rate = 0.01
epochs = 40000
datanum = 1000

# Data
x_train = jnp.linspace(-1.0, 1.0, datanum)
y_train = f(x_train)

key=jax.random.PRNGKey(0)
key, w_key, b_key = jax.random.split(key, 3)
params = {
    'w': jax.random.normal(w_key, (POLYNOMIAL_DEGREE, 1)), 
    'b': jax.random.normal(b_key, (1,))
}

# Construct the Polynomial Model
def polynomial_model(params, x):
    x_col = x.reshape(-1, 1)
    exponents = jnp.arange(1, POLYNOMIAL_DEGREE + 1)
    features = jnp.power(x_col, exponents)
    return features @ params['w'] + params['b']

def loss_fn(params, x, y):
    predictions = polynomial_model(params, x).squeeze() 
    return jnp.mean((predictions - y)**2)
loss_history = []

@jax.jit
# Gradient Descent
def update_step(params, x, y, learning_rate):
    grads = jax.grad(loss_fn)(params, x, y)
    return jax.tree.map(lambda p, g: p - learning_rate * g, params, grads)


for epoch in range(epochs):
    params = update_step(params, x_train, y_train, learning_rate)
    
    if epoch % 1000 == 0:
        loss = loss_fn(params, x_train, y_train)
        loss_history.append(loss)
        print(f"Epoch {epoch}, Loss: {loss:.6f}")
        
x_plot = jnp.linspace(-1, 1, 500)
y_true = f(x_plot)
y_pred = polynomial_model(params, x_plot).squeeze()

final_mse = loss_fn(params, x_train, y_train)
max_error = jnp.max(jnp.abs(polynomial_model(params, x_train).squeeze() - y_train))

# Result
print("\n--- Final Result ---")
print(f"Final MSE (Degree={POLYNOMIAL_DEGREE}): {final_mse:.6f}")
print(f"Final Max Error: {max_error:.6f}")

# Plotting
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(x_plot, y_true, label='True Runge Function', color='blue')
plt.plot(x_plot, y_pred, label='Neural Network Prediction', color='red', linestyle='--')
plt.scatter(x_train, y_train, s=10, color='gray', alpha=0.5, label='Training Data')
plt.title('Function Approximation')
plt.xlabel('x')
plt.ylabel('f(x)')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(range(0, epochs, 1000), loss_history)
plt.title('Training Loss Curve')
plt.xlabel('Epoch')
plt.ylabel('Mean Squared Error Loss')
plt.yscale('log') 
plt.grid(True)

plt.tight_layout()
plt.show()

```text
--- Final Result ---
Final MSE (Degree=12): 0.024220
Final Max Error: 0.413023
```

<figure id="Polynomial 1 Epoch:20000 Datanum:1001">
    <img src="Polynomial_12_0.01_20000_1001.png" alt="Polynomial 1 Layer" style="width: 70%;">
    <figcaption><b>Figure 1</b>: Polynomial Degree:12 Epoch:20000 Datanum:1001.</figcaption>
</figure>

```text
--- Final Result ---
Final MSE (Degree=12): 0.018068
Final Max Error: 0.372409
```

<figure id="Polynomial 1 Epoch:40000 Datanum:1001">
    <img src="Polynomial_12_0.01_40000_1001.png" alt="Polynomial 1 Layer" style="width: 70%;">
    <figcaption><b>Figure 2</b>: Polynomial Degree:12 Epoch:40000 Datanum:1001.</figcaption>
</figure>

I think that the model does not learn well; perhaps the number of data points is too small.

However, the result is:

```text
--- Final Result ---
Final MSE (Degree=12): 0.018105
Final Max Error: 0.372532
```
<figure id="Polynomial 1 Epoch:40000 Datanum:1001">
    <img src="Polynomial_12_0.01_40000_10001.png" alt="Polynomial 1 Layer" style="width: 70%;">
    <figcaption><b>Figure 3</b>: Polynomial Degree:12 Epoch:40000 Datanum:10001.</figcaption>
</figure>

```text
--- Final Result ---
Final MSE (Degree=12): 0.018109
Final Max Error: 0.372544
```

<figure id="Polynomial 1 Epoch:40000 Datanum:1001">
    <img src="Polynomial_12_0.01_40000_100001.png" alt="Polynomial 1 Layer" style="width: 70%;">
    <figcaption><b>Figure 4</b>: Polynomial Degree:12 Epoch:40000 Datanum:100001.</figcaption>
</figure>

Next, I tried increasing the number of epochs.

```text
--- Final Result ---
Final MSE (Degree=12): 0.014324
Final Max Error: 0.328050
```
<figure id="Polynomial 1 Epoch:80000 Datanum:1001">
    <img src="Polynomial_12_0.01_80000_100001.png" alt="Polynomial 1 Layer" style="width: 70%;">
    <figcaption><b>Figure 5</b>: Polynomial Degree:12 Epoch:80000 Datanum:100001.</figcaption>
</figure>


This shows that we cannot expect much further improvement from this approach.

Plus, even if we try to stack the hidden layer, it just modify the degree of our $\sigma$.

```text
--- Final Result ---
Final MSE (Degree=24): 0.019440
Final Max Error: 0.360167
```

<figure id="Polynomial 1 Epoch:40000 Datanum:1001">
    <img src="Polynomial_24_0.01_40000_100001.png" alt="Polynomial 1 Layer" style="width: 70%;">
    <figcaption><b>Figure 6</b>: Polynomial Degree:24 Epoch:40000 Datanum:100001.</figcaption>
</figure>

Thus, I tried another way to approximate it.