In [1]:
import torch

In [20]:
# Must specificy the gradient arg if we want to calculate it later
x = torch.randn(3, requires_grad=True)
x

tensor([ 1.0551,  0.7340, -0.9781], requires_grad=True)

In [5]:
# Calculate the gradient f(x)
# Calculates the computational graph
y = x+2
y

tensor([1.2633, 1.7773, 2.9491], grad_fn=<AddBackward0>)

In [15]:
z = y*y*2
z

tensor([ 3.1917,  6.3173, 17.3946], grad_fn=<MulBackward0>)

In [13]:
z = z.mean()

In [14]:
# Calculate the gradient of Z with respect to X
z.backward() #dz/dx | Since z is a scala value we don't have to use an arg inside the func
x.grad

tensor([3.3687, 4.7394, 7.8643])

In [16]:
# Without scala values
v = torch.tensor([0.1, 1.0, 0.001], dtype=torch.float32)
z.backward(v)
x.grad

tensor([ 3.8741, 11.8484,  7.8761])

### Prevent Gradient Tracking

In [17]:
# Disable gradient tracking from x e.g x.requires_grad_(false)
# x.detach() -> creates a new tensor that does not require the gradient
# wrap using with -> with torch.no_grad():

In [19]:
# Disable gradient tracking from x e.g x.requires_grad_(false)
x.requires_grad_(False)
x

tensor([-0.7367, -0.2227,  0.9491])

In [21]:
# x.detach() -> creates a new tensor that does not require the gradient
yy = x.detach()
yy

tensor([ 1.0551,  0.7340, -0.9781])

In [25]:
# wrap using with -> with torch.no_grad():
with torch.no_grad():
    yy = x + 2
yy

tensor([3.0551, 2.7340, 1.0219])

### Dummy Example

In [29]:
weights = torch.ones(4, requires_grad=True)

In [30]:
for epoch in range(1):
    model_output = ((weights * 3).sum())
    
    model_output.backward()
    
weights.grad

tensor([3., 3., 3., 3.])

In [31]:
# Notice how the values are accumulating
for epoch in range(2):
    model_output = ((weights * 3).sum())
    
    model_output.backward()
    
weights.grad

tensor([9., 9., 9., 9.])

In [33]:
# Empty the gradients before next iteration
for epoch in range(3):
    model_output = ((weights * 3).sum())
    model_output.backward()
    print(weights.grad)
    weights.grad.zero_()

tensor([3., 3., 3., 3.])
tensor([3., 3., 3., 3.])
tensor([3., 3., 3., 3.])
