## Linear Model

In [None]:
#| default_exp linear_model

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *

In [None]:
#| export
from __future__ import annotations
from explainax.imports import *

In [None]:
#| export
def l2_loss(x1, x2, weights=None):
    if weights is None:
        return optax.l2_loss(x1, x2)
    else:
        return jnp.sum((weights / optax.safe_norm(weights, ord=1)) * jnp.square(x1 - x2)) / 2.0

In [None]:
#| export
def _init_train_fn(
    X: jnp.ndarray, # Input data
    y: jnp.ndarray, # Target data
    fit_bias: bool = True, # Fit bias term
    seed: int = 42, # Random seed
):  
    rng = jax.random.PRNGKey(seed)
    n_samples, n_features = X.shape
    rng, w_key, b_key = jax.random.split(rng, 3)
    w = jax.random.normal(w_key, (n_features,))
    if fit_bias:
        b = jax.random.normal(b_key, (1,))
    else:
        b = jnp.zeros(1)
    params = dict(w=w, b=b)
    return params

def calculate_loss(
    params: Dict[str, jnp.ndarray],
    batch: Tuple[Array, Array, Array],
    loss_fn: Callable,
    reg_term: int = None,
    alpha: float = 1.0
):
    """Calculate the loss for a batch of data."""
    w, b = params["w"], params["b"]
    X, y, weights = batch
    y_pred = jnp.dot(X, w) + b
    loss = loss_fn(y, y_pred, weights)
    if reg_term is not None:
        reg = jnp.linalg.norm(w, ord=reg_term)
        loss += jnp.mean(reg) * alpha
    return loss

def sgd_train_linear_model(
    X: jnp.ndarray, # Input data. Shape: `(N, k)`
    y: jnp.ndarray, # Target data. Shape: `(N,)` or `(N, 1)`
    weights: jnp.ndarray = None, # Initial weights. Shape: `(N,)`
    lr: float = 0.01, # Learning rate
    n_epochs: int = 100, # Number of epochs
    batch_size: int = 32, # Batch size
    seed: int = 42, # Random seed
    loss_fn: Callable = l2_loss, # Loss function
    reg_term: int = None, # Regularization term
    alpha: float = 1.0, # Regularization strength
    fit_bias: bool = True, # Fit bias term
) -> Tuple[np.ndarray, np.ndarray]: # The trained weights and bias
    """Train a linear model using SGD."""

    @jax.jit
    def sgd_step(params, opt_state, batch):
        """Perform a single SGD step."""
        grads = jax.grad(calculate_loss)(params, batch, loss_fn, reg_term, alpha)
        updates, opt_state = opt.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        return params, opt_state

    # TODO: Check shapes of X and y
    n_samples = X.shape[0]
    params = _init_train_fn(X, y, fit_bias, seed)
    opt = optax.sgd(lr)
    opt_state = opt.init(params)
    for epoch in range(n_epochs):
        for i in range(0, n_samples, batch_size):
            X_batch = X[i : i + batch_size]
            y_batch = y[i : i + batch_size]
            w_batch = weights[i : i + batch_size] if weights is not None else None
            params, opt_state = sgd_step(params, opt_state, (X_batch, y_batch, w_batch))
    return params["w"], params["b"]


In [None]:
#| export
class BaseEstimator:
    def __init__(self):
        ...

    def fit(self, X, y):
        ...

In [None]:
#| export
class LinearModel(BaseEstimator):
    def __init__(
        self,
        intercept: bool = True,
        trainer_fn: Callable=None,
        **kwargs,
    ):
        self.fit_bias = intercept
        self.trainer_fn = sgd_train_linear_model if trainer_fn is None else trainer_fn
    
    def fit(
        self, 
        X: jnp.ndarray, 
        y: jnp.ndarray,
        weights: jnp.ndarray = None,
        **kwargs,
    ) -> LinearModel:
        self.coef_, self.intercept_ = self.trainer_fn(
            X, y, weights, fit_bias=self.fit_bias, **kwargs)
        return self

