<a href="https://colab.research.google.com/github/DmitriyValetov/nlp_course_project/blob/master/inferences.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [14]:
import numpy as np
from numpy.random import choice, randint

np.random.seed(0)
pad, unk, sos, eos = 0, 1, 2, 3
emb_size = 5  # 0, 1, 2, 3, 4
batch_size = 3
x1_lens = randint(1, 4, batch_size)
print(x1_lens)
max_len = np.max(x1_lens)
print(max_len)
x1_base = randint(1, 5, (batch_size, 4))
bx1 = np.array([[2] + list(choice([1, 4], x)) + [3] + [0] * (max_len - x) for x in x1_lens])
print(bx1)

def predict(bx1, bx2, emb_size=5):
  np.random.seed(0)
  batch_size, len_x2 = bx2.shape
  a = np.random.pareto(2, (len_x2, batch_size, emb_size))  # activations
  a = a.swapaxes(0, 1)  # for consitency
  s = np.exp(a)/np.sum(np.exp(a), axis=2, keepdims=True)  # softmax
  by2p = np.log(s)  # log softmax
  return by2p

import torch
from torch.distributions.pareto import Pareto
import torch.nn.functional as F

def torch_predict(bx1, bx2, emb_size=5):
  torch.manual_seed(0)
  batch_size, len_x2 = bx2.size()
  bx1 = torch.repeat_interleave(bx1, int(batch_size/bx1.size(0)), 0)
  scale, alpha = bx1[:,0], bx1[:,-1]
  scale[scale == 0] = 1
  alpha[alpha == 0] = 1
  p = Pareto(scale.float(), alpha.float())  # Pareto
  a = p.sample((len_x2, emb_size))
  a = a.permute(2, 0, 1)
  by2p = F.log_softmax(a, 2)  # log softmax predictions
  return by2p

bx1 = torch.tensor(bx1)
bx2 = torch.full((batch_size, 1), 2)
torch_predict(bx1, bx2)

[1 2 1]
2
[[2 1 3 0]
 [2 1 1 3]
 [2 4 3 0]]


tensor([[[ 0.0000e+00, -4.1549e+01, -6.4350e+01, -6.3975e+01, -6.3683e+01]],

        [[-1.2998e+00, -1.4892e+00, -2.0037e+00, -1.8685e+00, -1.5480e+00]],

        [[-5.8917e+00, -1.5008e-02, -4.8140e+00, -7.1616e+00, -5.7313e+00]]])

In [0]:
def greedy(bx1, sos=2, eos=3, max_len=10):
  bx2 = np.full((batch_size, 1), sos)  # batch with <sos> token
  # stop when predictions len > max_len or all have <eos> token
  while 1 + bx2.shape[1] < max_len and not np.all(np.any(bx2 == eos, axis=1)):
    # print(bx2)
    by2p = predict(bx1, bx2)
    lp = by2p[:,-1,:]  # last prediction
    # print(by2)
    # print(lp)
    next_bx2 = np.argmax(lp, axis=1).reshape((lp.shape[0], 1))
    # print(next_bx2)
    bx2 = np.concatenate((bx2, next_bx2), axis=1)
    # print(bx2)
    # print(bx2 == eos)
    # print(np.all(np.any(bx2 == eos, axis=1)))
  return by2p

def greedy_many_to_one(bx1, sos=2, eos=3, max_len=10, reduction=np.sum):
  bx2 = np.full((bx1.shape[0], 1), sos)  # batch with <sos> token
  # stop when predictions len > max_len or all have <eos> token
  while 1 + bx2.shape[1] < max_len and not np.all(np.any(bx2 == eos, axis=1)):
    print(bx2)
    by2p = predict(bx1, bx2)
    lp = by2p[:,-1,:]  # last prediction
    print(lp)
    lp_red = reduction(lp, axis=0)
    print(lp_red)
    # print(by2)
    # print(lp)
    # next_bx2 = np.argmax(lp, axis=1).reshape((lp.shape[0], 1))
    next_x2 = np.argmax(lp_red, axis=0)
    # print(next_x2)
    next_bx2 = np.full((lp.shape[0], 1), next_x2)
    # print(next_bx2)
    bx2 = np.concatenate((bx2, next_bx2), axis=1)
    # print(bx2)
    # print(bx2 == eos)
    # print(np.all(np.any(bx2 == eos, axis=1)))
  return by2p

def forced(bx1, bx2):
  by2p = predict(bx1, bx2)
  return by2p

def many_to_one(by2p, reduction=np.sum):
  y2p = np.sum(by2p[:,:-1,:], axis=0)
  y2 = np.argmax(y2p, axis=1)
  return y2

