In [47]:
import os
import torch
import json
from torch.utils.data import Dataset
from torchvision.io import read_image
from torch.nn.functional import one_hot
from path import Path

class VideoDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dirs = Path(root_dir).dirs()
        self.videos = [i for i in self.root_dirs]
    def __len__(self):
        return len(self.videos)
    
    def __getitem__(self, idx):
        
        data = {}        
        frames =  [read_image(i) for i in  self.videos[idx].files('*.jpg')]
        frames = [torch.nn.functional.interpolate(i.unsqueeze(0), size=256).squeeze(0) for i in frames]
        with open(os.path.join(self.videos[idx],'frame_data.json'), 'r') as f: # open the json file
            video_data = json.load(f)
        video_start_points = video_data['video_start_points']
        labels = [1 if i in video_start_points else 0 for i in range(len(frames))]
        
        data['frames'] = torch.stack(frames)
        data['labels'] = torch.tensor(labels)
        data['video_start_points'] = video_start_points
        data['path'] = self.videos[idx]
        return data

In [48]:
dataset = VideoDataset( root_dir='/mnt/drive1/hsun/videoSeg/data/video_datasets/CondensedMovies/new_videos')
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, pin_memory=True, num_workers=8)


In [50]:
data = next(iter(dataloader))

In [51]:
data['frames'].shape, data['labels'].shape, data['video_start_points'], data['path']

(torch.Size([1, 202, 3, 256, 256]),
 torch.Size([1, 202]),
 [tensor([83])],
 [Path('/mnt/drive1/hsun/videoSeg/data/video_datasets/CondensedMovies/new_videos/1214')])