In [1]:
import torch 
from torch.utils.data import Dataset 

In [2]:
# Set manual seed 
torch.manual_seed(1)

<torch._C.Generator at 0x1e4374ff4f0>

In [3]:
# Create a simple dataset 

class ToySet(Dataset): 
    # Constructor with default values 
    def __init__(self, length: int = 100, transform = None): 
        self.len = length 
        self.x = 2 * torch.ones(length, 2) 
        self.y = torch.ones(length, 1)
        self.transform = transform 

    # Getter 
    def __getitem__(self, index: int): 
        sample = self.x[index], self.y[index]
        if self.transform: 
            sample = self.transform(sample)
        return sample 
    
    # Get length 
    def __len__(self): 
        return self.len 

In [4]:
our_dataset = ToySet() 
print(f"Value on index 0: {our_dataset[0]}")

Value on index 0: (tensor([2., 2.]), tensor([1.]))


In [11]:
# Create a class for transforming data 

from typing import Any


class AddMult(object): 
    # Constructor 
    def __init__(self, addx = 1, muly = 2): 
        self.addx = addx 
        self.muly = muly 
    
    # Executor 
    def __call__(self, sample):  
        x = sample[0] 
        y = sample[1]  
        x = x + self.addx 
        y = y * self.muly 
        sample = x, y
        return sample 

In [12]:
# Create a transform object 
a_m = AddMult() 
dataset = ToySet() 


In [13]:
# Loop 
for i in range(10): 
    x, y = dataset[i] 
    print(f"Index: {i}, X Orig: {x}, Y Orig: {y}")
    x_, y_ = a_m(dataset[i]) 
    print(f"Index: {i}, Transformed X: {x_}, Transformed Y: {y_}") 

Index: 0, X Orig: tensor([2., 2.]), Y Orig: tensor([1.])
Index: 0, Transformed X: tensor([3., 3.]), Transformed Y: tensor([2.])
Index: 1, X Orig: tensor([2., 2.]), Y Orig: tensor([1.])
Index: 1, Transformed X: tensor([3., 3.]), Transformed Y: tensor([2.])
Index: 2, X Orig: tensor([2., 2.]), Y Orig: tensor([1.])
Index: 2, Transformed X: tensor([3., 3.]), Transformed Y: tensor([2.])
Index: 3, X Orig: tensor([2., 2.]), Y Orig: tensor([1.])
Index: 3, Transformed X: tensor([3., 3.]), Transformed Y: tensor([2.])
Index: 4, X Orig: tensor([2., 2.]), Y Orig: tensor([1.])
Index: 4, Transformed X: tensor([3., 3.]), Transformed Y: tensor([2.])
Index: 5, X Orig: tensor([2., 2.]), Y Orig: tensor([1.])
Index: 5, Transformed X: tensor([3., 3.]), Transformed Y: tensor([2.])
Index: 6, X Orig: tensor([2., 2.]), Y Orig: tensor([1.])
Index: 6, Transformed X: tensor([3., 3.]), Transformed Y: tensor([2.])
Index: 7, X Orig: tensor([2., 2.]), Y Orig: tensor([1.])
Index: 7, Transformed X: tensor([3., 3.]), Tran

In [16]:
# Create a new dataset with AddMult object as transform 
cust_data_set = ToySet(transform=AddMult) 

In [17]:

for i in range(10): 
    x, y = dataset[i] 
    print(f"Index: {i}, Original X: {x}, Original y: {y}") 
    x_, y_ = cust_dataset[i] 
    print(f"Index: {i}, Transformed X: {x_}, Transformed Y: {y_}") 

Index: 0, Original X: tensor([2., 2.]), Original y: tensor([1.])


TypeError: cannot unpack non-iterable AddMult object