<a href="https://colab.research.google.com/github/Pandatoey/LAB229351/blob/main/%E0%B8%AA%E0%B8%B3%E0%B9%80%E0%B8%99%E0%B8%B2%E0%B8%82%E0%B8%AD%E0%B8%87_208424Lab01.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

###Problem 4: Gradient Verification for Ridge Regression

In [1]:
import jax
import jax.numpy as jnp
from jax import grad

#1.Generate synthetic data
key = jax.random.PRNGKey(0)

X = jax.random.normal(key, (50, 5))   # X ∈ R^{50×5}
y = jax.random.normal(key, (50,))     # y ∈ R^{50}
beta = jax.random.normal(key, (5,))   # β ∈ R^{5}
lam = 10.0                            # λ = 10

#2.Ridge Regression loss
def ridge_loss(beta, X, y, lam):
    r = X @ beta - y
    return r.T @ r + lam * (beta.T @ beta)

#3.หาค่า gradient จากฟังก์ชัน jax โดยตรง
jax_grad = grad(ridge_loss)(beta, X, y, lam)

#4.หาค่า gradient จากสูตรคณิตศาสตร์ ∇β​f(β)=2(XTXβ−XTy+λβ)
def ridge_grad_analytic(beta, X, y, lam):
    return 2 * (X.T @ (X @ beta - y) + lam * beta)

#5.Analytic gradient
analytic_grad = ridge_grad_analytic(beta, X, y, lam)

#6.Compare
print("JAX gradient:", jax_grad)
print("Analytic gradient:", analytic_grad)
print("Gradients match:",
      jnp.allclose(jax_grad, analytic_grad)) #ค่า gradient ที่ได้จาก jax กับ จากสูตรคำนวณมีค่าใกล้เคียงกัน

JAX gradient: [152.26129  233.75662  -22.530704   1.699074  35.291744]
Analytic gradient: [152.26132   233.75667   -22.530706    1.6990726  35.29174  ]
Gradients match: True


###Problem 5: MLE Verification via Gradient Checking

In [2]:
import jax
import jax.numpy as jnp
from jax import grad

# 1. Generate exponential data
key = jax.random.PRNGKey(42)
lambda_true = 4.0 #ค่าพารามิเตอร์จริงคือ 4
x = jax.random.exponential(key, (100,)) / lambda_true

# 2. Analytic MLE
lambda_hat = 1.0 / jnp.mean(x)

# 3. Negative Log-Likelihood
def nll(lam, x):
    return -jnp.sum(jnp.log(lam) - lam * x)

# 4. Gradient of NLL
nll_grad = grad(nll)

# 5. Evaluate gradient at analytic MLE
grad_at_mle = nll_grad(lambda_hat, x)

print("True lambda:", lambda_true)
print("Estimated lambda_hat:", lambda_hat)
print("Gradient at lambda_hat:", grad_at_mle)

True lambda: 4.0
Estimated lambda_hat: 4.1560264
Gradient at lambda_hat: 1.9073486e-06
