# Torch 

## Custom nn Modules

Sometimes you will want to specify models that are more complex than a sequence of existing Modules; for these cases you can define your own Modules by subclassing nn.Module and defining a forward which receives input Tensors and produces output Tensors using other modules or other autograd operations on Tensors.

In [1]:
# -*- coding: utf-8 -*-

import torch

Implement a two-layer nerwork as a custom Module subclass:

In [2]:
class TwoLayerNet(torch.nn.Module):
    
    """
    This is a custom two layer module
    """
    
    def __init__(self, D_in, H, D_out):
        
        """
        In the constructor we instantiate two nn.Linear modules and assign them as 
        member variables.
        """
        
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)
        
    def forward(self, x):
        
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as 
        well as arbitrary operator on Tensors.
        """
        
        h_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h_relu)
        
        return y_pred

In [3]:
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.

N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs

x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

In [4]:
# Using custom nn Module

model = TwoLayerNet(D_in, H, D_out)

In [5]:
# loss function and optimizer

criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)

In [6]:
for t in range(500):
    
    y_pred = model(x)
    
    loss = criterion(y_pred, y)
    if t % 100 == 99:
        print(t, loss.item())
        
    optimizer.zero_grad()
    
    loss.backward()
    
    optimizer.step()

99 2.32717227935791
199 0.03986956924200058
299 0.0015381972771137953
399 8.920634718379006e-05
499 6.171580480440753e-06


-- by HanaRo, 2020/09/10