In [None]:
from model import *  # embedding layers, normalization layer
from data import *  # data loader
from utils import *
from exp import *

import os
import sys
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import numpy as np

### Time-domain attention

Denote input queries, keys and values as $\mathbf{q} \in \mathbb{R}^{L \times D}, \mathbf{k} \in \mathbb{R}^{L \times D}, \mathbf{v} \in \mathbb{R}^{L \times D}$, which are transformed from input $\mathbf{x}$ through linear embeddings. Denote output of attention module as $\mathbf{o} ( \mathbf{q}, \mathbf{k}, \mathbf{v} ) \in \mathbb{R}^{L \times D}$. Time-domain attention calculates attention in the original time domain as follows:

$$\mathbf{o} ( \mathbf{q}, \mathbf{k}, \mathbf{v} ) = \sigma \left(\frac{\mathbf{q}\mathbf{k}^T}{\sqrt{d_q}}\right)\mathbf{v}$$

where $d_q$ is the dimension for queries that serves as normalization term in attention operation, and $\sigma(\cdot)$ represents activation function. When  $\sigma(\cdot)=\mathrm{softmax}(\cdot)$ ($\mathrm{softmax}(\mathbf{x}) = \frac{e^{x_i}}{\sum_i e^{x_i}}$), we have softmax attention: $\mathbf{o} ( \mathbf{q}, \mathbf{k}, \mathbf{v} ) = \mathrm{softmax} \left({\mathbf{q}\mathbf{k}^T}/{\sqrt{d_q}}\right)\mathbf{v}$. When $\sigma(\cdot)=\mathrm{Id}(\cdot)$ (identity mapping), we have linear attention: $\mathbf{o}( \mathbf{q}, \mathbf{k}, \mathbf{v} ) = \left(\mathbf{q}\mathbf{k}^T/\sqrt{d_q}\right)\mathbf{v}$


In [None]:
from math import sqrt
import os


class TriangularCausalMask:
    def __init__(self, B, L, device="cpu"):
        mask_shape = [B, 1, L, L]
        with torch.no_grad():
            self._mask = torch.triu(
                torch.ones(mask_shape, dtype=torch.bool), diagonal=1
            ).to(device)

    @property
    def mask(self):
        return self._mask


class FullAttention(nn.Module):
    def __init__(
        self,
        mask_flag=True,
        factor=3,
        scale=None,
        attention_dropout=0.1,
        T=1,
        activation="softmax",
        output_attention=False,
    ):
        super(FullAttention, self).__init__()
        self.scale = scale
        self.mask_flag = mask_flag
        self.activation = activation
        self.output_attention = output_attention
        self.dropout = nn.Dropout(attention_dropout)
        self.T = T

    def forward(self, queries, keys, values, attn_mask):
        B, L, H, E = queries.shape
        _, S, _, D = values.shape
        scale = self.scale or 1.0 / sqrt(E)

        scores = torch.einsum("blhe,bshe->bhls", queries, keys) * scale

        if self.activation == "softmax":
            if self.mask_flag:
                if attn_mask is None:
                    attn_mask = TriangularCausalMask(
                        B, L, device=queries.device
                    )

                scores.masked_fill_(attn_mask.mask, -np.inf)

            A = self.dropout(torch.softmax(scores / self.T, dim=-1))
            V = torch.einsum("bhls,bshd->blhd", A, values)

        elif self.activation == "linear":
            V = torch.einsum("bhls,bshd->blhd", scores, values)

        elif self.activation == "linear_norm":
            mins = (
                scores.min(dim=-1)[0]
                .unsqueeze(-1)
                .expand(-1, -1, -1, scores.shape[3])
            )
            scores = scores - mins + 1e-8

            if self.mask_flag:
                if attn_mask is None:
                    attn_mask = TriangularCausalMask(
                        B, L, device=queries.device
                    )
                scores.masked_fill_(attn_mask.mask, 0)

            sums = (
                scores.sum(dim=-1)
                .unsqueeze(-1)
                .expand(-1, -1, -1, scores.shape[3])
            )
            scores /= sums
            V = torch.einsum("bhls,bshd->blhd", scores, values)

        if self.output_attention:
            return (V.contiguous(), A)
        else:
            return (V.contiguous(), None)

