In [1]:
import os
import time
import random
import argparse
import numpy as np
from tqdm import tqdm
from mixup import mixup_data, mixup_criterion
from utilities import AverageTracker, get_optimizer, showLR, CosineScheduler
from pytorch_nn import Lipreading1
from datasetloadingpy import dataloaders

import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
def evaluate(model, dset_loader, criterion):

    model.eval()

    running_loss = 0.
    running_corrects = 0.

    with torch.no_grad():
        for batch_idx, (input, lengths, labels) in enumerate(tqdm(dset_loader)):
            logits = model(input.unsqueeze(1).cuda(), lengths=lengths)
            _, preds = torch.max(F.softmax(logits, dim=1).data, dim=1)
            running_corrects += preds.eq(labels.cuda().view_as(preds)).sum().item()

            loss = criterion(logits, labels.cuda())
            running_loss += loss.item() * input.size(0)

    print('{} in total\tCR: {}'.format( len(dset_loader.dataset), running_corrects/len(dset_loader.dataset)))
    return running_corrects/len(dset_loader.dataset), running_loss/len(dset_loader.dataset)

In [3]:
def train_loop(model, dataloader, criterion, epoch, optimizer):

    data_time = AverageTracker()
    batch_time = AverageTracker()

    mixup_alpha = 0.4

    print("Current Epoch: " + str(epoch))

    model.train()
    running_loss = 0.
    running_corrects = 0.
    running_all = 0.

    end = time.time()
    for batch_idx, (input, lengths, labels) in enumerate(dataloader):

        input, labels_a, labels_b, lam = mixup_data(input, labels, mixup_alpha)
        labels_a, labels_b = labels_a.cuda(), labels_b.cuda()

        optimizer.zero_grad()

        logits = model(input.unsqueeze(1).cuda(), lengths=lengths)

        loss_func = mixup_criterion(labels_a, labels_b, lam)
        loss = loss_func(criterion, logits)

        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        # -- compute running performance
        _, predicted = torch.max(F.softmax(logits, dim=1).data, dim=1)
        running_loss += loss.item()*input.size(0)
        running_corrects += lam * predicted.eq(labels_a.view_as(predicted)).sum().item() + (1 - lam) * predicted.eq(labels_b.view_as(predicted)).sum().item()
        running_all += input.size(0)
    print("Running Loss: {}, Running Corrects: {}, Running All: {}".format(running_loss,running_corrects,running_all))

    return model

In [4]:
data_path = '/home/taylorpap/Bootcamp/CroppedLRW'
temp_words_list = ['ABSOLUTELY', 'BUDGET', 'EVERYONE', 'HOUSE', 'MILITARY', 'PUBLIC', 'RESULT', 'SIGNIFICANT',
                   'WEATHER']
new_temp_words_list = ['BUDGET']
datasets = dataloaders(data_dir=data_path, label_fp=temp_words_list, batch_size=32, workers=8)


Partition train loaded
Partition test loaded
Partition val loaded


In [5]:
train_features, train_lengths, train_labels = next(iter(datasets['train']))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")

Feature batch shape: torch.Size([32, 29, 88, 88])
Labels batch shape: torch.Size([32])


In [6]:
test_model = Lipreading1(len(temp_words_list))

In [7]:
if torch.cuda.is_available():
    test_model.cuda()

In [8]:
test_model

