In [None]:
#H2+     Energy = -0.6023424   for R = 1.9972
#fit(batch_size=10000, n_el=1, steps=500, epochs=1, RR=[[-1, 0, 0], [1., 0, 0]])

#H2		 Energy = -1.173427    for R = 1.40
#fit(batch_size=10000,n_el=2,steps=100,epochs=5,RR=torch.tensor([[-0.7,0,0],[0.7,0,0]]))

#He+	 Energy = -1.9998
#fit(batch_size=10000,n_el=1,steps=100,epochs=5,RR=torch.tensor([[0.,0,0]]),RR_charges=[2])

#He		 Energy = −2.90338583
#fit(batch_size=10000,n_el=2,steps=300,epochs=5,RR=torch.tensor([[0.3,0,0]]),RR_charges=[2])

In [None]:
%load_ext autoreload
%autoreload 1
%aimport  dlqmc.sampling, dlqmc.utils, dlqmc.nn.base, dlqmc.fit, dlqmc.nn.cusp
%config InlineBackend.figure_format = 'svg' 
%config InlineBackend.print_figure_kwargs = \
    {'bbox_inches': 'tight', 'dpi': 300}

In [None]:
# exteternal ___________________

import ipywidgets
import torch.nn as nn
import numpy as np
from scipy import special
import scipy.stats as sps
import matplotlib.pyplot as plt
import torch
import sys, os
from matplotlib import cm
import time
from functools import partial
from tqdm.auto import tqdm, trange
from tensorboardX import SummaryWriter
import threading
import pickle

# pyscf __________________________

from pyscf import gto, scf, mcscf
from pyscf import gto, scf, dft
import pyscf
from pyscf.data.nist import BOHR


# dlqmc ____________________________

from dlqmc.utils import *
from dlqmc.analysis import pair_correlations_from_samples
from dlqmc.nn.base import * 
from dlqmc.geom import *
from dlqmc.sampling import langevin_monte_carlo, hmc ,samples_from, take, sample_start
from dlqmc.fit import *
from dlqmc.nn.anti import *
from dlqmc.physics import (
    local_energy, grad, quantum_force,nuclear_potential,
    nuclear_energy, laplacian, electronic_potential
)
from dlqmc.analysis import autocorr_coeff
from dlqmc.stats import GaussianKDEstimator

from dlqmc.nn.gto import *
from dlqmc.nn import ssp
from dlqmc.nn.hannet import *
from dlqmc.nn.cusp import *
from dlqmc.nn.schnet import *
from dlqmc.nn.slaterjastrownet import *


def normplot(x,y,norm,*args,**kwargs):
    if norm:
        plt.plot(x,y/np.max(np.abs(y)),*args,**kwargs)
    else:
        plt.plot(x,y,*args,**kwargs)
        
def test_sampling(sampler):
    samples, psis, info = samples_from(sampler,trange(10))
    return info.mean()

def make_electron_line(mm=(-2,4),d=2,points=500,n_electrons=5,e=0,dim=0,offset=0):
    
    ind = np.arange(0,n_electrons*3)
    ind[e*3+dim] = 0
    ind[0] = e*3+dim
    
    x_line = torch.cat((torch.linspace(mm[0], mm[1], 500)[:, None], torch.zeros((500, 3*n_electrons-1))), dim=1)
    x_line = x_line.t()[ind].t()
    x_line = x_line.view(-1,n_electrons,3)
    
    for i in range(n_electrons):
        if not i == e:
            x_line[:,i,dim]=d*i
            x_line[:,i,dim]+=torch.randn(1)*offset
        
    x_line = x_line.cuda()
    
    return x_line

print(torch.cuda.memory_allocated())
print(torch.cuda.memory_cached(device=None))
print(torch.cuda.max_memory_cached(device=None))
torch.cuda.empty_cache()

c=cm.get_cmap('Paired')
c2=cm.get_cmap('Reds')
c3=cm.get_cmap('Blues')
c4=cm.get_cmap('tab10')

In [None]:
class jastrow_factory(nn.Module):
    
    def __init__(self,n_nuclei, dist_basis_dim, n_up, n_down):
        
        super().__init__()
        self.schnet = ElectronicSchnet(
            n_up = n_up,
            n_down = n_down,
            n_nuclei = n_nuclei,
            n_interactions=3,
            basis_dim=dist_basis_dim,
            kernel_dim=128,
            embedding_dim=128,
            interaction_factory=None,)
        
        self.log_dnn = get_log_dnn(128, 1, SSP, n_layers=5)
        
    
    def forward(self,x):
        
        x = self.schnet(x).sum(dim=1)
        x = self.log_dnn(x).squeeze(-1)
        return x


