# Transformer - Encoder, Decoder layer

## 0. imports

In [1]:
%load_ext jupyter_black

In [2]:
import sys

sys.path.append("..")

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [4]:
from src.dataset import ETTDataModule
from src.model import DataEmbedding
from src.model import Attention

## 1. prev setting

In [5]:
dm_params = {
    "data_path": "../data/ETT-small/ETTh1.csv",
    "task": "M",
    "freq": "h",
    "target": "OT",
    "seq_len": 96,
    "label_len": 48,
    "pred_len": 96,
    "use_scaler": True,
    "use_time_enc": True,
    "batch_size": 32,
}


dm = ETTDataModule(**dm_params)

In [6]:
emb_params = {
    "c_in": 7,
    "d_model": 512,
    "embed_type": "time_features",
    "freq": "h",
    "dropout": 0.1,
}

embedding = DataEmbedding(**emb_params)

In [7]:
attn_params = {
    "d_model": 512,
    "n_heads": 8,
    "d_keys": None,
    "d_values": None,
    "scale": None,
    "attention_dropout": 0.1,
    "output_attention": True,
}

attn_layer = Attention(**attn_params)

In [8]:
train_dataloader = dm.train_dataloader()
batch = next(iter(train_dataloader))

In [9]:
x = embedding(x=batch["past_values"], x_features=batch["past_time_features"])

new_x, attn = attn_layer(queries=x, keys=x, values=x)

## 2. Encoder Layer

### 2.1 line by line

In [10]:
d_model = 512
dropout = 0.1
activation = "gelu"

d_ff = 2048

In [11]:
d_ff = d_ff or 4 * d_model
conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
norm1 = nn.LayerNorm(d_model)
norm2 = nn.LayerNorm(d_model)
dropout = nn.Dropout(dropout)
activation = F.relu if activation == "relu" else F.gelu

In [12]:
x = x + dropout(new_x)
y = x = norm1(x)

y = dropout(activation(conv1(y.transpose(-1, 1))))
y = dropout(conv2(y).transpose(-1, 1))
out = norm2(x + y)

### 2.2 EncoderLayer class

In [13]:
class EncoderLayer(nn.Module):
    def __init__(
        self,
        attention: nn.Module,
        d_model: int,
        d_ff: int = None,
        dropout: float = 0.1,
        activation: str = "relu",
    ):
        super(EncoderLayer, self).__init__()

        d_ff = d_ff or 4 * d_model
        self.attention = attention
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu

    def forward(self, x: torch.Tensor):
        # 1. compute attention
        new_x, attn = self.attention(queries=x, keys=x, values=x)

        # 2. add and norm
        x = x + self.dropout(new_x)
        y = x = self.norm1(x)

        # 3. positionwise feed forward
        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
        y = self.dropout(self.conv2(y).transpose(-1, 1))

        return self.norm2(x + y), attn

In [14]:
enc_layer_params = {
    "attention": Attention(**attn_params),
    "d_model": 512,
    "d_ff": 2048,
    "dropout": 0.1,
    "activation": "gelu",
}

enc_layer = EncoderLayer(**enc_layer_params)

In [15]:
enc_layer

