# Linear Layer

## Definition: Linear

Input
* $x \in \mathbb{R}^{d_{in}}$ 

xeights
* weight $W \in \mathbb{R}^{d_{out}\times d_{in}}$ 
* bias $b \in \mathbb{R}^{d_{out}}$

Output
* $o \in \mathbb{R}^{d_{out}}$

$$o = \text{Linear}_{W,b}(x)=Wx+b.$$

**Note:** In practice we also have a batch dimension and the inputs takes the form $X\in \mathbb{R}^{d_{b}\times d_{in}}$ so the immplementation takes the form 

$$O = \text{Linear}_{W,b}(X)=XW^{T}+1_{d_b}b^T,$$
where
$$1_{n}=\begin{pmatrix}1\\1\\\vdots\\ 1
\end{pmatrix}.$$
denotes the $n$-dimensional vector of 1s.

## Property: Let $A\in\mathbb{R}^{m\times n}$ and $w\in \mathbb{R}^n$, then
 $$\frac{\partial Aw}{\partial w} = A$$

**Proof:** Consider $A=[a_1|a_2|\dots|a_m]^T$ and $f_i(w)=[Aw]_i=a^T_{i}w$. Then we have:
$$
\begin{align*}
\frac{\partial f_i}{\partial w_\ell}(w) &= \frac{\partial }{\partial w_\ell}\left[\sum_{j=1}^n a_{i,j}w_j\right] ,
\\
&=\sum_{j=1}^na_{i,j}\frac{\partial f}{\partial w_\ell}\left[w_j\right] ,\\
&=\sum_{j=1}^na_{i,j}\begin{cases}
1, & j=\ell,\\
0,&j\neq \ell,\\
\end{cases}
,\\
&= a_{i,l}.
\end{align*}
$$

Then $\frac{\partial f_i}{\partial w}= a_i^T$, and consequently $\frac{\partial Aw}{\partial w} =A$.

## Property: Let $A\in\mathbb{R}^{n\times n}$ and $w\in \mathbb{R}^n$, then
 $$\frac{\partial w^TAw}{\partial w} = w^T(A+A^T)$$

**Proof:** Consider $f(w)=w^tAw$ and let's derive $f$ wrt $w_\ell$:
$$
\begin{align*}
\frac{\partial f}{\partial w_\ell}(w) &= \frac{\partial }{\partial w_\ell}\left[\sum_{i,j}w_ia_{i,j}w_j\right] ,
\\
&=\sum_{i,j=1}^na_{i,j}\frac{\partial }{\partial w_\ell}\left[w_iw_j\right] ,\\
&=\sum_{i,j=1}^na_{i,j}\begin{cases}
w_j, & i=\ell,j\neq \ell,\\
w_i, & i\neq \ell,j= \ell,\\
2w_{\ell} & i=\ell,j= \ell,\\
0& i\neq\ell,j\neq \ell,\\
\end{cases}
,\\
&= \sum_{ j=1, j\neq \ell}^n a_{\ell,j}w_j+\sum_{ i=1, i\neq \ell}^n a_{i,\ell}w_i +2a_{\ell,\ell}a_{\ell,\ell},\\
&= \sum_{ j=1}^n a_{\ell,j}w_j+\sum_{ i=1}^n a_{i,\ell}w_i,\\
&= [Aw]_{j}+[A^w]_{j},\\
&= [Aw+A^Tw]_{j},\\
&= [(A+A^T)w]_{j},\\
\end{align*}
$$

So we get $\frac{\partial f}{\partial w} =[(A+A^T)w ]^T=w^T(A^T+A)=w^T(A+A^T)$.

## Property: Let $A\in\mathbb{R}^{m\times n}$ is a full rank matrix with $m>n$, $b\in \mathbb{R}^m$ and $w\in \mathbb{R}^n$. Then
1. The matrix $A^T\!A$ is invertible
2. The function $f(w) = \|Aw-b\|_2$ reachs its unique minimum value at

$$w_{\text{min}} = (A^TA)^{-1}A^Tb $$

**Proof (1.):** Since $A$ has full column rank ($m>n$), we have $ \operatorname{rank}(A) = n.$

