<a href="https://colab.research.google.com/github/DmitriyValetov/nlp_course_project/blob/master/inferences.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import numpy as np
from numpy.random import choice, randint

np.random.seed(0)
emb_size = 5  # 0, 1, 2, 3, 4
batch_size = 2
x1_lens = randint(1, 4, batch_size)
print(x1_lens)
max_len = np.max(x1_lens)
print(max_len)
x1_base = randint(1, 5, (batch_size, 4))
bx1 = np.array([[2] + list(choice([1, 4], x)) + [3] + [0] * (max_len - x) for x in x1_lens])
print(bx1)

def predict(bx1, bx2):
  np.random.seed(0)
  batch_size, x2_len = bx2.shape
  a = np.random.pareto(2, (x2_len, batch_size, emb_size))  # activations
  a = a.swapaxes(0, 1)  # right behaviour
  # print(a)
  s = np.exp(a)/np.sum(np.exp(a), axis=2, keepdims=True)  # softmax
  # print(s)
  # print(np.sum(s, axis=2))
  by2p = np.log(s)  # log softmax
  return by2p

[1 2]
2
[[2 1 3 0]
 [2 1 4 3]]


In [0]:
def greedy(bx1, sos=2, eos=3, max_len=10):
  bx2 = np.full((batch_size, 1), sos)  # batch with <sos> token
  # stop when predictions len > max_len or all have <eos> token
  while 1 + bx2.shape[1] < max_len and not np.all(np.any(bx2 == eos, axis=1)):
    # print(bx2)
    by2p = predict(bx1, bx2)
    lp = by2p[:,-1,:]  # last prediction
    # print(by2)
    # print(lp)
    next_bx2 = np.argmax(lp, axis=1).reshape((lp.shape[0], 1))
    # print(next_bx2)
    bx2 = np.concatenate((bx2, next_bx2), axis=1)
    # print(bx2)
    # print(bx2 == eos)
    # print(np.all(np.any(bx2 == eos, axis=1)))
  return by2p

def greedy_many_to_one(bx1, sos=2, eos=3, max_len=10, reduction=np.sum):
  bx2 = np.full((bx1.shape[0], 1), sos)  # batch with <sos> token
  # stop when predictions len > max_len or all have <eos> token
  while 1 + bx2.shape[1] < max_len and not np.all(np.any(bx2 == eos, axis=1)):
    print(bx2)
    by2p = predict(bx1, bx2)
    lp = by2p[:,-1,:]  # last prediction
    print(lp)
    lp_red = reduction(lp, axis=0)
    print(lp_red)
    # print(by2)
    # print(lp)
    # next_bx2 = np.argmax(lp, axis=1).reshape((lp.shape[0], 1))
    next_x2 = np.argmax(lp_red, axis=0)
    # print(next_x2)
    next_bx2 = np.full((lp.shape[0], 1), next_x2)
    # print(next_bx2)
    bx2 = np.concatenate((bx2, next_bx2), axis=1)
    # print(bx2)
    # print(bx2 == eos)
    # print(np.all(np.any(bx2 == eos, axis=1)))
  return by2p

def forced(bx1, bx2):
  by2p = predict(bx1, bx2)
  return by2p

def many_to_one(by2p, reduction=np.sum):
  y2p = np.sum(by2p[:,:-1,:], axis=0)
  y2 = np.argmax(y2p, axis=1)
  return y2

# by2p = greedy_many_to_one(bx1, reduction=np.sum)
# print(by2p)
# by2p = greedy_many_to_one(bx1, reduction=np.max)
# print(by2p)
# by2p = greedy_many_to_one(bx1, reduction=np.mean)
# print(by2p)
# y2 = np.argmax(np.sum(by2p[:,:-1,:], axis=0), axis=1)
# print(y2)
# print(many_to_one_sum(by2p))
# by2p = greedy(bx1)
# print(by2p)
# by2 = np.argmax(by2p[:,:-1,:], axis=2)
# print(by2)
# print(many_to_one(by2p, np.sum))
# print(many_to_one(by2p, np.max))
# print(many_to_one(by2p, np.mean))
# forced
# bx2 = np.concatenate((np.full((batch_size, 1), 2), by2), axis=1)
# print(bx2)
# by2p = forced(bx1, bx2)
# # print(bpy2)
# by2 = np.argmax(by2p[:,:-1,:], axis=2)
# print(by2)

In [38]:
# TODO do many_to_one search with batch_reduction
# TODO try Viterbi algorithm for best beam searching (like beam search but in reversed direction)
# TODO maybe try spectral beams? (With multiple reductions + Viterbi)

