In [1]:
import torch
import math

In [3]:
class legendrePolynomial3(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return 0.5*(5* input**3 - 3*input)
    
    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        return grad_output * 1.5 * (5 * input **2 -1)

In [4]:
dtype = torch.float
device = torch.device("cuda:0")

In [5]:
x = torch.linspace(-math.pi, math.pi, 2000, dtype = torch.float,device = device)
y = torch.sin(x)

In [6]:
a = torch.full((),0.0, device = device, dtype=  dtype, requires_grad = True)
b = torch.full((),-1.0, device = device, dtype=  dtype, requires_grad = True)
c = torch.full((),0.0, device = device, dtype=  dtype, requires_grad = True)
d = torch.full((),0.3, device = device, dtype=  dtype, requires_grad = True)

In [9]:
learning_rate = 1e-6
for i in range(2000):
    p3 = legendrePolynomial3.apply
    
    y_pred = a + b * p3(c+d*x)
    
    loss = (y_pred - y).pow(2).sum()
    if i%100 == 99:
        print(i, loss.item())
        
    loss.backward()
    
    
    with torch.no_grad():
        
        a-= learning_rate * a.grad
        b-= learning_rate * b.grad
        c-= learning_rate * c.grad
        d -= learning_rate * d.grad
        
        
        a.grad = None
        b.grad = None
        c.grad = None
        d.grad = None
        
print(f"Results y = {a.item()} + {b.item()} * p3({c.item()} + {d.item()} x)")

99 66.23503112792969
199 61.937660217285156
299 57.963260650634766
399 54.28733444213867
499 50.88731384277344
599 47.742374420166016
699 44.83330535888672
799 42.142276763916016
899 39.65290832519531
999 37.34986114501953
1099 35.219364166259766
1199 33.24830627441406
1299 31.424734115600586
1399 29.737539291381836
1499 28.176536560058594
1599 26.732240676879883
1699 25.39589500427246
1799 24.159400939941406
1899 23.015254974365234
1999 21.956523895263672
Results y = 3.358735858882689e-11 + -1.9735198020935059 * p3(-4.7892217969192075e-11 + 0.2537434995174408 x)
