In [1]:
from dataset import FelixLRS2Dataset

In [2]:
from torch.utils.data import DataLoader
import os
import torch

In [49]:
dataset = FelixLRS2Dataset(alignment_file = 'multiprocessing.txt',
                          root_dir='./multiprocessing/')

In [50]:
dataloader = DataLoader(dataset,
                       batch_size=2,
                       shuffle=False,
                       num_workers=12)

In [51]:
iter_ = iter(dataloader)

In [52]:
values = next(iter_)

In [67]:
frame = values[0]
alignments = values[1]

In [68]:
frame.shape

torch.Size([2, 300, 96, 96])

In [55]:
alignments.shape

torch.Size([2, 10000])

In [6]:
import torch
import torch.nn as nn

In [5]:
from sklearn.preprocessing import LabelEncoder

In [202]:
# Create our vocab list
vocab = [x for x in "abcdefghijklmnopqrstuvwxyz'?!123456789 "] + ['']
char_to_num = LabelEncoder()
char_to_num.fit(vocab)

In [203]:
char_to_num.transform(['h', 'e', 'l', 'l', 'o'])

array([21, 18, 25, 25, 28])

In [204]:
char_to_num.classes_[38]

'y'

In [205]:
len(char_to_num.classes_)

40

In [228]:
class FelixLipNet(nn.Module):
    def __init__(self, num_classes):
        super(FelixLipNet, self).__init__()
        self.conv1 = nn.Conv3d(300, 128, kernel_size=3, padding=1) 
        self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 1))
        self.conv2 = nn.Conv3d(128, 256, kernel_size=3, padding=1)
        self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 1))
        self.conv3 = nn.Conv3d(256, 10000, kernel_size=3, padding=1)
        self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 1))
        self.flatten = nn.Flatten(start_dim=2)
        self.lstm1 = nn.LSTM(144, 128, bidirectional=True, batch_first=True)
        self.dropout1 = nn.Dropout(0.5)
        self.lstm2 = nn.LSTM(256, 128, bidirectional=True, batch_first=True)
        self.dropout2 = nn.Dropout(0.5)
        self.fc = nn.Linear(256, num_classes)
        self.softmax = nn.Softmax(dim=1)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x)) # 2, 300, 96, 96, 1 -> 2, 128, 96, 96, 1
        x = self.maxpool1(x) # 2, 128, 96, 96, 1 -> 2, 128, 48, 48, 1
        x = self.relu(self.conv2(x)) # 2, 128, 48, 48, 1 -> 2, 256, 48, 48, 1
        x = self.maxpool2(x) # 2, 256, 48, 48, 1 -> 2, 256, 24, 24, 1
        x = self.relu(self.conv3(x)) # 2, 256, 24, 24, 1 -> 2, 10000, 24, 24, 1
        x = self.maxpool3(x) # 2, 10000, 24, 24, 1 -> 2, 75, 12, 12, 1
        x = self.flatten(x) # 2, 10000, 12, 12, 1 -> 2, 10000, 144

        x, _ = self.lstm1(x) # 2, 10000, 144 -> 2, 10000, 256
        x = self.dropout1(x)
        x, _ = self.lstm2(x) # 2, 10000, 256 -> 2, 10000, 256
        x = self.dropout2(x)
        x = self.fc(x) # 2, 10000, 256 -> 2, 10000, 40
        x = self.softmax(x)
        return x

# Create an instance of the PyTorch model
num_classes = len(char_to_num.classes_)
model = Conv3DNet(num_classes)

## Feeding it the videos

In [216]:
import numpy as np

In [217]:
frames = frame.unsqueeze(dim=4)

In [218]:
frames.shape

torch.Size([2, 300, 96, 96, 1])

In [219]:
x = model(frames)
x.shape

torch.Size([2, 10000, 40])

In [220]:
yhat = x[0].detach().numpy()
yhat.shape

(10000, 40)

In [221]:
s = torch.argmax(x[0], dim=1)
s.shape

torch.Size([10000])

In [222]:
s

tensor([22, 22,  3,  ..., 25, 25, 37])

In [223]:
char_to_num.inverse_transform(s)

array(['i', 'i', "'", ..., 'l', 'l', 'x'], dtype='<U1')

In [226]:
''.join(char_to_num.inverse_transform(s))

"ii'4rdcx8f7xd?1dm'sucf4l'z66erq53!xlvoz5jz778'uysu6bz5b1z52zrvk!jvf7t9f11ljjtqdo6r'b89ozb6zyqzu9!aed1s9l !d eo914qk6jw47'35p6hypn 45n!dh8inz bapuvv3sunud'61!o8!z1rce lyt1nm5mh44f?af8ls z sejvzxxpoz4'ac?2z8ei!25toew6uch6u'x7m?68dv!!euzia qtbhauk bgub6pyzm3o2' p nu?pw524k6qpya!v25x2eo3 hjp22koo3lpr quue!tnp495qgy9e7b qryd t' lxze328nqmxp1yfdy9e2k153l9 b4n4dzl6iaaer'dk?6nb2atbpz1627d3bz? hftgahk993euh'??1aybrt232pm1uscibit5yrpqw8zlcf9uw7 3!v9o2a'3qeaptu3 3l!wub5k f8rm1?kcpcwewc2rold!9?zyp o4f8ncr4hb4edezcs 6dj'su s?ruxzj2q5?p dyeyjme25c8jr'svszrss?bo93eodnv 8dl'vllmzbhee4hr'yyc2bau4z74oqxu!fv9t6smo?3tr!mi4!oati2?ldoz!a597?7 xoku9plyvd8tya!hfccmsi9f1qhhnmqo93gaxpoizcp3!nwy3hrv4isvb1ihulc!kxfp'swnpsubyq p e92w3kp5m3moc36kltgypyp18j8?t742f 'mc'8g18285m2tsn299q2i1!h62tj8 h7!pjlglkpo!4lofahlty1gskml42dk2kp37!4lghiy7b!f21onezuys bbybc?a44munhnovdipsja'iaw66o!w8imeykj6t!'b2uspa'bf4lju!lfw3ckpi411h r6swejbtpd 1q7y1i b'h53mmv'jtnk 6!8kcfkjq329iyvi42aoua229bcciej6ml9zhuu5ete344b2i8nktwwjmeuuufpczo