In [2]:
import torch

In [3]:
# Create a tensor with requires_grad=True to track computations
x = torch.tensor([2.0, 3.0], requires_grad=True)
print("x:", x)
# Output: x: tensor([2., 3.], requires_grad=True)

x: tensor([2., 3.], requires_grad=True)


In [None]:
# Perform operations
y = x ** 2
z = y.sum()
print("z:", z)
# Output: z: tensor(13., grad_fn=<SumBackward0>)
# z  = x1^2 + x2^2 = 4 + 9 = 13 

z: tensor(13., grad_fn=<SumBackward0>)


In [None]:
# Compute gradients
z.backward()
print("Gradient of x:", x.grad)
# Output: Gradient of x: tensor([4., 6.])
# Explanation: 
# z = x1^2 + x2^2 = 2^2 + 3^2 = 13
# dz/dx = 2x = 2 * [2, 3] = [4, 6]

In [2]:
import graphviz

def draw_computational_graph():
    dot = graphviz.Digraph(format='png')
    dot.attr(rankdir='LR')  # Left to Right direction

    # Nodes for Forward Pass (Blue)
    dot.node('x', 'x\n[2.0, 3.0]', shape='rect', style='filled', fillcolor='#e1f5fe')
    dot.node('y', 'y = x²\n[4.0, 9.0]', shape='rect', style='filled', fillcolor='#e1f5fe')
    dot.node('z', 'z = Σy\n13.0', shape='rect', style='filled', fillcolor='#e1f5fe')

    # Nodes for Backward Pass (Red/Orange)
    # Gradients flow backwards
    dot.node('grad_z', '∂z/∂z\n1.0', shape='oval', style='filled', fillcolor='#ffccbc')
    dot.node('grad_y', '∂z/∂y\n[1.0, 1.0]', shape='oval', style='filled', fillcolor='#ffccbc')
    dot.node('grad_x', '∂z/∂x\n[4.0, 6.0]', shape='oval', style='filled', fillcolor='#ffccbc')

    # Edges for Forward Pass
    dot.edge('x', 'y', label='Square')
    dot.edge('y', 'z', label='Sum')

    # Edges for Backward Pass (Dashed)
    dot.edge('grad_z', 'grad_y', label='Distribute', style='dashed', color='red')
    dot.edge('grad_y', 'grad_x', label='2x', style='dashed', color='red')
    
    # Align forward and backward nodes for better layout
    with dot.subgraph() as s:
        s.attr(rank='same')
        s.node('z')
        s.node('grad_z')
    
    return dot

try:
    display(draw_computational_graph())
except ImportError:
    print("Graphviz library is not installed. Please install it to see the visualization: pip install graphviz")
    print("Note: You also need the Graphviz system binary installed.")
except Exception as e:
    print(f"An error occurred: {e}")

ExecutableNotFound: failed to execute WindowsPath('dot'), make sure the Graphviz executables are on your systems' PATH

<graphviz.graphs.Digraph at 0x1d99c6a2b90>