In [None]:
#H2+     Energy = -0.6023424   for R = 1.9972
#fit(batch_size=10000, n_el=1, steps=500, epochs=1, RR=[[-1, 0, 0], [1., 0, 0]])

#H2		 Energy = -1.173427    for R = 1.40
#fit(batch_size=10000,n_el=2,steps=100,epochs=5,RR=torch.tensor([[-0.7,0,0],[0.7,0,0]]))

#He+	 Energy = -1.9998
#fit(batch_size=10000,n_el=1,steps=100,epochs=5,RR=torch.tensor([[0.,0,0]]),RR_charges=[2])

#He		 Energy = −2.90338583
#fit(batch_size=10000,n_el=2,steps=300,epochs=5,RR=torch.tensor([[0.3,0,0]]),RR_charges=[2])

In [None]:
%load_ext autoreload
%autoreload 1
%aimport  dlqmc.sampling, dlqmc.utils, dlqmc.nn.base, dlqmc.fit
%config InlineBackend.figure_format = 'svg' 
%config InlineBackend.print_figure_kwargs = \
    {'bbox_inches': 'tight', 'dpi': 300}

In [None]:
import ipywidgets
import torch.nn as nn
import numpy as np
from scipy import special
import scipy.stats as sps
import matplotlib.pyplot as plt
import torch
#from torch.utils.data import DataLoader, RandomSampler
#from torch.distributions import Normal
from pyscf import gto, scf, dft
import pyscf
from pyscf.data.nist import BOHR
import time
from functools import partial
from tqdm.auto import tqdm, trange
from tensorboardX import SummaryWriter

from dlqmc.nn.base import * 
from dlqmc.nn.base import conv_indexing
from dlqmc.geom import *
from dlqmc.nn.gto import *
#from dlqmc.nn import *
from dlqmc.sampling import langevin_monte_carlo, hmc ,samples_from
from dlqmc.fit import *
from dlqmc.nn.anti import *
#from dlqmc.utils import assign_where
from dlqmc.physics import (
    local_energy, grad, quantum_force,nuclear_potential,
    nuclear_energy, laplacian, electronic_potential
)
#from dlqmc.analysis import autocorr_coeff, blocking
from dlqmc.nn import ssp
from dlqmc.nn.hannet import HanNet
from dlqmc.nn.hfnet import HFNet

def normplot(x,y,norm,*args,**kwargs):
    if norm:
        plt.plot(x,y/np.max(np.abs(y)),*args,**kwargs)
    else:
        plt.plot(x,y,*args,**kwargs)
        
print(torch.cuda.memory_allocated())
print(torch.cuda.memory_cached(device=None))
print(torch.cuda.max_memory_cached(device=None))
torch.cuda.empty_cache()

In [None]:
import torch
from torch import nn

from dlqmc.utils import NULL_DEBUG, dctsel, triu_flat
from dlqmc.nn.base import (
    SSP,
    BaseWFNet,
    Concat,
    DistanceBasis,
    ElectronicAsymptotic,
    NuclearAsymptotic,
    get_log_dnn,
    pairwise_distance,
)
from dlqmc.nn.schnet import ElectronicSchnet
from dlqmc.nn.hfnet import HFNet
from dlqmc.nn.backflow import Backflow



