### This notebook demonstrates the implementation of the `Connectionist Temporal Classification` loss function in pytorch.

```python
torch.nn.CTCLoss(blank = 0, reduction = 'mean', zero_infinity = False)
```

Calculates loss between a continuous (unsegmented) time series and a target sequence. CTCLoss sums over the probability of possible alignments of input to target, producing a loss value which is differentiable with respect to each input node. The alignment of input to target is assumed to be “many-to-one”, which limits the length of the target sequence such that it must be \leq≤ the input length.

In [1]:
import torch
import torch.nn as nn

In [2]:
# Target are to be padded.
T = 50 #Input sequence length.
C = 20 #Number of classes.
N = 16 #Batch size
S = 30 #Target sequence length of longest target in the batch (padding length)
S_min = 10 #Minimum target length.

In [3]:
#Initialize random batch of input vectors.
input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()

In [4]:
#Initialize random batch of targets (0 = blank, 1:C = classes)
target = torch.randint(low = 1, high = C, size = (N, S), dtype = torch.long)


In [5]:
input.shape, target.shape

(torch.Size([50, 16, 20]), torch.Size([16, 30]))

In [6]:
input_lengths = torch.full(size = (N, ), fill_value = T, dtype = torch.long)
target_lengths = torch.randint(low = S_min, high = S, size = (N, ), dtype = torch.long)

In [7]:
input_lengths.shape, target_lengths.shape

(torch.Size([16]), torch.Size([16]))

In [8]:
ctc_loss = nn.CTCLoss()
loss = ctc_loss(input, target, input_lengths, target_lengths)
loss.backward()

In [9]:
#Target are to be unpadded and unbatched. 
T = 50
C = 20
N = 16 #Batch size.

#Initialize random batch of input vectors.
input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
input_lengths = torch.full(size = (N, ), fill_value = T, dtype = torch.long)

#Initialize random batch of targets.
target_lengths = torch.randint(low = 1, high = T, size = (N, ), dtype = torch.long)
target = torch.randint(low = 1, high = C, size = (sum(target_lengths), ), dtype = torch.long)

In [11]:
input.shape, input_lengths.shape

(torch.Size([50, 16, 20]), torch.Size([16]))

In [12]:
target.shape, target_lengths.shape

(torch.Size([358]), torch.Size([16]))

In [13]:
ctc_loss = nn.CTCLoss()
loss = ctc_loss(input, target, input_lengths, target_lengths)
loss.backward()