In [5]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook as tqdm
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
import warnings

warnings.filterwarnings("ignore")

In [6]:
hn = 3 #hidden state dim

class Model(nn.Module):
     
    def __init__(self, N):
        super().__init__()
        self.N = N
        self.stepLength = 2        
        self.hidden_initial = nn.Parameter(torch.empty(1, hn).normal_(mean=0, std=1)) #make inital hidden trainable
        self.hidden = self.hidden_initial.repeat((self.N, 1))  
        
        self.Whh = nn.Linear(hn, hn)
        self.Whx = nn.Linear(2, hn, bias = False)
        
        self.layer1 = nn.Linear(2 + hn, 128)
        self.layer2 = nn.Linear(128, 64)
        self.layer3 = nn.Linear(64, 32)
        self.layer4 = nn.Linear(32, 1)
  
    
    def forward_DNN(self, x):
        """Approximate wave function"""
        
        #concatenate hidden state to input of DNN
        if self.hidden.shape[0] == x.shape[0]:
            y = torch.cat((self.hidden, x), 1)
        else:
            y = torch.cat((self.hidden[0].repeat(x.shape[0], 1), x), 1)
            
        y = torch.relu(self.layer1(y))
        y = torch.relu(self.layer2(y))
        y = torch.relu(self.layer3(y))
        y = self.layer4(y)
        
        return y
        
    
    
    def forward_RNN(self, x):
        """Encode correlation in hidden state"""
        self.hidden = torch.tanh(self.Whh(self.hidden) + self.Whx(x))
    
    
    def sample(self, N, n):
        total = 0
        x = torch.Tensor(4*np.random.random((N,2)) - 2)
        psi_old = self.forward_DNN(x)
        
        for i in range(n):
            x_new = x + self.stepLength*torch.Tensor(2*np.random.random((N,2)) - 1)
            psi_new = self.forward_DNN(x_new)
            
            idx = ((psi_new/psi_old)**2 > torch.Tensor(np.random.random((N,1)))).reshape(-1)
            
            
            x[idx] = x_new[idx]
            psi_old[idx] = psi_new[idx]
            total += torch.sum(idx)
            
        return x, total
    
    def resetHidden(self, N=0):
        if N == 0: N = self.N
        self.hidden = self.hidden_initial.repeat((N, 1))

In [7]:
N = 5000 #Batch size
n = 10  #Metropolis step
h = 0.01
h1 = torch.Tensor([h, 0]).repeat((N,1))
h2 = torch.Tensor([0, h]).repeat((N,1))


torch.manual_seed(42)
np.random.seed(42)
model = Model(N)
optimizer = torch.optim.Adam(model.parameters())

### Training

In [9]:
epochs = 2000

for epoch in tqdm(range(epochs)):
    PE_acc = 0
    P_acc = 0
    E_acc = 0
        
    model.resetHidden() 
    x1 = model.sample(N, n)[0].detach() #sample N positions for first particle
    psi1 = model.forward_DNN(x1)
    
    model.forward_RNN(x1)               #Advance hidden state based on previous sampled positions
    x2 = model.sample(N, n)[0].detach() #Sample positions for second particle, corrolated with first particle
    psi2 = model.forward_DNN(x2)        
    
    psi_total = psi1*psi2               #Total wave function
    
    
    #Numerical laplacian
    model.resetHidden()
    psi1_plus = model.forward_DNN(x1 + h1) #Change particle 1 coordinate, holding particle 2 constant. 
    model.forward_RNN(x1 + h1)             #however, psi2 still changes because of corrolation
    psi2_plus = model.forward_DNN(x2)
    
    model.resetHidden()
    psi1_minus = model.forward_DNN(x1 - h1)
    model.forward_RNN(x1 - h1)
    psi2_minus = model.forward_DNN(x2)

    lap1 = 1/psi_total*(psi1_plus*psi2_plus - 2*psi_total + psi1_minus*psi2_minus)/h**2
    
    ########################
    model.resetHidden()
    psi1_plus = model.forward_DNN(x1 + h2) #Change particle 1 coordinate, holding particle 2 constant. 
    model.forward_RNN(x1 + h2)              #however, psi2 still changes because of corrolation
    psi2_plus = model.forward_DNN(x2)
    
    model.resetHidden()
    psi1_minus = model.forward_DNN(x1-h2)
    model.forward_RNN(x1-h2)
    psi2_minus = model.forward_DNN(x2)

    lap2 = 1/psi_total*(psi1_plus*psi2_plus - 2*psi_total + psi1_minus*psi2_minus)/h**2
    ########################
    
    model.resetHidden()
    model.forward_RNN(x1)
    psi2_plus = model.forward_DNN(x2+h1) #Change particle 2 coordinate, holding particle 1 constant.
    psi2_minus = model.forward_DNN(x2-h1)
    
    lap3 = 1/psi2*(psi2_plus - 2*psi2 + psi2_minus)/h**2    #psi1 factor here, since it is constant
    
    ########################
    model.resetHidden()
    model.forward_RNN(x1)
    psi2_plus = model.forward_DNN(x2+h2) #Change particle 2 coordinate, holding particle 1 constant.
    psi2_minus = model.forward_DNN(x2-h2)
    
    lap4 = 1/psi2*(psi2_plus - 2*psi2 + psi2_minus)/h**2    #psi1 factor here, since it is constant
    ########################
    

    E_L = (-0.5*(lap1 + lap2 + lap3 + lap4) + 0.5*torch.sum(x1**2 + x2**2, dim=1) + 0/torch.sqrt(torch.sum((x1 - x2)**2, dim=1) + 0.1**2)).detach()

    PE = torch.mean(torch.log(psi_total)*E_L)
    P  = torch.mean(torch.log(psi_total))
    E  = torch.mean(E_L)  
    
    loss = 2*(PE - P*E)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    
    if epoch%100 == 0: 
        print(f"epoch: {epoch}, Energy: {E.item()}")

HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))

epoch: 0, Energy: 30.929119110107422
epoch: 100, Energy: 35.20866775512695



KeyboardInterrupt: 

In [10]:
def f(x):
    x_lin = torch.ones((5000,2))*x
    x_lin[:,0] = torch.linspace(-8, 8, 5000).reshape(5000)
    print(x_lin.shape)
    dx = x_lin[1,0] - x_lin[0,0] 
    model.resetHidden()
    psi1 = model.forward_DNN(x_lin)[:,0].detach().numpy()

    #x = torch.Tensor([[0,0]])
    
    #model.forward_RNN(x)
    #psi2 = model.forward_DNN(x_lin)[:,0].detach().numpy()

    plt.plot(x_lin[:,0], 1/dx*psi1**2, "b")
    #plt.plot(x_lin[:,0], 1/dx*psi2**2/np.sum(psi2**2), "r")
    #plt.plot(x, 0.01, "bo")
    #plt.ylim((0, 0.8))

    plt.show()
    
def g(x):
    model.resetHidden()

    x = torch.Tensor([[x]])

    plt.plot(model.hidden[0].detach().numpy(), "bo")
    model.forward_RNN(x)

    plt.plot(model.hidden[0].detach().numpy(), "ro")
    plt.ylim((-1, 1))

    plt.show()

interact(f, x=(-4.0, 4., 0.05));
#interact(g, x=(-4.0, 4., 0.1));

interactive(children=(FloatSlider(value=0.0, description='x', max=4.0, min=-4.0, step=0.05), Output()), _dom_c…

In [None]:
model.resetHidden()
x1, total = model.sample(10000, 20)
print(total)

### Check metropolis sampling

In [None]:
N = 100000
model.resetHidden()
x_lin = torch.linspace(-3, 3, 100).reshape(100,-1)
dx = x_lin[1] - x_lin[0]

model.forward_RNN(torch.Tensor([[0.4]]))
psi = model.forward_DNN(x_lin)[:,0].detach().numpy()

x = model.sample(N, 10)[0].detach().numpy()

bins = np.linspace(-3, 3, 100)
plt.hist(x, bins=bins)
plt.plot(x_lin[:,0], N*psi**2/np.sum(psi**2), "r")
plt.show()

## Estimating energy

In [None]:
N = 100000

model.resetHidden(N) 
x1 = model.sample(N, n)[0].detach() #sample N positions for first particle
psi1 = model.forward_DNN(x1)

model.forward_RNN(x1)               #Advance hidden state based on previous sampled positions
x2 = model.sample(N, n)[0].detach() #Sample positions for second particle, corrolated with first particle
psi2 = model.forward_DNN(x2)        

psi_total = psi1*psi2               #Total wave function

#Numerical laplacian
model.resetHidden(N)
psi1_plus = model.forward_DNN(x1+h) #Change particle 1 coordinate, holding particle 2 constant. 
model.forward_RNN(x1 + h)           #however, psi2 still changes because of corrolation
psi2_plus = model.forward_DNN(x2)

model.resetHidden(N)
psi1_minus = model.forward_DNN(x1-h)
model.forward_RNN(x1 - h)
psi2_minus = model.forward_DNN(x2)

lap1 = 1/psi_total*(psi1_plus*psi2_plus - 2*psi_total + psi1_minus*psi2_minus)/h**2

model.resetHidden(N)
model.forward_RNN(x1)
psi2_plus = model.forward_DNN(x2+h) #Change particle 2 coordinate, holding particle 1 constant.
psi2_minus = model.forward_DNN(x2-h)

lap2 = 1/psi2*(psi2_plus - 2*psi2 + psi2_minus)/h**2    #psi1 factor here, since it is constant


E = (-0.5*(lap1 + lap2) + 0.5*(x1**2 + x2**2) + 1/torch.sqrt((x1 - x2)**2 + 0.1**2))
E = torch.mean(E)
print(E.item())