# PyTorch：自定义nn模块

经过训练的三阶多项式，可以通过最小化平方的欧几里得距离来预测y = sin(x)从-pi到pi。

此实现将模型定义为自定义Module子类。 每当您想要一个比现有模块的简单序列更复杂的模型时，都需要以这种方式定义模型

In [1]:
import torch
import math

class Polynomial3(torch.nn.Module):
    def __init__(self):
        """
        In the constructor we instantiate four parameters and assign them as
        member parameters.
        """
        super().__init__()
        self.a = torch.nn.Parameter(torch.randn(()))
        self.b = torch.nn.Parameter(torch.randn(()))
        self.c = torch.nn.Parameter(torch.randn(()))
        self.d = torch.nn.Parameter(torch.randn(()))

    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        return self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3

    def string(self):
        """
        Just like any class in Python, you can also define custom method on PyTorch modules
        """
        return f'y = {self.a.item()} + {self.b.item()} x + {self.c.item()} x^2 + {self.d.item()} x^3'

# Create Tensors to hold input and outputs.
x = torch.linspace(-math.pi, math.pi, 2000)
y = torch.sin(x)

# Construct our model by instantiating the class defined above
model = Polynomial3()

# Construct our loss function and an Optimizer. The call to model.parameters()
# in the SGD constructor will contain the learnable parameters of the nn.Linear
# module which is members of the model.
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-6)
for t in range(2000):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    if t % 100 == 99:
        print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print(f'Result: {model.string()}')

99 1025.8203125
199 724.6102294921875
299 512.7764892578125
399 363.7436828613281
499 258.85662841796875
599 185.01414489746094
699 133.01153564453125
799 96.37835693359375
899 70.5650405883789
999 52.3709716796875
1099 39.544090270996094
1199 30.49899673461914
1299 24.119285583496094
1399 19.61856460571289
1499 16.442825317382812
1599 14.201584815979004
1699 12.619612693786621
1799 11.50278091430664
1899 10.714215278625488
1999 10.157346725463867
Result: y = 0.03827813267707825 + 0.8512060642242432 x + -0.0066036139614880085 x^2 + -0.0925431177020073 x^3
