In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd

In [2]:
class CSVClassificationDataset(Dataset):
    def __init__(self, csv_file_path, transform=None, target_transform=None):

        self.data = self.preprocess(csv_file_path)
        self.X = self.data.iloc[:,1:].values.astype('float32')
        self.y = self.data.iloc[:,0].values.astype('int64')
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.X[idx]
        label = self.y[idx]

        if self.transform:
            sample = self.transform(sample)
        if self.target_transform:
            label = self.target_transform(label)

        return torch.tensor(sample), torch.tensor(label)

    def preprocess(self, csv_file_path):
        df = pd.read_csv(csv_file_path)
        yes_no_columns = ['mainroad', 'guestroom', 'basement',
                          'hotwaterheating', 'airconditioning', 'prefarea']

        for col in yes_no_columns:
            if col in df.columns:
                df[col] = df[col].map({'yes': 1, 'no': 0})
        df = pd.get_dummies(df, columns=['furnishingstatus'], drop_first=False, dtype = int)
        return df


In [3]:
path = 'Housing.csv'
dataset = CSVClassificationDataset(path)
dataset.__len__()

545

In [4]:
from torch.utils.data import DataLoader
train_dataloader = DataLoader(dataset, batch_size = 16, shuffle = True)

In [5]:
next(iter(train_dataloader))

[tensor([[4.5000e+03, 3.0000e+00, 2.0000e+00, 3.0000e+00, 1.0000e+00, 0.0000e+00,
          0.0000e+00, 1.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 1.0000e+00,
          0.0000e+00, 0.0000e+00],
         [5.5000e+03, 3.0000e+00, 2.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
          1.0000e+00, 0.0000e+00, 0.0000e+00, 2.0000e+00, 1.0000e+00, 1.0000e+00,
          0.0000e+00, 0.0000e+00],
         [9.8600e+03, 3.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          1.0000e+00, 0.0000e+00],
         [3.1800e+03, 2.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 1.0000e+00],
         [3.4600e+03, 4.0000e+00, 1.0000e+00, 2.0000e+00, 1.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          1.0000e+00, 0.0000e+00],
     