# by2p = greedy_many_to_one(bx1, reduction=np.sum)
# print(by2p)
# by2p = greedy_many_to_one(bx1, reduction=np.max)
# print(by2p)
# by2p = greedy_many_to_one(bx1, reduction=np.mean)
# print(by2p)
# y2 = np.argmax(np.sum(by2p[:,:-1,:], axis=0), axis=1)
# print(y2)
# print(many_to_one_sum(by2p))
# by2p = greedy(bx1)
# print(by2p)
# by2 = np.argmax(by2p[:,:-1,:], axis=2)
# print(by2)
# print(many_to_one(by2p, np.sum))
# print(many_to_one(by2p, np.max))
# print(many_to_one(by2p, np.mean))
# forced
# bx2 = np.concatenate((np.full((batch_size, 1), 2), by2), axis=1)
# print(bx2)
# by2p = forced(bx1, bx2)
# # print(bpy2)
# by2 = np.argmax(by2p[:,:-1,:], axis=2)
# print(by2)

In [0]:
# + do many_to_one search with batch_reduction
# + try Viterbi algorithm for best beam searching (like beam search but in reversed direction)
# TODO maybe try spectral beams? (With multiple reductions + Viterbi)
# bx1 - encoder batch input
# bx2 - decoder batch input
# by2p - encoder batch output probabilties 
# by2 - encoder batch output

In [17]:
def beam(bx1, sos=2, eos=3, max_len=6,
         beam_width=2, beam_depth=2, depth_reduction=np.sum):
  batch_size = bx1.shape[0]
  bx2 = np.full((batch_size, 1), sos)  # base batch with <sos> token
  cur_len = bx2.shape[1] - 1  # without <sos>
  # stop when cur_len > max_len or all predictions have <eos> token
  while cur_len < max_len and not np.all(np.any(bx2 == eos, axis=1)):
    bx2t = bx2.copy()  # temporal batch bx2
    for i in range(beam_depth):
      by2p = predict(bx1, bx2t)  # all predictions
      lp = by2p[:,-1,:]  # last prediction
      next_bx2 = np.argsort(-lp, axis=1)[:,:beam_width]  # top "beam_width" last predictions
      bx2t = np.repeat(bx2t, beam_width, axis=0)  # multiply batch_size by beam_width
      next_bx2 = next_bx2.reshape(batch_size*beam_width**(i+1), 1)  # to new batch_size
      bx2t = np.concatenate((bx2t, next_bx2), axis=1)  # update batch
    # Prediction with temporal batch
    by2p = predict(bx1, bx2t)  # all predictions by last temporal batch
    by2p = by2p[:,cur_len:-1]  # [base part:-last]  # remove old predictions
    bx2t = bx2t[:,cur_len+1:]  # [base part + <sos>:]  # remove old labels
    # Best beams searching
    by2p = np.take_along_axis(by2p, bx2t[:,np.newaxis], axis=2)  # beams transitions predictions
    by2p = by2p.reshape(batch_size, beam_width**beam_depth, -1)  # reshape to base batch shape (like)
    by2p_red = depth_reduction(by2p, axis=2)  # beams reduction, i.e. giving them scores
    best_beams = np.argmax(by2p_red, axis=1)  # get indices of the best beams
    # Base batch updating
    bx2t = bx2t.reshape(batch_size, beam_width**beam_depth, -1)  # reshape to base batch shape (like)
    bx2t = np.take_along_axis(bx2t, best_beams[:, np.newaxis, np.newaxis], axis=1)  # get best beams
    bx2t = bx2t.reshape(batch_size, -1)  # to new labels to base batch_size
    bx2 = np.concatenate((bx2, bx2t), axis=1)  # update base batch
    cur_len = bx2.shape[1] - 1  # without <sos>
  if bx2.shape[1] - 1 > max_len:  # cut to max_len
    bx2 = bx2[:max_len + 1]  # <sos> included
  return bx2

by2 = beam(bx1, depth_reduction=np.sum)
print(by2)
by2 = beam(bx1, depth_reduction=np.mean)
print(by2)

[[2 2 4 3 3]
 [2 2 0 3 3]
 [2 3 2 4 3]]
[[2 2 4 3 3]
 [2 2 0 3 3]
 [2 3 2 4 3]]


In [18]:
def torch_beam(bx1, sos=2, eos=3, max_len=10,
               beam_width=2, beam_depth=2, depth_reduction=torch.sum):
  batch_size = bx1.size(0)  # input batch_size
  bx2 = torch.full((batch_size, 1), sos, dtype=bx1.dtype)  # batch with <sos>
  # stop when len decoder output > max_len or all decoder outputs have <eos> token
  while bx2.size(1) - 1 < max_len and not torch.all(torch.any(bx2 == eos, axis=1)):
    beam_scores = torch.empty((batch_size, 0))  # scores for each beam
    bx1t = bx1.clone()
    for i in range(beam_depth):                                           
      new_batch_size = batch_size*beam_width**(i+1)
      by2p = torch_predict(bx1, bx2)  # predict
      next_by2p, next_bx2 = torch.topk(by2p[:,-1,:], beam_width)  # beams to top k last predictions
      next_bx2 = next_bx2.view(new_batch_size, 1)  # new beams
      next_by2p = next_by2p.view(new_batch_size, 1)  # new scores
      beam_scores = torch.repeat_interleave(beam_scores, beam_width, 0)  # increase batch for new scores
      bx2 = torch.repeat_interleave(bx2, beam_width, 0)  # increase batch for new beams
      bx2 = torch.cat((bx2, next_bx2), 1)  # add beams
      beam_scores = torch.cat((beam_scores, next_by2p), 1)  # add beams scores
      bx1t = torch.repeat_interleave(bx1t, beam_width, 0)  # increase batch for new beams
    beam_scores = depth_reduction(beam_scores, axis=1) # cumulative beams scores
    beam_scores = beam_scores.view(batch_size, -1)  # split scores into batches
    best_beams = torch.argmax(beam_scores, axis=1, keepdim=True)  # best beams
    bx2 = bx2.view(batch_size, beam_width**beam_depth, -1)  # split beams into batches
    # XXX its fucking magic... (return to input batch_size)
    best_beams = best_beams.unsqueeze(2).expand(best_beams.size(0), best_beams.size(1), bx2.size(2))
    bx2 = torch.gather(bx2, 1, best_beams)
    bx2 = bx2.view(batch_size, -1)
  return bx2