### Fourier-domain attention

Fourier attention first converts queries, keys, and values with Fourier Transform, performs a similar attention mechanism in the frequency domain, and finally converts the results back to the time domain using inverse Fourier transform. Let $\mathcal{F}(\cdot), \mathcal{F}^{-1}(\cdot)$ denote Fourier transform and inverse Fourier transform, then Fourier attention is $\mathbf{o} ( \mathbf{q}, \mathbf{k}, \mathbf{v} ) = \mathcal{F}^{-1} \Big(\sigma\big({\mathcal{F}(\mathbf{q})\overline{\mathcal{F}(\mathbf{k}}})^T/{\sqrt{d_q}}\big)\mathcal{F}(\mathbf{v})\Big)$.

In [None]:
class FourierAttention(nn.Module):
    def __init__(self, T=1, activation="softmax", output_attention=False):
        super(FourierAttention, self).__init__()
        print(" fourier enhanced cross attention used!")
        """
        1D Fourier Cross Attention layer. It does FFT, linear transform, attention mechanism and Inverse FFT.    
        """
        self.activation = activation
        self.output_attention = output_attention
        self.T = T

    def forward(self, q, k, v, mask):
        # size = [B, L, H, E]
        B, L, H, E = q.shape
        _, S, H, E = k.shape
        xq = q.permute(0, 2, 3, 1)  # size = [B, H, E, L]
        xk = k.permute(0, 2, 3, 1)
        xv = v.permute(0, 2, 3, 1)

        xq_ft_ = torch.fft.rfft(xq, dim=-1, norm="ortho")
        xk_ft_ = torch.fft.rfft(xk, dim=-1, norm="ortho")
        xv_ft_ = torch.fft.rfft(xv, dim=-1, norm="ortho")

        xqk_ft = torch.einsum(
            "bhex,bhey->bhxy", xq_ft_, torch.conj(xk_ft_)
        ) / sqrt(E)

        if self.activation == "softmax":
            xqk_ft = torch.softmax(xqk_ft.abs() / self.T, dim=-1)
            xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft))
            xqkv_ft = torch.einsum("bhxy,bhey->bhex", xqk_ft, xv_ft_)

        elif self.activation == "linear":
            xqkv_ft = torch.einsum("bhxy,bhey->bhex", xqk_ft, xv_ft_)

        elif self.activation == "linear_norm":
            mins_real = (
                xqk_ft.real.min(dim=-1)[0]
                .unsqueeze(-1)
                .expand(-1, -1, -1, xqk_ft.shape[3])
            )
            xqk_ft_real = xqk_ft.real - mins_real
            sums_real = (
                xqk_ft_real.sum(dim=-1)
                .unsqueeze(-1)
                .expand(-1, -1, -1, xqk_ft.shape[3])
            )
            xqk_ft_real /= sums_real

            mins_imag = (
                xqk_ft.imag.min(dim=-1)[0]
                .unsqueeze(-1)
                .expand(-1, -1, -1, xqk_ft.shape[3])
            )
            xqk_ft_imag = xqk_ft.imag - mins_imag
            sums_imag = (
                xqk_ft_imag.sum(dim=-1)
                .unsqueeze(-1)
                .expand(-1, -1, -1, xqk_ft.shape[3])
            )
            xqk_ft_imag /= sums_imag

            xqkv_ft_real = torch.einsum(
                "bhxy,bhey->bhex", xqk_ft_real, xv_ft_.real
            )
            xqkv_ft_imag = torch.einsum(
                "bhxy,bhey->bhex", xqk_ft_imag, xv_ft_.imag
            )
            xqkv_ft = torch.complex(xqkv_ft_real, xqkv_ft_imag)

        elif self.activation == "linear_norm_abs":
            xqk_ft = xqk_ft.abs() / xqk_ft.abs().sum(dim=-1).unsqueeze(
                -1
            ).expand(-1, -1, -1, xqk_ft.shape[3])
            xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft))
            xqkv_ft = torch.einsum("bhxy,bhey->bhex", xqk_ft, xv_ft_)

        elif self.activation == "linear_norm_real":
            mins_real = (
                xqk_ft.real.min(dim=-1)[0]
                .unsqueeze(-1)
                .expand(-1, -1, -1, xqk_ft.shape[3])
            )
            xqk_ft_real = xqk_ft.real - mins_real
            sums_real = (
                xqk_ft_real.sum(dim=-1)
                .unsqueeze(-1)
                .expand(-1, -1, -1, xqk_ft.shape[3])
            )
            xqk_ft_real /= sums_real

            xqk_ft = torch.complex(xqk_ft_real, torch.zeros_like(xqk_ft_real))
            xqkv_ft = torch.einsum("bhxy,bhey->bhex", xqk_ft, xv_ft_)

        out = torch.fft.irfft(xqkv_ft, n=L, dim=-1, norm="ortho").permute(
            0, 3, 1, 2
        )

        if self.output_attention == False:
            return (out, None)
        else:
            return (out, (xqk_ft_real, xqk_ft_imag))

