In [17]:
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
import numpy as np
import math


class WineDataset(Dataset):
    
    def __init__(self, transform):
        xy = np.loadtxt('wine.csv', delimiter=',', dtype=np.float32, skiprows=1)
        self.n_samples = xy.shape[0]
        
        
        self.x = xy[:, 1:]
        self.y = xy[:, [0]]
        
        self.transform = transform
        
    def __getitem__(self, index):
        sample =  self.x[index], self.y[index]
        
        if self.transform:
            sample = self.transform(sample)
        
        return sample
    
    def __len__(self):
        return self.n_samples


class ToTensor:
    
    def __call__(self, sample):
        inputs, targets = sample
        return torch.from_numpy(inputs), torch.from_numpy(targets)
    
    
class MulTransform:
    
    def __init__(self, factor):
        self.factor = factor
        
    def __call__(self, sample):
        inputs, targets = sample
        inputs *= self.factor
        
        return inputs, targets
            
        


dataset = WineDataset(transform = ToTensor())
first_data = dataset[0]
features, labels = first_data
print(type(features), type(labels))




dataset = WineDataset(transform = None)
first_data = dataset[0]
features, labels = first_data
print(type(features), type(labels))


dataset = WineDataset(transform = MulTransform(2))
first_data = dataset[0]
features, labels = first_data
print(features, labels)

print()
print()

composed = torchvision.transforms.Compose([ToTensor(), MulTransform(2)])
dataset = WineDataset(transform = composed)
first_data = dataset[0]
features, labels = first_data
print(features)
print(type(features), type(labels))


<class 'torch.Tensor'> <class 'torch.Tensor'>
<class 'numpy.ndarray'> <class 'numpy.ndarray'>
[2.846e+01 3.420e+00 4.860e+00 3.120e+01 2.540e+02 5.600e+00 6.120e+00
 5.600e-01 4.580e+00 1.128e+01 2.080e+00 7.840e+00 2.130e+03] [1.]


tensor([2.8460e+01, 3.4200e+00, 4.8600e+00, 3.1200e+01, 2.5400e+02, 5.6000e+00,
        6.1200e+00, 5.6000e-01, 4.5800e+00, 1.1280e+01, 2.0800e+00, 7.8400e+00,
        2.1300e+03])
<class 'torch.Tensor'> <class 'torch.Tensor'>
