In [None]:
class DF2Paths():
    def __init__(self, path, fps=24):
        self.path, self.fps = path, fps
        
    def __call__(self, item:pd.Series):
        def fr(t): return int(float(t)*self.fps)
    
        Id, start, end = item['id'], item['start'], item['end']
        start, end = fr(start), fr(end)
        step = -1 if start > end else 1                     # If start is greater than end,
                                                            # it reverses the order of the for loop
        vid = L()                                           # This because it seems some videos are in reverse
        for n in range(start, end, step):
            fr_path = self.path/'Charades_v1_rgb'/Id/f'{Id}-{n:0>6d}.jpg'
            if os.path.exists(fr_path):
                vid.append(fr_path)
        return vid

In [None]:
@delegates()
class UniformizedDataLoader(TfmdDL): 
    def __init__(self, dataset=None, n_el=4, n_lbl=4, **kwargs):
        kwargs['bs'] = n_el*n_lbl
        super().__init__(dataset, **kwargs)
        store_attr(self, 'n_el,n_lbl')
        self.lbls = list(map(int, self.dataset.tls[1]))
        self.dl_vocab = list(range(len(self.vocab)))
        
    def before_iter(self):
        super().before_iter()
        lbl2idxs = {lbl:[] for lbl in self.dl_vocab}
        for i, lbl in enumerate(self.lbls): lbl2idxs[lbl].append(i)
        
        if self.shuffle: [random.shuffle(v) for v in lbl2idxs.values()]
        self.lbl2idxs = lbl2idxs
        
    def get_labeled_elements(self, lbl, n_el):
        els_of_lbl = []
        while len(els_of_lbl) < n_el:
            item = self.do_item(self.lbl2idxs[lbl].pop())
            if item is not None: els_of_lbl.append(item) 
        return els_of_lbl
        
    def create_batches(self, samps):
        n_lbl, n_el = self.n_lbl, self.n_el
        self.it = iter(self.dataset) if self.dataset is not None else None
        
        while len(self.dl_vocab) >= n_lbl:
            
            batch_lbls, b = [], []
            
            while len(batch_lbls) < n_lbl:
                try: i = random.randint(0, len(self.dl_vocab) - 1)
                except ValueError: raise CancelBatchException
                lbl = self.dl_vocab.pop(i)
                if len(self.lbl2idxs[lbl]) < n_lbl: continue
                
                try: els_of_lbl = self.get_labeled_elements(lbl, n_el)
                except IndexError: continue
                    
                b.extend(els_of_lbl)
                batch_lbls.append(lbl)
                
            self.dl_vocab.extend(batch_lbls)
            
            yield self.do_batch(b)
            
        self.dl_vocab = list(range(len(self.vocab)))         

In [None]:
#export
def uniformize_dataset(items, lbls, vocab=None, n_el=3, n_lbl=3, shuffle=True):
    if vocab is None: vocab = list(set(lbls))
    lbl2idxs = {lbl:[] for lbl in vocab}
    for i, lbl in enumerate(lbls): lbl2idxs[lbl].append(i)
    for lbl, idxs in lbl2idxs.items(): 
        if len(idxs) < n_el: vocab.remove(lbl)       
    if shuffle: [random.shuffle(v) for v in lbl2idxs.values()]
    idxs = []
    while len(vocab) >= n_lbl:
        lbl_samples = random.sample(vocab, n_lbl)
        for lbl in lbl_samples:
            i = 0
            while i < n_el:
                i += 1
                idx = lbl2idxs[lbl].pop()
                idxs.append(idx)
            if len(lbl2idxs[lbl]) <= n_el:
                vocab.remove(lbl)
    return getattr(items, 'iloc', items)[idxs]

In [None]:
items = pd.read_csv(path_charades/'df0.csv', index_col=0)
items = uniformize_dataset(items, items['lbl'])
items.tail(6)

In [None]:
#export
class UniformizedShuffle():
    def __init__(self, lbls, vocab=None, n_el=4, n_lbl=4):
        self.lbls = lbls
        if vocab is None: vocab = list(set(lbls))
        self.vocab = vocab
        self.n_el = n_el
        self.n_lbl = n_lbl
    def __call__ (self, items):
        return uniformize_dataset(items, lbls=self.lbls, vocab=self.vocab, n_el=self.n_el, n_lbl=self.n_lbl)

In [None]:
df = pd.read_csv(path_charades/'df0.csv', index_col=0)
un = UniformizedShuffle(items['lbl'])
un(items).tail(7)

NameError: name 'pd' is not defined