# A simplified version of sequence-to-sequence with attention mechanism

*Yanagi*

*Updated May 21th, 2022*

> *Reference:*
> + **original seq2seq+attention paper**: *Neural machine translation by jointly learning to align and translate. Bahdanau et al., **ICLR 2015**.*
>
> + **variant with 2 different attention mechanism**: *Effecttive approaches to attention-based neural machine translation. Luong et al. **ENMLP 2015**.*


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.__version__

'1.9.0+cu111'

In [2]:
class Seq2SqeEncoder(nn.Module):
    """An encoder implementation based on nn.LSTM."""
    def __init__(self, embed_size, hidden_size, src_vocab_size):
        super(Seq2SqeEncoder, self).__init__()

        self.lstm_layer = nn.LSTM(input_size=embed_size, hidden_size=hidden_size, batch_first=True)
        self.embedding_table = nn.Embedding(src_vocab_size, embed_size)

    def forward(self, input_ids):
        
        input_seq = self.embedding_table(input_ids)
        output_states, (final_h, final_c) = self.lstm_layer(input_seq)

        return output_states, final_h

In [3]:
class Seq2SeqAttentionMechanism(nn.Module):
    """An dot-product attention mechanism implementation."""
    def __init__(self):
        super(Seq2SeqAttentionMechanism, self).__init__()

    def forward(self, decoder_state_t, encoder_states):
        """
        @param decoder_state_t (Tensor): hidden state of decoder at timestep *t*
        @param encoder_states (Tensor): all hidden states from encoder
        """
        bs, src_len, hidden_size = encoder_states.shape

        # [bs, hidden_size] -> [bs, 1, hidden_size]
        decoder_state_t = decoder_state_t.unsqueeze(1)
        # [bs, 1, hidden_size] -> [bs, src_len, hidden_size]
        decoder_state_t = torch.tile(decoder_state_t, dims=(1, src_len, 1))

        score = torch.sum(decoder_state_t * encoder_states, dim=-1) # [bs, src_len]

        attn_prob = F.softmax(score, dim=-1) # [bs, src_len]

        # [bs, src_len, 1] * [bs, 1, hidden_size], broadcasting
        context = torch.sum(attn_prob.unsqueeze(-1) * encoder_states, 1)

        return attn_prob, context

In [4]:
class Seq2SqeDecoder(nn.Module):
    """An decoder implementation based on nn.LSTMCell."""
    def __init__(self, embed_size, hidden_size, num_classes, tgt_vocab_size, start_id, end_id):
        super(Seq2SqeDecoder, self).__init__()

        self.lstm_cell = nn.LSTMCell(embed_size, hidden_size)
        self.dense = nn.Linear(hidden_size * 2, num_classes)
        self.attention_mechanism = Seq2SeqAttentionMechanism()
        self.num_classes = num_classes
        self.embedding_table = nn.Embedding(tgt_vocab_size, embed_size)
        self.stard_id = start_id
        self.end_id = end_id

    def forward(self, shifted_target_ids, encoder_states):
        # invoked when training

        # teacher forcing
        shifted_target = self.embedding_table(shifted_target_ids)

        bs, tgt_len, embed_size = shifted_target.shape
        bs, src_len, hidden_size = encoder_states.shape

        logits = torch.zeros(bs, tgt_len, self.num_classes)
        probs = torch.zeros(bs, tgt_len, src_len)

        for t in range(tgt_len):
            decoder_input_t = shifted_target[:, t, :] # [bs, embed_size]
            if t == 0:
                h_t, c_t = self.lstm_cell(decoder_input_t) # if h_0 and c_0 is not provided, default to zero
            else:
                h_t, c_t = self.lstm_cell(decoder_input_t, (h_t, c_t))

            attn_prob, context = self.attention_mechanism(h_t, encoder_states)

            # context: [bs, hidden_size], h_t: [1, hidden_size]
            decoder_output = torch.cat((context, h_t), dim=-1)
            logits[:, t, :] = self.dense(decoder_output)
            probs[:, t, :] = attn_prob

        return probs, logits

    def inference(self, encoder_states):
        target_id = self.start_id
        h_t = None
        result = []

        while True:
            decoder_input_t = self.embedding_table(target_id)
            if h_t is None:
                h_t, c_t = self.lstm_cell(decoder_input_t)
            else:
                h_t, c_t = self.lstm_cell(decoder_input_t, (h_t, c_t))

            attn_prob, context = self.attention_mechanism(h_t, encoder_states)

            decoder_output = torch.cat((context, h_t), dim=-1)
            logits = self.dense(decoder_output)

            target_id = torch.argmax(logits, dim=-1)
            result.append(target_id)

            if torch.any(target_id == self.end_id):
                print("Stop decoding!")
                break

        predicted_ids = torch.stack(result, dim=0)
        return predicted_ids


In [5]:
class EncoderDecoder(nn.Module):
    """An encoder-decoder model implementation"""
    def __init__(self, embed_size, hidden_size, num_classes, src_vocab_size, tgt_vocab_size,
                start_id, end_id):
        super(EncoderDecoder, self).__init__()
        self.encoder = Seq2SqeEncoder(embed_size, hidden_size, src_vocab_size)
        self.decoder = Seq2SqeDecoder(embed_size, hidden_size, num_classes, tgt_vocab_size, start_id, end_id)

    def forward(self, input_sequencer_ids, shifted_target_ids):
        # when training

        encoder_states, final_h = self.encoder(input_sequencer_ids)
        probs, logits = self.decoder(shifted_target_ids, encoder_states)
        return probs, logits

    def infer(self):
        # TODO
        pass

In [6]:
if __name__ == "__main__":
    src_len = 3
    tgt_len = 4
    embed_size = 8
    hidden_size =16
    num_classes = 10
    bs = 2
    start_id = end_id = 0
    src_vocab_size = 100
    tgt_vocab_size = 100

    input_sequence_ids = torch.randint(src_vocab_size, size=(bs, src_len)).to(torch.int32)

    target_ids = torch.randint(tgt_vocab_size, size=(bs, tgt_len))
    target_ids = torch.cat((target_ids, end_id * torch.ones(bs, 1), ), dim=1).to(torch.int32)

    shifted_target_ids = torch.cat((start_id * torch.ones(bs, 1), target_ids[:, 1:]), dim=1).to(torch.int32)

    model = EncoderDecoder(embed_size, hidden_size, num_classes, src_vocab_size, tgt_vocab_size, start_id, end_id)
    probs, logits = model(input_sequence_ids, shifted_target_ids)
    print(input_sequence_ids.shape)
    print(shifted_target_ids.shape)
    print(probs.shape)
    print(logits.shape)


torch.Size([2, 3])
torch.Size([2, 5])
torch.Size([2, 5, 3])
torch.Size([2, 5, 10])


In [8]:
input_sequence_ids

tensor([[74, 68, 72],
        [98, 36, 67]], dtype=torch.int32)

In [9]:
target_ids

tensor([[90, 79, 42, 15,  0],
        [67,  3,  0, 11,  0]], dtype=torch.int32)