# Home assignment 3

In this asignment, you will train your skills in implementation of the objective functions such that JAX can derive gradient function automatically without additional efforts from your side.
We will start from the simple tasks and finish with more or less real scenario.


## Task 1 (5 pts)

The classical regression task aims to find such optinal parameters $w^*$ that minimizes the mean squared error loss functon

 $$ MSE = \frac{1}{m} \sum_{i=1}^m (y_i - f(w|x_i))^2, $$

 where $f(w|x_i)$ is a parametric function that maps input vector $x_i$ to a scalar $\widehat{y}_i \approx y_i$.
 Thus, we want to find such parameter $w^*$ that approximation of the ground-truth labels $y_i$ becomes as accurate as possible.

 Consider a particular instance of this problem, where

 $$ f(w|x) = \cos(w_1x_1 + w_2) + \exp(-w_5 x_2)\sin(w_3x_2 + w_4)$$

 Write Python code that
 - (2 pts) implements two functons: $f$ and MSE
 - (1 pts) supports autograd from JAX
 - (2 pts) the resulting gradient function is fast and correct, i.e. the resulting gradient w.r.t. $w$ equals to the analytical gradient

 Demonstrate this on some demo input $(x_i, y_i)$ and random $w$.


In [144]:
# Your solution is here
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import numpy as np

In [145]:
# Implements the parametric function f
def f(w, x):
    return jnp.cos(w[0] * x[0] + w[1]) + jnp.exp(-w[4] * x[1]) * jnp.sin(w[2] * x[1] + w[3])

# Implements the Mean Squared Error function
def mse(w, x, y):
    return jnp.mean((y - f(w, x))**2)


In [146]:
# Generate demo input data (x_i, y_i)
X = jnp.array(np.random.randn(100, 2))  # 100 samples with 2 features
y = jnp.array(np.random.randn(100, 1))   # 100 target values
w = jnp.array([0.5, 0.5, 0.5, 0.5, 0.5])  # Initial weights


The function \( f(w, x) \) is defined as:
$
f(w, x) = \cos(w[0] \cdot x[0] + w[1]) + e^{-w[4] \cdot x[1]} \cdot \sin(w[2] \cdot x[1] + w[3])
$

Let’s break down the gradient calculations for each weight:

1. **Gradient with respect to \( w[0] \)**:
   $
   \frac{\partial f}{\partial w[0]} = -\sin(w[0] \cdot x[0] + w[1]) \cdot x[0]
   $

2. **Gradient with respect to \( w[1] \)**:
   $
   \frac{\partial f}{\partial w[1]} = -\sin(w[0] \cdot x[0] + w[1])
   $

3. **Gradient with respect to \( w[2] \)**:
   $
   \frac{\partial f}{\partial w[2]} = e^{-w[4] \cdot x[1]} \cdot x[1] \cdot \cos(w[2] \cdot x[1] + w[3])
   $

4. **Gradient with respect to \( w[3] \)**:
   $
   \frac{\partial f}{\partial w[3]} = e^{-w[4] \cdot x[1]} \cdot \cos(w[2] \cdot x[1] + w[3])
   $

5. **Gradient with respect to \( w[4] \)**:
   $
   \frac{\partial f}{\partial w[4]} = -e^{-w[4] \cdot x[1]} \cdot x[1] \cdot \sin(w[2] \cdot x[1] + w[3])
   $


In [147]:
# Analytical gradient function
def analytical_gradient(w, x):
    df_dw0 = -jnp.sin(w[0] * x[:, 0] + w[1]) * x[:, 0]
    df_dw1 = -jnp.sin(w[0] * x[:, 0] + w[1])
    df_dw2 = jnp.exp(-w[4] * x[:, 1]) * x[:, 1] * jnp.cos(w[2] * x[:, 1] + w[3])
    df_dw3 = jnp.exp(-w[4] * x[:, 1]) * jnp.cos(w[2] * x[:, 1] + w[3])
    df_dw4 = -jnp.exp(-w[4] * x[:, 1]) * x[:, 1] * jnp.sin(w[2] * x[:, 1] + w[3])
    
    return jnp.array([jnp.mean(df_dw0), jnp.mean(df_dw1), jnp.mean(df_dw2), jnp.mean(df_dw3), jnp.mean(df_dw4)])

In [148]:
# Compute Mean Squared Error
mse_val = mse(w, X, y)
print(f'MSE = {mse_val}')

# Compute the gradient of the MSE with respect to w
mse_grad = grad(mse)(w, X, y)
print(f'MSE Gradient (automatic) = {mse_grad}')
print()
# Compute analytical gradient
grad_f = analytical_gradient(w, X)
print(f'Analytical Gradient = {grad_f}')


MSE = 2.4340803623199463
MSE Gradient (automatic) = [-0.14500625 -0.6360937  -0.71605384  1.7825898  -0.76763594]

