In [None]:
class DecoderWithRNN(nn.Module):
    # ...
    def __init__(self, cfg, encoder_dim=14*14*2048):
        super(DecoderWithRNN, self).__init__()
        # 初始化层的其余部分...
        self.embedding = nn.Embedding(cfg['vocab_size'], cfg['embed_dim'])
        self.decode_step = nn.LSTMCell(cfg['embed_dim'], cfg['decoder_dim'])
        self.init_h = nn.Linear(encoder_dim, cfg['decoder_dim'])  # 初始隐藏状态h0
        self.init_c = nn.Linear(encoder_dim, cfg['decoder_dim'])  # 初始细胞状态c0
        self.fc = nn.Linear(cfg['decoder_dim'], cfg['vocab_size'])
        # 可能还需要其他层...

    def forward(self, encoder_out, encoded_captions, caption_lengths):
        # 前向传播的其余部分...
        h, c = self.init_h(encoder_out), self.init_c(encoder_out)  # (batch_size, decoder_dim)
        for t in range(max(decode_lengths)):
            batch_size_t = sum([l > t for l in decode_lengths])
            embeddings = self.embedding(encoded_captions[:batch_size_t, t])
            h, c = self.decode_step(embeddings, (h[:batch_size_t], c[:batch_size_t]))
            preds = self.fc(h)  # (batch_size_t, vocab_size)
            predictions[:batch_size_t, t, :] = preds
        return predictions, encoded_captions, decode_lengths, sort_ind

class Attention(nn.Module):
    # ...
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super(Attention, self).__init__()
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)  # 编码器的线性层
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)  # 解码器的线性层
        self.full_att = nn.Linear(attention_dim, 1)  # 用于计算注意力权重的线性层
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)  # softmax层来计算权重

    def forward(self, encoder_out, decoder_hidden):
        att1 = self.encoder_att(encoder_out)  # (batch_size, num_pixels, attention_dim)
        att2 = self.decoder_att(decoder_hidden)  # (batch_size, attention_dim)
        att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2)  # (batch_size, num_pixels)
        alpha = self.softmax(att)  # (batch_size, num_pixels)
        attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)  # (batch_size, encoder_dim)
        return attention_weighted_encoding, alpha

class DecoderWithAttention(nn.Module):
    # ...
    def __init__(self, cfg, encoder_dim=2048):
        super(DecoderWithAttention, self).__init__()
        # 初始化层的其余部分...
        self.attention = Attention(encoder_dim, cfg['decoder_dim'], cfg['attention_dim'])
        self.embedding = nn.Embedding(cfg['vocab_size'], cfg['embed_dim'])
        self.decode_step = nn.LSTMCell(cfg['embed_dim'] + encoder_dim, cfg['decoder_dim'])
        self.init_h = nn.Linear(encoder_dim, cfg['decoder_dim'])  # 初始隐藏状态h0
        self.init_c = nn.Linear(encoder_dim, cfg['decoder_dim'])  # 初始细胞状态c0
        self.f_beta = nn.Linear(cfg['decoder_dim'], encoder_dim)  # 线性层以创建sigmoid激活门
        self.sigmoid = nn.Sigmoid()
        self.fc = nn.Linear(cfg['decoder_dim'], cfg['vocab_size'])
        # 可能还需要其他层...

    def forward(self, encoder_out, encoded_captions, caption_lengths):
        # 前向传播的其余部分...
        h, c = self.init_h(mean_encoder_out), self.init_c(mean_encoder_out)  # (batch_size, decoder_dim)
        for t in range(max(decode_lengths)):
            batch_size_t = sum([l > t for l in decode_lengths])
            attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t],
                                                                h[:batch_size_t])
            gate = self.sigmoid(self.f_beta(h[:batch_size_t]))  # gating scalar, (batch_size_t, encoder_dim)
            attention_weighted_encoding = gate * attention_weighted_encoding
            embeddings = self.embedding(encoded_captions[:batch_size_t, t])
            h, c = self.decode_step(torch.cat([embeddings, attention_weighted_encoding], dim=1),
                                    (h[:batch_size_t], c[:batch_size_t]))
            preds = self.fc(h)  # (batch_size_t, vocab_size)
            predictions[:batch_size_t, t, :] = preds
            alphas[:batch_size_t, t, :] = alpha
        return predictions, encoded_captions, decode_lengths, alphas, sort_ind

In [None]:
def one_step(self, prev_embeddings, encoder_out, h, c):
    """Perform a single decoding step.

    :param prev_embeddings: embeddings of previous words, shape: (batch_size, embed_dim)
    :param encoder_out: encoded images, shape: (batch_size, num_pixels, encoder_dim)
    :param h: previous hidden state, shape: (batch_size, decoder_dim)
    :param c: previous cell state, shape: (batch_size, decoder_dim)
    :return: preds - prediction scores for next word, shape: (batch_size, vocab_size)
    :return: alpha - attention weights, shape: (batch_size, num_pixels)
    :return: h - new hidden state, shape: (batch_size, decoder_dim)
    :return: c - new cell state, shape: (batch_size, decoder_dim)
    """
    attention_weighted_encoding, alpha = self.attention(encoder_out, h)
    gate = self.sigmoid(self.f_beta(h))
    attention_weighted_encoding = gate * attention_weighted_encoding
    h, c = self.decode_step(
        torch.cat([prev_embeddings, attention_weighted_encoding], dim=1), (h, c)
    )
    preds = self.fc(h)
    return preds, alpha, h, c