In [27]:
import numpy as np
import torch

from torch import nn

from src.algo.language.lm import OneHotEncoder, GRUEncoder, init_rnn_params

class GRUDecoder(nn.Module):
    """
    Class for a language decoder using a Gated Recurrent Unit network
    """
    def __init__(self, context_dim, embed_dim, word_encoder, max_len, 
                 n_layers=1, embed_layer= None, device="cpu"):
        """
        Inputs:
            :param context_dim (int): Dimension of the context vectors
            :param word_encoder (OneHotEncoder): Word encoder, associating 
                tokens with one-hot encodings
            :param n_layers (int): number of layers in the GRU (default: 1)
            :param device (str): CUDA device
        """
        super(GRUDecoder, self).__init__()
        self.device = device
        self.max_len = max_len
        # Dimension of hidden states
        self.hidden_dim = context_dim
        # Word encoder
        self.word_encoder = word_encoder
        # Number of recurrent layers
        self.n_layers = n_layers
        # Embedding layer
        if embed_layer is not None:
            self.embed_layer = embed_layer
        else:
            self.embed_layer = nn.Embedding(self.word_encoder.enc_dim, embed_dim)
        # Model
        self.gru = nn.GRU(
            embed_dim, 
            self.hidden_dim, 
            self.n_layers)
        init_rnn_params(self.gru)
        # Output layer
        self.out = nn.Sequential(
            nn.Linear(self.hidden_dim, self.word_encoder.enc_dim),
            nn.LogSoftmax(dim=2)
        )

    def forward_step(self, last_token, last_hidden):
        """
        Generate prediction from GRU network.
        Inputs:
            :param last_token (torch.Tensor): Token at last time step, 
                dim=(1, 1, token_dim).
            :param last_hidden (torch.Tensor): Hidden state of the GRU at last
                time step, dim=(1, 1, hidden_dim).
        Outputs:
            :param output (torch.Tensor): Log-probabilities outputed by the 
                model, dim=(1, 1, token_dim).
            :param hidden (torch.Tensor): New hidden state of the GRU network,
                dim=(1, 1, hidden_dim).
        """
        output, hidden = self.gru(last_token, last_hidden)
        output = self.out(output)
        return output, hidden
    
    def forward(self, context_batch, target_encs=None):
        """
        Transforms context vectors to sentences
        Inputs:
            :param context_batch (torch.Tensor): Batch of context vectors,
                dim=(batch_size, context_dim).
            :param target_encs (torch.Tensor): Batch of target sentences used
                for teacher forcing, encoded as onehots and padded with -1, 
                dim=(batch_size, max_sent_len, enc_dim). If None then no 
                teacher forcing. Default: None.
        Outputs:
            :param decoder_outputs (list): Batch of tensors containing
                log-probabilities generated by the GRU network.
            :param sentences (list): Sentences generated with greedy 
                sampling. Empty if target_encs is not None (teacher forcing,
                so we only care about model predictions).
        """
        teacher_forcing = target_encs is not None
        batch_size = context_batch.size(0)
        max_sent_len = target_encs.shape[1] if teacher_forcing \
            else self.max_len

        if teacher_forcing:
            # Embed
            target_ids = target_encs.argmax(-1)
            target_embeds = self.embed_layer(target_ids)

        hidden = context_batch.unsqueeze(0)
        # Init last token to the SOS token, embedded
        last_tokens = self.embed_layer(
            torch.zeros((1, batch_size), dtype=torch.int).to(self.device))

        tokens = []
        decoder_outputs = []
        # sentences = [[] for b_i in range(batch_size)]
        sentences = torch.Tensor()
        sent_finished = np.array([False] * batch_size)
        print(sent_finished)
        for t_i in range(max_sent_len):
            # RNN pass
            outputs, hidden = self.forward_step(last_tokens, hidden)
            decoder_outputs.append(outputs)

            # Sample next tokens
            if teacher_forcing:
                last_tokens = target_embeds[:, t_i].unsqueeze(0)
            else:
                _, topi = outputs.topk(1)
                # topi = topi.squeeze(-1)
                last_tokens = self.embed_layer(topi.squeeze(-1))
                #print(last_tokens, last_tokens.shape)
                print(topi, topi.shape)
                sent_finished = sent_finished | topi.squeeze() == 1
                print(sent_finished)

                input()

                # for b_i in range(batch_size):
                #     if topi[b_i] == self.word_encoder.EOS_ID:
                #         sent_finished[b_i] = True
                #     if not sent_finished[b_i]:
                #         sentences[b_i].append(
                #             self.word_encoder.index2token(topi[b_i]))
                
                if all(sent_finished):
                    break
                    
        decoder_outputs = torch.cat(decoder_outputs, axis=0).transpose(0, 1)

        return decoder_outputs, sentences

