In [1]:
%matplotlib inline
from IPython import display
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l
from tqdm import tqdm, trange
from collections import abc

In [2]:
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

In [3]:
num_inputs = 784
num_outputs = 10

W = torch.normal(0, 1, (num_inputs, num_outputs)).requires_grad_(True)
b = torch.zeros(num_outputs).requires_grad_(True)

In [4]:
def softmax(X):
        X_exp = torch.exp(X)
        partition = X_exp.sum(1, keepdim=True)
        return X_exp / partition

In [5]:
def net(X):
        return softmax(torch.mm(X.view((-1, num_inputs)), W) + b)


def cross_entropy(pred, label):
        return -torch.log(pred[range(len(pred)), label])

In [6]:
(X, y) = next(iter(train_iter))
print(X.shape, y.shape)
print(X.view(-1, num_inputs).shape)
pred = net(X)
print(pred.shape)
print(pred[range(len(pred)), y].shape)
print(pred.argmax(axis=1).shape)

torch.Size([256, 1, 28, 28]) torch.Size([256])
torch.Size([256, 784])
torch.Size([256, 10])
torch.Size([256])
torch.Size([256])


In [7]:
y = torch.tensor([0, 2])
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
print(y_hat)
print(y_hat[[0, 1], y])

print(cross_entropy(y_hat, y))

tensor([[0.1000, 0.3000, 0.6000],
        [0.3000, 0.2000, 0.5000]])
tensor([0.1000, 0.5000])
tensor([2.3026, 0.6931])


In [8]:
def accuracy(pred, label):
        if len(pred) > 1 and pred.shape[1] > 1:
                pred = pred.argmax(axis=1)
        cmp = pred.type(label.dtype) == label
        return float(cmp.type(torch.float32).mean())


accuracy(y_hat, y)

0.5

In [9]:
class Accumulator:
        def __init__(self, n):
                self.data = [0.0] * n

        def add(self, *args):
                self.data = [a + float(b) for a, b in zip(self.data, args)]

        def reset(self):
                self.data = [0.0] * len(self.data)

        def __getitem__(self, idx):
                return self.data[idx]


def evaluate_accuracy(data_iter, net):
        if isinstance(net, torch.nn.Module):
                net.eval()  # Set the model to evaluation mode
        metric = Accumulator(2)  # Accumulator for loss and accuracy
        with torch.no_grad():
                for X, y in data_iter:
                        metric.add(accuracy(net(X), y), y.numel())

        return metric[0] / metric[1]

In [10]:
def train_epoch(net, train_iter, lose, optimizer):
        if isinstance(net, torch.nn.Module):
                net.train()  # Set the model to training mode
        metric = Accumulator(3)  # Accumulator for loss, accuracy, and count
        for X, y in train_iter:
                y_hat = net(X)
                l = lose(y_hat, y)
                if isinstance(optimizer, torch.optim.Optimizer):
                        optimizer.zero_grad()
                        l.backward()
                        optimizer.step()
                        metric.add(float(l) * len(y), accuracy(y_hat, y), y.numel())
                else:
                        l.sum().backward()
                        optimizer(X.shape[0])
                        metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())
        return metric[0] / metric[2], metric[1] / metric[2]

In [17]:
class Animator:
        """For plotting data in animation."""

        def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,
                     ylim=None, xscale='linear', yscale='linear',
                     fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,
                     figsize=(3.5, 2.5)):
                """Defined in :numref:`sec_softmax_scratch`"""
                # Incrementally plot multiple lines
                if legend is None:
                        legend = []
                d2l.use_svg_display()
                self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize)
                if nrows * ncols == 1:
                        self.axes = [self.axes, ]
                # Use a lambda function to capture arguments
                self.config_axes = lambda: d2l.set_axes(
                        self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)
                self.X, self.Y, self.fmts = None, None, fmts

        def add(self, x, y):
                # Add multiple data points into the figure
                if not hasattr(y, "__len__"):
                        y = [y]
                n = len(y)
                if not hasattr(x, "__len__"):
                        x = [x] * n
                if not self.X:
                        self.X = [[] for _ in range(n)]
                if not self.Y:
                        self.Y = [[] for _ in range(n)]
                for i, (a, b) in enumerate(zip(x, y)):
                        if a is not None and b is not None:
                                self.X[i].append(a)
                                self.Y[i].append(b)
                self.axes[0].cla()
                for x, y, fmt in zip(self.X, self.Y, self.fmts):
                        self.axes[0].plot(x, y, fmt)
                self.config_axes()
                display.display(self.fig)
                d2l.plt.draw()
                d2l.plt.pause(0.01)
                display.clear_output(wait=True)
                

In [21]:
def train(net, train_iter, test_iter, loss, num_epochs, optimizer):
        animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3, 0.9],
                            legend=['train loss', 'train acc', 'test acc'])

        for epoch in trange(num_epochs):
                train_metrics = train_epoch(net, train_iter, loss, optimizer)
                test_acc = evaluate_accuracy(test_iter, net)
                animator.add(epoch + 1, train_metrics + (test_acc,))
                d2l.plt.show()

                train_loss, train_acc = train_metrics
                print(train_loss, train_acc, test_acc)


In [22]:
lr = 0.1
num_epochs = 10

def updater(batch_size):
        return d2l.sgd([W, b], lr, batch_size)

In [23]:
train(net, train_iter, test_iter, cross_entropy, num_epochs, updater)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:13<00:00,  1.30s/it]

0.6519392424265543 0.0031909071177244185 0.003176171875



