In [1]:
from graphviz import Digraph
from graphviz import Source
import torch
from torch.autograd import Variable, Function


def iter_graph(root, callback):
    queue = [root]
    seen = set()
    while queue:
        fn = queue.pop()
        if fn in seen:
            continue
        seen.add(fn)
        for next_fn, _ in fn.next_functions:
            if next_fn is not None:
                queue.append(next_fn)
        callback(fn)

def register_hooks(var):
    fn_dict = {}
    def hook_cb(fn):
        def register_grad(grad_input, grad_output):
            fn_dict[fn] = grad_input
        fn.register_hook(register_grad)
    iter_graph(var.grad_fn, hook_cb)

    def is_bad_grad(grad_output):
        if grad_output is None:
                return True
        grad_output = grad_output.data
        return torch.isnan(grad_output).any() or grad_output.gt(1e6).any()

    def make_dot():
        node_attr = dict(style='filled',
                        shape='box',
                        align='left',
                        fontsize='12',
                        ranksep='0.1',
                        height='0.2')
        dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))

        def size_to_str(size):
            return '('+(', ').join(map(str, size))+')'

        def build_graph(fn):
            if hasattr(fn, 'variable'):  # if GradAccumulator
                u = fn.variable
                node_name = 'Variable\n ' + size_to_str(u.size())
                dot.node(str(id(u)), node_name, fillcolor='lightblue')
            else:
                assert fn in fn_dict, fn
                fillcolor = 'white'
                if any(is_bad_grad(gi) for gi in fn_dict[fn]):
                    fillcolor = 'red'
                dot.node(str(id(fn)), str(type(fn).__name__), fillcolor=fillcolor)
            for next_fn, _ in fn.next_functions:
                if next_fn is not None:
                    next_id = id(getattr(next_fn, 'variable', next_fn))
                    dot.edge(str(next_id), str(id(fn)))
        iter_graph(var.grad_fn, build_graph)

        return dot

    return make_dot

if __name__ == '__main__':
    x = Variable(torch.randn(10, 10), requires_grad=True)
    y = Variable(torch.randn(10, 10), requires_grad=True)

    z = x / (y * 0)
    z = z.sum() * 2
    get_dot = register_hooks(z)
    z.backward()
    dot = get_dot()
    path = 'aDebugGraph.dot'
    dot.save(path)
    print(dot)
    
    s = Source.from_file(path)
    s.render('aDebugGraph', format='png', cleanup=True)


digraph {
	graph [size="12,12"]
	node [align=left fontsize=12 height=0.2 ranksep=0.1 shape=box style=filled]
	140349087695104 [label=MulBackward0 fillcolor=red]
	140348414082544 -> 140349087695104
	140348414082544 [label=SumBackward0 fillcolor=white]
	140348408787776 -> 140348414082544
	140348408787776 [label=DivBackward0 fillcolor=red]
	140349086023824 -> 140348408787776
	140348408787920 -> 140348408787776
	140348408787920 [label=MulBackward0 fillcolor=red]
	140349086023744 -> 140348408787920
	140349086023744 [label="Variable
 (10, 10)" fillcolor=lightblue]
	140349086023824 [label="Variable
 (10, 10)" fillcolor=lightblue]
}



In [2]:
import torch
import torch.nn as nn

class SimpleNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)  # Input layer to hidden layer
        self.relu = nn.ReLU()  # Activation function
        self.fc2 = nn.Linear(hidden_size, output_size)  # Hidden layer to output layer
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x


In [3]:
x = Variable(torch.randn(1, 13), requires_grad=True)
y = Variable(torch.randn(1, 2), requires_grad=True)


model = SimpleNN(13, 100, 2)
optimizer = torch.optim.Adam(model.parameters(), lr = 0.1, weight_decay=0.01)
pred = model(x)

loss = torch.nn.functional.mse_loss(y, pred)

get_dot = register_hooks(pred)
loss.backward()
dot = get_dot()
path = 'aDebugGraph.dot'
dot.save(path)
print(dot)

s = Source.from_file(path)
s.render('aDebugGraph', format='png', cleanup=True)


digraph {
	graph [size="12,12"]
	node [align=left fontsize=12 height=0.2 ranksep=0.1 shape=box style=filled]
	140349087690912 [label=AddmmBackward0 fillcolor=white]
	140349087626432 -> 140349087690912
	140349087689856 -> 140349087690912
	140349087691104 -> 140349087690912
	140349087691104 [label=TBackward0 fillcolor=white]
	140349087626032 -> 140349087691104
	140349087626032 [label="Variable
 (2, 100)" fillcolor=lightblue]
	140349087689856 [label=ReluBackward0 fillcolor=white]
	140349087686448 -> 140349087689856
	140349087686448 [label=AddmmBackward0 fillcolor=white]
	140349087615312 -> 140349087686448
	140349087626512 -> 140349087686448
	140349087696928 -> 140349087686448
	140349087696928 [label=TBackward0 fillcolor=white]
	140349087614672 -> 140349087696928
	140349087614672 [label="Variable
 (100, 13)" fillcolor=lightblue]
	140349087626512 [label="Variable
 (1, 13)" fillcolor=lightblue]
	140349087615312 [label="Variable
 (100)" fillcolor=lightblue]
	140349087626432 [label="Variable
 

'aDebugGraph.png'