# Training with `nn.CTCLoss`

Starting from PyTorch 1.1.0, built-in support for CTC loss is available as `nn.CTCLoss`. Before that, people have to use third-party libraries like `warp-ctc`. We strongly recommend you to use a recent PyTorch version and `nn.CTCLoss` for HW3P2.



## Toy task: English spelling to pronunciation

As a demonstration, we consider the task of predicting the pronunciation (as sequence of phonemes) of an English word given its spelling. The model we use is a bidirectional LSTM.

CTC is actually not the best formulation for this problem, since the letter "X" corresponds to two phonemes "K S", but it works well with our simplified data.

In [1]:
# Words with only E, I, N, S, T.
# Pronunciation is from http://www.speech.cs.cmu.edu/cgi-bin/pronounce
data = [
    ('SEE', 'S IY'),
    ('SET', 'S EH T'),
    ('SIT', 'S IH T'),
    ('SITE', 'S AY T'),
    ('SIN', 'S IH N'),
    ('TEEN', 'T IY N'),
    ('TIN', 'T IH N'),
    ('TIE', 'T AY'),
    ('TEST', 'T EH S T'),
    ('NET', 'N EH T'),
    ('NEET', 'N IY T'),
    ('NINE', 'N AY N')
]
letters = 'EINST'
# Starts with ' ' for blank, followed by actual phonemes
phonemes = [' ', 'S', 'T', 'N', 'IY', 'IH', 'EH', 'AY']

Note that, if there are P phonemes for the output, they should be indexed as 1 to P, not 0 to P-1. **Index 0 is reserved for "blank".**

Accordingly, the output classifier of the model should have P+1 classes, since "blank" is also a class.

In [2]:
import torch
from torch import nn
from torch.nn.utils.rnn import *

X = [torch.LongTensor([letters.find(c) for c in word]) for word, _ in data]
Y = [torch.LongTensor([phonemes.index(p) for p in pron.split()]) for _, pron in data]
X_lens = torch.LongTensor([len(seq) for seq in X])
Y_lens = torch.LongTensor([len(seq) for seq in Y])
X = pad_sequence(X)
# `batch_first=True` is required for use in `nn.CTCLoss`.
Y = pad_sequence(Y, batch_first=True)

print('X', X.size(), X_lens)
print('Y', Y.size(), Y_lens)

class Model(nn.Module):
    def __init__(self, in_vocab, out_vocab, embed_size, hidden_size):
        super(Model, self).__init__()
        self.embed = nn.Embedding(in_vocab, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, bidirectional=True)
        self.output = nn.Linear(hidden_size * 2, out_vocab)
    
    def forward(self, X, lengths):
        X = self.embed(X)
        print("embed")
        print(X.size())
        print(X)
        packed_X = pack_padded_sequence(X, lengths, enforce_sorted=False)
        packed_out = self.lstm(packed_X)[0]
        out, out_lens = pad_packed_sequence(packed_out)
        # Log softmax after output layer is required for use in `nn.CTCLoss`.
        out = self.output(out).log_softmax(2)
        return out, out_lens

X torch.Size([4, 12]) tensor([3, 3, 3, 4, 3, 4, 3, 3, 4, 3, 4, 4])
Y torch.Size([12, 4]) tensor([2, 3, 3, 3, 3, 3, 3, 2, 4, 3, 3, 3])


## Usage

The official documentation is your best friend: https://pytorch.org/docs/stable/nn.html#ctcloss

`nn.CTCLoss` takes 4 arguments to compute the loss:
* `log_probs`: Prediction of your model at each time step.
  * Shape: (T, N, C), where T is the largest length in the batch, N is batch size, and C is number of classes (remember that it should be number of phonemes plus 1).
  * **Values must be log probabilities.** Neither probabilities nor logits will work. Make sure the output of your network is log probabilities, by adding a `nn.LogSoftmax` after the last linear layer.
* `targets`: The ground truth sequences.
  * Shape: (N, S), where N is batch size, and S is the largest length in the batch. **WARNING!** This dimension order is unconventional in PyTorch. If you use `torch.nn.utils.rnn.pad_sequence` to pad the target sequence,  **you must explicitly set `batch_first=True`**.
  * Values are indices of phonemes. Again, remember that index 0 is reserved for "blank" and should not represent any phoneme.
