# Transformer Model

## 0. imports

In [2]:
%load_ext jupyter_black

In [3]:
import sys

sys.path.append("..")

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from collections import namedtuple

In [5]:
from src.dataset import ETTDataModule
from src.model import DataEmbedding
from src.model import Attention
from src.model import Encoder, Decoder
from src.model import EncoderLayer, DecoderLayer

## 1. prev setting

In [6]:
dm_params = {
    "data_path": "../data/ETT-small/ETTh1.csv",
    "task": "M",
    "freq": "h",
    "target": "OT",
    "seq_len": 96,
    "label_len": 48,
    "pred_len": 96,
    "use_scaler": True,
    "use_time_enc": True,
    "batch_size": 32,
}


dm = ETTDataModule(**dm_params)

In [7]:
emb_params = {
    "c_in": 7,
    "d_model": 512,
    "embed_type": "time_features",
    "freq": "h",
    "dropout": 0.1,
}

enc_embedding = DataEmbedding(**emb_params)
dec_embedding = DataEmbedding(**emb_params)

In [8]:
attn_params = {
    "d_model": 512,
    "n_heads": 8,
    "d_keys": None,
    "d_values": None,
    "scale": None,
    "attention_dropout": 0.1,
    "output_attention": True,
}

attn_layer = Attention(**attn_params)

In [9]:
enc_layer_params = {
    "attention": Attention(**attn_params),
    "d_model": 512,
    "d_ff": 2048,
    "dropout": 0.1,
    "activation": "gelu",
}

dec_layer_params = {
    "self_attention": Attention(**attn_params),
    "cross_attention": Attention(**attn_params),
    "d_model": 512,
    "d_ff": 2048,
    "dropout": 0.1,
    "activation": "gelu",
}


d_model = 512
num_enc_layers: int = 2
num_dec_layers: int = 1
c_out = 7

encoder = Encoder(
    enc_layers=[EncoderLayer(**enc_layer_params) for _ in range(num_enc_layers)],
    norm_layer=nn.LayerNorm(d_model),
)

decoder = Decoder(
    dec_layers=[DecoderLayer(**dec_layer_params) for _ in range(num_dec_layers)],
    norm_layer=nn.LayerNorm(d_model),
    projection=nn.Linear(d_model, c_out),
)

## 2. Transformer

In [48]:
# transformer_params
Config = namedtuple(
    "Config",
    [
        "task_name",
        "pred_len",
        "seq_len",
        "num_class",
        "enc_in",
        "dec_in",
        "c_out",
        "d_model",
        "embed_type",
        "freq",
        "dropout",
        "n_heads",
        "d_keys",
        "d_values",
        "d_ff",
        "scale",
        "attention_dropout",
        "output_attention",
        "activation",
        "num_enc_layers",
        "num_dec_layers",
    ],
)

configs = Config(
    task_name="long_term_forecast",
    pred_len=96,
    seq_len=None,
    num_class=None,
    enc_in=7,
    dec_in=7,
    c_out=7,
    d_model=512,
    embed_type="time_features",
    freq="h",
    dropout=0.1,
    n_heads=8,
    d_keys=None,
    d_values=None,
    d_ff=2048,
    scale=None,
    attention_dropout=0.1,
    output_attention=True,
    activation="gelu",
    num_enc_layers=2,
    num_dec_layers=1,
)

