In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
def cmp(s, dt, t):
    ex = torch.all(dt == t.grad).item()
    app = torch.allclose(dt, t.grad)
    maxdiff = (dt - t.grad).abs().max().item()
    print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff:  {maxdiff}')

In [3]:
g = torch.Generator().manual_seed(111111)
X = torch.randn((10, 10), generator=g, requires_grad=True)
y = torch.randn((10, 1), generator=g, requires_grad=True)
W = torch.randn((10, 1), generator=g, requires_grad=True)
b = torch.randn(1, generator=g, requires_grad=True)

# Math Expression
matmul = X @ W
pred = matmul + b
diff = y - pred
diff2 = diff ** 2 
loss = diff2.mean(0)

for t in [matmul, pred, diff, diff2, loss]:
    t.retain_grad()

print(f'Loss:', loss)
loss.backward()

Loss: tensor([6.0275], grad_fn=<MeanBackward1>)


In [4]:
# X = m * n, W = n * 1,  matmul = m  * 1
# matmul = X @ W
# dmatmul / dX = W * dmatmul = (10, 10) = (10, 1) @ (1, 10) = (10, 10)
# dmatmul / dW = X * dmatmul = (n * 1) = (n * m) @ (m * 1) = (n, 1)
 
matmul.shape, X.shape, W.shape

(torch.Size([10, 1]), torch.Size([10, 10]), torch.Size([10, 1]))

In [5]:
# Exersice: Let's backpropagate through that entire math expression and 
# compare calculated gradiens with Pytorch
# calculates ones assuming that they are correct

ddiff2 = torch.ones_like(diff) * (1.0 / diff.shape[0]) * 1.0
ddiff = 2 * diff * ddiff2
dy = torch.ones_like(y) * ddiff
dpred = -torch.ones_like(pred) * ddiff
dmatmul = torch.ones_like(matmul) * dpred
db = (torch.ones_like(b) * dpred).sum(0)
dX = dmatmul @ W.T
dW = X.T @ dmatmul

cmp('diff2', ddiff2, diff2)
cmp('ddiff', ddiff, diff)
cmp('dy', dy, y)
cmp('dpred', dpred, pred)
cmp('dmatmul', dmatmul, matmul)
cmp('db', db, b)
cmp('dX', dX, X)
cmp('dW', dW, W)

diff2           | exact: True  | approximate: True  | maxdiff:  0.0
ddiff           | exact: True  | approximate: True  | maxdiff:  0.0
dy              | exact: True  | approximate: True  | maxdiff:  0.0
dpred           | exact: True  | approximate: True  | maxdiff:  0.0
dmatmul         | exact: True  | approximate: True  | maxdiff:  0.0
db              | exact: True  | approximate: True  | maxdiff:  0.0
dX              | exact: True  | approximate: True  | maxdiff:  0.0
dW              | exact: True  | approximate: True  | maxdiff:  0.0


In [6]:
g = torch.Generator().manual_seed(111111)
X = torch.randn((10, 10), generator=g, requires_grad=True)
y = torch.randint(0, 2, (10,1), dtype=torch.float32, generator=g, requires_grad=True)
W = torch.randn((10, 1), generator=g, requires_grad=True)
b = torch.randn(1, generator=g, requires_grad=True)

logits = X @ W + b
neg_logits = -logits
exp_neg_logits = neg_logits.exp()
probs = 1 / (1 + exp_neg_logits)
positive_probs = probs * y
negative_probs = (1 - probs) * (1 - y)
new_probs = positive_probs + negative_probs
log_probs = new_probs.log()
neg_log_probs = -log_probs
loss = neg_log_probs.mean(0)

for t in [logits, neg_logits, exp_neg_logits, probs,  positive_probs, negative_probs, new_probs, log_probs, neg_log_probs, loss]:
    t.retain_grad()
print('Loss:', loss)
print(f'Torch Loss:', F.binary_cross_entropy_with_logits(logits, y))
loss.backward()

