In [35]:
from torch.utils import data #import Dataset, IterableDataset#, Dataloader

class MyMapDataset(data.Dataset):
    def __init__(self, data_):
        self.data_ = data_
    def __len__(self):
        return len(self.data_)
    def __getitem__(self, idx):
        return self.data_[idx]

print('done')

done


In [36]:
from itertools import islice

data_ = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
map_dataset = MyMapDataset(data_)
loader = data.DataLoader(map_dataset, batch_size=4)
for batch in loader:
    print(batch)

tensor([0, 1, 2, 3])
tensor([4, 5, 6, 7])
tensor([ 8,  9, 10, 11])


In [31]:

class MyIterableDataset_simple(data.IterableDataset):
    def __init__(self, data_):
        self.data_ = data_
    def __iter__(self):
        return iter(self.data_)

iterable_dataset = MyIterableDataset_simple(data_)
loader = data.DataLoader(iterable_dataset, batch_size=4)
for batch in loader:
    print(batch)

tensor([0, 1, 2, 3])
tensor([4, 5, 6, 7])
tensor([ 8,  9, 10, 11])


In [151]:
from itertools import cycle, islice, chain, zip_longest
import random

class MyIterableDataset(IterableDataset):
       
    def __init__(self, data_list, batch_size):
        self.data_list = data_list
        self.batch_size = batch_size
        self.reset()
        
    
    def reset(self):
        self.picked = list(range(len(self.data_list)))
        self.num_pick = [1 for _ in range(self.batch_size)]
        
        cc = len(self.data_list) - self.batch_size
        while cc > 0:
            for i in range(self.batch_size):          
                self.num_pick[i] += 1
                cc -= 1
                if cc <=0:
                    break
        
        print('picked: ', self.picked)
        print('num_pick: ', self.num_pick)
        

    @property
    def shuffled_data_list(self):
        if self.picked:
            r_list = self.data_list[self.picked[0]:self.picked[0]+self.num_pick[0]]
            for _ in range(self.num_pick[0]):
                self.picked.pop(0)
            self.num_pick.pop(0)
            return r_list
        else:
            return []
        

    def process_data(self, data):
        for x in data:
            yield x

    def get_stream(self, data_list):
        return chain.from_iterable(map(self.process_data, data_list))
        

    def get_streams(self):
        a_list = []
        for l in range(self.batch_size):
            a_list.append(self.get_stream(self.shuffled_data_list))
        return zip_longest(*a_list,fillvalue=-1)
        

    def __iter__(self):
        return self.get_streams()
    
data_list = [
 [12, 13, 14, 15, 16, 17],
 [27, 28, 29,30],
 [31, 32, 33, 34, 35, 36, 37, 38, 39],
 [40, 41, 42, 43],]    


iterable_dataset = MyIterableDataset(data_list, batch_size=3)
loader = data.DataLoader(iterable_dataset, batch_size=None)

for batch in islice(loader, None):
    print(batch)
print('-----------')
iterable_dataset.reset()
for batch in islice(loader, None):
    print(batch)

picked:  [0, 1, 2, 3]
num_pick:  [2, 1, 1]
[12, 31, 40]
[13, 32, 41]
[14, 33, 42]
[15, 34, 43]
[16, 35, -1]
[17, 36, -1]
[27, 37, -1]
[28, 38, -1]
[29, 39, -1]
[30, -1, -1]
-----------
picked:  [0, 1, 2, 3]
num_pick:  [2, 1, 1]
[12, 31, 40]
[13, 32, 41]
[14, 33, 42]
[15, 34, 43]
[16, 35, -1]
[17, 36, -1]
[27, 37, -1]
[28, 38, -1]
[29, 39, -1]
[30, -1, -1]


In [160]:
import torch
import time

class MyIterableDataset_s(IterableDataset):

    def __init__(self, data_list, batch_size):
        self.data_list = data_list
        self.batch_size = batch_size

    @property
    def shuffled_data_list(self):
        return random.sample(self.data_list, len(self.data_list))

    def process_data(self, data):
        for x in data:
            worker = torch.utils.data.get_worker_info()
            worker_id = id(self) if worker is not None else -1
            start = time.time()
            time.sleep(0.1)
            end = time.time()
            yield x#, worker_id, start, end

    def get_stream(self, data_list):
        return chain.from_iterable(map(self.process_data, cycle(data_list)))

    def get_streams(self):
        return zip(*[self.get_stream(self.shuffled_data_list)
                            for _ in range(self.batch_size)])

    def __iter__(self):
        return self.get_streams()

    @classmethod
    def split_datasets(cls, data_list, batch_size, max_workers):

        for n in range(max_workers, 0, -1):
            if batch_size % n == 0:
                num_workers = n
                break

        split_size = batch_size // num_workers
        return [cls(data_list, batch_size=split_size)
                        for _ in range(num_workers)]

