In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from dataset import TCNDataset
from model import TCN , MultiStageTCN , ParallelTCNs


def train(model,dataloader,optimizer):
    i=0
    criterion = nn.CrossEntropyLoss()
    running_loss = 0 
    for features, labels, masks in dataloader:
        out = model(features,masks)
        optimizer.zero_grad()
        loss = criterion(out, labels)
        loss.backward()
        optimizer.step()
        if i % 10 == 0:
                print("    Batch {}: loss = {}".format(i ,loss.item()))
        i += 1
        running_loss = loss.item()
    return running_loss / len(dataloader)

    
def train_parallel(model, dataloader,optimizer):
    model.train()
    i=0
    criterion = nn.CrossEntropyLoss()
    running_loss = 0 
    for features, labels, masks in dataloader:
        out1, out2, out3 ,out_average = model(features,masks)
        optimizer.zero_grad()
        loss1 = criterion(out1, labels)
        loss2 = criterion(out2, labels)
        loss3 = criterion(out3, labels)
        loss_average = criterion(out_average, labels)
        loss = loss1 + loss2 + loss3 + loss_average
        loss.backward()
        optimizer.step()
        if i % 10 == 0:
                print("    Batch {}: combined loss = {}".format(i ,loss.item()))
        i += 1
        running_loss = loss.item()
    return running_loss / len(dataloader)



# function for zero padding for dataloader because of variable video length
# inspired by the code from the paper
def collate_fn_padd(batch):
        batch_input , batch_target = [list(t) for t in zip(*batch)] 
        length_of_sequences = list(map(len, batch_target))
        batch_input_tensor = torch.zeros(len(batch_input), np.shape(batch_input[0])[0], max(length_of_sequences), dtype=torch.float)
        
        batch_target_tensor = torch.ones(len(batch_input), max(length_of_sequences), dtype=torch.long)*(-100)
        
        mask = torch.zeros(len(batch_input), num_classes, max(length_of_sequences), dtype=torch.float)
        
        for i in range(len(batch_input)):
            batch_input_tensor[i, :, :np.shape(batch_input[i])[1]] = batch_input[i]
            
            batch_target_tensor[i, :np.shape(batch_target[i])[0]] = batch_target[i]
            
            mask[i, :, :np.shape(batch_target[i])[0]] = torch.ones(num_classes, batch_target[i].shape[0])
            
        return batch_input_tensor, batch_target_tensor, mask
            
            

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 4
epochs = 50
num_classes = 48


# DATA LOADERS 

training_dataset = TCNDataset(training=True)
training_dataloader = torch.utils.data.DataLoader(training_dataset,collate_fn=collate_fn_padd,  batch_size=batch_size, shuffle=True, drop_last=False)

In [3]:
single_TCN = TCN()
single_TCN_optimizer = torch.optim.Adam(single_TCN.parameters(),lr=0.001)

In [4]:
for epoch in range(epochs):
    print("RUNNING EPOCH: {}".format(epoch+1))
    train(single_TCN,training_dataloader , single_TCN_optimizer )

   

RUNNING EPOCH: 1
    Batch 0: loss = 3.891010046005249
    Batch 10: loss = 3.343764305114746
    Batch 20: loss = 2.984891176223755
    Batch 30: loss = 2.8843634128570557
    Batch 40: loss = 2.3481693267822266
    Batch 50: loss = 3.232736349105835
    Batch 60: loss = 2.3321526050567627
    Batch 70: loss = 2.704082727432251
    Batch 80: loss = 2.7374885082244873
    Batch 90: loss = 1.9123780727386475
    Batch 100: loss = 2.063551187515259
    Batch 110: loss = 2.005189895629883
    Batch 120: loss = 2.5331056118011475
    Batch 130: loss = 2.5087087154388428
    Batch 140: loss = 2.4343767166137695
    Batch 150: loss = 1.7807143926620483
    Batch 160: loss = 2.907348871231079
    Batch 170: loss = 2.2333271503448486
    Batch 180: loss = 2.4040868282318115
RUNNING EPOCH: 2
    Batch 0: loss = 2.128330945968628
    Batch 10: loss = 2.0890893936157227
    Batch 20: loss = 2.159660816192627
    Batch 30: loss = 2.1074116230010986
    Batch 40: loss = 2.2682130336761475
    Batch

    Batch 100: loss = 1.7902750968933105
    Batch 110: loss = 1.0403409004211426
    Batch 120: loss = 1.0449045896530151
    Batch 130: loss = 1.1544429063796997
    Batch 140: loss = 1.2821348905563354
    Batch 150: loss = 1.4884865283966064
    Batch 160: loss = 1.5512279272079468
    Batch 170: loss = 1.5238364934921265
    Batch 180: loss = 1.11223304271698
