In [1]:
from collections import OrderedDict
import os
import glob
import bisect
import numpy as np
import skimage.io as io
from tqdm import tqdm

def file_key(filename):
    basename = os.path.basename(filename)
    return int(basename.split("-")[0])

def file_label(filename):
    basename = os.path.splitext(os.path.basename(filename))[0]
    return int(basename.split("-")[1])

class CustomDataSet:
    def __init__(self, main_dir, window_size):
        self.main_dir = main_dir
        self.window_size = window_size
        self.scan_dataset()
        
    def scan_dataset(self):
        self.cumulative_idxs = []
        self.cumulative_dirs = []

        cumsum = 0
        for d in tqdm(os.listdir(self.main_dir)[:10]):
            directory = os.path.join(self.main_dir, d)
            if os.path.isdir(directory):
                n_files = len(glob.glob(os.path.join(directory, "*.jpg")))
                n_idx = n_files - self.window_size + 1
                
                self.cumulative_idxs.append(cumsum)
                self.cumulative_dirs.append(d)
                cumsum += n_idx
                
        self.n_idxs = cumsum

    def __len__(self):
        return self.n_idxs

    def __getitem__(self, idx):
        directory_idx = bisect.bisect_right(self.cumulative_idxs, idx) - 1
        d = self.cumulative_dirs[directory_idx]
        d_start_idx = self.cumulative_idxs[directory_idx]
        
        files = glob.glob(os.path.join(self.main_dir, d, "*.jpg"))
        files.sort(key=file_key)
        
        start_idx = idx - d_start_idx
        f = files[start_idx:start_idx+self.window_size]
        return f
    
        # image = Image.open(f).convert("RGB")
        # return image

In [2]:
class MyIterableDataset:
    def __init__(self, main_dir, window_size):
        self.main_dir = main_dir
        self.window_size = window_size
        self.scan_dataset()
        
    def scan_dataset(self):
        self.cumulative_idxs = []
        self.cumulative_dirs = []

        cumsum = 0
        for d in tqdm(os.listdir(self.main_dir)[:10]):
            directory = os.path.join(self.main_dir, d)
            if os.path.isdir(directory):
                n_files = len(glob.glob(os.path.join(directory, "*.jpg")))
                n_idx = n_files - self.window_size + 1
                
                self.cumulative_idxs.append(cumsum)
                self.cumulative_dirs.append(d)
                cumsum += n_idx
                
        self.n_idxs = cumsum

    def get_paths(self, idx):
        directory_idx = bisect.bisect_right(self.cumulative_idxs, idx) - 1
        d = self.cumulative_dirs[directory_idx]
        d_start_idx = self.cumulative_idxs[directory_idx]
        
        files = glob.glob(os.path.join(self.main_dir, d, "*.jpg"))
        files.sort(key=file_key)
        
        start_idx = idx - d_start_idx
        f = files[start_idx:start_idx+self.window_size]
        return f
    
        # image = Image.open(f).convert("RGB")
        # return image

    def __iter__(self):
        worker_info = None # torch.utils.data.get_worker_info()
        if worker_info is None:  # single-process data loading, return the full iterator
            self.iter_start = 0
            self.iter_end = self.n_idxs
        else:  # in a worker process
            # split workload
            per_worker = int(math.ceil((self.n_idxs) / float(worker_info.num_workers)))
            worker_id = worker_info.id
            self.iter_start = worker_id * per_worker
            self.iter_end = min(iter_start + per_worker, self.n_idxs)
            
        self.i = self.iter_start
        return self
    
    def __next__(self):
        if self.i == self.iter_end:
            raise StopIteration
        
        # Check if we're starting new video
        directory_idx = bisect.bisect_right(self.cumulative_idxs, self.i) - 1
        d = self.cumulative_dirs[directory_idx]
        d_start_idx = self.cumulative_idxs[directory_idx]
        
        if self.i == d_start_idx or self.last_images is None or self.last_labels is None:
            paths = self.get_paths(self.i)
            self.last_images = [io.imread(f) for f in paths]
            self.last_labels = np.array([file_label(f) for f in paths])
            self.i += 1
            return self.last_images, self.last_labels
        else:
            paths = self.get_paths(self.i)
            new_path = paths[-1]
            new_image = io.imread(new_path)
            self.last_images = np.concatenate([self.last_images[1:], [new_image]])
            new_label = file_label(new_path)
            self.last_labels = np.concatenate([self.last_labels[1:], [new_label]])
            
            self.i += 1
            return self.last_images, self.last_labels

In [3]:
c = MyIterableDataset('/home/ubuntu/data/videos/', 2)

100%|██████████| 10/10 [00:00<00:00, 46.04it/s]


In [4]:
it = iter(c)

In [5]:
next(it)

