In [1]:
import numpy as np


class Node:
    def __init__(self, value, func='variable', parents=[], *args):
        self.value = value
        self.func = func
        self.parents = parents
        self.args = args

    def __str__(self):
        return f'({self.value}, {self.func}, {self.args}, {self.parents})'

    def __repr__(self):
        return f'({self.value}, {self.func}, {self.args}, {self.parents})'

In [2]:
def add(a: Node, b: Node):
    # unbox
    va = a.value
    vb = b.value

    # primitive
    vc = np.add(va, vb)

    # box
    c = Node(vc, 'add', [a, b])
    c.args = (va, vb)
    return c

In [3]:
def negative(a: Node):
    # unbox
    va = a.value

    # primitive
    vb = np.negative(va)

    # box
    b = Node(vb, 'negative', [a])
    b.args = (va,)
    return b

In [4]:
def exp(a: Node):
    # unbox
    va = a.value

    # primitive
    vb = np.exp(va)

    # box
    b = Node(vb, 'exp', [a])
    b.args = (va,)
    return b

In [5]:
def reciprocal(a: None):
    # unbox
    va = a.value

    # primitive
    vb = np.reciprocal(va)

    # box
    b = Node(vb, 'reciprocal', [a])
    b.args = (va,)
    return b

In [6]:
def logistic(i):
    return reciprocal(add(Node(1), exp(negative(i))))

In [7]:
z = Node(1.5)
logsit = logistic(z)

In [8]:
def add_vjp(g, ans, a, b):
    return g, g

In [9]:
def exp_vjp(g, ans, a):
    return (ans * g,)

In [10]:
def negative_vjp(g, ans, a):
    return (-1 * g,)

In [11]:
def reciprocal_vjp(g, ans, a):
    return (np.divide(-1, a * a) * g,)

In [12]:
def variable_vjp(g, ans):
    return (g,)

In [13]:
vjps = {
    "add": add_vjp,
    "negative": negative_vjp,
    "exp": exp_vjp,
    "reciprocal": reciprocal_vjp,
    "variable": variable_vjp
}

In [14]:
def backward_pass(g, end_node):
    tmp_node = Node(end_node.value, parents=[end_node])
    q = []
    gs = {tmp_node: g}
    q.append(tmp_node)
    while len(q) > 0:
        cur_node = q.pop(0)
        cur_gs = gs[cur_node]
        for node, cur_g in zip(cur_node.parents, cur_gs):
            q.append(node)
            vjp = vjps[node.func]
            grads = vjp(cur_g, node.value, *node.args)
            if node not in gs:
                gs[node] = grads
            else:
                gs[node] += grads
    return gs


In [15]:
gs = backward_pass((1.,), logsit)

In [16]:
gs[z]

(0.14914645207033284,)

In [17]:
import torch

tz = torch.tensor(1.5, requires_grad=True)
# reciprocal(add(Node(1), exp(negative(i))))
y = torch.reciprocal(torch.add(1., torch.exp(torch.negative(tz))))
y.backward(torch.tensor(1.))

In [18]:
assert tz.grad.cpu().detach().numpy() - gs[z] < 0.00001

In [19]:
b1 = torch.tensor([2., 3.], requires_grad=True)
d = torch.exp(b1)
d.backward(gradient=torch.tensor([1., 1.]))