### Wavelet-domain attention

Wavelet transform applies wavelet decomposition and reconstruction to obtain signals of different scales. Wavelet attention performs attention calculation to decomposed queries, keys, and values in each scale, and reconstructs the output from attention results in each scale. Let $\mathcal{W}(\cdot), \mathcal{W}^{-1}(\cdot)$ denote wavelet decomposition and wavelet reconstruction, then wavelet attention is $\mathbf{o} ( \mathbf{q}, \mathbf{k}, \mathbf{v} ) = \mathcal{W}^{-1}\Big(\sigma\left({\mathcal{W}(\mathbf{q})\mathcal{W}(\mathbf{k}^T})/{\sqrt{d_q}}\right)\mathcal{W}(\mathbf{v})\Big)$.

In [None]:
class sparseKernel1d(nn.Module):
    def __init__(self, k, c=1, nl=1, initializer=None, **kwargs):
        super(sparseKernel1d, self).__init__()

        self.k = k
        self.Li = nn.Linear(c * k, 128)
        self.conv = self.convBlock(c * k, 128)
        self.Lo = nn.Linear(128, c * k)

    def forward(self, x):
        B, N, c, ich = x.shape  # (B, N, c, k)
        x = x.view(B, N, -1)
        x = x.permute(0, 2, 1)
        x = self.conv(x)
        x = x.permute(0, 2, 1)
        x = self.Lo(x)
        x = x.view(B, N, c, ich)
        return x

    def convBlock(self, ich, och):
        net = nn.Sequential(
            nn.Conv1d(ich, och, 3, 1, 1),
            nn.ReLU(inplace=True),
        )
        return net


