In [1]:
# Read a Picture dataset
%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l

# set the display format as svg
# d2l.use_svg_display()

# convert the image from PIL format to float32 format
# Normalize the RGB value from 0-255 to 0-1
trans = transforms.ToTensor()

# download the dataset from torch to local directory data
# with the transformer defined above
mnist_train = torchvision.datasets.FashionMNIST(root='./data/SoftMaxRegression/train/' , train=True , transform=trans , download=True)

mnist_test = torchvision.datasets.FashionMNIST(root='./data/SoftMaxRegression/test/' , train=False , transform=trans , download=True) 

len(mnist_train) , len(mnist_test)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [Errno 8] nodename nor servname provided, or not known>



RuntimeError: Error downloading train-images-idx3-ubyte.gz

In [None]:
# return the shape of a single picture
# there is only one rgb channel of this b-w picture
mnist_train[0][0].shape

In [None]:
def get_fashion_mnist_labels(labels):
    """return the text labels of FashionMNIST dataset"""
    text_labels = [
        't-shirt' , 'trousers' , 'pullover' , 'dress' , 'coat' , 'sandal' , 'shirt' , 'sneaker' , 'bag' , 'ankle boot'
    ]
    return [text_labels[int(i)] for i in labels]

def show_images(imgs , num_rows , num_cols , titles=None , scale=1.5):
    """plot a list of images"""
    # scale represents for the ration of enlarge 
    figsize = (num_cols * scale , num_rows * scale)

    # arrange space for pictures
    _, axes = d2l.plt.subplots(num_rows , num_cols , figsize=figsize)
    axes = axes.flatten()
    
    # pack pictures
    for i , (ax , img) in enumerate(zip(axes , imgs)):
        if torch.is_tensor(img):
            # picture tensor
            ax.imshow(img.numpy())
        else:
            ax.imshow(img)

        # rid the outer border line
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        
        # set titles
        if titles:
            ax.set_title(titles[i])

# get a small batch of pictures and display them        
X , y = next(iter(data.DataLoader(mnist_train , batch_size=18)))
show_images(X.reshape(18 , 28 , 28) , 3 , 6 , titles=get_fashion_mnist_labels(y))

             

In [None]:
# ready to work!
batch_size = 256

def get_dataloader_workers():
    """the thread of dataloader"""
    return 4

# generate a iterator   
train_iter = data.DataLoader(mnist_train , batch_size , shuffle=True , num_workers=get_dataloader_workers())

timer = d2l.Timer()
# caculate the time of read data
for X , y in train_iter:
    continue
f'{timer.stop():.2f} sec'

In [None]:
# Intergrate the components into a simple function
def load_data_fashion_mnist(batch_size, resize=None):  #@save
    """download FashionMNIST dataset and load it into memory"""
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(
        root="../data", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="../data", train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=get_dataloader_workers()))