In [1]:
import torch
from torch import nn

In [7]:
class SAttn(nn.Module):
    def __init__(self, input_dim):
        super(SAttn, self).__init__()
        self.input_dim = input_dim
        self.q = nn.Linear(self.input_dim, self.input_dim, bias=False)
        self.k = nn.Linear(self.input_dim, self.input_dim, bias=False)
        self.v = nn.Linear(self.input_dim, self.input_dim, bias=False)
    def forward(self, input_tensor):
        q = self.q(input_tensor)
        k = self.k(input_tensor)
        v = self.v(input_tensor)
        attn = torch.matmul(q, k.permute(0, 2, 1))
        attn = torch.softmax(attn, dim=2)
        output = torch.matmul(attn, v)
        return output

In [8]:
class FFTLayer(nn.Module):
    def __init__(self, seq_len, feat_size):
        super(FFTLayer, self).__init__()
        self.seq_len = seq_len
        self.feat_size = feat_size
        
        self.complex_weight = nn.Parameter(torch.randn(1, self.seq_len//2+1, self.feat_size, 2, requires_grad=True))
        
    def forward(self, _x):
        hidden_states=_x
        x = torch.fft.rfft(hidden_states, n=self.seq_len, dim=1, norm='forward')
        weight = torch.view_as_complex(self.complex_weight)
        x = x*weight
        suq_emb_fft = torch.fft.irfft(x, n=self.seq_len, dim=1, norm='forward')
        
        return suq_emb_fft

In [9]:
class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, seq_len=9, num_layers=2, dropout=0.):
        super(LSTMModel, self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.seq_len = seq_len
        self.num_layers = num_layers
        self.cat_dim = self.input_size+self.hidden_size
        
        # RNN
        self.rnn = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True)
        
        # FFT
        self.fft = FFTLayer(seq_len, self.input_size)
        self.weight_aggr = nn.Linear(self.seq_len, 1)
        
        # Att
        self.EAtt = SAttn(self.cat_dim)
        
        # Prediction
        self.fc1 = nn.Linear(self.cat_dim, self.input_size)
        self.elu = nn.ELU()
        self.fc2 = nn.Linear(self.input_size, self.output_size)
    def forward(self, _x):
        # RNN: Trend
        f_trend, _ = self.rnn(_x)
        
        # FFT: Cycle
        f_cycle = self.fft(_x)
        
        # Att
        combine_feature = torch.cat([f_trend, f_cycle], dim=-1)
        combine_feature = self.EAtt(combine_feature)
        combine_feature = self.weight_aggr(combine_feature.permute(0,2,1)).flatten(1)
        
        # Prediction
        hn = self.fc1(combine_feature)
        hn = self.elu(hn)
        hn = self.fc2(hn)
        return hn