## Linear Model

In [None]:
#| default_exp linear_model

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
#| export
from __future__ import annotations
from jax_interpret.imports import *

In [None]:
#| export
def l2_loss(x1, x2, weights=None):
    if weights is None:
        return jnp.mean(jnp.square(x1 - x2)) / 2.0
    else:
        return jnp.sum((weights / jnp.linalg.norm(weights, ord=1)) * jnp.square(x1 - x2)) / 2.0

In [None]:
#| export
def _init_train_fn(
    X: jnp.ndarray, # Input data
    y: jnp.ndarray, # Target data
    fit_bias: bool = True, # Fit bias term
    seed: int = 42, # Random seed
):  
    rng = jax.random.PRNGKey(seed)
    n_samples, n_features = X.shape
    rng, w_key, b_key = jax.random.split(rng, 3)
    w = jax.random.normal(w_key, (n_features,))
    if fit_bias:
        b = jax.random.normal(b_key, (1,))
    else:
        b = jnp.zeros(1)
    params = dict(w=w, b=b)
    return params

def calculate_loss(
    params: Dict[str, jnp.ndarray],
    batch: Tuple[Array, Array, Array],
    loss_fn: Callable,
    reg_term: int = None,
    alpha: float = 1.0
):
    """Calculate the loss for a batch of data."""
    w, b = params["w"], params["b"]
    X, y, weights = batch
    y_pred = jnp.dot(X, w) + b
    loss = loss_fn(y, y_pred, weights)
    if reg_term is not None:
        reg = jnp.linalg.norm(w, ord=reg_term)
        loss += jnp.mean(reg) * alpha
    return loss

def sgd_train_linear_model(
    X: jnp.ndarray, # Input data. Shape: `(N, k)`
    y: jnp.ndarray, # Target data. Shape: `(N,)` or `(N, 1)`
    weights: jnp.ndarray = None, # Initial weights. Shape: `(N,)`
    lr: float = 0.01, # Learning rate
    n_epochs: int = 100, # Number of epochs
    batch_size: int = 32, # Batch size
    seed: int = 42, # Random seed
    loss_fn: Callable = l2_loss, # Loss function
    reg_term: int = None, # Regularization term
    alpha: float = 1.0, # Regularization strength
    fit_bias: bool = True, # Fit bias term
) -> Tuple[np.ndarray, np.ndarray]: # The trained weights and bias
    """Train a linear model using SGD."""

    @jax.jit
    def sgd_step(params, opt_state, batch):
        """Perform a single SGD step."""
        grads = jax.grad(calculate_loss)(params, batch, loss_fn, reg_term, alpha)
        updates, opt_state = opt.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        return params, opt_state

    # TODO: Check shapes of X and y
    n_samples = X.shape[0]
    params = _init_train_fn(X, y, fit_bias, seed)
    opt = optax.sgd(lr)
    opt_state = opt.init(params)
    for epoch in range(n_epochs):
        for i in range(0, n_samples, batch_size):
            X_batch = X[i : i + batch_size]
            y_batch = y[i : i + batch_size]
            w_batch = weights[i : i + batch_size] if weights is not None else None
            params, opt_state = sgd_step(params, opt_state, (X_batch, y_batch, w_batch))
    return params["w"], params["b"]


In [None]:
#| export
class BaseEstimator:
    def __init__(self):
        ...

    def fit(self, X, y):
        ...

In [None]:
#| export
class LinearModel(BaseEstimator):
    def __init__(
        self,
        intercept: bool = True,
        trainer_fn: Callable=None,
        **kwargs,
    ):
        self.fit_bias = intercept
        self.trainer_fn = sgd_train_linear_model if trainer_fn is None else trainer_fn
    
    def fit(
        self, 
        X: jnp.ndarray, 
        y: jnp.ndarray,
        weights: jnp.ndarray = None,
        **kwargs,
    ) -> LinearModel:
        self.coef_, self.intercept_ = self.trainer_fn(
            X, y, weights, fit_bias=self.fit_bias, **kwargs)
        return self

