-
Notifications
You must be signed in to change notification settings - Fork 258
/
Copy pathpolynomial_custom_function.py
executable file
Β·104 lines (85 loc) Β· 4.44 KB
/
polynomial_custom_function.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# -*- coding: utf-8 -*-
"""
PyTorch: μ autograd Function μ μνκΈ°
----------------------------------------
:math:`y=\sin(x)` μ μμΈ‘ν μ μλλ‘, :math:`-\pi` λΆν° :math:`\pi` κΉμ§
μ ν΄λ¦¬λ 거리(Euclidean distance)λ₯Ό μ΅μννλλ‘ 3μ°¨ λ€νμμ νμ΅ν©λλ€.
λ€νμμ :math:`y=a+bx+cx^2+dx^3` λΌκ³ μ°λ λμ :math:`y=a+b P_3(c+dx)` λ‘ λ€νμμ μ κ² μ΅λλ€.
μ¬κΈ°μ :math:`P_3(x)=\\frac{1}{2}\\left(5x^3-3x\\right)` μ 3μ°¨
`λ₯΄μ₯λλ₯΄ λ€νμ(Legendre polynomial)`_ μ
λλ€.
.. _λ₯΄μ₯λλ₯΄ λ€νμ(Legendre polynomial):
https://en.wikipedia.org/wiki/Legendre_polynomials
μ΄ κ΅¬νμ PyTorch ν
μ μ°μ°μ μ¬μ©νμ¬ μμ ν λ¨κ³λ₯Ό κ³μ°νκ³ , PyTorch autogradλ₯Ό μ¬μ©νμ¬
λ³νλ(gradient)λ₯Ό κ³μ°ν©λλ€.
μλ ꡬνμμλ :math:`P_3'(x)` μ μννκΈ° μν΄ μ¬μ©μ μ μ autograd Functionλ₯Ό ꡬνν©λλ€.
μνμ μΌλ‘λ :math:`P_3'(x)=\\frac{3}{2}\\left(5x^2-1\\right)` μ
λλ€.
"""
import torch
import math
class LegendrePolynomial3(torch.autograd.Function):
"""
torch.autograd.Functionμ μμλ°μ μ¬μ©μ μ μ autograd Functionμ ꡬννκ³ ,
ν
μ μ°μ°μ νλ μμ ν λ¨κ³μ μμ ν λ¨κ³λ₯Ό ꡬνν΄λ³΄κ² μ΅λλ€.
"""
@staticmethod
def forward(ctx, input):
"""
μμ ν λ¨κ³μμλ μ
λ ₯μ κ°λ ν
μλ₯Ό λ°μ μΆλ ₯μ κ°λ ν
μλ₯Ό λ°νν©λλ€.
ctxλ 컨ν
μ€νΈ κ°μ²΄(context object)λ‘ μμ ν μ°μ°μ μν μ 보 μ μ₯μ μ¬μ©ν©λλ€.
ctx.save_for_backward λ©μλλ₯Ό μ¬μ©νμ¬ μμ ν λ¨κ³μμ μ¬μ©ν μ΄λ€ κ°μ²΄λ
μ μ₯(cache)ν΄ λ μ μμ΅λλ€.
"""
ctx.save_for_backward(input)
return 0.5 * (5 * input ** 3 - 3 * input)
@staticmethod
def backward(ctx, grad_output):
"""
μμ ν λ¨κ³μμλ μΆλ ₯μ λν μμ€(loss)μ λ³νλ(gradient)λ₯Ό κ°λ ν
μλ₯Ό λ°κ³ ,
μ
λ ₯μ λν μμ€μ λ³νλλ₯Ό κ³μ°ν΄μΌ ν©λλ€.
"""
input, = ctx.saved_tensors
return grad_output * 1.5 * (5 * input ** 2 - 1)
dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # GPUμμ μ€ννλ €λ©΄ μ΄ μ£Όμμ μ κ±°νμΈμ
# μ
λ ₯κ°κ³Ό μΆλ ₯κ°μ κ°λ ν
μλ€μ μμ±ν©λλ€.
# requires_grad=Falseκ° κΈ°λ³Έκ°μΌλ‘ μ€μ λμ΄ μμ ν λ¨κ³ μ€μ μ΄ ν
μλ€μ λν λ³νλλ₯Ό κ³μ°ν
# νμκ° μμμ λνλ
λλ€.
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)
# κ°μ€μΉλ₯Ό κ°λ μμμ ν
μλ₯Ό μμ±ν©λλ€. 3μ°¨ λ€νμμ΄λ―λ‘ 4κ°μ κ°μ€μΉκ° νμν©λλ€:
# y = a + b * P3(c + d * x)
# μ΄ κ°μ€μΉλ€μ΄ μλ ΄(convergence)νκΈ° μν΄μλ μ λ΅μΌλ‘λΆν° λ무 λ©λ¦¬ λ¨μ΄μ§μ§ μμ κ°μΌλ‘
# μ΄κΈ°νκ° λμ΄μΌ ν©λλ€.
# requires_grad=Trueλ‘ μ€μ νμ¬ μμ ν λ¨κ³ μ€μ μ΄ ν
μλ€μ λν λ³νλλ₯Ό κ³μ°ν νμκ°
# μμμ λνλ
λλ€.
a = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)
b = torch.full((), -1.0, device=device, dtype=dtype, requires_grad=True)
c = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)
d = torch.full((), 0.3, device=device, dtype=dtype, requires_grad=True)
learning_rate = 5e-6
for t in range(2000):
# μ¬μ©μ μ μ Functionμ μ μ©νκΈ° μν΄ Function.apply λ©μλλ₯Ό μ¬μ©ν©λλ€.
# μ¬κΈ°μ 'P3'λΌκ³ μ΄λ¦μ λΆμμ΅λλ€.
P3 = LegendrePolynomial3.apply
# μμ ν λ¨κ³: μ°μ°μ νμ¬ μμΈ‘κ° yλ₯Ό κ³μ°ν©λλ€;
# μ¬μ©μ μ μ autograd μ°μ°μ μ¬μ©νμ¬ P3λ₯Ό κ³μ°ν©λλ€.
y_pred = a + b * P3(c + d * x)
# μμ€μ κ³μ°νκ³ μΆλ ₯ν©λλ€.
loss = (y_pred - y).pow(2).sum()
if t % 100 == 99:
print(t, loss.item())
# autogradλ₯Ό μ¬μ©νμ¬ μμ ν λ¨κ³λ₯Ό κ³μ°ν©λλ€.
loss.backward()
# κ²½μ¬νκ°λ²(gradient descent)μ μ¬μ©νμ¬ κ°μ€μΉλ₯Ό κ°±μ ν©λλ€.
with torch.no_grad():
a -= learning_rate * a.grad
b -= learning_rate * b.grad
c -= learning_rate * c.grad
d -= learning_rate * d.grad
# κ°μ€μΉ κ°±μ νμλ λ³νλλ₯Ό μ§μ 0μΌλ‘ λ§λλλ€.
a.grad = None
b.grad = None
c.grad = None
d.grad = None
print(f'Result: y = {a.item()} + {b.item()} * P3({c.item()} + {d.item()} x)')