## $H_{10}$

Han -5.5685, Benchmark -5.6655, Benchmark (VMC)-5.65634 / 5.65201

In [None]:
d=1.786
n=10
hn = Geometry([[d*i, 0., 0.] for i in range(n)], [1. for i in range (n)])
n_electrons=n
n_down = n//2
n_up = n_electrons-n_down

In [None]:
mol = gto.M(
    atom=[
        ['H', (d*i, 0, 0)] for i in range(n)      
    ],
    unit='bohr',
    basis='6-31G',#'aug-cc-pV5Z',
    charge=0,
    spin=n%2,
    cart=True
)
mf = scf.RHF(mol)
E_hf=mf.kernel()

mc = mcscf.CASSCF(mf, n,n)
mc.kernel()[0]

gtowf = SlaterJastrowNet.from_pyscf(mf,cusp_correction=True,cusp_electrons=True).cuda()
gtowf2 = SlaterJastrowNet.from_pyscf(mc,cusp_correction=True,cusp_electrons=True).cuda()


In [None]:
electron_line=0
x_line=make_electron_line(d=d,n_electrons=n_electrons,e=electron_line,offset=0.5)
normplot(x_line[:,electron_line , 0].cpu().detach().numpy(), gtowf(x_line).cpu().detach().numpy(),ls='-',label="hf",norm=True)
normplot(x_line[:,electron_line , 0].cpu().detach().numpy(), gtowf2(x_line).cpu().detach().numpy(),ls='-',label="mcscf",norm=True)

#x_line=x_line[:,[0,2,1]]
#
#normplot(x_line[:,0 , 0].cpu().detach().numpy(), hfschnet(x_line).cpu().detach().numpy(),ls='-',label="HFschnet",norm=normed)
#normplot(x_line[:,0 , 0].cpu().detach().numpy(), -gtowf(x_line).cpu().detach().numpy(),ls=':',label="HF",norm=normed)


plt.axhline(0,ls=':',color='k')
for i in range(n):
    plt.axvline(d*i,ls=':',color='k')

for i in range(n_electrons):
    if not i==electron_line:
        plt.axvline(x_line[0,i][0],ls=':',color='r')


plt.legend()
plt.show()



In [None]:
sjnet = SlaterJastrowNet.from_pyscf(mf,jastrow_factory=jastrow_factory,cusp_electrons=True).cuda()

In [None]:
n_wlaker=1500

samplersj = langevin_monte_carlo(
    sjnet,
    sample_start(hn,n_wlaker,n_electrons,var=1),
    tau=.4,
)

samplergto = langevin_monte_carlo(
    gtowf,
    sample_start(hn,n_wlaker,n_electrons,var=1),
    tau=.4,
)

samplergto2 = langevin_monte_carlo(
    gtowf2,
    sample_start(hn,n_wlaker,n_electrons,var=1),
    tau=.4,
)

In [None]:
writer_name=str(np.round(time.time(),0))

opt = torch.optim.Adam(sjnet.parameters(), lr=5e-3)

scheduler = torch.optim.lr_scheduler.ExponentialLR(opt, 0.99, last_epoch=-1)

#________________________________________________________________
writer_name=str(np.round(time.time(),0))

#samples_from(samplersj,trange(500))

    
fit_wfnet(
    sjnet,
    loss_total_energy_indirect,
    opt,
    wfnet_fit_driver(
            samplersj,
            samplings=range(100),
            n_epochs=1,
            n_sampling_steps=30,
            batch_size=150,
            n_discard=29,
            range_sampling=partial(trange, desc='sampling steps', leave=False),
            range_training=partial(trange, desc='training steps', leave=False)
        ),
    indirect=True,
    scheduler = scheduler,
    acc_grad=5,
    writer = SummaryWriter(f'runs/'+writer_name),
)



