In [1]:
%matplotlib inline

数据迭代器 - Data Generator
====
>Python2.7 + Pytorch 1.2.0 backened
>
>MNIST

数据迭代器主要是解决内存/显存空间不足的问题，但会使得程序的运行时间变慢。<br>
总体上来说，程序的运行时间取决于运算器能力，迭代器只是一种折中方式。<br>
数据迭代器并不能保证内存/显存空间就一直会稳定在一个水平，会非常缓慢的上升。（这是由程序运行造成的，不适用迭代器也会造成次现象。建议留出一定的空间）

In [2]:
# -*- coding: utf-8 -*-
# !/usr/bin/env python
'''
@author: Yang
@time: 18-2-24 下午2:08
'''

from __future__ import print_function

import torch
import torch.utils.data as data
from torch.utils.data import Dataset
import numpy as np
import os

首先利用numpy生成假数据

In [3]:
def fake_data_generator(num=100):
    # define make direction func
    def mkdir(name=None):
        if os.path.exists(name):
            pass
        else:
            os.mkdir(name)

    # generate images
    mkdir(name='images')
    for i in xrange(num):
        new_array = np.zeros(shape=(10, 10), dtype=np.float32) * num
        np.save(file='images/%s.npy' % i, arr=new_array)
    # generate text
    mkdir(name='text')
    for i in xrange(num):
        new_text = str(i)
        with open('text/%s.txt' % i, mode='wb') as text_buffer:
            text_buffer.write(new_text)


fake_data_generator()

创建一个数据迭代器类，继承于torch.utils.data.Dataset

In [4]:
# this function is to extract files from the path
def getSubfiles(path):
    return sorted([os.path.join(path, subdir) for subdir in os.listdir(path)])


class DataGenerator(Dataset):
    def __init__(self, img_dir, text_dir, func=getSubfiles):
        self.img_list = func(img_dir)
        self.text_list = func(text_dir)

        assert len(self.img_list) == len(self.text_list)
        self.length = len(self.img_list)

    def __getitem__(self, index):
        '''
        this function can only return a pair of data
        :param index:
        :return: (image, text)
        '''
        # read an image
        img = np.load(self.img_list[index])
        # read a txt
        with open(self.text_list[index], 'rb') as text_buffer:
            text = text_buffer.read()
        return (img, text)

    def __len__(self):
        # return self.length
        return len(self.img_list)

DataGenerator的__getitem__函数，一次只能生成一组(img, text)数据。需要利用额外的collate_fn将多组(img, text)变成(img), (text)。

In [5]:
def collate_fn(batch_data):
#     batch_data.sort(key=lambda x: int(x[-1]), reverse=True)
    img, labels = zip(*batch_data)
    return img, labels

In [6]:
dataset = DataGenerator(img_dir='images', text_dir='text')
batch_size = 10

data_loader = torch.utils.data.DataLoader(
    dataset=dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn)

对迭代器进行一次遍历，该过程中的数据不会重复。<br>
可以无限循环下去

In [7]:
for _ in xrange(len(data_loader)):
    img, label = data_loader.__iter__().__next__()
    print(label)

('56', '64', '80', '38', '2', '93', '27', '95', '73', '40')
('30', '6', '14', '44', '0', '61', '53', '37', '90', '15')
('68', '1', '50', '24', '16', '6', '30', '7', '33', '28')
('60', '90', '53', '46', '38', '19', '1', '11', '95', '67')
('50', '85', '47', '44', '62', '30', '63', '8', '93', '89')
('12', '27', '34', '35', '53', '66', '17', '46', '64', '93')
('84', '89', '38', '87', '99', '21', '78', '98', '15', '26')
('1', '94', '4', '50', '41', '21', '85', '74', '83', '25')
('78', '44', '79', '41', '16', '87', '31', '67', '69', '22')
('35', '93', '2', '10', '87', '42', '76', '66', '34', '45')
