In [1]:
import numpy as np
import torch

#### 1- Computing Gradients:

In [2]:
x = torch.randn(3, requires_grad=True)
y = x + 2
print(x)          # Created by the user -> grad_fn is None
print(y)          # Created as a result of an operation -> has grad_fn
print(y.grad_fn)  # grad_fn references a Function that created the Tensor

tensor([-0.7222,  0.8070, -2.0910], requires_grad=True)
tensor([ 1.2778,  2.8070, -0.0910], grad_fn=<AddBackward0>)
<AddBackward0 object at 0x0000027503E3A220>


In [3]:
x = torch.randn(3, requires_grad=True)
y = x + 2
z = y * y * 3
print(z)
z = z.mean()      # z = (1/len(x)) * sum( 3 * (x+2)**2 )
print(z)
z.backward()
print(x.grad)     # dz/dx

tensor([10.1255, 19.6355, 44.4029], grad_fn=<MulBackward0>)
tensor(24.7213, grad_fn=<MeanBackward0>)
tensor([3.6743, 5.1167, 7.6944])


In [4]:
# If a Tensor is non-scalar, a gradient argument, which is a tensor of matching shape,
# is needed for backward() to calculate the vector-Jacobian product.
x = torch.randn(3, requires_grad=True)
y = x * 2
v = torch.tensor([0.1, 1.0, 0.0001], dtype=torch.float32)
y.backward(v)
print(x.grad)

tensor([2.0000e-01, 2.0000e+00, 2.0000e-04])


#### 2- Stopping Gradient Tracking

In [5]:
# For example, during the training loop, the weights must be updated, and this update shouldn't be part of the gradient computation.
# 1- x.requires_grad_(False)
# 2- x.detach()
# 3- wrap in 'with torch.no_grad():'

# Example with .requires_grad_(False)
a = torch.randn(2, 2)
print(a.requires_grad)  # False by default
b = ((a * 3) / (a - 1))
print(b.grad_fn)        # None, as requires_grad=False
a.requires_grad_(True)  # Change requires_grad to True in-place
print(a.requires_grad)
b = (a * a).sum()
print(b.grad_fn)        # Has grad_fn as requires_grad=True

# Example with .detach()
a = torch.randn(2, 2, requires_grad=True)
print(a.requires_grad)  # True
b = a.detach()          # Create a new tensor detached from the computation graph
print(b.requires_grad)  # False

# Example with no_grad
a = torch.randn(2, 2, requires_grad=True)
print(a.requires_grad)  # True
with torch.no_grad():
    print((a ** 2).requires_grad)  # False inside the no_grad context

False
None
True
<SumBackward0 object at 0x000002757BA4D730>
True
False
True
False


#### 3- Setting the gradients to zero before a new optimization step to prevent accumulation

In [6]:
weights = torch.ones(4, requires_grad=True)

print("if the gradients are not set to zero before a new optimization step:")
for epoch in range(3):
    # Dummy model output
    model_output = (weights * 3).sum()
    model_output.backward()  # Compute gradients
    print(weights.grad)      # Print gradients

    # Update weights without tracking gradients
    with torch.no_grad():
        weights -= 0.1 * weights.grad

    # Zero gradients before the next step
    #weights.grad.zero_()

print(weights)  # Final weights

print("\nif the gradients are set to zero before a new optimization step:")
for epoch in range(3):
    # Dummy model output
    model_output = (weights * 3).sum()
    model_output.backward()  # Compute gradients
    print(weights.grad)      # Print gradients

    # Update weights without tracking gradients
    with torch.no_grad():
        weights -= 0.1 * weights.grad

    # Zero gradients before the next step
    weights.grad.zero_()

print(weights)  # Final weights

if the gradients are not set to zero before a new optimization step:
tensor([3., 3., 3., 3.])
tensor([6., 6., 6., 6.])
tensor([9., 9., 9., 9.])
tensor([-0.8000, -0.8000, -0.8000, -0.8000], requires_grad=True)

if the gradients are set to zero before a new optimization step:
tensor([12., 12., 12., 12.])
tensor([3., 3., 3., 3.])
tensor([3., 3., 3., 3.])
tensor([-2.6000, -2.6000, -2.6000, -2.6000], requires_grad=True)


In [7]:
# Alternatively, use an optimizer which handles zeroing gradients automatically:
optimizer = torch.optim.SGD([weights], lr=0.1)

for epoch in range(3):
    optimizer.zero_grad()       # Zero the gradients
    model_output = (weights * 3).sum()
    model_output.backward()     # Compute gradients
    optimizer.step()            # Update weights