In [None]:
electron_line=0
x_line=make_electron_line(d=d,n_electrons=n_electrons,e=electron_line,offset=0.5)
normplot(x_line[:,electron_line , 0].cpu().detach().numpy(), sjnet(x_line).cpu().detach().numpy(),ls='-',label="sjnet",norm=True)
normplot(x_line[:,electron_line , 0].cpu().detach().numpy(), gtowf(x_line).cpu().detach().numpy(),ls=':',label="hf",norm=True)
normplot(x_line[:,electron_line , 0].cpu().detach().numpy(), gtowf2(x_line).cpu().detach().numpy(),ls=':',label="mcscf",norm=True)

plt.axhline(0,ls=':',color='k')
for i in range(n):
    plt.axvline(d*i,ls=':',color='k')

for i in range(n_electrons):
    if not i==electron_line:
        plt.axvline(x_line[0,i][0],ls=':',color='r')


plt.legend()
plt.show()



In [None]:
for i,net in enumerate([sjnet,gtowf,gtowf2]):
    plt.plot(
        x_line[:, 0, 0].detach().cpu().numpy(),
        local_energy(x_line,net, net.geom)[0].cpu().detach().numpy(),label=['sjnet','hfwf','multiwf'][i]
    plt.ylim((-30, 30));

In [None]:
test_sampling(samplersj)

In [None]:
for net,sampler in zip([sjnet,gtowf,gtowf2],[samplersj,samplergto,samplergto2]):
    t=time.time()
    samples, psis, info = samples_from(sampler,trange(20))
    samples = samples.flatten(end_dim=1)[800:]
    E_loc = dlqmc.utils.batch_eval_tuple(partial(local_energy, wf=lambda x: net(x),geom=net.geom),tqdm(samples.view([-1,n_electrons,3]).split(1000)))[0]
    #E_loc = E_loc.clamp(-8, -3)
    mean=E_loc.mean().item()
    h = plt.hist(E_loc.detach().cpu().numpy(), bins=100,range=(-10,0),alpha = 0.5,label=("mean = "+str(np.round(mean,4))+", var = "+str(np.round(np.var(E_loc.cpu().detach().numpy()),4))))
    plt.legend()

del E_loc 

In [None]:
for net,sampler in zip([sjnet,gtowf,gtowf2],[samplersj,samplergto,samplergto2]):
    t=time.time()
    samples, psis, info = samples_from(sampler,trange(20))
    samples = samples.flatten(end_dim=1)[800:]
    E_loc = dlqmc.utils.batch_eval_tuple(partial(local_energy, wf=lambda x: net(x),geom=net.geom),tqdm(samples.view([-1,n_electrons,3]).split(1000)))[0]
    #E_loc = E_loc.clamp(-8, -3)
    mean=E_loc.mean().item()
    h = plt.hist(E_loc.detach().cpu().numpy(), bins=100,range=(-10,0),alpha = 0.5,label=("mean = "+str(np.round(mean,4))+", var = "+str(np.round(np.var(E_loc.cpu().detach().numpy()),4))))
    plt.legend()

del E_loc 

In [None]:
for net,sampler in zip([sjnet],[samplersj]):
    t=time.time()
    samples_, psis, info = samples_from(sampler,trange(100))
    samples_ = samples_.flatten(end_dim=1)
    means=[]
    for i in range(10):
        samples = samples_[1500*i:1500*(i+1)]
        E_loc = dlqmc.utils.batch_eval_tuple(partial(local_energy, wf=lambda x: net(x),geom=net.geom),tqdm(samples.view([-1,n_electrons,3]).split(1000)))[0]
        E_loc.clamp(-15,5)
        mean=E_loc.mean().item()
        means.append(mean)
        h = plt.hist(E_loc.detach().cpu().numpy(), bins=100,range=(-10,0),alpha = 0.5,label=("mean = "+str(np.round(mean,4))+", var = "+str(np.round(np.var(E_loc.cpu().detach().numpy()),4))))
        plt.legend()

del E_loc 

In [None]:
plt.plot(means,ls='',marker='o')
plt.axhline(np.mean(means),ls=':',label='sj')
plt.axhline(-5.37,ls=':',color='y',label='hf')
plt.axhline(-5.5,ls=':',color='r',label='mcscf')
plt.axhline(-5.5685,ls=':',color='g',label='han')
plt.axhline(-5.65634,ls=':',color='k',label='bench')
plt.ylim(-5.8,-5.3)
plt.legend(ncol=5,bbox_to_anchor=(1,-0.1))

# H2 -> H10

In [None]:
def same_atom(samples,geom):
    close=(samples[:,:,:,None,:]-geom.coords[None,None,None,:,:]).norm(dim=-1).argmin(dim=-1).cpu().numpy()
    same=((close[:,:,:,None]-close[:,:,None,:]-np.identity(samples[0,0].shape[0])*100==0).sum(axis=(-2,-1))/2).mean(axis=0)
    return same

In [None]:
def analyse_sampling(Net,N_electrons=None,n_steps=500,n_walker=100,tau=0.5):
    
    if N_electrons is None:
        N_electrons=[]
        for net in Net:
            N_electrons.append(len(net.geom.coords))
    
    alpha = 1/len(Net)
    
    fig = plt.figure(figsize=(10,16))
    ax1 = fig.add_subplot(421)
    ax2 = fig.add_subplot(422)
    ax3 = fig.add_subplot(423)
    ax4 = fig.add_subplot(424)
    ax5 = fig.add_subplot(425)
    ax6 = fig.add_subplot(426)
    ax7 = fig.add_subplot(427)
    ax8 = fig.add_subplot(428)


    for i,(net,n_electrons) in enumerate(zip(Net,N_electrons)):
        
        distance = torch.norm(net.geom.coords[1]-net.geom.coords[0]).item() #if all dists equivalent like in h-chain
        print('Setting: \n\n n_steps = %i \n n_electrons = %i \n tau = %.2f \n distance = %.2f'%(n_steps,n_electrons,tau,distance) )

        sampler = langevin_monte_carlo(
        net,
        sample_start(net.geom.cpu(),n_walker,n_electrons,var=1),
        tau=tau,
    )

        samples, psis, info = samples_from(sampler,trange(n_steps))

        E_loc = dlqmc.utils.batch_eval_tuple(partial(local_energy, wf=lambda x: net(x),geom=net.geom),tqdm(samples.view([-1,n_electrons,3]).split(1000)))[0]
        E_loc=E_loc.detach().cpu().numpy().reshape(n_walker,n_steps)
        
        lifetimes = np.concatenate(info['lifetime'].to_numpy()).reshape(n_steps,n_walker)
        acceptance = info['acceptance'].to_numpy()

        dens = GaussianKDEstimator(samples.flatten(end_dim=-2),bw=0.2)
        x_ = make_electron_line(mm=(-2,2+distance*n_electrons),d=distance,n_electrons=1,e=0,offset=0.5)
        x_ = x_.flatten(end_dim=-2)


        ax1.set_title('max lifetime (per walker)')
        ax1.hist(lifetimes.max(axis=0),bins=lifetimes.max(),color=c4(i),alpha=alpha,label=str(i))
        ax1.legend(bbox_to_anchor=(1.2, 1.6))
        ax2.set_title('max lifetime (@step)')    
        ax2.plot(lifetimes.max(axis=1),color=c4(i))
        ax3.set_title('acceptance rate (@step)')    
        ax3.plot(acceptance,color=c4(i))
        ax4.set_title('E_loc (mean,max,min) (@step)')
        ax4.plot(E_loc.mean(axis=0),color=c4(i))
        ax4.plot(E_loc.min(axis=0),ls=':',color=c4(i))
        ax4.plot(E_loc.max(axis=0),ls=':',color=c4(i))
        ax4.set_ylim(-20,10)
        ax5.set_title('autocorrelation')
        ax5.plot(autocorr_coeff(np.arange(0,min(500,n_steps)),samples).detach().cpu().numpy(),color=c4(i))
        #ax6.set_title('2d histogramm of samples (@x,y)')
        #ax6.hist2d(samples.flatten(end_dim=-2)[:,0].cpu().numpy(),samples.flatten(end_dim=-2)[:,1].cpu().numpy(),bins=(40,40),alpha=alpha);
        ax6.set_title('multiple electrons at atom')
        ax6.plot(same_atom(samples,net.geom))
        ax7.set_title('KDE over samples (@x-axis)')
        ax7.plot(x_.cpu().numpy()[:,0],dens(x_).cpu().numpy(),color=c4(i))
        ax8.set_title('local energys')
        ax8.hist(E_loc.flatten().clip(-15,5),bins=30,range=(-8,0),label='mean: %.3f, var: %.3f'%(np.mean(E_loc.flatten().clip(-20,10)),np.var(E_loc.flatten().clip(-20,10))),color=c4(i),alpha=alpha)
        ax8.legend()
    plt.show()

In [None]:
def do_pyscf(n,d,which=['hf_wf_nocusp','md_wf_nocusp','hf_wf_cusp','md_wf_cusp']):
    n_electrons=n
    n_down = n//2
    n_up = n_electrons-n_down

    mol = gto.M(
        atom=[
            ['H', (d*i, 0, 0)] for i in range(n)      
        ],
        unit='bohr',
        basis='6-31G',#'aug-cc-pV5Z',
        charge=0,
        spin=n%2,
        cart=True
    )
    mf = scf.RHF(mol)
    E_hf=mf.kernel()
    
    mc = mcscf.CASSCF(mf, n,n)
    mc.kernel()[0]
    
    ret = []
    for w in which:
        if w == 'hf_wf_nocusp':
            ret.append(SlaterJastrowNet.from_pyscf(mf,cusp_correction=False,cusp_electrons=False).cuda())
        elif w == 'md_wf_nocusp':
            ret.append(SlaterJastrowNet.from_pyscf(mc,cusp_correction=False,cusp_electrons=False).cuda())
        elif w == 'hf_wf_cusp':
            ret.append(SlaterJastrowNet.from_pyscf(mf,cusp_correction=True,cusp_electrons=True).cuda())
        elif w == 'md_wf_cusp':
            ret.append(SlaterJastrowNet.from_pyscf(mf,cusp_correction=True,cusp_electrons=True).cuda())
            
    return ret

In [None]:
hf_2,md_2,hf_2_cusp,md_2_cusp=do_pyscf(n=2,d=1.7,which=['hf_wf_nocusp','md_wf_nocusp','hf_wf_cusp','md_wf_cusp'])
hf_6,md_6,hf_6_cusp,md_6_cusp=do_pyscf(n=6,d=1.7,which=['hf_wf_nocusp','md_wf_nocusp','hf_wf_cusp','md_wf_cusp'])
hf_10,md_10,hf_10_cusp,md_10_cusp=do_pyscf(n=10,d=1.7,which=['hf_wf_nocusp','md_wf_nocusp','hf_wf_cusp','md_wf_cusp'])


In [None]:
hf_21=do_pyscf(n=2,d=0.5,which=['hf_wf_nocusp'])[0]
hf_22=do_pyscf(n=2,d=1.5,which=['hf_wf_nocusp'])[0]
hf_23=do_pyscf(n=2,d=4,which=['hf_wf_nocusp'])[0]


In [None]:
analyse_sampling([hf_2,hf_6,hf_10,],n_steps=50,n_walker=100,tau=.4)

In [None]:
sampler = langevin_monte_carlo(
hf_2,
sample_start(hf_2.geom.cpu(),n_walker=1,n_electrons=2,var=1),
tau=0.5,
)

samples, psis, info = samples_from(sampler,trange(100))


In [None]:
samples = samples.flatten(end_dim=1).detach().cpu().numpy()

In [None]:
x_max=samples[:,:,0].max()
x_min=samples[:,:,0].min()
y_max=samples[:,:,1].max()
y_min=samples[:,:,1].min()

plt.figure(figsize=(8,8))
ax1 = plt.subplot2grid((3,3), (0,0), colspan=2,sharex='col')
ax2 = plt.subplot2grid((3,3), (1,0), colspan=2,rowspan=2,sharex='col')
ax3 = plt.subplot2grid((3,3), (1, 2), rowspan=2)

for i,s in enumerate(samples):
    ax2.plot(s[0,0],s[0,1],ls='',marker='o',color=c2(i/len(samples)))
    ax2.plot(s[1,0],s[1,1],ls='',marker='o',color=c3(i/len(samples)))
ax2.plot(1.7,0,ls='',marker='x',ms=20,color='k')
ax2.plot(0,0,ls='',marker='x',ms=20,color='k')
ax2.set_xlim((x_min,x_max))
ax2.set_ylim((y_min,y_max))

ax1.hist(samples[:,0,0],color='r',alpha=0.5,range=(x_min,x_max))
ax1.hist(samples[:,1,0],color='b',alpha=0.5,range=(x_min,x_max))
#ax1.axis('off')
ax3.hist(samples[:,0,1],color='r',alpha=0.5,orientation="horizontal",bins=25,range=(y_min,y_max))
ax3.hist(samples[:,1,1],color='b',alpha=0.5,orientation="horizontal",bins=25,range=(y_min,y_max))
#ax3.axis('off')
plt.show()