In [1]:
import torch
import torchvision
from torch import nn
from d2l import torch as d2l

In [2]:
batch_size=256
train_data=torchvision.datasets.FashionMNIST(root='../data',train=True,transform=torchvision.transforms.ToTensor())
test_data=torchvision.datasets.FashionMNIST(root='../data',train=False,transform=torchvision.transforms.ToTensor())
train_dataloader=torch.utils.data.DataLoader(dataset=train_data,batch_size=batch_size,shuffle=True)
test_dataloader=torch.utils.data.DataLoader(dataset=test_data,batch_size=batch_size,shuffle=True)


In [30]:
def Relu(x):
    if not isinstance(x,torch.Tensor):
        x=torch.tensor(x)
    zero_mat=torch.zeros_like(x)
    return torch.maximum(x,zero_mat)

def net(x):
    input=x.reshape(-1,num_input)
    x2=input@w1+b1
    x3=Relu(x2)
    x4=x3@w2+b2
    return x4

class Accumulator:  #@save
    """在n个变量上累加"""
    def __init__(self, n):
        self.data = [0.0] * n

    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]

    def reset(self):
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
def accuracy(y_hat, y):  #@save
    """计算预测正确的数量"""
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum())


def evaluate_accuracy(net, data_loader):
    if isinstance(net,torch.nn.Module):
        net.eval()
    metric=Accumulator(2)
    for x, y in data_loader:
        metric.add(accuracy(net(x),y),y.numel())
    return metric[0]/metric[1]


In [35]:
num_input,num_output=784,10
num_hidden=256
w1=nn.Parameter(torch.randn(num_input,num_hidden,requires_grad=True))
b1=nn.Parameter(torch.zeros(num_hidden))
w2=nn.Parameter(torch.randn(num_hidden,num_output,requires_grad=True))
b2=nn.Parameter(torch.zeros(num_output))
params=[w1,b1,w2,b2]

num_epoch=10
lr=0.1
loss_func=nn.CrossEntropyLoss()
optm=torch.optim.SGD(params,lr=lr)

for epoch in range(num_epoch):
    for data,label in train_dataloader:
        optm.zero_grad()
        result=net(data)
        loss=loss_func(result,label)
        loss.backward()
        optm.step()

    with torch.no_grad():
        a=evaluate_accuracy(net,test_dataloader)
        print(a)

0.7389
0.7545
0.7605
0.7635
0.7696
0.7656
0.7756
0.7821
0.7838
0.7857
