In [8]:
import torch
from torch import nn



class ResMLP(nn.Module):
    def __init__(self, layers):
        super().__init__()
        
        self.layers = []
        input_dim = layers[0]
        for dim in layers[1:]:
            self.layers.append(nn.Linear(input_dim, dim))
            input_dim = dim  
            
        self.layers = nn.ModuleList(self.layers)
        
    
    def forward(self, x):
        for layer in self.layers:
            # skip block of 1, only works when input dims match output dims
            if layer.weight.shape[0] == layer.weight.shape[1]:
                x = x + layer(x)
            else:
                x = layer(x)
            x = torch.tanh(x)
            
        return x
            



#net = ResMLP([8 for _ in range(100)])

net = []
for _ in range(200):
    net.append(nn.Linear(128,128))
    net.append(nn.ReLU())
net = nn.Sequential(*net)
op = torch.optim.Adam(net.parameters())
L = nn.MSELoss()

for _ in range (1000):
    x = torch.randn(128, 128)
    y = torch.sin(x)
    
    out = net(x)
    loss = L(out, y)
    loss.backward()
    
    op.step()
    op.zero_grad()
    
    print(loss.item())

0.4338158965110779
0.43443232774734497
0.4357782006263733
0.4335229992866516
0.4340008795261383
0.43050041794776917
0.43370741605758667
0.43137800693511963
0.4331374764442444
0.4319044351577759
0.43208083510398865
0.4325769245624542
0.4311603307723999
0.4412326514720917
0.4372348487377167
0.43508797883987427
0.43525075912475586
0.43109333515167236
0.43596985936164856
0.4375825524330139
0.43016374111175537
0.4375734329223633
0.4284602403640747
0.43513911962509155
0.4326404333114624
0.43057507276535034
0.4286668300628662
0.4274650812149048
0.4346632957458496
0.4342920184135437
0.42845550179481506
0.4358082115650177
0.43508827686309814
0.4243760406970978
0.43296191096305847
0.4322831332683563
0.4328440725803375
0.43660277128219604
0.43391022086143494
0.43741780519485474
0.4261281490325928
0.43422064185142517
0.4337971806526184
0.43226754665374756
0.4312974214553833
0.4300386309623718
0.43147778511047363
0.4324761629104614
0.43118104338645935
0.42980897426605225
0.4301232695579529
0.430945

KeyboardInterrupt: 

In [None]:
class ResMlpLayer(nn.Module):
     def __init__(self, dims: int, layers: int):
         super().__init__()
         self.layers = nn.ModuleList([nn.Linear(dims, dims)for _ in range(layers)])
         self.activation = nn.ELU()
         
     
     def forward(self, x):
        z = x
        for layer in self.layers:
            z = layer(z)
            z = self.activation(z)
        
        z = z + x
        return z
    

In [None]:
class ResMLP2(nn.Module):
    #h dims is hidden dims
    def __init__(self, input_dims:int, h_dims:int, output_dims:int, num_hiden_res_layers:int, res_block_size = 2):
        super().__init__()
        self.input_layer = nn.Linear(input_dims, h_dims)
        self.activation = nn.ELU()
        self.hidden_layer = nn.Sequential(
            *[ResMlpLayer(h_dims, res_block_size) for _ in range(num_hiden_res_layers)]
        )
        
        self.output = nn.Linear(h_dims, output_dims)
        
        
    
    def forward(self,x):
        z = x
        z = self.input_layer(z)
        z = self.activation(z)
        z = self.hidden_layer(z)
        z = self.output(z)
        
        return z