In [None]:
#H2+     Energy = -0.6023424   for R = 1.9972
#fit(batch_size=10000, n_el=1, steps=500, epochs=1, RR=[[-1, 0, 0], [1., 0, 0]])

#H2		 Energy = -1.173427    for R = 1.40
#fit(batch_size=10000,n_el=2,steps=100,epochs=5,RR=torch.tensor([[-0.7,0,0],[0.7,0,0]]))

#He+	 Energy = -1.9998
#fit(batch_size=10000,n_el=1,steps=100,epochs=5,RR=torch.tensor([[0.,0,0]]),RR_charges=[2])

#He		 Energy = −2.90338583
#fit(batch_size=10000,n_el=2,steps=300,epochs=5,RR=torch.tensor([[0.3,0,0]]),RR_charges=[2])

In [None]:
%load_ext autoreload
%autoreload 1
%aimport  dlqmc.sampling, dlqmc.utils, dlqmc.nn.base, dlqmc.fit
%config InlineBackend.figure_format = 'svg' 
%config InlineBackend.print_figure_kwargs = \
    {'bbox_inches': 'tight', 'dpi': 300}

In [None]:
import ipywidgets
import torch.nn as nn
import numpy as np
from scipy import special
import scipy.stats as sps
import matplotlib.pyplot as plt
import torch
#from torch.utils.data import DataLoader, RandomSampler
#from torch.distributions import Normal
from pyscf import gto, scf, dft
import pyscf
from pyscf.data.nist import BOHR
import time
from functools import partial
from tqdm.auto import tqdm, trange
from tensorboardX import SummaryWriter

from dlqmc.nn.base import * 
from dlqmc.geom import *
from dlqmc.nn.gto import *
from dlqmc.nn import *
from dlqmc.sampling import langevin_monte_carlo, hmc ,samples_from
from dlqmc.fit import *
from dlqmc.nn.anti import *
#from dlqmc.utils import assign_where
from dlqmc.physics import (
    local_energy, grad, quantum_force,nuclear_potential,
    nuclear_energy, laplacian, electronic_potential
)
#from dlqmc.analysis import autocorr_coeff, blocking
from dlqmc.nn import ssp
from dlqmc.nn.hannet import HanNet

In [None]:
h2p = geomdb['H2+']
h2 = geomdb['H2']

In [None]:
def normplot(x,y,norm,*args,**kwargs):
    if norm:
        plt.plot(x,y/np.max(np.abs(y)),*args,**kwargs)
    else:
        plt.plot(x,y,*args,**kwargs)

In [None]:
class Net_pair(nn.Module):
    def __init__(self):
        super().__init__()

        self.NN1=nn.Sequential(
            torch.nn.Linear(6, 10),
            SSP(),
            #torch.nn.Linear(10, 10),
            #SSP(),
            torch.nn.Linear(10, 10)
            )
        
    def forward(self,x1,x2):
        d=torch.cat((x1,x2),dim=-1)
        return self.NN1(d)
    
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        self.NN1=nn.Sequential(
            torch.nn.Linear(10, 10),
            SSP(),
            #torch.nn.Linear(15, 10),
            #SSP(),
            torch.nn.Linear(10, 1))#,
            #nn.Sigmoid())
        
    def forward(self,x):
        return torch.sigmoid(self.NN1(x).flatten())

    
