In [1]:
from models import get_masked_vit_base_patch16_224
from utils.parser import parse_args, load_config
from torch.utils.data import default_collate

import torch
import sys
import argparse
import numpy as np

In [2]:
def parse_args():
    """
    Parse the following arguments for a default parser for PySlowFast users.
    Args:
        shard_id (int): shard id for the current machine. Starts from 0 to
            num_shards - 1. If single machine is used, then set shard id to 0.
        num_shards (int): number of shards using by the job.
        init_method (str): initialization method to launch the job with multiple
            devices. Options includes TCP or shared file-system for
            initialization. details can be find in
            https://pytorch.org/docs/stable/distributed.html#tcp-initialization
        cfg (str): path to the config file.
        opts (argument): provide addtional options from the command line, it
            overwrites the config loaded from file.
    """
    parser = argparse.ArgumentParser(
        description="Provide SlowFast video training and testing pipeline."
    )
    parser.add_argument(
        "--shard_id",
        help="The shard id of current node, Starts from 0 to num_shards - 1",
        default=0,
        type=int,
    )
    parser.add_argument(
        "--num_shards",
        help="Number of shards using by the job",
        default=1,
        type=int,
    )
    parser.add_argument(
        "--init_method",
        help="Initialization method, includes TCP or shared file-system",
        default="tcp://localhost:9999",
        type=str,
    )
    parser.add_argument(
        "--cfg",
        dest="cfg_file",
        help="Path to the config file",
        default="configs/Kinetics/SLOWFAST_4x16_R50.yaml",
        type=str,
    )
    parser.add_argument(
        "opts",
        help="See slowfast/config/defaults.py for all options",
        default=None,
        nargs=argparse.REMAINDER,
    )
    if len(sys.argv) == 1:
        parser.print_help()
    return parser.parse_args([])


class TubeMaskingGenerator:
    def __init__(self, input_size, mask_ratio):
        self.frames, self.height, self.width = input_size
        self.num_patches_per_frame =  self.height * self.width
        self.total_patches = self.frames * self.num_patches_per_frame 
        self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame)
        self.total_masks = self.frames * self.num_masks_per_frame

    def __repr__(self):
        repr_str = "Masks: total patches {}, mask patches {}".format(
            self.total_patches, self.total_masks
        )
        return repr_str

    def __call__(self):
        mask_per_frame = np.hstack([
            np.zeros(self.num_patches_per_frame - self.num_masks_per_frame),
            np.ones(self.num_masks_per_frame),
        ])
        np.random.shuffle(mask_per_frame)
        mask = np.tile(mask_per_frame, (self.frames,1)).flatten()
        return mask


def get_sinusoid_encoding_table(n_position, d_hid): 
    ''' Sinusoid position encoding table ''' 
    # TODO: make it with torch instead of numpy 
    def get_position_angle_vec(position): 
        return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] 

    sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) 
    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 
    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 

    return torch.FloatTensor(sinusoid_table).unsqueeze(0)

In [3]:
B = 2
C = 768
num_tokens = 1568
img_size = 224
patch_size = 16
num_patches = img_size//patch_size
frames = 8
masker = TubeMaskingGenerator((frames,num_patches,num_patches), 0.8)

batch = []
for i in range(B):
    batch.append(((torch.randn(3,8,224,224)), torch.from_numpy(masker()).to(torch.bool)))

batch = default_collate(batch)

In [4]:
video = batch[0]
mask = batch[1]

In [5]:
cfg = './models/configs/Kinetics/TimeSformer_divST_8x32_224.yaml'
opt = parse_args()
opt.cfg_file = cfg
config = load_config(opt)

In [6]:
model = get_masked_vit_base_patch16_224(cfg=config, no_head=True)

Loaded model inside with msg: _IncompatibleKeys(missing_keys=['time_embed', 'blocks.0.temporal_fc.weight', 'blocks.0.temporal_fc.bias', 'blocks.1.temporal_fc.weight', 'blocks.1.temporal_fc.bias', 'blocks.2.temporal_fc.weight', 'blocks.2.temporal_fc.bias', 'blocks.3.temporal_fc.weight', 'blocks.3.temporal_fc.bias', 'blocks.4.temporal_fc.weight', 'blocks.4.temporal_fc.bias', 'blocks.5.temporal_fc.weight', 'blocks.5.temporal_fc.bias', 'blocks.6.temporal_fc.weight', 'blocks.6.temporal_fc.bias', 'blocks.7.temporal_fc.weight', 'blocks.7.temporal_fc.bias', 'blocks.8.temporal_fc.weight', 'blocks.8.temporal_fc.bias', 'blocks.9.temporal_fc.weight', 'blocks.9.temporal_fc.bias', 'blocks.10.temporal_fc.weight', 'blocks.10.temporal_fc.bias', 'blocks.11.temporal_fc.weight', 'blocks.11.temporal_fc.bias', 'head.weight', 'head.bias'], unexpected_keys=[])


In [7]:
out = model(video, mask=True)

In [8]:
if isinstance(out, tuple) == True:
    print('Tuple of length:', len(out))
    for i in range(len(out)):
        print(out[i].shape)
else:
    print('returns tensor of shape:', out.shape)

Tuple of length: 3
torch.Size([2, 768])
torch.Size([2, 1248, 768])
torch.Size([2, 1248, 768])
