In [38]:
# residual for MLP

## ResMLP Architecture
![ResMLP Architecture](../ResMLP-architecture.png)

In [39]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [40]:
class BlockResMLP(nn.Module):
    def __init__(self, in_features, outs_features) -> None:
        super().__init__()
        features = [in_features] + outs_features
        self.block = nn.Sequential(
            *[nn.Sequential(nn.Linear(ins, outs), nn.ReLU() if i < len(features) - 2 else nn.Identity())
              for i, (ins, outs) in enumerate(zip(features[:-1], features[1:]))]
        )
        self.skip_conection = nn.Linear(in_features, outs_features[-1])
    def forward(self,X):
        out = self.block(X)
        out += self.skip_conection(X)
        out = F.relu(out)
        return out
def test_BlockResMLP():
    config = {
        'in_features': 9,
        'outs_features': [32,16,32],
    }
    X = torch.Tensor(4,9)
    model = BlockResMLP(**config)
    out = model(X)
    print(out.shape)
    print(model)
# test_BlockResMLP()

In [41]:
class LayerResMLP(nn.Module):
    def __init__(self, num_blocks, in_features, outs_features) -> None:
        super().__init__()
        ins_feature = [in_features] + [outs_features[-1]]*(num_blocks-1)
        self.layer = nn.Sequential(
            *[BlockResMLP(in_features= ins, outs_features= outs_features) for ins in ins_feature]
        )
    def forward(self, X):
        out = self.layer(X)
        return out
def test_LayerResMLP():
    config = {
        'num_blocks': 2,
        'in_features': 9,
        'outs_features': [32,16,32],
    }
    X = torch.Tensor(4,9)
    model = LayerResMLP(**config)
    out = model(X)
    print(out.shape)
    print(model)
# test_LayerResMLP()

In [42]:
config = {
    'layer1':{
        'num_blocks': 2,
        'in_features': 9,
        'outs_features': [32,16,32],
    },
    'layer2':{
        'num_blocks': 2,
        'in_features': 32,
        'outs_features': [64,32,64],
    },
    'layer3':{
        'num_blocks': 2,
        'in_features': 64,
        'outs_features': [128,64,128],
    },
    'layer4':{
        'num_blocks': 2,
        'in_features': 128,
        'outs_features': [256,128,256],
    },
    'layer_out':{
        'in_features': 256,
        'out_features': 9 
    }
}

In [43]:
class ResMLP(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()

        layer = []
        for i in range(1, len(config)):
            config_params_layer = config[f'layer{i}']
            layer.append(LayerResMLP(**config_params_layer))
        
        config_params_layer_out = config['layer_out']
        layer.append(nn.Linear(**config_params_layer_out))
        
        self.model = nn.Sequential(*layer)

    def forward(self,X):
        out = self.model(X)
        return out

def test_ResMLP():
    X = torch.Tensor(4,9)
    model = ResMLP(config)
    out = model(X)
    print(out.shape)
    print(model)
test_ResMLP()

torch.Size([4, 9])
ResMLP(
  (model): Sequential(
    (0): LayerResMLP(
      (layer): Sequential(
        (0): BlockResMLP(
          (block): Sequential(
            (0): Sequential(
              (0): Linear(in_features=9, out_features=32, bias=True)
              (1): ReLU()
            )
            (1): Sequential(
              (0): Linear(in_features=32, out_features=16, bias=True)
              (1): ReLU()
            )
            (2): Sequential(
              (0): Linear(in_features=16, out_features=32, bias=True)
              (1): Identity()
            )
          )
          (skip_conection): Linear(in_features=9, out_features=32, bias=True)
        )
        (1): BlockResMLP(
          (block): Sequential(
            (0): Sequential(
              (0): Linear(in_features=32, out_features=32, bias=True)
              (1): ReLU()
            )
            (1): Sequential(
              (0): Linear(in_features=32, out_features=16, bias=True)
              (1): ReLU()
   