In [29]:
import os
import sys
import glob
import shutil
import json
import cv2
import numpy as np
from PIL import Image
import torch
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms

In [33]:
class SVHNDataset(Dataset):
    def __init__(self, img_path, img_label, transform=None):
        self.img_path = img_path
        self.img_label = img_label
        if transforms is not None:
            self.transform = transform
        else:
            self.transform = None
    def __getitem__(self, index):
        img = Image.open(self.img_path[index]).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        lbl = np.array(self.img_label[index], dtype=np.int32)
        lbl = list(lbl) + (5 - len(lbl)) * [10]
        return img, torch.from_numpy(np.array(lbl[:5]))
    def __len__(self):
        return len(self.img_path)

In [30]:
train_path = glob.glob('./data/mchar_train/*.png')
train_path.sort()
train_json = json.load(open('./data/mchar_train.json'))
train_label = [train_json[x]['label'] for x in train_json]

In [34]:
train_loader = torch.utils.data.DataLoader(
        SVHNDataset(train_path, train_label,
                   transforms.Compose([
                       transforms.Resize((64, 128)),
                       transforms.ColorJitter(0.3, 0.3, 0.2),
                       transforms.RandomRotation(5),
                       transforms.ToTensor(),
                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])), 
    batch_size=10, # 每批样本个数
    shuffle=False, # 是否打乱顺序
    #num_workers=10, # 读取的线程个数
)