class SJNet(BaseWFNet):
    def __init__(
        self,
        geom,
        n_up,
        n_down,
        mf,
        basis_dim=32,
        kernel_dim=64,
        embedding_dim=128,
        n_interactions=3,
        n_orbital_layers=3,
        ion_pot=0.,
        cusp_same=None,
        cusp_anti=None,
        #nuc_asymp=True,

        **kwargs,
    ):
        def orbital_factory(embedding_dim):
            return get_log_dnn(embedding_dim, 1, SSP, n_layers=n_orbital_layers)


        super().__init__()
        self.n_up = n_up
        self.register_geom(geom)
        self.dist_basis = DistanceBasis(basis_dim, **dctsel(kwargs, 'cutoff'))
        #self.nuc_asymp = nuc_asymp
        self.asymp_nuc = NuclearAsymptotic(
                self.charges, ion_pot, **dctsel(kwargs, 'alpha')
        )

        self.asymp_same, self.asymp_anti = (
            ElectronicAsymptotic(cusp=cusp) if cusp is not None else None
            for cusp in (cusp_same, cusp_anti)
        )
        self.schnet = ElectronicSchnet(
            n_up,
            n_down,
            len(geom),
            n_interactions,
            basis_dim,
            kernel_dim,
            embedding_dim,
            interaction_factory=None,
        )
        self.orbital = orbital_factory(embedding_dim)
        self.anti = HFNet.from_pyscf(mf)

    def forward(self, rs, debug=NULL_DEBUG):
        dists_elec = pairwise_distance(rs, rs)
        dists_nuc = pairwise_distance(rs, self.coords[None, ...])
        dists = torch.cat([dists_elec, dists_nuc], dim=2)
        dists_basis = self.dist_basis(dists)
        with debug.cd('schnet'):
            xs = self.schnet(dists_basis, debug=debug)
        jastrow = debug['jastrow'] = self.orbital(xs).squeeze(dim=-1).sum(dim=-1)
        anti = debug['anti'] = self.anti(rs)
        
       # if self.nuc_asymp:
       #     asymp_nuc = debug['asymp_nuc'] = self.asymp_nuc(dists_nuc)  # TODO add electrons
       # else:
       #     asymp_nuc = 1.
        asymp_same = debug['asymp_same'] = (
            self.asymp_same(
                torch.cat(
                    [triu_flat(dists_elec[:, idxs, idxs]) for idxs in self.spin_slices],
                    dim=1,
                )
            )
            if self.asymp_same
            else 1.0
        )
        asymp_anti = debug['asymp_anti'] = (
            self.asymp_anti(
                dists_elec[:, : self.n_up, self.n_up :].flattten(start_dim=1)
            )
            if self.asymp_anti
            else 1.0
        )
        
        asymp = asymp_same * asymp_anti# * asymp_nuc 

        return anti * torch.exp(jastrow) * asymp

In [None]:
h2p = geomdb['H2+']
h2 = geomdb['H2']

In [None]:
import numpy as np
import torch
from torch import nn



class SlaterSchnetJastrowNet(BaseWFNet):
    def __init__(self, geom, n_up, n_down,
                 
        basis_dim=32,
                 
        kernel_dim_jastrow=64,
        embedding_dim_jastrow=128,
        n_interactions_jastrow=3,
                 
        kernel_dim_anti=64,
        embedding_dim_anti=128,
        n_interactions_anti=3,
                 
        ion_pot=1.,
        n_orbital_layers=4,
        **kwargs):
        
        super().__init__()
        self.n_up, self.n_down = n_up, n_down
        self.register_geom(geom)
        self.dist_basis = DistanceBasis(basis_dim, **dctsel(kwargs, 'cutoff'))
        self.mo = nn.Linear(embedding_dim_anti, max(n_up, n_down), bias=False)
        self.schnet_anti = ElectronicSchnet(
            n_up,
            n_down,
            len(geom),
            n_interactions_anti,
            basis_dim,
            kernel_dim_anti,
            embedding_dim_anti,
            interaction_factory=None,
        )
        self.schnet = ElectronicSchnet(
            n_up,
            n_down,
            len(geom),
            n_interactions_jastrow,
            basis_dim,
            kernel_dim_jastrow,
            embedding_dim_jastrow,
            interaction_factory=None,
        )
        
        self.asymp_nuc = NuclearAsymptotic(
                self.charges, ion_pot, **dctsel(kwargs, 'alpha')
        )

        
        def orbital_factory(embedding_dim):
            return get_log_dnn(embedding_dim, 1, SSP, n_layers=n_orbital_layers)
        
        self.orbital = orbital_factory(embedding_dim_jastrow)

    def __call__(self, rs, debug=NULL_DEBUG):
        
        dists_elec  = pairwise_distance(rs, rs)
        dists_nuc   = pairwise_distance(rs, self.coords[None, ...])
        dists       = torch.cat([dists_elec, dists_nuc], dim=2)
        dists_basis = self.dist_basis(dists)
        
        batch_dim, n_elec = rs.shape[:2]
        
        asymp_nuc = self.asymp_nuc(dists_nuc)
        
        xs = self.schnet(dists_basis)
        jastrow = self.orbital(xs).squeeze(dim=-1).sum(dim=-1)
        
        ys = self.schnet_anti(dists_basis)
        ys = self.mo(ys)

        det_up = debug['det_up'] = eval_slater(ys[:, : self.n_up, : self.n_up])
        det_down = debug['det_down'] = eval_slater(ys[:, self.n_up :, : self.n_down])

        return det_up * det_down * jastrow 