class WFNetAnti(nn.Module):
    def __init__(
        self,
        geom,
        n_electrons,
        net,
        net_pair,
        ion_pot=0.5,
        cutoff=10.0,
        n_dist_feats=32,
        alpha=1.0,
    ):
        super().__init__()
        self.dist_basis = DistanceBasis(n_dist_feats)
        self.register_buffer('coords', geom.coords)
        self.register_buffer('charges', geom.charges)
        self.nuc_asymp = NuclearAsymptotic(self.charges, ion_pot, alpha=alpha)
        #self.el_cusp = ElectronicCusp()
        n_atoms = len(geom.charges)
        n_pairs = n_electrons * n_atoms + n_electrons * (n_electrons - 1) // 2
        self.deep_lin = nn.Sequential(
            nn.Linear(n_pairs * n_dist_feats, 64),
            SSP(),
            nn.Linear(64, 64),
            SSP(),
            nn.Linear(64, 64),
            SSP(),
            nn.Linear(64, 64),
            SSP(),
            nn.Linear(64, 1),
        )
        self.antisym = AntisymmetricPart(net, net_pair)
        self._pdist = PairwiseDistance3D()
        self._psdist = PairwiseSelfDistance3D()

    def _featurize(self, rs):
        dists_nuc = self._pdist(rs, self.coords[None, ...])
        dists_el = self._psdist(rs)
        dists = torch.cat([dists_nuc.flatten(start_dim=1), dists_el], dim=1)
        xs = self.dist_basis(dists)  # .flatten(start_dim=1)
        return xs.flatten(start_dim=1), (dists_nuc, dists_el)

    def forward(self, rs):
        #dists_nuc = self._pdist(rs, self.geom.coords[None, ...])
        xs, (dists_nuc, dists_el) = self._featurize(rs)
        ys = self.deep_lin(xs).squeeze(dim=1)
        return self.nuc_asymp(dists_nuc) * torch.exp(ys) * self.antisym(rs)


## HF WF

In [None]:
mol = gto.M(
    atom=[
        ['H', (0, 0, 0)],
        ['H', (1.484, 0, 0)]
    ],
    unit='bohr',
    basis='4-31G',
    charge=0,
    spin=2,
)
mf = scf.RHF(mol)
mf.kernel()


In [None]:
#gtowf.get_aos(torch.randn(1, 3))

In [None]:
gtowf = TorchGTOSlaterWF(mf)

### Supervised

In [None]:
n_electrons=2
molecule = h2

Onet = Net().cuda()
Pnet = Net_pair().cuda()
net = WFNetAnti(molecule,n_electrons,Onet,Pnet,ion_pot=0.7).cuda()

L = []
V = []

x_line = torch.cat((torch.linspace(-3, 3, 500)[:, None], torch.zeros((500, 3*n_electrons-1))), dim=1)
x_line=x_line.view(-1,n_electrons,3).cuda()
#mesh = get_3d_cube_mesh([(-6, 6), (-4, 4), (-4, 4)], [600, 400, 400])

opt = torch.optim.Adam(net.parameters(), lr=1e-2)
t_start=time.time()
scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=1, gamma=0.999)

steps = 200
batchsize = 50_000
n_resamplings = 100
n_walker = 1_000

sampler = langevin_monte_carlo(
    gtowf,
    torch.randn(n_walker, n_electrons, 3, device='cuda'),
    tau=0.1,
)


#temporary
molecule._coords=molecule._coords.cuda()
molecule._charges=molecule._charges.cuda()