By the Rank-Nullity Theorem,

$$
\operatorname{rank}(A) + \operatorname{nullity}(A) = n,
$$

so

$$
\operatorname{nullity}(A) = 0.
$$

If $A^T A x = 0$, then

$$
0 = x^T A^T A x = (Ax)^T A x =\|Ax\|_2^2 \Rightarrow Ax = 0,
$$

which implies

$$
x = 0.
$$

Hence,

$$
\operatorname{nullity}(A^T A) = 0.
$$

Since $A^T A \in \mathbb{R}^{n \times n}$ and has zero nullity, it is invertible.


**Proof (2.):** Notice that $f(w)=\sqrt{(Aw-b)^T(Aw-b)}$ ing $\sqrt{\cdot}$ is monotone, it is enough to prove the result for the function $g(w)=(Aw-b)^T(Aw-b)$. Notice that

$$
\begin{align*}
g(w)&=(Aw-b)^T(Aw-b),\\
&=(w^TA^T-b^T)(Aw-b),\\
&=w^TA^TAw-w^TA^Tb-b^TAw+b^Tb,\\
&=w^TA^TAw-2b^TAw+b^Tb,\\
\end{align*}
$$

Then taking derivative wrt $w$

$$
\begin{align*}
\frac{\partial g}{\partial w}(w)&=2w^TA^TA-2b^TA,\\
\end{align*}
$$

Since the first derivative only vanishis at
$$
\begin{align*}
0&=\frac{\partial g}{\partial w}(w)
&=2w^TA^TA-2b^TA,\\
2b^TA&=2w^TA^TA,\\
A^TAw&=A^Tb,\\
w&=(A^TA)^{-1}A^Tb.
\end{align*}
$$

the only positive extreme value is $w_{\text{min}}=(A^TA)^{-1}A^Tb$ 

But $w_{\text{min}}$ is a minimum because its Hessian matrix is positive definite
$$H(g)=2A^TA$$
so $w_{\text{min}}=(A^TA)^{-1}A^Tb$  is the only minimum.

## Property: Exact Solution

If $X=[x_1|x_2|\dots|x_m]^T\in\mathbb{R}^{m\times n}$ is the matrix of $m$ trainning inputs and $y = (y_1,y_2,\dots,y_m)\in \mathbb{R}^m$ is the training data to the linear model $\text{Linear}_{w,b}$. Then the optimal values for the learnable parameter $w$ and $b$ for the given data $(X,y)$ and measing error using using the $\text{MSE}$ loss function 
$$\text{mse}(\hat{y},y)=\frac{1}{n}\sum_{i=1}^n(y_i-\hat{y}_i)^2,$$
this means:
$$
\begin{pmatrix}
w*\\
-\\
b*
\end{pmatrix} = \min_{w,b}\{\text{mse}\left(\text{Linear}_{w,b}(X),y\right)\}
$$
are
$$\begin{pmatrix}
w*\\
-\\
b*
\end{pmatrix}=(\hat{X}^T\hat{X})^{-1} \hat{X}^Ty,$$ 

where $\hat{X}=[X|1_{m}]$.


**Proof:** Notice that
$$
\begin{align*}
\text{Linear}_{w,b}(X)=Xw+b1_m =\hat{X}
\begin{pmatrix}
w\\
-\\
b
\end{pmatrix}.
\end{align*}
$$
From here the result is direct from applyting the previous property for $f(w)=\|\text{Linear}_{w,b}(x)-y\|_2$. Notice that
$$\text{mse}\left(\text{Linear}_{w,b}(x),y\right)= \frac{1}{n}\|\text{Linear}_{w,b}(x)-y\|_2^2$$

## Code: Linear

In [18]:
import torch
import torch.nn as nn
from linear import Linear

## Testing

In [20]:
batch_size = 4
in_dim = 5
out_dim = 3

bias = False
device = 'mps'
dtype = torch.float32

In [21]:
x = torch.randn(batch_size, in_dim).to(device=device, dtype=dtype)

### Weights