class WaveletAttention(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        seq_len_q,
        seq_len_kv,
        c=64,
        k=8,
        ich=512,
        L=3,
        base="legendre",
        initializer=None,
        T=1,
        activation="softmax",
        output_attention=False,
        **kwargs
    ):
        super(WaveletAttention, self).__init__()
        print("base", base)

        self.c = c
        self.k = k
        self.L = L
        self.T = T
        self.activation = activation
        H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k)
        H0r = H0 @ PHI0
        G0r = G0 @ PHI0
        H1r = H1 @ PHI1
        G1r = G1 @ PHI1

        H0r[np.abs(H0r) < 1e-8] = 0
        H1r[np.abs(H1r) < 1e-8] = 0
        G0r[np.abs(G0r) < 1e-8] = 0
        G1r[np.abs(G1r) < 1e-8] = 0
        self.max_item = 3

        self.register_buffer(
            "ec_s", torch.Tensor(np.concatenate((H0.T, H1.T), axis=0))
        )
        self.register_buffer(
            "ec_d", torch.Tensor(np.concatenate((G0.T, G1.T), axis=0))
        )

        self.register_buffer(
            "rc_e", torch.Tensor(np.concatenate((H0r, G0r), axis=0))
        )
        self.register_buffer(
            "rc_o", torch.Tensor(np.concatenate((H1r, G1r), axis=0))
        )

        self.Lk = nn.Linear(ich, c * k)
        self.Lq = nn.Linear(ich, c * k)
        self.Lv = nn.Linear(ich, c * k)
        self.out = nn.Linear(c * k, ich)

        self.output_attention = output_attention

    def forward(self, q, k, v, mask=None):
        B, N, H, E = q.shape  # (B, N, H, E) torch.Size([3, 768, 8, 2])
        _, S, _, _ = k.shape  # (B, S, H, E) torch.Size([3, 96, 8, 2])

        q = q.view(q.shape[0], q.shape[1], -1)  # (B, N, H*E)
        k = k.view(k.shape[0], k.shape[1], -1)
        v = v.view(v.shape[0], v.shape[1], -1)
        q = self.Lq(q)
        q = q.view(q.shape[0], q.shape[1], self.c, self.k)  # (B, N, E, H)
        k = self.Lk(k)
        k = k.view(k.shape[0], k.shape[1], self.c, self.k)
        v = self.Lv(v)
        v = v.view(v.shape[0], v.shape[1], self.c, self.k)

        if N > S:
            zeros = torch.zeros_like(q[:, : (N - S), :]).float()
            v = torch.cat([v, zeros], dim=1)
            k = torch.cat([k, zeros], dim=1)
        else:
            v = v[:, :N, :, :]
            k = k[:, :N, :, :]
        ns = math.floor(np.log2(N))
        nl = pow(2, math.ceil(np.log2(N)))
        extra_q = q[:, 0 : nl - N, :, :]
        extra_k = k[:, 0 : nl - N, :, :]
        extra_v = v[:, 0 : nl - N, :, :]
        q = torch.cat([q, extra_q], 1)
        k = torch.cat([k, extra_k], 1)
        v = torch.cat([v, extra_v], 1)

        Ud = torch.jit.annotate(List[Tensor], [])
        Us = torch.jit.annotate(List[Tensor], [])

        attn_d_list, attn_s_list = [], []

        for i in range(ns - self.L):
            dq, q = self.wavelet_transform(q)
            dk, k = self.wavelet_transform(k)
            dv, v = self.wavelet_transform(v)  # (B, N, E, H)

            scores_d = torch.einsum("bxeh,byeh->bhxy", dq, dk) / sqrt(E)

            if self.activation == "softmax":
                attn_d = F.softmax(scores_d / self.T, dim=-1)  # (B,H,q,k)
            elif self.activation == "linear":
                attn_d = scores_d  # (B,H,q,k)
            elif self.activation == "linear_norm":
                attn_d = scores_d  # (B,H,q,k)
                mins = (
                    attn_d.min(dim=-1)
                    .unsqueeze(-1)
                    .expand(-1, -1, -1, attn_d.shape[3])
                )
                attn_d -= mins
                sums = (
                    attn_d.sum(dim=-1)
                    .unsqueeze(-1)
                    .expand(-1, -1, -1, attn_d.shape[3])
                )
                attn_d /= sums
            Ud += [torch.einsum("bhxy,byeh->bxeh", attn_d, dv)]
            attn_d_list.append(attn_d)

            scores_s = torch.einsum("bxeh,byeh->bhxy", q, k) / sqrt(E)

            if self.activation == "softmax":
                attn_s = F.softmax(scores_s / self.T, dim=-1)  # (B,H,q,k)
            elif self.activation == "linear":
                attn_s = scores_s  # (B,H,q,k)
            elif self.activation == "linear_norm":
                attn_s = scores_s  # (B,H,q,k)
                mins = (
                    attn_s.min(dim=-1)
                    .unsqueeze(-1)
                    .expand(-1, -1, -1, attn_s.shape[3])
                )
                attn_s -= mins
                sums = (
                    attn_s.sum(dim=-1)
                    .unsqueeze(-1)
                    .expand(-1, -1, -1, attn_s.shape[3])
                )
                attn_s /= sums
            Us += [torch.einsum("bhxy,byeh->bxeh", attn_s, v)]
            attn_s_list.append(attn_s)

        # reconstruct
        for i in range(ns - 1 - self.L, -1, -1):
            v = v + Us[i]
            v = torch.cat((v, Ud[i]), -1)
            v = self.evenOdd(v)
        v = self.out(v[:, :N, :, :].contiguous().view(B, N, -1))
        if self.output_attention == False:
            return (v.contiguous(), None)
        else:
            return (v.contiguous(), (attn_s_list, attn_d_list))

    def wavelet_transform(self, x):
        xa = torch.cat(
            [
                x[:, ::2, :, :],
                x[:, 1::2, :, :],
            ],
            -1,
        )
        d = torch.matmul(xa, self.ec_d)
        s = torch.matmul(xa, self.ec_s)
        return d, s

    def evenOdd(self, x):
        B, N, c, ich = x.shape  # (B, N, c, k)
        assert ich == 2 * self.k
        x_e = torch.matmul(x, self.rc_e)
        x_o = torch.matmul(x, self.rc_o)

        x = torch.zeros(B, N * 2, c, self.k, device=x.device)
        x[..., ::2, :, :] = x_e
        x[..., 1::2, :, :] = x_o
        return x

