# Approximate by Trigonometric Functions
My understanding of polynomial approximation is analogous to a Taylor series, which led me to consider a Fourier series-like approach.

In the other words, I just try to use Trigonometric Functions to deal with this problem.

Before I start, note that the Runge function is an even function:

$$
f(-x)=\dfrac{1}{1+25(-x)^2}=\dfrac{1}{1+25x^2}=f(x)
$$

Thus, I just use the cosine functions as my hypothesis.

The main proram is shown below.

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

jax.config.update("jax_enable_x64", True)

# The Runge Function
def f(x):
    return 1/(1+25*x**2)

# Hyperparameters
learning_rate = 0.01
epochs = 20000
datanum = 1001

# Data
x_train = jnp.linspace(-1.0, 1.0, datanum)
y_train = f(x_train)

key=jax.random.PRNGKey(0)
key, w1_key, w2_key, b_key = jax.random.split(key, 4)
params = {
    'w1': jax.random.normal(w1_key, (1, 16)), 
    'w2': jax.random.normal(w2_key, (16, 1)), 
    'b': jax.random.normal(b_key, (1,))
}

# Construct the Trigonometric Model
def trigonometric_model(params, x):
    x = x.reshape(-1, 1)
    hidden = jnp.cos(x @ params['w1'])
    return hidden @ params['w2'] + params['b']

def loss_fn(params, x, y):
    predictions = trigonometric_model(params, x).squeeze() 
    return jnp.mean((predictions - y)**2)
loss_history = []

@jax.jit
# Gradient Descent
def update_step(params, x, y, learning_rate):
    grads = jax.grad(loss_fn)(params, x, y)
    return jax.tree.map(lambda p, g: p - learning_rate * g, params, grads)

for epoch in range(epochs):
    params = update_step(params, x_train, y_train, learning_rate)
    
    if epoch % 1000 == 0:
        loss = loss_fn(params, x_train, y_train)
        loss_history.append(loss)
        print(f"Epoch {epoch}, Loss: {loss:.6f}")
        
x_plot = jnp.linspace(-1, 1, 500)
y_true = f(x_plot)
y_pred = trigonometric_model(params, x_plot).squeeze()
final_mse = loss_fn(params, x_train, y_train)
max_error = jnp.max(jnp.abs(trigonometric_model(params, x_train).squeeze() - y_train))

# Result
print("\n--- Final Result ---")
print(f"Final MSE: {final_mse:.6f}")
print(f"Final Max Error: {max_error:.6f}")

# Plotting
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(x_plot, y_true, label='True Runge Function', color='blue')
plt.plot(x_plot, y_pred, label='Neural Network Prediction', color='red', linestyle='--')
plt.scatter(x_train, y_train, s=10, color='gray', alpha=0.5, label='Training Data')
plt.title('Function Approximation')
plt.xlabel('x')
plt.ylabel('f(x)')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(range(0, epochs, 1000), loss_history)
plt.title('Training Loss Curve')
plt.xlabel('Epoch')
plt.ylabel('Mean Squared Error Loss')
plt.yscale('log') 
plt.grid(True)

plt.tight_layout()
plt.show()

```text
--- Final Result ---
Final MSE: 0.014785
Final Max Error: 0.311562
```
<figure id="Fourier Learning rate:0.01 Layer:1 Epoch:20000 Datanum:1001">
    <img src="Fourier_0.01_20000_1001.png" alt="Fourier 1 Layer" style="width: 70%;">
    <figcaption><b>Figure 7</b>: Fourier Learning rate:0.01 Layer:1 Epoch:20000 Datanum:1001.</figcaption>
</figure>

```text
--- Final Result ---
Final MSE: 0.014769
Final Max Error: 0.311453
```

<figure id="Fourier Learning rate:0.01 Layer:1 Epoch:20000 Datanum:10001">
    <img src="Fourier_0.01_20000_10001.png" alt="Fourier 1 Layer" style="width: 70%;">
    <figcaption><b>Figure 8</b>: Fourier Learning rate:0.01 Layer:1 Epoch:20000 Datanum:10001.</figcaption>