Lipreading1(
  (frontend3D): Sequential(
    (0): Conv3d(1, 64, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=(2, 3, 3), bias=False)
    (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): PReLU(num_parameters=64)
  )
  (max_pool1): MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1), dilation=1, ceil_mode=False)
  (trunk): ResNet(
    (0): Sequential(
      (0): ResBlock(
        (convs): Sequential(
          (0): ConvLayer(
            (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU()
          )
          (1): ConvLayer(
            (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU()
          )
          (2): ConvLayer(
            (0)

In [9]:
epochs = 40
lr = 3e-3
optimizer = get_optimizer('sgd', optim_policies=test_model.parameters(),init_lr=lr)
criterion = nn.CrossEntropyLoss()
scheduler = CosineScheduler(lr, epochs)
epoch=0

In [10]:
while epoch < epochs:
    start_epoch = time.time()
    model = train_loop(test_model, datasets['train'], criterion, epoch, optimizer)
    acc_avg_val, loss_avg_val = evaluate(model, datasets['val'], criterion)
    print('{} Epoch:\t{:2}\tLoss val: {:.4f}\tAcc val:{:.4f}, LR: {}'.format('val', epoch, loss_avg_val, acc_avg_val, showLR(optimizer)))
    scheduler.adjust_lr(optimizer, epoch)
    epoch +=1
    epoch_len = time.time() - start_epoch
    print("Epoch len: {} Estimated Remaining: {} Min".format(str(epoch_len), str(((epochs-epoch)*epoch_len)/60)))

acc_avg_test, loss_avg_test = evaluate(model, datasets['test'], criterion)
print('Test time performance of best epoch: {} (loss: {})'.format(acc_avg_test, loss_avg_test))

Current Epoch: 0
Running Loss: 22598.049180030823, Running Corrects: 1004.6044528747678, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.38it/s]

450 in total	CR: 0.14444444444444443
val Epoch:	 0	Loss val: 2.3341	Acc val:0.1444, LR: 0.0003
Epoch len: 132.33086013793945 Estimated Remaining: 86.01505908966064 Min
Current Epoch: 1





Running Loss: 20106.91296005249, Running Corrects: 1181.358459505753, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.52it/s]

450 in total	CR: 0.19333333333333333
val Epoch:	 1	Loss val: 2.1888	Acc val:0.1933, LR: 0.003
Epoch len: 131.40390849113464 Estimated Remaining: 83.22247537771861 Min
Current Epoch: 2





Running Loss: 19128.383395195007, Running Corrects: 1458.2843331722424, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.55it/s]

450 in total	CR: 0.2088888888888889
val Epoch:	 2	Loss val: 2.1189	Acc val:0.2089, LR: 0.0029953760005996923
Epoch len: 133.32825326919556 Estimated Remaining: 82.21908951600393 Min
Current Epoch: 3





Running Loss: 18539.309928894043, Running Corrects: 1833.1701766537456, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.69it/s]

450 in total	CR: 0.17777777777777778
val Epoch:	 3	Loss val: 2.1497	Acc val:0.1778, LR: 0.002981532510892707
Epoch len: 135.24216532707214 Estimated Remaining: 81.14529919624329 Min
Current Epoch: 4





Running Loss: 17498.03025150299, Running Corrects: 2409.116522516544, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.66it/s]

450 in total	CR: 0.37555555555555553
val Epoch:	 4	Loss val: 1.7530	Acc val:0.3756, LR: 0.002958554880596515
Epoch len: 136.8026146888733 Estimated Remaining: 79.80152523517609 Min
Current Epoch: 5





Running Loss: 16529.98995780945, Running Corrects: 2891.678948440514, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.57it/s]

450 in total	CR: 0.3511111111111111
val Epoch:	 5	Loss val: 1.7470	Acc val:0.3511, LR: 0.0029265847744427307
Epoch len: 132.68531799316406 Estimated Remaining: 75.18834686279297 Min
Current Epoch: 6





Running Loss: 15044.010670661926, Running Corrects: 3667.6255144374495, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.94it/s]

450 in total	CR: 0.4444444444444444
val Epoch:	 6	Loss val: 1.5845	Acc val:0.4444, LR: 0.00288581929876693
Epoch len: 120.63944554328918 Estimated Remaining: 66.35169504880905 Min
Current Epoch: 7





Running Loss: 13549.501587867737, Running Corrects: 4359.029041797489, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  6.02it/s]

450 in total	CR: 0.6711111111111111
val Epoch:	 7	Loss val: 1.1424	Acc val:0.6711, LR: 0.0028365097862825517
Epoch len: 120.77445411682129 Estimated Remaining: 64.41304219563803 Min
Current Epoch: 8





Running Loss: 12428.365771770477, Running Corrects: 4897.398718431602, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.57it/s]