## $H_{10}$

In [None]:
import threading
from dlqmc.sampling import sample_start

Han -5.5685, Benschmark -5.6655

In [None]:
d=1.786
n=10
hn = Geometry([[d*i, 0., 0.] for i in range(n)], [1. for i in range (n)])
print(hn)
molecule = hn
n_electrons=n
n_down = n//2
n_up = n_electrons-n_down

In [None]:
mol = gto.M(
    atom=[
        ['H', (d*i, 0, 0)] for i in range(n)      
    ],
    unit='bohr',
    basis='6-31G',#'aug-cc-pV5Z',
    charge=0,
    spin=n%2,
    cart=True
)
mf = scf.RHF(mol)
mf.kernel()
gtowf = HFNet.from_pyscf(mf).cuda()

In [None]:
#hannet = HanNet(hn,n_up,n_down, 
#        basis_dim=16,
#        kernel_dim=32,
#        embedding_dim=64,
#        latent_dim=10,
#        n_interactions=3,
#        n_orbital_layers=3,).cuda()

#sjnet = SJNet(hn,n_up,n_down, mf=mf,
#        basis_dim=16,
#        kernel_dim=32,
#        embedding_dim=64,
#        latent_dim=10,
#        n_interactions=3,
#        n_orbital_layers=3,
#        ).cuda()

ssjnet = SlaterSchnetJastrowNet(hn,n_up,n_down, 
        ).cuda()


#sjnet_nuc = SJNet(hn,n_up,n_down, mf=mf,
#        basis_dim=16,
#        kernel_dim=32,
#        embedding_dim=64,
#        latent_dim=10,
#        n_interactions=3,
#        n_orbital_layers=3,
#        nuc_asymp=True).cuda()

In [None]:
n_wlaker=5
#samplerhan = langevin_monte_carlo(
#    hannet,
#    sample_start(molecule,n_wlaker,n_electrons,var=1),
#    tau=0.1,
#)
#samplersj = langevin_monte_carlo(
#    sjnet,
#    sample_start(molecule,n_wlaker,n_electrons,var=1),
#    tau=0.1,
#)
samplerssj = langevin_monte_carlo(
    ssjnet,
    sample_start(molecule,n_wlaker,n_electrons,var=1),
    tau=0.1,
)


samplergto = langevin_monte_carlo(
    gtowf,
    sample_start(molecule,n_wlaker,n_electrons,var=1),
    tau=0.1,
)

In [None]:
#def run():
fit_wfnet_supervised(
        ssjnet,
        gtowf.cuda(),
        loss_least_squares,
        torch.optim.Adam(ssjnet.parameters(), lr=1e-3),
        wfnet_fit_driver(
                samplerssj,
                samplings=range(20),
                n_epochs=1,
                n_sampling_steps=50,
                batch_size=1_000,
                n_discard=30,
                range_sampling=partial(trange, desc='sampling steps', leave=False),
                range_training=partial(trange, desc='training steps', leave=False),
            ),
        writer =  SummaryWriter(f'runs/'),
        )


for net,sampler in zip([ssjnet],[samplerssj]):
    fit_wfnet(
        net,
        partial(loss_local_energy, E_ref=-6,p=1),
        torch.optim.Adam(net.parameters(), lr=1e-3),
        wfnet_fit_driver_simple(
                sampler,
                samplings=trange(500),
                n_sampling_steps=50,
                n_decorrelate = 4,
                n_discard = 30
            ),
        clip_grad = None,
        writer = SummaryWriter(f'runs/'),
        )

    fit_wfnet(
        net,
        partial(loss_local_energy, E_ref=None,p=1),
        torch.optim.Adam(net.parameters(), lr=1e-3),
        wfnet_fit_driver_simple(
                sampler,
                samplings=trange(1000),
                n_sampling_steps=50,
                n_decorrelate = 4,
                n_discard = 30,
            ),
        clip_grad = None,
        acc_grad = 5,
        writer = SummaryWriter(f'runs/'),
    )

