# Transformer - Attention Layer

## 0. imports

In [1]:
%load_ext jupyter_black

In [2]:
import sys

sys.path.append("..")

In [11]:
import numpy as np

import torch
import torch.nn as nn

from math import sqrt

In [3]:
from src.dataset.datamodule import ETTDataModule
from src.model.embeddings import DataEmbedding

## 1. ETT DataModule

In [4]:
dm_params = {
    "data_path": "../data/ETT-small/ETTh1.csv",
    "task": "M",
    "freq": "h",
    "target": "OT",
    "seq_len": 96,
    "label_len": 48,
    "pred_len": 96,
    "use_scaler": True,
    "use_time_enc": True,
    "batch_size": 32,
}


dm = ETTDataModule(**dm_params)
train_dataloader = dm.train_dataloader()
batch = next(iter(train_dataloader))

## 2. Embedding Layer

In [5]:
emb_params = {
    "c_in": 7,
    "d_model": 512,
    "embed_type": "time_features",
    "freq": "h",
    "dropout": 0.1,
}


embedding = DataEmbedding(**emb_params)

In [6]:
x = embedding(x=batch["past_values"], x_features=batch["past_time_features"])

In [7]:
x.shape

torch.Size([32, 96, 512])

## 3. Attention Layer

### 3.1 line by line

In [34]:
d_model = 512
n_heads = 8
d_keys = None
d_values = None
attention_dropout = 0.1
output_attention = False

d_keys = d_keys or (d_model // n_heads)
d_values = d_values or (d_model // n_heads)


query_projection = nn.Linear(d_model, d_keys * n_heads)
key_projection = nn.Linear(d_model, d_keys * n_heads)
value_projection = nn.Linear(d_model, d_values * n_heads)
out_projection = nn.Linear(d_values * n_heads, d_model)
dropout = nn.Dropout(attention_dropout)

In [47]:
queries = x
keys = x
values = x

B, L, _ = queries.shape
_, S, _ = keys.shape
H = n_heads

queries = query_projection(queries).reshape(B, L, H, -1)
keys = key_projection(keys).reshape(B, S, H, -1)
values = value_projection(values).reshape(B, S, H, -1)

In [48]:
scale = None

B, L, H, E = queries.shape
_, S, _, D = values.shape
scale = scale or 1.0 / sqrt(E)

scores = torch.einsum("blhe,bshe->bhls", queries, keys)

In [49]:
A = dropout(torch.softmax(scale * scores, dim=-1))
V = torch.einsum("bhls,bshd->blhd", A, values)
V = V.reshape(B, L, -1)

In [52]:
out = out_projection(V)

### 3.2 Attention Layer Class

In [72]:
class Attention(nn.Module):
    def __init__(
        self,
        d_model: int,
        n_heads: int,
        d_keys: int = None,
        d_values: int = None,
        scale: float = None,
        attention_dropout: float = 0.1,
        output_attention: bool = False,
    ):
        super(Attention, self).__init__()

        d_keys = d_keys or (d_model // n_heads)
        d_values = d_values or (d_model // n_heads)

        self.n_heads = n_heads
        self.scale = scale
        self.output_attention = output_attention

        self.query_projection = nn.Linear(d_model, d_keys * n_heads)
        self.key_projection = nn.Linear(d_model, d_keys * n_heads)
        self.value_projection = nn.Linear(d_model, d_values * n_heads)
        self.out_projection = nn.Linear(d_values * n_heads, d_model)
        self.dropout = nn.Dropout(attention_dropout)

    def forward(self, queries: torch.Tensor, keys: torch.Tensor, values: torch.Tensor):
        # Q, K, V projection
        B, L, _ = queries.shape
        _, S, _ = keys.shape
        H = self.n_heads

        queries = self.query_projection(queries).view(B, L, H, -1)
        keys = self.key_projection(keys).view(B, S, H, -1)
        values = self.value_projection(values).view(B, S, H, -1)

        # Scaled Dot-Product Attention
        B, L, H, E = queries.shape
        _, S, _, D = values.shape
        self.scale = self.scale or 1.0 / sqrt(E)

        scores = torch.einsum("blhe,bshe->bhls", queries, keys)
        A = self.dropout(torch.softmax(self.scale * scores, dim=-1))
        V = torch.einsum("bhls,bshd->blhd", A, values)

        out = self.out_projection(V.reshape(B, L, -1))

        if self.output_attention:
            return out, A
        else:
            return out, None

In [73]:
attn_params = {
    "d_model": 512,
    "n_heads": 8,
    "d_keys": None,
    "d_values": None,
    "scale": None,
    "attention_dropout": 0.1,
    "output_attention": True,
}

attn_layer = Attention(**attn_params)

In [74]:
x = embedding(x=batch["past_values"], x_features=batch["past_time_features"])

new_x, attn = attn_layer(queries=x, keys=x, values=x)

In [75]:
new_x.shape

torch.Size([32, 96, 512])

In [76]:
attn.shape

torch.Size([32, 8, 96, 96])