450 in total	CR: 0.6933333333333334
val Epoch:	 8	Loss val: 1.0073	Acc val:0.6933, LR: 0.0027789602465311384
Epoch len: 126.66134977340698 Estimated Remaining: 65.44169738292695 Min
Current Epoch: 9





Running Loss: 11018.568222045898, Running Corrects: 5496.960916510979, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.93it/s]

450 in total	CR: 0.6755555555555556
val Epoch:	 9	Loss val: 1.0085	Acc val:0.6756, LR: 0.002713525491562421
Epoch len: 119.18156027793884 Estimated Remaining: 59.59078013896942 Min
Current Epoch: 10





Running Loss: 9761.486583709717, Running Corrects: 6059.1727434035665, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  6.02it/s]

450 in total	CR: 0.7622222222222222
val Epoch:	10	Loss val: 0.7930	Acc val:0.7622, LR: 0.0026406089484000464
Epoch len: 119.02736234664917 Estimated Remaining: 57.529891800880435 Min
Current Epoch: 11





Running Loss: 9237.110421180725, Running Corrects: 6271.265320578448, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.94it/s]

450 in total	CR: 0.8355555555555556
val Epoch:	11	Loss val: 0.5738	Acc val:0.8356, LR: 0.002560660171779821
Epoch len: 119.19195914268494 Estimated Remaining: 55.62291426658631 Min
Current Epoch: 12





Running Loss: 8603.701888561249, Running Corrects: 6515.589013069629, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.63it/s]

450 in total	CR: 0.8777777777777778
val Epoch:	12	Loss val: 0.4776	Acc val:0.8778, LR: 0.002474172072495276
Epoch len: 126.18720412254333 Estimated Remaining: 56.7842418551445 Min
Current Epoch: 13





Running Loss: 7904.002210617065, Running Corrects: 6738.323989884451, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.70it/s]

450 in total	CR: 0.8355555555555556
val Epoch:	13	Loss val: 0.5341	Acc val:0.8356, LR: 0.00238167787843871
Epoch len: 125.64384841918945 Estimated Remaining: 54.44566764831543 Min
Current Epoch: 14





Running Loss: 8240.33470249176, Running Corrects: 6673.339376859282, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.90it/s]

450 in total	CR: 0.8822222222222222
val Epoch:	14	Loss val: 0.4399	Acc val:0.8822, LR: 0.002283747847073923
Epoch len: 119.04093837738037 Estimated Remaining: 49.60039099057516 Min
Current Epoch: 15





Running Loss: 7820.174072742462, Running Corrects: 6861.45558776144, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.90it/s]

450 in total	CR: 0.8911111111111111
val Epoch:	15	Loss val: 0.4250	Acc val:0.8911, LR: 0.0021809857496093202
Epoch len: 125.67137312889099 Estimated Remaining: 50.268549251556394 Min
Current Epoch: 16





Running Loss: 7775.217131137848, Running Corrects: 6876.751825070864, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.32it/s]

450 in total	CR: 0.9155555555555556
val Epoch:	16	Loss val: 0.3746	Acc val:0.9156, LR: 0.002074025148547635
Epoch len: 129.1361894607544 Estimated Remaining: 49.50220595995585 Min
Current Epoch: 17





Running Loss: 6731.56608247757, Running Corrects: 7166.696972826051, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.69it/s]

450 in total	CR: 0.92
val Epoch:	17	Loss val: 0.3372	Acc val:0.9200, LR: 0.0019635254915624212
Epoch len: 125.84206080436707 Estimated Remaining: 46.14208896160126 Min
Current Epoch: 18





Running Loss: 6700.737734675407, Running Corrects: 7230.177389567987, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.90it/s]

450 in total	CR: 0.9066666666666666
val Epoch:	18	Loss val: 0.3707	Acc val:0.9067, LR: 0.0018501680457838584
Epoch len: 125.83039832115173 Estimated Remaining: 44.04063941240311 Min
Current Epoch: 19





Running Loss: 6960.934795618057, Running Corrects: 7153.850642902586, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.67it/s]