#t = threading.Thread(target=run)
#t.start()

In [None]:
x_line = torch.cat((torch.linspace(-d, (n_electrons)*d, 500)[:, None], torch.zeros((500, 3*n_electrons-1))), dim=1)
for i in range(n_electrons-1):
    x_line[:,3*(i+1)]=d/2*(2*(i))+torch.randn(1)/10
    #x_line[:,3*(i+1)+1]=torch.randn(1)/10
    #x_line[:,3*(i+1)+2]=torch.randn(1)/10
x_line=x_line.view(-1,n_electrons,3).cuda()
#x_line=x_line[:,[0,3,2,1]]
x_line.requires_grad = True



normed=True
normplot(x_line[:,0 , 0].cpu().detach().numpy(), ssjnet(x_line).cpu().detach().numpy(),label="ssjWF",norm=normed)
normplot(x_line[:,0 , 0].cpu().detach().numpy(), hannet(x_line).cpu().detach().numpy(),label="hanWF",norm=normed)
normplot(x_line[:,0 , 0].cpu().detach().numpy(), sjnet(x_line).cpu().detach().numpy(),label="sjWF",norm=normed)
normplot(x_line[:,0 , 0].cpu().detach().numpy(), gtowf(x_line).cpu().detach().numpy(),ls=':',label="HF",norm=normed)

plt.axhline(0,ls=':',color='k')
for i in range(n):
    plt.axvline(d*i,ls=':',color='k')
for i in range(n_electrons-1):
    plt.axvline(x_line[0,(i+1)][0],ls=':',color='r')


#plt.ylim(-1,2)
x_line.requires_grad = False
plt.legend()
plt.show()


In [None]:
for net,sampler in zip([ssjnet,sjnet,hannet],[samplerssj,samplersj,samplerhan]):
    t=time.time()
    samples = samples_from(sampler,trange(1000))[0].flatten(end_dim=1)
    print("it took: "+str(time.time()-t))
    E_loc = dlqmc.utils.batch_eval_tuple(partial(local_energy, wf=lambda x: net(x),geom=net.geom),tqdm(samples.view([-1,n_electrons,3]).split(1000)))[0]
    E_loc = E_loc.clamp(-4, -0)
    mean=E_loc.mean().item()
    h = plt.hist(E_loc.detach().cpu().numpy(), bins=100,alpha = 0.5,label=("mean = "+str(np.round(mean,4))))
    plt.legend()

## Scaling

In [None]:
from scipy.optimize import curve_fit as cf

In [None]:
def p2(x,a,c):
    return a+c*x**2
def p1(x,a,c):
    return a+c*x

In [None]:
n_p = []
m_alloc = []
m_max_alloc = []
m_cached = []
m_max_cached = []
T_forward=[]
T_eloc=[]
for n in trange(1,40):
    torch.cuda.empty_cache()

    d=1.5#1.786
    hn = Geometry([[d*i, 0., 0.] for i in range(n)], [1. for i in range (n)])

    n_electrons=n
    n_up = n//2
    n_down = n_electrons-n_up
    net = HanNet(hn,n_up,n_down, 
            basis_dim=8,
            kernel_dim=16,
            embedding_dim=32,
            latent_dim=5,
            n_interactions=2,
            n_orbital_layers=3,).cuda()


    n=0
    for p in net.parameters():
        n+=torch.prod(torch.tensor(p.shape))
    n_p.append(n)
    
    sampler = langevin_monte_carlo(
    net,
    torch.randn(10, n_electrons, 3, device='cuda'),
    tau=0.1,
    )

    samples = samples_from(sampler,range(10))[0].flatten(end_dim=1)

    T = []
    for j in range(10):
        t = time.time()
        y = net(samples)
        T.append(time.time()-t)
    T_forward.append(np.mean(np.array(T)))
    
    T = []
    for j in range(10):
        t = time.time()
        local_energy(samples,net)
        T.append(time.time()-t)
    T_eloc.append(np.mean(np.array(T)))

    m_alloc.append(torch.cuda.memory_allocated())
    m_max_alloc.append(torch.cuda.max_memory_allocated())
    m_cached.append(torch.cuda.memory_cached(device=None))
    m_max_cached.append(torch.cuda.max_memory_cached(device=None))
    
    del hn
    del net
    del sampler
    del samples
    torch.cuda.empty_cache()
    #print(torch.cuda.memory_allocated())
    #print(torch.cuda.memory_cached(device=None))