# bx1 = torch.full((5, 1), 0, dtype=torch.long)
bx1 = torch.randint(0, 4, (3, 3))
print(bx1)
by2 = torch_beam(bx1, max_len=10,
                 beam_width=2, beam_depth=2, depth_reduction=torch.sum)
print(by2)
by2 = torch_beam(bx1, depth_reduction=torch.mean)
print(by2)

tensor([[1, 0, 3],
        [0, 3, 1],
        [2, 3, 3]])
tensor([[2, 0, 3, 0, 3, 1, 2, 4, 4, 1, 1],
        [2, 0, 1, 0, 0, 2, 1, 0, 0, 4, 3],
        [2, 1, 2, 1, 4, 3, 2, 3, 2, 1, 2]])
tensor([[2, 0, 3, 0, 3, 1, 2, 4, 4, 1, 1],
        [2, 0, 1, 0, 0, 2, 1, 0, 0, 4, 3],
        [2, 1, 2, 1, 4, 3, 2, 3, 2, 1, 2]])


In [19]:
def torch_beam_many_to_one(bx1, sos=2, eos=3, max_len=10,
                           beam_width=2, beam_depth=2, 
                           depth_reduction=torch.sum,
                           batch_reduction=torch.sum):
  batch_size = bx1.size(0)  # input batch_size
  bx2 = torch.full((batch_size, 1), sos, dtype=bx1.dtype)  # batch with <sos>
  # stop when len decoder output > max_len or all decoder outputs have <eos> token
  while bx2.size(1) - 1 < max_len and not torch.all(torch.any(bx2 == eos, axis=1)):
    beam_scores = torch.empty((batch_size, 0))  # scores for each beam
    bx1t = bx1.clone()
    for i in range(beam_depth):
      prev_batch_size = batch_size*beam_width**(i)
      new_batch_size = batch_size*beam_width**(i+1)
      by2p = torch_predict(bx1, bx2)  # predict
      by2p = batch_reduction(by2p, axis=0) # cumulative batch predictions
      next_by2p, next_bx2 = torch.topk(by2p[-1,:], beam_width)  # beams to top k last predictions
      next_bx2 = next_bx2.repeat(prev_batch_size, 1)  # return to prev batch size
      next_by2p = next_by2p.repeat(prev_batch_size, 1)  # return to prev batch size
      next_bx2 = next_bx2.view(new_batch_size, 1)  # new beams
      next_by2p = next_by2p.view(new_batch_size, 1)  # new scores
      beam_scores = torch.repeat_interleave(beam_scores, beam_width, 0)  # increase batch for new scores
      bx2 = torch.repeat_interleave(bx2, beam_width, 0)  # increase batch for new beams
      bx2 = torch.cat((bx2, next_bx2), 1)  # add beams
      beam_scores = torch.cat((beam_scores, next_by2p), 1)  # add beams scores
      bx1t = torch.repeat_interleave(bx1t, beam_width, 0)  # increase batch for new beams
    beam_scores = depth_reduction(beam_scores, axis=1) # cumulative beams scores
    best_beams = torch.argmax(beam_scores, axis=0, keepdim=True)  # best beams
    bx2 = bx2[best_beams]  # best beam of all batches
    bx2 = bx2.repeat(batch_size, 1)  # return to input batch size
  return bx2

# bx1 = torch.full((3, 1), 0, dtype=torch.long)
bx1 = torch.randint(1, 4, (3, 3))
print(bx1)
bx2 = torch_beam_many_to_one(bx1, depth_reduction=torch.sum,
                             batch_reduction=torch.sum)
print(bx2)
bx2 = torch_beam_many_to_one(bx1, depth_reduction=torch.mean,
                             batch_reduction=torch.sum)
print(bx2)
bx2 = torch_beam_many_to_one(bx1, depth_reduction=torch.mean, 
                             batch_reduction=torch.mean)
print(bx2)
bx2 = torch_beam_many_to_one(bx1, depth_reduction=torch.mean,
                             batch_reduction=torch.mean)
print(bx2)

tensor([[2, 1, 3],
        [3, 3, 3],
        [1, 2, 3]])
tensor([[2, 0, 3],
        [2, 0, 3],
        [2, 0, 3]])
tensor([[2, 0, 3],
        [2, 0, 3],
        [2, 0, 3]])
