In [17]:
import operator
import torch
import torch.nn as nn
import torch.nn.functional as F
import heapq
from types import SimpleNamespace

# code based loosely on the following 2 projects
# https://github.com/budzianowski/PyTorch-Beam-Search-Decoding/blob/master/decode_beam.py
# https://github.com/IBM/pytorch-seq2seq/blob/master/seq2seq/models/DecoderRNN.py


class BeamDecoder:          
    def decode(self, utt_embeds, beam_width=5):
        output = []
        for idx, utt_embeds in enumerate(utt_embeds): 
            # set up decoding, init queue and set start variables
            device = utt_embeds.device
            cur_utt = utt_embeds[0]
            hx = torch.zeros([self.h_size], device=device)
            cx = torch.zeros([self.h_size], device=device)
            
            # add first point to states
            first_node = SimpleNamespace(path=[self.start_tok], log_prob=0, h=hx, c=cx)
            current_states = [first_node]    
            
            # do beam search
            for utt_num in range(len(utt_embeds)):
                utt = utt_embeds[utt_num]
                next_states = []
                for node in current_states:
                    prev_label =  torch.LongTensor([node.path[-1]]).to(device)
                    
                    # run through next cell and select top k probabilities
                    y, hx, cx = self.step(embed=utt, label=prev_label, hx=node.h, cx=node.c)
                    y_log_probs = F.log_softmax(y, dim=-1)
                    log_prob, indexes = torch.topk(y_log_probs, beam_width)
                    
                    # add all new states to next search space
                    for prob, ind in zip(log_prob, indexes):
                        path = node.path.copy() + [ind.item()]
                        prob = node.log_prob + prob
                        next_node = SimpleNamespace(path=path, log_prob=prob, h=hx, c=cx)
                        next_states.append(next_node)
                    
                # prune states to the best k states
                next_states.sort(key=lambda x: x.log_prob, reverse=True)
                next_states = next_states[:beam_width]
                current_states = next_states
            
            solution = current_states[0]
            output.append(solution.path[1:])
        output = torch.LongTensor(output).to(device) #[B, N]
        output = F.one_hot(output, num_classes=self.num_class).float() #[B, N, C]
        return output
    
class DecoderRNN(nn.Module, BeamDecoder):
    def __init__(self, cell_type, num_class, embed_size=10, h_size=768, rnn_h_size=300, dropout=0.0):
        '''RNN decoder which when given a sequence of vectors, outputs sequence of decisions'''
        super().__init__()
        
        #make embeddings for labels
        self.embedding = nn.Embedding(num_class+1, embed_size) # [-1] is start token
        self.start_tok = num_class
        self.num_class = num_class
        
        #make RNN decoder
        if cell_type.lower() == 'lstm':  self.rnn_cell = nn.LSTM
        elif cell_type.lower() == 'gru': self.rnn_cell = nn.GRU
        else: raise ValueError("unsupported RNN type")
        
        rnn_input_size = embed_size+h_size #concatentation to effectively have 2 inputs (Wy + Wh)
        self.rnn = self.rnn_cell(rnn_input_size, rnn_h_size, bidirectional=False, dropout=dropout, batch_first=True)
        self.h_size = rnn_h_size

        #output classifier
        self.classifier = nn.Linear(rnn_h_size, num_class)

    def forward(self, utt_embeds, labels):
        """teacher forcing training"""
        labels = torch.roll(labels, 1, -1)    #roll labels to use previous
        labels[:, 0] = self.start_tok         #set start token
        labels[labels==-100] = self.start_tok #pad all labels with -100
        label_embed = self.embedding(labels)  # [B, N]->[B, N, D_e]
        
        rnn_inputs  = torch.cat((utt_embeds, label_embed), dim=-1)
        output, (hn, cn) = self.rnn(rnn_inputs)
        y = self.classifier(output)         # [B, N, D]->[B, N, 43]
        return y

    def step(self, embed, label, hx=None, cx=None):
        """steps a single input through a RNN cell. label.shape = [1]
        embed.shape = cx.shape = hx.shape = [100]"""
        
        embed       = embed.view(1,1,-1)
        label_embed = self.embedding(label).view(1,1,-1)
        rnn_input   = torch.cat((embed, label_embed), dim=-1)  # [1,1,E]
        output, (hx, cx) = self.rnn(rnn_input, (hx.view(1,1,-1), cx.view(1,1,-1)))
        y = self.classifier(output)
        return y.view(-1), hx.view(-1), cx.view(-1)


