In [6]:
import torch
import torch.nn as nn
from torch.nn import functional as F


In [101]:
x = torch.randn((10,5), requires_grad=True)
y = torch.randint(0,3, (10,))
y

tensor([2, 1, 0, 1, 2, 1, 2, 0, 0, 0])

In [102]:
li = nn.Linear(in_features=5, out_features=10)
li2 = nn.Linear(10,3)

li_out = li(x)
out = li2(li_out)
out.retain_grad()
x.retain_grad()
li_out.retain_grad()

In [103]:
li2.weight.grad

In [104]:
loss = F.cross_entropy(out, y)
loss

tensor(1.1918, grad_fn=<NllLossBackward0>)

In [105]:
loss.backward()

In [106]:
-torch.log(out[range(10), y])  

tensor([0.5113, 1.1901, 1.0210,    nan, 0.6620, 2.4949, 0.3594, 1.0085, 0.6821,
           nan], grad_fn=<NegBackward0>)

In [107]:
class CrossEntropyLoss:

    def __call__(self,
                 y_pred: torch.Tensor,
                 y_true: torch.Tensor
                 ):
        n_samples = y_pred.shape[0]
        y_preds = y_pred - y_pred.max(1, keepdim=True).values
        counts = y_preds.exp()
        counts_sum = counts.sum(1, keepdim=True)
        counts_sum_inv = counts_sum ** -1
        probs = counts * counts_sum_inv
        logprobs = probs.log()
        self.out = -logprobs[range(n_samples), y_true].mean()
        return self.out

    def backward(self,
                 y_pred: torch.Tensor,
                 y_true: torch.Tensor
                 ):
        n_samples = y_pred.shape[0]
        # softmax = F.softmax()
        grad = F.softmax(y_pred, dim=1)
        grad[range(n_samples), y_true] -= 1
        grad = grad / n_samples
        return grad

    def paramerters(self):
        return []
    
lo = CrossEntropyLoss()
lo(out, y)

tensor(1.1918, grad_fn=<NegBackward0>)

In [108]:

dloss = lo.backward(out, y)

print(f"shape of dloss{dloss.shape}, out {out.shape}")
print(dloss)

shape of dlosstorch.Size([10, 3]), out torch.Size([10, 3])
tensor([[ 0.0240,  0.0347, -0.0587],
        [ 0.0299, -0.0671,  0.0373],
        [-0.0646,  0.0288,  0.0357],
        [ 0.0339, -0.0801,  0.0462],
        [ 0.0240,  0.0305, -0.0545],
        [ 0.0206, -0.0763,  0.0557],
        [ 0.0295,  0.0259, -0.0554],
        [-0.0686,  0.0274,  0.0413],
        [-0.0670,  0.0303,  0.0366],
        [-0.0860,  0.0463,  0.0397]], grad_fn=<DivBackward0>)


In [109]:
li2.weight.grad

tensor([[ 0.1641, -0.0714, -0.0778,  0.0775,  0.0259,  0.0220,  0.0946, -0.0208,
          0.0993,  0.0915],
        [-0.1298,  0.0790,  0.0594, -0.1384, -0.0275, -0.0181, -0.0883,  0.0483,
         -0.1244, -0.0467],
        [-0.0343, -0.0076,  0.0184,  0.0609,  0.0015, -0.0039, -0.0063, -0.0275,
          0.0250, -0.0448]])

In [110]:
print(li_out.shape, dloss.shape)
dw = dloss.T @ li_out
dw

torch.Size([10, 10]) torch.Size([10, 3])


tensor([[ 0.1641, -0.0714, -0.0778,  0.0775,  0.0259,  0.0220,  0.0946, -0.0208,
          0.0993,  0.0915],
        [-0.1298,  0.0790,  0.0594, -0.1384, -0.0275, -0.0181, -0.0883,  0.0483,
         -0.1244, -0.0467],
        [-0.0343, -0.0076,  0.0184,  0.0609,  0.0015, -0.0039, -0.0063, -0.0275,
          0.0250, -0.0448]], grad_fn=<MmBackward0>)

In [111]:
torch.allclose(dw, li2.weight.grad)

True

In [112]:
db = dloss.sum(dim = 0)
li2.bias.grad, db

(tensor([-0.1244,  0.0004,  0.1240]),
 tensor([-0.1244,  0.0004,  0.1240], grad_fn=<SumBackward1>))

In [113]:

dli_out =  dloss @ li2.weight
dli_out

