# CTC Beam Search with LM

This notebook implements CTC beam search with LM presented in [Hannun el al.](https://arxiv.org/pdf/1408.2873.pdf).

In [None]:
%run ctc.ipynb
%run ctc_bs.ipynb

In [None]:
import numpy as np
from collections import defaultdict, Counter

In [56]:
class CTCBeamSearchLM(AbstractDecoder):
    def __init__(self, beam_size=10, alpha=0.3, beta=5):
        self.beam_size = beam_size
        self.alpha = alpha
        self.beta = beta

        self.Pb = defaultdict(Counter)
        self.Pnb = defaultdict(Counter)

        self.A_prev = [""]

        self.to_char = lambda x: CharModel.ind_to_char[x]
        self.to_ind = lambda x: CharModel.char_to_ind[x]

    def _topK(self, t):
        A_next = self.Pb[t] + self.Pnb[t]
        # print(f"Current A_next is {A_next}")
        return sorted(A_next.keys(), 
                      key=lambda x: A_next[x] * len(x) ** self.beta, 
                      reverse=True)[:self.beam_size]

    def decode(self, output, lm=None):
        lm = (lambda l: 1) if not lm else lm
        self.Pb[0][""] = 1
        self.Pnb[0][""] = 0

        for t in range(1, output.shape[0]):
            # print(f"*******timestep {t} started*********")
            for seq in self.A_prev:
                # print(f"Extending sequence {seq}")
                for c in range(output.shape[1]):
                    # print(f"try to extend char {self.to_char(c)}")
                    if c == self.to_ind(CharModel.BLANK_SYMBOL):
                        self.Pb[t][seq] = output[t, c] * (self.Pb[t-1][seq] + self.Pnb[t-1][seq])
                    else:
                        curr_char = self.to_char(c)
                        curr_seq = seq + curr_char
                        # print(f"current looking seq is {curr_seq}")
                        if len(seq) > 0 and curr_char == seq[-1]:
                            self.Pnb[t][curr_seq] += output[t][c] * self.Pb[t-1][seq]
                            self.Pnb[t][seq] += output[t][c] * self.Pb[t-1][seq]
                        elif len(seq) > 0 and curr_char == self.to_ind(CharModel.SPACE):
                            self.Pnb[t][curr_seq] += \
                                output[t][c] * lm(curr_seq) ** self.alpha *\
                                (self.Pb[t-1][seq] + self.Pnb[t-1][seq])
                        else:
                            self.Pnb[t][curr_seq] += output[t, c] *\
                                (self.Pb[t-1][seq] + self.Pnb[t-1][seq])
                        
                        if curr_seq not in self.A_prev:
                            self.Pb[t][curr_seq] += output[t][self.to_ind(CharModel.BLANK_SYMBOL)] *\
                                (self.Pb[t-1][curr_seq] + self.Pnb[t-1][curr_seq])
                            self.Pnb[t][curr_seq] += output[t][c] * self.Pnb[t-1][curr_seq]
                        
            self.A_prev = self._topK(t)
        return self.A_prev[0]

In [57]:
output = _random_ctc_output(30)

In [58]:
ctcbslm = CTCBeamSearchLM()

In [59]:
ctcbslm.decode(output)

'vk plkiviawdmuhabmxs zdgpkbdl'