In [1]:
import torch
from torch import nn, optim, tensor, Tensor
from torchdiffeq import odeint, odeint_adjoint
from matplotlib import pyplot as plt

$$
\begin{cases}
\sigma(x) = \log (\exp(-x) + \exp(x)) \\
\sigma'(x) = \tanh(x) \\
\sigma''(x) = 1 - \tanh^2(x)
\end{cases}
$$


In [95]:
# def antidrivtanh(x):
#     return torch.abs(x) + torch.log(1+torch.exp(-2.0 * torch.abs(x)))

def antidrivtanh(x):
    return torch.log(torch.exp(-x) + torch.exp(x))

def derivtanh(x):
    return 1 - torch.tanh(x).pow(2)

- ResNet：This module is used to convert $\mathbb{R}^{d+1}$ into $\mathbb{R}^{m}$:
  1. The first layer is used to convert the dimension of data.
  2. The layers except for first is used to extract feature in $\mathbb{R}^{m}$

In [61]:
class ResNet(nn.Module):
    def __init__(self, d, m, nTh=2):
        assert nTh >=2, "nTh should not less than 2"
        super(ResNet, self).__init__()
        self.d = d
        self.m = m
        self.nTh = nTh
        
        self.layers = nn.ModuleList([nn.Linear(d+1,m, bias=True)]  + [nn.Linear(m,m, bias=True) for _ in range(nTh-1)])
        self.act = antidrivtanh
        self.h = 1.0 / (self.nTh - 1)

    def forward(self,x):
        x = self.act(self.layers[0].forward(x))

        for i in range(1, self.nTh):
            x = x + self.h * self.act(self.layers[i].forward(x))

        
        return x

- `Phi.forward()`
	$$ \begin{split} \Phi(\mathbf{s};\boldsymbol{\theta})  = \mathbf{w}^{T} N(\mathbf{s};\boldsymbol{\theta}_{N}) + \frac{1}{2} \mathbf{s}^{T}(A^{T}A)\mathbf{s} + \mathbf{b}^{T}\mathbf{s} + c,\\ \text{where} \quad \boldsymbol{\theta} =(\mathbf{w},\boldsymbol{\theta}_{N}, A,\mathbf{b},c) \end{split} $$
- `Phi.trHess`: Compute the trace
	1. 梯度计算:
		$$ \nabla_{\mathbf{s}}\Phi(\mathbf{s};\boldsymbol{\theta}) = \nabla_{\mathbf{s}}N(\mathbf{s};\boldsymbol{\theta}_{N})\mathbf{w} + (A^{T}A)\mathbf{s} + \mathbf{b} $$
		其中，$\nabla_{\mathbf{s}}N(\mathbf{s};\boldsymbol{\theta}_N)\mathbf{w}$ 可用如下的方法进行计算
		$$ \begin{split} &\mathbf{z}_{1} = \mathbf{w} + h K_{1}^{T} \text{diag}(\sigma'(K_{1} \mathbf{u}_{0}+\mathbf{b}_{1})) \mathbf{w},\\ &\mathbf{z}_{0} = K_{0}^{T}\text{diag}(\sigma'(K_{0}\mathbf{s}+\mathbf{b}_{0}))\mathbf{z}_{1},\quad \text{where}\\ &\nabla_{\mathbf{s}}n(\mathbf{s};\boldsymbol{\theta}_{n})\mathbf{w} = \mathbf{z}_{0} \end{split} $$
	2. 