### Transformer

The vanilla Transformer, can specify which attention to use (time, Fourier, wavelet)

In [None]:
class Transformer(nn.Module):
    """
    Vanilla Transformer with O(L^2) complexity
    """

    def __init__(self, configs):
        super(Transformer, self).__init__()
        self.pred_len = configs.pred_len
        self.output_attention = configs.output_attention
        self.enc_embedding = DataEmbedding(
            configs.enc_in,
            configs.d_model,
            configs.embed,
            configs.freq,
            configs.dropout,
        )
        self.dec_embedding = DataEmbedding(
            configs.dec_in,
            configs.d_model,
            configs.embed,
            configs.freq,
            configs.dropout,
        )

        # Encoder
        if configs.version == "Wavelet":
            enc_self_attention = WaveletAttention(
                in_channels=configs.d_model,
                out_channels=configs.d_model,
                seq_len_q=configs.seq_len,
                seq_len_kv=configs.seq_len,
                ich=configs.d_model,
                T=configs.temp,
                activation=configs.activation,
                output_attention=configs.output_attention,
            )
            dec_self_attention = WaveletAttention(
                in_channels=configs.d_model,
                out_channels=configs.d_model,
                seq_len_q=configs.seq_len // 2 + configs.pred_len,
                seq_len_kv=configs.seq_len // 2 + configs.pred_len,
                ich=configs.d_model,
                T=configs.temp,
                activation=configs.activation,
                output_attention=configs.output_attention,
            )
            dec_cross_attention = WaveletAttention(
                in_channels=configs.d_model,
                out_channels=configs.d_model,
                seq_len_q=configs.seq_len // 2 + configs.pred_len,
                seq_len_kv=configs.seq_len,
                ich=configs.d_model,
                T=configs.temp,
                activation=configs.activation,
                output_attention=configs.output_attention,
            )
        elif configs.version == "Fourier":
            enc_self_attention = FourierAttention(
                T=configs.temp,
                activation=configs.activation,
                output_attention=configs.output_attention,
            )
            dec_self_attention = FourierAttention(
                T=configs.temp,
                activation=configs.activation,
                output_attention=configs.output_attention,
            )
            dec_cross_attention = FourierAttention(
                T=configs.temp,
                activation=configs.activation,
                output_attention=configs.output_attention,
            )
        elif configs.version == "Time":
            enc_self_attention = FullAttention(
                False,
                T=configs.temp,
                activation=configs.activation,
                attention_dropout=configs.dropout,
                output_attention=configs.output_attention,
            )
            dec_self_attention = FullAttention(
                True,
                T=configs.temp,
                activation=configs.activation,
                attention_dropout=configs.dropout,
                output_attention=configs.output_attention,
            )
            dec_cross_attention = FullAttention(
                False,
                T=configs.temp,
                activation=configs.activation,
                attention_dropout=configs.dropout,
                output_attention=configs.output_attention,
            )
        # Encoder
        self.encoder = Encoder(
            [
                EncoderLayer(
                    AttentionLayer(enc_self_attention, configs.d_model),
                    configs.d_model,
                    dropout=configs.dropout,
                    activation=configs.activation,
                )
                for l in range(2)
            ],
            norm_layer=torch.nn.LayerNorm(configs.d_model),
        )
        # Decoder
        self.decoder = Decoder(
            [
                DecoderLayer(
                    AttentionLayer(dec_self_attention, configs.d_model),
                    AttentionLayer(dec_cross_attention, configs.d_model),
                    configs.d_model,
                    dropout=configs.dropout,
                    activation=configs.activation,
                )
                for l in range(1)
            ],
            norm_layer=torch.nn.LayerNorm(configs.d_model),
            projection=nn.Linear(configs.d_model, configs.c_out, bias=True),
        )

    def forward(
        self,
        x_enc,
        x_mark_enc,
        x_dec,
        x_mark_dec,
        enc_self_mask=None,
        dec_self_mask=None,
        dec_enc_mask=None,
    ):
        enc_out = self.enc_embedding(x_enc, x_mark_enc)
        enc_out, attn_e = self.encoder(enc_out, attn_mask=enc_self_mask)

        dec_out = self.dec_embedding(x_dec, x_mark_dec)
        dec_out, attn_d = self.decoder(
            dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask
        )

        if self.output_attention:
            return dec_out[:, -self.pred_len :, :], (attn_e, attn_d)
        else:
            return dec_out[:, -self.pred_len :, :]