tensor([[2, 0, 3],
        [2, 0, 3],
        [2, 0, 3]])
tensor([[2, 0, 3],
        [2, 0, 3],
        [2, 0, 3]])


In [20]:
def torch_beam_reduction_many_to_one(bx1, sos=2, eos=3, max_len=10,
                                     beam_width=2, beam_depth=2, 
                                     depth_reduction=torch.sum,
                                     beam_reduction=torch.sum,
                                     batch_reduction=torch.sum):
  batch_size = bx1.size(0)  # input batch_size
  bx2 = torch.full((batch_size, 1), sos, dtype=bx1.dtype)  # batch with <sos>
  # stop when len decoder output > max_len or all decoder outputs have <eos> token
  while bx2.size(1) - 1 < max_len and not torch.all(torch.any(bx2 == eos, axis=1)):
    beam_scores = torch.empty((batch_size, 0))  # scores for each beam
    bx1t = bx1.clone()
    for i in range(beam_depth):
      prev_batch_size = batch_size*beam_width**(i)
      new_batch_size = batch_size*beam_width**(i+1)
      by2p = torch_predict(bx1, bx2)  # predict
      if i > 0:  # beam reduction
        prev_prev_batch_size = batch_size*beam_width**(i-1)
        by2p = by2p.view(prev_prev_batch_size, beam_width, by2p.size(1), -1)  # split predictions into batches
        by2p = beam_reduction(by2p, axis=1) # cumulative beams predictions
      by2p = batch_reduction(by2p, axis=0) # cumulative batch predictions
      next_by2p, next_bx2 = torch.topk(by2p[-1,:], beam_width)  # beams to top k last predictions
      next_bx2 = next_bx2.repeat(prev_batch_size, 1)  # return to prev batch size
      next_by2p = next_by2p.repeat(prev_batch_size, 1)  # return to prev batch size
      next_bx2 = next_bx2.view(new_batch_size, 1)  # new beams
      next_by2p = next_by2p.view(new_batch_size, 1)  # new scores
      beam_scores = torch.repeat_interleave(beam_scores, beam_width, 0)  # increase batch for new scores
      bx2 = torch.repeat_interleave(bx2, beam_width, 0)  # increase batch for new beams
      bx2 = torch.cat((bx2, next_bx2), 1)  # add beams
      beam_scores = torch.cat((beam_scores, next_by2p), 1)  # add beams scores
      bx1t = torch.repeat_interleave(bx1t, beam_width, 0)  # increase batch for new beams
    beam_scores = depth_reduction(beam_scores, axis=1) # cumulative beams scores
    best_beams = torch.argmax(beam_scores, axis=0, keepdim=True)  # best beams
    bx2 = bx2[best_beams]  # best beam of all batches
    bx2 = bx2.repeat(batch_size, 1)  # return to input batch size
  return bx2

# bx1 = torch.full((3, 1), 0, dtype=torch.long)
bx1 = torch.randint(1, 4, (3, 3))
print(bx1)
bx2 = torch_beam_reduction_many_to_one(bx1,
                                       depth_reduction=torch.sum,
                                       beam_reduction=torch.sum,
                                       batch_reduction=torch.sum)
print(bx2)
bx2 = torch_beam_reduction_many_to_one(bx1, 
                                       depth_reduction=torch.mean,
                                       beam_reduction=torch.sum,
                                       batch_reduction=torch.sum)
print(bx2)
bx2 = torch_beam_reduction_many_to_one(bx1, 
                                       depth_reduction=torch.mean, 
                                       beam_reduction=torch.sum,
                                       batch_reduction=torch.mean)
print(bx2)
bx2 = torch_beam_reduction_many_to_one(bx1, 
                                       depth_reduction=torch.mean,
                                       beam_reduction=torch.sum,
                                       batch_reduction=torch.mean)
print(bx2)
bx2 = torch_beam_reduction_many_to_one(bx1,
                                       depth_reduction=torch.sum,
                                       beam_reduction=torch.mean,
                                       batch_reduction=torch.sum)
print(bx2)
bx2 = torch_beam_reduction_many_to_one(bx1, 
                                       depth_reduction=torch.mean,
                                       beam_reduction=torch.mean,
                                       batch_reduction=torch.sum)
print(bx2)
bx2 = torch_beam_reduction_many_to_one(bx1, 
                                       depth_reduction=torch.mean, 
                                       beam_reduction=torch.mean,
                                       batch_reduction=torch.mean)
print(bx2)
bx2 = torch_beam_reduction_many_to_one(bx1, 
                                       depth_reduction=torch.mean,
                                       beam_reduction=torch.mean,
                                       batch_reduction=torch.mean)
print(bx2)

tensor([[1, 2, 2],
        [2, 2, 3],
        [2, 2, 2]])
tensor([[2, 0, 3],
        [2, 0, 3],
        [2, 0, 3]])
tensor([[2, 0, 3],
        [2, 0, 3],
        [2, 0, 3]])
tensor([[2, 0, 3],
        [2, 0, 3],
        [2, 0, 3]])
tensor([[2, 0, 3],
        [2, 0, 3],
        [2, 0, 3]])