class MultiStreamDataLoader:

    def __init__(self, datasets):
        self.datasets = datasets

    def get_stream_loaders(self):
        return zip(*[data.DataLoader(dataset, num_workers=1, batch_size=None)
                                                 for dataset in datasets])

    def __iter__(self):
        for batch_parts in self.get_stream_loaders():
            yield list(chain(*batch_parts))

            
data_list = [
 [12, 13, 14, 15, 16, 17],
 [27, 28, 29],
 [31, 32, 33, 34, 35, 36, 37, 38, 39],
 [40, 41, 42, 43],]    
            
datasets = MyIterableDataset_s.split_datasets(data_list, batch_size=4, max_workers=1)
loader = MultiStreamDataLoader(datasets)
for batch in islice(loader, 12):
    print(batch)

[31, 40, 40, 31]
[32, 41, 41, 32]
[33, 42, 42, 33]
[34, 43, 43, 34]
[35, 31, 12, 35]
[36, 32, 13, 36]
[37, 33, 14, 37]
[38, 34, 15, 38]
[39, 35, 16, 39]
[12, 36, 17, 27]
[13, 37, 31, 28]
[14, 38, 32, 29]


In [237]:
import numpy as np

def getSum(n):    
    sum = 0
    for digit in str(n): 
      sum += int(digit)      
    return sum

class MyIterableDataset_s2(IterableDataset):

    def __init__(self, data_list, batch_size):
        self.data_list = data_list
        self.batch_size = batch_size
        self.reset()
        
    
    def reset(self):
        self.picked = list(range(len(self.data_list)))
        self.num_pick = [1 for _ in range(self.batch_size)]
        
        cc = len(self.data_list) - self.batch_size
        while cc > 0:
            for i in range(self.batch_size):          
                self.num_pick[i] += 1
                cc -= 1
                if cc <=0:
                    break
        
        print('picked: ', self.picked)
        print('num_pick: ', self.num_pick)
        

    @property
    def shuffled_data_list(self):
        if self.picked:
            r_list = self.data_list[self.picked[0]:self.picked[0]+self.num_pick[0]]
            for _ in range(self.num_pick[0]):
                self.picked.pop(0)
            self.num_pick.pop(0)
            return r_list
        else:
            return []

    def process_data(self, data):
        for x in data:
            worker = torch.utils.data.get_worker_info()
            worker_id = id(self) if worker is not None else -1
            start = time.time()
            time.sleep(0.1)
            end = time.time()
            yield x, getSum(worker_id)#, start, end

    def get_stream(self, data_list):
        return chain.from_iterable(map(self.process_data, data_list))

    def get_streams(self):
        a_list = []
        for l in range(self.batch_size):
            a_list.append(self.get_stream(self.shuffled_data_list))
        return zip_longest(*a_list,fillvalue=[-1,-1])

    def __iter__(self):
        return self.get_streams()

    @classmethod
    def split_datasets(cls, data_list, batch_size, max_workers):

        for n in range(max_workers, 0, -1):
            if batch_size % n == 0:
                num_workers = n
                break
        print('num_workers: ', num_workers)
        data_list.sort(key=len, reverse=True)
        data_lists = np.array_split(data_list, num_workers)
        print('data_lists: ', len(data_lists))
        split_size = batch_size // num_workers
        print('split_size: ', split_size)
        
        for x,i in enumerate(data_lists):
            print('[{}] {} lists: {}'.format(x,len(i),i))
        return [cls(data_lists[i], batch_size=split_size)
                        for i in range(num_workers)]
    
    def pad_lists(self, data_lists):
        max_size = 0
        
        

class MultiStreamDataLoader:

    def __init__(self, datasets):
        self.datasets = datasets

    def get_stream_loaders(self):
        return zip(*[data.DataLoader(dataset, num_workers=1, batch_size=None)
                                                 for dataset in datasets])

    def __iter__(self):
        for batch_parts in self.get_stream_loaders():
            yield list(chain(*batch_parts))

            
data_list = [
 [12, 13, 14, 15, 16, 17],
 [27, 28, 29],
 [31, 32, 33, 34, 35, 36, 37, 38, 39, 40],
 [40, 41, 42, 43, 44],]     
           
datasets = MyIterableDataset_s2.split_datasets(data_list, batch_size=4, max_workers=2)
loader = MultiStreamDataLoader(datasets)
print('\n') 
for batch in islice(loader, None):
    print(batch)

num_workers:  2
data_lists:  2
split_size:  2
[0] 2 lists: [list([31, 32, 33, 34, 35, 36, 37, 38, 39, 40])
 list([12, 13, 14, 15, 16, 17])]
[1] 2 lists: [list([40, 41, 42, 43, 44]) list([27, 28, 29])]
picked:  [0, 1]
num_pick:  [1, 1]
picked:  [0, 1]
num_pick:  [1, 1]


[[31, 67], [12, 67], [40, 76], [27, 76]]
[[32, 67], [13, 67], [41, 76], [28, 76]]
[[33, 67], [14, 67], [42, 76], [29, 76]]
[[34, 67], [15, 67], [43, 76], [-1, -1]]
[[35, 67], [16, 67], [44, 76], [-1, -1]]