([array([[[26, 40, 53],
          [24, 41, 59],
          [21, 41, 66],
          ...,
          [ 6, 20, 23],
          [ 9, 19, 21],
          [10, 18, 21]],
  
         [[24, 39, 60],
          [23, 41, 65],
          [20, 41, 68],
          ...,
          [ 8, 19, 23],
          [ 9, 19, 21],
          [10, 18, 21]],
  
         [[22, 39, 69],
          [20, 41, 70],
          [17, 42, 72],
          ...,
          [ 8, 19, 23],
          [ 9, 18, 23],
          [10, 18, 21]],
  
         ...,
  
         [[25, 21, 20],
          [28, 24, 23],
          [28, 24, 23],
          ...,
          [ 9, 18, 25],
          [ 9, 18, 25],
          [ 9, 18, 25]],
  
         [[30, 26, 25],
          [32, 28, 27],
          [31, 27, 26],
          ...,
          [ 9, 18, 25],
          [ 9, 18, 25],
          [ 9, 18, 25]],
  
         [[35, 30, 34],
          [35, 30, 34],
          [32, 27, 31],
          ...,
          [ 9, 18, 25],
          [ 9, 18, 25],
          [ 9, 18, 25]]], dtype=u

In [6]:
next(it)

(array([[[[26, 40, 53],
          [24, 41, 59],
          [21, 41, 66],
          ...,
          [ 6, 20, 23],
          [ 9, 19, 21],
          [10, 18, 21]],
 
         [[24, 39, 60],
          [23, 41, 65],
          [20, 41, 68],
          ...,
          [ 8, 19, 23],
          [ 9, 19, 21],
          [10, 18, 21]],
 
         [[22, 39, 69],
          [20, 41, 70],
          [17, 42, 72],
          ...,
          [ 8, 19, 23],
          [ 9, 18, 23],
          [10, 18, 21]],
 
         ...,
 
         [[27, 23, 22],
          [25, 21, 20],
          [26, 22, 21],
          ...,
          [ 9, 18, 25],
          [ 9, 18, 25],
          [ 9, 18, 25]],
 
         [[32, 31, 27],
          [30, 29, 25],
          [28, 27, 23],
          ...,
          [ 9, 18, 25],
          [ 9, 18, 25],
          [ 9, 18, 25]],
 
         [[32, 31, 26],
          [35, 34, 29],
          [35, 34, 29],
          ...,
          [ 9, 18, 25],
          [ 9, 18, 25],
          [ 9, 18, 25]]],
 
 
        [

In [7]:
for i in range(738):
    next(it)

In [8]:
next(it)

(array([[[[ 84,  51,  58],
          [ 87,  54,  63],
          [ 66,  33,  44],
          ...,
          [ 51,  18,  27],
          [ 50,  17,  26],
          [ 50,  19,  27]],
 
         [[ 75,  40,  47],
          [ 77,  44,  53],
          [ 80,  47,  58],
          ...,
          [ 48,  18,  26],
          [ 45,  18,  25],
          [ 49,  24,  30]],
 
         [[ 71,  36,  42],
          [ 86,  51,  58],
          [ 79,  46,  57],
          ...,
          [ 43,  22,  27],
          [ 44,  25,  29],
          [ 42,  23,  27]],
 
         ...,
 
         [[115, 107,  96],
          [108, 101,  91],
          [105, 101,  92],
          ...,
          [ 44,  45,  40],
          [ 37,  38,  33],
          [ 46,  47,  42]],
 
         [[135, 125, 115],
          [129, 120, 111],
          [130, 126, 117],
          ...,
          [ 36,  38,  33],
          [ 41,  43,  38],
          [ 53,  55,  50]],
 
         [[123, 110, 101],
          [118, 109, 100],
          [121, 117, 108],
   

In [9]:
next(it)

(array([[[[ 84,  51,  58],
          [ 87,  54,  63],
          [ 66,  33,  44],
          ...,
          [ 51,  18,  27],
          [ 50,  17,  26],
          [ 50,  19,  27]],
 
         [[ 75,  40,  47],
          [ 77,  44,  53],
          [ 80,  47,  58],
          ...,
          [ 48,  18,  26],
          [ 45,  18,  25],
          [ 49,  24,  30]],
 
         [[ 71,  36,  42],
          [ 86,  51,  58],
          [ 79,  46,  57],
          ...,
          [ 43,  22,  27],
          [ 44,  25,  29],
          [ 42,  23,  27]],
 
         ...,
 
         [[115, 107,  96],
          [108, 101,  91],
          [105, 101,  92],
          ...,
          [ 44,  45,  40],
          [ 37,  38,  33],
          [ 46,  47,  42]],
 
         [[135, 125, 115],
          [129, 120, 111],
          [130, 126, 117],
          ...,
          [ 36,  38,  33],
          [ 41,  43,  38],
          [ 53,  55,  50]],
 
         [[123, 110, 101],
          [118, 109, 100],
          [121, 117, 108],
   