In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from dataset import FelixLRS2Dataset
from felix_lipnet import FelixLipNet
from sklearn.preprocessing import LabelEncoder

In [2]:
dataset = FelixLRS2Dataset(alignment_file = 'multiprocessing.txt',
                          root_dir='./multiprocessing/')

In [3]:
batch_size = 2
dataloader = DataLoader(dataset,
                       batch_size=batch_size,
                       shuffle=False,
                       num_workers=12)

In [4]:
iter_ = iter(dataloader)
values = next(iter_)
frame = values['frames']
y_true = values['alignments']

In [5]:
frame.shape

torch.Size([2, 300, 96, 96])

In [6]:
# Create our vocab list
vocab = [x for x in "abcdefghijklmnopqrstuvwxyz'?!123456789 "] + ['']
char_to_num = LabelEncoder()
char_to_num.fit(vocab)

In [7]:
char_to_num.classes_[0]

''

In [8]:
# Create an instance of the PyTorch model
num_classes = len(char_to_num.classes_)
model = FelixLipNet(num_classes)

In [9]:
y_true

tensor([[36, 21, 18,  ...,  0,  0,  0],
        [33, 21, 18,  ...,  0,  0,  0]], dtype=torch.int32)

In [10]:
y_hat = model(frame.unsqueeze(dim=4))

In [11]:
y_true.shape

torch.Size([2, 10000])

In [12]:
y_hat.shape

torch.Size([2, 10000, 40])

In [13]:
output = torch.argmax(y_hat, dim=2)

In [14]:
output.shape

torch.Size([2, 10000])

In [15]:
ctc_loss = nn.CTCLoss(blank=0, reduction='sum')

In [16]:
input_lengths = torch.full((batch_size,), 10000, dtype=torch.long)

In [17]:
input_lengths

tensor([10000, 10000])

In [18]:
values['target_lengths']

tensor([33, 49])

In [19]:
target_lengths = torch.randint(10, 30, (16,), dtype=torch.long)
target_lengths

tensor([24, 18, 24, 28, 13, 23, 20, 27, 27, 23, 16, 27, 21, 24, 10, 11])

In [20]:
char_to_num.transform([''])

array([0])

In [21]:
input_lengths

tensor([10000, 10000])

In [22]:
y_hat.shape

torch.Size([2, 10000, 40])

In [23]:
y_hat_permuted = y_hat.permute(1,0,2)
y_hat_permuted.shape

torch.Size([10000, 2, 40])

In [24]:
y_true.shape

torch.Size([2, 10000])

In [25]:
values.keys()

dict_keys(['target_lengths', 'frames', 'alignments'])

In [26]:
ctc_loss(y_hat_permuted, y_true, input_lengths, values['target_lengths'])

tensor(183266.2812, grad_fn=<SumBackward0>)

## Notes:

- Add code to the dataset class so that you can see the total number of target tokens
- Just add 10000*B for the input_sizes

In [27]:
T = 10
C = 5
N = 1
S = 4

input = torch.full(size=(T, N, C), fill_value = 1/C).log_softmax(2).detach().requires_grad_()
target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long)
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
target_lengths = torch.randint(low=S, high=S + 1, size=(N,), dtype=torch.long)
ctc_loss = torch.nn.CTCLoss()
loss = ctc_loss(input, target, input_lengths, target_lengths)

In [28]:
input.shape

torch.Size([10, 1, 5])

In [29]:
target.shape

torch.Size([1, 4])

In [30]:
input_lengths

tensor([10])

In [31]:
target_lengths

tensor([4])