# Linear Layer

## Math

Input
* $x \in \mathbb{R}^{d_{in}}$ 

Weights
* weight $W \in \mathbb{R}^{d_{out}\times d_{in}}$ 
* bias $b \in \mathbb{R}^{d_{out}}$

Output
* $o \in \mathbb{R}^{d_{out}}$

$$o = \text{Linear}_{W,b}(x)=Wx+b.$$

**Note:** In practice we also have a batch dimension and the inputs takes the form $X\in \mathbb{R}^{d_{batch}\times d_{in}}$ so the immplementation takes the form 

$$O = \text{Linear}_{W,b}(X)=XW^{T}+B,$$
where
$$B=[b|b|\dots|b]^T\in \mathbb{R}^{d_{batch},d_{out}}.$$

## Code

In [86]:
import torch
import torch.nn as nn
import math

In [None]:
1.6732632423543772

In [None]:
class Linear(nn.Module):
    def __init__(
        self, input_dim, output_dim, bias=True, device=None, dtype=None, init_activation=None
    ):
        super().__init__()
        W = torch.empty(output_dim, input_dim, device=device, dtype=dtype)
        self.weight = nn.Parameter(W)
        match init_activation:
            case None:
                nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
            case 'relu':
                nn.init.kaiming_uniform_(self.weight, a=0, nonlinearity='relu')
            case 'leaky_relu' | 'prelu':
                nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5), nonlinearity='leaky_relu')
            case 'selu':
                alpha = 1.6732632423543772
                nn.init.normal_(self.weight, mean=0, std=math.sqrt(1 / input_dim) * math.sqrt(1 + alpha ** 2))
            case 'elu':
                gain=nn.init.calculate_gain('elu')
                nn.init.xavier_uniform_(self.weight, gain=gain)
            case 'gelu':
                gain=nn.init.calculate_gain('gelu')
                nn.init.xavier_uniform_(self.weight, gain=gain)
            case 'tanh':
                gain = nn.init.calculate_gain('tanh')
                nn.init.xavier_uniform_(self.weight, gain=gain)
            case 'sigmoid':
                gain = nn.init.calculate_gain('sigmoid')
                nn.init.xavier_uniform_(self.weight, gain=gain)
            
        if bias:
            b = torch.empty(output_dim, device=device, dtype=dtype)
            self.bias = nn.Parameter(b)
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)
        else:
            self.register_parameter("bias", None)

    def forward(self, x):
        out = x @ self.weight.T
        if self.bias is not None:
            out += self.bias
        return out

## Testing

In [108]:
batch_size = 4
in_dim = 5
out_dim = 3

bias = False
device = 'mps'
dtype = torch.float32

In [109]:
x = torch.randn(batch_size, in_dim).to(device=device, dtype=dtype)

### Weights

In [110]:
torch_linear = nn.Linear(in_dim, out_dim, bias=bias,device=device, dtype=dtype)
linear = Linear(in_dim, out_dim, bias=bias, device=device, dtype=dtype)

In [111]:
for name, param in torch_linear.named_parameters():
    print(name, param.shape)

weight torch.Size([3, 5])


In [112]:
for name, param in linear.named_parameters():
    print(name, param.shape)

weight torch.Size([3, 5])


## Output

In [104]:
seed = 40

In [105]:
torch.manual_seed(seed)
torch_linear = nn.Linear(in_dim, out_dim, bias=bias,device=device, dtype=dtype)

torch.manual_seed(seed)
linear = Linear(in_dim, out_dim, bias=bias, device=device, dtype=dtype)

In [106]:
linear(x)

tensor([[-0.5044, -0.2027, -0.0019],
        [-0.2299,  0.4194, -0.1675],
        [ 0.6229,  0.3113,  0.2739],
        [ 0.5654, -1.1222, -0.3983]], device='mps:0', grad_fn=<MmBackward0>)

In [107]:
torch_linear(x)

tensor([[-0.5044, -0.2027, -0.0019],
        [-0.2299,  0.4194, -0.1675],
        [ 0.6229,  0.3113,  0.2739],
        [ 0.5654, -1.1222, -0.3983]], device='mps:0',
       grad_fn=<LinearBackward0>)