In [1]:
import torch
import torchvision
import torch.nn as nn
import numpy as np
import sys

from torchvision import transforms

In [167]:
def dropout(X, drop_prob):
    X = X.float()
    keep_prob = 1 - drop_prob

    if keep_prob == 0:
        return torch.zeros_like(X)
    mask = (torch.rand(X.shape)<keep_prob).float()

    return mask * X/keep_prob

In [168]:
X = torch.arange(16).view(2,8)
print(X)

tensor([[ 0,  1,  2,  3,  4,  5,  6,  7],
        [ 8,  9, 10, 11, 12, 13, 14, 15]])


In [169]:
num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784, 10, 256, 256

W1 = torch.tensor(np.random.normal(0, 0.01, size=(num_inputs, num_hiddens1)), dtype=torch.float, requires_grad=True)
b1 = torch.zeros(num_hiddens1, requires_grad=True)
W2 = torch.tensor(np.random.normal(0, 0.01, size=(num_hiddens1, num_hiddens2)), dtype=torch.float, requires_grad=True)
b2 = torch.zeros(num_hiddens2, requires_grad=True)
W3 = torch.tensor(np.random.normal(0, 0.01, size=(num_hiddens2, num_outputs)), dtype=torch.float, requires_grad=True)
b3 = torch.zeros(num_outputs, requires_grad=True)

params = [W1, b1, W2, b2, W3, b3]

In [170]:
drop_prob1, drop_prob2 = 0.2,0.5
def net(X, is_training=True):
    X = X.view(-1,num_inputs)
    H1 = (torch.matmul(X,W1)+b1).relu()
    if is_training:
        H1 = dropout(H1, drop_prob1)
    H2 = (torch.matmul(H1, W2)+b2).relu()
    if is_training:
        H2 = dropout(H2, drop_prob2)
    return torch.matmul(H2, W3)+b3


In [171]:
def load_data(batch_size,num_workers):
    mnist_train = torchvision.datasets.FashionMNIST(root='E:\Training_Sets\Dropout',
                                                    train=True, download=True,
                                                    transform=transforms.ToTensor())
    mnist_test = torchvision.datasets.FashionMNIST(root='E:\Training_Sets\Dropout',
                                                train=False, download=True,
                                                transform=transforms.ToTensor())

    train_iter = torch.utils.data.DataLoader(
        mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_iter = torch.utils.data.DataLoader(
        mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    
    return train_iter,test_iter

In [172]:
def evaluate_accuracy(data_iter, net):
    acc_sum, n = 0.0, 0
    for X, y in data_iter:
        if isinstance(net, torch.nn.Module):
            net.eval() # 评估模式, 这会关闭dropout
            acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()
            net.train() # 改回训练模式
        else: # 自定义的模型
            if('is_training' in net.__code__.co_varnames): # 如果有is_training这个参数
                # 将is_training设置成False
                acc_sum += (net(X, is_training=False).argmax(dim=1) == y).float().sum().item() 
            else:
                acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() 
        n += y.shape[0]
    return acc_sum / n

In [173]:
loss = nn.CrossEntropyLoss()

def sgd(params, lr, batch_size):
    for param in params:
        param.data -= lr * param.grad / batch_size # 注意这里更改param时用的param.data+

In [174]:
def train(train_iter, test_iter, num_epochs, net, loss, params, sgd, batch_size, lr):
    for epoch in range(num_epochs):
        tran_l_sum, train_acc_sum, n = 0.0, 0.0, 0
        for X, y in train_iter:
            y_hat=net(X)

            l = loss(y_hat, y).sum()

            if params is not None and params[0].grad is not None:
                for param in params:
                    param.grad.data.zero_()

            l.backward()
            sgd(params, lr, batch_size)

            train_l_sum = l.item()
            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()
            n += y.shape[0]

        test_acc = evaluate_accuracy(test_iter, net)
        print(('epoch %d, loss %.9f, train_acc %.3f, test_acc %.3f') % (epoch + 1, train_l_sum/n, train_acc_sum/n, test_acc))

In [175]:
num_epochs,lr=5,100
train(train_iter,test_iter,num_epochs,net,loss,params,sgd,batch_size,lr)

epoch 1, loss 0.000010555, train_acc 0.554, test_acc 0.696
epoch 2, loss 0.000009115, train_acc 0.786, test_acc 0.808
epoch 3, loss 0.000007967, train_acc 0.822, test_acc 0.827
epoch 4, loss 0.000004916, train_acc 0.839, test_acc 0.830
epoch 5, loss 0.000006901, train_acc 0.848, test_acc 0.823


<torch.utils.data.dataloader.DataLoader object at 0x000002BBA2BF1BB0>
