In [1]:
from d2l import torch as d2l
import torch
from torch import nn

In [2]:
hidden_dim, inp_dim, ouput_dim = 256, 784, 10
W1 = nn.Parameter(torch.rand(inp_dim, hidden_dim, requires_grad=True)*0.01)
b1 = nn.Parameter(torch.zeros(hidden_dim, requires_grad=True))
W2 = nn.Parameter(torch.rand(hidden_dim, ouput_dim, requires_grad=True)*0.01)
b2 = nn.Parameter(torch.zeros(ouput_dim, requires_grad=True))

params = [W1, b1, W2, b2]

In [3]:
def relu(x):
    return x*(x > 0)
relu(torch.tensor([[0, -1, 5], [1, 2, -2]]))

tensor([[0, 0, 5],
        [1, 2, 0]])

In [4]:
def softmax(x):
    return torch.exp(x)/torch.exp(x).sum(1, keepdim=True);
softmax(torch.tensor([[0, 0, 0], [0, 0, 0]]))

tensor([[0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333]])

In [5]:
# net(torch.ones(5, 1, 28, 28))

In [6]:
def net(X):
    X = X.reshape((-1, inp_dim))
    return relu(X@W1 + b1)@W2 + b2

In [7]:
def cross_entropy(y_hat, y):
    max_ok = torch.max(y_hat, dim=1, keepdim=True).values
    return -(y_hat[range(len(y_hat)), y] - max_ok.ravel() - torch.log(torch.exp(y_hat - max_ok).sum(dim=1))).mean()

print(cross_entropy(torch.tensor([[0, 0, 0.7, 0.3], [0, 0, 1, 0]]), torch.tensor([3, 0])).mean())

tensor(1.5617)


In [8]:
torch.max(torch.tensor([[0, 1, 2], [4, 5, 6]]), dim=1, keepdim=True).values

tensor([[2],
        [6]])

In [9]:
def accuracy(y_pred, y_true):
    y_pred = torch.argmax(y_pred, dim=1)
    accuracy = torch.tensor((y_true == y_pred), dtype=torch.float32)
    return accuracy.mean()
accuracy(torch.tensor([[0, 1, 1], [1, 0, 0], [1, 0, 0]]), torch.tensor([1, 1, 1]))

  accuracy = torch.tensor((y_true == y_pred), dtype=torch.float32)


tensor(0.3333)

In [10]:
lr = 0.05
batch_size = 256

In [11]:
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)

In [12]:
updater = torch.optim.SGD(params, lr=lr)

In [13]:
def l2_penalty(params):
    return sum(torch.pow(param, 2).sum() for param in params)
# l2_penalty([[1, 2, 3], [1, 2, 3]])

In [14]:
l2_penalty(params)

tensor(6.7786, grad_fn=<AddBackward0>)

In [15]:
def train_epoch(net, train_iter, loss, updater, lambda_loss):
    with torch.autograd.set_detect_anomaly(True):
        for i, (X, y) in enumerate(train_iter):
    #         print(y.shape, X.shape)
            one_hot_y = torch.zeros(y.shape[0], 10)
            one_hot_y[torch.arange(y.shape[0]), y] = 1

            y_pred = net(X)
    #         if y_pred.isnan().sum():
    #             print(f'{y_pred.isnan().sum()} nan')
    #             print(X)
    #             print(y)
    #             print(params)
    #             break
            acc = accuracy(y_pred, y)
            l = loss(y_pred, y) + lambda_loss*l2_penalty(params)

            updater.zero_grad()
            l.sum().backward()
            updater.step()

            if (i % 100 == 0):
                print(f'Accuracy {acc:.2f}, Loss: {l.sum():.2f}')

In [16]:
def train(net, train_iter, num_epochs, loss, updater):
    for epoch in range(num_epochs):
        train_epoch(net, train_iter, loss, updater, 0.02)

In [17]:
num_epochs = 10
train(net, train_iter, num_epochs, cross_entropy, updater)

  accuracy = torch.tensor((y_true == y_pred), dtype=torch.float32)


Accuracy 0.12, Loss: 2.44
Accuracy 0.29, Loss: 1.72
Accuracy 0.63, Loss: 1.31
Accuracy 0.60, Loss: 1.32
Accuracy 0.67, Loss: 1.13
Accuracy 0.72, Loss: 1.06
Accuracy 0.73, Loss: 1.10
Accuracy 0.75, Loss: 1.01
Accuracy 0.73, Loss: 1.10
Accuracy 0.76, Loss: 1.07
Accuracy 0.76, Loss: 1.08
Accuracy 0.71, Loss: 1.12
Accuracy 0.81, Loss: 0.96
Accuracy 0.74, Loss: 1.10
Accuracy 0.79, Loss: 1.04
Accuracy 0.78, Loss: 1.01
Accuracy 0.71, Loss: 1.10
Accuracy 0.80, Loss: 0.97
Accuracy 0.80, Loss: 1.01
Accuracy 0.78, Loss: 1.00
Accuracy 0.78, Loss: 1.05
Accuracy 0.84, Loss: 0.94
Accuracy 0.84, Loss: 0.96
Accuracy 0.82, Loss: 0.96
Accuracy 0.79, Loss: 1.00
Accuracy 0.81, Loss: 1.04
Accuracy 0.85, Loss: 0.94
Accuracy 0.79, Loss: 1.01
Accuracy 0.81, Loss: 1.01
Accuracy 0.82, Loss: 0.98
