# Beam Search (Beam Decoding)

Instead of greedily choosing the most likely next step as the sequence is constructed, the beam search expands all possible next steps and keeps the k most likely, where k is a user-specified parameter and controls the number of beams or parallel searches through the sequence of probabilities.

The local beam search algorithm keeps track of "k" states rather than just one. It begins with k randomly generated states. At each step, all the successors of all k states are generated. If any one is a goal, the algorithm halts. Otherwise, it selects the k best successors from the complete list and repeats.

In [None]:
import numpy as np

## Parameters

In [None]:
V = 100 # output dimensionality. number of vocabulary
H = 10  # hidden dimensionality 
K = 3   # beam width
T = 10  # decoding timesteps

initial_y = np.array([[3], [6]], dtype=np.int32)
N = len(initial_y) # batch size
xh = np.random.randn(V, H) # weights from input to hidden
hh = np.random.randn(H, H) # weights from hidden to hidden
ho = np.random.randn(H, V) # weights from hidden to outputs

EOS_ID = 0

## Utility Functions

In [None]:
def OneHotEncoding(arry, size):
    '''
    arry: 2-d array of n, t
    size: output dimensions
    
    returns
    3-d array of (n, t, size)
    '''
    labels_one_hot = (arry.ravel()[np.newaxis] == np.arange(size)[:, np.newaxis]).T
    labels_one_hot.shape = arry.shape + (size,)
    return labels_one_hot.astype('int32')

In [None]:
def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=1,keepdims=True)

In [None]:
def decode(y):
    '''
    Decodes with a simple rnn.
    
    args:
        y (N, t) array if not None.
    '''
    global N, V, H, xh, hh, ho
    
    prev_hidden = np.zeros((y.shape[0], H)) # initial hidden
    
    for t in range(y.shape[1]):
        token = y[:, t]
        x_to_h = np.matmul(OneHotEncoding(token, V), xh) # (N, h)
        h_to_h = np.matmul(prev_hidden, hh) # (N, h)
        hidden = np.tanh(x_to_h + h_to_h)
        prev_hidden = hidden
        if t == y.shape[1]-1: # last step
            outputs = np.matmul(hidden, ho)
            probs = softmax(outputs)
            probs = np.log(probs)
            return probs # (N, V)

## Beam Decoding

In [None]:
for t in range(T):  
    def _get_preds_and_probs(PREDS):
        probs = decode(y=PREDS) # (N, V)
        preds_k = np.argsort(probs)[:, ::-1][:, :K].flatten() # (K*N,)
        probs_k = np.sort(probs)[:, ::-1][:, :K].flatten() # (K*N,)
        return preds_k, probs_k
    
    def logging(PREDS_k, EOS_k, PROBS_k):
        for i, (PREDS_k_batch, EOS_k_batch, PROBS_k_batch) in \
                    enumerate(zip(np.split(PREDS_k, N), np.split(EOS_k, N), np.split(PROBS_k, N) )):
            print("batch num=", i)
            for each_PREDS_k_batch, each_EOS_k_batch, each_PROBS_k_batch in zip(PREDS_k_batch, EOS_k_batch, PROBS_k_batch):
                print("{}\t{}\t{}".format(each_PREDS_k_batch, each_EOS_k_batch, each_PROBS_k_batch))
               
    if t==0: # initial step
        print("="*10, "timesteps=", t, "="*10)
        
        preds_k, probs_k = _get_preds_and_probs(initial_y) # (k*N), (k*N)
        PREDS_k = np.expand_dims(preds_k, -1) # PREDS_k: Final outputs, (k*N, 1)
        PROBS_k = probs_k
        EOS_k = preds_k==EOS_ID # 
        
        # logging
        logging(PREDS_k, EOS_k, PROBS_k)
                                                      
    else:
        print("="*10, "timesteps=", t, "="*10)
        print("Expansion...")
        
        preds_kk, probs_kk = _get_preds_and_probs(PREDS_k) # (k*k*N), (k*k*N) <- incremental(=local) values
        
        # preds for exanded beams
        PREDS_kk = np.repeat(PREDS_k, K, axis=0) # (k*k*N, t)
        PREDS_kk = np.append(PREDS_kk, np.expand_dims(preds_kk, -1), -1) # PREDS_kk: (k*k*N, t+1)
        
        # eos for expanded beams
        eos_kk = preds_kk==EOS_ID # (k*k*N) <- local
        EOS_kk = np.repeat(EOS_k, K, axis=0) # (k*k*N, )
        EOS_kk = np.logical_or(EOS_kk, eos_kk) # (k*k*N,)
        
        # probs for expanded beams
        PROBS_kk = np.repeat(PROBS_k, K, axis=0) # (k*k*N, )
        normalized_probs = ( PROBS_kk * t + probs_kk ) / (t+1)
        PROBS_kk = np.where(EOS_kk, PROBS_kk, normalized_probs) # (k*k*N, )
        
        # logging
        logging(PREDS_kk, EOS_kk, PROBS_kk)
        
        print("Pruning ...")
        winners = [] # (k*N). k elements are selected out of k^2
        for j, prob_kk in enumerate(np.split(PROBS_kk, N)): # (k*k,) 
            if t == T-1: # final step
                winner = np.argsort(prob_kk)[::-1][:1] # final 1 best
                winners.extend(list(winner + j*len(prob_kk)))
            else:
                winner = np.argsort(prob_kk)[::-1][:K]
                winners.extend(list(winner + j*len(prob_kk)))
        
        PREDS_k = PREDS_kk[winners] # (N, T) if final step,  otherwise (k*N, t)
        PROBS_k = PROBS_kk[winners] # (N, T) if final step,  otherwise (k*N, )
        EOS_k = EOS_kk[winners]
        
        # logging
        logging(PREDS_k, EOS_k, PROBS_k)

batch num= 0
[3]	False	-1.0104630587644836
[66]	False	-1.6612474832368465
[86]	False	-2.061597426559997
batch num= 1
[85]	False	-1.184326096272182
[79]	False	-1.217708716134948
[63]	False	-3.0112595542337894
Expansion...
batch num= 0
[3 3]	False	-1.0104630587644836
[ 3 66]	False	-1.335855271000665
[ 3 86]	False	-1.5360302426622403
[66 70]	False	-1.2986597638593262
[66 21]	False	-1.950765646880671
[66 55]	False	-1.9736278451829636
[86 17]	False	-1.9965168367301112
[86 69]	False	-2.1746499987483743
[86 40]	False	-2.211077902821743
batch num= 1
[85 85]	False	-1.6051277230409093
[85 15]	False	-1.661923928075009
[85 45]	False	-1.8430660765392783
[79 67]	False	-1.076659434797022
[79 96]	False	-1.313978308575485
[79 22]	False	-2.033316863110551
[63 99]	False	-2.504420761152929
[63 43]	False	-2.78438430437292
[63 21]	False	-2.8604865679150504
Pruning ...
batch num= 0
[3 3]	False	-1.0104630587644836
[66 70]	False	-1.2986597638593262
[ 3 66]	False	-1.335855271000665
batch num= 1
[79 67]	False	-1