In [1]:
import os
import numpy as np
import jax
import jax.numpy as jnp
from jax import random
from typing import Tuple, Optional
from pyscf import gto, scf, mcscf
import pyqmc.api as pyq
from pyqmc.energy import *

from qmc.pyscftools import orbital_evaluator_from_pyscf
from qmc.setting import initialize_calculation, determine_complex_settings
from qmc.mc import limdrift
from qmc.orbitals import *
from qmc.determinants import *
from qmc.extract import *

np.random.seed(42)
jax.config.update("jax_enable_x64",True)

Array([1.], dtype=float64)

In [172]:
aograd = aos(mol, "GTOval_sph_deriv1", coords[:, 0, :])
print(aograd.shape)
mograd = mos(aograd, mo_coeff[0])
mograd_vals = mograd[:, :, occup_hash[0]]
print(mograd_vals.shape)

(1, 4, 20, 7)
(4, 20, 1, 5)


In [177]:
e_eff = 0
inverse[0][..., e_eff].shape

(20, 1, 5)

In [179]:
dets

(Array([[[ -1.        ],
         [  1.        ],
         [  1.        ],
         [ -1.        ],
         [  1.        ],
         [ -1.        ],
         [  1.        ],
         [ -1.        ],
         [ -1.        ],
         [  1.        ],
         [  1.        ],
         [ -1.        ],
         [  1.        ],
         [  1.        ],
         [ -1.        ],
         [ -1.        ],
         [  1.        ],
         [  1.        ],
         [ -1.        ],
         [ -1.        ]],
 
        [[-13.1616891 ],
         [-19.14262285],
         [-11.29792497],
         [-12.24398278],
         [-13.64777082],
         [-14.69809905],
         [-13.11556495],
         [-10.30587905],
         [-10.7316497 ],
         [-11.25255387],
         [-12.7664418 ],
         [-18.01364508],
         [-10.33005549],
         [-16.73443612],
         [-11.48263168],
         [-12.16736444],
         [-16.10019578],
         [-17.87273356],
         [-16.84201006],
         [-16.83161216

In [155]:
nconf, nelect_tot, ndim = coords.shape
atomic_orbital =  aos(mol,"GTOval",coords)
print(atomic_orbital.shape)
aovals = atomic_orbital.reshape(-1, nconf, nelect_tot, atomic_orbital.shape[-1])
print(aovals.shape)
aovals[:, :, :_nelec[0]].shape
print(nconf)
print(mo_coeff[0].shape)
print(aovals.shape)
mo_values = mos(aovals, mo_coeff[0])
print(mo_values.shape)
mo_vals = jnp.swapaxes(mo_values[:, :, occup_hash[0]], 1, 2)
print(mo_vals.shape)


(1, 200, 7)
(1, 20, 10, 7)
20
(7, 5)
(1, 20, 10, 7)
(20, 10, 5)
(20, 1, 10, 5)


In [161]:
print(mo_vals.shape)
inverse = jax.vmap(jnp.linalg.inv)(mo_vals)

(20, 1, 10, 5)


ValueError: Argument to inv must have shape [..., n, n], got (1, 10, 5).

In [154]:
mol = gto.Mole()
mol.atom = '''
O        0.000000    0.000000    0.117790
H        0.000000    0.755453   -0.471161
H        0.000000   -0.755453   -0.471161'''
mol.basis = 'sto-3g'
# mol.ecp = "ccECP"
mol.build()

nconfig, seed = 20, 42

coords, max_orb, det_coeff, det_map, mo_coeff, occup_hash, _nelec, nelec = \
    initialize_calculation(mol, nconfig, seed)
    
iscomplex, mo_dtype, get_phase = \
    determine_complex_settings(mo_coeff, det_coeff)

coords = jnp.array(coords)
atom_coords = jnp.array(mol.atom_coords())
atom_charges = jnp.array(mol.atom_charges())

np.random.seed(seed)

config = pyq.initial_guess(mol, nconfig)

converged SCF energy = -74.963146775618


In [30]:
import pyscf
mf = pyscf.scf.RHF(mol)
mf.kernel()
mc = pyscf.mcscf.CASSCF(mf = pyscf.scf.RHF(mol))
mc.kernel()

converged SCF energy = -1.11675930739643


TypeError: CASSCF() got an unexpected keyword argument 'mf'

In [222]:
wf1.pgradient()

{'det_coeff': array([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]]),
 'mo_coeff_alpha': array([[[ 1.30960133e+00, -2.10685799e-02,  3.03389351e-02,
          -8.48229741e-03,  7.38973454e-03],
         [-3.04638843e+00,  9.57933840e-01, -3.23394705e-01,
           2.74433353e-01, -7.87701023e-02],
         [-5.98167688e-17,  8.70442096e-16,  2.29486289e-15,
          -7.80674637e-16,  1.00000000e+00],
         [ 1.80481684e+00,  1.42585068e-01,  1.60205031e+00,
          -9.87634808e-02,  1.08365942e-01],
         [ 1.23667723e+01, -3.14734238e-01,  1.21111729e+00,
           8.15206166e-01,  2.94995036e-01],
         [ 1.27969069e+01,  3.93532088e-01,  1.39865957e+00,
          -3.26615402e-01,  2.59273847e-01],
         [ 1.52571598e+01,  5.87898248e-01,  1.3367

In [218]:
from pyqmc.wftools import generate_slater
wf1, to_opt1 = generate_slater(mol, mf, optimize_orbitals=True)
wf1.recompute(config)
wf1.pgradient()

{'det_coeff': array([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]]),
 'mo_coeff_alpha': array([[[ 1.30960133e+00, -2.10685799e-02,  3.03389351e-02,
          -8.48229741e-03,  7.38973454e-03],
         [-3.04638843e+00,  9.57933840e-01, -3.23394705e-01,
           2.74433353e-01, -7.87701023e-02],
         [-5.98167688e-17,  8.70442096e-16,  2.29486289e-15,
          -7.80674637e-16,  1.00000000e+00],
         [ 1.80481684e+00,  1.42585068e-01,  1.60205031e+00,
          -9.87634808e-02,  1.08365942e-01],
         [ 1.23667723e+01, -3.14734238e-01,  1.21111729e+00,
           8.15206166e-01,  2.94995036e-01],
         [ 1.27969069e+01,  3.93532088e-01,  1.39865957e+00,
          -3.26615402e-01,  2.59273847e-01],
         [ 1.52571598e+01,  5.87898248e-01,  1.3367

In [184]:
mc = pyscf.mcscf.CASSCF(mf, 2, 2)
mc.kernel()
wf, to_opt = pyq.generate_wf(mol, mf, mc=mc)
grad =pyq.gradient_generator(mol, wf, to_opt)
de = wf.recompute(config)
print(de[0].shape)
print(wf._aovals.shape)

CASSCF energy = -74.9643987948316
CASCI E = -74.9643987948316  E(CI) = -1.66507447389938  S^2 = 0.0000000
(20,)


AttributeError: 'MultiplyWF' object has no attribute '_aovals'

In [None]:
def compute_wf_value(configs, dets, det_coeff, det_map):
    
    updets, dndets = dets
    updets, dndets = updets[:, :, det_map[0]], dndets[:, :, det_map[1]]
    
    upref, dnref = jnp.amax(updets[1]).real, jnp.amax(dndets[1]).real
    phases = updets[0] * dndets[0]
    logvals = updets[1] - upref + dndets[1] - dnref
    
    wf_val = jnp.einsum("d,id->i", det_coeff, phases * jnp.exp(logvals))
    wf_sign = jnp.where(wf_val == 0, 0.0, wf_val / jnp.abs(wf_val))
    wf_logval = jnp.where(wf_val == 0, -jnp.inf,
                            jnp.log(jnp.abs(wf_val)) + upref + dnref)
    
    return wf_sign, wf_logval

In [139]:
wf1.pgradient()

{'det_coeff': array([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]]),
 'mo_coeff_alpha': array([[[1.11884433],
         [0.70317275]],
 
        [[0.3960754 ],
         [1.42594169]],
 
        [[0.4998836 ],
         [1.32213349]],
 
        [[1.42151295],
         [0.40050413]],
 
        [[0.49948052],
         [1.32253656]],
 
        [[1.5048245 ],
         [0.31719258]],
 
        [[0.76205564],
         [1.05996145]],
 
        [[1.53600466],
         [0.28601242]],
 
        [[0.28898596],
         [1.53303113]],
 
        [[0.70197291],
         [1.12004417]],
 
        [[0.33670211],
         [1.48531497]],
 
        [[0.94528946],
         [0.87672762]],
 
        [[1.4956114 ],
         [0.32640568]],
 
        [[1.48443384],
         [0.33758324]],
 
 

In [162]:
from pyqmc.api import Slater
import pyqmc.api as pyq

np.random.seed(42)
configs = pyq.initial_guess(mol, nconfig)
coords = configs.configs
coords = jnp.array(configs.configs)

mf = scf.RHF(mol)
mf.kernel()

aovals, dets, inverse = recompute(mol, coords, mo_coeff, _nelec, occup_hash)
wf, to_opt = pyq.generate_wf(mol, mf)
grad =pyq.gradient_generator(mol, wf, to_opt)
wf.recompute(configs)
import time
start = time.time()
for i in range(500):
  aovals, dets, inverse = recompute(mol, coords, mo_coeff, _nelec, occup_hash)

end = time.time()
print(end -start)

start = time.time()
for i in range(500):
  wf.recompute(configs)

end = time.time()
print(end -start)


converged SCF energy = -74.963146775618
0.16313695907592773
0.8806300163269043


In [104]:
def compute_wf_value(configs, dets, det_coeff, det_map):
    
    updets, dndets = dets
    updets, dndets = updets[:, :, det_map[0]], dndets[:, :, det_map[1]]
    
    upref, dnref = jnp.amax(updets[1]).real, jnp.amax(dndets[1]).real
    phases = updets[0] * dndets[0]
    logvals = updets[1] - upref + dndets[1] - dnref
    
    wf_val = jnp.einsum("d,id->i", det_coeff, phases * jnp.exp(logvals))
    wf_sign = jnp.where(wf_val == 0, 0.0, wf_val / jnp.abs(wf_val))
    wf_logval = jnp.where(wf_val == 0, -jnp.inf,
                            jnp.log(jnp.abs(wf_val)) + upref + dnref)
    
    return wf_sign, wf_logval

In [50]:
def full_wf_calculation(configs, mol, mo_coeff, _nelec, occup_hash, det_coeff, det_map):
    """mo_coeff부터 시작하여 전체 파동함수 값을 스칼라로 계산"""
    # 모든 중간 값 계산
    aovals, dets, inverses = recompute(mol, configs, mo_coeff, _nelec, occup_hash)
    wf_sign, wf_logval = compute_wf_value(configs, dets, det_coeff, det_map)
    
    # 스칼라 값 반환 (자동 미분용)
    return jnp.sum(wf_sign * jnp.exp(wf_logval))

In [140]:
grad_fn = jax.grad(full_wf_calculation, argnums=2) 
grad_fn(coords, mol, mo_coeff, _nelec, occup_hash, det_coeff, det_map).shape

(2, 2, 1)

In [106]:
from pyqmc.accumulators import LinearTransform

In [191]:
wf, to_opt = pyq.generate_wf(mol, mf)
grad =pyq.gradient_generator(mol, wf, to_opt)
x0 = grad.transform.serialize_parameters(wf.parameters)

In [208]:
[k for k, opt in to_opt.items()]

['wf1det_coeff', 'wf2acoeff', 'wf2bcoeff']

In [215]:
wf.parameters

WFmerger: {'wf1': [{'det_coeff': array([1.])}, {'mo_coeff_alpha': array([[ 9.94132148e-01, -2.32780470e-01, -1.35336992e-16,
        -1.03397331e-01, -7.74861869e-17],
       [ 2.65420955e-02,  8.33658221e-01,  8.21619033e-16,
         5.37850536e-01,  3.58860084e-16],
       [-1.22555375e-19, -5.32925203e-18,  7.17463371e-16,
        -6.95515478e-16,  1.00000000e+00],
       [-5.39680694e-19,  3.70880539e-17,  6.06989046e-01,
        -1.31511318e-15, -4.61274033e-16],
       [-4.35236225e-03, -1.29783114e-01,  7.27850033e-16,
         7.75336428e-01,  5.71096595e-16],
       [-5.96115289e-03,  1.58604264e-01,  4.45281085e-01,
        -2.78551350e-01, -3.39173634e-16],
       [-5.96115289e-03,  1.58604264e-01, -4.45281085e-01,
        -2.78551350e-01,  1.14884068e-16]]), 'mo_coeff_beta': array([[ 9.94132148e-01, -2.32780470e-01, -1.35336992e-16,
        -1.03397331e-01, -7.74861869e-17],
       [ 2.65420955e-02,  8.33658221e-01,  8.21619033e-16,
         5.37850536e-01,  3.58860084e-16

In [210]:
wf.parameters["wf2bcoeff"]

array([[-0.25, -0.5 , -0.25],
       [ 0.  ,  0.  ,  0.  ],
       [ 0.  ,  0.  ,  0.  ],
       [ 0.  ,  0.  ,  0.  ]])

In [190]:
wf, to_opt = pyq.generate_slater(mol, mf)
print(to_opt)

{'det_coeff': array([False])}


In [118]:
configs.configs.shape

(10, 2, 3)

In [119]:
wf.recompute(configs)
wf.pgradient()

WFmerger: {'wf1': {'det_coeff': array([[1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.]]), 'mo_coeff_alpha': array([[[1.24566881],
        [0.57634827]],

       [[1.30463059],
        [0.51738649]],

       [[1.34335137],
        [0.47866571]],

       [[1.16592712],
        [0.65608996]],

       [[1.53002564],
        [0.29199144]],

       [[1.43797103],
        [0.38404605]],

       [[0.37853466],
        [1.44348242]],

       [[1.53199415],
        [0.29002293]],

       [[1.28853676],
        [0.53348032]],

       [[0.30226444],
        [1.51975264]]]), 'mo_coeff_beta': array([[[1.10515955],
        [0.71685753]],

       [[1.37670789],
        [0.44530919]],

       [[0.76093306],
        [1.06108402]],

       [[0.95615285],
        [0.86586423]],

       [[0.29095306],
        [1.53106402]],

       [[0.32213727],
        [1.49987981]],

       [[0.56304859],
        [1.25896849]],

       [[0.7605009

In [89]:
trans = LinearTransform(wf.parameters)
dp = trans.serialize_gradients(wf.pgradient())

In [94]:
dp.shape

(10, 37)

In [12]:
print(ei_energy(mol, configs))
print(ee_energy(configs))
print(ii_energy(mol))
print(kinetic(configs, wf)[0])
total = ei_energy(mol, configs) + ee_energy(configs) + ii_energy(mol) + kinetic(configs, wf)[0]
print(total)

[-1.90453407 -2.4670126  -3.41080024 -2.10215388 -3.85402988]
[0.71597433 1.04066599 0.84595687 0.3591553  0.55762922]
0.7151043390810812
[-0.15626125  0.12732074  0.44068302 -0.14745074  0.66278783]
[-0.62971665 -0.58392153 -1.40905601 -1.17534499 -1.9185085 ]


In [123]:
wf._det_occup

AttributeError: 'MultiplyWF' object has no attribute '_det_occup'

In [121]:
wf.parameters["det_coeff"]

KeyError: 'det'

In [71]:
# ei_mean = np.mean(ei_energy(mol, configs), axis = 0)
# ee_mean = np.mean(ee_energy(configs), axis = 0)
# ii_mean = np.mean(ii_energy(mol), axis= 0)
# ke_mean = np.mean(kinetic(configs,wf)[0], axis= 0)

In [9]:
print(jax_ei_energy(coords, atom_charges, atom_coords))
print(jax_ee_energy(coords))
print(jax_ii_energy(mol))
print(jax_kinetic_energy(coords, mol, dets, inverse, mo_coeff, det_coeff, det_map, _nelec, occup_hash)[0])

total = jax_ei_energy(coords, atom_charges, atom_coords) + jax_ee_energy(coords) + jax_ii_energy(mol) + jax_kinetic_energy(coords, mol, dets, inverse, mo_coeff, det_coeff, det_map, _nelec, occup_hash)[0]

print(total)

[-1.90453407 -2.4670126  -3.41080024 -2.10215388 -3.85402988]
[0.71597433 1.04066599 0.84595687 0.3591553  0.55762922]
0.7151043390810812
[-0.09396923  0.21639611  0.53020582 -0.08242941  0.6450487 ]
[-0.56742463 -0.49484616 -1.31953321 -1.11032365 -1.93624763]


In [73]:
grad(config, wf)

{'ke': array([-0.15626125,  0.12732074,  0.44068302, -0.14745074,  0.66278783]),
 'ee': array([0.71597433, 1.04066599, 0.84595687, 0.3591553 , 0.55762922]),
 'ei': array([-1.90453407, -2.4670126 , -3.41080024, -2.10215388, -3.85402988]),
 'ecp': array([0., 0., 0., 0., 0.]),
 'grad2': array([3.1355144 , 2.86026606, 2.25569794, 2.78992666, 1.69145976]),
 'total': array([-0.62971665, -0.58392153, -1.40905601, -1.17534499, -1.9185085 ]),
 'dpH': array([[-0.51262811+0.j, -0.41507876+0.j, -0.29545804+0.j,
         -0.17682499+0.j, -0.09536131+0.j, -0.04601202+0.j,
         -0.02189997+0.j, -0.00986495+0.j, -0.3304717 +0.j,
         -0.24590547+0.j, -0.11480599+0.j, -0.07212592+0.j,
         -0.02712589+0.j, -0.01602704+0.j, -0.00567161+0.j,
         -0.00330289+0.j,  0.        +0.j, -0.48976118+0.j,
          0.        +0.j,  0.        +0.j, -0.26070892+0.j,
          0.        +0.j,  0.        +0.j, -0.07861113+0.j,
          0.        +0.j],
        [-0.37515328+0.j, -0.40110065+0.j, -0.15

In [2]:
import pyqmc.mc

df_vmc, coords = pyqmc.mc.vmc(
    wf,
    configs,
    accumulators={"pgrad": grad},
)

NameError: name 'wf' is not defined

In [26]:
df_vmc["pgradtotal"]

array([-1.11293998, -1.02229444, -1.03308748, -1.09977122, -1.27617765,
       -1.15337538, -1.17873983, -1.19174489, -1.02438249, -1.03846931])

In [27]:
df_vmc

{'pgradke': array([1.00914793, 1.18944444, 1.56201179, 1.5362026 , 1.0849196 ,
        0.61071501, 0.89528079, 1.15546622, 1.09190092, 1.05833295]),
 'pgradee': array([0.62042943, 0.87031513, 0.74346155, 0.69357557, 0.57780602,
        0.68911782, 0.61119074, 0.64290674, 0.65118102, 0.65430755]),
 'pgradei': array([-3.45762167, -3.79715835, -4.05366516, -4.04465373, -3.6540076 ,
        -3.16831255, -3.4003157 , -3.70522219, -3.48256876, -3.46621416]),
 'pgradecp': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 'pgradgrad2': array([2.76291613, 2.72266334, 2.80891008, 2.69333234, 2.72107494,
        2.69002718, 2.73268775, 2.6970969 , 2.83886068, 2.80573356]),
 'pgradtotal': array([-1.11293998, -1.02229444, -1.03308748, -1.09977122, -1.27617765,
        -1.15337538, -1.17873983, -1.19174489, -1.02438249, -1.03846931]),
 'pgraddpH': array([[-0.90588663+0.j, -0.84892156+0.j, -0.597677  +0.j,
         -0.50483676+0.j, -0.27283644+0.j, -0.1981548 +0.j,
         -0.09140485+0.j, -0.053960

In [22]:
accumulators = {"pgrad" : grad}
for k, accumulator in accumulators.items():
    data =accumulator.avg(configs, wf)
    
data

{'ke': np.float64(0.5652793020204011),
 'ee': np.float64(0.5740119558498835),
 'ei': np.float64(-3.0198530590447406),
 'ecp': np.float64(0.0),
 'grad2': np.float64(2.773261665597687),
 'total': np.float64(-1.1654574620933746),
 'dpH': array([-0.87987491+0.j, -0.74025184+0.j, -0.4875116 +0.j, -0.37599864+0.j,
        -0.16883894+0.j, -0.12131195+0.j, -0.04129962+0.j, -0.0283496 +0.j,
        -0.94239756+0.j, -0.82914628+0.j, -0.62882745+0.j, -0.51376614+0.j,
        -0.27443396+0.j, -0.24584701+0.j, -0.07666104+0.j, -0.08244438+0.j,
         0.        +0.j, -0.67502344+0.j,  0.        +0.j,  0.        +0.j,
        -0.31708948+0.j,  0.        +0.j,  0.        +0.j, -0.10072625+0.j,
         0.        +0.j]),
 'dppsi': array([0.75912032+0.j, 0.61904342+0.j, 0.42332927+0.j, 0.31063856+0.j,
        0.14694685+0.j, 0.09894346+0.j, 0.03594173+0.j, 0.0230013 +0.j,
        0.79441994+0.j, 0.68174063+0.j, 0.5223659 +0.j, 0.41515284+0.j,
        0.22581728+0.j, 0.20058628+0.j, 0.06283257+0.j, 0.

In [15]:
df_vmc

{'pgradke': array([1.00914793, 1.18944444, 1.56201179, 1.5362026 , 1.0849196 ,
        0.61071501, 0.89528079, 1.15546622, 1.09190092, 1.05833295]),
 'pgradee': array([0.62042943, 0.87031513, 0.74346155, 0.69357557, 0.57780602,
        0.68911782, 0.61119074, 0.64290674, 0.65118102, 0.65430755]),
 'pgradei': array([-3.45762167, -3.79715835, -4.05366516, -4.04465373, -3.6540076 ,
        -3.16831255, -3.4003157 , -3.70522219, -3.48256876, -3.46621416]),
 'pgradecp': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 'pgradgrad2': array([2.76291613, 2.72266334, 2.80891008, 2.69333234, 2.72107494,
        2.69002718, 2.73268775, 2.6970969 , 2.83886068, 2.80573356]),
 'pgradtotal': array([-1.11293998, -1.02229444, -1.03308748, -1.09977122, -1.27617765,
        -1.15337538, -1.17873983, -1.19174489, -1.02438249, -1.03846931]),
 'pgraddpH': array([[-0.90588663+0.j, -0.84892156+0.j, -0.597677  +0.j,
         -0.50483676+0.j, -0.27283644+0.j, -0.1981548 +0.j,
         -0.09140485+0.j, -0.053960

In [1]:
# data = {}
# for k in grad.keys():
#     data[k] = np.mean(df_vmc["pgrad" + k], axis=0)
    
df_vmc["pgradtotal"]₩1

NameError: name 'grad' is not defined

In [223]:
electr

np.float64(16.609640474436812)