In [1]:
import torch
from torch import nn
class MLP(nn.Module):
    def __init__(self,**kwargs):
        super(MLP,self).__init__(**kwargs)
        self.hidden=nn.Linear(784,256)
        self.act=nn.ReLU()
        self.output=nn.Linear(256,10)
        
    def forward(self,x):
        a=self.act(self.hidden(x))
        return self.output(a)
        

In [2]:
X=torch.rand(2,784)
net=MLP()
print(net)
net(X)

MLP(
  (hidden): Linear(in_features=784, out_features=256, bias=True)
  (act): ReLU()
  (output): Linear(in_features=256, out_features=10, bias=True)
)


tensor([[ 0.0484, -0.2463, -0.1467, -0.0273, -0.1938, -0.0521,  0.0093,  0.0358,
          0.1821, -0.1083],
        [-0.0144, -0.1516, -0.2559,  0.0202, -0.1602, -0.1005, -0.0199,  0.0322,
         -0.0060, -0.1127]], grad_fn=<AddmmBackward>)

In [3]:
#
class Mysequential(nn.Module):
    from collections import OrderedDict
    def __init__(self,*args):
        super(Mysequential,self).__init__()
        if len(args)==1 and isinstance(args[0],OrderedDict):
            for key,module in args[0].item():
                self.add_module(key,module)
                
        else:
            for idx,module in enumerate(args):
                self.add_module(str(idx),module)
    def forward(self,input):
        for module in self._modules.values():
            input=module(input)
            
        return input
    
net=Mysequential(
    nn.Linear(784,256),
    nn.ReLU(),
    nn.Linear(256,10)
    )
print(net)
net(X)

Mysequential(
  (0): Linear(in_features=784, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=10, bias=True)
)


tensor([[ 0.1295, -0.1517,  0.0704, -0.0858,  0.0383,  0.2601,  0.0885, -0.0918,
          0.1700, -0.1933],
        [ 0.0622, -0.1256,  0.0137, -0.1767, -0.0135,  0.2021,  0.0188, -0.0389,
          0.1095, -0.0642]], grad_fn=<AddmmBackward>)

In [4]:
net=nn.ModuleList([nn.Linear(784,256),nn.ReLU()])
net.append(nn.Linear(256,10))
print(net[-1])
print(net)

Linear(in_features=256, out_features=10, bias=True)
ModuleList(
  (0): Linear(in_features=784, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=10, bias=True)
)


In [5]:

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule,self).__init__()
        self.linears=nn.ModuleList([nn.Linear(10,10) for i in range(10)])
        
    def forward(self,x):
        for i, l in enumerate(self.linears):
            x=self.linears[i//2]+l(x)
        return x

In [8]:
class Module_ModuleList(nn.Module):
    def __init__(self):
        super(Module_ModuleList,self).__init__()
        self.linear=nn.ModuleList([nn.Linear(10,10)])
        
class Module_List(nn.Module):
    def __init__(self):
        super(Module_List,self).__init__()
        self.linears=[nn.Linear(10,10)]
        
net1=Module_ModuleList()
net2=Module_List()
print('net1:')
for p in net1.parameters():
    print(p,p.size())
    
print('net2:')
for p in net2.parameters():
    print(p,p.size())

net1:
Parameter containing:
tensor([[ 0.0426,  0.1676, -0.1427,  0.0321,  0.1272, -0.1238,  0.1770, -0.0335,
         -0.2738, -0.2786],
        [ 0.2544,  0.0163, -0.2005, -0.2130,  0.2521,  0.0409, -0.2425, -0.1728,
          0.1518,  0.1115],
        [ 0.2211,  0.3043,  0.1024,  0.2533,  0.0879,  0.2876,  0.0802, -0.1489,
          0.3094, -0.0279],
        [ 0.0685,  0.2148, -0.1651, -0.2029, -0.2079,  0.2234,  0.2062, -0.2956,
         -0.1894, -0.1502],
        [ 0.0826, -0.1959, -0.1569,  0.2419, -0.0631,  0.0362, -0.1984, -0.2073,
          0.2421,  0.1409],
        [ 0.2138, -0.0451, -0.2710,  0.2875, -0.0397,  0.0665,  0.0344, -0.1998,
          0.2912, -0.1304],
        [ 0.2430,  0.2812, -0.0209, -0.1971,  0.2747, -0.2163,  0.1714,  0.0658,
          0.2825, -0.2068],
        [-0.0017, -0.0654, -0.2719,  0.1883,  0.2028,  0.2221,  0.0292,  0.2565,
         -0.0484,  0.0004],
        [-0.1818,  0.2281,  0.2876, -0.0765, -0.0481, -0.2560, -0.1784,  0.1848,
          0.2501, -

In [10]:
net=nn.ModuleDict({
    'linear':nn.Linear(784,256),
    'act':nn.ReLU(),
})
net['output']=nn.Linear(256,10)
print(net['linear'])
print(net.output)
print(net.act)
print(net)

Linear(in_features=784, out_features=256, bias=True)
Linear(in_features=256, out_features=10, bias=True)
ReLU()
ModuleDict(
  (act): ReLU()
  (linear): Linear(in_features=784, out_features=256, bias=True)
  (output): Linear(in_features=256, out_features=10, bias=True)
)


In [11]:
class FancyMLP(nn.Module):
    def __init__(self,**kwargs):
        super(FancyMLP,self).__init__(**kwargs)
        
        self.rand_weight=torch.rand((20,20),requires_grad=False)
        self.linear=nn.Linear(20,20)
        
    def forward(self,X):
        x=self.linear(X)
        x=nn.functional.relu(torch.mm(x,self.rand_weight.data)+1)
        
        x=self.linear(x)
        
        while x.norm().item()>1:
            x/=2
        if x.norm().item()<0.8:
            x*=10
        return x.sum()
    

In [12]:
X=torch.rand(2,20)
net=FancyMLP()
print(net)
net(X)

FancyMLP(
  (linear): Linear(in_features=20, out_features=20, bias=True)
)


tensor(-8.6338, grad_fn=<SumBackward0>)

In [15]:
class NestMLP(nn.Module):
    def __init__(self,**kwargs):
        super(NestMLP,self).__init__(**kwargs)
        self.net=nn.Sequential(nn.Linear(40,30),nn.ReLU())
    def forward(self,x):
        return self.net(x)
    
net=nn.Sequential(NestMLP(),nn.Linear(30,20),FancyMLP())

X=torch.rand(2,40)
print(net)
net(X)
    

Sequential(
  (0): NestMLP(
    (net): Sequential(
      (0): Linear(in_features=40, out_features=30, bias=True)
      (1): ReLU()
    )
  )
  (1): Linear(in_features=30, out_features=20, bias=True)
  (2): FancyMLP(
    (linear): Linear(in_features=20, out_features=20, bias=True)
  )
)


tensor(9.7763, grad_fn=<SumBackward0>)