# CTC Beam Search

The implementation follows the **Algorithm 1** in [Graves et al. (2014)](http://proceedings.mlr.press/v32/graves14.pdf).

In [86]:
!pip install nbformat



In [87]:
%run ctc.ipynb

In [88]:
import numpy as np
import abc

In [89]:
output = _random_ctc_output(30)

## Greedy Decoding
Now, let's first build a greedy decoder.

In [90]:
class AbstractDecoder(metaclass=abc.ABCMeta):
    @abc.abstractmethod
    def decode(self, output):
        pass

In [91]:
class CTCGreedyDecoder(AbstractDecoder):
    def decode(self, output):
        return np.argmax(softmax(output), axis=1)

In [99]:
gd = CTCGreedyDecoder()

"".join(list(map(lambda x: CharModel.ind_to_char[x], gd.decode(output))))

'w^jlslfjnmxyklog^czjwndaqcgivo'

## Beam Search

Now, let's build the CTC beam searcher.

Define $\operatorname{Pr}^{-}(\boldsymbol{y}, t), \operatorname{Pr}^{+}(\boldsymbol{y}, t)$ and $\operatorname{Pr}(\boldsymbol{y}, t)$ respectively as the blank, non-blank and total probabilities assigned to some (partial) output transcription $\boldsymbol{y}$, at time $t$ by the beam search, and set $\operatorname{Pr}(\boldsymbol{y}, t)=\operatorname{Pr}^{-}(\boldsymbol{y}, t)+\operatorname{Pr}^{+}(\boldsymbol{y}, t)$. Define the extension probability $\operatorname{Pr}(k, \boldsymbol{y}, t)$ of $\boldsymbol{y}$ by label $k$ at time $t$ as follows:
$$
\operatorname{Pr}(k, \boldsymbol{y}, t)=\operatorname{Pr}(k, t \mid \boldsymbol{x}) \operatorname{Pr}(k \mid \boldsymbol{y})\left\{\begin{array}{l}
\operatorname{Pr}^{-}(\boldsymbol{y}, t-1) \text { if } \boldsymbol{y}^e=k \\
\operatorname{Pr}(\boldsymbol{y}, t-1) \text { otherwise }
\end{array}\right.
$$
where $\operatorname{Pr}(k, t \mid \boldsymbol{x})$ is the CTC emission probability of $k$ at $t$, as defined in Eq. (13), $\operatorname{Pr}(k \mid \boldsymbol{y})$ is the transition probability from $\boldsymbol{y}$ to $\boldsymbol{y}+k$ and $\boldsymbol{y}^e$ is the final label in $\boldsymbol{y}$. Lastly, define $\hat{\boldsymbol{y}}$ as the prefix of $\boldsymbol{y}$ with the last label removed, and $\varnothing$ as the empty sequence, noting that $\operatorname{Pr}^{+}(\varnothing, t)=0 \forall t$.

$\operatorname{Pr}(k \mid \boldsymbol{y})$ is the transition prob of language model. It is set to 1 if no such information.

In [93]:
class CTCBeamSearchDecoder(AbstractDecoder):
    def __init__(self, beam_size=5):
        self.beam_size = beam_size
        self.B = {("", 1)}
        self.prob_minus = {("", 0): 1}
        self.prob_plus = dict()

    def _find_K_top(self):
        if len(self.B) <= self.beam_size:
            return self.B
        # lst = list(self.B)
        it = sorted(self.B, key=lambda x: x[1], reverse=True)
        return set(it[:self.beam_size])
    
    def _find_max_score(self):
        itr = map(lambda x: (x[0], x[1] ** (1.0 / len(x[0]))), self.B)
        return sorted(itr, key=lambda x: x[1], reverse=True)[0]

    def decode(self, output):
        log_prob = softmax(output)
        LM_score = 1 # No LM for now

        for t in range(1, log_prob.shape[0]):
            # print(self.B)
            B_hat = self._find_K_top()
            # print(B_hat)
            self.B = set()

            for seq, _ in B_hat:
                if seq != "":
                    self.prob_plus[(seq, t)] = \
                        self.prob_plus[(seq, t - 1)] * log_prob[t, CharModel.char_to_ind[seq[-1]]]
                    
                    if seq[:-1] in B_hat:
                        self.prob_plus[(seq, t)] += \
                            LM_score * \
                            log_prob[t, CharModel.char_to_ind[seq[-1]]] * \
                            self.prob_minus[(seq, t - 1)]
                self.prob_minus[(seq, t)] = \
                    self.prob_minus[(seq, t - 1)] * \
                        log_prob[t, CharModel.char_to_ind[CharModel.BLANK_SYMBOL]]
                
                self.B.add((
                    seq, 
                    self.prob_minus[(seq, t)] +\
                        (self.prob_plus[(seq, t)] if seq != "" else 0)
                    ))

                for k in range(log_prob.shape[1]):
                    curr_seq = seq + CharModel.ind_to_char[k]
                    self.prob_minus[(curr_seq, t)] = 0
                    self.prob_plus[(curr_seq, t)] = \
                        log_prob[t, k] * LM_score * (
                            self.prob_minus[(seq, t-1)] if len(seq) > 0 and CharModel.char_to_ind[seq[-1]] == k \
                            else self.prob_minus[(seq, t-1)] + (self.prob_plus[(seq, t-1)] if len(seq) > 0 else 0))
                    self.B.add((curr_seq, self.prob_plus[(curr_seq, t)]))
        
        return self._find_max_score()

In [94]:
ctcbs = CTCBeamSearchDecoder()

In [95]:
z = ctcbs.decode(output)

In [96]:
z

('^jlslfjnmxyklog^czjwndaqcgivo', 0.05402136794001452)