In [1]:
import torch
import matplotlib.pyplot as plt

In [2]:
# Create tensors with requires_grad=True
x = torch.tensor([2.0, 5.0], requires_grad=True)
y = torch.tensor([3.0, 7.0], requires_grad=True)

# Perform some operations
z = x * y + y**2

z.retain_grad() #By default intermediate layer weight updation is not shown. This way we ask to store the gradient w.r.t z also

# Compute the gradients
z_sum = z.sum().backward()      # since z is a 2D tensor, .backward() need a scalar. Therefore sum operation before backward


print(f"Gradient of x: {x.grad}")
print(f"Gradient of y: {y.grad}")
print(f"Gradient of z: {z.grad}")
print(f"Result of the operation: z = {z.detach()}")

Gradient of x: tensor([3., 7.])
Gradient of y: tensor([ 8., 19.])
Gradient of z: tensor([1., 1.])
Result of the operation: z = tensor([15., 84.])


The detach() method is used to create a new tensor that shares storage with the original tensor but without tracking operations. When you call detach(), it returns a new tensor that does not require gradients. This is useful when you want to perform operations on a tensor without affecting the computation graph.

In [9]:
# Let's detach z from the computation graph
print("Before detaching z from computation: ", z.requires_grad)
z_det = z.detach()
print("After detaching z from computation: ", z_det.requires_grad)

Before detaching z from computation:  True
After detaching z from computation:  False
