# We will use a problem of fitting y=sin(x) with a third order polynomial as our running example. The network will have four parameters, and will be trained with gradient descent to fit random data by minimizing the Euclidean distance between the network output and the true output.

In [1]:
import numpy as np
import math

In [4]:
x = np.linspace(-math.pi, math.pi, 2000)
y = np.sin(x)

In [5]:
a = np.random.randn()
b = np.random.randn()
c = np.random.randn()
d = np.random.randn()
learning_rate = 1e-6

In [7]:
for t in range(2000):
    y_pred = a + b * x + c * x ** 2 + d * x ** 3
    loss = np.square(y_pred - y).sum()
    if t % 100 == 99:
        print(t, loss)
    
    grad_y_pred = 2.0 * (y_pred - y)
    grab_a = grad_y_pred.sum()
    grab_b = (grad_y_pred * x).sum()
    grad_c = (grad_y_pred * x ** 2).sum()
    grad_d = (grad_y_pred * x ** 3).sum()
    
    a -= learning_rate * grab_a
    b -= learning_rate * grab_b
    c -= learning_rate * grad_c
    d -= learning_rate * grad_d
    
print(f"Result: y = {a} + {b}x + {c} x ^2 + {d} x^3")

99 103.42988841979711
199 71.85742043281431
299 50.84060533646392
399 36.844401385652354
499 27.519470973879002
599 21.303867298071374
699 17.158775310904925
799 14.39304843459093
899 12.546672548959513
999 11.3133450071635
1099 10.489024162577255
1199 9.937726693798215
1299 9.568782966377839
1399 9.32170647725409
1499 9.156124616953086
1599 9.045074984784755
1699 8.970540449047235
1799 8.920473970054793
1899 8.886815175559754
1999 8.864167435998365
Result: y = -0.00402052148319583 + 0.8622889345184817x + 0.0006936068576175409 x ^2 + -0.09411955673540166 x^3


In [9]:
import torch
import math

dtype = torch.float
device = torch.device("cpu")

x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)

a = torch.randn((), device=device, dtype=dtype)
b = torch.randn((), device=device, dtype=dtype)
c = torch.randn((), device=device, dtype=dtype)
d = torch.randn((), device=device, dtype=dtype)

lr = 1e-6

for t in range(2000):
    y_pred = a + b * x + c * x ** 2 + d * x ** 3
    loss = (y_pred - y).pow(2).sum().item()
    if t % 100 == 9:
        print(t, loss)
    grad_y_pred = 2.0 * (y_pred - y)
    grad_a = grad_y_pred.sum()
    grad_b = (grad_y_pred * x).sum()
    grad_c = (grad_y_pred * x ** 2).sum()
    grad_d = (grad_y_pred * x ** 2).sum()
    
    a -= lr * grad_a
    b -= lr * grad_b
    c -= lr * grad_c
    d -= lr * grad_d
    
print(f"Result : {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3")

9 124171.875
109 42615.92578125
209 34048.6328125
309 35271.3828125
409 37461.37109375
509 39625.4453125
609 41602.40625
709 43363.140625
809 44910.6484375
909 46258.87890625
1009 47425.6796875
1109 48430.4140625
1209 49291.9375
1309 50028.2578125
1409 50655.7578125
1509 51189.5
1609 51642.57421875
1709 52026.6015625
1809 52351.7890625
1909 52626.7578125
Result : -0.07593779265880585 + 7.33402156829834 x + 0.013100529089570045 x^2 + -1.1880244016647339 x^3


In [15]:
import torch
import math

dtype = torch.float
device = torch.device("cpu")

x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)

a = torch.randn((), device=device, dtype=dtype, requires_grad=True)
b = torch.randn((), device=device, dtype=dtype, requires_grad=True)
c = torch.randn((), device=device, dtype=dtype, requires_grad=True)
d = torch.randn((), device=device, dtype=dtype, requires_grad=True)

lr = 1e-6

for t in range(2000):
    y_pred = a + b * x + c * x ** 2 + d * x ** 3
    loss = (y_pred - y).pow(2).sum()
    if t % 100 == 10:
        print(t, loss.item())
        
    loss.backward()
    
    with torch.no_grad():
        grad_y_pred = 2.0 * (y_pred - y)
        grad_a = grad_y_pred.sum()
        grad_b = (grad_y_pred * x).sum()
        grad_c = (grad_y_pred * x ** 2).sum()
        grad_d = (grad_y_pred * x ** 2).sum()
        
        a.grad = None
        b.grad = None
        c.grad = None
        d.grad = None
    
print(f"Result : {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3")

10 106583.0234375
110 106583.0234375
210 106583.0234375
310 106583.0234375
410 106583.0234375
510 106583.0234375
610 106583.0234375
710 106583.0234375
810 106583.0234375
910 106583.0234375
1010 106583.0234375
1110 106583.0234375
1210 106583.0234375
1310 106583.0234375
1410 106583.0234375
1510 106583.0234375
1610 106583.0234375
1710 106583.0234375
1810 106583.0234375
1910 106583.0234375
Result : 1.0102566480636597 + -0.2916073501110077 x + -1.2082592248916626 x^2 + 0.5446464419364929 x^3
