In [1]:
import torch
from torch.utils.data import Dataset

In [3]:
class ToySet(Dataset):
    def __init__(self, length=100, transform=None):
        self.x = 2 * torch.ones(length, 2)
        self.y = torch.ones(length, 1)
        self.len = length
        self.transform = transform

    def __getitem__(self, index):
        sample = self.x[index], self.y[index]
        if self.transform:
            sample = self.transform(sample)
        return sample

    def __len__(self):
        return self.len

In [4]:
dataset = ToySet()
len(dataset)

100

In [5]:
dataset[0]

(tensor([2., 2.]), tensor([1.]))

In [6]:
for i in range(3):
    x, y = dataset[i]
    print(i, "x:", x, "y:", y)

0 x: tensor([2., 2.]) y: tensor([1.])
1 x: tensor([2., 2.]) y: tensor([1.])
2 x: tensor([2., 2.]) y: tensor([1.])


In [17]:
class AddMult(object):
    def __init__(self, addx=1, muly=1):
        self.addx = addx
        self.muly = muly

    def __call__(self, sample):
        x = sample[0]
        y = sample[1]
        x = x + self.addx
        y = y * self.muly
        return x, y

In [18]:
dataset2 = ToySet(transform=AddMult(addx=3, muly=4))
for i in range(3):
    x, y = dataset2[i]
    print(i, "x:", x, "y:", y)

0 x: tensor([5., 5.]) y: tensor([4.])
1 x: tensor([5., 5.]) y: tensor([4.])
2 x: tensor([5., 5.]) y: tensor([4.])


In [31]:
class Mult(object):
    def __init__(self, mul=100):
        self.mul = mul

    def __call__(self, sample):
        x = sample[0]
        y = sample[1]
        x = x * self.mul
        y = y * self.mul
        return x, y


In [32]:
from torchvision import transforms

data_transform = transforms.Compose([
    AddMult(addx=4, muly=9),
    Mult(mul=10)
])

In [34]:
dataset3 = ToySet(transform=data_transform)
for i in range(3):
    x, y = dataset3[i]
    print(i, "x:", x, "y:", y)

0 x: tensor([60., 60.]) y: tensor([90.])
1 x: tensor([60., 60.]) y: tensor([90.])
2 x: tensor([60., 60.]) y: tensor([90.])
