In [1]:
import torch
import math

In [2]:
dtype = torch.float
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_default_device(device) 

In [3]:
# Tensors to hold input and output
# requires_grad will be False by default because we don't need the compute gradients for these tensors

x = torch.linspace(-math.pi, math.pi, 2000, dtype=dtype)
y = torch.sin(x)

In [6]:
# Create random tensors for weights
# 4rd order polynomial, we need 4 weights: y = a + b*x + c*x** + d*x**3
# We want to compute gradients, so set requires_grad=True

a = torch.randn((), dtype=dtype, requires_grad=True)
b = torch.randn((), dtype=dtype, requires_grad=True)
c = torch.randn((), dtype=dtype, requires_grad=True)
d = torch.randn((), dtype=dtype, requires_grad=True)

In [7]:
learning_rate = 1e-6
for t in range(2000):
    # Forward pass, compute predicted y
    y_pred = a + b * x + c * x ** 2 + d * x ** 3
    
    #Compute loss
    loss = (y_pred - y).pow(2).sum()
    if t % 100 == 99:
        print(t, loss.item())
        
    # Use autograd to compute the backward pass
    loss.backward()
    
    #Update weights. We don't need grad, so we should go with torch.no_grad()
    with torch.no_grad():
        a -= learning_rate * a.grad
        b -= learning_rate * b.grad
        c -= learning_rate * c.grad
        d -= learning_rate * d.grad
        
        a.grad = None
        b.grad = None
        c.grad = None
        d.grad = None
        
print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')

99 4630.70703125
199 3065.953857421875
299 2030.99560546875
399 1346.440185546875
499 893.6438598632812
599 594.1377563476562
699 396.0223388671875
799 264.97113037109375
899 178.2801055908203
999 120.93199920654297
1099 82.99369812011719
1199 57.89504623413086
1299 41.29021453857422
1399 30.3043155670166
1499 23.035734176635742
1599 18.22638702392578
1699 15.044116973876953
1799 12.938362121582031
1899 11.544840812683105
1999 10.622652053833008
Result: y = -0.005551859736442566 + 0.815751850605011 x + 0.000957788317464292 x^2 + -0.08750005811452866 x^3
