## Imports

In [1]:
import os
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import wandb
import numpy as np
from torchvision import transforms
from custom_dataset import VideoFrameDataset, ImglistToTensor
from cls_head import ClassifierHead
from pytorch_csn import create_csn

wandb.init(entity="cares", project="autoencoder-experiments", group="cls-pytorch", name="wlasl")

  from .autonotebook import tqdm as notebook_tqdm
[34m[1mwandb[0m: Currently logged in as: [33msttaseen[0m ([33mcares[0m). Use [1m`wandb login --relogin`[0m to force relogin


## Device Agnostic Code

In [2]:
try:
    device = 'mps' if torch.backends.mps.is_available() else 'cpu'
except:
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Data Setup

In [3]:
data_root = os.path.join(os.getcwd(), 'data/wlasl/rawframes') 
ann_file_train = os.path.join(os.getcwd(), 'data/wlasl/train_annotations.txt') 
ann_file_test = os.path.join(os.getcwd(), 'data/wlasl/test_annotations.txt')
work_dir = 'work_dirs/wlasl/sampleframes/'
batch_size = 3

os.makedirs(work_dir, exist_ok=True)


# Setting up data augments
train_pipeline = transforms.Compose([
        ImglistToTensor(),
        transforms.Resize((256, 256)),
        transforms.RandomResizedCrop((256, 256), scale=(0.6, 1.0)),
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.Normalize(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
    ])

test_pipeline = transforms.Compose([
        ImglistToTensor(), # list of PIL images to (FRAMES x CHANNELS x HEIGHT x WIDTH) tensor
        transforms.Resize((256, 256)),  # image batch, resize smaller edge to 256
        transforms.CenterCrop((224, 224)),  # image batch, center crop to square 224x224
        transforms.Normalize(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
    ])

# Setting up datasets
train_dataset = VideoFrameDataset(
    root_path=data_root,
    annotationfile_path=ann_file_train,
    clip_len=32,
    frame_interval=2,
    num_clips=1,
    imagefile_template='img_{:05d}.jpg',
    transform=train_pipeline,
    test_mode=False
)

test_dataset = VideoFrameDataset(
    root_path=data_root,
    annotationfile_path=ann_file_test,
    clip_len=32,
    frame_interval=2,
    num_clips=1,
    imagefile_template='img_{:05d}.jpg',
    transform=test_pipeline,
    test_mode=True
)

# Setting up dataloaders
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                    batch_size=batch_size,
                                    shuffle=True,
                                    num_workers=4,
                                    pin_memory=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                    batch_size=1,
                                    shuffle=True,
                                    num_workers=4,
                                    pin_memory=True)

In [4]:
# # Testing
# dataiter = iter(test_loader)
# get = next(dataiter)
# reshape = get[0].permute(0,2,1,3,4)
# video = iter(reshape[0][0])
# plt.imshow(next(video))

## Set up model, loss and optimiser

In [5]:
encoder, decoder = create_csn(input_channel=3,
                   model_depth=50,
                   model_num_class=400,
                   dropout_rate=0.5)

In [6]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, x):
        code = self.encoder(x)
        return self.decoder(code)

In [7]:
# # Create a CSN model
# encoder = ResNet3dCSN(
#     pretrained2d=False,
#     pretrained='https://download.openmmlab.com/mmaction/recognition/csn/ircsn_from_scratch_r50_ig65m_20210617-ce545a37.pth',
#     depth=50,
#     with_pool2=False,
#     bottleneck_mode='ir',
#     norm_eval=True,
#     zero_init_residual=False,
#     bn_frozen=True
# )
# encoder.init_weights()

# decoder = ClassifierHead()
# decoder.init_weights()

model = EncoderDecoder(encoder, decoder)

# Specify loss function
loss_fn = nn.CrossEntropyLoss()

# Specify loss function
optimizer = torch.optim.SGD(model.parameters(), lr=0.000025, momentum=0.9, weight_decay=0)

# Specify learning rate scheduler
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[75,130])

# Setup wandb
wandb.watch(model, log_freq=10)

[]

In [8]:
# checkpoints = torch.load(work_dir+'latest.pth')
# model.load_state_dict(checkpoints)

In [9]:
def top_k_accuracy(scores, labels, topk=(1, )):
    """Calculate top k accuracy score.
    Args:
        scores (list[np.ndarray]): Prediction scores for each class.
        labels (list[int]): Ground truth labels.
        topk (tuple[int]): K value for top_k_accuracy. Default: (1, ).
    Returns:
        list[float]: Top k accuracy score for each k.
    """
    res = np.zeros(len(topk))
    labels = np.array(labels)[:, np.newaxis]
    for i, k in enumerate(topk):
        max_k_preds = np.argsort(scores, axis=1)[:, -k:][:, ::-1]
        match_array = np.logical_or.reduce(max_k_preds == labels, axis=1)
        topk_acc_score = match_array.sum() / match_array.shape[0]
        res[i] = topk_acc_score

    return res

In [10]:
def train_one_epoch(epoch_index, interval=5):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, (images, targets) in enumerate(train_loader):
        # Every data instance is an input + label pair
        images, targets = images.to(device).permute(0,2,1,3,4), targets.to(device)

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(images)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, targets)
        loss.backward()
        
        # Gradient Clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=40, norm_type=2.0)

        # Adjust learning weights
        optimizer.step()
#         scheduler.step()

#         lr: {scheduler.get_last_lr()[0]:.5e}

        # Gather data and report
        running_loss += loss.item()
        if i % interval == interval-1:
            last_loss = running_loss / interval # loss per batch     
            print(f'Epoch [{epoch_index}][{i+1}/{len(train_loader)}], loss: {last_loss:.5}')
            running_loss = 0.

    return last_loss

In [11]:
def validate():
    running_vloss = 0.0
    running_vacc = np.zeros(2)
    
    print('Evaluating top_k_accuracy...')
    
    with torch.inference_mode():   
        for i, (vimages, vtargets) in enumerate(test_loader):
            vimages, vtargets = vimages.to(device), vtargets.to(device)
            
            voutputs = model(vimages.permute(0,2,1,3,4))
            
            vloss = loss_fn(voutputs, vtargets)
            running_vloss += vloss

            running_vacc += top_k_accuracy(voutputs.detach().cpu().numpy(), vtargets.detach().cpu().numpy(), topk=(1,5))

    avg_vloss = running_vloss / (i + 1)

    acc = running_vacc/len(test_loader)
    top1_acc = acc[0].item()
    top5_acc = acc[1].item()
    
    return (avg_vloss, top1_acc, top5_acc)

In [12]:
epochs = 150
#21 

best_vloss = 1_000_000.

# Transfer model to device
model.to(device)

for epoch in range(epochs):
    
    # Turn on gradient tracking and do a forward pass
    model.train(True)
    avg_loss = train_one_epoch(epoch+1)
    
    # Turn off  gradients for reporting
    model.train(False)
    
    avg_vloss, top1_acc, top5_acc = validate()
    
    print(f'top1_acc: {top1_acc:.4}, top5_acc: {top5_acc:.4}, train_loss: {avg_loss:.5}, val_loss: {avg_vloss:.5}')
    
    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = work_dir + f'epoch_{epoch+1}.pth'
        print(f'Saving checkpoint at {epoch+1} epochs...')
        torch.save(model.state_dict(), model_path)
        
    # Track wandb
    wandb.log({'loss': avg_loss,
             'val/loss': avg_vloss,
             'val/top1_accuracy': top1_acc,
             'val/top5_accuracy': top5_acc})

Epoch [1][5/594], loss: 6.0058
Epoch [1][10/594], loss: 5.9491
Epoch [1][15/594], loss: 5.9023
Epoch [1][20/594], loss: 5.9649
Epoch [1][25/594], loss: 5.9265
Epoch [1][30/594], loss: 6.0411
Epoch [1][35/594], loss: 5.9891
Epoch [1][40/594], loss: 6.0152
Epoch [1][45/594], loss: 5.9568
Epoch [1][50/594], loss: 6.0365
Epoch [1][55/594], loss: 5.9752
Epoch [1][60/594], loss: 6.0357
Epoch [1][65/594], loss: 5.9352
Epoch [1][70/594], loss: 5.9114
Epoch [1][75/594], loss: 5.9198
Epoch [1][80/594], loss: 5.973
Epoch [1][85/594], loss: 6.013
Epoch [1][90/594], loss: 5.8491
Epoch [1][95/594], loss: 5.9884
Epoch [1][100/594], loss: 6.0162
Epoch [1][105/594], loss: 5.9667
Epoch [1][110/594], loss: 5.9901
Epoch [1][115/594], loss: 5.9462
Epoch [1][120/594], loss: 5.9734
Epoch [1][125/594], loss: 5.9203
Epoch [1][130/594], loss: 5.9957
Epoch [1][135/594], loss: 5.9638
Epoch [1][140/594], loss: 5.954
Epoch [1][145/594], loss: 5.935
Epoch [1][150/594], loss: 5.9567
Epoch [1][155/594], loss: 5.8664
E

Epoch [3][40/594], loss: 5.5579
Epoch [3][45/594], loss: 5.5999
Epoch [3][50/594], loss: 5.4887
Epoch [3][55/594], loss: 5.5633
Epoch [3][60/594], loss: 5.4102
Epoch [3][65/594], loss: 5.4436
Epoch [3][70/594], loss: 5.5512
Epoch [3][75/594], loss: 5.5037
Epoch [3][80/594], loss: 5.5342
Epoch [3][85/594], loss: 5.5074
Epoch [3][90/594], loss: 5.4414
Epoch [3][95/594], loss: 5.4988
Epoch [3][100/594], loss: 5.491
Epoch [3][105/594], loss: 5.4824
Epoch [3][110/594], loss: 5.4786
Epoch [3][115/594], loss: 5.536
Epoch [3][120/594], loss: 5.5402
Epoch [3][125/594], loss: 5.4989
Epoch [3][130/594], loss: 5.5182
Epoch [3][135/594], loss: 5.494
Epoch [3][140/594], loss: 5.5525
Epoch [3][145/594], loss: 5.5411
Epoch [3][150/594], loss: 5.4337
Epoch [3][155/594], loss: 5.4176
Epoch [3][160/594], loss: 5.4631
Epoch [3][165/594], loss: 5.4814
Epoch [3][170/594], loss: 5.4941
Epoch [3][175/594], loss: 5.4851
Epoch [3][180/594], loss: 5.4867
Epoch [3][185/594], loss: 5.5133
Epoch [3][190/594], loss:

Epoch [5][70/594], loss: 5.2254
Epoch [5][75/594], loss: 5.2039
Epoch [5][80/594], loss: 5.2239
Epoch [5][85/594], loss: 5.1362
Epoch [5][90/594], loss: 5.0971
Epoch [5][95/594], loss: 5.3176
Epoch [5][100/594], loss: 5.1831
Epoch [5][105/594], loss: 5.207
Epoch [5][110/594], loss: 5.1156
Epoch [5][115/594], loss: 5.0658
Epoch [5][120/594], loss: 5.0794
Epoch [5][125/594], loss: 5.2174
Epoch [5][130/594], loss: 5.0592
Epoch [5][135/594], loss: 5.1758
Epoch [5][140/594], loss: 5.2095
Epoch [5][145/594], loss: 5.2077
Epoch [5][150/594], loss: 5.1249
Epoch [5][155/594], loss: 5.1762
Epoch [5][160/594], loss: 5.2431
Epoch [5][165/594], loss: 5.177
Epoch [5][170/594], loss: 5.2044
Epoch [5][175/594], loss: 5.2281
Epoch [5][180/594], loss: 5.1184
Epoch [5][185/594], loss: 5.1737
Epoch [5][190/594], loss: 5.1666
Epoch [5][195/594], loss: 5.2585
Epoch [5][200/594], loss: 5.2812
Epoch [5][205/594], loss: 5.1545
Epoch [5][210/594], loss: 5.0984
Epoch [5][215/594], loss: 5.1843
Epoch [5][220/594]

Epoch [7][105/594], loss: 5.1082
Epoch [7][110/594], loss: 4.9737
Epoch [7][115/594], loss: 5.0726
Epoch [7][120/594], loss: 4.948
Epoch [7][125/594], loss: 4.9014
Epoch [7][130/594], loss: 4.9712
Epoch [7][135/594], loss: 4.9405
Epoch [7][140/594], loss: 4.9929
Epoch [7][145/594], loss: 4.885
Epoch [7][150/594], loss: 4.9534
Epoch [7][155/594], loss: 5.1089
Epoch [7][160/594], loss: 5.0324
Epoch [7][165/594], loss: 5.1138
Epoch [7][170/594], loss: 5.0262
Epoch [7][175/594], loss: 5.0377
Epoch [7][180/594], loss: 5.0343
Epoch [7][185/594], loss: 5.1131
Epoch [7][190/594], loss: 5.1108
Epoch [7][195/594], loss: 5.0574
Epoch [7][200/594], loss: 4.9712
Epoch [7][205/594], loss: 5.025
Epoch [7][210/594], loss: 5.0486
Epoch [7][215/594], loss: 4.8697
Epoch [7][220/594], loss: 5.021
Epoch [7][225/594], loss: 5.026
Epoch [7][230/594], loss: 4.9785
Epoch [7][235/594], loss: 5.0567
Epoch [7][240/594], loss: 5.0947
Epoch [7][245/594], loss: 5.0879
Epoch [7][250/594], loss: 5.0134
Epoch [7][255/5

Epoch [9][135/594], loss: 4.9727
Epoch [9][140/594], loss: 4.9296
Epoch [9][145/594], loss: 4.9225
Epoch [9][150/594], loss: 4.9409
Epoch [9][155/594], loss: 4.9031
Epoch [9][160/594], loss: 4.8805
Epoch [9][165/594], loss: 4.911
Epoch [9][170/594], loss: 4.9003
Epoch [9][175/594], loss: 4.9045
Epoch [9][180/594], loss: 4.8521
Epoch [9][185/594], loss: 4.9601
Epoch [9][190/594], loss: 4.8351
Epoch [9][195/594], loss: 4.9088
Epoch [9][200/594], loss: 4.9292
Epoch [9][205/594], loss: 4.9612
Epoch [9][210/594], loss: 4.946
Epoch [9][215/594], loss: 4.9225
Epoch [9][220/594], loss: 4.9165
Epoch [9][225/594], loss: 4.908
Epoch [9][230/594], loss: 4.9327
Epoch [9][235/594], loss: 4.7893
Epoch [9][240/594], loss: 4.9432
Epoch [9][245/594], loss: 4.9158
Epoch [9][250/594], loss: 4.8527
Epoch [9][255/594], loss: 4.9032
Epoch [9][260/594], loss: 4.902
Epoch [9][265/594], loss: 4.9324
Epoch [9][270/594], loss: 4.8322
Epoch [9][275/594], loss: 4.844
Epoch [9][280/594], loss: 4.8142
Epoch [9][285/5

Epoch [11][145/594], loss: 4.7302
Epoch [11][150/594], loss: 4.7782
Epoch [11][155/594], loss: 4.8207
Epoch [11][160/594], loss: 4.8328
Epoch [11][165/594], loss: 4.8511
Epoch [11][170/594], loss: 4.8853
Epoch [11][175/594], loss: 4.8618
Epoch [11][180/594], loss: 4.8833
Epoch [11][185/594], loss: 4.8216
Epoch [11][190/594], loss: 4.8791
Epoch [11][195/594], loss: 4.8689
Epoch [11][200/594], loss: 4.769
Epoch [11][205/594], loss: 4.7537
Epoch [11][210/594], loss: 4.8725
Epoch [11][215/594], loss: 4.8572
Epoch [11][220/594], loss: 4.8571
Epoch [11][225/594], loss: 4.7562
Epoch [11][230/594], loss: 4.8091
Epoch [11][235/594], loss: 4.792
Epoch [11][240/594], loss: 4.789
Epoch [11][245/594], loss: 4.7712
Epoch [11][250/594], loss: 4.7918
Epoch [11][255/594], loss: 4.8627
Epoch [11][260/594], loss: 4.8903
Epoch [11][265/594], loss: 4.8442
Epoch [11][270/594], loss: 4.8338
Epoch [11][275/594], loss: 4.7444
Epoch [11][280/594], loss: 4.7794
Epoch [11][285/594], loss: 4.7738
Epoch [11][290/59

Epoch [13][140/594], loss: 4.6946
Epoch [13][145/594], loss: 4.7848
Epoch [13][150/594], loss: 4.8176
Epoch [13][155/594], loss: 4.7698
Epoch [13][160/594], loss: 4.7373
Epoch [13][165/594], loss: 4.7836
Epoch [13][170/594], loss: 4.7237
Epoch [13][175/594], loss: 4.7294
Epoch [13][180/594], loss: 4.7573
Epoch [13][185/594], loss: 4.7394
Epoch [13][190/594], loss: 4.7863
Epoch [13][195/594], loss: 4.7318
Epoch [13][200/594], loss: 4.857
Epoch [13][205/594], loss: 4.7299
Epoch [13][210/594], loss: 4.7678
Epoch [13][215/594], loss: 4.8065
Epoch [13][220/594], loss: 4.718
Epoch [13][225/594], loss: 4.833
Epoch [13][230/594], loss: 4.8199
Epoch [13][235/594], loss: 4.7623
Epoch [13][240/594], loss: 4.7974
Epoch [13][245/594], loss: 4.8119
Epoch [13][250/594], loss: 4.8427
Epoch [13][255/594], loss: 4.8319
Epoch [13][260/594], loss: 4.7777
Epoch [13][265/594], loss: 4.8308
Epoch [13][270/594], loss: 4.6983
Epoch [13][275/594], loss: 4.8193
Epoch [13][280/594], loss: 4.8493
Epoch [13][285/59

Epoch [15][135/594], loss: 4.7397
Epoch [15][140/594], loss: 4.8135
Epoch [15][145/594], loss: 4.7986
Epoch [15][150/594], loss: 4.7491
Epoch [15][155/594], loss: 4.7562
Epoch [15][160/594], loss: 4.7937
Epoch [15][165/594], loss: 4.8228
Epoch [15][170/594], loss: 4.7964
Epoch [15][175/594], loss: 4.6801
Epoch [15][180/594], loss: 4.6829
Epoch [15][185/594], loss: 4.5789
Epoch [15][190/594], loss: 4.8039
Epoch [15][195/594], loss: 4.6605
Epoch [15][200/594], loss: 4.743
Epoch [15][205/594], loss: 4.7246
Epoch [15][210/594], loss: 4.8288
Epoch [15][215/594], loss: 4.8308
Epoch [15][220/594], loss: 4.7753
Epoch [15][225/594], loss: 4.7835
Epoch [15][230/594], loss: 4.7499
Epoch [15][235/594], loss: 4.8067
Epoch [15][240/594], loss: 4.7182
Epoch [15][245/594], loss: 4.6273
Epoch [15][250/594], loss: 4.7731
Epoch [15][255/594], loss: 4.82
Epoch [15][260/594], loss: 4.7078
Epoch [15][265/594], loss: 4.6307
Epoch [15][270/594], loss: 4.7553
Epoch [15][275/594], loss: 4.8468
Epoch [15][280/59

Epoch [17][130/594], loss: 4.7342
Epoch [17][135/594], loss: 4.7762
Epoch [17][140/594], loss: 4.7592
Epoch [17][145/594], loss: 4.7568
Epoch [17][150/594], loss: 4.7375
Epoch [17][155/594], loss: 4.731
Epoch [17][160/594], loss: 4.7726
Epoch [17][165/594], loss: 4.7732
Epoch [17][170/594], loss: 4.6982
Epoch [17][175/594], loss: 4.7309
Epoch [17][180/594], loss: 4.7382
Epoch [17][185/594], loss: 4.7929
Epoch [17][190/594], loss: 4.7451
Epoch [17][195/594], loss: 4.7219
Epoch [17][200/594], loss: 4.6686
Epoch [17][205/594], loss: 4.7264
Epoch [17][210/594], loss: 4.736
Epoch [17][215/594], loss: 4.7024
Epoch [17][220/594], loss: 4.7724
Epoch [17][225/594], loss: 4.7482
Epoch [17][230/594], loss: 4.7247
Epoch [17][235/594], loss: 4.7256
Epoch [17][240/594], loss: 4.7099
Epoch [17][245/594], loss: 4.7047
Epoch [17][250/594], loss: 4.7087
Epoch [17][255/594], loss: 4.777
Epoch [17][260/594], loss: 4.6676
Epoch [17][265/594], loss: 4.7025
Epoch [17][270/594], loss: 4.7373
Epoch [17][275/59

Epoch [19][125/594], loss: 4.6996
Epoch [19][130/594], loss: 4.7669
Epoch [19][135/594], loss: 4.6514
Epoch [19][140/594], loss: 4.6167
Epoch [19][145/594], loss: 4.6199
Epoch [19][150/594], loss: 4.7408
Epoch [19][155/594], loss: 4.677
Epoch [19][160/594], loss: 4.6746
Epoch [19][165/594], loss: 4.7389
Epoch [19][170/594], loss: 4.7403
Epoch [19][175/594], loss: 4.6818
Epoch [19][180/594], loss: 4.6324
Epoch [19][185/594], loss: 4.641
Epoch [19][190/594], loss: 4.6395
Epoch [19][195/594], loss: 4.5623
Epoch [19][200/594], loss: 4.7919
Epoch [19][205/594], loss: 4.7786
Epoch [19][210/594], loss: 4.781
Epoch [19][215/594], loss: 4.7554
Epoch [19][220/594], loss: 4.7491
Epoch [19][225/594], loss: 4.6226
Epoch [19][230/594], loss: 4.7073
Epoch [19][235/594], loss: 4.6284
Epoch [19][240/594], loss: 4.7025
Epoch [19][245/594], loss: 4.6319
Epoch [19][250/594], loss: 4.72
Epoch [19][255/594], loss: 4.7423
Epoch [19][260/594], loss: 4.7243
Epoch [19][265/594], loss: 4.7116
Epoch [19][270/594]

Epoch [21][120/594], loss: 4.7543
Epoch [21][125/594], loss: 4.6715
Epoch [21][130/594], loss: 4.6905
Epoch [21][135/594], loss: 4.7296
Epoch [21][140/594], loss: 4.655
Epoch [21][145/594], loss: 4.7259
Epoch [21][150/594], loss: 4.7102
Epoch [21][155/594], loss: 4.6301
Epoch [21][160/594], loss: 4.7479
Epoch [21][165/594], loss: 4.7426
Epoch [21][170/594], loss: 4.7342
Epoch [21][175/594], loss: 4.7517
Epoch [21][180/594], loss: 4.7612
Epoch [21][185/594], loss: 4.6765
Epoch [21][190/594], loss: 4.6507
Epoch [21][195/594], loss: 4.6698
Epoch [21][200/594], loss: 4.6903
Epoch [21][205/594], loss: 4.7466
Epoch [21][210/594], loss: 4.6392
Epoch [21][215/594], loss: 4.5973
Epoch [21][220/594], loss: 4.6981
Epoch [21][225/594], loss: 4.6551
Epoch [21][230/594], loss: 4.5804
Epoch [21][235/594], loss: 4.6355
Epoch [21][240/594], loss: 4.6463
Epoch [21][245/594], loss: 4.7498
Epoch [21][250/594], loss: 4.7595


KeyboardInterrupt: 