model = DecoderRNN('lstm', 5, 10, 100, 50, 0)
a = torch.rand((4,10,100))
b = torch.randint(0, 5, (4,10))
#print(a.shape)
#print(b.shape)
#model(a, b)
y = model.decode(a)
print(y)
#embed = torch.rand((100))
#label = torch.LongTensor([1])
#h, c = torch.zeros(50), torch.zeros(50)
#print(h.shape)
#model.step(embed, label, h, c)

pass

tensor([[[0, 0, 1, 0, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 0, 1, 0]],

        [[0, 0, 1, 0, 0],
         [0, 0, 1, 0, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 0, 1, 0]],

        [[0, 0, 0, 1, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 0, 1, 0]],

        [[0, 0, 1, 0, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 0, 1, 0],
      

In [3]:
from torch.nn.utils.rnn import pad_sequence

a = torch.rand(9, 768)
b = torch.rand(6, 768)
c = torch.rand(8, 768)

x = pad_sequence([a,b,c], batch_first=True, padding_value=0.0)
mask = torch.all((x!=0), dim=-1)
print(mask.dtype)
print(x.shape)

torch.bool
torch.Size([3, 9, 768])


In [None]:
def beam_decode(target_tensor, decoder_hiddens, encoder_outputs=None):
    '''
    :param target_tensor: target indexes tensor of shape [B, T] where B is the batch size and T is the maximum length of the output sentence
    :param decoder_hidden: input tensor of shape [1, B, H] for start of the decoding
    :param encoder_outputs: if you are using attention mechanism you can pass encoder outputs, [T, B, H] where T is the maximum length of input sentence
    :return: decoded_batch
    '''

    beam_width = 10
    topk = 1  # how many sentence do you want to generate
    decoded_batch = []

    # decoding goes sentence by sentence
    for idx in range(target_tensor.size(0)):
        if isinstance(decoder_hiddens, tuple):  # LSTM case
            decoder_hidden = (decoder_hiddens[0][:,idx, :].unsqueeze(0),decoder_hiddens[1][:,idx, :].unsqueeze(0))
        else:
            decoder_hidden = decoder_hiddens[:, idx, :].unsqueeze(0)
        encoder_output = encoder_outputs[:,idx, :].unsqueeze(1)

        # Start with the start of the sentence token
        decoder_input = torch.LongTensor([[SOS_token]], device=device)

        # Number of sentence to generate
        endnodes = []
        number_required = min((topk + 1), topk - len(endnodes))

        # starting node -  hidden vector, previous node, word id, logp, length
        node = BeamSearchNode(decoder_hidden, None, decoder_input, 0, 1)
        nodes = PriorityQueue()

        # start the queue
        nodes.put((-node.eval(), node))
        qsize = 1

        # start beam search
        while True:
            # give up when decoding takes too long
            if qsize > 2000: break

            # fetch the best node
            score, n = nodes.get()
            decoder_input = n.wordid
            decoder_hidden = n.h

            if n.wordid.item() == EOS_token and n.prevNode != None:
                endnodes.append((score, n))
                # if we reached maximum # of sentences required
                if len(endnodes) >= number_required:
                    break
                else:
                    continue

            print(decoder_input.shape, decoder_hidden.shape)
            # decode for one step using decoder
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_output)

            # PUT HERE REAL BEAM SEARCH OF TOP
            log_prob, indexes = torch.topk(decoder_output, beam_width)
            nextnodes = []

            for new_k in range(beam_width):
                decoded_t = indexes[0][new_k].view(1, -1)
                log_p = log_prob[0][new_k].item()

                node = BeamSearchNode(decoder_hidden, n, decoded_t, n.logp + log_p, n.leng + 1)
                score = -node.eval()
                nextnodes.append((score, node))

            # put them into queue
            for i in range(len(nextnodes)):
                score, nn = nextnodes[i]
                nodes.put((score, nn))
                # increase qsize
 

In [None]:
import operator
import torch
import torch.nn as nn
import torch.nn.functional as F
from queue import PriorityQueue
device = torch.device("cpu")

SOS_token = 0
EOS_token = 1
MAX_LENGTH = 50


class DecoderRNN(nn.Module):
    def __init__(self, embedding_size, hidden_size, output_size, cell_type, dropout=0.1):
        '''
        Illustrative decoder
        '''
        super(DecoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.cell_type = cell_type
        self.embedding = nn.Embedding(num_embeddings=output_size,
                                      embedding_dim=embedding_size,
                                      )

        self.rnn = nn.GRU(embedding_size, hidden_size, bidirectional=False, dropout=dropout, batch_first=False)
        self.dropout_rate = dropout
        self.out = nn.Linear(hidden_size, output_size)

    def forward(self, input, hidden, not_used):
        embedded = self.embedding(input).transpose(0, 1)  # [B,1] -> [ 1, B, D]
        embedded = F.dropout(embedded, self.dropout_rate)

        output = embedded

        output, hidden = self.rnn(output, hidden)

        out = self.out(output.squeeze(0))
        output = F.log_softmax(out, dim=1)
        return output, hidden


class BeamSearchNode(object):
    def __init__(self, hiddenstate, previousNode, wordId, logProb, length):
        '''
        :param hiddenstate:
        :param previousNode:
        :param wordId:
        :param logProb:
        :param length:
        '''
        self.h = hiddenstate
        self.prevNode = previousNode
        self.wordid = wordId
        self.logp = logProb
        self.leng = length

    def eval(self, alpha=1.0):
        reward = 0
        # Add here a function for shaping a reward

        return self.logp / float(self.leng - 1 + 1e-6) + alpha * reward



def beam_decode(target_tensor, decoder_hiddens, encoder_outputs=None):
    '''
    :param target_tensor: target indexes tensor of shape [B, T] where B is the batch size and T is the maximum length of the output sentence
    :param decoder_hidden: input tensor of shape [1, B, H] for start of the decoding
    :param encoder_outputs: if you are using attention mechanism you can pass encoder outputs, [T, B, H] where T is the maximum length of input sentence
    :return: decoded_batch
    '''

    beam_width = 10
    topk = 1  # how many sentence do you want to generate
    decoded_batch = []

    # decoding goes sentence by sentence
    for idx in range(target_tensor.size(0)):
        if isinstance(decoder_hiddens, tuple):  # LSTM case
            decoder_hidden = (decoder_hiddens[0][:,idx, :].unsqueeze(0),decoder_hiddens[1][:,idx, :].unsqueeze(0))
        else:
            decoder_hidden = decoder_hiddens[:, idx, :].unsqueeze(0)
        encoder_output = encoder_outputs[:,idx, :].unsqueeze(1)

        # Start with the start of the sentence token
        decoder_input = torch.LongTensor([[SOS_token]], device=device)

        # Number of sentence to generate
        endnodes = []
        number_required = min((topk + 1), topk - len(endnodes))

        # starting node -  hidden vector, previous node, word id, logp, length
        node = BeamSearchNode(decoder_hidden, None, decoder_input, 0, 1)
        nodes = PriorityQueue()

        # start the queue
        nodes.put((-node.eval(), node))
        qsize = 1

        # start beam search
        while True:
            # give up when decoding takes too long
            if qsize > 2000: break

            # fetch the best node
            score, n = nodes.get()
            decoder_input = n.wordid
            decoder_hidden = n.h

            if n.wordid.item() == EOS_token and n.prevNode != None:
                endnodes.append((score, n))
                # if we reached maximum # of sentences required
                if len(endnodes) >= number_required:
                    break
                else:
                    continue

            print(decoder_input.shape, decoder_hidden.shape)
            # decode for one step using decoder
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_output)

            # PUT HERE REAL BEAM SEARCH OF TOP
            log_prob, indexes = torch.topk(decoder_output, beam_width)
            nextnodes = []

            for new_k in range(beam_width):
                decoded_t = indexes[0][new_k].view(1, -1)
                log_p = log_prob[0][new_k].item()

                node = BeamSearchNode(decoder_hidden, n, decoded_t, n.logp + log_p, n.leng + 1)
                score = -node.eval()
                nextnodes.append((score, node))

            # put them into queue
            for i in range(len(nextnodes)):
                score, nn = nextnodes[i]
                nodes.put((score, nn))
                # increase qsize
            qsize += len(nextnodes) - 1

        # choose nbest paths, back trace them
        if len(endnodes) == 0:
            endnodes = [nodes.get() for _ in range(topk)]

        utterances = []
        for score, n in sorted(endnodes, key=operator.itemgetter(0)):
            utterance = []
            utterance.append(n.wordid)
            # back trace
            while n.prevNode != None:
                n = n.prevNode
                utterance.append(n.wordid)

            utterance = utterance[::-1]
            utterances.append(utterance)

        decoded_batch.append(utterances)

    return decoded_batch

decoder = DecoderRNN(100, 100, 100, 'lstm', dropout=0)

a = torch.randint(0, 5, (4,10))
b = torch.randint(0, 5, (1, 4, 100))
c = torch.rand((10, 4, 100))
beam_decode(a, b, c)