tensor([[2, 0, 3],
        [2, 0, 3],
        [2, 0, 3]])
tensor([[2, 0, 3],
        [2, 0, 3],
        [2, 0, 3]])
tensor([[2, 0, 3],
        [2, 0, 3],
        [2, 0, 3]])
tensor([[2, 0, 3],
        [2, 0, 3],
        [2, 0, 3]])


In [21]:
def torch_beam_reduction(bx1, sos=2, eos=3, max_len=10, 
                         beam_width=2, beam_depth=2, 
                         beam_reduction=torch.sum,
                         depth_reduction=torch.sum):
  batch_size = bx1.size(0)  # input batch_size
  bx2 = torch.full((batch_size, 1), sos, dtype=bx1.dtype)  # batch with <sos>
  # stop when len decoder output > max_len or all decoder outputs have <eos> token
  while bx2.size(1) - 1 < max_len and not torch.all(torch.any(bx2 == eos, axis=1)):
    beam_scores = torch.empty((batch_size, 0))  # scores for each beam
    bx1t = bx1.clone()
    for i in range(beam_depth):
      new_batch_size = batch_size*beam_width**(i+1)
      by2p = torch_predict(bx1, bx2)  # predict
      if i > 0:  # beam reduction
        prev_prev_batch_size = batch_size*beam_width**(i-1)
        by2p = by2p.view(prev_prev_batch_size, beam_width, by2p.size(1), -1)  # split predictions into batches
        by2p = beam_reduction(by2p, axis=1) # cumulative beams predictions
        by2p = by2p.repeat(beam_width, 1, 1)  # return to prev batch size
      next_by2p, next_bx2 = torch.topk(by2p[:,-1,:], beam_width)  # beams to top k last predictions
      next_bx2 = next_bx2.view(new_batch_size, 1)  # new beams
      next_by2p = next_by2p.view(new_batch_size, 1)  # new scores
      beam_scores = torch.repeat_interleave(beam_scores, beam_width, 0)  # increase batch for new scores
      bx2 = torch.repeat_interleave(bx2, beam_width, 0)  # increase batch for new beams
      bx2 = torch.cat((bx2, next_bx2), 1)  # add beams
      beam_scores = torch.cat((beam_scores, next_by2p), 1)  # add beams scores
      bx1t = torch.repeat_interleave(bx1t, beam_width, 0)  # increase batch for new beams
    beam_scores = depth_reduction(beam_scores, axis=1) # cumulative beams scores
    beam_scores = beam_scores.view(batch_size, -1)  # split scores into batches
    best_beams = torch.argmax(beam_scores, axis=1, keepdim=True)  # best beams
    bx2 = bx2.view(batch_size, beam_width**beam_depth, -1)  # split beams into batches
    # XXX its fucking magic... (return to input batch_size)
    best_beams = best_beams.unsqueeze(2).expand(best_beams.size(0), best_beams.size(1), bx2.size(2))
    bx2 = torch.gather(bx2, 1, best_beams)
    bx2 = bx2.view(batch_size, -1)
  return bx2

# bx1 = torch.full((5, 1), 0, dtype=torch.long)
bx1 = torch.randint(0, 4, (3, 3))
print(bx1)
bx2 = torch_beam_reduction(bx1, depth_reduction=torch.sum, beam_reduction=torch.sum)
print(bx2)
bx2 = torch_beam_reduction(bx1, depth_reduction=torch.mean, beam_reduction=torch.sum)
print(bx2)
bx2 = torch_beam_reduction(bx1, depth_reduction=torch.sum, beam_reduction=torch.mean)
print(bx2)
bx2 = torch_beam_reduction(bx1, depth_reduction=torch.mean, beam_reduction=torch.mean)
print(bx2)

tensor([[3, 3, 1],
        [1, 1, 0],
        [1, 1, 1]])
tensor([[2, 0, 3, 4, 2, 1, 0],
        [2, 1, 3, 0, 1, 3, 0],
        [2, 1, 1, 1, 2, 3, 1]])
tensor([[2, 0, 3, 4, 2, 1, 0],
        [2, 1, 3, 0, 1, 3, 0],
        [2, 1, 1, 1, 2, 3, 1]])
tensor([[2, 0, 3, 0, 3, 1, 0],
        [2, 1, 3, 0, 1, 2, 2],
        [2, 1, 1, 1, 2, 3, 1]])
tensor([[2, 0, 3, 0, 3, 1, 0],
        [2, 1, 3, 0, 1, 2, 2],
        [2, 1, 1, 1, 2, 3, 1]])


