In [1]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Layer, LSTM, Dense, Embedding, Dropout
from tensorflow.keras.initializers import GlorotUniform

In [None]:



class Encoder(Model):
    def __init__(self, vocab_size: int, embedding_size: int, units: int):
        """ The encoder model for the src sentences.
            It contains an embedding part and a GRU part.
        
        Args:
            vocab_size: The src vocabulary size
            embedding_size: The embedding size for the embedding layer
            units: Number of hidden units in the RNN (GRU) layer
        """
        super().__init__()
        # Start your code here
        # Note: Please know what the decoder needs from encoder. This determines the parameters of the GRU layer
        self.emb = Embedding(vocab_size, embedding_size, embeddings_initializer=GlorotUniform(seed=0), name='Encoder Embedding')
        self.lstm = LSTM(units, recurrent_initializer=GlorotUniform(seed=0), return_state=True, return_sequences=True, name='Encoder LSTM')
        #self.den = Dense(vocab_size, kernel_initializer=GlorotUniform(seed=0), Dropout=Dropout(rate=0.1, seed=0))
        # End

    def call(self, src_ids, src_mask):
        """ Encoder forward
        Args:
            src_ids: Tensor, (batch_size x max_len), the token ids of input sentences in a batch
            src_mask: Tensor, (batch_size x max_len), the mask of the src input. True value in the mask means this timestep is valid, otherwise this timestep is ignored
        Returns:
            enc_output: Tensor, (batch_size x max_len x units), the output of GRU for all timesteps
            final_state: Tensor, (batch_size x units), the state of the final valid timestep
        """
        # Start your code here
        # Step 1. Retrieve embedding
        #      2. GRU
        # Please refer to the calling arguments of GRU: https://www.tensorflow.org/api_docs/python/tf/keras/layers/GRU#call-arguments
        emb = self.emb(src_ids)
        enc_outputs, final_state = self.lstm(emb, mask=src_mask)
        #x = self.den(last_state)
        # End
        return enc_outputs, final_state

In [None]:
class Decoder(Model):
    def __init__(self, vocab_size: int, embedding_size: int, units: int, dropout_rate: float):
        """ The decoder model for the tgt sentences.
            It contains an embedding part, a GRU part, a dropout part, and a classifier part.
            
        Args:
            vocab_size: The tgt vocabulary size
            embedding_size: The embedding size for the embedding layer
            units: Number of hidden units in the RNN (GRU) layer
            dropout_rate: The classifier has a (units x vocab_size) weight. This is a large weight matrix. We apply a dropout layer to avoid overfitting.
        """
        super().__init__()
        # Start your code here
        # Note: 1. Please correctly set the parameter of GRU
        #       2. No softmax here because we will need the sequence to sequence loss later
        self.emb = Embedding(vocab_size, embedding_size, embeddings_initializer=GlorotUniform(seed=0), name='Decoder Embedding')
        self.lstm = LSTM(units, kernel_initializer=GlorotUniform(seed=0), return_state=True, return_sequences=True, name='Encoder GRU')
        self.den = Dense(vocab_size, activation='relu', kernel_initializer=GlorotUniform(seed=0), name='Decoder Dense')
        self.dropout = Dropout(rate=dropout_rate, seed=0)
        # End

    def call(self, tgt_ids, initial_state, tgt_mask):
        """ Decoder forward.
            It is called by decoder(tgt_ids=..., initial_state=..., tgt_mask=...)

        Args:
            tgt_ids: Tensor, (batch_size x max_len), the token ids of input sentences in a batch
            initial_state: Tensor, (batch_size x units), the state of the final valid timestep from the encoder
            tgt_mask: Tensor, (batch_size x max_len), the mask of the tgt input. True value in the mask means this timestep is valid, otherwise this timestep is ignored
        Return:
            dec_outputs: Tensor, (batch_size x max_len x vocab_size), the output of GRU for all timesteps
        """
        # Start your code here
        # Step 1. Retrieve embedding
        #      2. GRU
        #      3. Apply dropout to the GRU output
        #      4. Classifier
        # Note: Please refer to the calling arguments of GRU: https://www.tensorflow.org/api_docs/python/tf/keras/layers/GRU#call-arguments
        x = self.emb(tgt_ids)
        gru_outputs, _ = self.lstm(x, mask=tgt_mask, initial_state=initial_state)
        gru_outputs = self.dropout(gru_outputs)
        dec_outputs = self.den(gru_outputs)
        # End
        return dec_outputs
    
    def predict(self, tgt_ids, initial_state):
        """ Decoder prediction.
            This is a step in recursive prediction. We use the previous prediction and state to predict current token.
            Note that we only need to use the gru_cell instead of GRU becasue we only need to calculate one timestep.
            
        Args:
            tgt_ids: Tensor, (batch_size, ) -> (1, ), the token id of the current timestep in the current sentence.
            initial_state: Tensor, (batch_size x units) -> (1 x units), the state of the final valid timestep from the encoder or the previous hidden state in prediction.
        Return:
            dec_outputs: Tensor, (batch_size x vocab_size) -> (1 x vocab_size), the output of GRU for this timestep.
            state: Tensor, (batch_size x units) -> (1 x units), the state of this timestep.
        """
        lstm_cell = self.lstm.cell
        # Start your code here
        # Step 1. Retrieve embedding
        #      2. GRU Cell, see https://www.tensorflow.org/api_docs/python/tf/keras/layers/GRUCell#call-arguments
        #      3. Classifier (No dropout)
        x = self.emb(tgt_ids)
        gru_outputs, state = lstm_cell(x, states=initial_state, training=False)
        #gru_outputs = self.dropout(gru_outputs, training=False)
        dec_outputs = self.den(gru_outputs)
        # End
        return dec_outputs, state