</figure>

I found that the shape did not match my expectations as well as I had imagined.

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

jax.config.update("jax_enable_x64", True)

# The Runge Function
def f(x):
    return 1/(1+25*x**2)

# Hyperparameters
learning_rate = 0.01
epochs = 20000
datanum = 10001

# Data
x_train = jnp.linspace(-1.0, 1.0, datanum)
y_train = f(x_train)

key=jax.random.PRNGKey(0)
keys = jax.random.split(key, 7)
params = {
    'w1': jax.random.normal(keys[0], (1, 16)),
    
    'w2': jax.random.normal(keys[1], (16, 16)),
    'b2': jax.random.normal(keys[2], (16,)),

    'w3': jax.random.normal(keys[3], (16, 16)),                                                                   # Added
    'b3': jax.random.normal(keys[4], (16,)),                                                                      # Added

    'w4': jax.random.normal(keys[5], (16, 1)),                                                                    # Added
    'b4': jax.random.normal(keys[6], (1,))                                                                        # Added
}

# Construct the Trigonometric Model
def deep_trigonometric_model(params, x):
    x = x.reshape(-1, 1)
    features = jnp.cos(x @ params['w1'])
    hidden1 = jnp.cos(features @ params['w2'] + params['b2'])
    hidden2 = jnp.cos(hidden1 @ params['w3'] + params['b3'])
    output = hidden2 @ params['w4'] + params['b4']
    return output

def loss_fn(params, x, y):
    predictions = deep_trigonometric_model(params, x).squeeze() 
    return jnp.mean((predictions - y)**2)
loss_history = []

@jax.jit
# Gradient Descent
def update_step(params, x, y, learning_rate):
    grads = jax.grad(loss_fn)(params, x, y)
    return jax.tree.map(lambda p, g: p - learning_rate * g, params, grads)

for epoch in range(epochs):
    params = update_step(params, x_train, y_train, learning_rate)
    
    if epoch % 1000 == 0:
        loss = loss_fn(params, x_train, y_train)
        loss_history.append(loss)
        print(f"Epoch {epoch}, Loss: {loss:.6f}")
        
x_plot = jnp.linspace(-1, 1, 500)
y_true = f(x_plot)
y_pred = deep_trigonometric_model(params, x_plot).squeeze()
y_pred_train = deep_trigonometric_model(params, x_train).squeeze()
final_mse = jnp.mean((y_pred_train - y_train)**2)
max_error = jnp.max(jnp.abs(y_pred_train - y_train))

# Result
print("\n--- Final Result ---")
print(f"Final MSE: {final_mse:.6f}")
print(f"Final Max Error: {max_error:.6f}")

# Plotting
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(x_plot, y_true, label='True Runge Function', color='blue')
plt.plot(x_plot, y_pred, label='Neural Network Prediction', color='red', linestyle='--')
plt.scatter(x_train, y_train, s=10, color='gray', alpha=0.5, label='Training Data')
plt.title('Function Approximation')
plt.xlabel('x')
plt.ylabel('f(x)')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(range(0, epochs, 1000), loss_history)
plt.title('Training Loss Curve')
plt.xlabel('Epoch')
plt.ylabel('Mean Squared Error Loss')
plt.yscale('log') 
plt.grid(True)

plt.tight_layout()
plt.show()

```text
--- Final Result ---
Final MSE: 0.001284
Final Max Error: 0.091240
```

<figure id="Fourier Learning rate:0.01 Layer:2 Epoch:20000 Datanum:10001">
    <img src="Fourier_0.01_20000_10001_2.png" alt="Fourier 2 Layers" style="width: 70%;">
    <figcaption><b>Figure 9</b>: Fourier Learning rate:0.01 Layer:2 Epoch:20000 Datanum:10001.</figcaption>
</figure>

This matched my assumption. Although a strange oscillation appeared on the right-hand side, this result encouraged me to add another hidden layer.

