## Logistic Regression
* Minimal, from scratch implementation of multi-class softmax regression 
* The only pytorch functionality I am using is computing the gradient using `loss.backward()`
* Dataset is fashion MNIST

In [1]:
from d2l import torch as d2l
import torch
from IPython import display

In [2]:
num_inputs = 28*28
num_outputs = 10

In [3]:
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

In [4]:
W = torch.normal(0, 0.01, (num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)

#### Softmax Model
\begin{align}
logits&=XW+b\\y\_probs&=softmax\left(logits\right)\\softmax\left(z\right)_{i}&=\dfrac{\exp\left(z_{i}\right)}{\sum_{j=1}^{K}\exp\left(z_{j}\right)}
\end{align}

In [5]:
def net(X):
    logits = torch.mm(X.reshape(-1, num_inputs), W) + b
    log_exp = torch.exp(logits)
    y_probs = log_exp/log_exp.sum(axis=1).reshape(-1, 1)
    return y_probs

#### Cross Entropy Loss

For each example
$$cross\_entropy\left(y,y\_probs\right)=-\sum_{i=1}^{K}y_{i}\log y\_probs_{i}$$
Average over all examples
$$ loss=\dfrac{1}{N}\sum_{n=1}^{N}cross\_entropy\left(y,y\_probs\right)$$

In [6]:
def cross_entropy(y, y_probs):
    n_batch = y_batch.shape[0]
    loss_ex = -torch.log(
                        y_probs[range(0, n_batch), y_batch])
    
    # if we use torch.sum instead the training will fail probably because of large 
    # values of the loss
    loss = torch.mean(loss_ex)
    
    return loss

In [7]:
def accuracy(test_iter, net):
    num_correct, num_total = 0, 0
    for x_batch, y_batch in test_iter:
        y_probs = net(x_batch)
        y_hat = y_probs.argmax(axis=1)
        num_correct += (y_hat == y_batch).int().sum().item()
        num_total += y_batch.shape[0]
    
    return num_correct/num_total

In [8]:
lr = 0.1
num_epochs = 10

In [9]:
iter_counter = 0
for epoch in range(num_epochs):
    for x_batch, y_batch in train_iter:    
        y_probs = net(x_batch)
        loss = cross_entropy(y_batch, y_probs)
        
        if not W.grad is None:
            W.grad.zero_()
            b.grad.zero_()
        
        loss.backward()

        with torch.no_grad():
            W -= lr*W.grad
            b -= lr*b.grad
            
        iter_counter += 1
        if iter_counter % 100 == 1:
            with torch.no_grad():
                test_acc = accuracy(test_iter, net)
                print(f"{iter_counter}: loss = {loss.item():.3f}  "\
                          f" test acc: = {test_acc:.3f}")

1: loss = 2.316   test acc: = 0.180
101: loss = 0.771   test acc: = 0.764
201: loss = 0.633   test acc: = 0.774
301: loss = 0.466   test acc: = 0.802
401: loss = 0.514   test acc: = 0.804
501: loss = 0.687   test acc: = 0.796
601: loss = 0.507   test acc: = 0.817
701: loss = 0.456   test acc: = 0.816
801: loss = 0.485   test acc: = 0.820
901: loss = 0.553   test acc: = 0.825
1001: loss = 0.509   test acc: = 0.825
1101: loss = 0.513   test acc: = 0.825
1201: loss = 0.463   test acc: = 0.825
1301: loss = 0.456   test acc: = 0.826
1401: loss = 0.560   test acc: = 0.823
1501: loss = 0.492   test acc: = 0.832
1601: loss = 0.426   test acc: = 0.831
1701: loss = 0.408   test acc: = 0.831
1801: loss = 0.483   test acc: = 0.831
1901: loss = 0.548   test acc: = 0.835
2001: loss = 0.372   test acc: = 0.828
2101: loss = 0.515   test acc: = 0.834
2201: loss = 0.466   test acc: = 0.832
2301: loss = 0.383   test acc: = 0.833