def beam(bx1, sos=2, eos=3, max_len=6,
         beam_width=2, beam_depth=2, depth_reduction=np.sum):
  batch_size = bx1.shape[0]
  bx2 = np.full((batch_size, 1), sos)  # base batch with <sos> token
  # stop when predictions len > max_len or all have <eos> token
  cur_len = bx2.shape[1] - 1  # without <sos>
  while cur_len < max_len and not np.all(np.any(bx2 == eos, axis=1)):
    bx2t = bx2.copy()  # temporal batch bx2
    for i in range(beam_depth):
      by2p = predict(bx1, bx2t)  # all predictions
      lp = by2p[:,-1,:]  # last prediction
      next_bx2 = np.argsort(-lp, axis=1)[:,:beam_width]  # top "beam_width" last predictions
      bx2t = np.repeat(bx2t, beam_width, axis=0)  # multiply batch_size by beam_width
      next_bx2 = next_bx2.reshape(batch_size*beam_width**(i+1), 1)  # to new batch_size
      bx2t = np.concatenate((bx2t, next_bx2), axis=1)  # update batch
    # Prediction with temporal batch
    by2p = predict(bx1, bx2t)  # all predictions by last temporal batch
    by2p = by2p[:,cur_len:-1]  # [base part:-last]  # remove old predictions
    bx2t = bx2t[:,cur_len+1:]  # [base part + <sos>:]  # remove old labels
    # Best beams searching
    by2p = np.take_along_axis(by2p, bx2t[:,np.newaxis], axis=2)  # beams transitions predictions
    by2p = by2p.reshape(batch_size, beam_width**beam_depth, -1)  # reshape to base batch shape (like)
    by2p_red = depth_reduction(by2p, axis=2)  # beams reduction, i.e. giving them scores
    best_beams = np.argmax(by2p_red, axis=1)  # get indices of the best beams
    # Base batch updating
    bx2t = bx2t.reshape(batch_size, beam_width**beam_depth, -1)  # reshape to base batch shape (like)
    bx2t = np.take_along_axis(bx2t, best_beams[:, np.newaxis, np.newaxis], axis=1)  # get best beams
    bx2t = bx2t.reshape(batch_size, -1)  # to new labels to base batch_size
    bx2 = np.concatenate((bx2, bx2t), axis=1)  # update base batch
    cur_len = bx2.shape[1] - 1  # without <sos>
  if bx2.shape[1] - 1 > max_len:  # cut to max_len
    bx2 = bx2[:max_len + 1]
  return bx2

bx2 = beam(bx1, depth_reduction=np.sum)
print(bx2)
# bx2 = beam(bx1, depth_reduction=np.max)
# print(bx2)
# bx2 = beam(bx1, depth_reduction=np.mean)
# print(bx2)

[[2 2 2 1 3]
 [2 3 3 0 1]]


In [126]:
import torch
bt = bx1[:,1:]  # without <sos>
print(bt)
print(bt[bt != 0])
bp = by2[:,:-1,:]  # without last prediction
print(bp)
print(bt != 0)
print(bp[bt != 0])

bp, bt = torch.tensor(bp), torch.tensor(bt)
loss = torch.nn.NLLLoss()
print(loss(bp[bt != 0], bt[bt != 0]))

loss = torch.nn.NLLLoss(ignore_index=0)
print(loss(torch.flatten(bp, 0, 1), torch.flatten(bt, 0)))

loss = 0
cnt = 0
for i, t in enumerate(bt):
  for j, l in enumerate(t):
    print(l)
    print(bp[i,j,l])
    loss += bp[i,j,l].item()
    cnt += 1
print(loss, loss/cnt)

[[5 3 0]
 [1 5 3]]
[5 3 1 5 3]
[[[-1.89006574 -1.50502211 -1.79218746 -1.89650799 -2.06159531
   -1.69833605]
  [-4.11028414 -2.40401057 -0.19777221 -4.17018005 -3.25252363
   -3.98678273]
  [-2.63134807 -0.48678304 -3.11534831 -3.10624282 -3.14261297
   -1.70861462]]

 [[-3.42675656 -2.82523264 -3.41128033 -0.24047856 -3.04555381
   -3.18465323]
  [-2.15389649 -1.21536353 -1.96396499 -1.79775144 -2.31043832
   -1.70277709]
  [-2.90484114 -2.89473326 -0.29414425 -2.73762603 -3.26092274
   -3.17766349]]]
[[ True  True False]
 [ True  True  True]]
[[-1.89006574 -1.50502211 -1.79218746 -1.89650799 -2.06159531 -1.69833605]
 [-4.11028414 -2.40401057 -0.19777221 -4.17018005 -3.25252363 -3.98678273]
 [-3.42675656 -2.82523264 -3.41128033 -0.24047856 -3.04555381 -3.18465323]
 [-2.15389649 -1.21536353 -1.96396499 -1.79775144 -2.31043832 -1.70277709]
 [-2.90484114 -2.89473326 -0.29414425 -2.73762603 -3.26092274 -3.17766349]]
tensor(2.6268, dtype=torch.float64)
tensor(2.6268, dtype=torch.float64)
