In [75]:
import numpy as np
import torch

from torch import nn

from src.algo.language.lm import OneHotEncoder, GRUEncoder, init_rnn_params

class GRUDecoder(nn.Module):
    """
    Class for a language decoder using a Gated Recurrent Unit network
    """
    def __init__(self, context_dim, embed_dim, word_encoder, max_len, 
                 n_layers=1, embed_layer= None, device="cpu"):
        """
        Inputs:
            :param context_dim (int): Dimension of the context vectors
            :param word_encoder (OneHotEncoder): Word encoder, associating 
                tokens with one-hot encodings
            :param n_layers (int): number of layers in the GRU (default: 1)
            :param device (str): CUDA device
        """
        super(GRUDecoder, self).__init__()
        self.device = device
        self.max_len = max_len
        # Dimension of hidden states
        self.hidden_dim = context_dim
        # Word encoder
        self.word_encoder = word_encoder
        # Number of recurrent layers
        self.n_layers = n_layers
        # Embedding layer
        if embed_layer is not None:
            self.embed_layer = embed_layer
        else:
            self.embed_layer = nn.Embedding(self.word_encoder.enc_dim, embed_dim)
        # Model
        self.gru = nn.GRU(
            embed_dim, 
            self.hidden_dim, 
            self.n_layers)
        init_rnn_params(self.gru)
        # Output layer
        self.out = nn.Sequential(
            nn.Linear(self.hidden_dim, self.word_encoder.enc_dim),
            nn.LogSoftmax(dim=2)
        )

    def forward_step(self, last_token, last_hidden):
        """
        Generate prediction from GRU network.
        Inputs:
            :param last_token (torch.Tensor): Token at last time step, 
                dim=(1, 1, token_dim).
            :param last_hidden (torch.Tensor): Hidden state of the GRU at last
                time step, dim=(1, 1, hidden_dim).
        Outputs:
            :param output (torch.Tensor): Log-probabilities outputed by the 
                model, dim=(1, 1, token_dim).
            :param hidden (torch.Tensor): New hidden state of the GRU network,
                dim=(1, 1, hidden_dim).
        """
        output, hidden = self.gru(last_token, last_hidden)
        output = self.out(output)
        return output, hidden
    
    def forward(self, context_batch, target_encs=None):
        """
        Transforms context vectors to sentences
        Inputs:
            :param context_batch (torch.Tensor): Batch of context vectors,
                dim=(batch_size, context_dim).
            :param target_encs (torch.Tensor): Batch of target sentences used
                for teacher forcing, encoded as onehots and padded with -1, 
                dim=(batch_size, max_sent_len, enc_dim). If None then no 
                teacher forcing. Default: None.
        Outputs:
            :param decoder_outputs (list): Batch of tensors containing
                log-probabilities generated by the GRU network.
            :param sentences (list): Sentences generated with greedy 
                sampling. Empty if target_encs is not None (teacher forcing,
                so we only care about model predictions).
        """
        teacher_forcing = target_encs is not None
        batch_size = context_batch.size(0)
        max_sent_len = target_encs.shape[1] if teacher_forcing \
            else self.max_len

        if teacher_forcing:
            # Embed
            target_ids = target_encs.argmax(-1)
            target_embeds = self.embed_layer(target_ids)

        hidden = context_batch.unsqueeze(0)
        # Init last token to the SOS token, embedded
        last_tokens = self.embed_layer(
            torch.zeros((1, batch_size), dtype=torch.int).to(self.device))

        tokens = []
        decoder_outputs = []
        # sentences = [[] for b_i in range(batch_size)]
        sentences = None
        sent_finished = np.array([False] * batch_size).reshape((1, batch_size, 1))
        for t_i in range(max_sent_len):
            # RNN pass
            outputs, hidden = self.forward_step(last_tokens, hidden)
            decoder_outputs.append(outputs)

            # Sample next tokens
            if teacher_forcing:
                last_tokens = target_embeds[:, t_i].unsqueeze(0)
            else:
                _, topi = outputs.topk(1)
                
                # Set next decoder input
                last_tokens = self.embed_layer(topi.squeeze(-1))

                # Add next token, if sentence is not already finished (then pad with -1)
                topi = topi.cpu().numpy()
                next_token_ids = sent_finished * -1 + (1 - sent_finished) * topi
                if sentences is None:
                    sentences = next_token_ids
                else:
                    sentences = np.concatenate((sentences, next_token_ids), -1)

                # Check for finished sentences
                sent_finished = sent_finished | (topi == 1)
                
                if sent_finished.all():
                    break
                    
        decoder_outputs = torch.cat(decoder_outputs, axis=0).transpose(0, 1)

        return decoder_outputs, sentences