In [None]:
x=np.arange(1,40)
plt.figure(figsize=(10,6))
plt.title("scaling")
plt.subplot2grid((2,2),(0,0))
plt.plot(x,p1(x,*cf(p1,x,n_p)[0]))
plt.plot(x,n_p,ls='',marker='x')
#plt.annotate(xy=(5,5850),s="f(x) = a+bx\n a=%1.1f\n b=%1.1f "%tuple(cf(p1,x,n_p)[0]))
plt.xlabel("# hydrogen atoms")
plt.ylabel("# parameters")
plt.subplot2grid((2,2),(0,1))
#plt.plot(x,m_alloc)
plt.plot(x,p2(x,*cf(p2,x,m_max_alloc)[0]))
plt.plot(x,m_max_alloc,ls='',marker='x')
#plt.annotate(xy=(5,3e7),s="f(x) = a+bx^2\n a=%1.1f\n b=%1.1f "%tuple(cf(p2,x,m_max_alloc)[0]))
plt.xlabel("# hydrogen atoms")
plt.ylabel("memory allocated")
plt.subplot2grid((2,2),(1,0))
plt.plot(x,p1(x,*cf(p1,x,T_forward)[0]))
plt.plot(x,T_forward,ls='',marker='x')
plt.ylabel("time forward")
plt.xlabel("# hydrogen atoms")
plt.subplot2grid((2,2),(1,1))
plt.plot(x,p2(x,*cf(p2,x,T_eloc)[0]))
plt.plot(x,T_eloc,ls='',marker='x')
plt.ylabel("time local energy")
plt.xlabel("# hydrogen atoms")
plt.show()

# Backflow

In [None]:
mol = gto.M(
    atom=[
        ['B', (0, 0, 0)]
    ],
    unit='bohr',
    basis='4-31G',
    cart=True,
    charge=0,
    spin=1,
)
mf = scf.RHF(mol)
mf.kernel()


In [None]:
bohr=geomdb['H']
bohr._coords=torch.tensor([[0,0,0.]])
bohr._charges=torch.tensor([5.])

In [None]:
import numpy as np
import torch
from torch import nn

from dlqmc.geom import Geometry
from dlqmc.utils import NULL_DEBUG
from dlqmc.nn.anti import eval_slater
from dlqmc.nn.base import BaseWFNet
from dlqmc.nn.base import DistanceBasis
from dlqmc.nn.gto import GTOBasis
from dlqmc.nn.backflow import Backflow


class HFNet(BaseWFNet):
    def __init__(self, geom, n_up, n_down, basis, n_interactions):
        super().__init__()
        self.n_up, self.n_down = n_up, n_down
        self.register_geom(geom)
        self.basis = basis
        self.mo = nn.Linear(basis.dim, max(n_up, n_down), bias=False)
        self.backflow = Backflow(n_up, n_down,n_interactions,4,20)

    def init_from_pyscf(self, mf):
        mo_coeff = mf.mo_coeff.copy()
        if mf.mol.cart:
            mo_coeff *= np.sqrt(np.diag(mf.mol.intor('int1e_ovlp_cart')))[:, None]
        self.mo.weight.detach().copy_(
            torch.from_numpy(mo_coeff[:, : max(self.n_up, self.n_down)].T)
        )

    @classmethod
    def from_pyscf(cls, mf, n_interactions):
        n_up = (mf.mo_occ >= 1).sum()
        n_down = (mf.mo_occ == 2).sum()
        assert (mf.mo_occ[:n_down] == 2).all()
        assert (mf.mo_occ[n_down:n_up] == 1).all()
        assert (mf.mo_occ[n_up:] == 0).all()
        geom = Geometry(mf.mol.atom_coords().astype('float32'), mf.mol.atom_charges())
        basis = GTOBasis.from_pyscf(mf.mol)
        wf = cls(geom, n_up, n_down, basis, n_interactions)
        wf.init_from_pyscf(mf)
        return wf

    def __call__(self, rs, debug=NULL_DEBUG):
        batch_dim, n_elec = rs.shape[:2]
        xs = self.backflow(rs)
        xs = debug['aos'] = self.basis(rs.flatten(end_dim=1)).view(batch_dim, n_elec, -1)
        xs = debug['slaters'] = self.mo(xs)
        det_up = debug['det_up'] = eval_slater(xs[:, : self.n_up, : self.n_up])
        det_down = debug['det_down'] = eval_slater(xs[:, self.n_up :, : self.n_down])
        return det_up * det_down

    def orbitals(self, rs):
        return self.mo(self.basis(rs))

    def density(self, rs):
        xs = self.orbitals(rs)
        return sum(
            (xs[:, :n_elec] ** 2).sum(dim=-1) for n_elec in (self.n_up, self.n_down)
        )


