# Written Assignment
# Programming Assignment
> 1. Use the same code from Assignment 2 to calculate the error in approximating the derivative of the given function.

Since the last program is too long to execute, I try to improve the algorithm.
The basic logic is completely the same, but with a efficient computation.

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import time # Import time for performance comparison

# The Runge Function
def f(x):
    return 1/(1+25*x**2)

# Hyperparameters
learning_rate = 0.01
epochs = 20000
datanum = 1000
batch_size = 32

# Select the data uniformly in [-1,1]
x_train = jnp.linspace(-1.0, 1.0, datanum, dtype=jnp.float32)
y_train = f(x_train)

key = jax.random.PRNGKey(0)
key, w1_key, b1_key, w2_key, b2_key, w3_key, b3_key = jax.random.split(key, 7)

# Initialize parameters for 16 neurons
params = {
    'w1': jax.random.normal(w1_key, (1, 16)), 'b1': jax.random.normal(b1_key, (16,)),
    'w2': jax.random.normal(w2_key, (16, 16)), 'b2': jax.random.normal(b2_key, (16,)),
    'w3': jax.random.normal(w3_key, (16, 1)), 'b3': jax.random.normal(b3_key, (1,))
}

# Hypoyhesis Function: Sigmoid function, with 3 layers in the hidden part.
# Also, use itself as an activation function.
def deep_model(params, x):
    x = x.reshape(-1, 1)
    hidden1 = jax.nn.sigmoid(x @ params['w1'] + params['b1'])
    hidden2 = jax.nn.sigmoid(hidden1 @ params['w2'] + params['b2'])
    output = hidden2 @ params['w3'] + params['b3']
    return output

# Loss Fucntion: 1/N*\|y_pred-y\|^2
def loss_fn(params, x, y):
    predictions = deep_model(params, x).squeeze()
    return jnp.mean((predictions - y)**2)

# --- Create a JIT-compiled function for the ENTIRE epoch ---
@jax.jit
def train_epoch(params, train_data, permutation):
    x_train, y_train = train_data
    steps_per_epoch = len(x_train) // batch_size

    def body_fun(step, current_params):
        start_idx = step * batch_size
        batch_idx = jax.lax.dynamic_slice_in_dim(permutation, start_idx, batch_size)
        x_batch, y_batch = x_train[batch_idx], y_train[batch_idx]
        grads = jax.grad(loss_fn)(current_params, x_batch, y_batch)
        return jax.tree_util.tree_map(lambda p, g: p - learning_rate * g, current_params, grads)

    params = jax.lax.fori_loop(0, steps_per_epoch, body_fun, params)
    return params

# --- Modified Training Loop ---
loss_history = []
key, shuffle_key = jax.random.split(key)
num_train = len(x_train)
start_time = time.time()

for epoch in range(epochs):
    shuffle_key, perm_key = jax.random.split(shuffle_key)
    perm = jax.random.permutation(perm_key, num_train)
    
    params = train_epoch(params, (x_train, y_train), perm)
    
    if epoch % 1000 == 0:
        loss = loss_fn(params, x_train, y_train)
        loss_history.append(loss)
        print(f"Epoch {epoch}, Loss: {loss:.6f}")

end_time = time.time()
print(f"\nTraining finished in {end_time - start_time:.2f} seconds.")

# --- Evaluation and Plotting ---
x_plot = jnp.linspace(-1, 1, 500, dtype=jnp.float32)
y_true = f(x_plot)
y_pred = deep_model(params, x_plot).squeeze()
y_pred_train = deep_model(params, x_train).squeeze()
final_mse = loss_fn(params, x_train, y_train)
max_error = jnp.max(jnp.abs(y_pred_train - y_train))

print(f"\n--- Final Result ---")
print(f"Final MSE: {final_mse:.6f}")
print(f"Final Max Error: {max_error:.6f}")

plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(x_plot, y_true, label='True Runge Function')
plt.plot(x_plot, y_pred, label='Deep Sigmoid NN', linestyle='--')
plt.title('Deep Sigmoid Network Approximation')
plt.legend(); plt.grid(True)
plt.subplot(1, 2, 2)
plt.plot(range(0, epochs, 1000), loss_history)
plt.title('Training Loss Curve'); plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.yscale('log'); plt.grid(True)
plt.tight_layout()
plt.show()

# Derivative calculation
def model_for_grad(x, model_params):
    return deep_model(model_params, x).squeeze()

grad_f = jax.grad(f)  
y_d_true = jax.vmap(grad_f)(x_plot) 
model_derivative_fn = jax.grad(model_for_grad, argnums=0)
vectorized_model_derivative = jax.vmap(model_derivative_fn, in_axes=(0, None))
y_d_pred = vectorized_model_derivative(x_plot, params)

pred_derivative_on_train = vectorized_model_derivative(x_train, params)
true_derivative_on_train = jax.vmap(grad_f)(x_train)
final_derivative_mse = jnp.mean((pred_derivative_on_train - true_derivative_on_train)**2)
max_derivative_error = jnp.max(jnp.abs(y_d_pred-y_d_true))

print(f"\n--- Final Result (Derivatives) ---")
print(f"Final MSE (Derivative): {final_derivative_mse:.6f}")
print(f"Final Max Error (Derivative): {max_derivative_error:.6f}")

plt.figure(figsize=(8,6))
plt.plot(x_plot, y_d_true, label='True Derivative of Runge Function')
plt.plot(x_plot, y_d_pred, label='Deep Sigmoid NN Derivative', linestyle='--')
plt.title("Comparison of Derivatives")
plt.xlabel("x")
plt.ylabel("f'(x)")
plt.legend()
plt.grid(True)
plt.show()

> 2. In this assignment, you will use a neural network to approximate both the **Runge function** and its **derivative**. Your task is to train a neural network that approximates:
> a. The function $f(x)$ itself.
> b. The derivative $f'(x)$.
> You should define a **loss function** consisting of two components:
> 1). **Function loss**: the error between the predicted $f(x)$ and the true $f'(x)$.
> 2). **Derivative loss**: the error between the predicted $f'(x)$ and the true $f'(x)$.
> Write a short report (1–2 pages) explaining method, results, and discussion including
> * Plot the true function and the neural network prediction together.
> * Show the training/validation loss curves.
> * Compute and report errors (MSE or max error).