In [43]:
import torch

In [44]:
x_tr = torch.randn((100, 3))
y_tr = torch.tensor([i for i in range(100)])

x_val = torch.randn((100, 3))
y_val = torch.tensor([100-i for i in range(100)])

In [45]:
class DataSet:
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, i):
        return self.x[i], self.y[i]

In [46]:
d_tr = Dataset(x_tr, y_tr)

In [47]:
d_tr[12]

(tensor([ 2.0908, -1.6859,  2.0738]), tensor(12))

In [48]:
d_tr[12:15]

(tensor([[ 2.0908, -1.6859,  2.0738],
         [ 0.0132, -0.1084, -0.3203],
         [-0.2289,  0.1321,  1.1599]]), tensor([12, 13, 14]))

In [49]:
len(d_tr)

100

In [58]:
def collate(b):
    xs, ys = zip(*b)
    return torch.stack(xs), torch.stack(ys)

class DataLoader:
    def __init__(self, dataset, batch_size, shuffle=False):
        self.n = len(dataset)
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle

    def __iter__(self):
        idxs = torch.randperm(self.n) if self.shuffle else torch.arange(self.n)
        for i in range(0, self.n, self.batch_size):
            yield self.dataset[idxs[i: i + self.batch_size]]


In [59]:
sampler_no_shuffle = DataLoader(d_tr, 11, False)
sampler_shuffle = DataLoader(d_tr, 11, True)

In [63]:
xb, yb = next(iter(sampler_no_shuffle))
xb, yb

(tensor([[-0.8577,  1.1585, -0.0140],
         [-1.5455,  0.9046, -0.8732],
         [-0.5231,  0.1205, -0.8968],
         [ 0.2439, -0.0118,  1.2558],
         [ 0.2621, -1.1611, -0.7014],
         [-0.5115,  0.8443,  0.5019],
         [ 1.9589,  0.1374, -2.1609],
         [-0.1661,  0.1940, -1.2070],
         [ 0.4323, -0.3291, -1.3778],
         [-0.6309, -0.4967,  0.5715],
         [ 0.9088,  1.2098,  0.4684]]),
 tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10]))

In [65]:
xb, yb = next(iter(sampler_shuffle))
xb, yb

(tensor([[-1.3935,  0.2471,  1.2907],
         [-0.4372, -1.5295,  0.4559],
         [-0.2758, -0.0591,  1.1778],
         [ 0.5983, -0.6433, -0.6772],
         [-0.4511,  0.6935,  0.0808],
         [ 0.0923, -0.4083, -0.3554],
         [-0.9065, -0.4453,  1.6310],
         [ 1.0713,  1.7136, -1.3117],
         [-2.0185, -0.1837,  0.4337],
         [ 0.6355, -0.1279, -1.1140],
         [ 1.7717, -0.4918,  0.1049]]),
 tensor([28, 27, 70, 52, 81, 69, 54, 95, 47, 84, 90]))

In [66]:
class DataBunch:
    def __init__(self, train_ds, valid_ds, batch_size, n_in=None, n_out=None):
        self.train_gen = DataLoader(train_ds, batch_size, shuffle=True)
        self.valid_gen = DataLoader(valid_ds, batch_size)
        self.n_in = n_in
        self.n_out = n_out

    @property
    def train_ds(self):
        return self.train_gen.dataset

    @property
    def valid_ds(self):
        return self.valid_gen.dataset

In [67]:
db = DataBunch(DataSet(x_tr, y_tr), DataSet(x_val, y_val),
              batch_size=11, n_in = 3, n_out=1)

In [68]:
db.n_in, db.n_out

(3, 1)

In [70]:
db.train_ds.x.shape

torch.Size([100, 3])

In [71]:
next(iter(db.train_gen))

(tensor([[-0.1661,  0.1940, -1.2070],
         [ 0.6420,  1.5221,  0.2802],
         [-0.6288, -0.3965,  0.1972],
         [-0.6775,  0.8000, -0.9840],
         [-0.3589, -1.5455,  0.5182],
         [ 0.7506, -1.0180,  0.1828],
         [ 0.6052,  0.2159, -0.2645],
         [-1.0243, -0.2480, -1.2089],
         [ 1.2348,  0.7668, -0.8181],
         [ 1.1693, -0.5535, -0.8331],
         [ 0.0555,  0.5442, -0.8952]]),
 tensor([ 7, 16, 89, 63, 93, 42, 32, 60, 48, 40, 26]))

In [72]:
next(iter(db.valid_gen))

(tensor([[ 0.1451,  0.9444,  1.1938],
         [-0.3524,  0.3700,  1.2762],
         [-1.4916,  0.1727, -0.3980],
         [ 1.7773, -0.9729, -1.3781],
         [-1.0903,  1.5241, -0.1308],
         [-0.0421, -2.9725, -0.7142],
         [-0.3718,  1.1710, -0.6711],
         [-0.8236,  1.9846, -0.0979],
         [ 1.7163,  0.1831, -0.8909],
         [-0.8961, -0.6836, -0.4825],
         [-0.0070, -0.2899, -0.1987]]),
 tensor([100,  99,  98,  97,  96,  95,  94,  93,  92,  91,  90]))

In [4]:
def wrap(f):
    def _inner(x):
        l = f(x)
        return l + 10
    return _inner


In [5]:
@wrap
def f(x):
    return 5*x

In [6]:
f(5)

35