In [None]:
n_electrons=5
n_up = 3
n_down = n_electrons-n_up
molecule = bohr

net0=HFNet.from_pyscf(mf,n_interactions=0).cuda()
net1=HFNet.from_pyscf(mf,n_interactions=1).cuda()
net3=HFNet.from_pyscf(mf,n_interactions=3).cuda()

In [None]:
sampler0 = langevin_monte_carlo(
    net0,
    torch.randn(1000, n_electrons, 3, device='cuda'),
    tau=0.1,
)

sampler1 = langevin_monte_carlo(
    net1,
    torch.randn(1000, n_electrons, 3, device='cuda'),
    tau=0.1,
)

sampler3 = langevin_monte_carlo(
    net3,
    torch.randn(1000, n_electrons, 3, device='cuda'),
    tau=0.1,
)

In [None]:
for net,sampler in zip([net0,net1],[sampler0,sampler1]):
    fit_wfnet(
        net,
        partial(loss_local_energy, E_ref=-1.1,p=2),
        torch.optim.Adam(net.parameters(), lr=1e-3),
        wfnet_fit_driver(
                sampler,
                samplings=range(1),
                n_epochs=1,
                n_sampling_steps=200,
                batch_size=1_000,
                n_discard=50,
                range_sampling=partial(trange, desc='sampling steps', leave=False),
                range_training=partial(trange, desc='training steps', leave=False),
            ),
        clip_grad = None,
        writer = SummaryWriter(f'runs/'),
        )

    fit_wfnet(
        net,
        partial(loss_local_energy, E_ref=None,p=1),
        torch.optim.Adam(net.parameters(), lr=1e-3),
        wfnet_fit_driver(
                sampler,
                samplings=range(3),
                n_epochs=1,
                n_sampling_steps=200,
                batch_size=1_000,
                n_discard=50,
                range_sampling=partial(trange, desc='sampling steps', leave=False),
                range_training=partial(trange, desc='training steps', leave=False),
            ),
        clip_grad = None,
        writer = SummaryWriter(f'runs/'),
        )


In [None]:
x_line = torch.cat((torch.linspace(-5, 5, 599)[:, None], torch.zeros((599, 3*n_electrons-1))), dim=1)
x_line[:,3]=0.4
x_line[:,4]=0.1
x_line[:,5]=-0.1
x_line[:,6]=-0.2
x_line[:,7]=-0.2
x_line[:,8]=-0.2
x_line[:,10]=0.4
x_line[:,13]=-0.2

x_line=x_line.view(-1,n_electrons,3).cuda()
x_line.requires_grad = True


normed=True
normplot(x_line[:,0 , 0].cpu().detach().numpy(), net0(x_line).cpu().detach().numpy(),label="WF0",norm=normed,lw=2)
normplot(x_line[:,0 , 0].cpu().detach().numpy(), net1(x_line).cpu().detach().numpy(),label="WF1",norm=normed,lw=2)
#normplot(x_line[:,0 , 0].cpu().detach().numpy(), net3(x_line).cpu().detach().numpy(),label="WF3",norm=normed,lw=2)
plt.axhline(0,ls=':',color='k')
plt.axvline(0,ls=':',color='k')
plt.legend()
plt.show()


In [None]:
i=0
for net,sampler in zip([net0,net1],[sampler0,sampler1]):
    i+=1
    t=time.time()
    samples = samples_from(sampler,range(100))[0].flatten(end_dim=1)
    print("it took: "+str(time.time()-t))
    
    E_loc = dlqmc.utils.batch_eval_tuple(partial(local_energy, wf=lambda x: net(x),geom=net.geom),tqdm(samples.view([-1,n_electrons,3]).split(1000)))[0]
    E_loc = E_loc.clamp(-40, 10)
    mean=E_loc.mean().item()
    h = plt.hist(E_loc.detach().cpu().numpy(), bins=100,alpha = 0.5,label=(str(i)+": mean = "+str(np.round(mean,4))))
