In [30]:
import numpy as np


class Node:
    def __init__(self, value, func='variable', parents=[], *args):
        self.value = value
        self.func = func
        self.parents = parents
        self.args = args

    def __str__(self):
        return f'({self.value}, {self.func}, {self.args}, {self.parents})'

    def __repr__(self):
        return f'({self.value}, {self.func}, {self.args}, {self.parents})'

In [2]:
def add(a: Node, b: Node):
    # unbox
    va = a.value
    vb = b.value

    # primitive
    vc = np.add(va, vb)

    # box
    c = Node(vc, 'add', [a, b])
    c.args = (va, vb)
    return c

In [3]:
def negative(a: Node):
    # unbox
    va = a.value

    # primitive
    vb = np.negative(va)

    # box
    b = Node(vb, 'negative', [a])
    b.args = (va,)
    return b

In [4]:
def exp(a: Node):
    # unbox
    va = a.value

    # primitive
    vb = np.exp(va)

    # box
    b = Node(vb, 'exp', [a])
    b.args = (va,)
    return b

In [5]:
def reciprocal(a: Node):
    # unbox
    va = a.value

    # primitive
    vb = np.reciprocal(va)

    # box
    b = Node(vb, 'reciprocal', [a])
    b.args = (va,)
    return b

In [29]:
def matmul(a: Node, b: Node):
    va = a.value
    vb = b.value

    vc = va @ vb

    c = Node(vc, 'matmul', [a, b])
    c.args = (va, vb)
    return c

In [6]:
def logistic(i):
    return reciprocal(add(Node(1), exp(negative(i))))

In [8]:
def add_vjp(g, ans, a, b):
    return g, g

In [9]:
def exp_vjp(g, ans, a):
    return (ans * g,)

In [10]:
def negative_vjp(g, ans, a):
    return (-1 * g,)

In [11]:
def reciprocal_vjp(g, ans, a):
    return (np.divide(-1, a * a) * g,)

In [12]:
def variable_vjp(g, ans):
    return (g,)

In [75]:
def matmul_vjp(g, ans, a, b):
    # a is a matrix
    # b is a vector
    b_d = np.matmul(np.transpose(a), g)
    r, c = a.shape
    # y=Wx
    # dy/dW = [x;x;x] element wise multi g 
    a_d = np.multiply(np.tile(b, (r, 1)), g.reshape(r, 1))

    return (a_d, b_d)

In [74]:
a = np.tile([1, 2, 3], (3, 1))
b = np.array([1, 2, 3])
np.multiply(a, b.reshape(3, 1))

array([[1, 2, 3],
       [2, 4, 6],
       [3, 6, 9]])

In [13]:
vjps = {
    "add": add_vjp,
    "negative": negative_vjp,
    "exp": exp_vjp,
    "reciprocal": reciprocal_vjp,
    "variable": variable_vjp
}

In [14]:
def backward_pass(g, end_node):
    tmp_node = Node(end_node.value, parents=[end_node])
    q = []
    gs = {tmp_node: g}
    q.append(tmp_node)
    while len(q) > 0:
        cur_node = q.pop(0)
        cur_gs = gs[cur_node]
        for node, cur_g in zip(cur_node.parents, cur_gs):
            q.append(node)
            vjp = vjps[node.func]
            grads = vjp(cur_g, node.value, *node.args)
            if node not in gs:
                gs[node] = grads
            else:
                gs[node] += grads
    return gs


In [20]:
z = Node(np.array([1.5, 1.5]))
logsit = logistic(z)

In [21]:
gs = backward_pass((1.,), logsit)

In [22]:
gs[z]

(array([0.14914645, 0.14914645]),)

In [26]:
import torch

tz = torch.tensor([1.5, 1.5], requires_grad=True)
# reciprocal(add(Node(1), exp(negative(i))))
y = torch.reciprocal(torch.add(1., torch.exp(torch.negative(tz))))
y.backward(torch.tensor([1., 1.]))

In [28]:
tz.grad

tensor([0.1491, 0.1491])

In [27]:
assert tz.grad.cpu().detach().numpy() - gs[z] < 0.00001

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

In [19]:
b1 = torch.tensor([2., 3.], requires_grad=True)
d = torch.exp(b1)
d.backward(gradient=torch.tensor([1., 1.]))

In [63]:
c1 = torch.tensor([1., 2, 3], requires_grad=True)
c2 = torch.tensor([
    [1., 2, 3],
    [4, 5, 6]
], requires_grad=True)
print(c1.size())
print(c2.size())

torch.Size([3])
torch.Size([2, 3])


In [64]:
y = torch.matmul(c2, c1)
print(y)

tensor([14., 32.], grad_fn=<MvBackward0>)


In [65]:
y.backward(gradient=torch.tensor([1., 2.]))

In [66]:
c1.grad

tensor([ 9., 12., 15.])

In [67]:
c2.grad

tensor([[1., 2., 3.],
        [2., 4., 6.]])

1. single -> batch
2. python -> cuda

In [82]:
ma = np.array([
    [1., 2, 3],
    [4, 5, 6]
])
mb = np.array([1., 2, 3])
mc = np.matmul(ma, mb)
print(mc)

[14. 32.]


In [83]:
matmul_vjp(np.array([1., 2.]), mc, ma, mb)

(array([[1., 2., 3.],
        [2., 4., 6.]]),
 array([ 9., 12., 15.]))