450 in total	CR: 0.9155555555555556
val Epoch:	19	Loss val: 0.2853	Acc val:0.9156, LR: 0.0017346516975603465
Epoch len: 123.8790876865387 Estimated Remaining: 41.29302922884623 Min
Current Epoch: 20





Running Loss: 6769.213171124458, Running Corrects: 7242.931346292056, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.70it/s]

450 in total	CR: 0.9177777777777778
val Epoch:	20	Loss val: 0.3455	Acc val:0.9178, LR: 0.0016176886435917675
Epoch len: 130.6135594844818 Estimated Remaining: 41.36096050341924 Min
Current Epoch: 21





Running Loss: 6977.888783931732, Running Corrects: 7124.798737215873, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.69it/s]

450 in total	CR: 0.9422222222222222
val Epoch:	21	Loss val: 0.3153	Acc val:0.9422, LR: 0.0015
Epoch len: 125.93622040748596 Estimated Remaining: 37.78086612224579 Min
Current Epoch: 22





Running Loss: 6615.871075153351, Running Corrects: 7236.368295498487, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.59it/s]

450 in total	CR: 0.9511111111111111
val Epoch:	22	Loss val: 0.2800	Acc val:0.9511, LR: 0.0013823113564082328
Epoch len: 135.10077381134033 Estimated Remaining: 38.27855257987976 Min
Current Epoch: 23





Running Loss: 6961.306880474091, Running Corrects: 7128.525614923492, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.50it/s]

450 in total	CR: 0.9511111111111111
val Epoch:	23	Loss val: 0.2645	Acc val:0.9511, LR: 0.0012653483024396542
Epoch len: 126.06664800643921 Estimated Remaining: 33.61777280171712 Min
Current Epoch: 24





Running Loss: 6486.620091080666, Running Corrects: 7267.946141785301, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.51it/s]

450 in total	CR: 0.9422222222222222
val Epoch:	24	Loss val: 0.2677	Acc val:0.9422, LR: 0.0011498319542161421
Epoch len: 128.06471848487854 Estimated Remaining: 32.016179621219635 Min
Current Epoch: 25





Running Loss: 6511.041285276413, Running Corrects: 7262.665804415958, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.50it/s]

450 in total	CR: 0.9533333333333334
val Epoch:	25	Loss val: 0.2599	Acc val:0.9533, LR: 0.001036474508437579
Epoch len: 129.739972114563 Estimated Remaining: 30.272660160064696 Min
Current Epoch: 26





Running Loss: 6390.578077316284, Running Corrects: 7246.352356440173, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.62it/s]

450 in total	CR: 0.9444444444444444
val Epoch:	26	Loss val: 0.2608	Acc val:0.9444, LR: 0.0009259748514523654
Epoch len: 127.44324111938477 Estimated Remaining: 27.612702242533366 Min
Current Epoch: 27





Running Loss: 6809.811918497086, Running Corrects: 7177.357834101294, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.95it/s]

450 in total	CR: 0.9466666666666667
val Epoch:	27	Loss val: 0.2680	Acc val:0.9467, LR: 0.0008190142503906799
Epoch len: 120.40963053703308 Estimated Remaining: 24.081926107406616 Min
Current Epoch: 28





Running Loss: 6365.089375972748, Running Corrects: 7323.644288394579, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.50it/s]

450 in total	CR: 0.9488888888888889
val Epoch:	28	Loss val: 0.2507	Acc val:0.9489, LR: 0.0007162521529260768
Epoch len: 123.64034032821655 Estimated Remaining: 22.6673957268397 Min
Current Epoch: 29





Running Loss: 6361.458177924156, Running Corrects: 7270.027782178019, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.47it/s]

450 in total	CR: 0.9511111111111111
val Epoch:	29	Loss val: 0.2295	Acc val:0.9511, LR: 0.0006183221215612905
Epoch len: 128.8457841873169 Estimated Remaining: 21.474297364552815 Min
Current Epoch: 30





Running Loss: 6522.102969408035, Running Corrects: 7231.511945202284, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.78it/s]

