A new autograd operator can be defined with a subclass of <i>torch.autograd.Function</i> and implementing the <i>forward</i> and <i>backward</i> functions<br>




In this notebook, it will be defined as <b>y = a + b * P_3(c+dx)</b> instead of <i>y =a + b*x + c*x**2 + d*x**3<i>, where <br>P_3(x) = (1/2)*(5*x**3 - 3*x) --> Legendre Polynomial

In [1]:
import torch
import math

In [2]:
class LegendrePolynomial3(torch.autograd.Function):
    """
    We can implement our own custom autograd Functions by subclassing
    torch.autograd.Function and implementing the forward and backward passes
    which operate on Tensors.
    """

    @staticmethod
    def forward(ctx, input):
        """
        ctx is a context object that can be used
        to stash information for backward computation. We can cache arbitrary
        objects for use in the backward pass using the ctx.save_for_backward method.
        """
        ctx.save_for_backward(input)
        return 0.5 * (5 * input ** 3 - 3 * input)

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        """
        input, = ctx.saved_tensors
        return grad_output * 1.5 * (5 * input ** 2 - 1)

In [3]:
dtype = torch.float
device = torch.device("cpu")

In [4]:
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)

In [5]:
a = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)
b = torch.full((), -1.0, device=device, dtype=dtype, requires_grad=True)
c = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)
d = torch.full((), 0.3, device=device, dtype=dtype, requires_grad=True)

In [7]:
learning_rate = 5e-6
for t in range(2000):
    # To apply our Function, we use Function.apply method.
    P_3 = LegendrePolynomial3.apply

    y_pred = a + b * P_3(c + d * x)

    loss = (y_pred - y).pow(2).sum()
    if t % 100 == 99:
        print(t, loss.item())

    # Use autograd to compute the backward pass.
    loss.backward()

    # Update weights using gradient descent
    with torch.no_grad():
        a -= learning_rate * a.grad
        b -= learning_rate * b.grad
        c -= learning_rate * c.grad
        d -= learning_rate * d.grad

        # Zero the gradients after updating weights
        a.grad = None
        b.grad = None
        c.grad = None
        d.grad = None

print(f'Result: y = {a.item()} + {b.item()} * P3({c.item()} + {d.item()} x)')

99 8.903072357177734
199 8.87551498413086
299 8.85680103302002
399 8.84408950805664
499 8.835456848144531
599 8.829587936401367
699 8.825606346130371
799 8.82289981842041
899 8.821061134338379
999 8.81981086730957
1099 8.818964004516602
1199 8.818387985229492
1299 8.817996978759766
1399 8.817730903625488
1499 8.817550659179688
1599 8.817427635192871
1699 8.81734561920166
1799 8.81728744506836
1899 8.817249298095703
1999 8.817222595214844
Result: y = -1.0873581202108795e-10 + -2.233529806137085 * P3(-1.0858720866924187e-10 + 0.2556561827659607 x)
