In [1]:
import torch
from torch import nn

  from .autonotebook import tqdm as notebook_tqdm


In [59]:
torch.set_printoptions(sci_mode=False)

t=0
- start with SOS
- decoder_out: [N, |Vocab|], with candidates: |Vocab|
- discard all candidates 
- select top-k: [N, k]

t=1
- start with top k hypothesis: [N, k]
- flatten to decoder_input: [N*k] (ensure that hidden and cell is replicated)

In [17]:
N = 4
k=3
H=2
vocab_size= 10
decoder_input = torch.randint(low=1, high=10, size=[N, k])
decoder_hidden = torch.rand([1, N, H])
decoder_cell = torch.rand([1, N, H])

In [18]:
# flatten decoder input
decoder_input_flat = decoder_input.flatten() # [N*k]

# replicate hidden and cell for all k per N and flatten
 # [1, N, 1, H] -> [1, N, k, H] -> [1, N*k, H]
decoder_hidden_flat = decoder_hidden[..., None, :].expand(-1, -1, k, -1).flatten(1, 2)
decoder_cell_flat = decoder_cell[..., None, :].expand(-1, -1, k, -1).flatten(1,2)

In [50]:
decoder_output = torch.rand([N*k, vocab_size])
decoder_output=decoder_output.unflatten(0, [N, k]) # [N, k, |Vocab|]

In [48]:
neg_constraints = [0, 5, 9] # extra ingredients

In [179]:
discard_mask = torch.ones_like(decoder_output)
discard_mask[:, :, neg_constraints] = 0
discard_mask

tensor([[[0., 1., 1., 1., 1., 0., 1., 1., 1., 0.],
         [0., 1., 1., 1., 1., 0., 1., 1., 1., 0.],
         [0., 1., 1., 1., 1., 0., 1., 1., 1., 0.]],

        [[0., 1., 1., 1., 1., 0., 1., 1., 1., 0.],
         [0., 1., 1., 1., 1., 0., 1., 1., 1., 0.],
         [0., 1., 1., 1., 1., 0., 1., 1., 1., 0.]],

        [[0., 1., 1., 1., 1., 0., 1., 1., 1., 0.],
         [0., 1., 1., 1., 1., 0., 1., 1., 1., 0.],
         [0., 1., 1., 1., 1., 0., 1., 1., 1., 0.]],

        [[0., 1., 1., 1., 1., 0., 1., 1., 1., 0.],
         [0., 1., 1., 1., 1., 0., 1., 1., 1., 0.],
         [0., 1., 1., 1., 1., 0., 1., 1., 1., 0.]]])

In [68]:
alpha = 10 # preserve top 10 likelihood (NOTE: in reality probably should be much higher)

In [69]:
# get top alpha likelihoods per batch
topalpha_likelihood = decoder_output.flatten(1, 2).topk(k=alpha, dim=1).values # [N, alpha]
topalpha_likelihood

tensor([[0.8981, 0.8748, 0.8729, 0.8588, 0.8464, 0.8034, 0.7439, 0.7008, 0.6946,
         0.6321],
        [0.9818, 0.9199, 0.9032, 0.7661, 0.7591, 0.7459, 0.6528, 0.6097, 0.6004,
         0.5518],
        [0.9361, 0.9312, 0.9002, 0.8954, 0.8801, 0.8235, 0.8165, 0.7799, 0.7223,
         0.7157],
        [0.9277, 0.9043, 0.8776, 0.8353, 0.7703, 0.7264, 0.7006, 0.6453, 0.6336,
         0.6155]])

In [70]:
threshold = topalpha_likelihood.min(-1).values # minimum values to be included within top alpha
threshold # [N]

tensor([0.6321, 0.5518, 0.7157, 0.6155])

In [71]:
decoder_output