### MLP

3-layer MLP

In [None]:
class MLP(nn.Module):
    def __init__(self, configs):
        super(MLP, self).__init__()
        self.seq_len = configs.seq_len
        self.label_len = configs.label_len
        self.pred_len = configs.pred_len

        # Encoder
        self.mlp = nn.Sequential(
            nn.Linear(configs.seq_len, configs.d_model),
            nn.ReLU(),
            nn.Linear(configs.d_model, configs.d_model),
            nn.ReLU(),
            nn.Linear(configs.d_model, configs.pred_len),
        )

    def forward(
        self,
        x_enc,
        x_mark_enc,
        x_dec,
        x_mark_dec,
        enc_self_mask=None,
        dec_self_mask=None,
        dec_enc_mask=None,
    ):
        out = self.mlp(x_enc.permute(0, 2, 1)).permute(0, 2, 1)

        return out

### TDformer

First apply seasonal-trend decomposition, then use MLP to model the trend, and Fourier attention to model the seasonal part, and add them together to obtain the final prediction.

In [None]:
class TDformer(nn.Module):
    """
    Transformer for seasonality, MLP for trend
    """

    def __init__(self, configs):
        super(TDformer, self).__init__()
        self.version = configs.version
        self.seq_len = configs.seq_len
        self.label_len = configs.label_len
        self.pred_len = configs.pred_len
        self.output_attention = configs.output_attention
        self.output_stl = configs.output_stl
        self.device = configs.device

        # Decomp
        kernel_size = configs.moving_avg
        if isinstance(kernel_size, list):
            self.decomp = series_decomp_multi(kernel_size)
        else:
            self.decomp = series_decomp(kernel_size)

        # Embedding
        self.enc_seasonal_embedding = DataEmbedding(
            configs.enc_in,
            configs.d_model,
            configs.embed,
            configs.freq,
            configs.dropout,
        )
        self.dec_seasonal_embedding = DataEmbedding(
            configs.dec_in,
            configs.d_model,
            configs.embed,
            configs.freq,
            configs.dropout,
        )
        # Encoder
        if configs.version == "Wavelet":
            enc_self_attention = WaveletAttention(
                in_channels=configs.d_model,
                out_channels=configs.d_model,
                seq_len_q=configs.seq_len,
                seq_len_kv=configs.seq_len,
                ich=configs.d_model,
                T=configs.temp,
                activation=configs.activation,
                output_attention=configs.output_attention,
            )
            dec_self_attention = WaveletAttention(
                in_channels=configs.d_model,
                out_channels=configs.d_model,
                seq_len_q=configs.seq_len // 2 + configs.pred_len,
                seq_len_kv=configs.seq_len // 2 + configs.pred_len,
                ich=configs.d_model,
                T=configs.temp,
                activation=configs.activation,
                output_attention=configs.output_attention,
            )
            dec_cross_attention = WaveletAttention(
                in_channels=configs.d_model,
                out_channels=configs.d_model,
                seq_len_q=configs.seq_len // 2 + configs.pred_len,
                seq_len_kv=configs.seq_len,
                ich=configs.d_model,
                T=configs.temp,
                activation=configs.activation,
                output_attention=configs.output_attention,
            )
        elif configs.version == "Fourier":
            enc_self_attention = FourierAttention(
                T=configs.temp,
                activation=configs.activation,
                output_attention=configs.output_attention,
            )
            dec_self_attention = FourierAttention(
                T=configs.temp,
                activation=configs.activation,
                output_attention=configs.output_attention,
            )
            dec_cross_attention = FourierAttention(
                T=configs.temp,
                activation=configs.activation,
                output_attention=configs.output_attention,
            )
        elif configs.version == "Time":
            enc_self_attention = FullAttention(
                False,
                T=configs.temp,
                activation=configs.activation,
                attention_dropout=configs.dropout,
                output_attention=configs.output_attention,
            )
            dec_self_attention = FullAttention(
                True,
                T=configs.temp,
                activation=configs.activation,
                attention_dropout=configs.dropout,
                output_attention=configs.output_attention,
            )
            dec_cross_attention = FullAttention(
                False,
                T=configs.temp,
                activation=configs.activation,
                attention_dropout=configs.dropout,
                output_attention=configs.output_attention,
            )
        # Encoder
        self.seasonal_encoder = Encoder(
            [
                EncoderLayer(
                    AttentionLayer(enc_self_attention, configs.d_model),
                    configs.d_model,
                    dropout=configs.dropout,
                    activation=configs.activation,
                )
                for l in range(2)
            ],
            norm_layer=torch.nn.LayerNorm(configs.d_model),
        )
        # Decoder
        self.seasonal_decoder = Decoder(
            [
                DecoderLayer(
                    AttentionLayer(dec_self_attention, configs.d_model),
                    AttentionLayer(dec_cross_attention, configs.d_model),
                    configs.d_model,
                    dropout=configs.dropout,
                    activation=configs.activation,
                )
                for l in range(1)
            ],
            norm_layer=torch.nn.LayerNorm(configs.d_model),
            projection=nn.Linear(configs.d_model, configs.c_out, bias=True),
        )

        # Encoder
        self.trend = nn.Sequential(
            nn.Linear(configs.seq_len, configs.d_model),
            nn.ReLU(),
            nn.Linear(configs.d_model, configs.d_model),
            nn.ReLU(),
            nn.Linear(configs.d_model, configs.pred_len),
        )

        self.revin_trend = RevIN(configs.enc_in).to(self.device)

    def forward(
        self,
        x_enc,
        x_mark_enc,
        x_dec,
        x_mark_dec,
        enc_self_mask=None,
        dec_self_mask=None,
        dec_enc_mask=None,
    ):
        # decomp init
        zeros = torch.zeros(
            [x_dec.shape[0], self.pred_len, x_dec.shape[2]]
        ).to(
            self.device
        )  # cuda()
        seasonal_enc, trend_enc = self.decomp(x_enc)
        seasonal_dec = F.pad(
            seasonal_enc[:, -self.label_len :, :], (0, 0, 0, self.pred_len)
        )

        # seasonal
        enc_out = self.enc_seasonal_embedding(seasonal_enc, x_mark_enc)
        enc_out, attn_e = self.seasonal_encoder(
            enc_out, attn_mask=enc_self_mask
        )

        dec_out = self.dec_seasonal_embedding(seasonal_dec, x_mark_dec)
        seasonal_out, attn_d = self.seasonal_decoder(
            dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask
        )
        seasonal_out = seasonal_out[:, -self.pred_len :, :]

        seasonal_ratio = seasonal_enc.abs().mean(
            dim=1
        ) / seasonal_out.abs().mean(dim=1)
        seasonal_ratio = seasonal_ratio.unsqueeze(1).expand(
            -1, self.pred_len, -1
        )

        # trend
        trend_enc = self.revin_trend(trend_enc, "norm")
        trend_out = self.trend(trend_enc.permute(0, 2, 1)).permute(0, 2, 1)
        trend_out = self.revin_trend(trend_out, "denorm")

        # final
        dec_out = trend_out + seasonal_ratio * seasonal_out

        if self.output_attention:
            return dec_out, (attn_e, attn_d)
        elif self.output_stl:
            return (
                dec_out,
                trend_enc,
                seasonal_enc,
                trend_out,
                seasonal_ratio * seasonal_out,
            )
        else:
            return dec_out

