In [31]:
# Load the annotation file
anno_path = '/workspace/work/misc/O2ONet/data/annotations_minus_unavailable_yt_vids.pkl'

import pickle as pkl

f = open(anno_path, 'rb')
anno = pkl.load(f)
f.close()

In [32]:
# For i3d based bbox features

import numpy as np
import torch
import torchvision
from torchvision.ops import MultiScaleRoIAlign
from torchvision.models.detection import FasterRCNN
from torch import nn
import torchvision.transforms as transforms
import pytorchvideo.models as models
import torch.nn.functional as F
from collections import OrderedDict
from torch import nn

class FeatureExtractor(nn.Module):

    def __init__(self, submodule, layer):
        super(FeatureExtractor, self).__init__()
        self.pretrain_model = submodule
        self.layer = layer
        
        self.layer_list = list(self.pretrain_model._modules['blocks']._modules.keys())
        output_layer = self.layer_list[self.layer]  # just change the number of the layer to get the output

        self.children_list = []
        for (name, comp_layer) in self.pretrain_model._modules['blocks'].named_children():
            self.children_list.append(comp_layer)
            if name == output_layer:
                break
        self.feature_extrac_net = nn.Sequential(*self.children_list)
        self.pretrain_model = None

    def forward(self, image):
        feature = self.feature_extrac_net(image)
        return feature

from pytorchvideo.data.encoded_video import EncodedVideo
from torchvision.transforms import Compose, Lambda
from torchvision.transforms._transforms_video import  NormalizeVideo
from pytorchvideo import transforms
from pytorchvideo.transforms import (
    ApplyTransformToKey,
    UniformTemporalSubsample,
)

import numpy as np

def read_gif(gif_path):
    """read gif and return dictionary as key ['video'] and value the tensor of size(CxTxHXW)"""

    video = EncodedVideo.from_path(gif_path)
    video = video.get_clip(0, 5) # get_clip fetches the clip from starting time to ending time

    return video

In [33]:
# Master Feature Generator for VSGNet
import torch

'''
get the current dictionary - DONE
get the current gif_folder - DONE
get the i3d feature extractor
and the transform
and the device
'''
def master_feature_generator(annotations, current_file, 
                             gif_folder, 
                             i3d_feature_extractor,
                             i3d_transform, device):

    current_dict = torch.load(current_file)

    # Getting details to load the GIF
    yt_id = current_dict['metadata']['yt_id']
    frame_index = current_dict['metadata']['frame no.']

    # locate the annotation
    current_anno = None
    for a in annotations:
        temp_yt_id = a['metadata']['yt_id']
        temp_frame_no = a['metadata']['frame no.']
        if temp_yt_id == yt_id and temp_frame_no == frame_index:
            current_anno = a
            break
    if current_anno == None:
        raise ValueError('Annotation Element not found')

    # 'bboxes': {'0': {'class': 'generic_object', 'bbox': [35, 176, 1277, 343]},
    # '1': {'class': 'hand', 'bbox': [1113, 116, 1233, 181]},
    # '2': {'class': 'hand', 'bbox': [663, 81, 843, 176]}}        

    bbox_data = current_anno['bboxes']
    keys  = list(bbox_data.keys())
    classes = [-1 for i_ in range(12)]
    
    for i, k in enumerate(keys):
        temp_class = bbox_data[k]['class']

        if temp_class == 'generic_object':
            classes[i] = 0
            continue

        if temp_class == 'hand':
            classes[i] = 1
            continue
        
        else:
            raise ValueError('object class not recognized ', + str(temp_class))
    
    current_dict['obj_classes'] = classes
    
    window_size = 5

    # Loading the gif    
    filename = yt_id + '_' + str(frame_index) + '_' + str(window_size) + '.gif'
    import os
    file_location = os.path.join(gif_folder, filename)

    # i3d features
    temp_i3d_video = read_gif(file_location)
    temp_i3d_video = i3d_transform(temp_i3d_video)["video"]
    temp_i3d_video = temp_i3d_video.unsqueeze(0).to(device)

    temp_i3d_feature_map = i3d_feature_extractor(temp_i3d_video)

    current_dict['i3d_fmap'] = temp_i3d_feature_map[0].to(torch.device('cpu'))
    current_dict['obj_class_map'] = {'Hand' : 1, 'Generic Object' : 0}
        
    return current_dict, 1

In [34]:

# Load the annotation file
anno_path = '/workspace/work/misc/O2ONet/data/annotations_minus_unavailable_yt_vids.pkl'

import pickle as pkl

f = open(anno_path, 'rb')
annotations = pkl.load(f)
f.close()

gif_folder = '/workspace/data/data_folder/o2o/gifs_11'

import torchvision
device = torch.device('cuda:3') if torch.cuda.is_available() else torch.device('cpu')
layer_no = 5

# i3d feature extractor
import pytorchvideo.models as models
model_name = "i3d_r50"
model = torch.hub.load("facebookresearch/pytorchvideo:main", model=model_name, pretrained=True)
model = model.to(device)
i3d_feature_net = FeatureExtractor(model, layer_no)

# i3d transform
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]

from torchvision.transforms import Resize

i3d_transform =  ApplyTransformToKey(
    key="video",
    transform=Compose(
        [
            UniformTemporalSubsample(11),
            Resize((720,1280)),
            Lambda(lambda x: x/255.0),
            NormalizeVideo(mean, std)    
        ]
    ),
)

Using cache found in /root/.cache/torch/hub/facebookresearch_pytorchvideo_main


In [36]:
import os

gif_folder = '/workspace/data/data_folder/o2o/gifs_11'
from glob import glob as glob
file_list = glob('/workspace/data/data_folder/o2o/gifs_11_features_vsgnet/*.pt')
# from tqdm import tqdm as tqdm
saving_folder = '/workspace/data/data_folder/o2o/gifs_11_features_ral'

errors = {}
errors['file'] = []
errors['exceptions'] = []

from tqdm import tqdm as tqdm

for f in tqdm(file_list):

    try:    
        file_name = f.split('/')[-1]
        file_location = os.path.join(saving_folder, file_name)
        res, success = master_feature_generator(annotations, f, gif_folder, i3d_feature_net, 
                                                i3d_transform, device)
        torch.save(res, file_location)
    except Exception as e:
        print("Issue")
        errors['file'] = file_name
        errors['exceptions'] = e
    

100%|██████████| 2052/2052 [1:34:10<00:00,  2.75s/it]  


In [37]:
errors

{'file': [], 'exceptions': []}

In [25]:
res.keys()

dict_keys(['metadata', 'num_obj', 'bboxes', 'lr', 'mr', 'cr', 'object_pairs', 'num_relation', 'frame_deep_features', 'i3d_fmap', 'obj_classes', 'obj_class_map'])