## **Import important Libraries**

In [1]:
import torch
import numpy as np

### **Gradient with scalar value**

In [4]:
x1 = torch.randn(3, requires_grad=True)
print(x1)

x2 = x1 + 2
print(x2)

x3 = x2*x2*2
x3 = x3.mean()
print(x3)

x3.backward() # dx3/dx1
print(x1.grad)

tensor([1.1981, 0.7285, 0.6427], requires_grad=True)
tensor([2.0405, 0.7849, 2.1110], grad_fn=<AddBackward0>)
tensor(16.4373, grad_fn=<MeanBackward0>)
tensor([4.2641, 3.6380, 3.5236])


### **Gradient with vector**

In [9]:
x4 = torch.randn(3, requires_grad=True)
print(x4)

x5 = x4 + 2
print(x5)

x6 = x5*x5*2
print(x6)

v = torch.tensor([0.1, 1.0, 0.001], dtype=torch.float32)
x6.backward(v) # dx6/dx4
print(x4.grad)

tensor([-0.3423,  0.6730, -0.4464], requires_grad=True)
tensor([1.6577, 2.6730, 1.5536], grad_fn=<AddBackward0>)
tensor([ 5.4962, 14.2904,  4.8274], grad_fn=<MulBackward0>)
tensor([6.6310e-01, 1.0692e+01, 6.2144e-03])


### **Prevent tracking the history of gradient**

- x.requires_grad_(False)
- x.detach()
- with torch.no_grad():

In [10]:
x7 = torch.randn(3, requires_grad=True)
print(x7)
x7.requires_grad_(False)
print(x7)

tensor([-0.6420, -0.1796,  0.0222], requires_grad=True)
tensor([-0.6420, -0.1796,  0.0222])


In [11]:
x8 = torch.randn(3, requires_grad=True)
print(x8)
x9 = x8.detach()
print(x9)

tensor([-0.9813, -0.2444,  0.5999], requires_grad=True)
tensor([-0.9813, -0.2444,  0.5999])


In [12]:
x10 = torch.randn(3, requires_grad=True)
print(x10)
with torch.no_grad():
    x11 = x10 + 2
    print(x11)

tensor([ 2.7206,  0.1196, -0.6959], requires_grad=True)
tensor([4.7206, 2.1196, 1.3041])


### **Simple Training**
- before next itteration and optimization step, we must empty the gradients. this is important

In [17]:
weights = torch.ones(4, requires_grad=True)
for epoch in range(3):
    model_output = (weights*3).sum()
    model_output.backward()
    print(weights.grad)

    # before next itteration and optimization step we must empty gradient
    weights.grad.zero_()

tensor([3., 3., 3., 3.])
tensor([3., 3., 3., 3.])
tensor([3., 3., 3., 3.])
