In [1]:
# imports
import argparse
import logging
import time
from tqdm import tqdm
import numpy as np
import torch

from vector_cv_tools import datasets
from vector_cv_tools import transforms as T
from vector_cv_tools import utils
from vector_cv_tools import checkpointing

import albumentations as A
from torch.utils.data import DataLoader

import torchvision

In [None]:
kinetics_annotation_path = "./datasets/kinetics/kinetics700/train.json"
kinetics_data_path = "./datasets/kinetics/train"

In [2]:
import logging
LOG_FILE = "run_log.out"

def log_and_print(print_str):
    logging.info(print_str)
    print(print_str)

logging.basicConfig(format='%(asctime)s %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO,
                    filename=LOG_FILE,
                    filemode='a')

## Data Preperation

### Dataset and Transforms

#### Define spatial and temporal transforms 

In [3]:
# define spatial transforms
spatial_transforms = [A.Resize(128, 128), A.ToFloat(max_value=255)]
spatial_transforms = T.ComposeVideoSpatialTransform(spatial_transforms)

# define temporal transforms
temporal_transforms = [ T.video_transforms.RandomTemporalCrop(size=64, 
                                                    pad_if_needed=True,
                                                    padding_mode ="wrap"), 
                        T.video_transforms.ToTensor()]

temporal_transforms = T.ComposeVideoTemporalTransform(temporal_transforms)

print("Spatial transforms: \n{}".format(spatial_transforms))
print("Temporal transforms: \n{}".format(temporal_transforms))


Spatial transforms: 
ComposeVideoSpatialTransform(
    Resize(always_apply=False, p=1, height=128, width=128, interpolation=1)
    ToFloat(always_apply=False, p=1.0, max_value=255)
)
Temporal transforms: 
ComposeVideoTemporalTransform(
    RandomTemporalCrop(size=64, padding=None)
    ToTensor()
)


#### Create dataset given the annotation files and data files for Kinetics dataset

In [4]:
# create dataset, only filter two classes here
dataset = datasets.KineticsDataset(
        fps=10,
        max_frames=128,
        round_source_fps=False,
        annotation_path = kinetics_annotation_path,
        data_path = kinetics_data_path,
        class_filter = ["push_up", "pull_ups"],
        spatial_transforms=spatial_transforms,
        temporal_transforms=temporal_transforms,)

# inspect labels
labels = dataset.metadata.labels