In [108]:
class Phi(nn.Module):
    def __init__(self, nTh, m, d, r=10, alph=[1.0]*5):
        super(Phi,self).__init__()

        self.m = m
        self.nTh = nTh
        self.d = d
        self.alph = alph

        r = min(r,d+1)

        self.A = nn.Parameter(torch.zeros(r,d+1), requires_grad=True)
        self.A = nn.init.xavier_normal_(self.A)
        self.c = nn.Linear(d+1, 1, bias=True)
        self.w = nn.Linear(m , 1, bias=False)

        self.N = ResNet(d,m, nTh=nTh)

        self.w.weight.data = torch.ones(self.w.weight.data.shape)
        self.c.weight.data = torch.ones(self.c.weight.data.shape)
        self.c.bias.data = torch.ones(self.c.bias.data.shape)
    
    def forward(self,x):
        symA = torch.mul(self.A.T,self.A)

        return self.w(self.N(x)) + 1/2 * torch.sum(torch.matmul(x, symA) * x, dim=1, keepdims=True) + self.c(x)
    
    def trHess(self, x, justGrad=False):
        N = self.N
        m = N.layers[0].weight_shape[0]
        nex = x.shape[0]
        d = x.shape[1] -1
        symA = torch.matmul(self.A.T, self.A)

        u = []  # store the result of each layer in ResNet
        z = N.nTh * [None]
        opening = N.layers[0].forward(x)

        u.append(N.act(opening))
        feat = u[0]

        for i in range(1, N.nTh):
            feat = feat + N.h * N.act(N.layers[i].forward(feat))
            u.append(feat)

        tanhopen = torch.tanh(opening)

        for i in range(N.nTh-1, 0, -1):
            if i == N.nTh-1:
                term = self.w.weight.T
            else:
                term = z[i+1]
            
            z[i] = term + N.h * torch.mm(N.layers[i].weight.T, torch.tanh(N.layers[i].forward(u[i-1])).T * term)

        z[0] = torch.mm(N.layers[0].weight.T, tanhopen.T * z[1])
        grad = z[0] + torch.mm(symA, x.T) + self.c.weight.T

        if justGrad:
            return grad.T
        
        Kopen = N.layers[0].weight[:,0:d]
        temp = derivtanh(opening.T) * z[1]
        trH = torch.sum(temp.reshape(m, -1, nex) * Kopen.unsqueeze(2).pow(2), dim=(0,1))






In [109]:
a = torch.rand(3,3)

In [110]:
a.reshape(3,-1,3)

tensor([[[0.2514, 0.5387, 0.8071]],

        [[0.8981, 0.0797, 0.6337]],

        [[0.4484, 0.0544, 0.7297]]])

In [97]:
A = torch.randn(3,10)

In [103]:
A

tensor([[-0.8469, -0.1457, -0.2025, -0.4517, -1.9758, -1.0002, -0.7658, -1.2111,
          0.8555, -0.1719],
        [ 0.1901, -0.5251,  0.2471,  1.8723,  0.1929, -0.3041, -0.9909, -0.7169,
          0.1969,  1.3466],
        [ 0.9016,  0.1166, -0.4051,  1.4957,  0.8242,  0.6152,  0.1457, -2.0950,
          0.5188,  0.2074]])

In [75]:
x = torch.randn(10)

In [80]:
x @ A.T @ A 

tensor([-1.7451,  9.9124,  6.2011,  3.0812,  3.2062, -7.0861,  1.2512, -9.9082,
         0.9248, -6.4054])

In [89]:
a = Phi(2, 10, 2)(torch.randn(2,3))

In [92]:
a.shape

torch.Size([2, 1])

In [50]:

t = nn.ModuleList([nn.Linear(3,2, bias=True)]  + [nn.Linear(2,2, bias=True) for _ in range(3-2)])

In [51]:
t[1]

Linear(in_features=2, out_features=2, bias=True)

tensor([[-0.3446,  0.0425, -1.4836, -0.4734],
        [-0.3858,  0.1519,  0.0171, -0.6214],
        [-0.1361,  1.9312, -0.8237, -0.0834]])

In [56]:
t = ResNet(3, 2)
t(torch.randn(3,4))

tensor([[1.4962, 1.5356],
        [1.4549, 1.7337],
        [1.5968, 1.7456]], grad_fn=<AddBackward0>)

[Linear(in_features=3, out_features=2, bias=True),
 Linear(in_features=3, out_features=3, bias=True)]