# TensorFlow Dataset

如果遇到大量的图片文件，将其全部写入 npy 文件，训练时读到内存可能会遇到内存不足的问题

一种解决方法是将图片路径、标签写入文件，训练时再读取。这就需要自定义数据集

自定义数据集需要继承 `tensorflow.python.keras.utils.data_utils.Sequence` 类，并实现 `__len__` 和 `__getitem__` 方法

假设 csv 文件中每行文件的内容为：图片绝对路径;标签

In [None]:
import cv2
import numpy as np
from tensorflow.keras.utils import to_categorical
from tensorflow.python.keras.utils.data_utils import Sequence

class DataLoader(Sequence):
    def __init__(self, type='train', batch_size=32, num_classes=8):
        with open(f'dataset/{type}.csv') as f:
            self.lines = [line.strip() for line in f.readlines()]
        self.batch_size = batch_size
        self.num_classes = num_classes

    def __len__(self):
        return int(np.ceil(len(self.lines) / self.batch_size))

    def __getitem__(self, idx):
        # 此处的 getitem 读取的是一个 batch
        batch = self.lines[idx * self.batch_size:(idx + 1) * self.batch_size]
        return self.preprocess(batch)

    def preprocess(self, batch_data):
        # load images and labels
        # labels -> to_categorial(labels)
        images = []
        labels = []
        for data in batch_data:
            images.append(cv2.imread(data.split(';')[0]))
            labels.append(int(data.split(';')[1]))
        labels = to_categorical(labels, self.num_classes, dtype=int)
        return np.array(images), labels