In [1]:
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from IPython import display
from d2l import torch as d2l
from d2l.torch import Accumulator

d2l.use_svg_display()

In [2]:
def get_dataloader_workers():
    return 4

In [3]:
def load_data_fashion_mnist(batch_size, resize = None):
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)

    mnist_train = torchvision.datasets.FashionMNIST(root = './data', train = True, transform = trans, download = True)
    mnist_test = torchvision.datasets.FashionMNIST(root = './data', train = False, transform = trans, download = True)

    return (
        data.DataLoader(mnist_train, batch_size = batch_size, shuffle = True, num_workers = get_dataloader_workers()),
        data.DataLoader(mnist_test, batch_size = batch_size, shuffle = True, num_workers = get_dataloader_workers())
    )

In [4]:
batch_size = 256
train_iter, test_iter = load_data_fashion_mnist(batch_size)

In [5]:
num_inputs = 784
num_outputs = 10

In [6]:
w = torch.normal(0, 0.01, size = (num_inputs, num_outputs), requires_grad = True)
b = torch.zeros(size = (1, num_outputs), requires_grad = True)

In [7]:
def softmax(x):
    x_exp = torch.exp(x)
    partition = x_exp.sum(dim = 1, keepdim = True) #keepdim是为了广播机制
    return x_exp / partition #这里利用了广播机制

In [8]:
def net(x):
    z = torch.mm(x.reshape(-1, w.shape[0]), w) + b
    #-1表示架构帮你计算,实际上这边是批量大小,x实际是一个3维矩阵,每层代表一张图,reshape将x变成一个每行是一张图的矩阵,行数就是批量大小了
    return softmax(z)

In [9]:
def cross_entropy(y_hat, y):
    return -torch.log(y_hat[range(0, len(y_hat)), y])
    #这边log里的y_hat可以这么理解：y_hat是一个矩阵,相当于一个二维数组,里面的参数就是2个
    #两个参数分别代表行列,由于每一行为一个样本,里面是每个标签的概率,而y又代表正确的标签,因此可以表示为正确类别上的所有样本的概率
    #举个例子,两个参数分别为[0, 1] [0, 2]那就相当于取了y_hat[0, 0] y_hat[1, 2]两个参数之间一一对应,然后组成一个向量
    #这里就相当于在所有行上取正确类别的标签

In [10]:
def accuracy(y_hat, y):
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1: #shape[1]也可以写成size(1)
        y_hat = y_hat.argmax(axis = 1) #argmax函数代表数组中的最大值所在的索引,也即数组下标,axis表示维度(0表示行与行之间比较,1表示列与列之间比较)

    cmp = y_hat.type(y.dtype) == y #因为是向量之间的比较所以cmp也是向量,将y_hat的类型转化为y的类型是因为可能类型不一样(跟向量计算需要reshape一样,因为pytorch不区分行向量和列向量)
    return float(cmp.type(y.dtype).sum()) #转换成y的type也是为了方便计算,sum是因为向量无法直接转化为float,需要求和,而且本身也本来就需要求和,因为需要返回正确的个数

# accracy(y_hat, y) / len(y) 计算正确率

In [11]:
def evaluate_accuracy(net, data_iter):
    if isinstance(net, torch.nn.Module):
        net.eval()
    metric = Accumulator(2) #Accumulator累加器
    for x, y in data_iter:
        metric.add(accuracy(net(x), y), y.numel()) #两个分别是正确个数和总样本数
        return metric[0] / metric[1] #用正确个数除以总样本数就得到了正确率

In [12]:
evaluate_accuracy(net, test_iter)

0.0546875

In [13]:
lr = 0.1

def updater(batch_size):
    return d2l.sgd([w, b], lr, batch_size)

In [None]:
epochs = 10

d2l.train_ch3(net, train_iter, test_iter, cross_entropy, epochs, updater)

In [None]:
d2l.predict_ch3(net, test_iter)