In [1]:
# To construct own class we must import the abstact class.
import torch
from torch.utils.data import Dataset

In [2]:
# This is a subclass of Dataset class.
class toy_set(Dataset):
    def __init__(self, length=100, transform=None):
        self.x = 2*torch.ones(length, 2)
        self.y = torch.ones(length, 1)
        self.len = length
        self.transform = transform
    
    def __getitem__(self, index):
        sample = self.x[index], self.y[index]
        if self.transform:
            sample = self.transform(sample)
        return sample
    
    def __len__(self):
        return self.len

In [3]:
# Instance of class.
dataset = toy_set()

In [4]:
# len function return number of samples.
print("Number of samples :",len(dataset))

Number of samples : 100


In [5]:
# Accessing the samples.
print(dataset[0])

(tensor([2., 2.]), tensor([1.]))


In [6]:
# Transforms.
# They are used to change the dataset.
# We create callable class instead of function.
class add_mul(object):
    def __init__(self,addx=1,muly=1):
        self.addx = addx
        self.muly = muly
    
    def __call__(self, sample):
        x = sample[0]
        y = sample[1]
        x = x + self.addx
        y = y + self.muly
        sample = x, y
        return sample

In [7]:
# Apply transform directly to the sample.
addmul = add_mul()
x_, y_ = addmul(dataset[0])
print(x_,y_)

tensor([3., 3.]) tensor([2.])


In [8]:
# Apply transform to the class constructor.
dataset_ = toy_set(transform=addmul)
print(dataset_[2])

(tensor([3., 3.]), tensor([2.]))


In [9]:
# Apply multiple transforms.
# We use compose of the transfroms.
# Here is another transform.
class mult(object):
    def __init__(self, mul=100):
        self.mul = mul
    
    def __call__(self, sample):
        x = sample[0]
        y = sample[1]
        x = x * self.mul
        y = y * self.mul
        sample = x, y
        return sample

In [10]:
# Importing compose.
from torchvision import transforms

dataset_tranforms = transforms.Compose([add_mul(), mult()]) # make list of transforms.

# Then apply to class constructor.
dataset_tr = toy_set(transform=dataset_tranforms)

print(dataset_tr[3])

(tensor([300., 300.]), tensor([200.]))