for i_step in range(steps):
        
    if i_step%(steps//4) == 0 or i_step==steps:
        with torch.no_grad():
            Psi2 = net(x_line)**2
            plt.plot(x_line[:,0 , 0].cpu().detach().numpy(), Psi2.cpu().detach().numpy(),label=i_step)
    
    scheduler.step()
    if i_step%(steps//n_resamplings)==0:
        print("resample                                                                        ",end="\r")
        rs,rs_psis  = samples_from(sampler,range(int(batchsize*steps/(n_resamplings*n_walker))))[0:-1]
        rs = rs.flatten(end_dim=1).cuda()
        rs_psis = rs_psis.flatten(end_dim=1).cuda()
        idx = torch.randperm(len(rs))
        rs = rs[idx]
        rs_psis = rs_psis[idx]
        
    r=rs[i_step%(steps//n_resamplings)*batchsize:(i_step%(steps//n_resamplings)+1)*batchsize]
    loss = torch.sum((net(r)**2-gtowf(r).cuda()**2)**2)
    
    print("Progress {:2.0%}".format(i_step /steps)+"   ->"+"I"*(int(i_step/steps*100)//10)+"i"*(int(i_step/steps*100)%10)+"  "+"current loss = "+str(np.round(loss.item(),4))+"        ", end="\r")


    loss.backward()
    L.append(loss.cpu().detach().numpy())
    #V.append(((E_loc**2-E_loc.mean()**2).mean()).cpu().detach().numpy())
        
    opt.step()
    opt.zero_grad()
    
plt.legend()
print("it took ="+str(np.round(time.time()-t_start,5))+"                    ")
    


In [None]:
x_line = torch.cat((torch.linspace(-5, 5, 5000)[:, None], torch.zeros((5000, 3*n_electrons-1))), dim=1)
x_line[:,3] = 1.484/2
x_line=x_line.view(-1,n_electrons,3).cuda()
x_line.requires_grad = True
net.cuda()
f_line = net._featurize(x_line)
normed=True
#normplot(x_line[:,0 , 0].cpu().detach().numpy(), torch.exp(net.deep_lin(f_line[0])).squeeze().cpu().detach().numpy(),label="sym",norm=normed)
#normplot(x_line[:,0 , 0].cpu().detach().numpy(), net.antisym(x_line).cpu().detach().numpy(),label="anti",norm=normed)
#normplot(x_line[:,0 , 0].cpu().detach().numpy(), net.nuc_asymp(f_line[1][0]).cpu().detach().numpy(),label="asym",norm=normed)
#N = net.nuc_asymp(f_line[1][0]).cpu().detach().numpy()
#normplot(x_line[:,0,0].cpu().detach().numpy(),-1*(N*x_line[:,0,0].cpu().detach().numpy()),label="asym*line",norm=normed)
normplot(x_line[:,0 , 0].cpu().detach().numpy(), net(x_line).cpu().detach().numpy()**2,label="WF",norm=normed,lw=2,color='k')
normplot(x_line[:,0 , 0].cpu().detach().numpy(), gtowf(x_line.cpu().detach()).numpy()**2,label="gtowf",norm=normed,lw=2,color='grey')

plt.axhline(0,ls=':',color='k')
plt.axvline(0,ls=':',color='k')

#plt.ylim(-1,2)
x_line.requires_grad = False
plt.legend()
#plt.savefig('lastrunwf.png')
plt.show()
plt.subplot2grid((2,1),(0,0))
plt.plot(L[:steps//10])
plt.yscale('log')
plt.subplot2grid((2,1),(1,0))
plt.plot(L[steps//10:])
plt.yscale('log')
#plt.savefig('lastrunloss.png')


In [None]:
plt.figure(figsize=(12,4))
for i in range(6):
    plt.subplot2grid((2,3),(i//3,i%3))
    x = torch.zeros(500, 6)
    x[:,0] = 0
    x[:,i] = torch.linspace(-5, 5, 500)
    x = x.view(-1,2,3)
    plt.title("electron " +str(i//3+1))
    plt.plot(np.linspace(-5, 5, 500),gtowf(x).detach().numpy()**2)
    plt.plot(np.linspace(-5, 5, 500),net(x).cpu().detach().numpy()**2)
    plt.axhline(0,ls=':',color='k')
    plt.savefig("supervised.png")
    #plt.axis('off')

In [None]:
###### try:
    net.cuda()
    plt.figure(figsize=(12,4))
    for i in range(6):
        plt.subplot2grid((2,3),(i//3,i%3))
        x = torch.zeros(500, 6)
        x[:,i] = torch.linspace(-5, 5, 500)
        x = x.view(-1,2,3)
        plt.title("electron " +str(i//3+1))
        plt.plot(np.linspace(-5, 5, 500),net(x.cuda()).cpu().detach().numpy()**2)
        plt.axhline(0,ls=':',color='k')
        #plt.axis('off')
except:
    pass

In [None]:
try:
    G = np.array(np.meshgrid(np.linspace(-5, 5, 500),np.linspace(-5, 5, 500))).T.reshape(-1,2)
    F = np.append(G,np.ones((250000,4)),axis=-1)
    H = np.append(F[:,[0,2,4]],F[:,[1,3,5]],axis=-1)
    W1 = gtowf(torch.from_numpy(H).view(-1,2,3)).view(500,500).numpy()
    W2 = net(torch.from_numpy(H).view(-1,2,3).type(torch.FloatTensor).cuda()).view(500,500).cpu().detach().numpy()
    levels=30
    plt.figure(figsize=(8,3))
    plt.subplot2grid((1,2),(0,0))
    plt.title("gtowf")
    plt.contourf(W1,levels)
    plt.colorbar()
    plt.subplot2grid((1,2),(0,1))
    plt.title("netwf")
    plt.contourf(W2,levels)
    plt.colorbar()
    plt.show()
    
except:
    pass

## Unsupervised

In [None]:
n_electrons=2
n_up = 2
n_down = n_electrons-n_up

In [None]:
net = HanNet(h2,n_up,n_down).cuda()

In [None]:
if False:
    sampler = hmc(
        net,
        torch.randn(1000, n_electrons, 3, device='cuda'),
        dysteps=3,
        stepsize=0.2,
        tau = 0.1,
        cutoff = 1.0
    )
else:
    sampler = langevin_monte_carlo(
        gtowf,
        torch.randn(1000, n_electrons, 3, device='cuda'),
        tau=0.1,
    )

In [None]:
#t=time.time()
#samples = samples_from(sampler,range(1000))[0].flatten(end_dim=1)
#print("it took: "+str(time.time()-t))

In [None]:
molecule = h2
fit_wfnet(
    net,
    partial(loss_local_energy, E_ref=-1.1,p=2),
    torch.optim.Adam(net.parameters(), lr=1e-3),
    wfnet_fit_driver(
            sampler,
            samplings=range(1),
            n_epochs=1,
            n_sampling_steps=100,
            batch_size=1_000,
            n_discard=50,
            range_sampling=partial(trange, desc='sampling steps', leave=False),
            range_training=partial(trange, desc='training steps', leave=False),
        ),
    clip_grad = None,
    exclude_below = 0,
    writer = SummaryWriter(f'runs/'),
    )

fit_wfnet(
    net2,
    partial(loss_local_energy, E_ref=None,p=1),
    torch.optim.Adam(net.parameters(), lr=1e-3),
    wfnet_fit_driver(
            sampler,
            samplings=range(1),
            n_epochs=1,
            n_sampling_steps=1000,
            batch_size=1_000,
            n_discard=50,
            range_sampling=partial(trange, desc='sampling steps', leave=False),
            range_training=partial(trange, desc='training steps', leave=False),
        ),
    clip_grad = None,
    writer = SummaryWriter(f'runs/'),
    )



In [None]:
x_line = torch.cat((torch.linspace(-5, 5, 5000)[:, None], torch.zeros((5000, 3*n_electrons-1))), dim=1)
x_line[:,3]=h2.coords[1][0]/2
x_line=x_line.view(-1,n_electrons,3)#.cuda()
x_line.requires_grad = True
gtowf

normed=True
#normplot(x_line[:,0 , 0].cpu().detach().numpy(), torch.exp(net.deep_lin(f_line[0])).squeeze().cpu().detach().numpy(),label="sym",norm=normed)
#normplot(x_line[:,0 , 0].cpu().detach().numpy(), net.antisym(x_line).cpu().detach().numpy(),label="anti",norm=normed)
#normplot(x_line[:,0 , 0].cpu().detach().numpy(), net.nuc_asymp(f_line[1][0]).cpu().detach().numpy(),label="asym",norm=normed)
#N = net.nuc_asymp(f_line[1][0]).cpu().detach().numpy()
#normplot(x_line[:,0,0].cpu().detach().numpy(),-1*(N*x_line[:,0,0].cpu().detach().numpy()),label="asym*line",norm=normed)
#d=net(x_line).cpu().detach().numpy()
#D.append(d)
normplot(x_line[:,0 , 0].cpu().detach().numpy(), gtowf(x_line).cpu().detach().numpy(),label="WF",norm=normed,lw=2,color='k')

plt.axhline(0,ls=':',color='k')
plt.axvline(0,ls=':',color='k')

#plt.ylim(-1,2)
x_line.requires_grad = False
plt.legend()
#plt.savefig('lastrunwf.png')
plt.show()
#plt.subplot2grid((2,1),(0,0))
#plt.plot(L[:steps//10])
#plt.yscale('log')
#plt.subplot2grid((2,1),(1,0))
#plt.plot(L[steps//10:])
#plt.yscale('log')
#plt.savefig('lastrunloss.png')


In [None]:
#plt.plot(x_line[:,0 , 0].cpu().detach().numpy(),net.antisym.net_pair_anti(x_line[:,0],x_line[:,1]).cpu().detach().numpy())
#plt.show()

In [None]:
#tmp = net.antisym.net_pair_anti(torch.from_numpy(H[:,0:3]).type(torch.FloatTensor).cuda(),torch.from_numpy(H[:,3:]).type(torch.FloatTensor).cuda()).cpu().detach().numpy()[:,9].reshape(500,500)
#plt.contourf(tmp)
#plt.colorbar()

In [None]:
t=time.time()
samples = samples_from(sampler,range(100))[0].flatten(end_dim=1)
print("it took: "+str(time.time()-t))

In [None]:
plt.hist2d(
    samples[:,0, 0].cpu().detach().numpy(),
    samples[:,0, 1].cpu().detach().numpy(),
    bins=100,
    range=[[-3, 3+h2.coords[1,0].cpu().numpy()], [-3, 3]],
)                                   
plt.gca().set_aspect(1)

In [None]:
#net = net.cpu()
#samples = samples.cpu()
h2._coords  = h2._coords.cuda()
h2._charges = h2._charges.cuda()
h2.coords.device

In [None]:
E_loc = dlqmc.utils.batch_eval_tuple(partial(local_energy, wf=lambda x: net(x),geom=h2),tqdm(samples.view([-1,n_electrons,3]).split(1000)))[0]

In [None]:
#print(np.where((E_loc.detach().numpy())>100)[0].shape)
#print(np.where((E_loc.detach().numpy())<-100)[0].shape)
#print(np.min(E_loc.detach().numpy()))
#print(np.max(E_loc.detach().numpy()))
#net(samples[np.where((E_loc.detach().numpy())>10)])**2

In [None]:
mean=E_loc.clamp(-10, 10).mean().item()

h = plt.hist(E_loc.detach().clamp(-1.5, 1).cpu().numpy(), bins=100,alpha = 0.5,color='b')
plt.annotate("mean = "+str(np.round(mean,4)),(-0.3,np.max(h[0])/2),color='b')
plt.annotate("var     = "+str(np.round(np.var(E_loc.detach().clamp(-10, 10).cpu().numpy()),4)),(-0.3,np.max(h[0])/2-np.max(h[0])/15),color='b')

#mean=e_loc_net.mean()

#h = plt.hist(e_loc_net, bins=100,color='r',alpha = 0.5)
#plt.annotate("mean = "+str(np.round(mean,4)),(0,np.max(h[0])/2-3000),color='r')
#plt.annotate("var     = "+str(np.round(np.var(e_loc_net),4)),(0,np.max(h[0])/2-np.max(h[0])/15-3000),color='r')
plt.savefig('lastruneloc.png')
plt.show()

In [None]:
#e_loc_net=E_loc.detach().clamp(-1.5, 1).cpu().numpy().copy()

In [None]:
print(torch.cuda.memory_allocated())
print(torch.cuda.memory_cached(device=None))
print(torch.cuda.max_memory_cached(device=None))
torch.cuda.empty_cache()

## $H_{10}$

In [None]:
d=1.5#1.786
n=2
hn = Geometry([[d*i, 0., 0.] for i in range(n)], [1. for i in range (n)])
print(hn)

In [None]:
mol = gto.M(
    atom=[
        ['H', (d*i, 0, 0)] for i in range(n)      
    ],
    unit='bohr',
    basis='6-31G',#'aug-cc-pV5Z',
    charge=0,
    spin=0,
)
mf = scf.RHF(mol)
mf.kernel()
gtowf = TorchGTOSlaterWF(mf)

In [None]:
n_electrons=n
n_up = n//2
n_down = n_electrons-n_up
net = HanNet(hn,n_up,n_down, 
        basis_dim=8,
        kernel_dim=16,
        embedding_dim=32,
        latent_dim=5,
        n_interactions=2,
        n_orbital_layers=3,).cuda()

In [None]:
if False:
    sampler = hmc(
        net,
        torch.randn(1000, n_electrons, 3, device='cuda'),
        dysteps=3,
        stepsize=0.2,
        tau = 0.1,
        cutoff = 1.0
    )
else:
    sampler = langevin_monte_carlo(
        net,
        torch.randn(1000, n_electrons, 3, device='cuda'),
        tau=0.1,
    )

In [None]:
molecule = hn
fit_wfnet(
    net,
    partial(loss_local_energy, E_ref=-1.1,p=2),
    torch.optim.Adam(net.parameters(), lr=1e-3),
    wfnet_fit_driver(
            sampler,
            samplings=range(1),
            n_epochs=1,
            n_sampling_steps=60,
            batch_size=1_000,
            n_discard=50,
            range_sampling=partial(trange, desc='sampling steps', leave=False),
            range_training=partial(trange, desc='training steps', leave=False),
        ),
    clip_grad = None,
    writer = SummaryWriter(f'runs/'),
    )

fit_wfnet(
    net,
    partial(loss_local_energy, E_ref=None,p=1),
    torch.optim.Adam(net.parameters(), lr=1e-3),
    wfnet_fit_driver(
            sampler,
            samplings=range(1),
            n_epochs=1,
            n_sampling_steps=60,
            batch_size=1_000,
            n_discard=50,
            range_sampling=partial(trange, desc='sampling steps', leave=False),
            range_training=partial(trange, desc='training steps', leave=False),
        ),
    clip_grad = None,
    writer = SummaryWriter(f'runs/'),
    )



In [None]:
x_line = torch.cat((torch.linspace(-1, 16, 500)[:, None], torch.zeros((500, 3*n_electrons-1))), dim=1)
for i in range(n_electrons-1):
    x_line[:,3*(i+1)]=0.7*(i+1)
x_line=x_line.view(-1,n_electrons,3).cuda()
print(x_line.shape)
x_line.requires_grad = True
net.cuda()

normed=True
#normplot(x_line[:,0 , 0].cpu().detach().numpy(), net(x_line).cpu().detach().numpy(),label="WF",norm=normed,lw=2,color='k')
normplot(x_line[:,0 , 0].cpu().detach().numpy(), gtowf(x_line).cpu().detach().numpy(),label="WF",norm=normed,lw=2,color='k')

plt.axhline(0,ls=':',color='k')
for i in range(n):
    plt.axvline(d*i,ls=':',color='k')

#plt.ylim(-1,2)
x_line.requires_grad = False
plt.legend()
plt.show()


# Backflow

In [None]:
mol = gto.M(
    atom=[
        ['B', (0, 0, 0)]
    ],
    unit='bohr',
    basis='4-31G',
    cart=True,
    charge=0,
    spin=1,
)
mf = scf.RHF(mol)
mf.kernel()


In [None]:
bohr=geomdb['H']
bohr._coords=torch.tensor([[0,0,0.]])
bohr._charges=torch.tensor([5.])

In [None]:
mol = gto.M(
    atom=[
        ['H', (0, 0, 0)],
        ['H', (1.484, 0, 0)]
    ],
    unit='bohr',
    basis='4-31G',
    cart=True,
    charge=0,
    spin=2,
)
mf = scf.RHF(mol)
mf.kernel()


In [None]:
from dlqmc.utils import nondiag

class Backflow(nn.Module):
      
    def __init__(
        self,
        n_up,
        n_down,
        n_interactions,
        basis_dim
    ):
            
        super().__init__()
        def interaction_factory(basis_dim):
                modules = {
                    'interact': get_log_dnn(basis_dim, 1, SSP, n_layers=4, last_bias=False),
                }
                return nn.ModuleDict(modules)
            
        self.interactions = nn.ModuleList(
            [
                interaction_factory(basis_dim)
                for _ in range(n_interactions)
            ])
        
        self.dist_basis = DistanceBasis(basis_dim)
            
    def forward(self,rs, debug=NULL_DEBUG):
        xs = rs.clone()
        for i, interaction in enumerate(self.interactions):
            dists_basis = self.dist_basis(pairwise_distance(xs,xs))
            *batch_dims, n_elec, n_elec, basis_dim = dists_basis.shape
            c_i, c_j, c_shape = self._conv_indexing(n_elec, n_elec, batch_dims)
            dists_basis = dists_basis[..., c_i, c_j, :]
            Ws = interaction.interact(dists_basis)
            zs = (Ws.view(*c_shape) * (xs[:, c_j].view(*c_shape)-xs[:,:,None,:])).sum(dim=2)
            xs = xs + zs
        return xs
    
    @staticmethod
    def _conv_indexing(n_elec, n_all, batch_dims):
        i, j = np.mask_indices(n_all, nondiag)
        n = n_elec * (n_all - 1)
        i, j = i[:n], j[:n]
        shape = (*batch_dims, n_elec, n_all - 1, -1)
        return i, j, shape

In [None]:
import numpy as np
import torch
from torch import nn

from dlqmc.geom import Geometry
from dlqmc.utils import NULL_DEBUG
from dlqmc.nn.anti import eval_slater
from dlqmc.nn.base import BaseWFNet
from dlqmc.nn.base import DistanceBasis
from dlqmc.nn.gto import GTOBasis


class HFNet(BaseWFNet):
    def __init__(self, geom, n_up, n_down, basis, n_interactions):
        super().__init__()
        self.n_up, self.n_down = n_up, n_down
        self.register_geom(geom)
        self.basis = basis
        self.mo = nn.Linear(basis.dim, max(n_up, n_down), bias=False)
        self.backflow = Backflow(n_up, n_down,n_interactions,20)

    def init_from_pyscf(self, mf):
        mo_coeff = mf.mo_coeff.copy()
        if mf.mol.cart:
            mo_coeff *= np.sqrt(np.diag(mf.mol.intor('int1e_ovlp_cart')))[:, None]
        self.mo.weight.detach().copy_(
            torch.from_numpy(mo_coeff[:, : max(self.n_up, self.n_down)].T)
        )

    @classmethod
    def from_pyscf(cls, mf, n_interactions):
        n_up = (mf.mo_occ >= 1).sum()
        n_down = (mf.mo_occ == 2).sum()
        assert (mf.mo_occ[:n_down] == 2).all()
        assert (mf.mo_occ[n_down:n_up] == 1).all()
        assert (mf.mo_occ[n_up:] == 0).all()
        geom = Geometry(mf.mol.atom_coords().astype('float32'), mf.mol.atom_charges())
        basis = GTOBasis.from_pyscf(mf.mol)
        wf = cls(geom, n_up, n_down, basis, n_interactions)
        wf.init_from_pyscf(mf)
        return wf

    def __call__(self, rs, debug=NULL_DEBUG):
        batch_dim, n_elec = rs.shape[:2]
        rs_back = self.backflow(rs)
        xs = debug['aos'] = self.basis(rs_back.flatten(end_dim=1)).view(
            batch_dim, n_elec, -1
        )
        xs = debug['slaters'] = self.mo(xs)
        det_up = debug['det_up'] = eval_slater(xs[:, : self.n_up, : self.n_up])
        det_down = debug['det_down'] = eval_slater(xs[:, self.n_up :, : self.n_down])
        return det_up * det_down

    def orbitals(self, rs):
        return self.mo(self.basis(rs))

    def density(self, rs):
        xs = self.orbitals(rs)
        return sum(
            (xs[:, :n_elec] ** 2).sum(dim=-1) for n_elec in (self.n_up, self.n_down)
        )


In [None]:
n_electrons=2
n_up = 2
n_down = n_electrons-n_up
molecule = h2

net0=HFNet.from_pyscf(mf,n_interactions=0).cuda()
net1=HFNet.from_pyscf(mf,n_interactions=1).cuda()
net3=HFNet.from_pyscf(mf,n_interactions=3).cuda()

In [None]:
sampler0 = langevin_monte_carlo(
    net0,
    torch.randn(1000, n_electrons, 3, device='cuda'),
    tau=0.1,
)

sampler1 = langevin_monte_carlo(
    net1,
    torch.randn(1000, n_electrons, 3, device='cuda'),
    tau=0.1,
)

sampler3 = langevin_monte_carlo(
    net3,
    torch.randn(1000, n_electrons, 3, device='cuda'),
    tau=0.1,
)

In [None]:
for net,sampler in zip([net0,net1,net3],[sampler0,sampler1,sampler3]):
    fit_wfnet(
        net,
        partial(loss_local_energy, E_ref=-1.1,p=2),
        torch.optim.Adam(net.parameters(), lr=1e-3),
        wfnet_fit_driver(
                sampler,
                samplings=range(1),
                n_epochs=1,
                n_sampling_steps=100,
                batch_size=1_000,
                n_discard=50,
                range_sampling=partial(trange, desc='sampling steps', leave=False),
                range_training=partial(trange, desc='training steps', leave=False),
            ),
        clip_grad = None,
        writer = SummaryWriter(f'runs/'),
        )

    fit_wfnet(
        net,
        partial(loss_local_energy, E_ref=None,p=1),
        torch.optim.Adam(net.parameters(), lr=1e-3),
        wfnet_fit_driver(
                sampler,
                samplings=range(1),
                n_epochs=1,
                n_sampling_steps=100,
                batch_size=1_000,
                n_discard=50,
                range_sampling=partial(trange, desc='sampling steps', leave=False),
                range_training=partial(trange, desc='training steps', leave=False),
            ),
        clip_grad = None,
        writer = SummaryWriter(f'runs/'),
        )


In [None]:
x_line = torch.cat((torch.linspace(-5, 5, 500)[:, None], torch.zeros((500, 3*n_electrons-1))), dim=1)
x_line[:,3]=h2.coords[1][0]/2
x_line=x_line.view(-1,n_electrons,3).cuda()
x_line.requires_grad = True


normed=True
normplot(x_line[:,0 , 0].cpu().detach().numpy(), net0(x_line).cpu().detach().numpy(),label="WF0",norm=normed,lw=2)
normplot(x_line[:,0 , 0].cpu().detach().numpy(), net1(x_line).cpu().detach().numpy(),label="WF1",norm=normed,lw=2)
normplot(x_line[:,0 , 0].cpu().detach().numpy(), net3(x_line).cpu().detach().numpy(),label="WF3",norm=normed,lw=2)
plt.axhline(0,ls=':',color='k')
plt.axvline(0,ls=':',color='k')
plt.axvline(0,ls=':',color='k')
x_line.requires_grad = False
plt.legend()
plt.show()


In [None]:
i=0
for net,sampler in zip([net0,net1,net3],[sampler0,sampler1,sampler3]):
    i+=1
    t=time.time()
    samples = samples_from(sampler,range(100))[0].flatten(end_dim=1)
    print("it took: "+str(time.time()-t))
    
    E_loc = dlqmc.utils.batch_eval_tuple(partial(local_energy, wf=lambda x: net(x),geom=net.geom),tqdm(samples.view([-1,n_electrons,3]).split(1000)))[0]
    E_loc = E_loc.clamp(-4, 1)
    mean=E_loc.mean().item()
    h = plt.hist(E_loc.detach().cpu().numpy(), bins=100,alpha = 0.5,label=(str(i)+": mean = "+str(np.round(mean,4))))
plt.legend()
plt.show()

In [None]:
t=time.time()
samples = samples_from(sampler3,range(100))[0].flatten(end_dim=1)
print("it took: "+str(time.time()-t))
    
plt.hist2d(
    samples[:,:, 0].cpu().flatten().detach().numpy(),
    samples[:,:, 1].cpu().flatten().detach().numpy(),
    bins=100,
    range=[[-1, 1], [-1, 1]],
)                                   
plt.gca().set_aspect(1)

In [None]:
E_loc = dlqmc.utils.batch_eval_tuple(partial(local_energy, wf=lambda x: net0(x),geom=net0.geom),tqdm(samples.view([-1,n_electrons,3]).split(1000)))[0]

In [None]:
mean=np.round(E_loc.clamp(-40, 10).mean().item(),4)
var=np.round(np.var(E_loc.detach().clamp(-40, 10).cpu().numpy()),4)

h = plt.hist(E_loc.detach().clamp(-40, 10).cpu().numpy(), bins=100,alpha = 0.5,color='b')
plt.annotate("mean = "+str(mean),(-0.3,np.max(h[0])/2),color='b')
plt.annotate("var     = "+str(var),(-0.3,np.max(h[0])/2-np.max(h[0])/15),color='b')

plt.savefig('lastruneloc.png')
plt.show()