In [None]:
#| export
class Lasso(LinearModel):
    def __init__(self, alpha: float = 1.0, **kwargs):
        super().__init__(**kwargs)
        self.alpha = alpha

    def fit(self, X: jnp.ndarray, y: jnp.ndarray, weights: jnp.ndarray = None, **kwargs) -> LinearModel:
        return super().fit(X, y, weights, reg_term=1, alpha=self.alpha, **kwargs)

In [None]:
#| export
class Ridge(LinearModel):
    def __init__(self, alpha: float = 1.0, **kwargs):
        super().__init__(**kwargs)
        self.alpha = alpha

    def fit(self, X: jnp.ndarray, y: jnp.ndarray, weights: jnp.ndarray = None, **kwargs) -> LinearModel:
        return super().fit(X, y, weights, reg_term=2, alpha=self.alpha, **kwargs)

#### Test 

In [None]:
from sklearn.datasets import make_regression
from sklearn.linear_model import LinearRegression

In [None]:
X, y = make_regression(n_samples=500, n_features=20)
w = np.ones(X.shape[0])

In [None]:
sk_lm = LinearRegression()
sk_lm.fit(X, y)
sk_lm.coef_, sk_lm.intercept_

(array([ 0.00000000e+00,  6.80690685e+00,  7.17571931e+01, -1.26707763e-14,
        -3.38413423e-14, -6.82817392e-14,  4.29433926e+01, -6.58114080e-15,
         4.01878799e-14,  3.24526952e+01, -1.42780834e-14,  7.09915801e+01,
         3.21218784e+01,  3.04321225e-14,  2.73409206e+00,  6.89600047e+01,
         8.21964684e+01,  3.42503803e-14, -2.33146835e-14,  6.94968410e+01]),
 5.329070518200751e-15)

In [None]:
lm = LinearModel()
lm.fit(X, y)
lm.fit(X, y, w)
lm.coef_, lm.intercept_

(DeviceArray([ 2.9183337e-05,  6.8071284e+00,  7.1756966e+01,
               3.4530715e-06,  1.0120852e-04,  3.5201498e-07,
               4.2943222e+01,  9.6378506e-05,  3.3780834e-05,
               3.2452621e+01,  1.4333882e-04,  7.0991470e+01,
               3.2121925e+01, -1.2890021e-04,  2.7341127e+00,
               6.8959724e+01,  8.2196198e+01, -1.3383102e-04,
               1.2260761e-04,  6.9496712e+01], dtype=float32),
 DeviceArray([-5.7324924e-05], dtype=float32))

In [None]:
lasso = Lasso(alpha=0.1)
lasso.fit(X, y)
lasso.fit(X, y, w)
lasso.coef_, lasso.intercept_

(DeviceArray([ 6.7087787e-04,  6.7190990e+00,  7.1662598e+01,
               7.4942748e-04, -5.4690341e-04, -1.1745903e-03,
               4.2832729e+01,  6.7359884e-04, -1.9092231e-04,
               3.2343441e+01, -2.1547372e-04,  7.0898727e+01,
               3.2021362e+01,  9.4644330e-04,  2.6314874e+00,
               6.8846603e+01,  8.2106552e+01, -2.2809652e-03,
              -9.7414968e-04,  6.9408112e+01], dtype=float32),
 DeviceArray([0.03017059], dtype=float32))

In [None]:
ridge = Ridge(alpha=0.1)
ridge.fit(X, y)
ridge.fit(X, y, w)
ridge.coef_, ridge.intercept_

(DeviceArray([ 2.2781775e-03,  6.8124046e+00,  7.1715286e+01,
               2.9025278e-03,  3.1277947e-03, -1.9458444e-03,
               4.2912235e+01,  2.3261304e-03,  4.7600130e-03,
               3.2431641e+01,  4.1142856e-03,  7.0952370e+01,
               3.2105915e+01, -2.5508574e-03,  2.7339778e+00,
               6.8915466e+01,  8.2148575e+01, -8.6568939e-03,
               7.1830135e-03,  6.9459007e+01], dtype=float32),
 DeviceArray([0.00411766], dtype=float32))