# "Beam Decode" Made Easy

In [1]:
import numpy as np

## Parameters

In [262]:
V = 100 # output dimensionality. number of vocabulary
H = 10 # hidden dimensionality 
K = 3 # beam width
T = 10 # decoding timesteps
initial_y = np.array([[3], [6]], dtype=np.int32)
N = len(initial_y) # batch size
xh = np.random.randn(V, H) # weights from input to hidden
hh = np.random.randn(H, H) # weights from hidden to hidden
ho = np.random.randn(H, V) # weights from hidden to outputs

EOS_ID = 0

## Utility functions

In [263]:
def onehot(arry, size):
    '''
    arry: 2-d array of n, t
    size: output dimensions
    
    returns
    3-d array of (n, t, size)
    '''
    labels_one_hot = (arry.ravel()[np.newaxis] == np.arange(size)[:, np.newaxis]).T
    labels_one_hot.shape = arry.shape + (size,)
    return labels_one_hot.astype('int32')

In [264]:
def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=1,keepdims=True)

In [270]:
def decode(y):
    '''Decodes with a simple rnn.
    y: (N, t) array if not None.
    '''
    global N, V, H, xh, hh, ho
    
    prev_hidden = np.zeros((y.shape[0], H)) # initial hidden
    
    for t in range(y.shape[1]):
        token = y[:, t]
        x_to_h = np.matmul(onehot(token, V), xh) # (N, h)
        h_to_h = np.matmul(prev_hidden, hh) # (N, h)
        hidden = np.tanh(x_to_h + h_to_h)
        prev_hidden = hidden
        if t == y.shape[1]-1: # last step
            outputs = np.matmul(hidden, ho)
            probs = softmax(outputs)
            probs = np.log(probs)
            return probs # (N, V)

## Beam decode

In [268]:
for t in range(T):  
    def _get_preds_and_probs(PREDS):
        probs = decode(y=PREDS) # (N, V)
        preds_k = np.argsort(probs)[:, ::-1][:, :K].flatten() # (K*N,)
        probs_k = np.sort(probs)[:, ::-1][:, :K].flatten() # (K*N,)
        return preds_k, probs_k
    
    def logging(PREDS_k, EOS_k, PROBS_k):
        for i, (PREDS_k_batch, EOS_k_batch, PROBS_k_batch) in \
                    enumerate(zip(np.split(PREDS_k, N), np.split(EOS_k, N), np.split(PROBS_k, N) )):
            print("batch num=", i)
            for each_PREDS_k_batch, each_EOS_k_batch, each_PROBS_k_batch in zip(PREDS_k_batch, EOS_k_batch, PROBS_k_batch):
                print("{}\t{}\t{}".format(each_PREDS_k_batch, each_EOS_k_batch, each_PROBS_k_batch))
               
    if t==0: # initial step
        print("="*10, "timesteps=", t, "="*10)
        
        preds_k, probs_k = _get_preds_and_probs(initial_y) # (k*N), (k*N)
        PREDS_k = np.expand_dims(preds_k, -1) # PREDS_k: Final outputs, (k*N, 1)
        PROBS_k = probs_k
        EOS_k = preds_k==EOS_ID # 
        
        # logging
        logging(PREDS_k, EOS_k, PROBS_k)
                                                      
    else:
        print("="*10, "timesteps=", t, "="*10)
        print("Expansion...")
        
        preds_kk, probs_kk = _get_preds_and_probs(PREDS_k) # (k*k*N), (k*k*N) <- incremental(=local) values
        
        # preds for exanded beams
        PREDS_kk = np.repeat(PREDS_k, k, axis=0) # (k*k*N, t)
        PREDS_kk = np.append(PREDS_kk, np.expand_dims(preds_kk, -1), -1) # PREDS_kk: (k*k*N, t+1)
        
        # eos for expanded beams
        eos_kk = preds_kk==EOS_ID # (k*k*N) <- local
        EOS_kk = np.repeat(EOS_k, k, axis=0) # (k*k*N, )
        EOS_kk = np.logical_or(EOS_kk, eos_kk) # (k*k*N,)
        
        # probs for expanded beams
        PROBS_kk = np.repeat(PROBS_k, k, axis=0) # (k*k*N, )
        normalized_probs = ( PROBS_kk*t + probs_kk ) / (t+1)
        PROBS_kk = np.where(EOS_kk, PROBS_kk, normalized_probs) # (k*k*N, )
        
        # logging
        logging(PREDS_kk, EOS_kk, PROBS_kk)
        
        print("Pruning ...")
        winners = [] # (k*N). k elements are selected out of k^2
        for j, prob_kk in enumerate(np.split(PROBS_kk, N)): # (k*k,) 
            if t == T-1: # final step
                winner = np.argsort(prob_kk)[::-1][:1] # final 1 best
                winners.extend(list(winner + j*len(prob_kk)))
            else:
                winner = np.argsort(prob_kk)[::-1][:k]
                winners.extend(list(winner + j*len(prob_kk)))
        
        PREDS_k = PREDS_kk[winners] # (N, T) if final step,  otherwise (k*N, t)
        PROBS_k = PROBS_kk[winners] # (N, T) if final step,  otherwise (k*N, )
        EOS_k = EOS_kk[winners]
        
        # logging
        logging(PREDS_k, EOS_k, PROBS_k)

batch num= 0
[13]	False	-1.801408656139552
[48]	False	-1.9670540723462346
[33]	False	-2.198485320487616
batch num= 1
[30]	False	-0.7897988012958753
[1]	False	-2.4629262078330942
[76]	False	-2.942492396759654
Expansion...
batch num= 0
[13 53]	False	-1.7223380770456393
[13 13]	False	-2.153169837293493
[13 21]	False	-2.1963759327725914
[48  7]	False	-1.5501885837535672
[48 56]	False	-2.0774770760291426
[48 71]	False	-2.2559148677511205
[33 13]	False	-2.02571782070014
[33 20]	False	-2.2622432967465382
[33 99]	False	-2.411861897194325
batch num= 1
[30  8]	False	-1.0387280196474995
[30 38]	False	-1.4677326159171868
[30 67]	False	-1.963540864046755
[ 1 46]	False	-1.8347386403083892
[ 1 50]	False	-2.062073079883135
[ 1 71]	False	-2.79060930087043
[76 89]	False	-2.4038844011808673
[76 44]	False	-2.440214472610669
[76  3]	False	-2.4820893854715784
Pruning ...
batch num= 0
[48  7]	False	-1.5501885837535672
[13 53]	False	-1.7223380770456393
[33 13]	False	-2.02571782070014
batch num= 1
[30  8]	Fals

Be aware that the tokens that follows the `<EOS>` are stripped.