In [22]:
def torch_beam_viterbi(bx1, sos=2, eos=3, max_len=4,
                       beam_width=2, beam_depth=2):
  batch_size = bx1.size(0)  # input batch_size
  bx2 = torch.full((batch_size, 1), sos, dtype=bx1.dtype)  # batch with <sos>
  # stop when len decoder output > max_len or all decoder outputs have <eos> token
  while bx2.size(1) - 1 < max_len and not torch.all(torch.any(bx2 == eos, axis=1)):
    fwd_scores = torch.empty((batch_size, 0))  # scores for each beam
    fwd_bx2 = bx2.clone()
    bx1t = bx1.clone()
    for i in range(beam_depth):
      new_batch_size = batch_size*beam_width**(i+1)
      by2p = torch_predict(bx1, fwd_bx2)  # predict
      next_by2p, next_bx2 = torch.topk(by2p[:,-1,:], beam_width)  # beams to top k last predictions
      next_bx2 = next_bx2.view(new_batch_size, 1)  # new beams
      next_by2p = next_by2p.view(new_batch_size, 1)  # new scores
      fwd_scores = torch.repeat_interleave(fwd_scores, beam_width, 0)  # increase batch for new scores
      fwd_bx2 = torch.repeat_interleave(fwd_bx2, beam_width, 0)  # increase batch for new beams
      fwd_bx2 = torch.cat((fwd_bx2, next_bx2), 1)  # add beams
      fwd_scores = torch.cat((fwd_scores, next_by2p), 1)  # add beams scores
      bx1t = torch.repeat_interleave(bx1t, beam_width, 0)  # increase batch for new beams
    fwd_scores = fwd_scores.view(batch_size, beam_width**beam_depth, -1)  # split scores into batches
    fwd_bx2 = fwd_bx2.view(batch_size, beam_width**beam_depth, -1)  # split beams into batches
    bkw_bx2 = torch.empty((batch_size, 0), dtype=bx1.dtype)  # scores for each beam
    bkw_scores =  torch.empty((batch_size, 0))  # scores for each beam
    mask = torch.full((fwd_bx2.size(0), fwd_bx2.size(1), 1), False, dtype=torch.bool)
    for i in range(fwd_scores.size(2)):
      cur_values = fwd_bx2[:,:,-i-1:-i if i > 0 else None]
      cur_scores = fwd_scores[:,:,-i-1:-i if i > 0 else None]
      cur_scores[mask] = float('-inf')  # mask from prev step
      best_scores, best_beams = torch.max(cur_scores, axis=1, keepdim=True)
      best_values = torch.gather(cur_values, 1, best_beams).view(batch_size, -1)
      exp_best_values = best_values[:,None,:].expand(-1, cur_values.size(1), -1)
      mask = cur_values != exp_best_values
      best_values = best_values.view(batch_size, -1)
      bkw_bx2 = torch.cat((best_values, bkw_bx2), 1)
      best_scores = best_scores.view(batch_size, -1)
      bkw_scores = torch.cat((best_scores, bkw_scores), 1)
    bx2 = torch.cat((bx2, bkw_bx2), 1)  # <sos> + back_bx2
  return bx2

bx1 = torch.full((5, 1), 0, dtype=torch.long)
bx1 = torch.randint(0, 4, (3, 3))
print(bx1)
by2 = torch_beam_viterbi(bx1)
print(by2)

tensor([[0, 3, 0],
        [3, 0, 0],
        [0, 0, 1]])
tensor([[2, 0, 3, 4, 0],
        [2, 1, 4, 0, 0],
        [2, 2, 3, 2, 1]])


