# 加载数据

# 1. Dataset

In [9]:
from torch.utils.data import Dataset

In [10]:
from PIL import Image # 一个读取图片的库
import os # 导入关于系统的库

In [11]:
# Dataset?? # 可以查看工具的作用

*根据提示，需要自己定义一个类继承 Dataset类，修改__init__、\__getitem__、__len__这几个方法，让自己的类可以读取数据与标签值*

**可以边用控制台调试边写（Python REPL）**

In [12]:
class MyData(Dataset):
    '''
    创建自己的数据集读取类
    '''
    def __init__(self, root_dir, label_dir): # 定义类里面的全局变量，观察到labels就是文件夹的名称
        self.root_dir = root_dir # 根目录，这里是训练集的文件夹路径
        self.label_dir = label_dir # 数据集所在文件夹名字，这里是标签名
        self.path = os.path.join(self.root_dir, self.label_dir) # 获取图片所在文件，这里文件名就是标签
        self.image_path_list = os.listdir(self.path) # 获取所有图片的名字组成列表
    
    def __getitem__(self, index): # 获取某一个图片，index为索引
        image_name = self.image_path_list[index]
        image_item_path = os.path.join(self.path, image_name)
        image = Image.open(image_item_path)
        label = self.label_dir
        return image, label
    
    def __len__(self): # 返回数据集的长度
        return len(self.image_path_list)


In [13]:
root_dir = r"hymenoptera_data\train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir, ants_label_dir) # 获取蚂蚁数据集
bees_dataset = MyData(root_dir, bees_label_dir) # 获取蜜蜂数据集

In [14]:
train_dataset = ants_dataset + bees_dataset # 两个数据集合并，合成训练集

In [None]:
image, label = train_dataset[123] # 获取某一索引的数据
image.show()

*以下是用Python REPL的调试过程（逐行运行），注意将self删除再运行*

In [15]:
# from PIL import Image
# image_path = "E:\\Programming\\python\\pytorch\\hymenoptera_data\\train\\ants\\0013035.jpg"
# image = Image.open(image_path)
# image.size
# image.show()
# dir_path = r"hymenoptera_data\train\ants" # 这里r防止\转义
# import os
# image_path_list = os.listdir(dir_path) # 文件夹下所有文件名称作为列表输出
# image_path_list[0]
# root_dir = r"hymenoptera_data\train" # 训练集文件夹
# label_dir = "ants" # 标签名称
# path = os.path.join(root_dir, label_dir) # 将两个路径合成拼接
# image_path_list = os.listdir(path) # 获取所有图片的名字组成列表
# index = 0
# image_name = image_path_list[index]
# image_item_path = os.path.join(path, image_name)
# image = Image.open(image_item_path)
# label = label_dir
# len(image_path_list)

*写完class后，接着运行调试*

In [16]:
# 整段运行
# from torch.utils.data import Dataset
# from PIL import Image # 一个读取图片的库
# import os # 导入关于系统的库
# class MyData(Dataset):
#     '''
#     创建自己的数据集读取类
#     '''
#     def __init__(self, root_dir, label_dir): # 定义类里面的全局变量，观察到labels就是文件夹的名称
#         self.root_dir = root_dir # 根目录，这里是训练集的文件夹路径
#         self.label_dir = label_dir # 数据集所在文件夹名字，这里是标签名
#         self.path = os.path.join(self.root_dir, self.label_dir) # 获取图片所在文件，这里文件名就是标签
#         self.image_path_list = os.listdir(self.path) # 获取所有图片的名字组成列表
    
#     def __getitem__(self, index): # 获取某一个图片，index为索引
#         image_name = self.image_path_list[index]
#         image_item_path = os.path.join(self.path, image_name)
#         image = Image.open(image_item_path)
#         label = self.label_dir
#         return image, label
    
#     def __len__(self): # 返回数据集的长度
#         return len(self.image_path_list)
    
# root_dir = r"hymenoptera_data\train"
# ants_label_dir = "ants"

