
PyTorch: Custom nn Modules
--------------------------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation defines the model as a custom Module subclass. Whenever you
want a model more complex than a simple sequence of existing Modules you will
need to define your model this way.



In [0]:
import torch
import torch.nn as nn
import torch.optim as optim

# Get reproducible results
torch.manual_seed(0)

# Define the model
class MLP(torch.nn.Module):
    def __init__(self, num_inputs, num_hidden_layer_nodes, num_outputs):
        # Initialize super class
        super().__init__()

        # Build model using Sequential container
        self.model = nn.Sequential(
            # Add hidden layer 
            nn.Linear(num_inputs, num_hidden_layer_nodes),
            # Add ReLU activation
            nn.ReLU(),
            # Add output layer
            nn.Linear(num_hidden_layer_nodes, num_outputs)
        )

    def forward(self, x):
        # Forward pass
        return self.model(x)

# Num data points
num_data = 1000

# Network parameters
num_inputs = 1000
num_hidden_layer_nodes = 100
num_outputs = 10

# Training parameters
num_epochs = 100 

# Create input and output tensors
x = torch.randn(num_data, num_inputs)
y = torch.randn(num_data, num_outputs)

# Construct model
model = MLP(num_inputs, num_hidden_layer_nodes, num_outputs)

# Define loss function
loss_function = nn.MSELoss(reduction='sum')

# Define optimizer
optimizer = optim.SGD(model.parameters(), lr=1e-4)


for t in range(num_epochs):

    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = loss_function(y_pred, y)
    print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()

    # Calculate gradient using backward pass
    loss.backward()

    # Update model parameters (weights)
    optimizer.step()

0 10581.46484375
1 9755.71875
2 9161.2998046875
3 8637.4951171875
4 8135.08154296875
5 7634.916015625
6 7127.3212890625
7 6610.97021484375
8 6087.3720703125
9 5563.39599609375
10 5048.9482421875
11 4551.427734375
12 4078.5126953125
13 3634.7451171875
14 3225.676025390625
15 2852.951904296875
16 2516.31201171875
17 2214.470947265625
18 1946.5286865234375
19 1709.1199951171875
20 1499.035400390625
21 1313.6361083984375
22 1151.3740234375
23 1009.0540771484375
24 883.8927612304688
25 774.8145141601562
26 679.2297973632812
27 596.13232421875
28 524.0948486328125
29 462.2839660644531
30 411.1573791503906
31 372.46417236328125
32 353.1471252441406
33 370.2545471191406
34 465.539794921875
35 733.1923828125
36 1382.0526123046875
37 2739.25732421875
38 5025.0751953125
39 6997.12939453125
40 6101.65625
41 2680.3486328125
42 893.8970336914062
43 422.0373840332031
44 282.54412841796875
45 214.08810424804688
46 170.41282653808594
47 139.2351531982422
48 115.83944702148438
49 97.63912200927734
50 83