In this notebook, we implement the self distillation part of the Informer Encoder seen [here](https://github.com/zhouhaoyi/Informer2020/blob/main/models/encoder.py#L5) in the original code

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

Let `y` be the output of the ProbSparse attention + the rest of the encoder transformation logic.

In [None]:
y = torch.tensor(
    [[[-0.4279, -0.3064, -0.4279, -0.4279, -0.7816, -0.2918, -0.4279, -0.1300, -0.2762, -0.3435],
      [-0.3153, -0.2792, -0.3153, -0.3153, -0.0058, -0.1535, -0.3153,  0.2457, -0.4200, -0.4456],
      [ 0.2137,  0.1333,  0.2137,  0.2137,  0.3584,  0.0689,  0.2137, -0.0486,  0.1854,  0.2332],
      [-0.0565, -0.0637, -0.0565, -0.0565, -0.1396, -0.0519, -0.0565, -0.2497, -0.0881, -0.0858]
]])

In [None]:
batch_size, d_model, L_Q = y.shape
y.shape

torch.Size([1, 4, 10])

In [None]:
y.transpose(1, -1)

tensor([[[-0.4279, -0.3153,  0.2137, -0.0565],
         [-0.3064, -0.2792,  0.1333, -0.0637],
         [-0.4279, -0.3153,  0.2137, -0.0565],
         [-0.4279, -0.3153,  0.2137, -0.0565],
         [-0.7816, -0.0058,  0.3584, -0.1396],
         [-0.2918, -0.1535,  0.0689, -0.0519],
         [-0.4279, -0.3153,  0.2137, -0.0565],
         [-0.1300,  0.2457, -0.0486, -0.2497],
         [-0.2762, -0.4200,  0.1854, -0.0881],
         [-0.3435, -0.4456,  0.2332, -0.0858]]])

The records with similar values correspond to the non-chosen queries.

If the 10 queries, 6 are active. Hence 4 are lazy and so we have 4 redundencies here.

Our goal is to now remove these redundencies. And we do so with a process called "self attention distillation".

In chemistry, distillation is the process extraction and concentration of a compount from a mixture.

In much the same way, we perform distillation here to extract the active queries from the lazy queries.

In practice, this is done adding a convolution, batch normalization, activation and max pooling.

This is the mathematical representation of the distillation process

$$
y = \text{MaxPool(ELU(BatchNorm(Conv(y))))}
$$

## 1 Conv1D



![Conv 1D](https://discuss.pytorch.org/uploads/default/original/3X/5/f/5faf64f4eb86dd37121774c720877b1d44f7f617.gif)

^ Source: https://discuss.pytorch.org/t/1d-convolution-on-1d-data/54661/11

In [None]:
downConv = nn.Conv1d(
    in_channels=d_model,
    out_channels=d_model,
    kernel_size=3,
    padding=1,
    padding_mode='circular'
)
downConv

Conv1d(4, 4, kernel_size=(3,), stride=(1,), padding=(1,), padding_mode=circular)

In [None]:
y = downConv(y)
y.transpose(1, -1)

tensor([[[ 0.1781, -0.2201, -0.3053, -0.0687],
         [ 0.1251, -0.2603, -0.3470, -0.0435],
         [ 0.1565, -0.2149, -0.3138, -0.0751],
         [ 0.0428, -0.4520, -0.3234, -0.1159],
         [ 0.3373, -0.1578, -0.3097, -0.1270],
         [ 0.1785, -0.2392, -0.4275, -0.0581],
         [ 0.2518, -0.1709, -0.1564, -0.2870],
         [ 0.1757, -0.1858, -0.4516,  0.0660],
         [ 0.2008, -0.0368, -0.2885, -0.1068],
         [ 0.1144, -0.2748, -0.2837, -0.0460]]], grad_fn=<TransposeBackward0>)

In [None]:
y.shape

torch.Size([1, 4, 10])

**What is happening in convolution 1D?**
- This means we should have 4 (C_out) kernels with each kernel having a shape 3 x 4 (kernel_size x C_in).
- So for the given input 1 x 4 x 10, we slide one 3x4 kernel along the time dimension (10).
- In the first iteration of convolution, the 3x4 kernel is convolved with the first 3 timestamps.
- So when convolving the 3x4 kernel with 3x4 slice of the input matrix, we end up with a single number.
- We then slide the kernel, perform the convolutions and end up with a single number each time.
- The once the kernel has slid to the end of the time dimension, we would have ended up with a 1 x 10 dimensional vector.
- But there should be 4 output channels (C_out = 4).
- Hence performing the similar convolution along the time dimension for each of these kernels leads to 4 1x10 matricies. Or a 4 x 10 matrix.

**Effectively as a result of this operation**:
- The shape is not changed
- The vectors now have a sense of local context becuase of the kernel size = 3.

Here is an implementation of Conv1D

In [None]:
class CustomConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super(CustomConv1d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size))
        self.bias = nn.Parameter(torch.randn(out_channels))

    def forward(self, x):

        if self.padding > 0:
            x = F.pad(x, (self.padding, self.padding))

        batch_size, in_channels, sequence_length = x.shape
        out_length = (sequence_length - self.kernel_size) // self.stride + 1

        output = torch.zeros((batch_size, self.out_channels, out_length))

        # Perform the convolution operation
        for b in range(batch_size):
            for o in range(self.out_channels):
                for i in range(0, sequence_length - self.kernel_size + 1, self.stride):
                    region = x[b, :, i:i+self.kernel_size]
                    output[b, o, i // self.stride] = torch.sum(region * self.weight[o, :, :]) + self.bias[o]

        return output

kernel_size = 3
in_channels = 4
out_channels = 4

custom_conv1d = CustomConv1d(
    in_channels=in_channels,
    out_channels=out_channels,
    kernel_size=kernel_size,
    stride=1,
    padding=1
)
x = torch.randn(1, in_channels, 10) # Example input with (batch_size, in_channels, sequence_length)
print("Input Shape:", x.shape)
x = custom_conv1d(x)
print("Output Shape:", x.shape)
print("Output:", x.transpose(1, -1))

Input Shape: torch.Size([1, 4, 10])
Output Shape: torch.Size([1, 4, 10])
Output: tensor([[[ 3.0105, -0.3359,  4.7515,  1.9616],
         [-4.4009,  1.3132,  0.4764, -2.2379],
         [-1.0181,  1.4909,  1.6197,  1.8993],
         [ 3.6055, -1.2282,  1.0855, -2.3256],
         [ 1.2337, -0.3768, -2.1121, -0.8444],
         [-1.1882, -0.1815, -2.2545,  2.6530],
         [ 7.1665, -0.6749, -1.3431,  5.7508],
         [ 0.1426, -2.5285, -1.1750,  0.2404],
         [-1.7602,  2.9976,  2.5532,  4.9390],
         [ 1.5305, -0.7658, -0.1509, -3.7880]]], grad_fn=<TransposeBackward0>)


## BatchNorm1D

In [None]:
norm = nn.BatchNorm1d(d_model)
norm

BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

In [None]:
y = norm(y)
y

tensor([[[ 0.0260, -0.6752, -0.2591, -1.7654,  2.1353,  0.0313,  1.0030,
          -0.0055,  0.3272, -0.8175],
         [ 0.0113, -0.3901,  0.0636, -2.3083,  0.6347, -0.1790,  0.5033,
           0.3547,  1.8452, -0.5353],
         [ 0.2003, -0.3423,  0.0891, -0.0349,  0.1433, -1.3888,  2.1361,
          -1.7024,  0.4185,  0.4810],
         [ 0.2078,  0.5062,  0.1316, -0.3516, -0.4835,  0.3328, -2.3792,
           1.8039, -0.2442,  0.4762]]], grad_fn=<NativeBatchNormBackward0>)

In [None]:
y.transpose(1, -1)

tensor([[[ 0.0260,  0.0113,  0.2003,  0.2078],
         [-0.6752, -0.3901, -0.3423,  0.5062],
         [-0.2591,  0.0636,  0.0891,  0.1316],
         [-1.7654, -2.3083, -0.0349, -0.3516],
         [ 2.1353,  0.6347,  0.1433, -0.4835],
         [ 0.0313, -0.1790, -1.3888,  0.3328],
         [ 1.0030,  0.5033,  2.1361, -2.3792],
         [-0.0055,  0.3547, -1.7024,  1.8039],
         [ 0.3272,  1.8452,  0.4185, -0.2442],
         [-0.8175, -0.5353,  0.4810,  0.4762]]], grad_fn=<TransposeBackward0>)

In [None]:
y.shape

torch.Size([1, 4, 10])

In [None]:
class CustomBatchNorm1d(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
        super(CustomBatchNorm1d, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine

        # Initialize learnable parameters gamma and beta if affine is True
        if self.affine:
            self.gamma = nn.Parameter(torch.ones(num_features))
            self.beta = nn.Parameter(torch.zeros(num_features))
        else:
            self.register_parameter('gamma', None)
            self.register_parameter('beta', None)

        # Initialize running mean and running variance
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))

    def forward(self, x):
        if self.training:
            # Compute batch mean and variance
            batch_mean = x.mean(dim=[0, 2])
            batch_var = x.var(dim=[0, 2], unbiased=False)

            # Update running mean and variance
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var
        else:
            batch_mean = self.running_mean
            batch_var = self.running_var

        # Normalize input
        x_hat = (x - batch_mean[None, :, None]) / torch.sqrt(batch_var[None, :, None] + self.eps)

        # Apply learnable parameters if affine is True
        if self.affine:
            x_hat = self.gamma[None, :, None] * x_hat + self.beta[None, :, None]

        return x_hat

# Create a custom BatchNorm1d layer for 4 features
batch_norm = CustomBatchNorm1d(4)
batch_norm.train() # Set to training mode
optimizer = optim.SGD(batch_norm.parameters(), lr=0.01) # Create an optimizer

# Example training loop
for i in range(100):
    x = torch.randn(10, 4, 10) + torch.randn(10, 4, 10) * 0.5  # Add some noise to the tensor
    optimizer.zero_grad()  # Zero the gradients
    output = batch_norm(x)
    loss = ((output - x) ** 2).mean()  # Mean squared error for demonstration
    loss.backward()  # Backpropagate the loss
    optimizer.step()  # Update the parameters

batch_norm.eval()  # Switch to evaluation mode

# Example input tensor with shape 1 x 4 x 10 (batch_size, num_features, sequence_length)
x = torch.randn(1, 4, 10) + torch.randn(1, 4, 10) * 10  # Add some noise to the tensor
print(f"Input x = {x.transpose(1, -1)}")
# Apply custom Batch Normalization
x_normalized = batch_norm(x)
print("Normalized Output:", x_normalized.transpose(1, -1))
print(f"Learned parameters: gamma = {batch_norm.gamma}, beta = {batch_norm.beta}")

Input x = tensor([[[ -3.1063,   0.1304,  17.9303,  -8.4741],
         [ 11.1877,   6.3469,  -9.0957,  -1.5174],
         [-14.7600,  10.8687,  26.7499,   7.0928],
         [  1.1092,  -3.5419,   2.1124,   1.7924],
         [-19.9245,   2.3600,  -3.9004,   9.1631],
         [ 14.7265,   0.9318,  -2.3177,   3.8916],
         [  2.2570,  12.1360,  -8.3831,  -0.6496],
         [  8.5910,   8.0781,   6.2588,   2.6054],
         [  3.8296,   8.5246,  10.7348, -11.1375],
         [-17.7490, -11.3412,   7.7679,  15.6306]]])
Normalized Output: tensor([[[ -2.8880,   0.1589,  16.7962,  -7.8655],
         [ 10.3153,   6.0642,  -8.5759,  -1.4109],
         [-13.6525,  10.3597,  25.0761,   6.5779],
         [  1.0058,  -3.3296,   1.9464,   1.6601],
         [-18.4229,   2.2769,  -3.6985,   8.4988],
         [ 13.5840,   0.9202,  -2.2126,   3.6078],
         [  2.0660,  11.5635,  -7.9068,  -0.6057],
         [  7.9168,   7.7088,   5.8390,   2.4144],
         [  3.5187,   8.1330,  10.0411, -10.3367],


**What is BatchNorm 1D doing**
- During training, we compute a running mean and variance of each feature across batches.
- During inference, we normalize input using these running mean and variance.  

**Effectively:**
- Does not change shape of input tensor
- Overall, improves for faster and more stable training

## 3 ELU

![Elu](https://armandolivares.tech/wp-content/uploads/2022/09/elu-1.png)

In [None]:
activation = nn.ELU()
activation

ELU(alpha=1.0)

In [None]:
np.exp(-0.6752) - 1

-0.49094540049412283

In [None]:
y.transpose(1, -1)

tensor([[[ 0.0260,  0.0113,  0.2003,  0.2078],
         [-0.6752, -0.3901, -0.3423,  0.5062],
         [-0.2591,  0.0636,  0.0891,  0.1316],
         [-1.7654, -2.3083, -0.0349, -0.3516],
         [ 2.1353,  0.6347,  0.1433, -0.4835],
         [ 0.0313, -0.1790, -1.3888,  0.3328],
         [ 1.0030,  0.5033,  2.1361, -2.3792],
         [-0.0055,  0.3547, -1.7024,  1.8039],
         [ 0.3272,  1.8452,  0.4185, -0.2442],
         [-0.8175, -0.5353,  0.4810,  0.4762]]], grad_fn=<TransposeBackward0>)

In [None]:
y = activation(y)
y.transpose(1, -1)

tensor([[[ 0.0260,  0.0113,  0.2003,  0.2078],
         [-0.4910, -0.3230, -0.2898,  0.5062],
         [-0.2282,  0.0636,  0.0891,  0.1316],
         [-0.8289, -0.9006, -0.0343, -0.2964],
         [ 2.1353,  0.6347,  0.1433, -0.3834],
         [ 0.0313, -0.1639, -0.7506,  0.3328],
         [ 1.0030,  0.5033,  2.1361, -0.9074],
         [-0.0055,  0.3547, -0.8177,  1.8039],
         [ 0.3272,  1.8452,  0.4185, -0.2167],
         [-0.5585, -0.4145,  0.4810,  0.4762]]], grad_fn=<TransposeBackward0>)

In [None]:
class CustomELU(nn.Module):
    def __init__(self, alpha=1.0):
        super(CustomELU, self).__init__()
        self.alpha = alpha

    def forward(self, x):
        return torch.where(x >= 0, x, self.alpha * (torch.exp(x) - 1))

elu = CustomELU(alpha=1.0)
x = torch.randn(1, 4, 10) # Example input with (batch_size, in_channels, sequence_length)
print("Input:", x.transpose(1, -1))
x = elu(x)
print("Output:", x.transpose(1, -1))

Input: tensor([[[ 1.3933, -0.1014,  0.1263, -1.4776],
         [-0.6408,  1.3629,  1.3665, -1.3499],
         [-0.4621,  0.7703,  1.5326, -0.4740],
         [-0.0628, -0.3342, -2.3160,  0.1374],
         [-0.1057, -0.7446, -0.0148, -1.0766],
         [ 1.2083,  1.0100,  0.5466, -0.8150],
         [ 1.3274, -1.4542,  2.2502, -1.9747],
         [-0.1835,  1.1149, -1.2386,  1.8667],
         [ 0.2250, -0.6249,  0.7311,  1.9726],
         [-0.9925, -0.3995, -0.1549,  0.7506]]])
Output: tensor([[[ 1.3933, -0.0965,  0.1263, -0.7718],
         [-0.4731,  1.3629,  1.3665, -0.7407],
         [-0.3701,  0.7703,  1.5326, -0.3775],
         [-0.0608, -0.2841, -0.9013,  0.1374],
         [-0.1003, -0.5251, -0.0147, -0.6592],
         [ 1.2083,  1.0100,  0.5466, -0.5573],
         [ 1.3274, -0.7664,  2.2502, -0.8612],
         [-0.1676,  1.1149, -0.7102,  1.8667],
         [ 0.2250, -0.4647,  0.7311,  1.9726],
         [-0.6294, -0.3294, -0.1435,  0.7506]]])


**What is ELU doing?**
- For each element, it performs a linear or exponential transformation

**Effectively**
- The shape does not change
- Like other activation functions, it introduces non-linearity so the network can learn complex patterns
- Unlike ReLU, it prevents dead neurons and vanishing gradients.

# 4 MaxPooling

In [None]:
maxPool = nn.MaxPool1d(
    kernel_size=3,
    stride=2,
    padding=1
)
maxPool

MaxPool1d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)

In [None]:
y.transpose(1, -1)

tensor([[[ 0.0260,  0.0113,  0.2003,  0.2078],
         [-0.4910, -0.3230, -0.2898,  0.5062],
         [-0.2282,  0.0636,  0.0891,  0.1316],
         [-0.8289, -0.9006, -0.0343, -0.2964],
         [ 2.1353,  0.6347,  0.1433, -0.3834],
         [ 0.0313, -0.1639, -0.7506,  0.3328],
         [ 1.0030,  0.5033,  2.1361, -0.9074],
         [-0.0055,  0.3547, -0.8177,  1.8039],
         [ 0.3272,  1.8452,  0.4185, -0.2167],
         [-0.5585, -0.4145,  0.4810,  0.4762]]], grad_fn=<TransposeBackward0>)

In [None]:
y = maxPool(y)
y.transpose(1, -1)

tensor([[[ 0.0260,  0.0113,  0.2003,  0.5062],
         [-0.2282,  0.0636,  0.0891,  0.5062],
         [ 2.1353,  0.6347,  0.1433,  0.3328],
         [ 1.0030,  0.5033,  2.1361,  1.8039],
         [ 0.3272,  1.8452,  0.4810,  1.8039]]], grad_fn=<TransposeBackward0>)

In [None]:
y.shape

torch.Size([1, 4, 5])

**What happens during max pooling 1D?**
- kernel_size = 3, stride = 2, padding = 1
- A window of 3 (kernel size) is slid along the time dimension on each of the 4 dimensions.
- When sliding, the window skips over 2 items (stride)
- When sliding, there is a padding before and after the time dimension (padding).
- With the 1 x 4 x 10 input, we slide the max pool window on the first of the 4 dimensions across the first 3 timesteps, take the max and slide the window over. This row will lead to 1 x 5 max values
- Since the vector size is 4, we iterate the max pool window to get 4 of these 1x5 max values.
- The result is 1 x 4 x 5 (time dimension) values.

**Effectively**
- max pooling slices the time dimension in half
- max pooling will allow us to only select the more "active" queries and leave out the more redundant "lazy" queries.

In [None]:
class CustomMaxPool1d(nn.Module):
    def __init__(self, kernel_size, stride=None, padding=0):
        super(CustomMaxPool1d, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride if stride is not None else kernel_size
        self.padding = padding

    def forward(self, x):

        if self.padding > 0:
            x = F.pad(x, (self.padding, self.padding))

        batch_size, in_channels, sequence_length = x.shape
        out_length = (sequence_length - self.kernel_size) // self.stride + 1
        output = torch.zeros((batch_size, in_channels, out_length))

        # Perform the max pooling operation manually
        for b in range(batch_size):
            for c in range(in_channels):
                for i in range(0, sequence_length - self.kernel_size + 1, self.stride):
                    region = x[b, c, i:i+self.kernel_size]
                    output[b, c, i // self.stride] = torch.max(region)

        return output

# Create a MaxPool1d layer
maxpool1d = CustomMaxPool1d(kernel_size=2, stride=2, padding=0)

# Example input tensor with shape (batch_size, in_channels, sequence_length)
x = torch.tensor(
    [[[1.0, 2.0, 3.0, 4.0, 5.0],
      [5.0, 4.0, 3.0, 2.0, 1.0]]]
)
print(f"Input  ({x.shape}) : {x} ")
x = maxpool1d(x)
print(f"Output ({x.shape}) :", x)

Input  (torch.Size([1, 2, 5])) : tensor([[[1., 2., 3., 4., 5.],
         [5., 4., 3., 2., 1.]]]) 
Output (torch.Size([1, 2, 2])) : tensor([[[2., 4.],
         [5., 3.]]])
