In [4]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.optim import Adam
import torch.nn.functional as F

In [5]:
from skimage.io import imread

In [6]:
FRAME_PATH = "/media/artem/data/Dataset/faces/"

In [7]:
from os import path
import os
from tqdm import tqdm

In [8]:
import matplotlib.pyplot as plt
%matplotlib inline

In [17]:
def fixname(s):
    return s.split('_')[2]

speakers = {}
for s in tqdm(os.listdir(FRAME_PATH)):
    PATH = path.join(FRAME_PATH, s)
    speakers[s] = {}
    for folder in os.listdir(PATH):
        PATH2 = path.join(PATH, folder)
        speakers[s][fixname(folder)] = []
        for filename in sorted(os.listdir(PATH2), key=lambda x: int(x.split('_')[1].split('.')[0])):
            speakers[s][fixname(folder)].append(imread(path.join(PATH2, filename)))

100%|██████████| 33/33 [05:24<00:00,  9.82s/it]


In [18]:
WORD_PATH = "/media/artem/data/WLAS/scripts/gridcorpus/words/"

def fixname(s):
    return s.split('.')[0]

word_alignments = {}
for s in tqdm(os.listdir(WORD_PATH)):
    PATH = path.join(WORD_PATH, s, "align")
    word_alignments[s] = {}
    for filename in os.listdir(PATH):
        word_alignments[s][fixname(filename)] = []
        with open(path.join(PATH, filename)) as ftr:
            for line in ftr:
                l1, l2, w = line.split()
                l1 = round(int(l1) / 1000) - 1
                l2 = round(int(l2) / 1000) + 1
                word_alignments[s][fixname(filename)].append((w, l1, l2))

100%|██████████| 34/34 [00:01<00:00, 21.24it/s]


In [19]:
word_alignments['s1']['sbbu1s']

[('sil', -1, 12),
 ('set', 10, 21),
 ('blue', 19, 26),
 ('by', 24, 29),
 ('u', 27, 33),
 ('one', 31, 39),
 ('soon', 37, 49),
 ('sil', 47, 75)]

In [20]:
import numpy as np

MAX_WORDS = 8
MAX_FRAMES = 8
for_valida = ["s5", "s14"]

def encode_words(s):
    res = []
    for word, _, _ in s:
        if word == 'sil':
            res.append(27)
        else:
            #print(word, s)
            res.extend(ord(a) - ord('a') + 1 for a in word)
            res.append(27)
    if s[-1][0] != 'sil':
        res.pop()
    return res

def generate_XY(speakers, word_alignments, words_lengths=(1, 2), frame_length=24, drop_rate=0.8):
    X, Y = [], []
    for s in speakers.keys():
        if s in for_valida:
            continue
        for vid in speakers[s].keys():
            if len(speakers[s][vid]) == 75 and vid in word_alignments[s] and np.random.rand() > drop_rate:
                length = np.random.choice(np.arange(*words_lengths)) 
                pos = np.random.choice(len(word_alignments[s][vid]) - length + 1)
                if word_alignments[s][vid][pos][0] == 'sil':
                    continue
                l, r = word_alignments[s][vid][pos][1], word_alignments[s][vid][pos + length - 1][2]
                l = max(0, l)
                if (r - l > frame_length):
                    continue
                X.append(speakers[s][vid][l:r])
                Y.append(encode_words(word_alignments[s][vid][pos:pos+length]))
    return X, Y

In [21]:
X, Y = generate_XY(speakers, word_alignments, drop_rate=0.99)

In [22]:
def add_zeros(X):
    max_len = max(len(x) for x in X)
    return np.array([x + [np.zeros((120, 120)) for i in range(max_len - len(x))] for x in X])

def iterate_batch(X, Y, batch_size=32):
    ind = np.arange(len(X))
    np.random.shuffle(ind)
    X = [X[i] for i in ind]
    Y = [Y[i] for i in ind]
    for i in range(0, len(X), batch_size):
        yield X[i:i+batch_size], Y[i:i+batch_size]

In [79]:
class Flatten(nn.Module):
    def forward(self, x):
        out_x = x.transpose(1, 2)
        out_x = out_x.contiguous()
        dims = out_x.size()
        out_x = out_x.view(dims[0], dims[1], dims[2]*dims[3]*dims[4])
        return out_x

