In [123]:
import torch
from torch.utils.data import Dataset, DataLoader, random_split

class LinearRegressionDataset(Dataset):
  def __init__(self, N=50, m=-3, b=2, *args, **kwargs):
    # N: number of samples, e.g. 50
    # m: slope
    # b: offset
    super().__init__(*args, **kwargs)

    self.x = torch.rand(N, 2)  # uniform distribution (0~1)
    self.noise = torch.rand(N) * 0.2  # 0 ~ 0.2
    self.m = m
    self.b = b
    self.y = (torch.sum(self.x * self.m) + self.b + self.noise).unsqueeze(-1) # [50, 1]

  def __len__(self): # 필수 method
    return len(self.x)

  def __getitem__(self, idx): # 필수 method : 주어진 idx에 해당하는 x, y return
    return self.x[idx], self.y[idx] # self.x[idx] : size[N,2]에서 2개의 input 값, self.y[idx] : size[N,1]에서 1개의 target 값

  def __str__(self):
    str = "Data Size: {0}, Input Shape: {1}, Target Shape: {2}".format(
      len(self.x), self.x.shape, self.y.shape
    )
    return str

In [124]:
if __name__ == "__main__":
  linear_regression_dataset = LinearRegressionDataset()

  print(linear_regression_dataset) # __str__ method 호출

  print("#" * 50, 1)

  for idx, sample in enumerate(linear_regression_dataset): # __getitem__ method 호출해, sample에 (self.x[idx], self.y[idx])
    input, target = sample
    print("{0} - {1}: {2}".format(idx, input, target))

  train_dataset, validation_dataset, test_dataset = random_split(linear_regression_dataset, [0.7, 0.2, 0.1])

  print("#" * 50, 2)

  print(len(train_dataset), len(validation_dataset), len(test_dataset))

  print("#" * 50, 3)

  train_data_loader = DataLoader(
    dataset=train_dataset,
    batch_size=4,
    shuffle=True
  )

  for idx, batch in enumerate(train_data_loader):
    input, target = batch
    print("{0} - {1}: {2}".format(idx, input, target))


Data Size: 50, Input Shape: torch.Size([50, 2]), Target Shape: torch.Size([50, 1])
################################################## 1
0 - tensor([0.5042, 0.7405]): tensor([-145.8085])
1 - tensor([0.6344, 0.5684]): tensor([-145.8837])
2 - tensor([0.5913, 0.7975]): tensor([-145.9206])
3 - tensor([0.6232, 0.8485]): tensor([-145.8979])
4 - tensor([0.0567, 0.2762]): tensor([-145.7822])
5 - tensor([0.3051, 0.3364]): tensor([-145.8784])
6 - tensor([0.9858, 0.6632]): tensor([-145.9324])
7 - tensor([0.8453, 0.1840]): tensor([-145.8899])
8 - tensor([0.5854, 0.6962]): tensor([-145.9429])
9 - tensor([0.5052, 0.5316]): tensor([-145.9304])
10 - tensor([0.1812, 0.6073]): tensor([-145.9492])
11 - tensor([0.1826, 0.7992]): tensor([-145.8569])
12 - tensor([0.9503, 0.8380]): tensor([-145.7615])
13 - tensor([0.4230, 0.6522]): tensor([-145.9547])
14 - tensor([0.5197, 0.1964]): tensor([-145.8913])
15 - tensor([0.1808, 0.8157]): tensor([-145.9089])
16 - tensor([0.3629, 0.5258]): tensor([-145.7888])
17 - te