* `input_lengths`: Lengths of sequences in `log_probs`.
  * Shape: (N,).
  * This is not necessarily the same as lengths of input of the model. If your model uses CNNs or pyramidal RNNs, it changes the length of sequences, and you must correctly compute the lengths of its output to be used here.
* `target_lengths`: Lengths of sequences in `targets`.
  * Shape: (N,).


In [3]:
torch.manual_seed(11785)
model = Model(len(letters), len(phonemes), 5, 4)
criterion = nn.CTCLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

for epoch in range(50):
    model.zero_grad()
    out, out_lens = model(X, X_lens)
    loss = criterion(out, Y, out_lens, Y_lens)
    print('Epoch', epoch + 1, 'Loss', loss.item())
    loss.backward()
    optimizer.step()

embed
torch.Size([4, 12, 5])
tensor([[[ 0.0048,  1.3516, -0.6897,  0.5317,  1.6516],
         [ 0.0048,  1.3516, -0.6897,  0.5317,  1.6516],
         [ 0.0048,  1.3516, -0.6897,  0.5317,  1.6516],
         [ 0.0048,  1.3516, -0.6897,  0.5317,  1.6516],
         [ 0.0048,  1.3516, -0.6897,  0.5317,  1.6516],
         [-0.8527,  0.0948,  0.8146,  0.7157,  0.8928],
         [-0.8527,  0.0948,  0.8146,  0.7157,  0.8928],
         [-0.8527,  0.0948,  0.8146,  0.7157,  0.8928],
         [-0.8527,  0.0948,  0.8146,  0.7157,  0.8928],
         [-0.3436, -0.5816, -0.1644, -0.1478,  0.7724],
         [-0.3436, -0.5816, -0.1644, -0.1478,  0.7724],
         [-0.3436, -0.5816, -0.1644, -0.1478,  0.7724]],

        [[-0.2231, -0.1079,  1.3535,  1.0777, -1.1725],
         [-0.2231, -0.1079,  1.3535,  1.0777, -1.1725],
         [-1.1771,  1.8800, -0.7710,  0.4062,  1.5917],
         [-1.1771,  1.8800, -0.7710,  0.4062,  1.5917],
         [-1.1771,  1.8800, -0.7710,  0.4062,  1.5917],
         [-0.2231

embed
torch.Size([4, 12, 5])
tensor([[[ 1.0545,  2.3891, -1.2157,  0.8489,  2.6586],
         [ 1.0545,  2.3891, -1.2157,  0.8489,  2.6586],
         [ 1.0545,  2.3891, -1.2157,  0.8489,  2.6586],
         [ 1.0545,  2.3891, -1.2157,  0.8489,  2.6586],
         [ 1.0545,  2.3891, -1.2157,  0.8489,  2.6586],
         [-1.5142,  0.2586,  1.3114, -0.2026,  1.2138],
         [-1.5142,  0.2586,  1.3114, -0.2026,  1.2138],
         [-1.5142,  0.2586,  1.3114, -0.2026,  1.2138],
         [-1.5142,  0.2586,  1.3114, -0.2026,  1.2138],
         [-0.6313, -1.6214,  0.1632,  0.2783,  1.0439],
         [-0.6313, -1.6214,  0.1632,  0.2783,  1.0439],
         [-0.6313, -1.6214,  0.1632,  0.2783,  1.0439]],

        [[-0.3904, -0.1334,  0.3786,  1.8888, -1.7003],
         [-0.3904, -0.1334,  0.3786,  1.8888, -1.7003],
         [-2.3461,  2.3667, -1.7637,  0.5197,  0.9224],
         [-2.3461,  2.3667, -1.7637,  0.5197,  0.9224],
         [-2.3461,  2.3667, -1.7637,  0.5197,  0.9224],
         [-0.3904

Epoch 33 Loss 0.11679060012102127
embed
torch.Size([4, 12, 5])
tensor([[[ 1.6937,  2.9937, -1.2835,  0.7093,  3.0061],
         [ 1.6937,  2.9937, -1.2835,  0.7093,  3.0061],
         [ 1.6937,  2.9937, -1.2835,  0.7093,  3.0061],
         [ 1.6937,  2.9937, -1.2835,  0.7093,  3.0061],
         [ 1.6937,  2.9937, -1.2835,  0.7093,  3.0061],
         [-2.2606,  0.7513,  1.1733, -0.7230,  2.0878],
         [-2.2606,  0.7513,  1.1733, -0.7230,  2.0878],
         [-2.2606,  0.7513,  1.1733, -0.7230,  2.0878],
         [-2.2606,  0.7513,  1.1733, -0.7230,  2.0878],
         [-0.0521, -1.8994, -0.7246,  1.0093,  1.8392],
         [-0.0521, -1.8994, -0.7246,  1.0093,  1.8392],
         [-0.0521, -1.8994, -0.7246,  1.0093,  1.8392]],

        [[-0.6976, -0.1064,  0.1990,  2.6779, -1.7496],
         [-0.6976, -0.1064,  0.1990,  2.6779, -1.7496],
         [-3.0846,  2.6091, -2.1886,  0.1742,  0.9129],
         [-3.0846,  2.6091, -2.1886,  0.1742,  0.9129],
         [-3.0846,  2.6091, -2.1886,  0

# Decoding with `ctcdecode`

During inference, we want to generate the most probable sequence from predicted probabilities. PyTorch doesn't have built-in support for that, so we need another library called `ctcdecode`.

## Installation

If you just follow the steps in https://github.com/parlance/ctcdecode, you may encounter `ModuleNotFoundError: No module named 'wget'`. Simply `pip install wget` solves the problem.

Installing `ctcdecode` with the following steps should be successful. (Change `pip3 install` to either `pip3 install --user` or `sudo -H pip3 install` if you are using the system Python instead of Conda) It takes a few minutes to compile, so be patient.

In [4]:
!git clone --recursive https://github.com/parlance/ctcdecode.git
!pip3 install wget
%cd ctcdecode
!pip3 install .
%cd ..

Cloning into 'ctcdecode'...
remote: Enumerating objects: 1006, done.[K
remote: Total 1006 (delta 0), reused 0 (delta 0), pack-reused 1006[K
Receiving objects: 100% (1006/1006), 728.22 KiB | 1.83 MiB/s, done.
Resolving deltas: 100% (500/500), done.
Submodule 'third_party/ThreadPool' (https://github.com/progschj/ThreadPool.git) registered for path 'third_party/ThreadPool'
Submodule 'third_party/kenlm' (https://github.com/kpu/kenlm.git) registered for path 'third_party/kenlm'
Cloning into '/content/ctcdecode/third_party/ThreadPool'...
remote: Enumerating objects: 82, done.        
remote: Total 82 (delta 0), reused 0 (delta 0), pack-reused 82        
Cloning into '/content/ctcdecode/third_party/kenlm'...
remote: Enumerating objects: 5, done.        
remote: Counting objects: 100% (5/5), done.        
remote: Compressing objects: 100% (5/5), done.        
remote: Total 13329 (delta 0), reused 1 (delta 0), pack-reused 13324        
Receiving objects: 100% (13329/13329), 5.33 MiB | 7.56 Mi

Test whether ctcdecode is working.

Common errors:
* `ImportError: No module named 'ctcdecode._ext'`: Your current working directory is in `ctcdecode`. `cd` into other directories will solve this.
* `undefined symbol: _ZN6caffe26detail37_typeMetaDataInstance_preallocated_32E`: **`torch` MUST be imported before importing `ctcdecode`**, otherwise you will see this.


In [4]:
import torch
from ctcdecode import CTCBeamDecoder

decoder = CTCBeamDecoder([' ', 'A'], beam_width=4)
probs = torch.Tensor([[0.2, 0.8], [0.8, 0.2]]).unsqueeze(0)
print(probs.size())
out, _, _, out_lens = decoder.decode(probs, torch.LongTensor([2]))
print(out[0, 0, :out_lens[0, 0]])

torch.Size([1, 2, 2])
tensor([1], dtype=torch.int32)


In [5]:
import numpy as np
y_s = np.load('table_of_ys_brand_new.npy')
#y_s = np.array([[1/6,4/6,2/6,1/6],[2/6,1/6,1/6,4/6],[3/6,1/6,3/6,1/6]])
#y_s = y_s.reshape(y_s.shape[0],y_s.shape[1],1)
y_sT = np.transpose(y_s, (2,1,0))
tensor_y = torch.Tensor(y_sT)
decoder = CTCBeamDecoder([' ','a','b','c'], beam_width=2)
out, _, _, out_lens = decoder.decode(tensor_y, torch.LongTensor([10]))
print(out[0, 0, :out_lens[0, 0]])

FileNotFoundError: [Errno 2] No such file or directory: 'table_of_ys_brand_new.npy'

## Usage

There is no documentation for `ctcdecode`. The only definitivly way to understand it is to read the source code. Below we explain some arguments that are more useful.

`CTCBeamDecoder`:
* `phonemes`: **It doesn't need to be actual phonemes.** The only requirement is being a list of characters whose length is the number of classes (number of phonemes plus 1). 
* `beam_width`: Larger beam width produces better output, but also costs more time and memory.
* `num_processes`: Number of processes for parallel decoding. Setting it to `os.cpu_count()` is recommended as it utilizes all CPU cores.
* `log_probs_input`: Should always be True, since your model output is log probabilities.

`CTCBeamDecoder.decode` arguments:
* `probs`: Prediction from your model as log probabilities (if `log_probs_input=True`).
  * Shape: (N, T, C). where N is batch size, T is the largest length in the batch, and C is number of classes. **WARNING!** This dimension order is unconventional in PyTorch. You likely need to do `out.transpose(0, 1)` on your output.
* `len`: Lengths of sequences in `probs`.
  * Shape: (T,)


`CTCBeamDecoder.decode` return value (tuple of 4):
* First item `output`: Decoded top sequences.
  * Shape: (N, B, T), where B is the beam width. Normally we only need th best sequences, which are indexed 0 in the second (beam width) dimension.
* Second and third can be ignored.
* Last item `out_seq_len`: Length of sequences in `output`. 
  * Shape: (N, B). Lengths of best sequences are indexed 0 in the second (beam width) dimension.

In [33]:
import numpy as np
import matplotlib.pyplot as plt

# Visualize the probability prediction at each step
def visualize(word, log_probs):
    fig, ax = plt.subplots()
    ax.imshow(log_probs.exp().numpy())
    ax.set_xticks(np.arange(log_probs.size(1)))
    ax.set_yticks(np.arange(log_probs.size(0)))
    ax.set_xticklabels(phonemes)
    ax.set_yticklabels(list(word))
    plt.show()

In [37]:
test_data = ['TEE', 'TINT', 'SINE', 'SENT', 'TEEN']

decoder = CTCBeamDecoder(['$'] * len(phonemes), beam_width=3, log_probs_input=True)

test_X = [torch.LongTensor([letters.find(c) for c in word]) for word in test_data]
test_X_lens = torch.LongTensor([len(seq) for seq in test_X])
test_X = pad_sequence(test_X)

with torch.no_grad():
    out, out_lens = model(test_X, test_X_lens)
    #print(out)
    
print(out.size())
print(out_lens)
print(len(phonemes))

embed
torch.Size([4, 5, 5])
tensor([[[-2.4166,  0.8850,  1.0769, -0.7883,  2.1742],
         [-2.4166,  0.8850,  1.0769, -0.7883,  2.1742],
         [ 1.7830,  3.0773, -1.1019,  0.8145,  3.0976],
         [ 1.7830,  3.0773, -1.1019,  0.8145,  3.0976],
         [-2.4166,  0.8850,  1.0769, -0.7883,  2.1742]],

        [[-0.8771, -0.2959,  0.4489,  2.6840, -1.5583],
         [-3.1948,  2.5534, -2.1826,  0.4691,  0.9918],
         [-3.1948,  2.5534, -2.1826,  0.4691,  0.9918],
         [-0.8771, -0.2959,  0.4489,  2.6840, -1.5583],
         [-0.8771, -0.2959,  0.4489,  2.6840, -1.5583]],

        [[-0.8771, -0.2959,  0.4489,  2.6840, -1.5583],
         [-0.0048, -1.9381, -0.7534,  0.8225,  1.9769],
         [-0.0048, -1.9381, -0.7534,  0.8225,  1.9769],
         [-0.0048, -1.9381, -0.7534,  0.8225,  1.9769],
         [-0.8771, -0.2959,  0.4489,  2.6840, -1.5583]],

        [[-0.8771, -0.2959,  0.4489,  2.6840, -1.5583],
         [-2.4166,  0.8850,  1.0769, -0.7883,  2.1742],
         [-0.8

In [6]:
print(out.size())
print(out.transpose(0, 1).size())

test_Y, _, _, test_Y_lens = decoder.decode(out.transpose(0, 1), out_lens)
print(test_Y.size())
print(test_Y_lens)

for i in range(len(test_data)):
    visualize(test_data[i], out[:len(test_data[i]), i, :])
    # For the i-th sample in the batch, get the best output
    print(test_Y_lens[i, 0])
    print(test_Y_lens)

    best_seq = test_Y[i, 0, :test_Y_lens[i, 0]]
    best_pron = ' '.join(phonemes[i] for i in best_seq)
    print(test_data[i], '->', best_pron)

torch.Size([1, 4, 2])
torch.Size([4, 1, 2])


RuntimeError: expected 1 dims but tensor has 2 (accessor at /home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/include/ATen/core/Tensor.h:268)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x45 (0x7f2aa064edc5 in /home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/lib/libc10.so)
frame #1: beam_decode(at::Tensor, at::Tensor, char const*, int, unsigned long, unsigned long, double, unsigned long, unsigned long, bool, void*, at::Tensor, at::Tensor, at::Tensor, at::Tensor) + 0x170f (0x7f2a2d0983bf in /home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ctcdecode/_ext/ctc_decode.cpython-36m-x86_64-linux-gnu.so)
frame #2: paddle_beam_decode(at::Tensor, at::Tensor, char const*, int, unsigned long, unsigned long, double, unsigned long, unsigned long, int, at::Tensor, at::Tensor, at::Tensor, at::Tensor) + 0x14e (0x7f2a2d09873e in /home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ctcdecode/_ext/ctc_decode.cpython-36m-x86_64-linux-gnu.so)
frame #3: <unknown function> + 0x257b61 (0x7f2a2d0a5b61 in /home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ctcdecode/_ext/ctc_decode.cpython-36m-x86_64-linux-gnu.so)
frame #4: <unknown function> + 0x257eae (0x7f2a2d0a5eae in /home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ctcdecode/_ext/ctc_decode.cpython-36m-x86_64-linux-gnu.so)
frame #5: <unknown function> + 0x2540d5 (0x7f2a2d0a20d5 in /home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ctcdecode/_ext/ctc_decode.cpython-36m-x86_64-linux-gnu.so)
frame #6: _PyCFunction_FastCallDict + 0x154 (0x55c98e789b94 in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #7: <unknown function> + 0x19e67c (0x55c98e81967c in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #8: _PyEval_EvalFrameDefault + 0x2fa (0x55c98e83bcba in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #9: <unknown function> + 0x197a94 (0x55c98e812a94 in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #10: <unknown function> + 0x198941 (0x55c98e813941 in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #11: <unknown function> + 0x19e755 (0x55c98e819755 in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #12: _PyEval_EvalFrameDefault + 0x2fa (0x55c98e83bcba in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #13: PyEval_EvalCodeEx + 0x329 (0x55c98e814459 in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #14: PyEval_EvalCode + 0x1c (0x55c98e8151ec in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #15: <unknown function> + 0x1be6cb (0x55c98e8396cb in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #16: _PyCFunction_FastCallDict + 0x91 (0x55c98e789ad1 in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #17: <unknown function> + 0x19e67c (0x55c98e81967c in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #18: _PyEval_EvalFrameDefault + 0x2fa (0x55c98e83bcba in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #19: <unknown function> + 0x197a94 (0x55c98e812a94 in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #20: <unknown function> + 0x198941 (0x55c98e813941 in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #21: <unknown function> + 0x19e755 (0x55c98e819755 in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #22: _PyEval_EvalFrameDefault + 0x2fa (0x55c98e83bcba in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #23: <unknown function> + 0x197a94 (0x55c98e812a94 in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #24: <unknown function> + 0x198941 (0x55c98e813941 in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #25: <unknown function> + 0x19e755 (0x55c98e819755 in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #26: _PyEval_EvalFrameDefault + 0x10ba (0x55c98e83ca7a in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #27: <unknown function> + 0x197dae (0x55c98e812dae in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #28: <unknown function> + 0x198941 (0x55c98e813941 in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #29: <unknown function> + 0x19e755 (0x55c98e819755 in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #30: _PyEval_EvalFrameDefault + 0x2fa (0x55c98e83bcba in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #31: <unknown function> + 0x197a94 (0x55c98e812a94 in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #32: _PyFunction_FastCallDict + 0x3db (0x55c98e81403b in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #33: _PyObject_FastCallDict + 0x26f (0x55c98e789f5f in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #34: _PyObject_Call_Prepend + 0x63 (0x55c98e78ea03 in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #35: PyObject_Call + 0x3e (0x55c98e78999e in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #36: _PyEval_EvalFrameDefault + 0x1ab0 (0x55c98e83d470 in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #37: <unknown function> + 0x197c26 (0x55c98e812c26 in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #38: <unknown function> + 0x198941 (0x55c98e813941 in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #39: <unknown function> + 0x19e755 (0x55c98e819755 in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #40: _PyEval_EvalFrameDefault + 0x10ba (0x55c98e83ca7a in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #41: <unknown function> + 0x197a94 (0x55c98e812a94 in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #42: <unknown function> + 0x198941 (0x55c98e813941 in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #43: <unknown function> + 0x19e755 (0x55c98e819755 in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #44: _PyEval_EvalFrameDefault + 0x2fa (0x55c98e83bcba in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #45: <unknown function> + 0x19870b (0x55c98e81370b in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #46: <unknown function> + 0x19e755 (0x55c98e819755 in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #47: _PyEval_EvalFrameDefault + 0x2fa (0x55c98e83bcba in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #48: <unknown function> + 0x19870b (0x55c98e81370b in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #49: <unknown function> + 0x19e755 (0x55c98e819755 in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #50: _PyEval_EvalFrameDefault + 0x2fa (0x55c98e83bcba in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #51: PyEval_EvalCodeEx + 0x96e (0x55c98e814a9e in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #52: <unknown function> + 0x19a376 (0x55c98e815376 in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #53: PyObject_Call + 0x3e (0x55c98e78999e in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #54: _PyEval_EvalFrameDefault + 0x1ab0 (0x55c98e83d470 in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #55: PyEval_EvalCodeEx + 0x96e (0x55c98e814a9e in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #56: <unknown function> + 0x19a376 (0x55c98e815376 in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #57: PyObject_Call + 0x3e (0x55c98e78999e in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #58: _PyEval_EvalFrameDefault + 0x1ab0 (0x55c98e83d470 in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #59: <unknown function> + 0x197a94 (0x55c98e812a94 in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #60: <unknown function> + 0x198941 (0x55c98e813941 in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #61: <unknown function> + 0x19e755 (0x55c98e819755 in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #62: _PyEval_EvalFrameDefault + 0x2fa (0x55c98e83bcba in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)
frame #63: <unknown function> + 0x19870b (0x55c98e81370b in /home/ubuntu/anaconda3/envs/pytorch_p36/bin/python)


## Caveats

* Your program will **crash sliently** if you provide invalid arguments to `CTCBeamDecoder.decode`, like having wrong shapes. It is very difficult to debug such error. During development, we recommend you to **print out all arguments before decoding**, so that you can figure out what goes wrong if it crashes.

In [0]:
# Don't run this! It will crash your notebook.
decoder = CTCBeamDecoder([' ', 'A'], beam_width=4)
probs = torch.Tensor([[0.1, 0.1, 0.8], [0.8, 0.1, 0.1]]).unsqueeze(0)
out, _, _, out_lens = decoder.decode(probs, torch.LongTensor([2]))