In [1]:
# parameters

truncate = 10
location = '/root/.cache/keypoints'
clear_cache = False

In [2]:
import joblib

from joblib import Memory
memory = Memory(location,verbose=0)
if clear_cache:
    memory.clear()

In [3]:
from pathlib import Path
import numpy as np
from torch.utils.data import Dataset
from functools import cached_property
from collections import OrderedDict
from itertools import islice

import keypoints_io
folder_video_keypoints = "../data/keypoints/as_arrays" 
folder_video_keypoints = str(Path(folder_video_keypoints).resolve())
# with respect to the notebook's directory, all the notebooks are launched with respect to their folder's location
# using papermill or ploomber engine

def uniform_padding(array,max_shape,value):
    dim = len(array.shape) # of set of the arrays
    padding_width = ((0,max_shape-array.shape[0]),*[(0,0) for _ in range(dim-1)])
    new_array = np.pad(array,padding_width,constant_values=value)
    return new_array
        
    
@memory.cache
def read_and_filter_keypoints(path_file,with_frame_padding,default_shape=(17,2),max_number_of_frames=None):
    if with_frame_padding:
        assert isinstance(max_number_of_frames,int),"a number of maximum frames must be provided"
        with_frame_padding = False
        keypoints,is_frame_present = read_and_filter_keypoints(path_file,with_frame_padding,default_shape,
                                                              max_number_of_frames)
        assert len(keypoints)<=max_number_of_frames
        assert len(is_frame_present)<=max_number_of_frames
        
        keypoints = uniform_padding(keypoints,max_number_of_frames,np.nan)
        is_frame_present = uniform_padding(is_frame_present,max_number_of_frames,False)
    else:
        keypoints,scores = keypoints_io.load(path_file)
        keypoints = [keypoints_frames[np.argmax(scores_frames)] if len(scores_frames)>0 else None for (keypoints_frames,scores_frames) in zip(keypoints,scores)]

        is_frame_present = np.array([el is not None for el in keypoints])
        keypoints = np.array([el if keep else np.nan*np.ones(default_shape)  for el,keep in zip(keypoints,is_frame_present)])
        keypoints = keypoints.astype("float32")
    keypoints = np.nan_to_num(keypoints)
    return keypoints,is_frame_present


class VideoKeyPointDataset(Dataset):
    def __init__(self,folder_video_keypoints,with_frame_padding=True,truncate=None):
        self.folder_video_keypoints = folder_video_keypoints
        self.with_frame_padding = with_frame_padding
        self.video_keypoints_field_paths = list(islice(Path(self.folder_video_keypoints).rglob("*.npz"),truncate))
    @cached_property
    def number_of_frames(self):
        assert not(self.with_frame_padding),"can't get the exact number of frames if padding is activated"
        res = {i:len(self.__getitem__(i)["kpts"]) for i in range(len(self))}
        return res
    
    @cached_property
    def max_number_of_frames(self):
        res = max(self.number_of_frames.values())
        return res

    
    def __getitem__(self,i):
        outputs = OrderedDict()
        if self.with_frame_padding:
            assert "max_number_of_frames" in self.__dict__,"""
            max_number_of_frames should be computed at least once,with self.with_frame_padding
            set to False before using padding 
            """
            kpts,is_detection_present = read_and_filter_keypoints(self.video_keypoints_field_paths[i],
                                                          with_frame_padding=self.with_frame_padding,
                                                          max_number_of_frames=self.max_number_of_frames)
            outputs["number_of_frames"]=self.number_of_frames[i]
        else:
            kpts,is_detection_present = read_and_filter_keypoints(self.video_keypoints_field_paths[i],
                                                          with_frame_padding=self.with_frame_padding,
                                                              max_number_of_frames=None)
        outputs.update(kpts=kpts,
                      is_detection_present=is_detection_present)
        return outputs
    
    def __len__(self):
        return len(self.video_keypoints_field_paths)
    


In [4]:
video_kpt_dataset = VideoKeyPointDataset(folder_video_keypoints,with_frame_padding=False,truncate=truncate)

In [5]:
if truncate is not None:
    assert len(video_kpt_dataset) == min(truncate,len(video_kpt_dataset.folder_video_keypoints))
print(len(video_kpt_dataset))

10


In [6]:
# fill the cache and print informations about caching in other notebook
from ploomber_engine.ipython import PloomberClient

path_cache_setup = "./performance_tests/caching_tests_and_metrics.ipynb"
path_cache_setup = str(Path(path_cache_setup).resolve())

client = PloomberClient.from_path(path_cache_setup,remove_tagged_cells="notebook_call")
namespace = client.get_namespace(dict(video_kpt_dataset=video_kpt_dataset))

Executing cell: 8: 100%|██████████████████████████| 9/9 [00:01<00:00,  4.57it/s]


In [7]:
import torch

def collatefn(batch):
    batch = {k: [dic[k] for dic in batch] for k in batch[0]}
    nb_frames,kpts,is_detection_present = batch.values()
    max_nb_frames = max(nb_frames) 
    # we compute the maximum over the batch for some extra computation savings
    kpts = torch.tensor(kpts)
    kpts = kpts.reshape(kpts.shape[0],kpts.shape[1],-1)
    is_detection_present = torch.tensor(is_detection_present)
    
    kpts = kpts[:,:max_nb_frames]
    is_detection_present = is_detection_present[:,:max_nb_frames]
    #nb_frames = np.array(nb_frames)
    return kpts,is_detection_present

from torch.utils.data import DataLoader

video_dataloader = DataLoader(video_kpt_dataset,shuffle=True,collate_fn=collatefn,batch_size=10)

In [9]:
if False:
    res = next(iter(video_dataloader))
    res[0].shape,res[1].shape

  kpts = torch.tensor(kpts)
  is_detection_present = torch.tensor(is_detection_present)


In [16]:
type(video_dataloader.sampler) == torch.utils.data.sampler.RandomSampler

True

In [None]:
#sanity check test
#res[0][0].numpy()
#video_kpt_dataset[0][0]#res[0][0].numpy()
#video_kpt_dataset[0][0]#res[0][0].numpy()
#video_kpt_dataset[0][0]
unshuffled = False
if unshuffled:
    loader_iter = iter(d_loader)
    res = next(loader_iter)
    assert np.all(res[0][0][6].numpy() == video_kpt_dataset[0][0][6])