In [1]:
import torch
from torch.utils.data import Dataset

In [2]:
class toy_set(Dataset):
    
    def __init__(self, length=100, transform=None):
        self.x = 2 * torch.ones(length, 2)
        self.y = torch.ones(length, 1)
        self.len = length
        self.transform = transform
        
    def __getitem__(self, index):
        sample = self.x[index], self.y[index]
        if self.transform:
            sample = self.transform(sample)
        return sample
    
    def __len__(self):
        return self.len

In [3]:
dataset = toy_set()

In [4]:
len(dataset)

100

In [5]:
dataset[0]

(tensor([2., 2.]), tensor([1.]))

In [6]:
for i in range(10):
    x, y = dataset[i]
    print(i, f'x: {x}, y: {y}')

0 x: tensor([2., 2.]), y: tensor([1.])
1 x: tensor([2., 2.]), y: tensor([1.])
2 x: tensor([2., 2.]), y: tensor([1.])
3 x: tensor([2., 2.]), y: tensor([1.])
4 x: tensor([2., 2.]), y: tensor([1.])
5 x: tensor([2., 2.]), y: tensor([1.])
6 x: tensor([2., 2.]), y: tensor([1.])
7 x: tensor([2., 2.]), y: tensor([1.])
8 x: tensor([2., 2.]), y: tensor([1.])
9 x: tensor([2., 2.]), y: tensor([1.])


In [7]:
class add_mult(object):
    
    def __init__(self, addx=1,  muly=1):
        self.addx = addx
        self.muly = muly
        
    def __call__(self, sample):
        x = sample[0]
        y = sample[1]
        
        x = x+self.addx
        y = y*self.muly
        sample = x, y
        return sample
    
class mult(object):
    
    def __init__(self, mul=1):
        self.mul = mul
        
    def __call__(self, sample):
        x = sample[0]
        y = sample[1]
        
        x = x*self.mul
        y = y*self.mul
        sample = x,y
        return sample

In [8]:
am = add_mult(1, 2)
dataset_ = toy_set(transform = am)

In [9]:
for i in range(10):
    x, y = dataset_[i]
    print(i, f'x: {x}, y: {y}')

0 x: tensor([3., 3.]), y: tensor([2.])
1 x: tensor([3., 3.]), y: tensor([2.])
2 x: tensor([3., 3.]), y: tensor([2.])
3 x: tensor([3., 3.]), y: tensor([2.])
4 x: tensor([3., 3.]), y: tensor([2.])
5 x: tensor([3., 3.]), y: tensor([2.])
6 x: tensor([3., 3.]), y: tensor([2.])
7 x: tensor([3., 3.]), y: tensor([2.])
8 x: tensor([3., 3.]), y: tensor([2.])
9 x: tensor([3., 3.]), y: tensor([2.])


In [15]:
from torchvision import transforms

In [20]:
data_transform = transforms.Compose([add_mult(1, 2), mult(2)])

In [21]:
x_, y_ = data_transform(dataset[0])

In [22]:
x_

tensor([6., 6.])

In [23]:
y_

tensor([4.])