In [80]:
vocab = ["Prey", "Center", "North", "South", "East", "West"]

word_encoder = OneHotEncoder(vocab)

decoder = GRUDecoder(8, 4, word_encoder, word_encoder.max_len)

In [81]:
context = torch.rand((10, 8))

decoder(context)

0
[[[5]
  [5]
  [5]
  [5]
  [5]
  [5]
  [5]
  [5]
  [5]
  [5]]] (1, 10, 1)
[[[5]
  [5]
  [5]
  [5]
  [5]
  [5]
  [5]
  [5]
  [5]
  [5]]]
[[[False]
  [False]
  [False]
  [False]
  [False]
  [False]
  [False]
  [False]
  [False]
  [False]]]
1
[[[1]
  [4]
  [1]
  [1]
  [1]
  [1]
  [1]
  [4]
  [4]
  [1]]] (1, 10, 1)
[[[5 1]
  [5 4]
  [5 1]
  [5 1]
  [5 1]
  [5 1]
  [5 1]
  [5 4]
  [5 4]
  [5 1]]]
[[[ True]
  [False]
  [ True]
  [ True]
  [ True]
  [ True]
  [ True]
  [False]
  [False]
  [ True]]]
2
[[[5]
  [5]
  [5]
  [5]
  [5]
  [5]
  [5]
  [5]
  [5]
  [5]]] (1, 10, 1)
[[[ 5  1 -1]
  [ 5  4  5]
  [ 5  1 -1]
  [ 5  1 -1]
  [ 5  1 -1]
  [ 5  1 -1]
  [ 5  1 -1]
  [ 5  4  5]
  [ 5  4  5]
  [ 5  1 -1]]]
[[[ True]
  [False]
  [ True]
  [ True]
  [ True]
  [ True]
  [ True]
  [False]
  [False]
  [ True]]]
3
[[[4]
  [4]
  [4]
  [4]
  [4]
  [4]
  [4]
  [4]
  [4]
  [4]]] (1, 10, 1)
[[[ 5  1 -1 -1]
  [ 5  4  5  4]
  [ 5  1 -1 -1]
  [ 5  1 -1 -1]
  [ 5  1 -1 -1]
  [ 5  1 -1 -1]
  [ 5  1 -1 -1]
  [ 5 

(tensor([[[-2.7296, -1.7853, -2.6413,  ..., -1.5117, -2.0266, -2.2326],
          [-2.3762, -1.7887, -2.2434,  ..., -1.8768, -2.0259, -2.1445],
          [-2.3078, -1.8734, -2.1645,  ..., -1.7791, -2.1838, -2.0539],
          ...,
          [-2.1854, -1.9318, -2.0495,  ..., -1.9464, -2.1663, -2.0893],
          [-2.3611, -1.9145, -2.2350,  ..., -1.5959, -2.3149, -2.1266],
          [-2.1854, -1.9318, -2.0495,  ..., -1.9464, -2.1663, -2.0893]],
 
         [[-2.3895, -2.0528, -2.4236,  ..., -1.5763, -2.1796, -2.0863],
          [-2.1751, -1.9779, -2.1291,  ..., -1.9431, -2.0957, -2.0721],
          [-2.3625, -1.9347, -2.2816,  ..., -1.5980, -2.2808, -2.1208],
          ...,
          [-2.1854, -1.9318, -2.0495,  ..., -1.9464, -2.1663, -2.0893],
          [-2.3611, -1.9145, -2.2350,  ..., -1.5959, -2.3149, -2.1266],
          [-2.1854, -1.9318, -2.0495,  ..., -1.9464, -2.1663, -2.0893]],
 
         [[-2.5890, -1.8510, -2.4937,  ..., -1.6452, -2.0672, -1.9756],
          [-2.3419, -1.8387,

In [46]:













np.concatenate((np.array((10,
                         )), np.ones(10, 1)))

TypeError: Cannot interpret '1' as a data type