### main function

main parameters to change:

1. fix_seed: random seed
2. dataset: choices: electricity.csv, exchange_rate.csv, traffic.csv, weather.csv, sin.csv, vary.csv, linear.csv, spikes.csv. Data will be automatically downloaded for the first time.
3. r: training ratio
4. pred: prediction horizon
5. v: attention version (Time, Fourier, Wavelet)
6. model: Transformer, MLP, TDformer

In [None]:
def main():
    class Namespace:
        def __init__(self, **kwargs):
            self.__dict__.update(kwargs)

    dim_dict = {
        "ETTm2.csv": 7,
        "electricity.csv": 321,
        "exchange_rate.csv": 8,
        "traffic.csv": 862,
        "weather.csv": 21,
        "sin.csv": 1,
        "vary.csv": 1,
        "linear.csv": 1,
        "spikes.csv": 1,
    }
    model_dict = {"MLP": MLP, "TDformer": TDformer, "Transformer": Transformer}

    if not os.path.exists("./dataset"):
        print(
            "first create './dataset/' in the current folder, then download the datasets from 'https://drive.google.com/drive/u/0/folders/1UD5jqDtJnliBhghkhM9CVqb49ZAhr1KB'"
        )
        assert 0
        # os.mkdir('./dataset')
        # import gdown
        # url = 'https://drive.google.com/drive/u/0/folders/1UD5jqDtJnliBhghkhM9CVqb49ZAhr1KB'
        # gdown.download_folder(url, remaining_ok=True)

    if not os.path.exists("./log"):
        os.mkdir("./log")

    for fix_seed in [0, 1, 42, 2021, 2022]:
        random.seed(fix_seed)
        torch.manual_seed(fix_seed)
        np.random.seed(fix_seed)

        for dataset in ["exchange_rate.csv"]:
            # for r in [0.0205,0.021,0.0215,0.022,0.03]: # ratio for sin.csv
            # for r in [0.0225,0.025,0.03,0.04,0.06]: # ratio for vary.csv
            # for r in [0.5]: # ratio for linear.csv and spikes.csv
            for r in [0.7]:  # ratio for real-world datasets
                for pred in [96, 192, 336, 720]:
                    for v in [
                        "Fourier"
                    ]:  # Attention types: Time, Fourier, Wavelet
                        args = Namespace(
                            model="TDformer",  # Transformer, MLP, TDformer
                            activation="softmax",
                            seq_len=96,  # 96,
                            label_len=48,  # 48,
                            temp=1,
                            pred_len=pred,
                            version=v,
                            data_path=dataset,
                            ratio=r,
                            features="M",
                            target="OT",
                            freq="h",
                            patience=20,
                            enc_in=dim_dict[dataset],
                            dec_in=dim_dict[dataset],
                            c_out=dim_dict[dataset],
                            adjust=False,
                            des="",
                            d_model=512,
                            d_ff=2048,
                            embed="timeF",
                            dropout=0.05,
                            output_attention=False,
                            output_stl=False,
                            learning_rate=0.0001,
                            lradj="type1",
                            moving_avg=[24],
                            gpu=0,
                        )
                        args.des = (
                            args.data_path[:-4]
                            + str(args.ratio)
                            + args.activation
                            + "temp"
                            + str(args.temp)
                            + args.model
                            + args.version
                            + str(args.pred_len)
                            + "seed"
                            + str(fix_seed)
                        )

                        args.device = torch.device(
                            "cuda" if torch.cuda.is_available() else "cpu"
                        )

                        model = (
                            model_dict[args.model](args)
                            .float()
                            .to(args.device)
                        )

                        exp = Exp_Main(args, model)  # set experiments
                        exp.train()

                        exp.test(test=1)
                        torch.cuda.empty_cache()


main()