In [4]:
import torch

from fastcore.test import test_close

In [5]:
N = 1000
d = 48

X = torch.randn((N, d))
w = torch.randn((d, 1))
b = torch.randn(1)
y = torch.randn((N, 1))

In [6]:
# Manual Pass
z = X @ w + b
a = z.clamp_min(0)
mse_loss = (a - y).pow(2).mean(dim=0) 

z.shape, a.shape, a.min(), mse_loss

(torch.Size([1000, 1]), torch.Size([1000, 1]), tensor(0.), tensor([14.4401]))

### Notes
* See this [blog post](https://nasheqlbrm.github.io/blog/posts/2021-11-13-backward-pass.html)

In [11]:
class Linear():
    def __init__(self, w, b):
        self.w = w
        self.b = b

    def __call__(self, x):
        return x @ self.w + self.b


class Relu():
    def __init__(self):
        pass

    def __call__(self, a):
        return a.clamp_min(0)


class MSE():
    def __init__(self):
        pass

    def __call__(self, target, pred):
        return (target - pred).pow(2).mean(dim=0)


class Model():

    def __init__(self, w, b):
        self.layers = [
            Linear(w, b),
            Relu()
        ]
        self.loss = MSE()

    def __call__(self, x, y):
        for layer in self.layers:
            x = layer(x)
        return self.loss(y, x)
    

### Ensure the forward Pass is correct

In [12]:
# No wrapping via Model
linear_layer1 = Linear(w, b)
z = linear_layer1(X)
a = Relu()(z)
MSE()(y, a)

tensor([14.4401])

In [13]:
# Wrapping via Model
model = Model(w, b)
model(X, y)

tensor([14.4401])

### Compute gradient via pytorch so we have a baseline to compare against

In [14]:
class Linear():
    def __init__(self, w, b):
        self.w = w
        self.b = b

    def __call__(self, x):
        return x @ self.w + self.b


class Relu():
    def __init__(self):
        pass

    def __call__(self, a):
        a.retain_grad()
        self.a = a
        return a.clamp_min(0)


class MSE():
    def __init__(self):
        pass

    def __call__(self, target, pred):
        pred.retain_grad()
        self.target = target
        self.pred = pred
        return (target - pred).pow(2).mean(dim=0)


class Model():

    def __init__(self, w, b):
        self.layers = [
            Linear(w, b),
            Relu()
        ]
        self.loss = MSE()

    def __call__(self, x, y):
        for layer in self.layers:
            x = layer(x)
        return self.loss(y, x)
    

In [15]:
def mkgrad(x): 
    return x.clone().requires_grad_(True)

chks = w, b, X
ptgrads = w_prime, b_prime, X_prime = tuple(map(mkgrad, chks))

In [16]:
model_prime = Model(w_prime, b_prime)
loss_prime = model_prime(X_prime, y)
loss_prime.backward()

# w_prime.grad, b_prime.grad, model_prime.loss.pred.grad[0:10], model_prime.layers[-1].a.grad[0:10]

### Update classes to compute gradient

In [17]:
class Linear():
    def __init__(self, w, b):
        self.w = w
        self.b = b

    def __call__(self, x):
        self.input = x
        self.output = x @ self.w + self.b
        return self.output

    def backward(self):

        # Will be used for chain rule
        dloss_doutput = self.output.g

        # Gradient of loss wrt w
        X = self.input
        doutput_dw = X.T
        self.w.g = doutput_dw @ dloss_doutput
 
        # Gradient of loss wrt b
        doutput_db = 1
        self.b.g = (dloss_doutput * doutput_db).sum(dim=0)


class Relu():
    def __init__(self):
        pass

    def __call__(self, x):
        self.input = x
        self.output = x.clamp_min(0)
        return self.output
    
    def backward(self):
        dloss_doutput = self.output.g
        doutput_dinput = (self.input > 0).float()
        # Chain rule - Element wise multiplication
        self.input.g = dloss_doutput * doutput_dinput 


class MSE():
    def __init__(self):
        pass

    def __call__(self, target, pred):
        self.input = pred
        self.target = target
        self.out = (target - pred).pow(2).mean(dim=0)
        return self.out 
    
    def backward(self):
        N = self.target.shape[0]
        pred = self.input
        target = self.target
        # No chain rule - this is a leaf node
        dloss_dpred = (2 / N) * (pred - target)
        self.input.g = dloss_dpred



class Model():
    def __init__(self, w, b):
        self.layers = [
            Linear(w, b),
            Relu()
        ]
        self.loss = MSE()

    def __call__(self, x, y):
        for layer in self.layers:
            x = layer(x)
        return self.loss(y, x)
    
    def backward(self):
        pass

In [18]:
# Wrapping via Model
model = Model(w, b)
loss = model(X, y)
loss

tensor([14.4401])

In [19]:
model.loss.backward()
model.layers[-1].backward()
model.layers[-2].backward()

In [20]:
# Gradient of loss wrt pred
test_close(model.loss.input.g[0:10], model_prime.loss.pred.grad[0:10])
# Gradient of loss wrt input to relu
test_close(model.layers[-1].input.g[0:10], model_prime.layers[-1].a.grad[0:10])
# Gradient of loss wrt w
test_close(model.layers[-2].w.g[0:10], model_prime.layers[-2].w.grad[0:10])
# Gradient of loss wrt b
test_close(model.layers[-2].b.g[0:10], model_prime.layers[-2].b.grad[0:10])

### Refactor: Add forward method, backward to model, use a module

In [53]:
class Module():
    def __call__(self, *args):
        self.input = args
        self.output = self.forward(*args)
        return self.output
    
    def forward(self):
        raise Exception('Not Implemented')
    
    def backward(self):
        self.bwd(self.output, *self.input)
    
    def bwd(self):
        raise Exception('Not Implemented')

class Linear(Module):
    def __init__(self, w, b):
        self.w = w
        self.b = b

    def forward(self, x):
        return x @ self.w + self.b

    def bwd(self, output, input_):

        # Will be used for chain rule
        dloss_doutput = output.g

        # Gradient of loss wrt w
        X = input_
        doutput_dw = X.T
        self.w.g = doutput_dw @ dloss_doutput
 
        # Gradient of loss wrt b
        doutput_db = 1
        self.b.g = (dloss_doutput * doutput_db).sum(dim=0)


class Relu(Module):
    def forward(self, x):
        return x.clamp_min(0)
    
    def bwd(self, output, input_):
        dloss_doutput = output.g
        doutput_dinput = (input_ > 0).float()
        # Chain rule - Element wise multiplication
        input_.g = dloss_doutput * doutput_dinput 


class MSE(Module):    
    def forward(self, target, pred):
        return (target - pred).pow(2).mean(dim=0)
    
    def bwd(self, output, target, input_):
        N = target.shape[0]
        # No chain rule - this is a leaf node
        dloss_dinput = (2 / N) * (input_ - target)
        input_.g = dloss_dinput


class Model():
    def __init__(self, w, b):
        self.layers = [
            Linear(w, b),
            Relu()
        ]
        self.loss = MSE()

    def __call__(self, x, y):
        for layer in self.layers:
            x = layer(x)
        return self.loss(y, x)
    
    def backward(self):
        self.loss.backward()
        for layer in reversed(self.layers):
            layer.backward()

In [54]:
# Wrapping via Model
model = Model(w, b)
loss = model(X, y)
loss

tensor([14.4401])

In [55]:
model.backward()

In [62]:
# Gradient of loss wrt w
test_close(w.g[0:10], model_prime.layers[-2].w.grad[0:10])
# Gradient of loss wrt b
test_close(b.g[0:10], model_prime.layers[-2].b.grad[0:10])