In [62]:
class Transformer(nn.Module):
    def __init__(
        self,
        task_name: str = "long_term_forecast",
        pred_len: int = 96,
        seq_len: int = None,
        num_class: int = None,
        enc_in: int = 7,
        dec_in: int = 7,
        c_out: int = 7,
        d_model: int = 512,
        embed_type: str = "time_features",
        freq: str = "h",
        dropout: float = 0.1,
        n_heads: int = 8,
        d_keys: int = None,
        d_values: int = None,
        d_ff: int = 2048,
        scale: float = None,
        attention_dropout: float = 0.1,
        output_attention: bool = True,
        activation: str = "gelu",
        num_enc_layers: int = 2,
        num_dec_layers: int = 1,
    ):
        super(Transformer, self).__init__()

        self.task_name = task_name
        self.pred_len = pred_len
        self.output_attention = output_attention

        # 1. Encoder embedding layer
        self.enc_embedding = DataEmbedding(enc_in, d_model, embed_type, freq, dropout)

        # 2. Encoder
        enc_layer = EncoderLayer(
            attention=Attention(
                d_model,
                n_heads,
                d_keys,
                d_values,
                scale,
                attention_dropout,
                output_attention,
            ),
            d_model=d_model,
            d_ff=d_ff,
            dropout=dropout,
            activation=activation,
        )

        self.encoder = Encoder(
            enc_layers=[enc_layer for _ in range(num_enc_layers)],
            norm_layer=nn.LayerNorm(d_model),
        )

        # 3. Decoder
        if (
            self.task_name == "long_term_forecast"
            or self.task_name == "short_term_forecast"
        ):
            # 3.1 Decoder embedding layer
            self.dec_embedding = DataEmbedding(
                dec_in, d_model, embed_type, freq, dropout
            )

            # 3.2 Decoder
            dec_layer = DecoderLayer(
                self_attention=Attention(
                    d_model,
                    n_heads,
                    d_keys,
                    d_values,
                    scale,
                    attention_dropout,
                    output_attention,
                ),
                cross_attention=Attention(
                    d_model,
                    n_heads,
                    d_keys,
                    d_values,
                    scale,
                    attention_dropout,
                    output_attention,
                ),
                d_model=d_model,
                d_ff=d_ff,
                dropout=dropout,
                activation=activation,
            )

            self.decoder = Decoder(
                dec_layers=[dec_layer for _ in range(num_dec_layers)],
                norm_layer=nn.LayerNorm(d_model),
                projection=nn.Linear(d_model, c_out),
            )
        elif self.task_name == "imputation" or self.task_name == "anomaly_detection":
            self.projection = nn.Linear(d_model, c_out)
        elif self.task_name == "classification":
            self.dropout = nn.Dropout(dropout)
            self.projection = nn.Linear(d_model * seq_len, num_class)

    def forecast(
        self,
        past_values: torch.Tensor,
        past_time_features: torch.Tensor,
        future_values: torch.Tensor,
        future_time_features: torch.Tensor,
    ):
        enc_emb = self.enc_embedding(
            x=past_time_features, x_features=past_time_features
        )
        enc_out, enc_attns = self.encoder(enc_emb)

        dec_emb = self.dec_embedding(x=future_values, x_features=future_time_features)
        dec_out, dec_attns, cross_attns = self.decoder(dec_emb, enc_out)

        return_dict = {"decoder_hidden_states": dec_out}
        if self.output_attention:
            return_dict["encoder_attentions"] = enc_attns
            return_dict["decoder_attentions"] = dec_attns
            return_dict["cross_attentions"] = cross_attns

        return return_dict

    def forward(
        self,
        past_values: torch.Tensor,
        past_time_features: torch.Tensor,
        future_values: torch.Tensor,
        future_time_features: torch.Tensor,
    ):
        if (
            self.task_name == "long_term_forecast"
            or self.task_name == "short_term_forecast"
        ):
            output = self.forecast(
                past_values, past_time_features, future_values, future_time_features
            )

            return output
            # return dec_out[:, -self.pred_len :, :]  # [B, L, D]

In [63]:
# class Transformer(nn.Module):
#     def __init__(
#         self,
#         task_name: str = "long_term_forecast",
#         pred_len: int = 96,
#         enc_in: int = 7,
#         dec_in: int = 7,
#         c_out: int = 7,
#         d_model: int = 512,
#         embed_type: str = "time_features",
#         freq: str = "h",
#         dropout: float = 0.1,
#         n_heads: int = 8,
#         d_keys: int = None,
#         d_values: int = None,
#         d_ff: int = 2048,
#         scale: float = None,
#         attention_dropout: float = 0.1,
#         output_attention: bool = True,
#         activation: str = "gelu",
#         num_enc_layers: int = 2,
#         num_dec_layers: int = 1,
#     ):
#         super(Transformer, self).__init__()

#         self.task_name = task_name
#         self.pred_len = pred_len
#         self.output_attention = output_attention

#         self.enc_emb_params = {
#             "c_in": enc_in,
#             "d_model": d_model,
#             "embed_type": embed_type,
#             "freq": freq,
#             "dropout": dropout,
#         }

#         self.dec_emb_params = {
#             "c_in": dec_in,
#             "d_model": d_model,
#             "embed_type": embed_type,
#             "freq": freq,
#             "dropout": dropout,
#         }

#         self.enc_attn_params = {
#             "d_model": d_model,
#             "n_heads": n_heads,
#             "d_keys": d_keys,
#             "d_values": d_values,
#             "scale": scale,
#             "attention_dropout": attention_dropout,
#             "output_attention": output_attention,
#         }

#         self.enc_layer_params = {
#             "attention": Attention(**self.enc_attn_params),
#             "d_model": d_model,
#             "d_ff": d_ff,
#             "dropout": dropout,
#             "activation": activation,
#         }

#         self.dec_attn_params = {
#             "d_model": d_model,
#             "n_heads": n_heads,
#             "d_keys": d_keys,
#             "d_values": d_values,
#             "scale": scale,
#             "attention_dropout": attention_dropout,
#             "output_attention": False,
#         }

#         self.dec_layer_params = {
#             "self_attention": Attention(**self.dec_attn_params),
#             "cross_attention": Attention(**self.dec_attn_params),
#             "d_model": d_model,
#             "d_ff": d_ff,
#             "dropout": dropout,
#             "activation": activation,
#         }

In [64]:
model = Transformer(**configs._asdict())

In [65]:
train_dataloader = dm.train_dataloader()
batch = next(iter(train_dataloader))

# decoder input
label_len = 48
dec_inp = torch.zeros_like(batch["future_values"][:, -configs.pred_len :, :]).float()
dec_inp = torch.cat([batch["future_values"][:, :label_len, :], dec_inp], dim=1).float()

In [66]:
output = model(
    past_values=batch["past_values"],
    past_time_features=batch["past_time_features"],
    future_values=dec_inp,
    future_time_features=batch["future_time_features"],
)

RuntimeError: Given groups=1, weight of size [512, 7, 3], expected input[32, 4, 98] to have 7 channels, but got 4 channels instead