block -> nn.Module

Sequential Block -> nn.Sequential ({add_module}, forward)

# A Custom Block

In [1]:
import torch
import torch.nn as nn

In [31]:
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.net = nn.Sequential(
                    nn.Linear(20, 256),
                    nn.ReLU(),
                    nn.Linear(256, 10))
        
    def forward(self, data):
        return self.net(data)

In [5]:
MLP().net

Sequential(
  (0): Linear(in_features=20, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=10, bias=True)
)

# A Sequential Block

In [25]:
class MySequential(nn.Sequential):
    def __init__(self, **kwargs):
        super(MySequential, self).__init__(**kwargs)
        
    def add_module(self, block):
        self._modules[str(len(self._modules))] = block
#         self._modules[str(type(block)).split('.')[-1].split('\'')[0]] = block
        
    def forward(self, data):
        for block in self._modules.values():
            data = block(data)
            
        return data


In [26]:
class MyMLP(nn.Module):
    def __init__(self):
        super(MyMLP, self).__init__()
        self.net = MySequential()
        self.net.add_module(nn.Linear(20, 256))
        self.net.add_module(nn.ReLU())
        self.net.add_module(nn.Linear(256, 10))
        
    def forward(self, data):
        return self.net(data)

In [27]:
MyMLP().net

MySequential(
  (0): Linear(in_features=20, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=10, bias=True)
)

# Exercise 1 - Parallel block

In [35]:
class ParallelBlock(nn.Module):
    def __init__(self, net1, net2):
        super(ParallelBlock, self).__init__()
        self.net1 = net1
        self.net2 = net2
        
    def forward(self, data):
        out1 = self.net1(data)
        out2 = self.net2(data)
        return torch.cat((out1, out2), 0)

In [38]:
net = ParallelBlock(MLP(), MLP())
x = torch.randn(2,20)
out = net(x)
out, out.shape

(tensor([[-0.1455,  0.3422,  0.1582,  0.3535, -0.1130, -0.1365, -0.1637, -0.1925,
          -0.4087, -0.3598],
         [-0.3239,  0.1748,  0.0075,  0.0226,  0.1066,  0.1049, -0.3315, -0.0438,
           0.1683, -0.2228],
         [ 0.5970, -0.6489, -0.0513,  0.4739, -0.0281,  0.0071,  0.1806, -0.2795,
          -0.6322,  0.3603],
         [ 0.1791,  0.0008, -0.1905,  0.3419, -0.2979, -0.2244, -0.2947, -0.1042,
          -0.3730, -0.1888]], grad_fn=<CatBackward>),
 torch.Size([4, 10]))

# Exercise 2 - concatenate multiple instances of the same network

In [46]:
def factory(net_name, num):
    net_list = []
    for i in range(num):
        net_list.append(net_name())
        
    net = nn.Sequential()
    for idx,block in enumerate(net_list):
        net.add_module(str(idx), block)
        
    return net

In [47]:
factory(MLP, 3)

Sequential(
  (0): MLP(
    (net): Sequential(
      (0): Linear(in_features=20, out_features=256, bias=True)
      (1): ReLU()
      (2): Linear(in_features=256, out_features=10, bias=True)
    )
  )
  (1): MLP(
    (net): Sequential(
      (0): Linear(in_features=20, out_features=256, bias=True)
      (1): ReLU()
      (2): Linear(in_features=256, out_features=10, bias=True)
    )
  )
  (2): MLP(
    (net): Sequential(
      (0): Linear(in_features=20, out_features=256, bias=True)
      (1): ReLU()
      (2): Linear(in_features=256, out_features=10, bias=True)
    )
  )
)