In [36]:
import torch
import torch.nn as nn


class Encoder(nn.Module):
    """编码器"""

    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(input_size, hidden_size)  # 词嵌入层
        self.lstm = nn.LSTM(hidden_size, hidden_size)  # lstm层

    def forward(self, input):
        # input: [seq_len, batch_size]
        # embedded: [seq_len, batch_size, hidden_size]
        # output: [seq_len, batch_size, hidden_size]
        # hidden: [1, batch_size, hidden_size]
        # cell: [1, batch_size, hidden_size]
        embedded = self.embedding(input)
        output, (hidden, cell) = self.lstm(embedded)
        return output, hidden, cell


class Attention(nn.Module):
    """注意力机制"""

    def __init__(self, hidden_size):
        super().__init__()
        self.Wq = nn.Linear(hidden_size, hidden_size)
        self.Wk = nn.Linear(hidden_size, hidden_size)
        self.W = nn.Linear(hidden_size, 1)

    def forward(self, encoder_outputs, decoder_hidden):
        # encoder_outputs: [seq_len, batch_size, hidden_size]
        # decoder_hidden: [1, batch_size, hidden_size]
        encoder_outputs = encoder_outputs.permute(1, 0, 2)  # 转换形状为[batch_size, seq_len, hidden_size]
        decoder_hidden = decoder_hidden.permute(1, 0, 2)  # 转换形状为[batch_size, 1, hidden_size]
        score = self.W(torch.tanh(self.Wq(decoder_hidden) + self.Wk(encoder_outputs)))
        attention_weights = torch.softmax(score, dim=2)
        context = attention_weights * encoder_outputs
        return context


class Decoder(nn.Module):
    """解码器"""

    def __init__(self, output_size, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(output_size, hidden_size)  # 词嵌入层
        self.attention = Attention(hidden_size)  # 注意力机制
        self.lstm = nn.LSTM(hidden_size * 2, hidden_size)  # lstm层
        self.linear = nn.Linear(hidden_size, output_size)  # 全连接层

    def forward(self, input, hidden, cell, encoder_outputs):
        # input: [1, batch_size]
        # embedded: [1, batch_size, hidden_size]
        embedded = self.embedding(input)
        context = self.attention(encoder_outputs, hidden)
        input = torch.cat((embedded, context), dim=2)  # 拼接输入向量和上下文向量
        output, (hidden, cell) = self.lstm(input, (hidden, cell))
        output = self.linear(output)
        return output, hidden, cell


class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, source, target, teacher_forcing_ratio=0.5):
        # source: [seq_len, batch_size]
        # target: [seq_len, batch_size]
        batch_size = source.size(1)  # 批量大小
        target_len = target.size(0)  # 目标序列长度
        target_vocab_size = self.decoder.linear.out_features  # 目标词表大小

        outputs = torch.zeros(target_len, batch_size, target_vocab_size)  # 初始化输出张量
        encoder_outputs, hidden, cell = self.encoder(source)  # 获取编码器输出，隐藏层状态，细胞状态
        input = target[0, :]  # 目标序列第一个词<SOS>作为解码器输入
        for t in range(1, target_len):
            output, hidden, cell = self.decoder(input.unsqueeze(0), hidden, cell, encoder_outputs)  # 获取解码器输出
            outputs[t:,] = output.squeeze(1)  # 将解码器输出添加到输出中
            # 教师强制是指在解码器训练时，使用真实目标序值作为输入
            teacher_force = torch.rand(1).item() < teacher_forcing_ratio  # 随机使用教师强制
            input = target[t] if teacher_force else output.argmax(2).squeeze(0)  # 根据是否使用教师强制选择输入词
        return outputs


input_size, output_size, hidden_size = 5, 6, 4  # 输入词表大小，输出词表大小，隐藏层维度
encoder = Encoder(input_size, hidden_size)
decoder = Decoder(output_size, hidden_size)
model = Seq2Seq(encoder, decoder)

source = torch.tensor([[0], [1]])
target = torch.tensor([[0], [1], [1]])
output = model(source, target)

RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 1 but got size 2 for tensor number 1 in the list.