In [2]:
import torch
from torch.utils.data import Dataset

In [4]:
class toy_set(Dataset):
    def __init__(self, length=100, transform=None):
        self.x = 2*torch.ones(length, 2)
        self.y = torch.ones(length, 1)

        self.len = length
        self.transform = transform

    def __getitem__(self, index):
        sample = self.x[index], self.y[index]
        if self.transform:
            sample = self.transform(sample)
        return sample

    def __len__(self):
        return self.len

In [5]:
dataset = toy_set()

In [6]:
dataset[0]

(tensor([2., 2.]), tensor([1.]))

In [7]:
for i in range(3):
    x, y = dataset[i]
    print(f"index={i}, x={x}, y={y}")

index=0, x=tensor([2., 2.]), y=tensor([1.])
index=1, x=tensor([2., 2.]), y=tensor([1.])
index=2, x=tensor([2., 2.]), y=tensor([1.])


In [8]:
### transforms ###
class add_mult(object):
    def __init__(self, addx=1, muly=1):
        self.addx = addx
        self.muly = muly

    def __call__(self, sample):
        x = sample[0]
        y = sample[1]
        x = x + self.addx
        y = y * self.muly
        sample = x, y
        return sample

In [9]:
dataset = toy_set()

In [10]:
a_m = add_mult()

In [12]:
x_, y_ = a_m(dataset[0])
print(x_, y_)

tensor([3., 3.]) tensor([1.])


In [13]:
dataset_ = toy_set(transform=a_m)

In [14]:
dataset[0]

(tensor([2., 2.]), tensor([1.]))

In [15]:
dataset_[0]

(tensor([3., 3.]), tensor([1.]))

In [17]:
### transforms compose ###
class mult(object):
    def __init__(self, mul=100):
        self.mul = mul

    def __call__(self, sample):
        x = sample[0]
        y = sample[1]
        x = x * self.mul
        y = y * self.mul
        sample = x, y
        return sample

In [24]:
import torch

In [25]:
from torchvision import transforms

In [26]:
data_transform = transforms.Compose([add_mult(), mult()])

In [27]:
x_, y_ = data_transform(dataset[0])
print(x_, y_)

tensor([300., 300.]) tensor([100.])


In [28]:
data_set_tr = toy_set(transform=data_transform)

In [29]:
data_set_tr[0]

(tensor([300., 300.]), tensor([100.]))