# PyTorch Basics - Computation Graph

Make Your First GAN With PyTorch, 2020

In [1]:
from rich import print

In [2]:
import torch

## Simple Computation Graph

```
  (x) --> (y) --> (z)
```

> y = x^2
>
> z = 2y + 3

In [3]:
# set up simple graph relating x, y and z

x = torch.tensor(3.5, requires_grad=True)

y = x*x

z = 2*y + 3

In [4]:
# work out gradients

z.backward()

# what is gradient at x = 3.5

print(x.grad)  # dz/dx since z.backward() is called

In [5]:
x = torch.tensor(3.5, requires_grad=True)
y = x*x
z = 2*y + 3

y.backward()
print(f"{x.grad = }")    # dy/dx

## Computation Graph With Multiple Links To A Node

```

  (a) --> (x)
       \ /     \
       .       (z)
      / \     /
  (b) --> (y)

 
  x = 2a + 3b
 
  y = 5a^2 + 3b^3
 
  z = 2x + 3y

```

In [6]:
# set up simple graph relating x, y and z

a = torch.tensor(2.0, requires_grad=True)
b = torch.tensor(1.0, requires_grad=True)

x = 2*a + 3*b

y = 5*a*a + 3*b*b*b

z = 2*x + 3*y

In [7]:
# work out gradients

z.backward()

In [8]:
# what is gradient at a = 2.0

a.grad

## Manually check PyTorch Result


```

dz/da = dz/dx * dx/da + dz/dy * dy/da

      = 2 * 2 + 3 * 10a

      = 4  + 30a

When a = 3.5, dz/da = 64  ... correct!

```

