In [3]:
import torch
import numpy as np
from qmc.tracehess import autograd_trace_hessian
from torch import nn, optim

In [4]:
class NelectronVander(nn.Module):
    #ansatz given by the Vandermonde determinant of the one electron wavefunctions e^(-alpha_i*r_i)
    #input is 1D tensor which determines the number of particles (i.e. the dimension)
    def __init__(self, alpha):
        super(NelectronVander, self).__init__()
        self.alpha = nn.Parameter(alpha)
    
    def forward(self, x):
        #returns the log prob. of the wavefunction
        #input is tensor of size m x alpha.size or m x n x alpha.size
        a = torch.exp(-self.alpha*x.unsqueeze(-1)) - torch.exp(-self.alpha*x.unsqueeze(-2))
        return 2 * torch.sum(torch.log(torch.abs(a[...,torch.triu(torch.ones(self.alpha.shape[0],self.alpha.shape[0]), diagonal=1).nonzero(as_tuple = True)[0],torch.triu(torch.ones(self.alpha.shape[0],self.alpha.shape[0]), diagonal=1).nonzero(as_tuple = True)[1] ])),-1)
    
    
    def wave(self,x):
        # Returns the value of the wavefunction
        #input is tensor of size m x alpha.size or m x n x alpha.size
        a = torch.exp(-self.alpha*x.unsqueeze(-1)) - torch.exp(-self.alpha*x.unsqueeze(-2))
        return torch.prod(a[...,torch.triu(torch.ones(self.alpha.shape[0],self.alpha.shape[0]), diagonal=1).nonzero(as_tuple = True)[0],torch.triu(torch.ones(self.alpha.shape[0],self.alpha.shape[0]), diagonal=1).nonzero(as_tuple = True)[1] ],-1)
    
    
        

In [5]:
class NelectronVanderWithMult(nn.Module):
    #ansatz given by the Vandermonde determinant of the one electron wavefunctions e^(-alpha_i * r_i) 
    #multiplied by e^(-beta * (r_1 + r_2 + ... r_N))
    #input is alpha, beta where alpha is 1D tensor which determines the number of particles and beta is scalar
    def __init__(self, alpha, beta):
        super(NelectronVanderWithMult, self).__init__()
        self.alpha = nn.Parameter(alpha)
        self.beta = nn.Parameter(beta)
    
    def forward(self, x):
        #returns the log prob. of the wavefunction
        #input is tensor of size m x alpha.size or m x n x alpha.size
        a = torch.exp(-self.alpha*x.unsqueeze(-1)) - torch.exp(-self.alpha*x.unsqueeze(-2))
        return 2 * ( -self.beta * torch.sum(x, -1)
            + torch.sum(torch.log(torch.abs(a[...,torch.triu(torch.ones(self.alpha.shape[0],self.alpha.shape[0]), diagonal=1).nonzero(as_tuple = True)[0],torch.triu(torch.ones(self.alpha.shape[0],self.alpha.shape[0]), diagonal=1).nonzero(as_tuple = True)[1] ])),-1) )
    
    
    def wave(self,x):
        # Returns the value of the wavefunction
        #input is tensor of size m x alpha.size or m x n x alpha.size
        a = torch.exp(-self.alpha*x.unsqueeze(-1)) - torch.exp(-self.alpha*x.unsqueeze(-2))
        return torch.exp(-self.beta * torch.sum(x, -1)) * torch.prod(a[...,torch.triu(torch.ones(self.alpha.shape[0],self.alpha.shape[0]), diagonal=1).nonzero(as_tuple = True)[0],torch.triu(torch.ones(self.alpha.shape[0],self.alpha.shape[0]), diagonal=1).nonzero(as_tuple = True)[1] ],-1)
    
    

In [6]:
f = NelectronVander(torch.rand(5))

In [7]:
f(torch.rand((3,5)))

tensor([-56.8537, -46.5428, -55.8184], grad_fn=<MulBackward0>)

In [8]:
f.wave(torch.rand((3,7,5)))

tensor([[ 1.3805e-10, -1.8393e-12,  1.3022e-12, -2.1324e-13,  6.8906e-12,
          2.2860e-13, -4.0267e-14],
        [-5.7630e-12, -1.7393e-10,  1.1242e-11, -1.3832e-10,  1.3604e-16,
          5.6603e-13,  1.2050e-12],
        [ 1.5432e-12,  1.4025e-10, -3.0663e-13, -1.5626e-11, -4.9439e-11,
         -1.1078e-14,  6.9832e-11]], grad_fn=<ProdBackward1>)

In [9]:
x = torch.rand(4,6,5)

In [10]:
f = NelectronVanderWithMult(torch.rand(5),torch.rand(1))

In [11]:
f.wave(torch.rand((3,7,5)))

tensor([[-1.1293e-14,  1.7137e-13, -6.7046e-12, -9.1058e-17, -6.0976e-14,
          1.5243e-16, -2.0857e-11],
        [ 8.4631e-14, -1.4440e-15, -2.7262e-16, -2.2096e-17, -1.1676e-11,
         -1.6118e-11, -1.4632e-12],
        [-1.5815e-12,  1.1804e-13,  6.3631e-12, -7.4608e-15,  9.0964e-12,
          1.9701e-15, -1.1346e-13]], grad_fn=<MulBackward0>)

In [12]:
f(torch.rand((9,5)))

tensor([-55.3138, -64.7939, -52.1932, -63.4310, -55.9829, -55.3099, -69.5667,
        -62.8414, -64.5213], grad_fn=<MulBackward0>)