Following [this](https://https://huggingface.co/blog/autoformer) tutorial

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DecompositionLayer(nn.Module):
    def __init__(self, kernel_size):
        super().__init__()
        self.avg_pool = nn.AvgPool1d(kernel_size, stride=1, padding=kernel_size//2, count_include_pad=False)

    def forward(self, x):
        trend = self.avg_pool(x)
        seasonal = x - trend
        return trend, seasonal

class AutocorrelationAttention(nn.Module):
    def __init__(self, embed_size):
        super().__init__()
        self.fc = nn.Linear(embed_size, embed_size)

    def forward(self, x):

        fft_x = torch.fft.rfft(x, dim=-1)
        autocorr = torch.fft.irfft(fft_x * torch.conj(fft_x), dim=-1, n=x.shape[-1])
        return self.fc(autocorr)

class Autoformer(nn.Module):
    def __init__(self, sequence_length, embed_size):
        super().__init__()
        self.decomposition = DecompositionLayer(kernel_size=24*2)
        self.attention = AutocorrelationAttention(embed_size=embed_size)
        self.fc = nn.Linear(sequence_length * embed_size, 1)

    def forward(self, x):

        trend, seasonal = self.decomposition(x.transpose(1, 2))
        trend, seasonal = trend.transpose(1, 2), seasonal.transpose(1, 2)


        seasonal_attended = self.attention(seasonal)


        combined = torch.cat([trend, seasonal_attended], dim=-1)


        out = self.fc(combined.flatten(start_dim=1))
        return out


sequence_length = 2880
embed_size = 128
model = Autoformer(sequence_length=sequence_length, embed_size=embed_size)

print(model)
