### numpy implementation of Connectionist Temporal Classification (https://www.cs.toronto.edu/~graves/icml_2006.pdf)

In [1]:
import numpy as np

create some test input

In [3]:
logits = np.array([[-0.1063757 ,  0.86675691,  0.42268333,  0.9100397 ],
                   [ 1.66946294, -0.1009918 , -0.05339809,  0.18738931],
                   [-0.56614544,  0.2582804 , -1.50399211, -0.05561591],
                   [ 0.07700631, -1.3008013 , -0.78838889,  0.21257662],
                   [-1.1350321 , -0.21027041, -0.57139345, -0.93812239],
                   [ 1.28395484,  0.72440664,  1.43311044, -0.71667425],
                   [ 1.29279426,  0.59313199,  0.4500655 ,  2.0559043 ],
                   [ 0.19267242, -1.18393113,  0.5355136 , -0.85880894],
                   [-0.74923177,  1.56242967, -0.74563585,  0.7803211 ],
                   [-0.18403769, -1.55432081, -0.73049657,  0.48370142]])

vocab = ['a','b','c','=']
blank = '='
idx_to_char = vocab
char_to_idx = dict()
for i in range(len(idx_to_char)):
    char_to_idx[idx_to_char[i]] = i
blank_idx = char_to_idx[blank]

label = 'accac'

In [6]:
def pad_label(label):
    return '='+'='.join(label)+'='

In [7]:
def forward_backward(logits,label):
    padded_label = pad_label(label)
    num_t = logits.shape[0]
    num_s = len(padded_label)
    
    alpha = np.zeros((num_t,num_s))
    def forward(t,s):
        current_score = logits[t, char_to_idx[padded_label[s]]]
        
        #initialization of alpha matrix
        if t == 0:
            if s == 0:
                return logits[0, blank_idx]
            elif s == 1:
                return logits[0, char_to_idx[padded_label[s]]]
            else:
                return 0 
        
        # equation 6 & 7
        prev_score = alpha[t-1, s]
        if s > 0:
            prev_score += alpha[t-1, s-1]
        if padded_label[s] != blank and (s < 2 or padded_label[s] != padded_label[s-2]):
            if s-2 >= 0:
                prev_score += alpha[t-1, s-2]
        return prev_score * current_score
                
    for t in range(num_t):
        for s in range(num_s):
            alpha[t, s] = forward(t, s)
    
    beta = np.zeros((num_t,num_s))
    def backward(t,s):
        current_score = logits[t, char_to_idx[padded_label[s]]]
        
        #initialization of beta matrix
        if t == num_t-1:
            if s == num_s-1:
                return logits[-1, blank_idx]
            elif s == num_s-2:
                return current_score
            else:
                return 0
        
        # equation 10 & 11
        next_score = beta[t+1,s]
        if s < num_s-1:
            next_score += beta[t+1,s+1]
        if padded_label[s] != blank and (s > num_s-3 or padded_label[s] != padded_label[s+2]):
            if s+2 <= num_s-1:
                next_score += beta[t+1,s+2]
        return next_score * current_score
    
    for t in reversed(range(num_t)):
        for s in reversed(range(num_s)):
            beta[t,s] = backward(t,s)
    
    return alpha, beta

def gradient(logits,label):
    padded_label = pad_label(label)
    alpha, beta = forward_backward(logits,label)
    
    pl_x = alpha[-1,-1]+alpha[-1,-2]
    
    gradients = np.zeros_like(logits)
    
    # equation 15
    for t in range(logits.shape[0]):
        for k in range(logits.shape[1]):
            lab_lk = [i for i in range(len(padded_label)) if padded_label[i] == vocab[k]]
            dp_dytk = 0.
            for s in lab_lk:
                dp_dytk += alpha[t,s]*beta[t,s]
            dp_dytk /= logits[t,k]**2
            dlnp_dytk = dp_dytk / pl_x
            gradients[t,k] = dlnp_dytk
    return gradients

In [11]:
gradient(logits,label)

array([[ 1.27076965,  0.        ,  0.        ,  1.24739504],
       [ 0.57680003,  0.        , -0.03111668,  0.18886964],
       [-0.65993026,  0.        , -0.39989117, -0.44861845],
       [ 1.3546767 ,  0.        , -0.52285446,  2.27432676],
       [-0.04480657,  0.        , -0.10440541, -0.94815617],
       [ 0.39371001,  0.        , -0.25641427, -1.20272788],
       [-1.02922122,  0.        ,  0.00347259,  1.13283891],
       [-0.92080778,  0.        , -0.04519529, -1.3991668 ],
       [-2.05502753,  0.        ,  0.13602776, -0.56164666],
       [ 0.        ,  0.        , -1.53525176, -0.25118004]])