# PyTorch Dataset

若数据和标签使用 `numpy.ndarray` 保存，可以直接使用 `torch.utils.data.TensorDataset` 来完成转换

In [None]:
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader

x_train = np.zeros((6000, 28, 28, 1))
y_train = np.zeros((6000, ))

train_set = TensorDataset(torch.from_numpy(x_train),
                          torch.from_numpy(y_train))

自定义数据集需要继承 `torch.utils.data.Dataset` 类，并实现 `__len__` 和 `__getitem__` 方法

此处假设使用 csv 文件保存数据集的信息，csv 文件每行是一个二元组 `(图片路径, 标签)`

In [None]:
import cv2
import os
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor

class MyDataset(Dataset):
    def __init__(self, csv='train.csv'):
        # 文件路径
        self.base = os.path.dirname(csv)
        # 读文件
        with open(csv) as f:
            # 去掉第一行
            f.readline()
            # 读剩下的行，并去掉多余的字符
            lines = [line.strip() for line in f.readlines()]
        # 图片路径
        self.paths = [line.split(',')[0] for line in lines]
        self.labels = [int(line.split(',')[-1]) for line in lines]
        # 转换器，以 ToTensor 为例
        self.transforms = ToTensor()
    
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, item):
        path = self.paths[item]
        label = self.labels[item]
        img = cv2.imread(os.path.join(self.base, path))
        img = self.transforms(img)
        return img, label
    

数据集定义完毕后需要使用 `DataLoader` 进行读取

In [None]:
batch_size = 64
num_workers = 4 # 多线程读取，线程数

train_loader = DataLoader(train_set,
                          batch_size=batch_size,
                          shuffle=True, # 是否打乱数据
                          num_workers=num_workers)