# Model, Loss and Optimizer for Linear Regression

## Introduction

You've made it pretty far now! 
I also hope you're excited to learn more!

In the previous chapter, we've shown you this function called `grad`.
`grad` is one of the key functions that JAX provides -
it is the user-facing interface to
the automatic differentiation system that JAX provides.

In this chapter, we're going to talk about linear regression.
But wait, isn't this a course on deep learning fundamentals?
Why are we talking about simple linear models?
Well... it's because the _fundamental_ ideas that we need to learn
apply the same way to linear models as they do to deep neural network models.
As such, we're going to be _leveraging_ linear regression - 
at least for me, it serves as a minimally complex example 
to study the foundational ideas that generalize to deep neural networks.

## What is a linear model?

Your canonical linear model is something you've probably been taught in high school mathematics or physics.
At its simplest form, the model expresses the idea that "as some value increases, some other value should increase too".
In equation form, the linear model looks something like this:

$$y = mx + c$$

Linear models are popular because of their simple inductive bias - 
a simple Y-increases-with-X kind of relationship.
We know that most of the world isn't linear,
but if you zoom in hard enough into the world,
you can make any pair of relationships look linear.

### Exercise: Implement linear model

Let's build your intuition up by writing some Python code.
We'll start by writing the linear model as a Python function.
Implement the linear model below 
according to the following parameter specs.

In [None]:
def linear(params: tuple, x: float) -> float:
    """Univariate linear model.

    :param params: A tuple of two scalar floats.
        The first element should be the weight/slope
        and the second element should be the bias/intercept.
    :param x: Input data which can be scalar
        or a vector of numbers.
    :returns: A scalar of vector of outputs, y_estimated.
    """
    # Your answer goes here
    weight, bias = params 
    return weight * x + bias

## How does optimization relate to linear models?

The long answer cut short is this:
_Whenever we train a linear model._
(By the way, _training_ and _fitting_ a linear model are synonymous!)
Gradient-based optimization comes into the picture
as one of a few canonical ways to fit linear models to data.
(There are exact analytical solutions, but we won't be covering them here
because they would distract from our goal
of seeing how gradient-based optimization
is a general-purpose tool for fitting models.)

Let's break this down step-by-step.

Firstly, we formulate the problem.
If we squint hard at the linear model equation above,
then we'll see that the _parameters_ of the model
are $m$ and $c$, the slope and intercept parameters.
(In the deep learning world,
they are also referred to as weights and biases.)
We would like to tweak their values
such that when we evaluate the model's predictions
against a ground truth dataset of $y_{\text{true}}$ values
given a bunch of $x$ values,
we are able to obtain $y_{\text{estimated}}$ values from the model
that _minimizes the error_ between the $y_{\text{estimated}}$ and $y_{\text{true}}$ values.

That last statement is the key!
Remember that in any optimization problem
our goal is to tweak parameters of something
to minimize or maximize some other thing.
When we train a linear model
we are optimizing, or minimizing, an _error function_.
In the deep learning literature,
you might see this show up as the _loss function_,
while in other optimization literature circles,
you might see this show up as the _cost function_.
Those are synonyms for one another.

What, then, is a reasonable linear model loss function?
This is something we get to design!
A common loss function for linear models
_that produce continuous outputs_
is the mean squared error (MSE) function.
As its name suggests,
it is the mean of the squares of the errors between our model-predicted and actual values.
In math form, it looks like the following:

$$\text{MSE}(\text{params}, x, y) = \frac{\sum_{i}^{n} ({y_{i} - f(\text{params}, x_{i})})^2}{n}$$

Let's see what the intuition behind the MSE is...
without going into a full derivation of why the MSE function is a good loss function.
The numerator of the function above is the sum of squared errors.
In a good model,
we might believe that errors might be distributed more or less evenly
above and below the line of best fit (given by the tweaked values of $m$ and $c$).
As such, there'll be some positive and some negative values of errors.
Now, it's kind of nice to work with just the error values raw,
but if we would like a smiley-faced function that we can optimize,
taking squares is a good idea. 
Finally, we take an average over all of the $n$ data points that are present in the dataset - 
kind of useful if, for example, we want to compare the fit of the same class of model
on two datasets with different sizes.

