In [99]:
import torch
import math

class my_opty(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return 0.5 * (5 * input ** 3 - 3 * input)
    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        return grad_output * 1.5 * (5 * input ** 2 - 1)

dtype = torch.float
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

x = torch.linspace(-math.pi, math.pi, 10, device=device, dtype=dtype)
y = torch.sin(x)

a = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)
b = torch.full((), -1.0, device=device, dtype=dtype, requires_grad=True)
c = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)
d = torch.full((), 0.3, device=device, dtype=dtype, requires_grad=True)

lr = 0.00005

for t in range(100000):
    m_o = my_opty.apply

    y_pred = a + b * m_o(c + d * x)
    mse = (y_pred - y).pow(2).sum()
    if t%1000 == 0:
        print(t)
        print("input : ", x)
        print("pred : ", y_pred)
        print("target : ",y)
        print("mse : ", mse.item())
        print("{:.5f} {:.5f} {:.5f} {:.5f}".format(a,b,c,d))
        print()

    mse.backward()
    with torch.no_grad():
        a -= lr * a.grad
        b -= lr * b.grad
        c -= lr * c.grad
        d -= lr * d.grad
        a.grad = None
        b.grad = None
        c.grad = None
        d.grad = None

0
input :  tensor([-3.1416, -2.4435, -1.7453, -1.0472, -0.3491,  0.3491,  1.0472,  1.7453,
         2.4435,  3.1416], device='cuda:0')
pred :  tensor([ 0.6792, -0.1148, -0.4265, -0.3937, -0.1542,  0.1542,  0.3937,  0.4265,
         0.1148, -0.6792], device='cuda:0', grad_fn=<AddBackward0>)
target :  tensor([ 8.7423e-08, -6.4279e-01, -9.8481e-01, -8.6603e-01, -3.4202e-01,
         3.4202e-01,  8.6603e-01,  9.8481e-01,  6.4279e-01, -8.7423e-08],
       device='cuda:0')
mse :  2.620178461074829
0.00000 -1.00000 0.00000 0.30000

1000
input :  tensor([-3.1416, -2.4435, -1.7453, -1.0472, -0.3491,  0.3491,  1.0472,  1.7453,
         2.4435,  3.1416], device='cuda:0')
pred :  tensor([-0.0030, -0.3951, -0.4927, -0.3798, -0.1406,  0.1406,  0.3798,  0.4927,
         0.3951,  0.0030], device='cuda:0', grad_fn=<AddBackward0>)
target :  tensor([ 8.7423e-08, -6.4279e-01, -9.8481e-01, -8.6603e-01, -3.4202e-01,
         3.4202e-01,  8.6603e-01,  9.8481e-01,  6.4279e-01, -8.7423e-08],
       device='cud

KeyboardInterrupt: ignored