In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class Encoder(nn.Module):
    def __init__(self, in_len, out_len, d_model, n_heads, e_layers, dropout):
        super(Encoder, self).__init__()
        self.in_len = in_len
        self.out_len = out_len
        self.d_model = d_model
        self.n_heads = n_heads
        self.e_layers = e_layers
        self.dropout = dropout
        self.pos_encoder = nn.Linear(in_len, d_model)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model, n_heads, dropout)
        self.encoder = nn.TransformerEncoder(self.encoder_layer, e_layers)
        self.fc = nn.Linear(d_model, out_len)

    def forward(self, src):
        src = self.pos_encoder(src)
        output = self.encoder(src)
        output = output.transpose(0,1)
        output = self.fc(output)
        return output
    
class Decoder(nn.Module):
    def __init__(self, in_len, out_len, d_model, n_heads, d_layers, dropout):
        super(Decoder, self).__init__()
        self.in_len = in_len
        self.out_len = out_len
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_layers = d_layers
        self.dropout = dropout
        self.pos_decoder = nn.Linear(in_len, d_model)
        self.decoder_layer = nn.TransformerDecoderLayer(d_model, n_heads, dropout)
        self.decoder = nn.TransformerDecoder(self.decoder_layer, d_layers)
        self.fc = nn.Linear(d_model, out_len)
    
    def forward(self, src, memory):
        src = self.pos_decoder(src)
        output = self.decoder(src, memory)
        output = output.transpose(0,1)
        output = self.fc(output)
        return output

class Informer(nn.Module):
    def __init__(self, in_len, out_len, d_model, n_heads, e_layers, d_layers, dropout):
        super(Informer, self).__init__()
        self.in_len = in_len
        self.out_len = out_len
        self.d_model = d_model
        self.n_heads = n_heads
        self.e_layers = e_layers
        self.d_layers = d_layers
        self.dropout = dropout
        self.encoder = Encoder(in_len, out_len, d_model, n_heads, e_layers, dropout)
        self.decoder = Decoder(out_len, out_len, d_model, n_heads, d_layers, dropout)

    def forward(self, src):
        encoder_output = self.encoder(src)
        decoder_output = self.decoder(torch.zeros_like(src[:, -self.out_len:]), encoder_output)
        return decoder_output