In [17]:
# 逐行运行
# ants_dataset = MyData(root_dir, ants_label_dir) # 包括所有初始化变量
# ants_dataset[0] # 与ants_dataset.__getitem__(0)作用相同，返回的是image和label
# image, label = ants_dataset[0]
# image.show()
# image, label = ants_dataset[1]
# image.show()

In [18]:
# 整段运行
# root_dir = r"hymenoptera_data\train"
# ants_label_dir = "ants"
# bees_label_dir = "bees"
# ants_dataset = MyData(root_dir, ants_label_dir)
# bees_dataset = MyData(root_dir, bees_label_dir)

In [19]:
# 逐行运行
# image, label = bees_dataset[1]
# image.show()
# train_dataset = ants_dataset + bees_dataset # 两个数据集合并
# len(train_dataset)
# len(ants_dataset)
# len(bees_dataset)
# image, label = train_dataset[123]
# image.show()
# image, label = train_dataset[124]
# image.show()

# 2. DataLoader

In [1]:
import torchvision
from torch.utils.data import DataLoader

In [2]:
test_set = torchvision.datasets.CIFAR10(root="./my_cifar10", train=False, transform=torchvision.transforms.ToTensor(), download=True)
test_set[0]

(tensor([[[0.6196, 0.6235, 0.6471,  ..., 0.5373, 0.4941, 0.4549],
          [0.5961, 0.5922, 0.6235,  ..., 0.5333, 0.4902, 0.4667],
          [0.5922, 0.5922, 0.6196,  ..., 0.5451, 0.5098, 0.4706],
          ...,
          [0.2667, 0.1647, 0.1216,  ..., 0.1490, 0.0510, 0.1569],
          [0.2392, 0.1922, 0.1373,  ..., 0.1020, 0.1137, 0.0784],
          [0.2118, 0.2196, 0.1765,  ..., 0.0941, 0.1333, 0.0824]],
 
         [[0.4392, 0.4353, 0.4549,  ..., 0.3725, 0.3569, 0.3333],
          [0.4392, 0.4314, 0.4471,  ..., 0.3725, 0.3569, 0.3451],
          [0.4314, 0.4275, 0.4353,  ..., 0.3843, 0.3725, 0.3490],
          ...,
          [0.4863, 0.3922, 0.3451,  ..., 0.3804, 0.2510, 0.3333],
          [0.4549, 0.4000, 0.3333,  ..., 0.3216, 0.3216, 0.2510],
          [0.4196, 0.4118, 0.3490,  ..., 0.3020, 0.3294, 0.2627]],
 
         [[0.1922, 0.1843, 0.2000,  ..., 0.1412, 0.1412, 0.1294],
          [0.2000, 0.1569, 0.1765,  ..., 0.1216, 0.1255, 0.1333],
          [0.1843, 0.1294, 0.1412,  ...,

In [3]:
test_set[0][0].shape

torch.Size([3, 32, 32])

In [4]:
test_loader = DataLoader(dataset=test_set, batch_size=4, shuffle=True, num_workers=0, drop_last=False)

In [6]:
type(test_loader)

torch.utils.data.dataloader.DataLoader

In [None]:
i = 0
for data in test_loader:
    i = i + 1
    images, targets = data
    print(images.shape)
    print(targets)
    
    if i == 5:
        break

torch.Size([4, 3, 32, 32])
tensor([7, 4, 2, 6])
torch.Size([4, 3, 32, 32])
tensor([0, 0, 0, 5])
torch.Size([4, 3, 32, 32])
tensor([5, 9, 2, 2])
torch.Size([4, 3, 32, 32])
tensor([5, 2, 7, 6])
torch.Size([4, 3, 32, 32])
tensor([8, 0, 5, 9])


In [14]:
from torch.utils.tensorboard import SummaryWriter

In [10]:
test_loader = DataLoader(dataset=test_set, batch_size=64, shuffle=True, num_workers=0, drop_last=False)

In [18]:
writer_loader = SummaryWriter("dataloader")
step = 0

for epoch in range(2):
    for data in test_loader:
        images, targets = data
        writer_loader.add_images("epoch:{}".format(epoch), images, step)
        step = step + 1

writer_loader.close()

*可以看到shuffle=True时随机取数据*