tensor([[-3.0885e-03, -4.2290e-03, -1.7008e-02, -1.9706e-02,  6.0901e-03,
         -6.6475e-03, -1.6097e-02, -1.0117e-03, -4.1209e-03,  4.6926e-03],
        [ 1.2462e-03,  1.3961e-02, -1.0619e-02,  2.8985e-02, -1.1432e-02,
          1.2178e-02,  2.0143e-02,  3.1030e-03,  4.2178e-03,  5.3580e-03],
        [ 2.6719e-03, -9.9180e-03,  3.4079e-02, -6.2542e-03,  4.6740e-03,
         -4.7698e-03, -1.1945e-03, -2.1102e-03,  7.3434e-04, -1.2093e-02],
        [ 1.5972e-03,  1.6520e-02, -1.1650e-02,  3.4791e-02, -1.3646e-02,
          1.4544e-02,  2.4287e-02,  3.6749e-03,  5.1194e-03,  6.0543e-03],
        [-2.8978e-03, -3.4967e-03, -1.6629e-02, -1.7680e-02,  5.3689e-03,
         -5.8718e-03, -1.4578e-02, -8.4562e-04, -3.7682e-03,  4.6815e-03],
        [ 2.2472e-03,  1.4847e-02, -4.4143e-03,  3.4539e-02, -1.3051e-02,
          1.3957e-02,  2.4820e-02,  3.3243e-03,  5.4523e-03,  3.5500e-03],
        [-3.0229e-03, -2.2650e-03, -1.9320e-02, -1.6071e-02,  4.5882e-03,
         -5.0540e-03, -1.3668e-0

In [115]:
torch.allclose(li_out.grad, dli_out)

True

In [116]:
out.grad

tensor([[ 0.0240,  0.0347, -0.0587],
        [ 0.0299, -0.0671,  0.0373],
        [-0.0646,  0.0288,  0.0357],
        [ 0.0339, -0.0801,  0.0462],
        [ 0.0240,  0.0305, -0.0545],
        [ 0.0206, -0.0763,  0.0557],
        [ 0.0295,  0.0259, -0.0554],
        [-0.0686,  0.0274,  0.0413],
        [-0.0670,  0.0303,  0.0366],
        [-0.0860,  0.0463,  0.0397]])

In [117]:
dw1 = dli_out.T @ x
dw1, li.weight.grad
torch.allclose(dw1, li.weight.grad)

True

In [118]:
db1 = dli_out.sum(0)
torch.allclose(db1, li.bias.grad)

True

In [119]:
print(dli_out.shape, li.weight.shape)
dx = dli_out @ li.weight

torch.Size([10, 10]) torch.Size([10, 5])


In [121]:
x.grad, dx

(tensor([[-3.0510e-05, -3.7117e-05,  1.1785e-02,  7.2215e-03, -2.1346e-03],
         [-4.9092e-03, -1.2360e-02, -1.1880e-02,  1.4466e-04,  3.0314e-03],
         [ 5.4787e-03,  1.3741e-02, -2.3019e-03, -9.6354e-03, -5.5749e-04],
         [-5.7404e-03, -1.4454e-02, -1.4428e-02, -1.5749e-04,  3.6420e-03],
         [-2.1783e-04, -5.1056e-04,  1.0783e-02,  6.8930e-03, -1.9194e-03],
         [-4.7046e-03, -1.1858e-02, -1.5418e-02, -2.3170e-03,  3.6374e-03],
         [-7.8253e-04, -1.9288e-03,  1.0445e-02,  7.5358e-03, -1.7574e-03],
         [ 5.6815e-03,  1.4247e-02, -3.2364e-03, -1.0509e-02, -4.2391e-04],
         [ 5.7016e-03,  1.4301e-02, -2.2830e-03, -9.9589e-03, -6.0061e-04],
         [ 7.6506e-03,  1.9195e-02, -1.1681e-03, -1.2209e-02, -1.1501e-03]]),
 tensor([[-3.0510e-05, -3.7116e-05,  1.1785e-02,  7.2215e-03, -2.1346e-03],
         [-4.9092e-03, -1.2360e-02, -1.1880e-02,  1.4466e-04,  3.0314e-03],
         [ 5.4787e-03,  1.3741e-02, -2.3019e-03, -9.6354e-03, -5.5749e-04],
         [

In [None]:
class Linear:
    def __init__(self,
                 fan_in: int,
                 fan_out: int,
                 bias=True):
        self.weight = torch.randn((fan_in, fan_out)) / fan_in ** 0.5
        self.bias = torch.randn(fan_out) if bias else None

    def __call__(self,
                 X: torch.Tensor):
        self.last_input = X
        self.out = X @ self.weight
        if self.bias is not None:
            self.out += self.bias
        return self.out

    def backward(self, d_L_d_out):
        # d_L_d_weights = torch.matmul(self.last_input.t(), d_L_d_out)

        d_L_d_weights = self.last_input.T @ d_L_d_out
        d_L_d_biases = torch.sum(d_L_d_out, dim=0)
        d_L_d_input = d_L_d_out @ self.weight.T

        return d_L_d_input, d_L_d_weights, d_L_d_biases

    def parameters(self):
        return [self.weight] + ([] if self.bias is None else [self.bias])

In [37]:
logit_maxes = out.max(1, keepdim=True).values
norm_logits = out - logit_maxes #subtract the max for numerical stability refer the previous notebooks 
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdims = True)
counts_sum_inv = counts_sum ** -1  # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...
probs = counts * counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(10), y].mean()

In [38]:
loss

tensor(1.1548, grad_fn=<NegBackward0>)

In [39]:
logit_maxes, norm_logits

(tensor([[0.2092],
         [0.5867],
         [0.1870],
         [0.0802],
         [0.2657],
         [0.2117],
         [0.1744],
         [0.3677],
         [0.2169],
         [0.2833]], grad_fn=<MaxBackward0>),
 tensor([[-0.5269,  0.0000, -0.3183],
         [ 0.0000, -1.0006, -0.8468],
         [-0.0870, -0.0782,  0.0000],
         [-0.1711,  0.0000, -0.2601],
         [-0.0527,  0.0000, -0.5498],
         [-0.1575, -0.1237,  0.0000],
         [ 0.0000, -0.2957, -0.2016],
         [ 0.0000, -0.2819, -0.6581],
         [-0.0386, -0.0865,  0.0000],
         [ 0.0000, -0.1860, -0.6365]], grad_fn=<SubBackward0>))

In [40]:
counts

tensor([[0.5904, 1.0000, 0.7274],
        [1.0000, 0.3677, 0.4288],
        [0.9166, 0.9248, 1.0000],
        [0.8427, 1.0000, 0.7709],
        [0.9487, 1.0000, 0.5770],
        [0.8543, 0.8836, 1.0000],
        [1.0000, 0.7440, 0.8174],
        [1.0000, 0.7543, 0.5178],
        [0.9622, 0.9171, 1.0000],
        [1.0000, 0.8302, 0.5291]], grad_fn=<ExpBackward0>)