In [1]:
from d2l import torch as d2l
import torch
from torch import nn

In [2]:
def dropout_layer(h, rate):
#     print(h.shape)
    # h have shape batch_size x layer size
    p = rate*h.shape[1]
    mask = torch.randperm(h.shape[1])
#     print(f'mask: {mask > p}')
    h = h*(mask >= p)
    return h

In [3]:
dropout_layer(torch.tensor([[1, 2, 3, 4]]), 0.5)

tensor([[0, 2, 3, 0]])

In [4]:
batch_size = 256
lr = 0.5
epochs = 10
num_inputs = 784
num_outputs = 10
num_hiddens1 = 256
num_hiddens2 = 256
dropout_rate = [0.2, 0.5]

In [5]:
class MLP(nn.Module):
    def __init__(self, num_inputs, num_outputs, num_hiddens1, num_hiddens2, dropout_rate):
        super(MLP, self).__init__()
        self.lin1 = nn.Linear(num_inputs, num_hiddens1)
        self.lin2 = nn.Linear(num_hiddens1, num_hiddens2)
        self.lin3 = nn.Linear(num_hiddens2, num_outputs)
    
        self.relu = nn.ReLU()
        
        self.num_inputs = num_inputs
        self.num_hiddens1 = num_hiddens1
        self.num_hiddens2 = num_hiddens2
        self.num_outputs = num_outputs
        
        self.dropout_rate = dropout_rate
        
    def forward(self, x):
        h1 = self.relu(self.lin1(x.reshape((-1, self.num_inputs))))

        if (self.dropout_rate):
            h1 = dropout_layer(h1, self.dropout_rate[0])
        
        h2 = self.relu(self.lin2(h1))
        
        if (self.dropout_rate):
            h2 = dropout_layer(h2, self.dropout_rate[1])
            
        o = self.lin3(h2)
        return o
net = MLP(num_inputs, num_outputs, num_hiddens1, num_hiddens2, dropout_rate)

In [6]:
loss = nn.CrossEntropyLoss()
updater = torch.optim.SGD(net.parameters(), lr=lr)

In [7]:
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

In [8]:
for epoch in range(epochs):
#     net.train()
    for i, (X, y) in enumerate(train_iter):
        
        y_hat = net(X)
        l = loss(y_hat, y)
        
        updater.zero_grad()
        l.backward()
        updater.step()
        
        if (i % 100 == 0):
            print(f'Loss: {l.sum():.2f}, accuracy: {(torch.argmax(y_hat, dim=1) == y).sum()/y.shape[0]:.2f}')
        

Loss: 2.30, accuracy: 0.11
Loss: 0.98, accuracy: 0.59
Loss: 0.91, accuracy: 0.67
Loss: 0.84, accuracy: 0.64
Loss: 0.71, accuracy: 0.74
Loss: 0.62, accuracy: 0.77
Loss: 0.47, accuracy: 0.81
Loss: 0.50, accuracy: 0.81
Loss: 0.55, accuracy: 0.82
Loss: 0.54, accuracy: 0.82
Loss: 0.46, accuracy: 0.84
Loss: 0.42, accuracy: 0.83
Loss: 0.49, accuracy: 0.79
Loss: 0.41, accuracy: 0.88
Loss: 0.45, accuracy: 0.84
Loss: 0.40, accuracy: 0.86
Loss: 0.32, accuracy: 0.88
Loss: 0.45, accuracy: 0.84
Loss: 0.46, accuracy: 0.84
Loss: 0.43, accuracy: 0.85
Loss: 0.35, accuracy: 0.84
Loss: 0.52, accuracy: 0.80
Loss: 0.30, accuracy: 0.91
Loss: 0.61, accuracy: 0.78
Loss: 0.43, accuracy: 0.82
Loss: 0.30, accuracy: 0.89
Loss: 0.35, accuracy: 0.88
Loss: 0.39, accuracy: 0.86
Loss: 0.40, accuracy: 0.83
Loss: 0.46, accuracy: 0.82
