## Graph Class

In [1]:
class Graph():
    def __init__(self):
        self.operations = []
        self.placeholders = []
        self.variables = []
        
    def set_as_default(self):
        global _default_graph
        _default_graph = self

## Variables Class

In [2]:
class Variable():
    def __init__(self, value=None):
        self.value = value
        self.output_nodes = []
        
        _default_graph.variables.append(self)

## Placeholder Class

In [3]:
class Placeholder():
    def __init__(self):
        self.output_nodes = []
        _default_graph.placeholders.append(self)

## Operation Class

In [4]:
class Operation():
    def __init__(self, input_nodes=[]):
        self.input_nodes = input_nodes
        self.output_nodes = []
        
        for node in input_nodes:
            node.output_nodes.append(self)
        
        _default_graph.operations.append(self)
        
    def compute(self):
        pass

In [5]:
class add(Operation):
    def __init__(self, x, y):
        super().__init__([x, y])
    
    def compute(self, x_var, y_var):
        self.inputs = [x_var, y_var]
        return (x_var + y_var)

In [6]:
class multiply(Operation):
    def __init__(self, x, y):
        super().__init__([x, y])
    
    def compute(self, x_var, y_var):
        self.inputs = [x_var, y_var]
        return (x_var * y_var)

In [7]:
class matmul(Operation):
    def __init__(self, x, y):
        super().__init__([x, y])
    
    def compute(self, x_var, y_var):
        self.inputs = [x_var, y_var]
        return (x_var.dot(y_var))

## Session Class

In [8]:
def traverse_postorder(operation):
    nodes_postorder = []
    def recurse(node):
        if isinstance(node, Operation):
            for input_node in node.input_nodes:
                recurse(input_node)
        nodes_postorder.append(node)
    
    recurse(operation)
    return nodes_postorder

In [9]:
class Session():
    def run(self, operation, feed_dict={}):
        nodes_postorder = traverse_postorder(operation)
        
        for node in nodes_postorder:
            if type(node) == Placeholder:
                node.output = feed_dict[node]
            elif type(node) == Variable:
                node.output = node.value
            else:    # Operation
                node.inputs = [input_node.output for input_node in node.input_nodes]
                node.output = node.compute(*node.inputs)
                
            if type(node.output) == list:
                node.output = np.array(node.output)
        
        return operation.output

## Working

In [10]:
g = Graph()
g.set_as_default()

+ g.operations = []
+ g.variables = []
+ g.placeholders = []
+ _default_graph = g

In [11]:
A = Variable(10)
B = Variable(1)

+ A.value = 10
+ A.output_nodes = []
+ g.variables = [A]


+ B.value = 1
+ B.output_nodes = []
+ g.variables = [A, B]

In [12]:
x = Placeholder()

+ x.output_nodes = []
+ g.placeholders = [x]

In [13]:
y = multiply(A, x)
z = add(y, B)

+ y.input_nodes = [A, x]
+ y.output_nodes = []
+ A.output_nodes = [y]
+ x.output_nodes = [y]
+ g.operations = [y]


+ z.input_nodes = [y, B]
+ z.output_nodes = []
+ y.output_nodes = [z]
+ B.output_nodes = [z]
+ g.operations = [y, z]

In [14]:
sess = Session()
sess.run(operation=z, feed_dict={x:20})

201

+ nodes_postorder = [A, x, y, B, z]


+ A.output = 10
+ x.output = 20


+ y.inputs = [10, 20]
+ y.output = 200


+ B.output = 1


+ z.inputs = [200, 1]
+ z.output = 201