EncoderLayer(
  (attention): Attention(
    (query_projection): Linear(in_features=512, out_features=512, bias=True)
    (key_projection): Linear(in_features=512, out_features=512, bias=True)
    (value_projection): Linear(in_features=512, out_features=512, bias=True)
    (out_projection): Linear(in_features=512, out_features=512, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,))
  (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,))
  (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

In [16]:
x = embedding(x=batch["past_values"], x_features=batch["past_time_features"])
out, attn = enc_layer(x)

In [17]:
out.shape

torch.Size([32, 96, 512])

In [48]:
attn.shape

torch.Size([32, 8, 96, 96])

## 3. Encoder block

### 3.1 line by line

In [18]:
num_layers = 2
norm_layer = None

encoder_layers = nn.ModuleList(
    [EncoderLayer(**enc_layer_params) for _ in range(num_layers)]
)

In [19]:
x = embedding(x=batch["past_values"], x_features=batch["past_time_features"])

attns = []
for enc_layer in encoder_layers:
    x, attn = enc_layer(x)
    attns.append(attn)

In [20]:
if norm_layer is not None:
    x = norm_layer(x)

### 3.2 Encoder class

In [21]:
class Encoder(nn.Module):
    def __init__(self, enc_layers: list[nn.Module], norm_layer: nn.Module = None):
        super(Encoder, self).__init__()

        self.enc_layers = nn.ModuleList(enc_layers)
        self.norm_layer = norm_layer

    def forward(self, x: torch.Tensor):
        attns = []
        for enc_layer in self.enc_layers:
            x, attn = enc_layer(x)
            attns.append(attn)

        if self.norm_layer is not None:
            x = self.norm_layer(x)

        return x, attns

In [22]:
d_model = 512
num_enc_layers: int = 2

encoder = Encoder(
    enc_layers=[EncoderLayer(**enc_layer_params) for _ in range(num_enc_layers)],
    norm_layer=nn.LayerNorm(d_model),
)

In [23]:
x = embedding(x=batch["past_values"], x_features=batch["past_time_features"])
out, attns = encoder(x)

In [24]:
out.shape

torch.Size([32, 96, 512])

In [25]:
len(attns)

2

## 4. Decoder Layer

### 4.1 line by line

In [26]:
pred_len = 96
label_len = 48

x = embedding(x=batch["past_values"], x_features=batch["past_time_features"])
enc_out, attn = enc_layer(x)

# decoder input
dec_inp = torch.zeros_like(batch["future_values"][:, -pred_len:, :]).float()
dec_inp = torch.cat([batch["future_values"][:, :label_len, :], dec_inp], dim=1).float()

In [27]:
# dec embedding
emb_params = {
    "c_in": 7,
    "d_model": 512,
    "embed_type": "time_features",
    "freq": "h",
    "dropout": 0.1,
}

dec_embedding = DataEmbedding(**emb_params)

In [28]:
# self & cross attention
attn_params = {
    "d_model": 512,
    "n_heads": 8,
    "d_keys": None,
    "d_values": None,
    "scale": None,
    "attention_dropout": 0.1,
    "output_attention": True,
}

self_attention = Attention(**attn_params)
cross_attention = Attention(**attn_params)

In [29]:
dropout = 0.1
activation = "gelu"


conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
norm1 = nn.LayerNorm(d_model)
norm2 = nn.LayerNorm(d_model)
norm3 = nn.LayerNorm(d_model)
dropout = nn.Dropout(dropout)
activation = F.relu if activation == "relu" else F.gelu

In [30]:
x = dec_embedding(dec_inp, batch["future_time_features"])

x = x + dropout(self_attention(queries=x, keys=x, values=x)[0])
x = norm1(x)

x = x + dropout(cross_attention(queries=x, keys=enc_out, values=enc_out)[0])
y = x = norm2(x)
y = dropout(activation(conv1(y.transpose(-1, 1))))
y = dropout(conv2(y).transpose(-1, 1))
out = norm3(x + y)

### 4.2 DecoderLayer Class

In [42]:
class DecoderLayer(nn.Module):
    def __init__(
        self,
        self_attention: nn.Module,
        cross_attention: nn.Module,
        d_model: int,
        d_ff: int = None,
        dropout: float = 0.1,
        activation: str = "relu",
    ):
        super(DecoderLayer, self).__init__()

        d_ff = d_ff or 4 * d_model
        self.self_attention = self_attention
        self.cross_attention = cross_attention
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu

    def forward(self, x: torch.Tensor, enc_out: torch.Tensor):
        # 1. compute self attention
        new_x, dec_attn = self.self_attention(queries=x, keys=x, values=x)
        x = x + self.dropout(new_x)
        x = self.norm1(x)

        # 2. compute cross attention
        new_x, cross_attn = self.cross_attention(
            queries=x, keys=enc_out, values=enc_out
        )
        x = x + self.dropout(new_x)
        y = x = self.norm2(x)

        # 3. positionwise feed forward
        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
        y = self.dropout(self.conv2(y).transpose(-1, 1))

        return self.norm3(x + y), dec_attn, cross_attn

In [43]:
dec_layer_params = {
    "self_attention": Attention(**attn_params),
    "cross_attention": Attention(**attn_params),
    "d_model": 512,
    "d_ff": 2048,
    "dropout": 0.1,
    "activation": "gelu",
}

dec_layer = DecoderLayer(**dec_layer_params)

In [44]:
x = dec_embedding(dec_inp, batch["future_time_features"])

out, dec_attn, cross_attn = dec_layer(x, enc_out)

In [45]:
out.shape

torch.Size([32, 144, 512])

In [46]:
dec_attn.shape

torch.Size([32, 8, 144, 144])

In [47]:
cross_attn.shape

torch.Size([32, 8, 144, 96])

## 5. Decoder block

### 5.1 line by line

In [35]:
num_dec_layers = 1
norm_layer = None
projection = None

decoder_layers = nn.ModuleList(
    [DecoderLayer(**dec_layer_params) for _ in range(num_layers)]
)

In [36]:
x = dec_embedding(dec_inp, batch["future_time_features"])

for layer in decoder_layers:
    x = layer(x, enc_out)

if norm_layer is not None:
    x = norm_layer(x)

if projection is not None:
    x = projection(x)

### 5.2 Decoder Class

In [49]:
class Decoder(nn.Module):
    def __init__(
        self,
        dec_layers: list[nn.Module],
        norm_layer: nn.Module = None,
        projection: nn.Module = None,
    ):
        super(Decoder, self).__init__()

        self.dec_layers = nn.ModuleList(dec_layers)
        self.norm_layer = norm_layer
        self.projection = projection

    def forward(self, x: torch.Tensor, enc_out: torch.Tensor):
        dec_attns, cross_attns = [], []
        for dec_layer in self.dec_layers:
            x, dec_attn, cross_attn = dec_layer(x, enc_out)
            dec_attns.append(dec_attn)
            cross_attns.append(cross_attn)

        if self.norm_layer is not None:
            x = self.norm_layer(x)

        if self.projection is not None:
            x = self.projection(x)

        return x, dec_attns, cross_attns

In [50]:
d_model = 512
num_dec_layers: int = 2
c_out = 7

decoder = Decoder(
    dec_layers=[DecoderLayer(**dec_layer_params) for _ in range(num_dec_layers)],
    norm_layer=nn.LayerNorm(d_model),
    projection=nn.Linear(d_model, c_out),
)

In [52]:
x = dec_embedding(dec_inp, batch["future_time_features"])

dec_out, dec_attns, cross_attns = decoder(x, enc_out)

In [53]:
dec_out.shape

torch.Size([32, 144, 7])