tensor([[[    0.0041,     0.2649,     0.0137,     0.6946,     0.8034,
              0.5727,     0.0461,     0.7008,     0.6321,     0.0922],
         [    0.4026,     0.8464,     0.6045,     0.2223,     0.7439,
              0.8588,     0.4581,     0.0875,     0.1411,     0.0917],
         [    0.6170,     0.0664,     0.3458,     0.8981,     0.8748,
              0.1114,     0.5228,     0.3961,     0.8729,     0.4127]],

        [[    0.7459,     0.5518,     0.0910,     0.0005,     0.5403,
              0.2545,     0.6004,     0.6097,     0.3484,     0.5266],
         [    0.6528,     0.9818,     0.4146,     0.7591,     0.2558,
              0.3842,     0.4622,     0.1786,     0.3990,     0.2000],
         [    0.9199,     0.5465,     0.2253,     0.4462,     0.2527,
              0.0406,     0.7661,     0.9032,     0.0472,     0.5398]],

        [[    0.6322,     0.1846,     0.4445,     0.8801,     0.6527,
              0.2099,     0.8235,     0.7223,     0.7021,     0.8165],
         

In [72]:
ltalphathreshold = decoder_output < threshold[:, None, None]

In [73]:
ltalphathreshold.shape

torch.Size([4, 3, 10])

In [98]:
satisfied_clauses_so_far = torch.randint(1, 3, size=[N,k])

In [99]:
satisfied_clauses_so_far

tensor([[1, 1, 1],
        [2, 1, 1],
        [2, 1, 2],
        [1, 2, 2]])

In [100]:
pos_constraints = [1, 2, 7] # input ingredients (indexes)

In [108]:
sat_clauses_now = satisfied_clauses_so_far.unsqueeze(-1).expand(-1, -1, vocab_size).clone()
# ! need to lookbehind for multi-word constraints
# ! check the last word in each constraint, if it matches this word, check the last word generated by hypothesis k (keep going back until full match or mismatch)
sat_clauses_now[:, :, pos_constraints] += 1 # add current satisfaction
sat_clauses_now

tensor([[[1, 2, 2, 1, 1, 1, 1, 2, 1, 1],
         [1, 2, 2, 1, 1, 1, 1, 2, 1, 1],
         [1, 2, 2, 1, 1, 1, 1, 2, 1, 1]],

        [[2, 3, 3, 2, 2, 2, 2, 3, 2, 2],
         [1, 2, 2, 1, 1, 1, 1, 2, 1, 1],
         [1, 2, 2, 1, 1, 1, 1, 2, 1, 1]],

        [[2, 3, 3, 2, 2, 2, 2, 3, 2, 2],
         [1, 2, 2, 1, 1, 1, 1, 2, 1, 1],
         [2, 3, 3, 2, 2, 2, 2, 3, 2, 2]],

        [[1, 2, 2, 1, 1, 1, 1, 2, 1, 1],
         [2, 3, 3, 2, 2, 2, 2, 3, 2, 2],
         [2, 3, 3, 2, 2, 2, 2, 3, 2, 2]]])

In [None]:
beta = 2 # preserve those that satisfy top 2 no, of constraints

In [147]:
y = sat_clauses_now.flatten(1, 2).sort(dim=-1).values
# subtracting, so duplicate values will become 0
# e.g., first row: [1, 2, 3, 0, 0, 4, 0]
y[:, 1:] *= ((y[:, 1:] - y[:, :-1]) !=0).long()

In [151]:
topbetaval = y.topk(beta).values
topbetaval

tensor([[2, 1],
        [3, 2],
        [3, 2],
        [3, 2]])

In [152]:
beta_thresh = topbetaval.min(-1).values
beta_thresh # [N]

tensor([1, 2, 2, 2])

In [157]:
ltbetathresh = sat_clauses_now < beta_thresh[:, None, None] # ! USE ACCUMULATED LIKELIHOOD
ltbetathresh

tensor([[[False, False, False, False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False, False, False]],

        [[False, False, False, False, False, False, False, False, False, False],
         [ True, False, False,  True,  True,  True,  True, False,  True,  True],
         [ True, False, False,  True,  True,  True,  True, False,  True,  True]],

        [[False, False, False, False, False, False, False, False, False, False],
         [ True, False, False,  True,  True,  True,  True, False,  True,  True],
         [False, False, False, False, False, False, False, False, False, False]],

        [[ True, False, False,  True,  True,  True,  True, False,  True,  True],
         [False, False, False, False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False, False, False]]])

In [190]:
# log probabilities
score = nn.functional.log_softmax(decoder_output, -1).clone()
score

tensor([[[-2.7284, -2.4676, -2.7188, -2.0379, -1.9291, -2.1598, -2.6864,
          -2.0317, -2.1003, -2.6403],
         [-2.3880, -1.9442, -2.1861, -2.5683, -2.0467, -1.9318, -2.3325,
          -2.7031, -2.6495, -2.6989],
         [-2.2386, -2.7891, -2.5098, -1.9574, -1.9807, -2.7441, -2.3327,
          -2.4594, -1.9827, -2.4428]],

        [[-2.0090, -2.2031, -2.6639, -2.7544, -2.2146, -2.5004, -2.1545,
          -2.1452, -2.4065, -2.2283],
         [-2.1503, -1.8213, -2.3885, -2.0439, -2.5473, -2.4189, -2.3408,
          -2.6245, -2.4040, -2.6031],
         [-1.8993, -2.2727, -2.5938, -2.3730, -2.5664, -2.7786, -2.0531,
          -1.9160, -2.7720, -2.2794]],

        [[-2.3031, -2.7507, -2.4908, -2.0552, -2.2827, -2.7254, -2.1118,
          -2.2130, -2.2332, -2.1189],
         [-2.8391, -2.6401, -2.6594, -1.9499, -2.2506, -2.3839, -2.1703,
          -2.2779, -1.9858, -2.2401],
         [-2.4304, -1.9500, -2.1013, -1.9859, -2.6812, -2.2713, -2.7074,
          -2.7655, -2.2815, -2.2375

In [191]:
# instead of filtering out, do soft prune, ie. heavily penalize the score 
# (because if we filter, it is possible to get <k per batch so can encounter errors)
irreversible_satisfcation_penalty = 10
low_likelihood_penalty = 2
low_satisfied_clauses_penalty = 5
score[discard_mask.bool()] -= irreversible_satisfcation_penalty
score[ltalphathreshold] -= low_likelihood_penalty
score[ltbetathresh] -= low_satisfied_clauses_penalty

In [192]:
score

tensor([[[ -4.7284, -14.4676, -14.7188, -12.0379, -11.9291,  -4.1598, -14.6864,
          -12.0317, -12.1003,  -4.6403],
         [ -4.3880, -11.9442, -14.1861, -14.5683, -12.0467,  -1.9318, -14.3325,
          -14.7031, -14.6495,  -4.6989],
         [ -4.2386, -14.7891, -14.5098, -11.9574, -11.9807,  -4.7441, -14.3327,
          -14.4594, -11.9827,  -4.4428]],

        [[ -2.0090, -12.2031, -14.6639, -14.7544, -14.2146,  -4.5004, -12.1545,
          -12.1452, -14.4065,  -4.2283],
         [ -7.1503, -11.8213, -14.3885, -17.0439, -19.5473,  -9.4189, -19.3408,
          -14.6245, -19.4040,  -9.6031],
         [ -6.8993, -14.2727, -14.5938, -19.3730, -19.5664,  -9.7786, -17.0531,
          -11.9160, -19.7720,  -9.2794]],

        [[ -4.3031, -14.7507, -14.4908, -12.0552, -14.2827,  -4.7254, -12.1118,
          -12.2130, -14.2332,  -2.1189],
         [ -9.8391, -14.6401, -14.6594, -16.9499, -19.2506,  -9.3839, -17.1703,
          -14.2779, -16.9858,  -9.2401],
         [ -4.4304, -11.9500

## Grouping

<!-- for now do grouping by number of clauses -->

In [460]:
score

tensor([[[ -4.7284, -14.4676, -14.7188, -12.0379, -11.9291,  -4.1598, -14.6864,
          -12.0317, -12.1003,  -4.6403],
         [ -4.3880, -11.9442, -14.1861, -14.5683, -12.0467,  -1.9318, -14.3325,
          -14.7031, -14.6495,  -4.6989],
         [ -4.2386, -14.7891, -14.5098, -11.9574, -11.9807,  -4.7441, -14.3327,
          -14.4594, -11.9827,  -4.4428]],

        [[ -2.0090, -12.2031, -14.6639, -14.7544, -14.2146,  -4.5004, -12.1545,
          -12.1452, -14.4065,  -4.2283],
         [ -7.1503, -11.8213, -14.3885, -17.0439, -19.5473,  -9.4189, -19.3408,
          -14.6245, -19.4040,  -9.6031],
         [ -6.8993, -14.2727, -14.5938, -19.3730, -19.5664,  -9.7786, -17.0531,
          -11.9160, -19.7720,  -9.2794]],

        [[ -4.3031, -14.7507, -14.4908, -12.0552, -14.2827,  -4.7254, -12.1118,
          -12.2130, -14.2332,  -2.1189],
         [ -9.8391, -14.6401, -14.6594, -16.9499, -19.2506,  -9.3839, -17.1703,
          -14.2779, -16.9858,  -9.2401],
         [ -4.4304, -11.9500

In [358]:
g=20
topgvals, topginds = score.flatten(-2, -1).topk(g)

In [359]:
topgvals

tensor([[ -1.9318,  -4.1598,  -4.2386,  -4.3880,  -4.4428,  -4.6403,  -4.6989,
          -4.7284,  -4.7441, -11.9291, -11.9442, -11.9574, -11.9807, -11.9827,
         -12.0317, -12.0379, -12.0467, -12.1003, -14.1861, -14.3325],
        [ -2.0090,  -4.2283,  -4.5004,  -6.8993,  -7.1503,  -9.2794,  -9.4189,
          -9.6031,  -9.7786, -11.8213, -11.9160, -12.1452, -12.1545, -12.2031,
         -14.2146, -14.2727, -14.3885, -14.4065, -14.5938, -14.6245],
        [ -2.1189,  -4.2375,  -4.2713,  -4.3031,  -4.4304,  -4.7254,  -9.2401,
          -9.3839,  -9.8391, -11.9500, -11.9859, -12.0552, -12.1013, -12.1118,
         -12.2130, -14.2332, -14.2779, -14.2815, -14.2827, -14.4908],
        [ -2.0885,  -4.2715,  -4.4114,  -4.4142,  -4.4412,  -4.6962,  -7.1762,
          -9.4129,  -9.6391, -11.8562, -11.8820, -12.0357, -12.1007, -12.1446,
         -12.1703, -12.2555, -14.2185, -14.2222, -14.2756, -14.3027]])

In [462]:
topginds

tensor([[15,  5, 20, 10, 29,  9, 19,  0, 25,  4, 11, 23, 24, 28,  7,  3, 14,  8,
         12, 16],
        [ 0,  9,  5, 20, 10, 29, 15, 19, 25, 11, 27,  7,  6,  1,  4, 21, 12,  8,
         22, 17],
        [ 9, 29, 25,  0, 20,  5, 19, 15, 10, 21, 23,  3, 22,  6,  7,  8, 17, 28,
          4,  2],
        [10, 25, 20, 15, 19, 29,  9,  5,  0, 11,  2, 28, 22, 27, 26, 24,  7, 14,
         13, 12]])

In [368]:
bs=torch.arange(4).unsqueeze(-1).expand(-1, g)
bs

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]])

In [371]:
topg_sat_clauses_now = sat_clauses_now.flatten(-2, -1)[bs, topginds]
topg_sat_clauses_now

tensor([[0, 0, 0, 0, 2, 0, 1, 0, 0, 0, 0, 0, 1, 2, 1, 0, 0, 1, 0, 0],
        [0, 2, 3, 0, 1, 0, 2, 0, 3, 3, 1, 1, 3, 3, 2, 1, 0, 2, 0, 3],
        [0, 0, 0, 3, 1, 3, 2, 0, 3, 1, 2, 2, 0, 3, 2, 2, 1, 3, 0, 1],
        [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0]])

In [372]:
topg_sat_clauses_now.shape

torch.Size([4, 20])

In [463]:
word_indices = topginds % vocab_size

In [464]:
word_indices

tensor([[5, 5, 0, 0, 9, 9, 9, 0, 5, 4, 1, 3, 4, 8, 7, 3, 4, 8, 2, 6],
        [0, 9, 5, 0, 0, 9, 5, 9, 5, 1, 7, 7, 6, 1, 4, 1, 2, 8, 2, 7],
        [9, 9, 5, 0, 0, 5, 9, 5, 0, 1, 3, 3, 2, 6, 7, 8, 7, 8, 4, 2],
        [0, 5, 0, 5, 9, 9, 9, 5, 0, 1, 2, 8, 2, 7, 6, 4, 7, 4, 3, 2]])

In [533]:
groups = [[] for _ in range(N)]
group_idxs = [[] for _ in range(N)]
group_word_idxs = [[] for _ in range(N)]
group_scores = [[] for _ in range(N)]
group_idxs

[[], [], [], []]

In [534]:
for ni in range(N):
    for vi in range(g):
        group = topg_sat_clauses_now[ni,vi].item()
        if topg_sat_clauses_now[ni, vi] not in groups[ni]:
            groups[ni].append(group)
            group_idxs[ni].append([vi])
            group_word_idxs[ni].append([word_indices[ni, vi].item()])
            group_scores[ni].append([topgvals[ni, vi].item()])
        else:
            gid = groups[ni].index(group)
            group_idxs[ni][gid].append(vi)
            group_word_idxs[ni][gid].append(word_indices[ni, vi].item())
            group_scores[ni][gid].append(topgvals[ni, vi].item())

In [491]:
groups

[[0, 2, 1], [0, 2, 3, 1], [0, 3, 1, 2], [0, 1]]

In [492]:
group_idxs

[[[0, 1, 2, 3, 5, 7, 8, 9, 10, 11, 15, 16, 18, 19], [4, 13], [6, 12, 14, 17]],
 [[0, 3, 5, 7, 16, 18],
  [1, 6, 14, 17],
  [2, 8, 9, 12, 13, 19],
  [4, 10, 11, 15]],
 [[0, 1, 2, 7, 12, 18],
  [3, 5, 8, 13, 17],
  [4, 9, 16, 19],
  [6, 10, 11, 14, 15]],
 [[0, 1, 2, 3, 4, 5, 7, 8, 9, 10, 12, 14, 15, 16, 17, 18, 19], [6, 11, 13]]]

In [493]:
group_word_idxs

[[[5, 5, 0, 0, 9, 0, 5, 4, 1, 3, 3, 4, 2, 6], [9, 8], [9, 4, 7, 8]],
 [[0, 0, 9, 9, 2, 2], [9, 5, 4, 8], [5, 5, 1, 6, 1, 7], [0, 7, 7, 1]],
 [[9, 9, 5, 5, 2, 4], [0, 5, 0, 6, 8], [0, 1, 7, 2], [9, 3, 3, 7, 8]],
 [[0, 5, 0, 5, 9, 9, 5, 0, 1, 2, 2, 6, 4, 7, 4, 3, 2], [9, 8, 7]]]

In [494]:
group_scores[0][0]

[-1.9317786693572998,
 -4.159816741943359,
 -4.23856782913208,
 -4.387991905212402,
 -4.640323638916016,
 -4.728414058685303,
 -4.744137763977051,
 -11.929056167602539,
 -11.944205284118652,
 -11.957398414611816,
 -12.037866592407227,
 -12.0466947555542,
 -14.186067581176758,
 -14.332537651062012]

## Selection

In [495]:
groups

[[0, 2, 1], [0, 2, 3, 1], [0, 3, 1, 2], [0, 1]]

In [496]:
groups[0]

[0, 2, 1]

In [497]:
group_word_idxs[0][0]

[5, 5, 0, 0, 9, 0, 5, 4, 1, 3, 3, 4, 2, 6]

In [498]:
constraints = [[[5, 0], [1], [7, 9, 10], [2]] for _ in range(N)]
constraints

[[[5, 0], [1], [7, 9, 10], [2]],
 [[5, 0], [1], [7, 9, 10], [2]],
 [[5, 0], [1], [7, 9, 10], [2]],
 [[5, 0], [1], [7, 9, 10], [2]]]

In [499]:
group_scores

[[[-1.9317786693572998,
   -4.159816741943359,
   -4.23856782913208,
   -4.387991905212402,
   -4.640323638916016,
   -4.728414058685303,
   -4.744137763977051,
   -11.929056167602539,
   -11.944205284118652,
   -11.957398414611816,
   -12.037866592407227,
   -12.0466947555542,
   -14.186067581176758,
   -14.332537651062012],
  [-4.442814826965332, -11.982657432556152],
  [-4.698945999145508,
   -11.98073673248291,
   -12.031679153442383,
   -12.100347518920898]],
 [[-2.008964776992798,
   -6.899304389953613,
   -9.279380798339844,
   -9.603067398071289,
   -14.388473510742188,
   -14.59383773803711],
  [-4.228276252746582,
   -9.418909072875977,
   -14.214564323425293,
   -14.406536102294922],
  [-4.500431060791016,
   -9.778593063354492,
   -11.821250915527344,
   -12.154487609863281,
   -12.203121185302734,
   -14.624506950378418],
  [-7.150251388549805,
   -11.91598892211914,
   -12.14522933959961,
   -14.272651672363281]],
 [[-2.1188576221466064,
   -4.237539291381836,
   -4.27132

In [535]:
lam = 0.5
for ni in range(N):
    for group_num in range(len(groups[ni])):
        for wi, word_idx in enumerate(group_word_idxs[ni][group_num]):
            max_completion = 0

            for constraint in constraints[ni]:
                if word_idx != constraint[0]:
                    continue
                completion = 1/len(constraint)
                if completion > max_completion:
                    max_completion = completion
            
            group_scores[ni][group_num][wi] += lam * max_completion

In [536]:
group_scores[0][0]

[-1.6817786693572998,
 -3.9098167419433594,
 -4.23856782913208,
 -4.387991905212402,
 -4.640323638916016,
 -4.728414058685303,
 -4.494137763977051,
 -11.929056167602539,
 -11.444205284118652,
 -11.957398414611816,
 -12.037866592407227,
 -12.0466947555542,
 -13.686067581176758,
 -14.332537651062012]

In [506]:
groups[0]

[0, 2, 1]

In [537]:
order = [sorted(l, reverse=True) for l in groups]
order

[[2, 1, 0], [3, 2, 1, 0], [3, 2, 1, 0], [1, 0]]

In [538]:
import numpy as np

In [539]:
decoder_input_next =  torch.empty([N, k])

In [540]:
decoder_input_next

tensor([[    -0.0004,      0.0000,     -0.0004],
        [     0.0000,      0.0000,      0.0000],
        [     0.0000,      0.0000,      0.0000],
        [     0.0000,      0.0000,      0.0000]])

In [541]:
for ni in range(N):
    obtained = 0
    while obtained < k:
        for groupnum in order[ni]:
            groupid = groups[ni].index(groupnum)
            max_score_idx = np.argmax(group_scores[ni][groupid])
            decoder_input_next[ni][obtained] = group_word_idxs[ni][groupid][max_score_idx]
            obtained += 1
            group_scores[ni][groupid].pop(max_score_idx)
            group_word_idxs[ni][groupid].pop(max_score_idx)

            if obtained >= k:
                break

In [542]:
decoder_input_next

tensor([[9., 9., 5.],
        [5., 9., 0.],
        [0., 9., 0.],
        [9., 0., 7.]])