In [1]:
import torch
from pytorchvideo.models.hub import (  # noqa: F401, E402

    slowfast_r50,
    slowfast_r50_detection,
    x3d_l,
    x3d_m,
    x3d_s,
    x3d_xs,
)

from torch.hub import load_state_dict_from_url

device = 'cuda' if torch.cuda.is_available() else 'cpu'



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import json
labels = 'G:\\.shortcut-targets-by-id\\1eyTB0qCfXgrxNsrmWNeLNbd5sTKzP5HT\\Data Wizards\\dataset\\labels\\min.json'
with open(labels, 'r') as file:
    vid_label = json.load(file)

In [3]:
category_mapping = {}
choice_mapping = {}

for i in vid_label:
    name = i['video'][i['video'].rfind('/') + 1:]
    if 'choice' in i.keys():
        choice = i['choice']
        if choice not in choice_mapping.keys():
            choice_mapping[choice] = len(choice_mapping)
        category_mapping[name] = choice_mapping[choice]

In [4]:
from pytorchvideo.transforms import (
    ApplyTransformToKey,
    ShortSideScale,
    UniformTemporalSubsample,
    UniformCropVideo
) 
from torchvision.transforms import Compose, Lambda
from torchvision.transforms._transforms_video import (
    CenterCropVideo,
    NormalizeVideo,
)

side_size = 256
mean = [0.45, 0.45, 0.45]
std = [0.225, 0.225, 0.225]
crop_size = 256
num_frames = 32
sampling_rate = 2
frames_per_second = 30
slowfast_alpha = 4
slow_num_frames = num_frames // slowfast_alpha
num_clips = 10
num_crops = 3

class PackPathway(torch.nn.Module):
    """
    Transform for converting video frames as a list of tensors. 
    """
    def __init__(self):
        super().__init__()
        
    def forward(self, frames: torch.Tensor):
        fast_pathway = frames
        # Perform temporal sampling from the fast pathway.
        slow_pathway = torch.index_select(
            frames,
            1,
            torch.linspace(
                0, frames.shape[1] - 1, slow_num_frames
            ).long(),
        )
        frame_list = [slow_pathway, fast_pathway]
        return frame_list
    
transform =  ApplyTransformToKey(
    key="video",
    transform=Compose(
        [
            UniformTemporalSubsample(num_frames),
            Lambda(lambda x: x/255.0),
            NormalizeVideo(mean, std),
            ShortSideScale(
                size=side_size
            ),
            CenterCropVideo(crop_size),
            PackPathway()
        ]
    ),
)



In [5]:
import os
import json
from pytorchvideo.data.encoded_video import EncodedVideo
import gc

vid_file = 'G:\\.shortcut-targets-by-id\\1eyTB0qCfXgrxNsrmWNeLNbd5sTKzP5HT\\Data Wizards\\dataset\\videoSync'

vids_tensor = []
vids_category = []


for root, dirs, files in os.walk(vid_file):
    for name in files:
        vid_path = os.path.join(root, name)
        if not vid_path.endswith('.mp4') or name not in category_mapping.keys():
            continue
        video = EncodedVideo.from_path(vid_path)
        video_data = video.get_clip(start_sec=0, end_sec=3)
        del video
        gc.collect()
        vids_tensor.append(transform(video_data)['video'])
        vids_category.append(category_mapping[name])


In [14]:
# Choose the `slowfast_r50` model 
# model = torch.hub.load('facebookresearch/pytorchvideo', 'slowfast_r50', pretrained=True, model_num_class = 5)
model = slowfast_r50(pretrained=False, model_num_class = 5).to(device)


root_dir = "https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo"
checkpoint_paths = {
    "slowfast_r50": f"{root_dir}/kinetics/SLOWFAST_8x8_R50.pyth",
    "slowfast_r50_detection": f"{root_dir}/ava/SLOWFAST_8x8_R50_DETECTION.pyth",
    "slowfast_r101": f"{root_dir}/kinetics/SLOWFAST_8x8_R101.pyth",
    "slowfast_16x8_r101_50_50": f"{root_dir}/kinetics/SLOWFAST_16x8_R101_50_50.pyth",
}

