In [20]:
import torch
from torch import nn
from torch.nn.utils import weight_norm
from d2l import torch as d2l

import math

In [21]:
class ConvBlock(nn.Module):
    def __init__(self, dropout, **kwargs):
        super(ConvBlock, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
        self.Temporal_Conv = nn.Conv2d(1, 16, kernel_size=(1,64), padding="same")
        self.Channel_DW_Conv = nn.Conv2d(16, 32, kernel_size=(22,1))
        self.Spatial_Conv = nn.Conv2d(32, 32, kernel_size=(1,16), padding="same")
        self.window = nn.Conv2d(32, 32, kernel_size=(1,5))
        self.Convolutional_Block = nn.Sequential(self.Temporal_Conv, nn.BatchNorm2d(16),
                    self.Channel_DW_Conv, nn.BatchNorm2d(32), nn.ELU(), 
                    nn.AvgPool2d(kernel_size=(1,8)),
                    self.Spatial_Conv, nn.BatchNorm2d(32), nn.ELU(), 
                    nn.AvgPool2d(kernel_size=(1,7)),
                    self.window, nn.Flatten(2,3))
    def forward(self, X):
        return self.Convolutional_Block(X)

In [22]:
class MultiHeadAttention(nn.Module):
    """多头注意力"""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = d2l.DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        # queries，keys，values的形状:
        # (batch_size，查询或者“键－值”对的个数，num_hiddens)
        # valid_lens　的形状:
        # (batch_size，)或(batch_size，查询的个数)
        # 经过变换后，输出的queries，keys，values　的形状:
        # (batch_size*num_heads，查询或者“键－值”对的个数，
        # num_hiddens/num_heads)
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            # 在轴0，将第一项（标量或者矢量）复制num_heads次，
            # 然后如此复制第二项，然后诸如此类。
            valid_lens = torch.repeat_interleave(
                valid_lens, repeats=self.num_heads, dim=0)

        # output的形状:(batch_size*num_heads，查询的个数，
        # num_hiddens/num_heads)
        output = self.attention(queries, keys, values, valid_lens)

        # output_concat的形状:(batch_size，查询的个数，num_hiddens)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)
    
def transpose_qkv(X, num_heads):
    """为了多注意力头的并行计算而变换形状"""
    
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
    
    X = X.permute(0, 2, 1, 3)

    return X.reshape(-1, X.shape[2], X.shape[3])

def transpose_output(X, num_heads):
    """逆转transpose_qkv函数的操作"""
    
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    
    X = X.permute(0, 2, 1, 3)
    
    return X.reshape(X.shape[0], X.shape[1], -1)

In [23]:
class AddNorm(nn.Module):
    def __init__(self, dropout, **kwargs):
        super(AddNorm, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)

    def forward(self, X, Y):
        return self.dropout(Y) + X

In [24]:
class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()


class TemporalBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.3):
        super(TemporalBlock, self).__init__()
        self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,
                                 self.conv2, self.chomp2, self.relu2, self.dropout2)
        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.relu = nn.ReLU()
        self.init_weights()

    def init_weights(self):
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        if self.downsample is not None:
            self.downsample.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.relu(out + res)


class TemporalConvNet(nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
        super(TemporalConvNet, self).__init__()
        layers = []
        num_levels = len(num_channels)
        for i in range(num_levels):
            dilation_size = 2 ** i
            in_channels = num_inputs if i == 0 else num_channels[i-1]
            out_channels = num_channels[i]
            layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
                                     padding=(kernel_size-1) * dilation_size, dropout=dropout)]

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)

In [25]:
class EncoderBlock(nn.Module):
    
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                  num_heads, dropout, **kwargs):
        super(EncoderBlock, self).__init__(**kwargs)
        self.num_channels=[32]
        self.ln = nn.LayerNorm(16)
        self.attention = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout)
        self.addnorm = AddNorm(dropout)
        self.tcn = TemporalConvNet(32, self.num_channels, kernel_size=4, dropout=0.5)

    def forward(self, X, valid_lens):
        X = self.ln(X)
        Y = self.attention(X, X, X, valid_lens)
        Y = self.addnorm(X,Y)
        return  self.tcn(Y)

In [26]:
class Encoder(nn.Module):
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                  num_heads, dropout, **kwargs):
        super(Encoder, self).__init__(**kwargs)
        self.ConvBlock = ConvBlock(dropout)
        self.blks = nn.Sequential()
        num_layers = 2 
        for i in range(num_layers):
            self.blks.add_module("block"+str(i),
                EncoderBlock(key_size, query_size, value_size, num_hiddens,
                  num_heads, dropout))
        self.fl = nn.Flatten(1,2)
        self.sm = nn.Linear(32*16, 4, bias=False)

    def forward(self, X, valid_lens, *args):  
        X = self.ConvBlock(X)
        for i, blk in enumerate(self.blks):
            X = blk(X, valid_lens)
        flx = self.fl(X)
        res = self.sm(flx)
        return res

In [27]:
X = torch.randn(size=[288, 1, 22, 1126])
ec = Encoder(16,16,16,16,2,0.5)
ec(X, None).shape

torch.Size([288, 4])