for label, info in labels.items():
    print("{:<40} ID: {} size: {} {}".
        format(label, info["id"], len(info["indexes"]), len(info["indexes"])//20 * "|"))

push_up                                  ID: 0 size: 964 ||||||||||||||||||||||||||||||||||||||||||||||||
pull_ups                                 ID: 1 size: 929 ||||||||||||||||||||||||||||||||||||||||||||||


### Savable Dataloader

In [5]:
# convert data to loader
num_workers = 4
batch_size = 8
###################### CHECKPOINTING ########################
# The dataloader need to keep state since we need to checkpoint within an epoch 
loader = checkpointing.SaveableDataLoader(
                dataset,
                num_workers=num_workers,
                batch_size=batch_size,
                collate_fn=utils.VideoCollateFnSelector("stack_or_combine"),
                shuffle=True)
###########################################################


print("Looping through the dataset, {} labels, {} data points in total".
        format(dataset.num_classes, len(loader)))


Looping through the dataset, 2 labels, 237 data points in total


### Visualize videos from the dataset

In [6]:
data_point, label = dataset[0]
print(data_point.shape)
print(label)
vid = (data_point.numpy() * 255).astype(np.uint8)
utils.create_GIF("TestImage.gif", vid)

torch.Size([64, 128, 128, 3])
{'label_ids': [0], 'label_names': ['push_up'], 'sampled_fps': 10}


### Get a pre-trained model and change the output layer


In [7]:
num_classes = dataset.num_classes

model = torchvision.models.video.r3d_18(pretrained=True, progress=True, num_classes=400)
# freeze the layers except for the last one
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False
            
set_parameter_requires_grad(model, feature_extracting=True)

model.fc = torch.nn.Linear(in_features=512, out_features=num_classes, bias=True)
device = torch.device("cuda:0")
model.to(device)

print(model)

VideoResNet(
  (stem): BasicStem(
    (0): Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False)
    (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Sequential(
        (0): Conv3DSimple(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
      (conv2): Sequential(
        (0): Conv3DSimple(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU(inplace=True)
    )
    (1): BasicBlock(
      (conv1): Sequential(
        (0): Conv3DSimple(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (1):

## Checkpointing setup 

In [8]:
###################### CHECKPOINTING ########################

optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.002)
criterion = torch.nn.CrossEntropyLoss()

cur_epoch = checkpointing.SavebleNumber(0)
cur_iter = checkpointing.SavebleNumber(0)

losses = checkpointing.SaveableList()
acc = checkpointing.SaveableList()
rng = checkpointing.SaveableRNG(888)

checkpoint = checkpointing.SavableCollection(
        model=model,
        optimizer=optimizer,
        loader=loader,
        cur_epoch=cur_epoch,
        cur_iter=cur_iter,
        losses=losses,
        acc=acc,
        rng=rng,
        )

manager = checkpointing.CheckpointManager(checkpoint=checkpoint, 
                              directory="./kinetics _train_checkpoints", 
                              max_to_keep=3,
                              checkpoint_interval=120, # 120s
                              )


### Initialize or restore from checkpoint
**Note:
`load_latest_checkpoint()` does nothing if there is no checkpoint loaded otherwise it loads from the latest checkpoint in the directory specified above**

In [9]:
manager.load_latest_checkpoint()
loaded = manager.latest_checkpoint is not None
if loaded:
    print_str = (f"Checkpoint that has finished epoch {checkpoint.cur_epoch}, iteration {checkpoint.cur_iter} with "
                 f"losses: {checkpoint.losses} is loaded from {manager.latest_checkpoint}")

else:
    print_str = f"No checkpoints found under {manager.directory}, starting from scratch"

log_and_print(print_str)
########################################################


Checkpoint that has finished epoch 0, iteration 221 with losses: [] is loaded from kinetics _train_checkpoints/checkpoint.12.pt


In [None]:
# train the model
model.train()
num_epochs = 50
ites_per_epoch = len(loader)

while cur_epoch < num_epochs:
    print('Epoch {}/{}'.format(cur_epoch, num_epochs - 1))
    print('-' * 10)

    start = time.time()
    total = running_corrects = 0
    total_loss = 0
###################### CHECKPOINTING ########################
    # the loader does not know at which iteration it will start 
    # when it is loaded again from a checkpoint. If we simply
    # enumerate, the loader will go through the rest of the datapoints, 
    # but the counting of "idx" will be wrong
    # Therefore, we should not reset the value of cur_iter when
    # it is freshly loaded from the checkpoint
    if not loaded:
        cur_iter.set_val(0)
    loaded = False
    for idx, (d, l) in enumerate(loader):
############################################################

        ########### Tweak input ##########
        # depending on what your model wants, tensor shapes may require a permute
        inputs = d.to(device).permute(0, 4, 2, 3, 1)
        
        # for single class, we just use the 0th element in the label
        labels = [li["label_ids"][0] for li in l]
        labels = torch.tensor(labels).to(device)

        # zero the parameter gradients
        optimizer.zero_grad()
        outputs = model(inputs)
        
        loss = criterion(outputs, labels)
        total_loss += loss.item()
        
        loss.backward()
        optimizer.step()
        
        _, preds = torch.max(outputs, 1)
        running_corrects += torch.sum(preds == labels)
        total += len(labels)
        
        log_and_print("Iteration {}, Loss: {}".format(cur_iter, loss.item()))
        
###################### CHECKPOINTING ########################\
        cur_iter.add_(1)
        manager.save(do_logging=True)
#############################################################
    
    duration = time.time() - start
    accuracy = running_corrects / total
    loss =  total_loss / len(loader)
    print_str = "\n".join([
            "Epoch took {:10.2f}s".format(duration),
            "Average time per batch {}".format(duration/ites_per_epoch),
            "Accuracy: {}".format(accuracy),
            "Epoch Loss: {}".format(loss)
    ])
    log_and_print(print_str)
###################### CHECKPOINTING ########################\
    losses.append(loss)
    acc.append(accuracy)
    cur_epoch.add_(1)
###################### CHECKPOINTING ########################\


Epoch 0/49
----------
Iteration 221, Loss: 0.5461950898170471
Iteration 222, Loss: 0.7653592824935913
Iteration 223, Loss: 0.5180494785308838
Iteration 224, Loss: 0.4324283301830292
Iteration 225, Loss: 0.524896502494812
Iteration 226, Loss: 0.6371507048606873
Iteration 227, Loss: 0.5810372233390808
Iteration 228, Loss: 0.42662107944488525
Iteration 229, Loss: 0.46102991700172424
Iteration 230, Loss: 0.4658900797367096
Iteration 231, Loss: 0.6230384111404419
Iteration 232, Loss: 0.539046049118042
Iteration 233, Loss: 0.43059056997299194
Iteration 234, Loss: 0.5698506236076355
Iteration 235, Loss: 0.2969943881034851
Iteration 236, Loss: 0.5753830075263977
Epoch took      89.87s
Average time per batch 0.37920015471897045
Accuracy: 0.7520000338554382
Epoch Loss: 0.03541586809017487
Epoch 1/49
----------
Iteration 0, Loss: 0.2884337306022644
Iteration 1, Loss: 0.937825620174408
Iteration 2, Loss: 0.42413580417633057
Iteration 3, Loss: 0.25568124651908875
Iteration 4, Loss: 0.43189582228660