Analytical Gradient = [-0.4106308  -0.36518627 -0.37647083  0.89268416 -0.22020797]


In [149]:
# Check if both gradients are close
comparison = jnp.allclose(mse_grad, grad_f)
print(f'Gradients are close: {comparison}')

Gradients are close: False


## Task 2 (8 pts)

The classical multiclass classification problem (with $C$ classes) aims to minimize so-called CrossEntropy loss function

$$xEntropy = -\frac{1}{N} \sum_{i=1}^N \log \left( \frac{\exp(\widehat{y}_{iy_i})}{\sum_{j=1}^C \exp(\widehat{y}_{ij})}\right),$$

where $\widehat{y}_i$ is the vector of shape $C$ which stores estimation of non-normalized probabilities corresponding to possible classes to assign vector $x_i$.
For example, if $C = 3$, the ground-truth class label $y_i = 0$ and $\widehat{y}_i = [1, 10, -3]$, then $x_i$ more likely belongs to class 1, than to classes 0 and 3 although $\widehat{y}_{iy_i} = 1$.

So, given data matrix $X \in \mathbb{R}^{N \times n}$ and the corresponding ground-truth class labels $y$ of shape $N$ and $y_i \in \{ 0, \ldots C-1 \}$.
To estimate $\widehat{y}_i$, one can construct arbitrary function depending on the set of parameters.
In this task, let us approximate it through simple linear model which trasforms the input vector $x_i$ of shape $n$ as $\widehat{y}_i = Wx_i + b$.
So, the parameters of the model are $(W, b)$.

- (3 pts) Implement crossentropy function in a stable manner
- (2 pts) Implement function (which also called a model) that generates samples' embeddings $\widehat{y}_i$ via linear model. Note that linear model $Wx + b$ is the simplest option, so you can compose this function from linear and non-linear functions. Pay attention that it depends on the data matrix, parameters $(W, b)$ for every linear function (and more parameters for nonlinear functions if any)
- (2 pts) Generate with JAX the function to compute gradient of crossentropy function w.r.t. parameters of the model and check that it works and sufficiently fast.
- (1 pts) Check numerically that your function works correcly. Numerical check means that your gradient function gives similar result to naive finite difference approximation of the gradient.



In [150]:
# Cross-entropy loss function
def cross_entropy(y_true, y_pred):
    N = y_true.shape[0]
    max_logits = jnp.max(y_pred, axis=1, keepdims=True)
    log_probs = y_pred - max_logits - jnp.log(jnp.sum(jnp.exp(y_pred - max_logits), axis=1, keepdims=True))
    true_log_probs = jnp.take_along_axis(log_probs, y_true[:, None], axis=1).flatten()
    xEntropy = -jnp.mean(true_log_probs)
    return xEntropy

# Linear model function
def linear_model(X, W, b):
    return jnp.dot(X, W) + b

# Combined loss and gradient function
def loss_and_grads(W, b, X, y):
    y_pred = linear_model(X, W, b)
    loss = cross_entropy(y, y_pred)
    return loss

# Gradient functions
grad_W = grad(loss_and_grads, argnums=0)
grad_b = grad(loss_and_grads, argnums=1)

# Numerical gradient function
def numerical_gradient(func, params, epsilon=1e-6):
    grads = jnp.zeros_like(params)
    for i in range(params.size):
        params_plus = params.at[i].set(params[i] + epsilon)
        params_minus = params.at[i].set(params[i] - epsilon)
        
        # Compute the finite difference
        grad = (func(params_plus) - func(params_minus)) / (2 * epsilon)
        grads = grads.at[i].set(grad)
    return grads

# Example usage
W = jnp.array([[0.1, 0.2], [0.3, 0.4]])  # Example weight matrix
b = jnp.array([0.1, 0.2])  # Example bias vector
X = jnp.array([[1.0, 2.0], [3.0, 4.0]])  # Example input data
y = jnp.array([0, 1])  # Example ground truth labels

# Compute analytical gradients
analytical_grad_W = grad_W(W, b, X, y)
analytical_grad_b = grad_b(W, b, X, y)

# Function to compute loss with both W and b
def loss_with_params(params):
    W, b = params
    return loss_and_grads(W, b, X, y)

# Compute numerical gradients
numerical_grad_W = numerical_gradient(lambda W: loss_with_params((W, b)), W)
numerical_grad_b = numerical_gradient(lambda b: loss_with_params((W, b)), b)

# Check if they are close
print("Analytical Gradient W:\n", analytical_grad_W)
print("Numerical Gradient W:\n", numerical_grad_W)
print("Analytical Gradient b:\n", analytical_grad_b)
print("Numerical Gradient b:\n", numerical_grad_b)

