In [2]:
import torch
import torch.nn as nn
import numpy as np

In [49]:
class Fourier(nn.Module):
    
    def __init__(self, nfeat, scale):
        super(Fourier, self).__init__()
        self.b = nn.Parameter(torch.randn(2, nfeat)*scale, requires_grad=False)
        self.pi = 3.14159265359

    def forward(self, x):
        size = x.shape[0]
        x = (2*self.pi*x) @ self.b.to(x.device)
        assert x.shape[0] == size
        assert x.shape[1] == self.b.shape[1]
        return torch.cat([torch.sin(x), torch.cos(x)], -1)
    

def LinearTanh(n_in, n_out):
    # do not work with ModuleList here either.
    block = nn.Sequential(
      nn.Linear(n_in, n_out),
      nn.Tanh()
    )
    return block


class MLP(nn.Module):
    
    def __init__(self, dim_layers, nfeat, scale):
        super(MLP, self).__init__()
        layers = []
        num_layers = len(dim_layers)
        
        blocks = []
        for l in range(num_layers-1):
            blocks.append(LinearTanh(dim_layers[l], dim_layers[l+1]))

        self.mlp = nn.Sequential(*blocks)
        
        self.sigma = lambda x : torch.tanh(x)
        self.dsigma = lambda x : 1-torch.tanh(x)**2
        self.ddsigma = lambda x : -2*torch.tanh(x)*(1-torch.tanh(x)**2)
        
        self.rff = Fourier(nfeat, scale)
        self.sincos = lambda x : torch.cat([torch.sin(x), torch.cos(x)], -1)
        self.dsincos = lambda x : torch.cat([torch.cos(x), -torch.sin(x)], -1)
    
    def forward(self, x):
        x = self.rff(x)
        x = self.mlp(x)
        return x
    
    
    def get_wb(self, depth):
        return self.mlp[depth][0].weight, self.mlp[depth][0].bias
    
    def compute_ux(self, x):
        # RFF
        Wr = self.rff.b
        Wr2 = torch.cat([Wr, Wr], axis=1).T
        
        dr = 2*self.rff.pi*self.dsincos(2*self.rff.pi*x @ Wr) # B x nfeat
        ar = self.sincos(2*self.rff.pi*x @ Wr) # B x nfeat
#         z = dr @ Wr

        # MLP
#         W1, b1 = self.get_wb(0)
#         d1 = self.dsigma(x @ W1.T + b1)
#         a1 = self.sigma(x @ W1.T + b1)
        
        W1, b1 = self.get_wb(0)
        d1 = self.dsigma(ar @ W1.T + b1)
        a1 = self.sigma(ar @ W1.T + b1)
        
        z = ((d1 @ W1) * dr) @ Wr2
#         z = d1 @ W1
        
        W2, b2 = self.get_wb(1)
        d2 = self.dsigma(a1 @ W2.T + b2)
        a2 = self.sigma(a1 @ W2.T + b2)
#         z = (d2 @ W2) * d1) @ W1
        z = ((((d2 @ W2) * d1) @ W1) * dr) @ Wr2
        
        W3, b3 = self.get_wb(2)
        d3 = self.dsigma(a2 @ W3.T + b3)
        a3 = self.sigma(a2 @ W3.T + b3)
#         z = (d3 @ W3 * d2) @ W2 * d1) @ W1
        z = (((((d3 @ W3 * d2) @ W2) * d1) @ W1) * dr) @ Wr2
        
        W4, b4 = self.get_wb(3)
        d4 = self.dsigma(a3 @ W4.T + b4)
        a4 = self.sigma(a3 @ W4.T + b4)        
#         z = ((((d4 @ W4 * d3) @ W3 * d2) @ W2 * d1) @ W1
        z = ((((((d4 @ W4 * d3) @ W3 * d2) @ W2) * d1) @ W1) * dr) @ Wr2
        
        return z
    
    def compute_uxx(self, x):
        W1, b1 = self.get_wb(0)
        d1 = self.ddsigma(x@W1 + b1)
        z = d1 @ W1**2
        return z
    
dim_in = 2
nfeat = 256
scale = 10

mlp_layers = [2*nfeat] + 3*[256] + [1]

mlp = MLP(mlp_layers, nfeat, scale)

x = torch.randn(10,dim_in)
x.requires_grad_(True)
y = mlp(x)

ux1_auto = torch.autograd.grad(y, x, torch.ones_like(y), retain_graph=True, create_graph=True)[0]
print(ux1_auto)
ux1_analy = mlp.compute_ux(x)
print(ux1_auto - ux1_analy)

tensor([[-2.1673, -1.5006],
        [-1.6214,  2.3846],
        [-4.9490,  4.8836],
        [ 1.7014,  0.1088],
        [ 0.8186, -2.7692],
        [-3.2884,  6.9571],
        [ 4.8860, -0.8622],
        [ 1.7091, -0.5310],
        [-0.5050, -0.4975],
        [-0.6334,  4.7950]], grad_fn=<MulBackward0>)
tensor([[-7.1526e-07,  5.9605e-07],
        [-8.3447e-07,  9.5367e-07],
        [-4.7684e-07, -4.7684e-07],
        [ 2.0266e-06, -7.4506e-09],
        [ 1.3709e-06,  2.3842e-07],
        [-2.3842e-07,  0.0000e+00],
        [ 0.0000e+00, -1.4305e-06],
        [-1.3113e-06, -1.1325e-06],
        [-1.1921e-07, -6.5565e-07],
        [ 1.1921e-06,  1.4305e-06]], grad_fn=<SubBackward0>)