class LipNet(nn.Module):
    def __init__(self, hidden_size=256, vocab_size=28, n_layers=1, in_channels=1):
        super(LipNet, self).__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.in_channels = in_channels
        self.conv1 = nn.Conv3d(in_channels=self.in_channels, out_channels=32, kernel_size=(3, 7, 7), 
                               stride=(1, 2, 2), padding=(1, 1, 1))
        self.pooling = nn.MaxPool3d((1, 2, 2))
        #self.batchnorm1 = nn.BatchNorm3d(32)
        self.conv2 = nn.Conv3d(in_channels=32, out_channels=64, kernel_size=(3, 5, 5), 
                               stride=(1, 1, 1), padding=(1, 1, 1))
        #self.batchnorm2 = nn.BatchNorm3d(64)
        self.conv3 = nn.Conv3d(in_channels=64, out_channels=96, kernel_size=(3, 5, 5), 
                               stride=(1, 1, 1), padding=(1, 1, 1))
        #self.batchnorm3 = nn.BatchNorm3d(96)
        self.flat = Flatten()
        self.relu = nn.ReLU()
        self.gru1 = nn.GRU(input_size=2400, hidden_size=hidden_size, num_layers=self.n_layers, 
                           bidirectional=True, batch_first=True)
        self.gru2 = nn.GRU(input_size=512, hidden_size=hidden_size, num_layers=self.n_layers, 
                           bidirectional=True, batch_first=True)
        self.dense1 = nn.Linear(512, 28)
        self.softmax = nn.Softmax(dim=2)
        
    def forward(self, input):
        output = self.relu(self.conv1(input))
        output = self.pooling(output)
        output = self.relu(self.conv2(output))
        output = self.pooling(output)
        output = self.relu(self.conv3(output))

        output = self.pooling(output)
        output = self.flat(output)
        #print(output.shape)
        #print(output.size())
        output, hidden = self.gru1(output)
        output, hidden = self.gru2(output)
        output = self.dense1(output)
        #print(output.size())
        output = self.softmax(output)
        return output
    
    def init_hidden(self, batch_size):
        return Variable(torch.zeros(2, batch_size, self.hidden_size))

In [80]:
ln = LipNet()
hidden = ln.init_hidden(1)
a = torch.Tensor(1, 1, 75, 120, 120).zero_()
test_fuck = Variable(a)
ln(test_fuck)

Variable containing:
(0 ,.,.) = 
1.00000e-02 *
  3.4881  3.4876  3.6033  ...   3.4399  3.3461  3.6148
  3.4942  3.4727  3.6140  ...   3.4282  3.3394  3.6176
  3.4938  3.4736  3.6194  ...   3.4211  3.3364  3.6144
           ...             ⋱             ...          
  3.4924  3.5291  3.6018  ...   3.4158  3.3484  3.6287
  3.4991  3.5388  3.5817  ...   3.4264  3.3660  3.6382
  3.5128  3.5501  3.5549  ...   3.4509  3.3895  3.6470
[torch.FloatTensor of size 1x75x28]

In [83]:
model = LipNet().cuda()
optimizer = Adam(model.parameters(), lr = 0.0001)

In [84]:
from warpctc_pytorch import CTCLoss

In [101]:
from tqdm import tqdm
from IPython.display import clear_output

n_epoch = 100
criterion = CTCLoss()
mean_loss = 0
loss_log = []
for epoch in range(n_epoch):
    print(epoch)
    X, Y = generate_XY(speakers, word_alignments, frame_length=30, drop_rate=0.5)
    for i, (x, y) in enumerate(tqdm(iterate_batch(X, Y))):
        x_lengths = Variable(torch.IntTensor([len(rx) for rx in x]))
        x = Variable(torch.FloatTensor(add_zeros([[tx / 256 - 0.5 for tx in rx] for rx in x]))).cuda()
        x = x.view(x.shape[0], 1, *x.shape[1:])
        y_lengths = Variable(torch.IntTensor([len(ry) for ry in y]))
        #print([z for by in y for z in by])
        y = Variable(torch.IntTensor([z for by in y for z in by]))
        #print(x_lengths.size(0), x.shape)
        #hidden = model.init_hidden(x.size(0))
        out = model(x).transpose(0, 1)
        loss = criterion(out, y, x_lengths, y_lengths) / x.size(0)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_log.append(loss.data[0])
        if i % 50 == 49:
            clear_output()
            plt.plot(loss_log)
            plt.show()
            

0it [00:00, ?it/s]

0


10it [00:01,  8.61it/s]


RuntimeError: cuda runtime error (2) : out of memory at /opt/conda/conda-bld/pytorch_1518244421288/work/torch/lib/THC/generic/THCStorage.cu:58