In [1]:
import  torch
import  os, glob
import  random, csv
from    torch.utils.data import Dataset, DataLoader
from    torchvision import transforms
from    PIL import Image


class Pokemon(Dataset):
    def __init__(self, root, resize, mode):
        super(Pokemon, self).__init__()
        self.root = root
        self.resize = resize

        self.name2label = {} # "sq...":0，创建字典映射关系，每一个宝可梦一旦编码后就不能动了
        for name in sorted(os.listdir(os.path.join(root))):	# listdir()返回的顺序不固定，所以排了个序
            if not os.path.isdir(os.path.join(root, name)):
                continue
            self.name2label[name] = len(self.name2label.keys())	# 当前键元素的数量的长度作为最新元素的label值
        # print(self.name2label)
        # out: {'bulbasaur': 0, 'charmander': 1, ...}

        # 希望拿到数据对：image_path + image_label
        self.images, self.labels = self.load_csv('images.csv')

        if mode=='train': # 60%
            self.images = self.images[:int(0.6*len(self.images))]
            self.labels = self.labels[:int(0.6*len(self.labels))]
        elif mode=='val': # 20% = 60%->80%
            self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
            self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
        else: # 20% = 80%->100%
            self.images = self.images[int(0.8*len(self.images)):]
            self.labels = self.labels[int(0.8*len(self.labels)):]

    def load_csv(self, filename):
        if not os.path.exists(os.path.join(self.root, filename)):	# 如果csv文件不存在才要创建
            images = []
            for name in self.name2label.keys():
                # 'pokemon\\mewtwo\\00001.png。label可以从path中推断出来所以没有单独保存
                images += glob.glob(os.path.join(self.root, name, '*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
            # 1167, ['pokemon\\bulbasaur\\00000000.png', 'pokemon\\bulbasaur\\00000001.png',...]
            print(len(images), images)
			# 把csv保存起来
            random.shuffle(images)
            with open(os.path.join(self.root, filename), mode='w', newline='') as f:
                writer = csv.writer(f)
                for img in images: # 'pokemon\\bulbasaur\\00000000.png'
                    name = img.split(os.sep)[-2]
                    label = self.name2label[name]
                    # 'pokemon\\bulbasaur\\00000000.png', 0
                    writer.writerow([img, label])
                print('writen into csv file:', filename)

        # read from csv file，拿到images和lables信息
        images, labels = [], []
        with open(os.path.join(self.root, filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                # 'pokemon\\bulbasaur\\00000000.png', 0
                img, label = row
                label = int(label)
                images.append(img)
                labels.append(label)

        assert len(images) == len(labels)	# 保证两者长度一样
        return images, labels

    def __len__(self):
        return len(self.images)

    def denormalize(self, x_hat):	# 逆标准化，为的是visdom显示正常，没啥用
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]

        # x_hat = (x-mean)/std
        # x = x_hat*std + mean
        # x: [c, h, w]
        # mean: [3] => [3, 1, 1]
        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
        # print(mean.shape, std.shape)
        x = x_hat * std + mean
        return x

    def __getitem__(self, idx):
        # idx的范围是[0~len(images)]
        # self.images, self.labels
        # img: 'pokemon\\bulbasaur\\00000000.png'
        # label: 0
        img, label = self.images[idx], self.labels[idx]

        tf = transforms.Compose([
            lambda x:Image.open(x).convert('RGB'), # string path= > image data
            transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))),
            transforms.RandomRotation(15),
            transforms.CenterCrop(self.resize),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],	# 统计的imageNet上RGB三通道上的均值和方差
                                 std=[0.229, 0.224, 0.225])
        ])

        img = tf(img)
        label = torch.tensor(label)
        return img, label

In [5]:
root = 'C:\\Users\\XGQ\\Desktop\\Programs\\Python\\Python程序\\Learning\\Deep Learning\\Data\\pokeman'
db = Pokemon(root, 64, 'train')
x, y = next(iter(db))
print('sample:', x.shape, y.shape, y)

loader = DataLoader(db, batch_size=32, shuffle=True, num_workers=8)

import visdom
import time
for x, y in loader:
    viz.images(
            db.denormalize(x), #因为对原始数据归一化，所以可视化需要返回去，该函数需要自己写下。
            nrow=8,  #每行显示8张图
            win='batch',
            opts=dict(title='batch'))
    viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))
    time.sleep(10)


1167 ['C:\\Users\\XGQ\\Desktop\\Programs\\Python\\Python程序\\Learning\\Deep Learning\\Data\\pokeman\\bulbasaur\\00000000.png', 'C:\\Users\\XGQ\\Desktop\\Programs\\Python\\Python程序\\Learning\\Deep Learning\\Data\\pokeman\\bulbasaur\\00000001.png', 'C:\\Users\\XGQ\\Desktop\\Programs\\Python\\Python程序\\Learning\\Deep Learning\\Data\\pokeman\\bulbasaur\\00000002.png', 'C:\\Users\\XGQ\\Desktop\\Programs\\Python\\Python程序\\Learning\\Deep Learning\\Data\\pokeman\\bulbasaur\\00000003.png', 'C:\\Users\\XGQ\\Desktop\\Programs\\Python\\Python程序\\Learning\\Deep Learning\\Data\\pokeman\\bulbasaur\\00000004.png', 'C:\\Users\\XGQ\\Desktop\\Programs\\Python\\Python程序\\Learning\\Deep Learning\\Data\\pokeman\\bulbasaur\\00000005.png', 'C:\\Users\\XGQ\\Desktop\\Programs\\Python\\Python程序\\Learning\\Deep Learning\\Data\\pokeman\\bulbasaur\\00000008.png', 'C:\\Users\\XGQ\\Desktop\\Programs\\Python\\Python程序\\Learning\\Deep Learning\\Data\\pokeman\\bulbasaur\\00000009.png', 'C:\\Users\\XGQ\\Desktop\\Programs