In [23]:
def torch_beam_bkw(bx1, sos=2, eos=3, max_len=4,
                   beam_width=2, beam_depth=2, 
                   bkw_beam_width=2,
                   depth_reduction=torch.sum):
  batch_size = bx1.size(0)  # input batch_size
  bx2 = torch.full((batch_size, 1), sos, dtype=bx1.dtype)  # batch with <sos>
  # stop when len decoder output > max_len or all decoder outputs have <eos> token
  while bx2.size(1) - 1 < max_len and not torch.all(torch.any(bx2 == eos, axis=1)):
    fwd_scores = torch.empty((batch_size, 0))  # scores for each beam
    fwd_bx2 = bx2.clone()
    bx1t = bx1.clone()
    for i in range(beam_depth):
      new_batch_size = batch_size*beam_width**(i+1)
      by2p = torch_predict(bx1, fwd_bx2)  # predict
      next_by2p, next_bx2 = torch.topk(by2p[:,-1,:], beam_width)  # beams to top k last predictions
      next_bx2 = next_bx2.view(new_batch_size, 1)  # new beams
      next_by2p = next_by2p.view(new_batch_size, 1)  # new scores
      fwd_scores = torch.repeat_interleave(fwd_scores, beam_width, 0)  # increase batch for new scores
      fwd_bx2 = torch.repeat_interleave(fwd_bx2, beam_width, 0)  # increase batch for new beams
      fwd_bx2 = torch.cat((fwd_bx2, next_bx2), 1)  # add beams
      fwd_scores = torch.cat((fwd_scores, next_by2p), 1)  # add beams scores
      bx1t = torch.repeat_interleave(bx1t, beam_width, 0)  # increase batch for new beams
    fwd_scores = fwd_scores.view(batch_size, beam_width**beam_depth, -1)  # split scores into batches
    fwd_bx2 = fwd_bx2.view(batch_size, beam_width**beam_depth, -1)  # split beams into batches
    bkw_bx2 = torch.empty((batch_size, beam_width**beam_depth, 0), dtype=bx1.dtype)  # scores for each beam
    bkw_scores = torch.empty((batch_size, beam_width**beam_depth, 0))  # scores for each beam
    mask = torch.full((fwd_bx2.size(0), fwd_bx2.size(1), 1), False, dtype=torch.bool)
    n_beams_per_batch = beam_width**beam_depth
    for i in range(fwd_scores.size(2)):
      old_n_beams_per_batch = n_beams_per_batch*bkw_beam_width**(i)
      new_n_beams_per_batch = n_beams_per_batch*bkw_beam_width**(i+1)
      cur_values = fwd_bx2[:,:,-i-1:-i if i > 0 else None]
      cur_values = torch.repeat_interleave(cur_values, bkw_beam_width**i, 1)
      cur_scores = fwd_scores[:,:,-i-1:-i if i > 0 else None]
      cur_scores = torch.repeat_interleave(cur_scores, bkw_beam_width**i, 1)
      cur_scores[mask] = float('-inf')  # mask from prev step
      best_scores, best_beams = torch.topk(cur_scores, bkw_beam_width, 1)  # best beams
      best_values = torch.gather(cur_values, 1, best_beams).view(batch_size, -1)
      best_values = best_values.view(best_values.size(0), bkw_beam_width, -1)
      best_values = torch.repeat_interleave(best_values, old_n_beams_per_batch, 1)
      cur_values = cur_values.repeat(1, bkw_beam_width, 1)
      mask = cur_values != best_values
      bkw_bx2 = bkw_bx2.repeat(1, bkw_beam_width, 1)  # increase batch for new beams
      bkw_bx2 = torch.cat((best_values, bkw_bx2), 2)
      best_scores = torch.repeat_interleave(best_scores, old_n_beams_per_batch, 1)
      bkw_scores = bkw_scores.repeat(1, bkw_beam_width, 1)  # increase batch for new beams
      bkw_scores = torch.cat((best_scores, bkw_scores), 2)
    beam_scores = depth_reduction(bkw_scores, axis=2) # cumulative beams scores
    best_beams = torch.argmax(beam_scores, axis=1, keepdim=True)  # best beams
    best_beams = best_beams.unsqueeze(2).expand(best_beams.size(0), best_beams.size(1), bkw_bx2.size(2))
    bkw_bx2 = torch.gather(bkw_bx2, 1, best_beams) # best beam of all batches
    bkw_bx2 = bkw_bx2.view(batch_size, -1)
    bx2 = torch.cat((bx2, bkw_bx2), 1)  # <sos> + bkw_bx2
  return bx2

bx1 = torch.full((5, 1), 0, dtype=torch.long)
bx1 = torch.randint(0, 4, (3, 3))
print(bx1)
by2 = torch_beam_bkw(bx1, depth_reduction=torch.sum)
print(by2)
by2 = torch_beam_bkw(bx1, depth_reduction=torch.mean)
print(by2)

tensor([[0, 1, 2],
        [1, 1, 3],
        [1, 0, 0]])
tensor([[2, 0, 3, 0, 0],
        [2, 0, 4, 0, 2],
        [2, 1, 3, 1, 1]])
tensor([[2, 0, 3, 0, 0],
        [2, 0, 4, 0, 2],
        [2, 1, 3, 1, 1]])


