In [None]:
import torch
import torchvision 
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l
import sys

d2l.use_svg_display()

In [None]:
trans: transforms.ToTensor = transforms.ToTensor() # 一个转换器，将pil类型转为32位浮点，并规范到0-1的浮点内
mnist_train: torchvision.datasets.FashionMNIST = torchvision.datasets.FashionMNIST(
    root="../self/data/",
    train=True,
    transform=trans,
    download=True
)

mnist_test: torchvision.datasets.FashionMNIST = torchvision.datasets.FashionMNIST(
    root="../self/data/",
    train=False,
    transform=trans,
    download=True
)


In [None]:
def get_fashion_mnist_labels(labels: list) -> list:
    '''获得标签的文本描述

    Args:
        labels (list): 输入的标签

    Returns:
        list: 返回的标签文本描述
    '''
    text_labels: list = [
        "t-shirt",
        "trouser",
        "pullover",
        "dress",
        "coat",
        "sandal",
        "shirt",
        "sneaker",
        "bag",
        "ankle boot"
    ]
    return [text_labels[int(i)] for i in labels] # 一种语法糖，将labels中的每个元素转为int，然后作为text_labels的下标，取出对应的值

def show_images(imgs: list, num_rows: int, num_cols: list, titles: list = None, scale: float = 2) -> None:
    '''显示图片

    Args:
        imgs (list): 图片列表
        num_rows (int): 行数
        num_cols (list): 列数
        titles (list, optional): 标题列表. Defaults to None.
        scale (float, optional): 图片缩放比例. Defaults to 2.
    '''
    figsize: tuple = (num_cols * scale, num_rows * scale) # 图片大小
    _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize) # 创建子图，子图的意思是将一个大图分为多个小图，返回的是一个元组，第一个元素是子图的figure，第二个元素是子图的axes，axes是一个多维数组，每个元素是一个子图，这里规定了框架的大小
    axes = axes.flatten() # 将子图展平，展平的意思是将多维数组转为一维数组，这么做方便下面的迭代。这里的展平不影响展现的框架
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if torch.is_tensor(img):
            ax.imshow(img.numpy(), cmap="gray") # 如果是tensor类型，转为numpy类型，然后显示为灰度图，不加的话，显示的是彩色图
        else:
            ax.imshow(img, cmap="gray") # 如果是numpy类型，直接显示为灰度图
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    d2l.plt.show()

In [None]:
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18))) # 从mnist_train中取出一个batch_size为18的数据
show_images(X.reshape(18, 28, 28), 3, 6, titles=get_fashion_mnist_labels(y)) 

In [None]:
def get_dataloader_workers() -> int:
    '''获得dataloader的进程数

    Returns:
        int: 进程数
    '''
    return 4

In [None]:
batch_size: int = 256
result: list = []
for i in range(1, 13):
    train_iter: data.DataLoader = data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=i)
    timer = d2l.Timer() # 计时器

    for X, y in train_iter: # 这是为了测试dataloader的速度
        continue
    result.append(timer.stop())

d2l.plt.plot(range(1, 13), result)




In [None]:
def load_data_fashion_mnist(batch_size: int, resize: tuple = None) -> tuple(data.DataLoader):
    '''下载fashion_mnist数据集，然后加载为dataloader

    Args:
        batch_size (int): 每次迭代的batch_size大小
        resize (tuple, optional): 是否需要resize. Defaults to None.

    Returns:
        data.DataLoader: 返回两个Dataloader，第一个是Train，第二个是Test
    '''
    trans: list = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train: torchvision.datasets.FashionMNIST = torchvision.datasets.FashionMNIST(
        root="../self/data/",
        train=True,
        transform=trans,
        download=True
    )
    mnist_test: torchvision.datasets.FashionMNIST = torchvision.datasets.FashionMNIST(
        root="../self/data/",
        train=False,
        transform=trans,
        download=True
    )
    return (
        data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_workers()),
        data.DataLoader(mnist_test, batch_size, shuffle=False, num_workers=get_dataloader_workers())
    )