```text
--- Final Result ---
Final MSE: 0.000074
Final Max Error: 0.022620
```

<figure id="Fourier Learning rate:0.01 Layer:3 Epoch:20000 Datanum:10001">
    <img src="Fourier_0.01_20000_10001_3.png" alt="Fourier 3 Layers" style="width: 70%;">
    <figcaption><b>Figure 10</b>: Fourier Learning rate:0.01 Layer:3 Epoch:20000 Datanum:10001.</figcaption>
</figure>

I found that the learning rate I had set was too large for this model, and it made the gradient "explode". 

Thus, I decrease the learning rate.
```text
--- Final Result ---
Final MSE: 0.000265
Final Max Error: 0.076985
```

<figure id="Fourier Learning rate:0.001 Layer:3 Epoch:20000 Datanum:10001">
    <img src="Fourier_0.001_20000_10001_3.png" alt="Fourier 3 Layers" style="width: 70%;">
    <figcaption><b>Figure 11</b>: Fourier Learning rate:0.001 Layer:3 Epoch:20000 Datanum:10001.</figcaption>
</figure>

```text
--- Final Result ---
Final MSE: 0.003896
Final Max Error: 0.226449
```

<figure id="Fourier Learning rate:0.0001 Layer:3 Epoch:20000 Datanum:10001">
    <img src="Fourier_0.0001_20000_10001_3.png" alt="Fourier 3 Layers" style="width: 70%;">
    <figcaption><b>Figure 12</b>: Fourier Learning rate:0.0001 Layer:3 Epoch:20000 Datanum:10001.</figcaption>
</figure>

The middle part almost match the shape of the target, however, it oscillates severely in the both sides, especially for the lowest learning rate, which is shown in Figure 12.

After this, I just removed one layer, which means that we have 2 layers hidden now, and try to give it more data.

```text
--- Final Result ---
Final MSE: 0.000896
Final Max Error: 0.081194
```

<figure id="Fourier Learning rate:0.01 Layer:2 Epoch:40000 Datanum:1000001">
    <img src="Fourier_0.01_40000_1000001_2.png" alt="Fourier 2 Layers Random with huge data" style="width: 70%;">
    <figcaption><b>Figure 13</b>: Fourier Learning rate:0.01 Layer:2 Epoch:40000 Datanum:1000001.</figcaption>
</figure>

```text
--- Final Result ---
Final MSE: 0.002676
Final Max Error: 0.135742
```

<figure id="Fourier Learning rate:0.001 Layer:2 Epoch:40000 Datanum:1000001">
    <img src="Fourier_0.001_40000_1000001_2.png" alt="Fourier 2 Layers Random with huge data" style="width: 70%;">
    <figcaption><b>Figure 14</b>: Fourier Learning rate:0.001 Layer:2 Epoch:40000 Datanum:1000001.</figcaption>
</figure>

I considered this a form of overfitting; however, the better performance in the central region was undeniable.

Therefore, I try to adjust the way I select the data.
```python
x_center = jnp.linspace(-0.5, 0.5, datanum//2)
x_left = jnp.linspace(-1.0, -0.5, datanum//4)
x_right = jnp.linspace(0.5, 1.0, datanum//4)
x_train = jnp.concatenate([x_center, x_left, x_right])
```

```text
--- Final Result ---
Final MSE: 0.001097
Final Max Error: 0.093484
```
<figure id="Fourier Learning rate:0.01 Layer:3 Epoch:20000 Datanum:10000">
    <img src="Fourier_0.01_20000_10000_3_s.png" alt="Fourier 3 Layers" style="width: 70%;">
    <figcaption><b>Figure 10'</b>: Fourier Learning rate:0.01 Layer:3 Epoch:20000 Datanum:10000.</figcaption>
</figure>

In conclusion, even if the graph in Figure 10 almost match the function we want to approach, but we can also see the osscilations in both sides.

Hence, I decided to try another approach using the sigmoid function, which we learned in class.