checkpoint = load_state_dict_from_url(checkpoint_paths["slowfast_r50"])
state_dict = checkpoint["model_state"]



with torch.no_grad():
    # for name, weight in model.state_dict().items():
    for k in model.state_dict().keys():
        # print(type(model.state_dict()[k]))
        # model.state_dict()[k].set_(torch.ones(size=model.state_dict()[k].shape, dtype=model.state_dict()[k].dtype))
        # print(model.state_dict()[k].shape)
        if model.state_dict()[k].shape and model.state_dict()[k].shape == state_dict[k].shape:
            model.state_dict()[k][:] = state_dict[k]
# model.load_state_dict()

# model = mvit_base_16x4(pretrained=False)
# print(model)

del state_dict

In [7]:
import gc
torch.cuda.empty_cache()
gc.collect(0)

0

In [15]:
from torch.nn import CrossEntropyLoss
from torch.optim import SGD, Adam
import numpy as np
import tqdm
import math



epoch = 5
batch_size = 2
num_categories = 5


optimizer = Adam(model.parameters(), lr=1e-4)
train_size = len(vids_tensor)
steps = math.ceil(train_size / batch_size)
crossEntropy = CrossEntropyLoss()

# input = [torch.zeros(size=(batch_size, 3, slow_num_frames, crop_size, crop_size), device= device), 
#         torch.zeros(size=(batch_size, 3, num_frames, crop_size, crop_size), device= device)]

target = torch.zeros(size=(batch_size,), device= device, dtype= torch.long)

for epoch_i in range(0, epoch):
    model.train()
    reordered = torch.randperm(train_size)
    loss_list = []

    idx = 0
    all = 0
    correct = 0

    
    for step in tqdm.tqdm(range(steps)):
        input = [[],[]]
        for i in range(batch_size):
            if idx == train_size:
                break
            
            input[0].append(vids_tensor[reordered[idx]][0].unsqueeze(0))
            input[1].append(vids_tensor[reordered[idx]][1].unsqueeze(0))


            target[i] = vids_category[reordered[idx]]
            
            idx += 1
        input[0] = torch.cat(tuple(input[0]), dim=0).to(device)
        input[1] = torch.cat(tuple(input[1]), dim=0).to(device)
        # print(input[0].shape)
        output = model(input)
        
        
        loss = crossEntropy(output, target)

        correct += torch.sum(torch.argmax(output, dim= 1) == target).item()
        all += batch_size

        
        optimizer.zero_grad()
        loss.backward()
        loss_list.append(loss.item())
        optimizer.step()

        del input

    torch.cuda.empty_cache()
    gc.collect()
    print(epoch_i, end = ' train loss:')
    print(np.mean(loss_list))

    print(correct / all)
    



100%|██████████| 48/48 [00:16<00:00,  2.91it/s]


0 train loss:1.0154871900255482
0.6458333333333334


100%|██████████| 48/48 [00:16<00:00,  2.96it/s]


1 train loss:0.5849497207285216
0.8125


100%|██████████| 48/48 [00:16<00:00,  2.93it/s]


2 train loss:0.2899218066013418
0.9583333333333334


100%|██████████| 48/48 [00:16<00:00,  2.95it/s]


3 train loss:0.11295494400352861
1.0


100%|██████████| 48/48 [00:16<00:00,  2.90it/s]

4 train loss:0.05074871005732954
1.0





In [11]:
print(correct / all)

0.3958333333333333


In [54]:
output.shape

torch.Size([1, 5])

In [31]:
vids_tensor[reordered[idx]][0].shape

torch.Size([3, 8, 256, 256])

In [40]:
input[0].shape

torch.Size([1, 64, 8, 64, 64])

In [38]:
input[0].shape

torch.Size([1, 64, 8, 64, 64])

In [33]:
input = [torch.zeros(size=(batch_size, 3, slow_num_frames, crop_size, crop_size), device= device), 
        torch.zeros(size=(batch_size, 3, num_frames, crop_size, crop_size), device= device)]