In [30]:
vocab = ["Prey", "Center", "North", "South", "East", "West"]

word_encoder = OneHotEncoder(vocab)

decoder = GRUDecoder(8, 4, word_encoder, word_encoder.max_len)

In [29]:
context = torch.rand((10, 8))

decoder(context)

[False False False False False False False False False False]
tensor([[[-1.1207,  0.4461,  0.1331, -0.3018],
         [-0.5092,  0.4545,  0.4991, -0.6824],
         [-1.1207,  0.4461,  0.1331, -0.3018],
         [-0.5092,  0.4545,  0.4991, -0.6824],
         [-1.1207,  0.4461,  0.1331, -0.3018],
         [-0.5092,  0.4545,  0.4991, -0.6824],
         [-1.2223,  0.8321, -0.3479,  0.2439],
         [-1.1207,  0.4461,  0.1331, -0.3018],
         [-1.1207,  0.4461,  0.1331, -0.3018],
         [-0.5092,  0.4545,  0.4991, -0.6824]]], grad_fn=<EmbeddingBackward0>) torch.Size([1, 10, 4])
tensor([[[6],
         [3],
         [6],
         [3],
         [6],
         [3],
         [7],
         [6],
         [6],
         [3]]]) torch.Size([1, 10, 1])
tensor([[[-1.2223,  0.8321, -0.3479,  0.2439],
         [-1.2223,  0.8321, -0.3479,  0.2439],
         [-1.2223,  0.8321, -0.3479,  0.2439],
         [-1.2223,  0.8321, -0.3479,  0.2439],
         [-1.2223,  0.8321, -0.3479,  0.2439],
         [-1.

(tensor([[[-2.0143, -2.4315, -2.3772,  ..., -2.2706, -1.7719, -1.9388],
          [-1.9641, -2.5612, -2.4553,  ..., -2.2944, -1.8945, -1.6837],
          [-1.9899, -2.6548, -2.5493,  ..., -2.1778, -1.8877, -1.7256],
          ...,
          [-2.0450, -2.7311, -2.5879,  ..., -2.0423, -1.8256, -1.8274],
          [-1.9978, -2.6879, -2.5317,  ..., -2.1645, -1.8806, -1.6965],
          [-2.0175, -2.7131, -2.5777,  ..., -2.0928, -1.8540, -1.7688]],
 
         [[-1.8724, -2.5057, -2.5699,  ..., -2.4312, -1.8118, -2.0060],
          [-1.9098, -2.4040, -2.4410,  ..., -2.3952, -1.9116, -1.6924],
          [-1.9578, -2.5444, -2.5456,  ..., -2.2343, -1.9084, -1.7252],
          ...,
          [-2.0398, -2.7287, -2.5890,  ..., -2.0496, -1.8312, -1.8179],
          [-2.0450, -2.7310, -2.5879,  ..., -2.0424, -1.8257, -1.8274],
          [-1.9978, -2.6879, -2.5317,  ..., -2.1645, -1.8806, -1.6965]],
 
         [[-2.0023, -2.5269, -2.3458,  ..., -2.3258, -1.6529, -2.0698],
          [-1.9662, -2.5731,

In [23]:
np.concatenate((np.array(()), np.ones(10)))

array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])