In [None]:
#| export
class Lasso(LinearModel):
    def __init__(self, alpha: float = 1.0, **kwargs):
        super().__init__(**kwargs)
        self.alpha = alpha

    def fit(self, X: jnp.ndarray, y: jnp.ndarray, weights: jnp.ndarray = None, **kwargs) -> LinearModel:
        return super().fit(X, y, weights, reg_term=1, alpha=self.alpha, **kwargs)

In [None]:
#| export
class Ridge(LinearModel):
    def __init__(self, alpha: float = 1.0, **kwargs):
        super().__init__(**kwargs)
        self.alpha = alpha

    def fit(self, X: jnp.ndarray, y: jnp.ndarray, weights: jnp.ndarray = None, **kwargs) -> LinearModel:
        return super().fit(X, y, weights, reg_term=2, alpha=self.alpha, **kwargs)

#### Test 

In [None]:
from sklearn.datasets import make_regression
from sklearn.linear_model import LinearRegression

In [None]:
X, y = make_regression(n_samples=500, n_features=20)
w = np.ones(X.shape[0])

In [None]:
sk_lm = LinearRegression()
sk_lm.fit(X, y)
sk_lm.coef_, sk_lm.intercept_

(array([ 3.08898669e-14,  3.10345763e+00, -3.40252413e-14, -1.17304952e-14,
         1.63301547e+01, -1.12741395e-14,  3.89375096e-14,  2.12155613e+01,
         1.87952565e-15, -2.45710044e-14, -5.12493483e-15,  3.53391521e+01,
         2.49617925e+01,  1.92548415e-15,  1.15489565e+01,  6.26240704e+00,
         2.80896941e+01,  3.32831479e+01,  7.02041648e-15,  3.46566155e+01]),
 8.881784197001252e-16)

In [None]:
lm = LinearModel()
lm.fit(X, y)
lm.fit(X, y, w)
lm.coef_, lm.intercept_

(Array([ 8.38563865e-05,  3.10347271e+00,  2.50895084e-07, -7.12871915e-05,
         1.63300648e+01,  2.79284559e-05,  1.49981890e-04,  2.12155819e+01,
        -4.31869266e-05, -1.74582556e-06, -1.05994470e-04,  3.53389740e+01,
         2.49617939e+01,  6.73230243e-05,  1.15488825e+01,  6.26239347e+00,
         2.80897846e+01,  3.32830086e+01,  3.86646097e-05,  3.46564369e+01],      dtype=float32),
 Array([-6.177535e-05], dtype=float32))

In [None]:
assert np.allclose(sk_lm.coef_, lm.coef_, atol=5e-3)
assert np.allclose(sk_lm.intercept_, lm.intercept_, atol=5e-3)

In [None]:
lasso = Lasso(alpha=0.1)
lasso.fit(X, y)
lasso.fit(X, y, w)
lasso.coef_, lasso.intercept_

(DeviceArray([ 6.7087787e-04,  6.7190990e+00,  7.1662598e+01,
               7.4942748e-04, -5.4690341e-04, -1.1745903e-03,
               4.2832729e+01,  6.7359884e-04, -1.9092231e-04,
               3.2343441e+01, -2.1547372e-04,  7.0898727e+01,
               3.2021362e+01,  9.4644330e-04,  2.6314874e+00,
               6.8846603e+01,  8.2106552e+01, -2.2809652e-03,
              -9.7414968e-04,  6.9408112e+01], dtype=float32),
 DeviceArray([0.03017059], dtype=float32))

In [None]:
ridge = Ridge(alpha=0.1)
ridge.fit(X, y)
ridge.fit(X, y, w)
ridge.coef_, ridge.intercept_

(DeviceArray([ 2.2781775e-03,  6.8124046e+00,  7.1715286e+01,
               2.9025278e-03,  3.1277947e-03, -1.9458444e-03,
               4.2912235e+01,  2.3261304e-03,  4.7600130e-03,
               3.2431641e+01,  4.1142856e-03,  7.0952370e+01,
               3.2105915e+01, -2.5508574e-03,  2.7339778e+00,
               6.8915466e+01,  8.2148575e+01, -8.6568939e-03,
               7.1830135e-03,  6.9459007e+01], dtype=float32),
 DeviceArray([0.00411766], dtype=float32))