450 in total	CR: 0.9488888888888889
val Epoch:	30	Loss val: 0.2522	Acc val:0.9489, LR: 0.0005258279275047247
Epoch len: 120.62751412391663 Estimated Remaining: 18.094127118587494 Min
Current Epoch: 31





Running Loss: 6290.391363859177, Running Corrects: 7331.380608679128, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.91it/s]

450 in total	CR: 0.9577777777777777
val Epoch:	31	Loss val: 0.2365	Acc val:0.9578, LR: 0.0004393398282201788
Epoch len: 120.52918434143066 Estimated Remaining: 16.070557912190754 Min
Current Epoch: 32





Running Loss: 6564.922949314117, Running Corrects: 7195.388320102719, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.86it/s]

450 in total	CR: 0.9533333333333334
val Epoch:	32	Loss val: 0.2478	Acc val:0.9533, LR: 0.0003593910515999536
Epoch len: 120.45456290245056 Estimated Remaining: 14.053032338619232 Min
Current Epoch: 33





Running Loss: 6373.746150970459, Running Corrects: 7255.975094434982, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.78it/s]

450 in total	CR: 0.9533333333333334
val Epoch:	33	Loss val: 0.2417	Acc val:0.9533, LR: 0.000286474508437579
Epoch len: 120.4833300113678 Estimated Remaining: 12.048333001136779 Min
Current Epoch: 34





Running Loss: 6848.025153875351, Running Corrects: 7131.594685091534, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.83it/s]

450 in total	CR: 0.9511111111111111
val Epoch:	34	Loss val: 0.2465	Acc val:0.9511, LR: 0.00022103975346886173
Epoch len: 120.58793091773987 Estimated Remaining: 10.048994243144989 Min
Current Epoch: 35





Running Loss: 6252.148253202438, Running Corrects: 7306.4044483461985, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.82it/s]

450 in total	CR: 0.9533333333333334
val Epoch:	35	Loss val: 0.2419	Acc val:0.9533, LR: 0.00016349021371744833
Epoch len: 120.6007661819458 Estimated Remaining: 8.040051078796386 Min
Current Epoch: 36





Running Loss: 6062.169922590256, Running Corrects: 7390.52405286161, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.80it/s]

450 in total	CR: 0.9533333333333334
val Epoch:	36	Loss val: 0.2433	Acc val:0.9533, LR: 0.0001141807012330699
Epoch len: 120.63722205162048 Estimated Remaining: 6.031861102581024 Min
Current Epoch: 37





Running Loss: 6107.149834394455, Running Corrects: 7368.831406505149, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.87it/s]

450 in total	CR: 0.9488888888888889
val Epoch:	37	Loss val: 0.2379	Acc val:0.9489, LR: 7.341522555726971e-05
Epoch len: 120.4900963306427 Estimated Remaining: 4.016336544354757 Min
Current Epoch: 38





Running Loss: 6264.211488366127, Running Corrects: 7310.01109145372, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.77it/s]

450 in total	CR: 0.9533333333333334
val Epoch:	38	Loss val: 0.2374	Acc val:0.9533, LR: 4.144511940348517e-05
Epoch len: 120.58574676513672 Estimated Remaining: 2.009762446085612 Min
Current Epoch: 39





Running Loss: 6339.239730834961, Running Corrects: 7284.627528490887, Running All: 8820.0


100%|██████████| 15/15 [00:02<00:00,  5.60it/s]


450 in total	CR: 0.9511111111111111
val Epoch:	39	Loss val: 0.2339	Acc val:0.9511, LR: 1.846748910729351e-05
Epoch len: 126.986811876297 Estimated Remaining: 0.0 Min


100%|██████████| 15/15 [00:02<00:00,  5.83it/s]

450 in total	CR: 0.9577777777777777
Test time performance of best epoch: 0.9577777777777777 (loss: 0.23824371230271127)





In [11]:
#ACC of 95.7777 on 9 Words
torch.save(model.state_dict(), '/home/taylorpap/Bootcamp/lipreadlstmv2.pth')