In [1]:
# Import Library
import torch

In [2]:
# 定義網路結構
class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        # 定義網路的結構
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H) 
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        # 定義網絡的運算方式
        h_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h_relu)
        return y_pred

In [3]:
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

In [4]:
# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

In [5]:
# Construct our model by instantiating the class defined above
# 建構模型與定義tensor 數量
model = TwoLayerNet(D_in, H, D_out)

In [6]:
# Construct our loss function and an Optimizer. The call to model.parameters()
# in the SGD constructor will contain the learnable parameters of the two
# nn.Linear modules which are members of the model.
#定義loss function and optimizer
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)

In [10]:
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    if t % 50 == 0:
        print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 2.3826958567951806e-05
50 6.617226972593926e-06
100 1.842025199039199e-06
150 5.141324663782143e-07
200 1.4459796204846498e-07
250 4.214072291119919e-08
300 1.3556642031176125e-08
350 5.2417701290607965e-09
400 2.5220388000235516e-09
450 1.5267088704362664e-09