RUNNING EPOCH: 12
    Batch 0: loss = 0.968706488609314
    Batch 10: loss = 1.2494111061096191
    Batch 20: loss = 1.7284404039382935
    Batch 30: loss = 1.2753329277038574
    Batch 40: loss = 1.5542116165161133
    Batch 50: loss = 1.4011132717132568
    Batch 60: loss = 1.3854570388793945
    Batch 70: loss = 1.080426573753357
    Batch 80: loss = 1.069066047668457
    Batch 90: loss = 0.7510177493095398
    Batch 100: loss = 0.9014356732368469
    Batch 110: loss = 1.4500170946121216
    Batch 120: loss = 1.6375812292099
    Batch 130: loss = 1.2451589107513428
    Batch 140: loss = 0.9888614416122437
    Batch 150: lo

RUNNING EPOCH: 22
    Batch 0: loss = 0.7999970316886902
    Batch 10: loss = 0.4883362650871277
    Batch 20: loss = 0.8373667001724243
    Batch 30: loss = 0.8382042646408081
    Batch 40: loss = 0.9634243249893188
    Batch 50: loss = 0.6439028382301331
    Batch 60: loss = 0.6303288340568542
    Batch 70: loss = 0.8969907760620117
    Batch 80: loss = 0.7182168960571289
    Batch 90: loss = 0.4932064414024353
    Batch 100: loss = 0.8019971251487732
    Batch 110: loss = 0.6895590424537659
    Batch 120: loss = 0.5185454487800598
    Batch 130: loss = 1.0687166452407837
    Batch 140: loss = 0.45304515957832336
    Batch 150: loss = 0.8687803149223328
    Batch 160: loss = 1.1277669668197632
    Batch 170: loss = 0.874983012676239
    Batch 180: loss = 0.6292275190353394
RUNNING EPOCH: 23
    Batch 0: loss = 1.0053290128707886
    Batch 10: loss = 0.8596996665000916
    Batch 20: loss = 0.6544319987297058
    Batch 30: loss = 0.7395153641700745
    Batch 40: loss = 0.65443962812423

    Batch 90: loss = 0.6280130743980408
    Batch 100: loss = 0.4807473123073578
    Batch 110: loss = 0.5148897171020508
    Batch 120: loss = 0.559730589389801
    Batch 130: loss = 0.4541029930114746
    Batch 140: loss = 0.5253770351409912
    Batch 150: loss = 0.4756004214286804
    Batch 160: loss = 0.481520414352417
    Batch 170: loss = 0.3923790156841278
    Batch 180: loss = 0.8623682856559753
RUNNING EPOCH: 33
    Batch 0: loss = 0.6993626952171326
    Batch 10: loss = 0.45590347051620483
    Batch 20: loss = 0.3383035361766815
    Batch 30: loss = 0.6009023189544678
    Batch 40: loss = 0.4285874664783478
    Batch 50: loss = 0.5243211388587952
    Batch 60: loss = 0.7607300877571106
    Batch 70: loss = 0.5710909962654114
    Batch 80: loss = 0.6057955622673035
    Batch 90: loss = 0.34043124318122864
    Batch 100: loss = 0.3883204162120819
    Batch 110: loss = 0.3809094727039337
    Batch 120: loss = 0.42326799035072327
    Batch 130: loss = 0.36993029713630676
    Batc

    Batch 170: loss = 0.30756068229675293
    Batch 180: loss = 0.3696420192718506
RUNNING EPOCH: 43
    Batch 0: loss = 0.5478893518447876
    Batch 10: loss = 0.27844205498695374
    Batch 20: loss = 0.29182934761047363
    Batch 30: loss = 0.35142335295677185
    Batch 40: loss = 0.4096445143222809
    Batch 50: loss = 0.30758628249168396
    Batch 60: loss = 0.48383381962776184
    Batch 70: loss = 0.4548671543598175
    Batch 80: loss = 0.3763805031776428
    Batch 90: loss = 0.37022122740745544
    Batch 100: loss = 0.6656994223594666
    Batch 110: loss = 0.42969244718551636
    Batch 120: loss = 0.32677724957466125
    Batch 130: loss = 0.32900163531303406
    Batch 140: loss = 0.7724692821502686
    Batch 150: loss = 0.44220486283302307
    Batch 160: loss = 0.5028329491615295
    Batch 170: loss = 0.4410719573497772
    Batch 180: loss = 0.26864951848983765
RUNNING EPOCH: 44
    Batch 0: loss = 0.5424371957778931
    Batch 10: loss = 0.886870265007019
    Batch 20: loss = 0.4

In [5]:
torch.save(single_TCN, "./single_tcn")

# Evaluation



In [20]:
from eval import eval_

In [21]:
test_dataset = TCNDataset(training=False)
test_dataloader = torch.utils.data.DataLoader(test_dataset,collate_fn=collate_fn_padd,  batch_size=1, shuffle=False, drop_last=False)

In [22]:
eval_(single_TCN,test_dataloader)

Acc: 49.2666
Edit: 6.3570
F1@0.10: 4.7232
F1@0.25: 3.6205
F1@0.50: 2.0951
