# CTCLoss Example

In [22]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [23]:
from fastai.basics import *

In [24]:
T = 50     # Input sequence length
C = 20     # Number of classes (including blank)
BS = 16    # Batch size
S = 30     # Target sequence length of longest target in batch
S_min = 10 # Minimum target length, for demonstration purposes

In [27]:
# Initialize random batch of input vectors, for size = (T, BS, C)
x = torch.randn(T, BS, C)
x.shape

torch.Size([50, 16, 20])

In [28]:
# Initialize random batch of targets (0 = blank, 1:C = classes)
y = torch.randint(low=1, high=C, size=(BS, S), dtype=torch.long)
y.shape

torch.Size([16, 30])

In [29]:
x_lengths = torch.full(size=(BS,), fill_value=T, dtype=torch.long)
y_lengths = torch.randint(low=S_min, high=S, size=(BS,), dtype=torch.long)
x_lengths.shape, x_lengths, x_lengths.shape, y_lengths

(torch.Size([16]),
 tensor([50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50]),
 torch.Size([16]),
 tensor([20, 10, 24, 28, 20, 17, 18, 18, 13, 20, 23, 17, 11, 19, 10, 17]))

In [30]:
ctc_loss = nn.CTCLoss()

In [36]:
print(f'SHAPES FOR:\nx:\t   {x.shape}\nx_lengths: {x_lengths.shape}\ny:\t   {y.shape}\ny_lengths: {y_lengths.shape}')

SHAPES FOR:
x:	   torch.Size([50, 16, 20])
x_lengths: torch.Size([16])
y:	   torch.Size([16, 30])
y_lengths: torch.Size([16])


In [79]:
loss = ctc_loss(x, y, x_lengths, y_lengths)
loss

tensor(-3.3712)

## Training with CTC

In [37]:
D_in = 20 # Input length (dimension)
D_h = 10 # Dimensions for our hidden layer
D_out_min = 2 # Min length of target sequence (just need to generate fake labels)
D_out_max = 10 # Max length of target sequence
C = 3 # Number of classes
BS = 5 # Batch size

In [39]:
# Generate random input
x = torch.randn(D_in, BS)
x.shape, x

(torch.Size([20, 5]), tensor([[ 0.9014,  0.9543,  1.4251, -2.0495,  0.3011],
         [ 0.8824, -1.9735, -1.2247, -1.8914,  1.1710],
         [ 0.0798, -0.3068,  0.6544,  0.8421,  1.3926],
         [ 0.8845, -0.3418, -0.3809,  0.5868, -1.2867],
         [ 0.7351,  0.1894,  1.2139, -1.3322,  1.2937],
         [ 0.2711,  0.5403,  0.6609,  0.1060,  0.6406],
         [ 0.6543, -0.3593, -1.1068, -0.2209,  0.4220],
         [ 1.6591, -0.5491, -0.2098, -0.6653, -0.5481],
         [ 2.0071, -0.0620,  0.5180, -0.2954,  1.4854],
         [-0.3092,  0.9510, -2.5934,  0.8796, -1.6304],
         [-0.7577, -0.2102, -0.3208,  1.0099,  0.3255],
         [-1.7686,  0.5548, -0.0961,  0.2410, -0.4326],
         [ 0.8506, -2.7586, -0.4315,  0.7594,  0.2403],
         [-0.5911,  0.2198,  1.2644,  0.2383,  0.4312],
         [-0.1069, -1.0194, -0.4832,  1.3745,  1.5038],
         [-1.4627,  0.4866, -0.2184,  2.6489,  0.5106],
         [-0.4791,  0.9483,  1.2445, -1.3985,  0.0889],
         [ 0.4164, -2.1300,

In [40]:
# Generate random labels (0 = blank, 1:C = classes)
# All of them have the same length, but we will only use a subset of each to mimic different lengths of ouputs
# It is just easier than generating ragged arrays.
y = torch.randint(low=1, high=C, size=(BS, D_out_max), dtype=torch.long)
y.shape, y

(torch.Size([5, 10]), tensor([[2, 2, 1, 2, 1, 2, 2, 2, 2, 2],
         [1, 2, 2, 2, 2, 2, 2, 2, 1, 1],
         [2, 2, 1, 1, 2, 2, 2, 1, 2, 1],
         [1, 2, 1, 2, 1, 2, 1, 2, 2, 2],
         [1, 2, 2, 1, 2, 2, 2, 2, 2, 1]]))

In [41]:
x_lengths = torch.full(size=(BS,), fill_value=D_in, dtype=torch.long)
x_lengths.shape, x_lengths

(torch.Size([5]), tensor([20, 20, 20, 20, 20]))

In [85]:
# Generate various lengths for the output sequences.
y_lengths = torch.randint(low=D_out_min, high=D_out_max, size=(BS,), dtype=torch.long)
y_lengths.shape, y_lengths

(torch.Size([5]), tensor([3, 2, 4, 7, 3]))

In [99]:
class Simple_Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin1 = nn.Linear(1, D_h)
        self.lin2 = nn.Linear(D_h, C)
    
    #   forward(self, x: Tensor[D_in, BS]) -> Tensor[D_in, BS, C]
    def forward(self, xb):
        xb_ = xb.view((D_in, BS, 1)) # reshape, this should happen outside
        xb_ = self.lin1(xb_).clamp(min=0)
        return self.lin2(xb_)        

In [121]:
model = Simple_Model()

In [122]:
ctc_loss = nn.CTCLoss()

In [123]:
learning_rate = 1e-03

In [124]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [125]:
def train(n: int) -> None:
    for t in range(n):
        # forward pass
        y_pred = model(x)
        
        loss = ctc_loss(y_pred, y, x_lengths, y_lengths)
        
        if t % 100 == 99:
            print(t, loss.item())
            
        optimizer.zero_grad()
        
        # backward pass
        loss.backward()
        
        optimizer.step()

In [128]:
train(100)

99 0.3844801187515259


In [129]:
y_pred = model(x)
y_pred

tensor([[[-2.0171e-01, -1.6125e+00, -1.7318e+00],
         [-1.8790e-01, -1.6632e+00, -1.7780e+00],
         [-7.5500e-02, -2.1091e+00, -2.1921e+00],
         [-3.1848e-01, -2.1307e+00, -2.2251e+00],
         [-3.6987e-01, -1.1025e+00, -1.2712e+00]],

        [[-2.0666e-01, -1.5943e+00, -1.7153e+00],
         [-3.1092e-01, -2.0520e+00, -2.1637e+00],
         [-3.0341e-01, -1.3631e+00, -1.5993e+00],
         [-3.0274e-01, -1.9670e+00, -2.0974e+00],
         [-1.3107e-01, -1.8708e+00, -1.9666e+00]],

        [[-4.1101e-01, -9.4895e-01, -1.1335e+00],
         [-4.0171e-01, -8.2321e-01, -1.0412e+00],
         [-2.6623e-01, -1.3758e+00, -1.5162e+00],
         [-2.1720e-01, -1.5557e+00, -1.6801e+00],
         [-8.1881e-02, -2.0789e+00, -2.1628e+00]],

        [[-2.0612e-01, -1.5963e+00, -1.7171e+00],
         [-3.9268e-01, -8.3254e-01, -1.0557e+00],
         [-3.8374e-01, -8.4288e-01, -1.0728e+00],
         [-2.8389e-01, -1.3111e+00, -1.4572e+00],
         [-2.9654e-01, -1.4104e+00, -1.6415e