In [9]:
%load_ext autoreload
%autoreload 2

In [10]:
#export
import torch, math, numpy as np

In [14]:
#export
class Sampler:
    def __init__(self, dataset, batchSize:int):
        """
        Creates a random sampler. `dataset` expected to have __len__() and __getitem__() defined.
        
        Basically, when given a dataset with length n and batch size, this will split
        things up into n/batchSize batches. Then, when indexed by an integer, this will return a range of the dataset
        """
        n = len(dataset)
        self.dataset = dataset
        self.nBatches = math.ceil(n / batchSize)
        self.batchSize = batchSize
        self.idxs = np.random.permutation(n)
    def __len__(self):
        return self.nBatches
    def __getitem__(self, i):
        items = self.idxs[i*self.batchSize:(i+1)*self.batchSize]
        return torch.Tensor([self.dataset[item] for item in items]).T
    def __iter__(self):
        return (self[i] for i in range(self.nBatches))

In [12]:
#export
class DataLoader:
    """
    Represents a data loader, meaning can do stuff like:
    
    >>> for xb in xDataLoader:
    >>>     # do something
    """
    def __init__(self, fGenerator:callable, length:int): self.fGenerator = fGenerator; self.length = length
    def __call__(self): return self.fGenerator()
    def __len__(self): return self.length
    def __iter__(self):
        for elem in self.fGenerator(): yield elem
class Data:
    """Just a shell, containing 2 DataLoaders, `train` and `valid`"""
    def __init__(self, train:DataLoader, valid:DataLoader):
        """Expecting train and valid to each return a generator when called upon"""
        self.train = train; self.valid = valid
    @staticmethod
    def fromDataset(dataset, batchSize, trainSplit=0.8):
        sampler = Sampler(dataset, batchSize); numBatches = len(sampler)
        trainRange = range(math.ceil(trainSplit * numBatches))
        testRange = range(math.ceil(trainSplit * numBatches), numBatches)
        def common(_range):
            return DataLoader(lambda: (sampler[idx] for idx in _range), len(_range))
        return Data(common(trainRange), common(testRange))

In [15]:
!exportnb data.ipynb

/home/kelvin/repos/labs/k1lib/k1lib
Current dir: 0, /home/kelvin/repos/labs/k1lib/export.py
File: /home/kelvin/repos/labs/k1lib/k1lib/data.py
running bdist_wheel
running build
installing to build/bdist.linux-x86_64/wheel
running install
running install_egg_info
running egg_info
writing k1lib.egg-info/PKG-INFO
writing dependency_links to k1lib.egg-info/dependency_links.txt
writing top-level names to k1lib.egg-info/top_level.txt
reading manifest file 'k1lib.egg-info/SOURCES.txt'
writing manifest file 'k1lib.egg-info/SOURCES.txt'
Copying k1lib.egg-info to build/bdist.linux-x86_64/wheel/k1lib-0.1.0-py3.8.egg-info
running install_scripts
creating build/bdist.linux-x86_64/wheel/k1lib-0.1.0.dist-info/WHEEL
creating 'dist/k1lib-0.1.0-py3-none-any.whl' and adding 'build/bdist.linux-x86_64/wheel' to it
adding 'k1lib-0.1.0.dist-info/METADATA'
adding 'k1lib-0.1.0.dist-info/WHEEL'
adding 'k1lib-0.1.0.dist-info/top_level.txt'
adding 'k1lib-0.1.0.dist-info/RECORD'
removing build/bdist.linux-x86_64/wh

In [None]:
s