In [3]:
import torch
import torch.nn as nn
import numpy as np

In [32]:
def LinearTanh(n_in, n_out):
    # do not work with ModuleList here either.
    block = nn.Sequential(
      nn.Linear(n_in, n_out),
      nn.Tanh()
    )
    return block


class MLP(nn.Module):
    
    def __init__(self, dim_layers):
        super(MLP, self).__init__()
        layers = []
        num_layers = len(dim_layers)
        
        blocks = []
        for l in range(num_layers-1):
            blocks.append(LinearTanh(dim_layers[l], dim_layers[l+1]))

        self.network = nn.Sequential(*blocks)
    
    def forward(self, x):
        return self.network(x)
    
mlp_layers = [2] + [10] + [20] + [30] + [2]

mlp = MLP(mlp_layers)
print(mlp)

MLP(
  (network): Sequential(
    (0): Sequential(
      (0): Linear(in_features=2, out_features=10, bias=True)
      (1): Tanh()
    )
    (1): Sequential(
      (0): Linear(in_features=10, out_features=20, bias=True)
      (1): Tanh()
    )
    (2): Sequential(
      (0): Linear(in_features=20, out_features=30, bias=True)
      (1): Tanh()
    )
    (3): Sequential(
      (0): Linear(in_features=30, out_features=2, bias=True)
      (1): Tanh()
    )
  )
)


In [33]:
x = torch.randn(3,2)
y = mlp(x)

In [41]:
tanh = lambda x : torch.tanh(x)
dtanh = lambda x : 1 - torch.tanh(x)**2

In [35]:
mlp.network[0][0].weight
mlp.network[0][0].bias

Parameter containing:
tensor([-0.2374,  0.1247, -0.6819, -0.4599,  0.0659, -0.3426,  0.3905, -0.3083,
         0.3608, -0.4472], requires_grad=True)

In [47]:
def get_wb(model,depth):
    return model[depth][0].weight, model[depth][0].bias


W = []
B = []

for i in range(4):
    w, b = get_wb(mlp.network,i)
    W.append(w)
    B.append(b)
    
print(len(W))
print(len(B))

4
4


In [55]:
a = W[-1]@W[-2]@W[-3]@W[-4]
tmp = x@W[-4].T + B[-4]
in1 = dtanh(tmp)
in2 = dtanh(tanh(in0)@W[-3].T + B[-3])
in3 = dtanh(tanh(in2)@W[-2].T + B[-2])
in4 = dtanh(tanh(in3)@W[-1].T + B[-1])
print(in1.shape)
print(in2.shape)
print(in3.shape)
print(in4.shape)

torch.Size([3, 10])
torch.Size([3, 20])
torch.Size([3, 30])
torch.Size([3, 2])


In [56]:
a.shape

torch.Size([2, 2])

In [213]:
def LinearTanh(n_in, n_out):
    # do not work with ModuleList here either.
    block = nn.Sequential(
      nn.Linear(n_in, n_out),
      nn.Tanh()
    )
    return block


class MLP(nn.Module):
    
    def __init__(self, dim_layers):
        super(MLP, self).__init__()
        layers = []
        num_layers = len(dim_layers)
        
        blocks = []
        for l in range(num_layers-1):
            blocks.append(LinearTanh(dim_layers[l], dim_layers[l+1]))

        self.network = nn.Sequential(*blocks)
        
        self.sigma = lambda x : torch.tanh(x)
        self.dsigma = lambda x : 1-torch.tanh(x)**2
        self.ddsigma = lambda x : -2*torch.tanh(x)*(1-torch.tanh(x)**2)
    
    def forward(self, x):
        return self.network(x)
    
    
    def get_wb(self, depth):
        return self.network[depth][0].weight, self.network[depth][0].bias
    
    def compute_ux(self, x):
        W1, b1 = self.get_wb(0)
        d1 = self.dsigma(x @ W1.T + b1)
        a1 = self.sigma(x @ W1.T + b1)
        
        z = d1 @ W1
        
#         W2, b2 = self.get_wb(1)
#         d2 = self.dsigma(a1 @ W2.T + b2)
#         a2 = self.sigma(a1 @ W2.T + b2)
#         z = (d2 @ W2) * d1) @ W1
        
#         W3, b3 = self.get_wb(2)
#         d3 = self.dsigma(a2 @ W3.T + b3)
#         a3 = self.sigma(a2 @ W3.T + b3)
#         z = (d3 @ W3 * d2) @ W2 * d1) @ W1
        
#         W4, b4 = self.get_wb(3)
#         d4 = self.dsigma(a3 @ W4.T + b4)
#         a4 = self.sigma(a3 @ W4.T + b4)
        
#         z = ((((d4 @ W4 * d3) @ W3 * d2) @ W2 * d1) @ W1
        
        return z
    
    def compute_uxx(self, x):
        W1, b1 = self.get_wb(0)
        d1 = self.ddsigma(x@W1.T + b1)
        z = d1 @ W1**2
        return z
    
dim_in = 2
mlp_layers = [dim_in] + [1]

mlp = MLP(mlp_layers)

x = torch.randn(10,dim_in)
x.requires_grad_(True)
y = mlp(x)

ux1_auto = torch.autograd.grad(y, x, torch.ones_like(y), retain_graph=True, create_graph=True)[0]
ux1_analy = mlp.compute_ux(x)
print(ux1_auto - ux1_analy)

ux2_auto = torch.autograd.grad(ux1_auto[:,0], x, torch.ones_like(ux1_auto[:,0]))[0]
ux2_analy = mlp.compute_uxx(x)
print(ux2_analy - ux2_auto)


tensor([[-1.4901e-08, -1.4901e-08],
        [-2.9802e-08, -2.9802e-08],
        [ 1.4901e-08,  1.4901e-08],
        [-7.4506e-09, -1.4901e-08],
        [-1.4901e-08, -1.4901e-08],
        [ 0.0000e+00,  0.0000e+00],
        [ 3.7253e-09,  7.4506e-09],
        [ 0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00]], grad_fn=<SubBackward0>)
tensor([[ 0.0000e+00, -6.0322e-02],
        [ 7.4506e-09, -8.9252e-02],
        [-7.4506e-09, -5.4100e-02],
        [ 7.4506e-09, -4.8782e-02],
        [ 0.0000e+00, -5.1549e-02],
        [ 0.0000e+00,  8.1362e-02],
        [-3.7253e-09, -1.3981e-02],
        [ 0.0000e+00, -7.2688e-02],
        [ 0.0000e+00, -7.2188e-02],
        [-7.4506e-09, -5.1812e-02]], grad_fn=<SubBackward0>)