In [24]:
def torch_beam_bkw_many_to_one(bx1, sos=2, eos=3, max_len=10,
                               beam_width=3, beam_depth=2, 
                               bkw_beam_width=2,
                               depth_reduction=torch.sum,
                               batch_reduction=torch.sum):
  batch_size = bx1.size(0)  # input batch_size
  bx2 = torch.full((batch_size, 1), sos, dtype=bx1.dtype)  # batch with <sos>
  # stop when len decoder output > max_len or all decoder outputs have <eos> token
  while bx2.size(1) - 1 < max_len and not torch.all(torch.any(bx2 == eos, axis=1)):
    fwd_scores = torch.empty((batch_size, 0))  # scores for each beam
    fwd_bx2 = bx2.clone()
    bx1t = bx1.clone()
    for i in range(beam_depth):
      prev_batch_size = batch_size*beam_width**(i)
      new_batch_size = batch_size*beam_width**(i+1)
      by2p = torch_predict(bx1, fwd_bx2)  # predict
      # fwd batch reduction
      by2p = batch_reduction(by2p, axis=0) # cumulative batch predictions
      next_by2p, next_bx2 = torch.topk(by2p[-1,:], beam_width)  # beams to top k last predictions
      next_bx2 = next_bx2.repeat(prev_batch_size, 1)  # return to prev batch size
      next_by2p = next_by2p.repeat(prev_batch_size, 1)  # return to prev batch size
      # fwd
      next_bx2 = next_bx2.view(new_batch_size, 1)  # new beams
      next_by2p = next_by2p.view(new_batch_size, 1)  # new scores
      fwd_scores = torch.repeat_interleave(fwd_scores, beam_width, 0)  # increase batch for new scores
      fwd_bx2 = torch.repeat_interleave(fwd_bx2, beam_width, 0)  # increase batch for new beams
      fwd_bx2 = torch.cat((fwd_bx2, next_bx2), 1)  # add beams
      fwd_scores = torch.cat((fwd_scores, next_by2p), 1)  # add beams scores
      bx1t = torch.repeat_interleave(bx1t, beam_width, 0)  # increase batch for new beams
    # bkw
    bkw_bx2 = torch.empty((new_batch_size, 0), dtype=bx1.dtype)  # scores for each beam
    bkw_scores = torch.empty((new_batch_size, 0))  # scores for each beam
    mask = torch.full((new_batch_size, 1), False, dtype=torch.bool)
    fwd_batch_size = new_batch_size
    for i in range(fwd_scores.size(1)):
      prev_batch_size = fwd_batch_size*bkw_beam_width**(i)
      new_batch_size = fwd_batch_size*bkw_beam_width**(i+1)
      cur_values = fwd_bx2[:,-i-1:-i if i > 0 else None]
      cur_values = torch.repeat_interleave(cur_values, bkw_beam_width**i, 0)
      cur_scores = fwd_scores[:,-i-1:-i if i > 0 else None]
      cur_scores = torch.repeat_interleave(cur_scores, bkw_beam_width**i, 0)
      cur_scores[mask] = float('-inf')  # mask from prev step
      best_scores, best_beams = torch.topk(cur_scores, bkw_beam_width, 0)  # best beams
      best_values = cur_values[best_beams]
      mask = cur_values != best_values
      mask = mask.view(new_batch_size, -1)
      best_values = best_values.view(bkw_beam_width, -1)
      best_values = torch.repeat_interleave(best_values, prev_batch_size, 0)
      bkw_bx2 = bkw_bx2.repeat(bkw_beam_width, 1)  # increase batch for new beams
      bkw_bx2 = torch.cat((best_values, bkw_bx2), 1)
      best_scores = torch.repeat_interleave(best_scores, prev_batch_size, 0)
      bkw_scores = bkw_scores.repeat(bkw_beam_width, 1)  # increase batch for new beams
      bkw_scores = torch.cat((best_scores, bkw_scores), 1)
    # bkw batch reduction
    bkw_scores = depth_reduction(bkw_scores, axis=1) # cumulative beams scores
    best_beams = torch.argmax(bkw_scores, axis=0, keepdim=True)  # best beams
    bkw_bx2 = bkw_bx2[best_beams]  # best beam of all batches
    bkw_bx2 = bkw_bx2.repeat(batch_size, 1)  # return to input batch size
    bx2 = torch.cat((bx2, bkw_bx2), 1)  # <sos> + bkw_bx2
  return bx2

bx1 = torch.full((20, 1), 0, dtype=torch.long)
bx1 = torch.randint(0, 4, (3, 3))
print(bx1)
by2 = torch_beam_bkw_many_to_one(bx1, depth_reduction=torch.sum, 
                                 batch_reduction=torch.sum)
print(by2)
by2 = torch_beam_bkw_many_to_one(bx1, depth_reduction=torch.mean, 
                                 batch_reduction=torch.sum)
print(by2)
by2 = torch_beam_bkw_many_to_one(bx1, depth_reduction=torch.sum, 
                                 batch_reduction=torch.mean)
print(by2)
by2 = torch_beam_bkw_many_to_one(bx1, depth_reduction=torch.mean, 
                                 batch_reduction=torch.mean)
print(by2)

tensor([[0, 1, 2],
        [1, 1, 3],
        [1, 0, 0]])
tensor([[2, 1, 0, 1, 1, 1, 4, 3, 1],
        [2, 1, 0, 1, 1, 1, 4, 3, 1],
        [2, 1, 0, 1, 1, 1, 4, 3, 1]])
tensor([[2, 1, 0, 1, 1, 1, 4, 3, 1],
        [2, 1, 0, 1, 1, 1, 4, 3, 1],
        [2, 1, 0, 1, 1, 1, 4, 3, 1]])
tensor([[2, 1, 0, 1, 1, 1, 4, 3, 1],
        [2, 1, 0, 1, 1, 1, 4, 3, 1],
        [2, 1, 0, 1, 1, 1, 4, 3, 1]])
tensor([[2, 1, 0, 1, 1, 1, 4, 3, 1],
        [2, 1, 0, 1, 1, 1, 4, 3, 1],
        [2, 1, 0, 1, 1, 1, 4, 3, 1]])


In [0]:
# import torch
# bt = bx1[:,1:]  # without <sos>
# print(bt)
# print(bt[bt != 0])
# bp = by2[:,:-1,:]  # without last prediction
# print(bp)
# print(bt != 0)
# print(bp[bt != 0])

# bp, bt = torch.tensor(bp), torch.tensor(bt)
# loss = torch.nn.NLLLoss()
# print(loss(bp[bt != 0], bt[bt != 0]))

# loss = torch.nn.NLLLoss(ignore_index=0)
# print(loss(torch.flatten(bp, 0, 1), torch.flatten(bt, 0)))

# loss = 0
# cnt = 0
# for i, t in enumerate(bt):
#   for j, l in enumerate(t):
#     print(l)
#     print(bp[i,j,l])
#     loss += bp[i,j,l].item()
#     cnt += 1
# print(loss, loss/cnt)