# Exercises

In [1]:
import jax
import jax.numpy as jnp
from jax import grad, jit

### Exercise 1: gradient descent optimization

Consider a dataset consisting of 2 features and a given number of samples, with the
corresponding labels:

In [2]:
# Generate synthetic data
key = jax.random.PRNGKey(0)
true_w = jnp.array([2.0, -3.0])  # True weights for the synthetic data
true_b = 5.0  # True bias for the synthetic data
num_samples = 100
x = jax.random.normal(key, (num_samples, 2))  # Input features
y = jnp.dot(x, true_w) + true_b  # Targets

Find the weights (2) and bias of a linear model that best fits the dataset (_linear
regression_). 

In [3]:
# Define the model
def model(weights, bias, x):
    return jnp.dot(x, weights) + bias

# Mean squared error loss function
def mse_loss(weights, bias, x, y):
    predictions = model(weights, bias, x)
    return jnp.mean((predictions - y) ** 2)

# Gradient of the loss function
grad_mse_loss = grad(mse_loss, argnums=(0, 1))

# Training step
@jit
def train_step(weights, bias, x, y, learning_rate):
    gradients = grad_mse_loss(weights, bias, x, y)
    new_weights = weights - learning_rate * gradients[0]
    new_bias = bias - learning_rate * gradients[1]
    return new_weights, new_bias

# Initialize parameters
weights = jax.random.normal(key, (2,))
bias = 0.0
learning_rate = 0.01
num_epochs = 1000

# Training loop
for epoch in range(num_epochs):
    weights, bias = train_step(weights, bias, x, y, learning_rate)
    if epoch % 100 == 0:
        current_loss = mse_loss(weights, bias, x, y)
        print(f"Epoch {epoch}, Loss: {current_loss}")

# Print the final parameters
print("Learned weights:", weights)
print("Learned bias:", bias)

Epoch 0, Loss: 48.84288787841797
Epoch 100, Loss: 0.871484637260437
Epoch 200, Loss: 0.019410856068134308
Epoch 300, Loss: 0.0005184339242987335
Epoch 400, Loss: 1.5399524272652343e-05
Epoch 500, Loss: 4.810555083167856e-07
Epoch 600, Loss: 1.5348598836339988e-08
Epoch 700, Loss: 6.657859263903276e-10
Epoch 800, Loss: 2.2406083932668963e-10
Epoch 900, Loss: 2.2406083932668963e-10
Learned weights: [ 2.0000076 -2.9999933]
Learned bias: 4.999988


### Exercise 2: solving a non-linear system of equations.

Consider the following system of non-linear algebraic equations:

In [4]:
# Define the system of nonlinear equations
@jit
def system_of_equations(x):
    # Example equations:
    # Equation 1: x0^2 + x1 - 37 = 0
    # Equation 2: x0 - x1^2 - 5 = 0
    return jnp.array([
        x[0]**2 + x[1] - 37,
        x[0] - x[1]**2 - 5
    ])

Solve the system numerically using the [`root`](https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.root.html)
function of `scipy`, with the Jacobian computed via JAX's autodiff. 

In [5]:
from scipy.optimize import root

# Function to compute the Jacobian using JAX
@jit
def jacobian(x):
    return jax.jacfwd(system_of_equations)(x)

# Initial guess for the solution
initial_guess = jnp.array([1.0, 1.0])

# Use scipy's root function to find the solution
solution = root(system_of_equations, initial_guess, jac=jacobian)

# Print the solution
print("Solution:", solution.x)
print("Function value at solution:", system_of_equations(solution.x))
print("Jacobian at solution:", jacobian(solution.x))

Solution: [5.99999999 0.99999991]
Function value at solution: [0. 0.]
Jacobian at solution: [[12.         1.       ]
 [ 1.        -1.9999999]]
