In [112]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import math

In [113]:
class WineDataset(Dataset):
    def __init__(self):
        xy = np.loadtxt('./data/wine.csv', delimiter=',', dtype=np.float32, skiprows=1)
        self.x = torch.from_numpy(xy[:, 1:])
        self.y = torch.from_numpy(xy[:, 0])
        self.n_samples = xy.shape[0]
        
#    def __repr__(self):
#            return '{}, {}'.format(self.x, self.y)
        
    def __getitem__(self, index):
        return self.x[index], self.y[index]
    
    def __len__(self):
        return self.n_samples       

In [114]:
dataset = WineDataset()
features, labels = dataset[0]
features, labels

(tensor([1.4230e+01, 1.7100e+00, 2.4300e+00, 1.5600e+01, 1.2700e+02, 2.8000e+00,
         3.0600e+00, 2.8000e-01, 2.2900e+00, 5.6400e+00, 1.0400e+00, 3.9200e+00,
         1.0650e+03]),
 tensor(1.))

In [115]:
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

In [116]:
dataiter = iter(dataloader)
data = next(dataiter)
data

[tensor([[1.2420e+01, 4.4300e+00, 2.7300e+00, 2.6500e+01, 1.0200e+02, 2.2000e+00,
          2.1300e+00, 4.3000e-01, 1.7100e+00, 2.0800e+00, 9.2000e-01, 3.1200e+00,
          3.6500e+02],
         [1.2330e+01, 9.9000e-01, 1.9500e+00, 1.4800e+01, 1.3600e+02, 1.9000e+00,
          1.8500e+00, 3.5000e-01, 2.7600e+00, 3.4000e+00, 1.0600e+00, 2.3100e+00,
          7.5000e+02],
         [1.2810e+01, 2.3100e+00, 2.4000e+00, 2.4000e+01, 9.8000e+01, 1.1500e+00,
          1.0900e+00, 2.7000e-01, 8.3000e-01, 5.7000e+00, 6.6000e-01, 1.3600e+00,
          5.6000e+02],
         [1.3680e+01, 1.8300e+00, 2.3600e+00, 1.7200e+01, 1.0400e+02, 2.4200e+00,
          2.6900e+00, 4.2000e-01, 1.9700e+00, 3.8400e+00, 1.2300e+00, 2.8700e+00,
          9.9000e+02]]),
 tensor([2., 2., 3., 1.])]

In [117]:
num_epochs = 2
total_samples = len(dataset)
n_iterations = math.ceil(total_samples/4)
total_samples, n_iterations, len(list(dataloader))

(178, 45, 45)

In [118]:
for epoch in range(num_epochs):
    for i, (inputs, labels) in enumerate(dataloader):
        if (i+1) % 5 == 0:
            print('{}/{}, step {}/{}, input.shape = {}'.format(epoch+1, num_epochs, i+1,
                                                               n_iterations, inputs.shape[0]))

1/2, step 5/45, input.shape = 4
1/2, step 10/45, input.shape = 4
1/2, step 15/45, input.shape = 4
1/2, step 20/45, input.shape = 4
1/2, step 25/45, input.shape = 4
1/2, step 30/45, input.shape = 4
1/2, step 35/45, input.shape = 4
1/2, step 40/45, input.shape = 4
1/2, step 45/45, input.shape = 2
2/2, step 5/45, input.shape = 4
2/2, step 10/45, input.shape = 4
2/2, step 15/45, input.shape = 4
2/2, step 20/45, input.shape = 4
2/2, step 25/45, input.shape = 4
2/2, step 30/45, input.shape = 4
2/2, step 35/45, input.shape = 4
2/2, step 40/45, input.shape = 4
2/2, step 45/45, input.shape = 2