plt.legend()
plt.show()

In [None]:
t=time.time()
samples = samples_from(sampler3,range(100))[0].flatten(end_dim=1)
print("it took: "+str(time.time()-t))
    
plt.hist2d(
    samples[:,:, 0].cpu().flatten().detach().numpy(),
    samples[:,:, 1].cpu().flatten().detach().numpy(),
    bins=100,
    range=[[-1, 1], [-1, 1]],
)                                   
plt.gca().set_aspect(1)

In [None]:
E_loc = dlqmc.utils.batch_eval_tuple(partial(local_energy, wf=lambda x: net0(x),geom=net0.geom),tqdm(samples.view([-1,n_electrons,3]).split(1000)))[0]

In [None]:
mean=np.round(E_loc.clamp(-40, 10).mean().item(),4)
var=np.round(np.var(E_loc.detach().clamp(-40, 10).cpu().numpy()),4)

h = plt.hist(E_loc.detach().clamp(-40, 10).cpu().numpy(), bins=100,alpha = 0.5,color='b')
plt.annotate("mean = "+str(mean),(-0.3,np.max(h[0])/2),color='b')
plt.annotate("var     = "+str(var),(-0.3,np.max(h[0])/2-np.max(h[0])/15),color='b')

plt.savefig('lastruneloc.png')
plt.show()

## Supervised fitting

In [None]:
ssjnet = SlaterSchnetJastrowNet(hn,n_up,n_down, 
        ).cuda()

samplerssj = langevin_monte_carlo(
    ssjnet,
    sample_start(molecule,2000,n_electrons,var=1),
    tau=0.1,
)
samplergto = langevin_monte_carlo(
    gtowf,
    sample_start(molecule,2000,n_electrons,var=1),
    tau=0.1,
)

In [None]:
fit_wfnet_supervised(
    ssjnet,
    gtowf.cuda(),
    loss_least_squares,
    torch.optim.Adam(ssjnet.parameters(), lr=1e-3),
    wfnet_fit_driver(
            samplerssj,
            samplings=range(1),
            n_epochs=1,
            n_sampling_steps=50,
            batch_size=1_000,
            n_discard=30,
            range_sampling=partial(trange, desc='sampling steps', leave=False),
            range_training=partial(trange, desc='training steps', leave=False),
        ),
    writer =  SummaryWriter(f'runs/'),
    )



In [None]:
x_line = torch.cat((torch.linspace(-d, (n_electrons)*d, 500)[:, None], torch.zeros((500, 3*n_electrons-1))), dim=1)
for i in range(n_electrons-1):
    x_line[:,3*(i+1)]=d/2*(2*(i+1))+torch.randn(1)/2
    #x_line[:,3*(i+1)+1]=torch.randn(1)/10
    #x_line[:,3*(i+1)+2]=torch.randn(1)/10
x_line=x_line.view(-1,n_electrons,3).cuda()



normed=True
normplot(x_line[:,0 , 0].cpu().detach().numpy(), (ssjnet(x_line)).cpu().detach().numpy(),label="ssjnet",norm=normed)
normplot(x_line[:,0 , 0].cpu().detach().numpy(), (gtowf(x_line)).cpu().detach().numpy(),label="gtoWF",norm=normed)

#x_line=x_line[:,[0,2,1]]
#
#normplot(x_line[:,0 , 0].cpu().detach().numpy(), (ssjnet(x_line)).cpu().detach().numpy(),label="ssjnet",norm=normed)
#normplot(x_line[:,0 , 0].cpu().detach().numpy(), (gtowf(x_line)).cpu().detach().numpy(),label="gtoWF",norm=normed)

plt.axhline(0,ls=':',color='k')
for i in range(n):
    plt.axvline(d*i,ls=':',color='k')
for i in range(n_electrons-1):
    plt.axvline(x_line[0,(i+1)][0],ls=':',color='r')


#plt.ylim(-1,2)
x_line.requires_grad = False
plt.legend()
plt.show()