In [5]:
import numpy as np
from dezero import Variable

def _dot_var(v, verbose = False):
    dot_var = '{} [label="{}", color=orange, style=filled]\n'
    
    name = '' if v.name is None else v.name
    if verbose and v.data is not None:
        if v.name is not None:
            name +=': '
        name += str(v.shape) + ' ' + str(v.dtype)
    return dot_var.format(id(v), name)

In [6]:
x = Variable(np.random.randn(2, 3))
x.name = 'x'
print(_dot_var(x))
print(_dot_var(x, verbose=True))

2142175016560 [label="x", color=orange, style=filled]

2142175016560 [label="x: (2, 3) float64", color=orange, style=filled]



In [7]:
def _dot_func(f):
    dot_func = '{} [label="{}", color=lightblue, style=filled, shape=box]\n'
    txt = dot_func.format(id(f), f.__class__.__name__)
    
    dot_edge = '{} -> {}\n'
    for x in f.inputs:
        txt += dot_edge.format(id(x), id(f))
    for y in f.outputs:
        txt += dot_edge.format(id(f), id(y())) # y는 약한 참조(weakref, 17.4절 참고)
    return txt

In [8]:
x0 = Variable(np.array(1.0))
x1 = Variable(np.array(1.0))
y = x0 + x1
txt = _dot_func(y.creator)
print(txt)

2142176008896 [label="Add", color=lightblue, style=filled, shape=box]
2142176007504 -> 2142176008896
2142176009472 -> 2142176008896
2142176008896 -> 2142176006208



In [None]:
def get_dot_graph(output, verbose=True):
    txt=''
    funcs = []
    seen_set = set()
    
    def add_func(f):
        if f not in seen_set:
            funcs.append(f)
            