Loss: tensor([1.1765], grad_fn=<MeanBackward1>)
Torch Loss: tensor(1.1765, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


In [7]:
dneg_log_probs = torch.ones_like(neg_log_probs) * (1.0 / neg_log_probs.shape[0])
dlog_probs = -torch.ones_like(log_probs) * dneg_log_probs
dnew_probs =  (1.0 / new_probs) * dlog_probs
dpositive_probs = torch.ones_like(positive_probs) * dnew_probs
dnegative_probs = torch.ones_like(negative_probs) * dnew_probs
dprobs = -(1 - y) * torch.ones_like(probs) * dnegative_probs
dy = -(1 - probs) * torch.ones_like(y) * dnegative_probs
dprobs += y * dpositive_probs
dy += probs * dpositive_probs
dexp_neg_logits = -(1 / (1 + exp_neg_logits)**2) * dprobs
dneg_logits = exp_neg_logits * dexp_neg_logits
dlogits = -torch.ones_like(logits) * dneg_logits
dX = dlogits @ W.T
dW = X.T @ dlogits
db = dlogits.sum(0, keepdim=True)

cmp('neg_log_probs', dneg_log_probs, neg_log_probs)
cmp('log_probs', dlog_probs, log_probs)
cmp('new_probs', dnew_probs, new_probs)
cmp('positive_probs', dpositive_probs, positive_probs)
cmp('negative_probs', dnegative_probs, negative_probs)
cmp('probs', dprobs, probs)
cmp('y', dy, y)
cmp('exp_neg_logits', dexp_neg_logits, exp_neg_logits)
cmp('neg_logits', dneg_logits, neg_logits)
cmp('logits', dlogits, logits)
cmp('X', dX, X)
cmp('W', dW, W)
cmp('b', db, b)

neg_log_probs   | exact: True  | approximate: True  | maxdiff:  0.0
log_probs       | exact: True  | approximate: True  | maxdiff:  0.0
new_probs       | exact: False | approximate: True  | maxdiff:  5.960464477539063e-08
positive_probs  | exact: False | approximate: True  | maxdiff:  5.960464477539063e-08
negative_probs  | exact: False | approximate: True  | maxdiff:  5.960464477539063e-08
probs           | exact: False | approximate: True  | maxdiff:  5.960464477539063e-08
y               | exact: False | approximate: True  | maxdiff:  5.960464477539063e-08
exp_neg_logits  | exact: False | approximate: True  | maxdiff:  5.960464477539063e-08
neg_logits      | exact: False | approximate: True  | maxdiff:  7.450580596923828e-09
logits          | exact: False | approximate: True  | maxdiff:  7.450580596923828e-09
X               | exact: False | approximate: True  | maxdiff:  1.4901161193847656e-08
W               | exact: False | approximate: True  | maxdiff:  5.960464477539063e-08
b  

In [8]:
x = torch.randn((5,5), generator=g, requires_grad=True)
w = torch.randn((5,1), generator=g, requires_grad=True)

pre = x @ w 
res = torch.relu(pre)
loss = res.mean(0)

for t in [x, w, pre, res, loss]:
    t.retain_grad()
loss.backward()

In [9]:
dres = torch.ones_like(res) / res.shape[0]
dpre = (pre > 0).int() * dres
dx = dpre @ w.T
dw = x.T @ dpre

cmp('res', dres, res)
cmp('pre', dpre, pre)
cmp('x', dx, x)
cmp('w', dw, w)


res             | exact: True  | approximate: True  | maxdiff:  0.0
pre             | exact: True  | approximate: True  | maxdiff:  0.0
x               | exact: True  | approximate: True  | maxdiff:  0.0
w               | exact: True  | approximate: True  | maxdiff:  0.0


In [10]:
g = torch.Generator().manual_seed(111111)
X = torch.randn((100, 10), generator=g, requires_grad=True)
W1= torch.randn((10, 200), generator=g, requires_grad=True)
b1 = torch.randn(200, generator=g, requires_grad=True)
W2 = torch.randn((200, 300),  generator=g, requires_grad=True)
b2 = torch.randn(300, generator=g, requires_grad=True)
W3 = torch.randn((300, 100), generator=g, requires_grad=True)
b3 = torch.randn(100, generator=g, requires_grad=True)


# Linear layer 1
h1_preact = X @ W1 + b1 # +
# Activation 1
h1 = torch.relu(h1_preact) # +
# Linear layer 2
h2_preact = h1 @ W2  + b2 # +
# Activation 2
h2 = torch.sigmoid(h2_preact) # +
# Linear layer 3
h3_preact = h2 @ W3 + b3 # 
# Activation 3
h3 = h3_preact.exp() # +
h3_sum = h3.sum(1, keepdim=True) # +
h3_sum_inv = h3_sum**-1 # +
probs = h3 * h3_sum_inv # +
# Outputs
outputs = probs.max(1, keepdim=True).values # +
# Cross Entropy Loss
logits = outputs.log() # +
loss = -logits.mean()  # +

for t in [h1_preact, h1, h2_preact, h2, h3_preact,  h3, h3_sum, h3_sum_inv, probs, outputs, logits, loss]:
    t.retain_grad()
    
print(f'Loss:',  loss.item())
loss.backward()

Loss: 0.1231394037604332


In [11]:
dlogits = -torch.ones_like(logits) / logits.shape[0]
doutputs = (1.0 / outputs) * dlogits
dprobs = F.one_hot(probs.max(1).indices, num_classes=probs.shape[1]) * doutputs
dh3 = dprobs * h3_sum_inv 
dh3_sum_inv = (dprobs * h3).sum(1, keepdim=True)
dh3_sum = (-1.0 *  h3_sum**-2) * dh3_sum_inv
dh3 += torch.ones_like(h3) * dh3_sum 
dh3_preact = h3  * dh3
dh2 = dh3_preact @ W3.T
dw3 = h2.T @  dh3_preact
db3 = dh3_preact.sum(0)
dh2_preact  = h2 * (1 - h2) * dh2
dh1 = dh2_preact @ W2.T
dw2 = h1.T @  dh2_preact
db2 = dh2_preact.sum(0)
dh1_preact = (h1_preact > 0).int() * dh1
dx = dh1_preact  @  W1.T 
dw1 = X.T  @ dh1_preact  
db1 = dh1_preact.sum(0)

cmp('logits', dlogits, logits)
cmp('outputs', doutputs, outputs)
cmp('probs', dprobs, probs)
cmp('h3_sum_inv', dh3_sum_inv, h3_sum_inv)
cmp('h3_sum', dh3_sum, h3_sum)
cmp('h3', dh3, h3)
cmp('h3_preact', dh3_preact, h3_preact)


cmp('dh2', dh2, h2)
cmp('dw3', dw3, W3)
cmp('db3', db3,  b3)
cmp('dh2_preact', dh2_preact, h2_preact)
cmp('dh1', dh1, h1)
cmp('dw2', dw2, W2)
cmp('db2', db2,  b2)
cmp('dh1_preact', dh1_preact, h1_preact)
cmp('dx', dx, X)
cmp('dw1', dw1, W1)
cmp('db1', db1, b1)

logits          | exact: True  | approximate: True  | maxdiff:  0.0
outputs         | exact: False | approximate: True  | maxdiff:  1.862645149230957e-09
probs           | exact: False | approximate: True  | maxdiff:  1.862645149230957e-09
h3_sum_inv      | exact: False | approximate: True  | maxdiff:  274877906944.0
h3_sum          | exact: False | approximate: True  | maxdiff:  4.235164736271502e-22
h3              | exact: False | approximate: True  | maxdiff:  4.235164736271502e-22
h3_preact       | exact: False | approximate: True  | maxdiff:  1.1827978596556932e-09
dh2             | exact: False | approximate: True  | maxdiff:  3.725290298461914e-09
dw3             | exact: False | approximate: True  | maxdiff:  1.4901161193847656e-08
db3             | exact: False | approximate: True  | maxdiff:  7.450580596923828e-09
dh2_preact      | exact: False | approximate: True  | maxdiff:  9.313225746154785e-10
dh1             | exact: False | approximate: True  | maxdiff:  5.58793544769

In [12]:
# f(X) = X*@W + b + sum(W^2)

X = torch.randn((100, 3), generator=g, requires_grad=True)
W = torch.randn((3, 1), generator=g, requires_grad=True)
b = torch.randn(1, generator=g, requires_grad=True)

In [13]:
linear = X @ W + b
W_2 = W**2
reg = W_2.sum(0)
preds = linear + reg
pred = preds.mean(0)

for t in [linear, W_2, reg, preds, pred]:
    t.retain_grad()

pred.backward()

In [14]:
dpreds = torch.ones_like(preds) / preds.shape[0]
dlinear = torch.ones_like(linear) * dpreds
dreg = (torch.ones_like(linear) * dpreds).sum(0)
dW_2 = torch.ones_like(W_2) * dreg
dW = 2 * W * dW_2
dW += X.T @ dlinear
dX = dlinear @ W.T
db = dlinear.sum(0)


cmp('preds', dpreds,  preds)
cmp('linear', dlinear, linear)
cmp('reg', dreg, reg)
cmp('W_2', dW_2, W_2)
cmp('W', dW, W)
cmp('X', dX, X)
cmp('b',db, b)

preds           | exact: True  | approximate: True  | maxdiff:  0.0
linear          | exact: True  | approximate: True  | maxdiff:  0.0
reg             | exact: True  | approximate: True  | maxdiff:  0.0
W_2             | exact: True  | approximate: True  | maxdiff:  0.0
W               | exact: True  | approximate: True  | maxdiff:  0.0
X               | exact: True  | approximate: True  | maxdiff:  0.0
b               | exact: True  | approximate: True  | maxdiff:  0.0