Analytical Gradient W:
 [[ 0.1656944  -0.16569445]
 [ 0.02136332 -0.02136338]]
Numerical Gradient W:
 [[-0.02980232 -0.02980232]
 [ 0.          0.        ]]
Analytical Gradient b:
 [-0.14433108  0.14433107]
Numerical Gradient b:
 [-0.08940697  0.17881393]


In [151]:
y = jnp.array([0, 1])
y

Array([0, 1], dtype=int32)

## Task 3 (6 pts)

The classical recommender system model is based on the regularized matrix factorization task, which aims to approximate the given matrix with user ratings to items by product of two smaller matrices $U, V$ as follows

$$ L(U, V) = \|X - UV \|_F^2 + \frac{\lambda}{2} \|U\|_F^2 + \frac{\lambda}{2} \|V\|_F^2, $$

where $X \in \mathbb{R}^{m \times n}$, $U \in \mathbb{R}^{m \times k}$ and $V \in \mathbb{R}^{k \times n}$ and $k$ is much smaller $m$ and $n$.
For example, if $10^5$ users rate $10^6$ items, then $m = 10^5$ and $n = 10^6$.
In addition, $\lambda > 0$ is a given constant.

- (2 pts) Implement function $L$ in JAX for given matrix $X$
- (1 pts) Verify that JAX can compute gradient w.r.t. $U$ and $V$. Use $m = 1000, n = 100, k = 10$ to generate synthetic matrices, and $\lambda = 1$.
- (3 pts) Compare the runtime for autograd and analytical expressions for the gradient for range of $m$, $n$ and $k$. What approach (analytical or autograd) is faster?

In [152]:
def loss(U, V, X, lambda_):
    frobenius_norm = jnp.linalg.norm(X - jnp.dot(U, V), ord='fro')**2
    regularization = (lambda_ / 2) * (jnp.linalg.norm(U, ord='fro')**2 + jnp.linalg.norm(V, ord='fro')**2)
    return frobenius_norm + regularization

In [153]:
# Parameters
m, n, k = 1000, 100, 10
lambda_ = 1.0

# Generate synthetic data
key = random.PRNGKey(0)
X = random.normal(key, (m, n))
U = random.normal(key, (m, k))
V = random.normal(key, (k, n))

# Compute loss
loss_value = loss(U, V, X, lambda_)

# Compute gradients
grad_U = grad(loss, argnums=0)(U, V, X, lambda_)
grad_V = grad(loss, argnums=1)(U, V, X, lambda_)

# Output results
print("Loss Value:", loss_value)
print("Gradient with respect to U shape:", grad_U.shape)
print("Gradient with respect to V shape:", grad_V.shape)

Loss Value: 1079256.0
Gradient with respect to U shape: (1000, 10)
Gradient with respect to V shape: (10, 100)


In [154]:
import time

def analytical_grad(U, V, X, lambda_):
    """
    Compute the analytical gradients with respect to U and V.
    
    Args:
        U: User feature matrix of shape (m, k).
        V: Item feature matrix of shape (k, n).
        X: Rating matrix of shape (m, n).
        lambda_: Regularization parameter.
        
    Returns:
        grad_U, grad_V: The computed gradients.
    """
    UV = jnp.dot(U, V)
    error = X - UV
    grad_U = -2 * jnp.dot(error, V.T) + lambda_ * U
    grad_V = -2 * jnp.dot(U.T, error) + lambda_ * V
    return grad_U, grad_V

# Function to compare runtimes
# Function to compare runtimes
def compare_runtimes(m_values, n_values, k_values):
    for m, n, k in zip(m_values, n_values, k_values):
        # Generate synthetic data
        key = random.PRNGKey(0)
        X = random.normal(key, (m, n))
        U = random.normal(key, (m, k))
        V = random.normal(key, (k, n))

        # Measure autograd time
        start_time = time.time()
        _ = grad(loss, argnums=0)(U, V, X, lambda_)
        _ = grad(loss, argnums=1)(U, V, X, lambda_)
        autograd_time = time.time() - start_time

        # Measure analytical time
        start_time = time.time()
        analytical_grad(U, V, X, lambda_)
        analytical_time = time.time() - start_time

        print(f"m: {m}, n: {n}, k: {k} | Autograd Time: {autograd_time:.6f}s | Analytical Time: {analytical_time:.6f}s")

# Example values for m, n, k
m_values = [100, 500, 1000]
n_values = [10, 50, 100]
k_values = [5, 10, 20]

compare_runtimes(m_values, n_values, k_values)

m: 100, n: 10, k: 5 | Autograd Time: 0.047501s | Analytical Time: 0.000854s


m: 500, n: 50, k: 10 | Autograd Time: 0.036570s | Analytical Time: 0.005101s
m: 1000, n: 100, k: 20 | Autograd Time: 0.045977s | Analytical Time: 0.006143s
