In [2]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import torchvision.transforms as transforms

# 數據集加載
## Dataset
### 抽象類別, 所有的數據集都必須繼承Dataset並且重寫 "__getitem__" 和 "__len__" 方法
### "__getitem__": 獲取每個數據對應的label
###  "__len__": 數據總數
## Dataset相關文檔查閱

In [3]:
help(Dataset)

Help on class Dataset in module torch.utils.data.dataset:

class Dataset(typing.Generic)
 |  An abstract class representing a :class:`Dataset`.
 |  
 |  All datasets that represent a map from keys to data samples should subclass
 |  it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
 |  data sample for a given key. Subclasses could also optionally overwrite
 |  :meth:`__len__`, which is expected to return the size of the dataset by many
 |  :class:`~torch.utils.data.Sampler` implementations and the default options
 |  of :class:`~torch.utils.data.DataLoader`.
 |  
 |  .. note::
 |    :class:`~torch.utils.data.DataLoader` by default constructs a index
 |    sampler that yields integral indices.  To make it work with a map-style
 |    dataset with non-integral indices/keys, a custom sampler must be provided.
 |  
 |  Method resolution order:
 |      Dataset
 |      typing.Generic
 |      builtins.object
 |  
 |  Methods defined here:
 |  
 |  __add__(self, oth

In [4]:
Dataset??

[1;31mInit signature:[0m [0mDataset[0m[1;33m([0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m
[1;31mSource:[0m        
[1;32mclass[0m [0mDataset[0m[1;33m([0m[0mGeneric[0m[1;33m[[0m[0mT_co[0m[1;33m][0m[1;33m)[0m[1;33m:[0m[1;33m
[0m    [1;34mr"""An abstract class representing a :class:`Dataset`.

    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
    data sample for a given key. Subclasses could also optionally overwrite
    :meth:`__len__`, which is expected to return the size of the dataset by many
    :class:`~torch.utils.data.Sampler` implementations and the default options
    of :class:`~torch.utils.data.DataLoader`.

    .. note::
      :class:`~torch.utils.data.DataLoader` by default constructs a index
      sampler that yields integral indices.  To make it work with a map-style
      dataset with non-integral indices/keys, a custom sampler must be

## 讀取自己的數據集

In [4]:
class Animal(Dataset):
    def __init__(self, data_dir, transform=None):
        """
            Animal Dataset
            param data_dir: str, 數據集所在路徑
            param transform: torch.transform，數據預處理, 默認不進行處理
            self.data_info: (圖片路徑, 標籤)的列表(全部圖片), [(), (), ...]
        """
        self.data_info = self.get_img_info(data_dir)
        self.transform = transform

    def __getitem__(self, index):
        path_img, label = self.data_info[index]
        img = Image.open(path_img).convert('RGB')  # 圖片需轉RGB
        if self.transform is not None:
            img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.data_info)

    # 返回所有圖片的路徑和標籤
    @staticmethod
    def get_img_info(data_dir):
        data_info = list()
        for root, dirs, _ in os.walk(data_dir):
            if dirs:
                for sub_dir in dirs:
                    file_names = os.listdir(os.path.join(root, sub_dir))
                    # 僅保留.jpg的檔案
                    img_names = list(filter(lambda x: x.endswith('.jpg'), file_names))

                    for img_name in img_names:
                        path_img = os.path.join(root, sub_dir, img_name)
                        label = sub_dir  # 0: ants, 1: bees
                        data_info.append((path_img, int(label)))
        return data_info

In [9]:
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])
animal_train = Animal('./Data/hymenoptera_data/train', transform=train_transform)
animal_test = Animal('./Data/hymenoptera_data/test', transform=test_transform)
print('len(animal_train): ', len(animal_train))  # __len__
img, label = animal_train[1]  # __getitem__
print('label: ', label)
print(type(img))
# # 可視化
# trans = transforms.ToPILImage()
# img = trans(img)
# img.show()

len(animal_train):  28
label:  0
<class 'torch.Tensor'>


## DataLoder
### dataset: 繼承Dataset的數據集
### batch_size: 一個Batch多少數據, 太大可能Gpu負荷不了
### shuffle: 是否打亂(建議True)
### drop_last: 除不盡的時候是否捨去
### num_workers: 多進程, Windows下設置大於0有時候會報錯

In [68]:
# 切分BATCH_SIZE
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])
animal_train = Animal('./Data/hymenoptera_data/train', transform=train_transform)
animal_test = Animal('./Data/hymenoptera_data/test', transform=test_transform)
train_loader = DataLoader(dataset=animal_train, batch_size=10, shuffle=True, drop_last=False)
test_loader = DataLoader(dataset=animal_test, batch_size=10, shuffle=True, drop_last=False)
for epoch in range(1):
    for data in test_loader:
        imgs, targets = data
        print(imgs.shape)  # 10張 3 * 224 * 224的圖片
        print(targets)  # 10個 label


torch.Size([10, 3, 224, 224])
tensor([1, 1, 0, 0, 0, 0, 0, 1, 0, 1])
torch.Size([10, 3, 224, 224])
tensor([0, 0, 1, 0, 1, 0, 1, 1, 0, 0])
torch.Size([10, 3, 224, 224])
tensor([1, 1, 1, 1, 0, 1, 1, 1, 1, 1])
torch.Size([10, 3, 224, 224])
tensor([1, 1, 1, 0, 0, 1, 1, 0, 0, 1])
torch.Size([10, 3, 224, 224])
tensor([1, 0, 0, 0, 1, 0, 0, 1, 0, 0])
torch.Size([10, 3, 224, 224])
tensor([1, 0, 0, 0, 1, 1, 1, 0, 1, 1])
torch.Size([10, 3, 224, 224])
tensor([1, 1, 1, 0, 0, 1, 0, 0, 0, 1])
torch.Size([10, 3, 224, 224])
tensor([0, 0, 0, 0, 0, 1, 0, 1, 1, 1])
torch.Size([10, 3, 224, 224])
tensor([1, 1, 0, 1, 0, 1, 1, 1, 0, 1])
torch.Size([10, 3, 224, 224])
tensor([0, 0, 1, 1, 1, 1, 0, 0, 0, 0])
torch.Size([10, 3, 224, 224])
tensor([1, 1, 1, 1, 0, 1, 1, 0, 1, 1])
torch.Size([10, 3, 224, 224])
tensor([0, 0, 1, 1, 0, 1, 1, 1, 0, 0])
torch.Size([10, 3, 224, 224])
tensor([1, 1, 1, 1, 0, 1, 1, 1, 0, 1])
torch.Size([10, 3, 224, 224])
tensor([0, 0, 1, 1, 1, 1, 1, 1, 0, 1])
torch.Size([10, 3, 224, 224])
tens