In [None]:
import torch
import torch.nn as nn

In [None]:
#对于每个时刻的协变量，将其特征维度由r映射到远小于r的r_bar
class FeatureProjection(nn.Module):
    def __init__(self,input_size,hidden_size,output_size,dropout_prob):
        super().__init__()
        self.dense_relu = nn.Linear(input_size,hidden_size)
        self.relu = nn.ReLU()
        self.dense_linear = nn.Linear(hidden_size,output_size)
        self.dropout = nn.Dropout(p=dropout_prob)
        self.linear = nn.Linear(input_size,output_size)
    def forward(self,x):
        y = self.dense_relu(x)
        y = self.relu(y)
        y = self.dense_linear(y)
        y = self.dropout(y)
        x = self.linear(x)
        y = y+x
        return y

In [None]:
class DenseEncoder(nn.Module):
    def __init__(self,input_size,hidden_size,output_size,dropout_prob,num_block) :
        super().__init__()
        self.first_block =  FeatureProjection(input_size,hidden_size,hidden_size,dropout_prob)
        self.mid_block =  FeatureProjection(hidden_size,hidden_size,hidden_size,dropout_prob)
        self.last_block = FeatureProjection(hidden_size,hidden_size,output_size,dropout_prob)
        self.one_block = FeatureProjection(input_size,hidden_size,output_size,dropout_prob)
        self.num_block = num_block
    def forward(self,x):
        if(self.num_block==1):
            y = self.one_block(x)
            return y
        elif(self.num_block==2):
            y = self.first_block(x)
            y = self.last_block(y)
            return y
        else:
            y = self.first_block(x)
            for i in range(self.num_block-2):
                y = self.mid_block(y)
            y = self.last_block(y)
            return y

In [None]:
# p X H
#第t列代表预测的时刻t的特征向量，p维 
class DenseDecoder(nn.Module):
    def __init__(self,input_size,hidden_size,H,p,dropout_prob,num_block) :
        super().__init__()
        self.first_block =  FeatureProjection(input_size,hidden_size,hidden_size,dropout_prob)
        self.mid_block =  FeatureProjection(hidden_size,hidden_size,hidden_size,dropout_prob)
        self.output_size = H*p
        self.last_block = FeatureProjection(hidden_size,hidden_size,self.output_size,dropout_prob)
        self.one_block = FeatureProjection(input_size,hidden_size,self.output_size,dropout_prob)
        self.num_block = num_block
        self.H = H
        self.p = p
    def forward(self,x):
        if(self.num_block==1):
            y = self.one_block(x)
        elif(self.num_block==2):
            y = self.first_block(x)
            y = self.last_block(y)
        else:
            y = self.first_block(x)
            for i in range(self.num_block-2):
                y = self.mid_block(y)
            y = self.last_block(y)
            batch_size = y.shape[0]
        y = torch.reshape(y,(batch_size,self.p,self.H))
        return y

In [None]:
class TemporalDecoder(nn.Module):
    def __init__(self,input_size,hidden_size,dropout_prob):
        super().__init__()
        self.decoder = FeatureProjection(input_size,hidden_size,1,dropout_prob)
    def forward(self,x):
        y = self.decoder(x)
        return y

In [None]:
#batch_size seq_length feature_size
class TiDE(nn.Module):
    def __init__(self,temporalWidth,hiddenSize,numEncoderLayers,H,decoderOutputDim,numDecoderLayers,temporalDecoderHidden,y_position=0,attrib_position=1):
        self.attrib_position = attrib_position
        self.y_position = y_position
        self.temporalWidth = temporalWidth#r_bar
        self.hiddenSize = hiddenSize
        self.numEncoderLayers = numEncoderLayers
        self.H = H
        self.decoderOutputDim = decoderOutputDim#p
        self.numDecoderLayers = numDecoderLayers
        self.temporalDecoderHidden = temporalDecoderHidden
        self.flatten = nn.Flatten()
    def forward(self,x):
        batch_size = x.shape[0];
        seq_length = x.shape[1];
        feature_size = x.shape[2];
        y_lookback = x[:,:,self.y_position]#batch_size seq_length 1
        a = x[:,:,self.attrib_position]#batch_size seq_length 1
        x = x[:,2:,:]#batch_size seq_length feature_size-2
        #把特征维度压缩至r_bar
        feature_projection = FeatureProjection(feature_size,(feature_size+self.temporalWidth)//2,self.temporalWidth,0.2)
        for i in range(x.shape[1]):#batch_size 1 r_bar
            if(i==0):
                x_bar = feature_projection(x[:,i,:].squeeze()).unsqueeze(dim=1)
            else:
                x_bar = x_bar.concat(feature_projection(x[:,i,:].squeeze()).unsqueeze(dim=1),dim=1)
            
        #x_bar batch_size seq_length temporalWidth
        #y_lookback batch_size seq_length 1
        #a batch_size seq_length 1

        #flatten()会智能忽略batch_size维度
        x_bar_flat = self.flatten(x_bar)# batch_size seq_length*temporalWidth
        y_lookback = y_lookback.squeeze()#batch_size seq_length
        a = a.squeeze()#batch_size seq_length

        encoder_input = torch.cat((y_lookback,a,x_bar_flat),dim=0)
        # encoder_input batch_size seq_length+seq_length+seq_length*temporalWidth
        
        dense_encoder = DenseEncoder(encoder_input.shape[-1],self.hiddenSize,self.hiddenSize//2,0.2,self.numEncoderLayers)
        e = dense_encoder(encoder_input)
        #e batch_size self.hiddenSize//2
        dense_decoder = DenseDecoder(e.shape[-1],self.hiddenSize,self.H,self.decoderOutputDim,0.2,self.numDecoderLayers)
        g = dense_decoder(e)
        #g batch_size decoderOutputDim H
        x_bar = x_bar.reshape(batch_size,self.decoderOutputDim,(seq_length*self.temporalWidth)//self.decoderOutputDim,-1)
        x_bar = x_bar[:,:,-self.H:]
        #x_bar batch_size decoderOutputDim H
        temporal_decoder_input = torch.cat((g,x_bar),dim=1)
        #temporal_decoder_input batch_size decoderOutputDim*2 H
        feature_size = temporal_decoder_input.shape[-2];
        temporal_decoder = TemporalDecoder(feature_size,(feature_size+self.temporalWidth)//2,self.temporalWidth,0.2)
        for i in range(temporal_decoder_input.shape[-1]):
            if(i==0):
                y_bar = temporal_decoder(temporal_decoder_input[:,:,i].squeeze()).unsqueeze(dim=-1)
            else:
                y_bar = y_bar.concat(temporal_decoder(temporal_decoder_input[:,:,i].squeeze()).unsqueeze(dim=-1),dim=-1)
        #y_bar batch_size 1 H
        y_bar = y_bar.squeeze()
        linear = nn.Linear(seq_length,self.H)
        #y_lookback batch_size seq_length
        y_lookback = linear(y_lookback)
        output = y_lookback+y_bar
        return output