In [22]:
torch_linear = nn.Linear(in_dim, out_dim, bias=bias,device=device, dtype=dtype)
linear = Linear(in_dim, out_dim, bias=bias, device=device, dtype=dtype)

In [23]:
for name, param in torch_linear.named_parameters():
    print(name, param.shape)

weight torch.Size([3, 5])


In [24]:
for name, param in linear.named_parameters():
    print(name, param.shape)

weight torch.Size([3, 5])


## Output

In [25]:
seed = 40

In [26]:
torch.manual_seed(seed)
torch_linear = nn.Linear(in_dim, out_dim, bias=bias,device=device, dtype=dtype)

torch.manual_seed(seed)
linear = Linear(in_dim, out_dim, bias=bias, device=device, dtype=dtype)

In [27]:
linear(x)

tensor([[ 0.0488, -0.1656, -0.1285],
        [ 0.2958,  0.4459,  0.5653],
        [-0.4216,  0.1715, -0.2898],
        [-1.2251,  0.9167, -0.0632]], device='mps:0', grad_fn=<MmBackward0>)

In [28]:
torch_linear(x)

tensor([[ 0.0488, -0.1656, -0.1285],
        [ 0.2958,  0.4459,  0.5653],
        [-0.4216,  0.1715, -0.2898],
        [-1.2251,  0.9167, -0.0632]], device='mps:0',
       grad_fn=<LinearBackward0>)

## Check Exact Solution

In [29]:
batch_size = 100
in_dim = 5
out_dim = 1

bias = False

In [30]:
torch_linear = nn.Linear(in_dim, out_dim, bias=bias,device=device, dtype=dtype)

### Train

In [31]:
x = torch.rand(batch_size, in_dim).to(device=device, dtype=dtype)
x.shape
A = torch.tensor([[2, 3, 0, 1,4]],device=device,dtype=dtype).T

In [32]:
y_true = (x@A+ 0.1* torch.randn(batch_size, out_dim).to(device=device, dtype=dtype))
optim = torch.optim.SGD(torch_linear.parameters(), lr=0.01)
x.shape

for epoch in range(10_000):
    y_pred = torch_linear(x)


    loss = nn.MSELoss()(y_pred, y_true)

    torch_linear.zero_grad()
    loss.backward()
    optim.step()

    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item()}")

Epoch 0, Loss: 25.60531997680664
Epoch 10, Loss: 15.397218704223633
Epoch 20, Loss: 9.4021577835083
Epoch 30, Loss: 5.876128196716309
Epoch 40, Loss: 3.7972495555877686
Epoch 50, Loss: 2.5667552947998047
Epoch 60, Loss: 1.8337950706481934
Epoch 70, Loss: 1.39277982711792
Epoch 80, Loss: 1.1232341527938843
Epoch 90, Loss: 0.9545536637306213
Epoch 100, Loss: 0.8453550338745117
Epoch 110, Loss: 0.7713768482208252
Epoch 120, Loss: 0.718390703201294
Epoch 130, Loss: 0.6780471801757812
Epoch 140, Loss: 0.6454421281814575
Epoch 150, Loss: 0.6176899075508118
Epoch 160, Loss: 0.5930885076522827
Epoch 170, Loss: 0.5706294775009155
Epoch 180, Loss: 0.5497121214866638
Epoch 190, Loss: 0.529973566532135
Epoch 200, Loss: 0.5111919045448303
Epoch 210, Loss: 0.4932273030281067
Epoch 220, Loss: 0.4759887158870697
Epoch 230, Loss: 0.4594144821166992
Epoch 240, Loss: 0.44345924258232117
Epoch 250, Loss: 0.428088903427124
Epoch 260, Loss: 0.41327521204948425
Epoch 270, Loss: 0.39899420738220215
Epoch 280,

In [33]:
torch_linear.weight.data

tensor([[1.9613e+00, 3.0499e+00, 3.9282e-03, 9.8797e-01, 3.9679e+00]],
       device='mps:0')

In [34]:
(x.T @ x).inverse() @ x.T @ y_true

tensor([[1.9613e+00],
        [3.0499e+00],
        [3.8978e-03],
        [9.8793e-01],
        [3.9679e+00]], device='mps:0')