**Note:** This next point has to be clarified:
we are _not_ optimizing the linear model function.
Rather, we are optimizing the parameters of the linear model
to minimize the loss function.

### Exercise: Implement loss function

Now, let's implement the mean squared error loss function as a Python + NumPy function.
Implement the function according to the Python specs below.

In [None]:
from typing import Callable
import jax.numpy as np

def mse_loss(params: tuple, model: Callable, x: float, y: float):
    """
    Mean squared error loss function.

    :param params: A tuple of params that gets passed into the `model` function.
    :param model: A Python callable that accepts `params` and `x` and returns `y_estimated`.
    :param x: Measured values of `x`.
    :param y: Measured values of `y`.
    """
    y_est = model(params, x)
    sq_err = np.power(y_est - y, 2)
    return np.mean(sq_err)

## Obtaining the gradient of the MSE loss function w.r.t. linear model parameters

What's coming up is going to hopefully illustrate for you 
why automatic differentiation systems are so powerful.
We now have a loss function that we're going to minimize for a given model.
If we wish to use gradient-based optimization to minimize the loss function,
then we're going to need to obtain its gradients w.r.t. model parameters.
I'm going to put it to you that:

1. the calculus involved is tedious, and
2. if we change the model structure, which will definitely happen as we progress to neural networks, re-computing gradient functions by hand is also incredibly annoying.

That, then, is why having an automatic differentiation system is so powerful.
With the `grad` function in JAX,
or any automatic differentiation system implemented in other deep learning libraries,
you get the benefit of being able to obtain gradient functions
without necessarily needing to do any of that by hand.

### Exercise: Obtain gradient function of the MSE loss function

Now, I'd like you to obtain the gradient function of the MSE loss function.
This should literally be a one-liner!

In [None]:
from jax import grad

dmse_loss = grad(mse_loss)

## Optimize the linear model parameters

With the gradient function on-hand, 
we can now use gradient-based optimization to optimize the parameters $m$ and $c$.

### Exercise: Implement optimization loop

Your final exercise here is to implement the optimization loop.
I have provided you with a dataset of values
that are obtained from a linear model.
By implementing the optimization loop,
your goal is to figure out what the values of the linear model parameters $m$ and $c$ are.

In [None]:
from jax import random

key = random.PRNGKey(49)

x = random.normal(key, shape=(1000,))
y = 3.5 * x - 1.9

In [None]:
# Firstly, we have initialized guesses for params.

k1, k2 = random.split(key)
params = random.normal(k1, shape=(2,))

# Now, write the optimization loop.

for i in range(200):
    # THIS IS WHERE THE EXERCISE LIVES!
    grad_params = dmse_loss(params, linear, x, y)
    params -= grad_params * 0.01

In [None]:
params

## Summary

### The pattern

In the first chapter, we showed you how to optimize a simple polynomial function of $x$.
The key steps were as follows:

1. Define the function of $x$ in Python code using NumPy.
2. Obtain its corresponding gradient function using `grad`
3. Use the evaluated gradient in a so-called _training loop_ to obtain the value of $x$ that minimizes the function.

With training linear models, we have an analogous situation:

1. Firstly, we define the mean squared error in terms of the key parameters we want to optimize, $m$ and $c$.
2. Obtain its corresponding gradient function using `grad`.
3. Use a _training loop_ to iteratively optimize $m$ and $c$ such that it minimizes the loss function.

### Automatic differentiation (AD)

I also wanted to make a point about automatic differentiation systems once more.
Without an automatic differentiation system, you would have had to implement the gradient function by hand.
With simple functions like polynomials,
that's not too tedious,
but once you go into the realm of derivatives for a wide variety of models,
taking derivatives over and over becomes a chore, tedium more suited to computers than humans.

### Looking forward

In the next two chapters, we will be swapping out the linear model for other models.
There, the power of automatic differentiation will become even more evident!

## Final exercises

### Exercise: When it comes to _regression_ problems, what loss function is commonly used?

- [x] Mean squared error
- [ ] Mean regression error
- [ ] Mean simple error
- [ ] Median squared error

### Exercise: In linear regression, what is the mean squared error being optimized with respect to?

- [ ] The output `y`
- [ ] The input `x` and parameters `m` and `c`
- [x] The parameters `m` and `c`
- [ ] The inputs `x` and the outputs `y`