In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
import numpy as np
import scipy as sp
from jax import jit, lax, vmap
from functools import partial
import matplotlib.pyplot as plt

module_path = "../../ad_afqmc"
module_path = "../../projected_mf"

if module_path not in sys.path:
    sys.path.append(module_path)

from ad_afqmc import (
    driver,
    pyscf_interface,
    mpi_jax,
    linalg_utils,
    spin_utils,
    lattices,
    propagation,
    wavefunctions,
    hamiltonian,
    hf_guess
)

from projected_ghf import (
    build_rotchol,
    get_wigner_d,
    apply_Sz_projector,
    apply_S2_projector,
    get_real_wavefunction,
    get_energy,
    get_projected_energy,
    optimize
)

from projected_ghf_jax import (
    build_rotchol_jax,
    get_wigner_d_jax,
    apply_Sz_projector_jax,
    apply_S2_projector_jax,
    get_energy_jax,
    get_Sz_projected_energy_jax,
    optimize_jax
)

from pyscf import fci, gto, scf, mp, ao2mo
import jax.numpy as jnp

np.set_printoptions(precision=5, suppress=True)

# Hostname: g271
# System Type: Linux
# Machine Type: x86_64
# Processor: x86_64


In [3]:
@partial(jit, static_argnums=2)
def get_greens(bra, ket, verbose=False):
    bra = jnp.array(bra, dtype=jnp.complex128)
    ket = jnp.array(ket, dtype=jnp.complex128)
    norb = bra.shape[0] // 2
    omat = bra.T.conj() @ ket
    ovlp = jnp.linalg.det(omat)
    greens = jnp.zeros((2*norb, 2*norb), dtype=jnp.complex128)

    # define the two branches as lambdas:
    def true_branch(_):
        if verbose:
            # We can't use a normal print in jitted code, 
            # but jax.debug.print is allowed inside the function
            jax.debug.print("ovlp = {}", ovlp)
            
        return greens, ovlp

    def false_branch(_):
        new_greens = ket @ jnp.linalg.inv(omat) @ bra.T.conj()
        return new_greens.T, ovlp

    # jax.lax.cond takes a boolean condition 
    # (here, jnp.absolute(ovlp) < 1e-15), plus two functions:
    return lax.cond(
        jnp.absolute(ovlp) < 1e-15,
        true_branch,
        false_branch,
        operand=None
    )

# Test `get_wigner_d`

In [60]:
def test_get_wigner_d():
    beta = np.random.random() * np.pi
    
    s, m, k = 0, 0, 0
    ref = 1.
    np.testing.assert_allclose(ref, get_wigner_d(s, m, k, beta), atol=1e-5)
    
    s, m, k = 0.5, 0.5, 0.5
    ref = np.cos(beta/2.)
    np.testing.assert_allclose(ref, get_wigner_d(s, m, k, beta), atol=1e-5)
    
    s, m, k = 0.5, 0.5, -0.5
    ref = -np.sin(beta/2.)
    np.testing.assert_allclose(ref, get_wigner_d(s, m, k, beta), atol=1e-5)
    
    s, m, k = 1, 1, 1
    ref = (1. + np.cos(beta)) / 2.
    np.testing.assert_allclose(ref, get_wigner_d(s, m, k, beta), atol=1e-5)
    
    s, m, k = 1, 1, 0
    ref = -np.sin(beta) / np.sqrt(2.)
    np.testing.assert_allclose(ref, get_wigner_d(s, m, k, beta), atol=1e-5)
    
    s, m, k = 1, 1, -1
    ref = (1. - np.cos(beta)) / 2.
    np.testing.assert_allclose(ref, get_wigner_d(s, m, k, beta), atol=1e-5)
    
    s, m, k = 1.5, 1.5, 1.5
    ref = (1. + np.cos(beta)) * np.cos(beta/2.) / 2.
    np.testing.assert_allclose(ref, get_wigner_d(s, m, k, beta), atol=1e-5)
    
    s, m, k = 2, 1.5, 1.5
    ref = (1. + np.cos(beta))**2 / 4.
    np.testing.assert_allclose(ref, get_wigner_d(s, m, k, beta), atol=1e-5)

In [61]:
test_get_wigner_d()

# Test `apply_Sz_projector`

# Test with $|+\rangle = \frac{1}{\sqrt{2}}(|\uparrow\rangle + |\downarrow\rangle)$

In [41]:
norb = 3
nocc = 1
psi = np.array([[1., 0., 0.], 
                [0., 1., 0.], 
                [0., 0., 1.], 
                [1., 0., 0.], 
                [0., 1., 0.], 
                [0., 0., 1.]]) / np.sqrt(2.)

In [42]:
# Initial spin average and overlap.
G, ovlp = get_greens(psi[:, :nocc], psi[:, :nocc])
savg = spin_utils.get_spin_average_from_greens_ghf(G) * ovlp

print(f'<S> = {savg}')
print(f'ovlp = {ovlp}')

<S> = [0.5+0.j 0. +0.j 0. +0.j]
ovlp = (0.9999999999999998+0j)


## Project to $s_z = 1/2$

In [43]:
kets, coeffs = apply_Sz_projector_jax(psi[:, :nocc], 0.5, 500) # Should just give |u>.
kets = np.array(kets)
coeffs = np.array(coeffs)

In [44]:
# Projected spin average and overlap.
savg_tot = np.zeros(3, dtype=np.complex128)
ovlp_tot = 0.

for ibra, bra in enumerate(kets):
    for iket, ket in enumerate(kets):
        G, ovlp = get_greens(bra, ket)
        savg = spin_utils.get_spin_average_from_greens_ghf(G) * ovlp
        savg_tot += savg * coeffs[ibra].conj() * coeffs[iket]
        ovlp_tot += ovlp * coeffs[ibra].conj() * coeffs[iket]

In [45]:
print(f'<S> * ovlp = {savg_tot}')
print(f'ovlp = {ovlp_tot}')
print(f'\n<S> = {savg_tot / ovlp_tot}')

<S> * ovlp = [-0.   -0.j -0.   +0.j  0.249-0.j]
ovlp = (0.5000000000000908+2.1608183671036974e-18j)

<S> = [-0.   -0.j -0.   +0.j  0.498-0.j]


## Project to $s_z = -1/2$

In [46]:
kets, coeffs = apply_Sz_projector_jax(psi[:, :nocc], -0.5, 500) # Should just give |d>.
kets = np.array(kets)
coeffs = np.array(coeffs)

In [47]:
# Projected spin average and overlap.
savg_tot = np.zeros(3, dtype=np.complex128)
ovlp_tot = 0.

for ibra, bra in enumerate(kets):
    for iket, ket in enumerate(kets):
        G, ovlp = get_greens(bra, ket)
        savg = spin_utils.get_spin_average_from_greens_ghf(G) * ovlp
        savg_tot += savg * coeffs[ibra].conj() * coeffs[iket]
        ovlp_tot += ovlp * coeffs[ibra].conj() * coeffs[iket]

In [48]:
print(f'<S> * ovlp = {savg_tot}')
print(f'ovlp = {ovlp_tot}')
print(f'\n<S> = {savg_tot / ovlp_tot}')

<S> * ovlp = [-0.   +0.j -0.   -0.j -0.249-0.j]
ovlp = (0.5000000000000908-2.1608183671036974e-18j)

<S> = [-0.   +0.j -0.   -0.j -0.498-0.j]


## Project to $s_z = 0$

We apply the projector 
\begin{align}
    \hat{P}_0 |+\rangle &= \frac{1}{\sqrt{2}} \frac{1}{2\pi} \int_0^{2\pi} e^{i\phi\hat{S}_z} d\phi \left( |\uparrow\rangle + |\downarrow\rangle \right) \\
    &= \frac{1}{\sqrt{2}} \cdot \frac{1}{2\pi} \left\{ \int_0^{2\pi} e^{i\phi\hat{S}_z} |\uparrow\rangle d\phi + \int_0^{2\pi} e^{i\phi\hat{S}_z} d\phi |\downarrow\rangle d\phi \right\} \\
    &= \frac{1}{\sqrt{2}} \cdot \frac{1}{2\pi} \left\{ \int_0^{2\pi} e^{i\phi/2} |\uparrow\rangle d\phi + \int_0^{2\pi} e^{-i\phi/2} d\phi |\downarrow\rangle d\phi \right\} \\
    &= \frac{1}{\sqrt{2}} \cdot \frac{1}{2\pi} \left\{ \left[\frac{2}{i} e^{i\phi/2}\right]_0^{2\pi} |\uparrow\rangle + \left[-\frac{2}{i} e^{-i\phi/2}\right]_0^{2\pi} |\downarrow\rangle  \right\} \\
    &= \frac{1}{\sqrt{2}} \cdot \frac{1}{2\pi} \left\{ \frac{2}{i} \left(e^{i\phi} - 1\right) |\uparrow\rangle + -\frac{2}{i} \left(e^{-i\phi} - 1\right) |\downarrow\rangle  \right\} \\
    &= \frac{1}{\sqrt{2}} \cdot \frac{1}{2\pi} \left\{ -\frac{4}{i} |\uparrow\rangle + \frac{4}{i} |\downarrow\rangle  \right\} \\
    &= \frac{1}{\sqrt{2}} \cdot \frac{2i}{\pi} \left\{ |\uparrow\rangle - |\downarrow\rangle  \right\} \\
    &= \frac{2i}{\pi} |-\rangle,
\end{align}

which gives
\begin{align}
    \langle \hat{S}_x \rangle &= \left(\frac{2i}{\pi}\right)^2 \langle - | \hat{S}_x | - \rangle = \left(\frac{2i}{\pi}\right)^2 \cdot \frac{1}{2} = -\frac{2}{\pi^2}.
\end{align}

In [49]:
kets, coeffs = apply_Sz_projector_jax(psi[:, :nocc], 0., 1000)
kets = np.array(kets)
coeffs = np.array(coeffs)

In [50]:
# Projected spin average and overlap.
savg_tot = np.zeros(3, dtype=np.complex128)
ovlp_tot = 0.

for ibra, bra in enumerate(kets):
    for iket, ket in enumerate(kets):
        G, ovlp = get_greens(bra, ket)
        savg = spin_utils.get_spin_average_from_greens_ghf(G) * ovlp
        savg_tot += savg * coeffs[ibra].conj() * coeffs[iket]
        ovlp_tot += ovlp * coeffs[ibra].conj() * coeffs[iket]

In [51]:
print(f'<S> * ovlp = {savg_tot}')
print(f'ovlp = {ovlp_tot}')
print(f'\n<S> = {savg_tot / ovlp_tot}')

<S> * ovlp = [-0.20232+0.j -0.00064+0.j  0.     -0.j]
ovlp = (0.4052850679027681+0j)

<S> = [-0.49921+0.j -0.00157+0.j  0.     -0.j]


In [52]:
print(-2. / np.pi**2)

-0.20264236728467555


# Test with $|++\rangle = \frac{1}{2}(|\uparrow\uparrow\rangle + |\uparrow\downarrow\rangle + |\downarrow\uparrow\rangle + |\downarrow\downarrow\rangle)$

In [34]:
norb = 3
nocc = 2
psi = np.array([[1., 0., 0.], 
                [0., 1., 0.], 
                [0., 0., 1.], 
                [1., 0., 0.], 
                [0., 1., 0.], 
                [0., 0., 1.]]) / np.sqrt(2.)

In [35]:
# Initial spin average and overlap.
G, ovlp = get_greens(psi[:, :nocc], psi[:, :nocc])
savg = spin_utils.get_spin_average_from_greens_ghf(G) * ovlp

print(f'<S> = {savg}')
print(f'ovlp = {ovlp}')

<S> = [1. 0. 0.]
ovlp = 0.9999999999999996


## Project to $s_z = 1$

In [36]:
kets, coeffs = apply_Sz_projector_jax(psi[:, :nocc], 1., 10) # Should just give |uu>.
kets = np.array(kets)
coeffs = np.array(coeffs)

In [37]:
# Projected spin average and overlap.
savg_tot = np.zeros(3, dtype=np.complex128)
ovlp_tot = 0.

for ibra, bra in enumerate(kets):
    for iket, ket in enumerate(kets):
        G, ovlp = get_greens(bra, ket)
        savg = spin_utils.get_spin_average_from_greens_ghf(G) * ovlp
        savg_tot += savg * coeffs[ibra].conj() * coeffs[iket]
        ovlp_tot += ovlp * coeffs[ibra].conj() * coeffs[iket]

In [38]:
print(f'<S> * ovlp = {savg_tot}')
print(f'ovlp = {ovlp_tot}')
print(f'\n<S> = {savg_tot / ovlp_tot}')

<S> * ovlp = [-0.  -0.j -0.  -0.j  0.25+0.j]
ovlp = (0.2499999999999999+2.6020852139652106e-18j)

<S> = [-0.-0.j -0.-0.j  1.-0.j]


## Project to $s_z = -1$

In [39]:
kets, coeffs = apply_Sz_projector_jax(psi[:, :nocc], -1., 10) # Should just give |dd>.
kets = np.array(kets)
coeffs = np.array(coeffs)

In [40]:
# Projected spin average and overlap.
savg_tot = np.zeros(3, dtype=np.complex128)
ovlp_tot = 0.

for ibra, bra in enumerate(kets):
    for iket, ket in enumerate(kets):
        G, ovlp = get_greens(bra, ket)
        savg = spin_utils.get_spin_average_from_greens_ghf(G) * ovlp
        savg_tot += savg * coeffs[ibra].conj() * coeffs[iket]
        ovlp_tot += ovlp * coeffs[ibra].conj() * coeffs[iket]

In [41]:
print(f'<S> * ovlp = {savg_tot}')
print(f'ovlp = {ovlp_tot}')
print(f'\n<S> = {savg_tot / ovlp_tot}')

<S> * ovlp = [-0.  +0.j -0.  +0.j -0.25+0.j]
ovlp = (0.2499999999999999-2.6020852139652106e-18j)

<S> = [-0.+0.j -0.+0.j -1.-0.j]


## Project to $s_z = 0$

In [42]:
kets, coeffs = apply_Sz_projector_jax(psi[:, :nocc], 0., 10) # Should give |ud> + |du>.
kets = np.array(kets)
coeffs = np.array(coeffs)

In [43]:
# Projected spin average and overlap.
savg_tot = np.zeros(3, dtype=np.complex128)
ovlp_tot = 0.

for ibra, bra in enumerate(kets):
    for iket, ket in enumerate(kets):
        G, ovlp = get_greens(bra, ket)
        savg = spin_utils.get_spin_average_from_greens_ghf(G) * ovlp
        savg_tot += savg * coeffs[ibra].conj() * coeffs[iket]
        ovlp_tot += ovlp * coeffs[ibra].conj() * coeffs[iket]

In [44]:
print(f'<S> * ovlp = {savg_tot}')
print(f'ovlp = {ovlp_tot}')
print(f'\n<S> = {savg_tot / ovlp_tot}')

<S> * ovlp = [-0.+0.j -0.+0.j  0.-0.j]
ovlp = (0.5+0j)

<S> = [-0.+0.j -0.+0.j  0.-0.j]


# Test with $|+++\rangle$

In [45]:
norb = 3
nocc = 3
psi = np.array([[1., 0., 0.], 
                [0., 1., 0.], 
                [0., 0., 1.], 
                [1., 0., 0.], 
                [0., 1., 0.], 
                [0., 0., 1.]]) / np.sqrt(2.)

In [46]:
# Initial spin average and overlap.
G, ovlp = get_greens(psi[:, :nocc], psi[:, :nocc])
savg = spin_utils.get_spin_average_from_greens_ghf(G) * ovlp

print(f'<S> = {savg}')
print(f'ovlp = {ovlp}')

<S> = [1.5 0.  0. ]
ovlp = 0.9999999999999993


## Project to $s_z = 1.5$

In [47]:
kets, coeffs = apply_Sz_projector_jax(psi[:, :nocc], 1.5, 10) # Should just give |uuu>.
kets = np.array(kets)
coeffs = np.array(coeffs)

In [48]:
# Projected spin average and overlap.
savg_tot = np.zeros(3, dtype=np.complex128)
ovlp_tot = 0.

for ibra, bra in enumerate(kets):
    for iket, ket in enumerate(kets):
        G, ovlp = get_greens(bra, ket)
        savg = spin_utils.get_spin_average_from_greens_ghf(G) * ovlp
        savg_tot += savg * coeffs[ibra].conj() * coeffs[iket]
        ovlp_tot += ovlp * coeffs[ibra].conj() * coeffs[iket]

In [49]:
print(f'<S> * ovlp = {savg_tot}')
print(f'ovlp = {ovlp_tot}')
print(f'\n<S> = {savg_tot / ovlp_tot}')

<S> * ovlp = [-0.    +0.j -0.    -0.j  0.1875+0.j]
ovlp = (0.12499999999999996-8.673617379884035e-19j)

<S> = [-0. +0.j -0. -0.j  1.5+0.j]


## Project to $s_z = 0.5$

In [50]:
kets, coeffs = apply_Sz_projector_jax(psi[:, :nocc], 0.5, 10) # Should give |uud> + |udu> + |duu>.
kets = np.array(kets)
coeffs = np.array(coeffs)

In [51]:
# Projected spin average and overlap.
savg_tot = np.zeros(3, dtype=np.complex128)
ovlp_tot = 0.

for ibra, bra in enumerate(kets):
    for iket, ket in enumerate(kets):
        G, ovlp = get_greens(bra, ket)
        savg = spin_utils.get_spin_average_from_greens_ghf(G) * ovlp
        savg_tot += savg * coeffs[ibra].conj() * coeffs[iket]
        ovlp_tot += ovlp * coeffs[ibra].conj() * coeffs[iket]

In [52]:
print(f'<S> * ovlp = {savg_tot}')
print(f'ovlp = {ovlp_tot}')
print(f'\n<S> = {savg_tot / ovlp_tot}')

<S> * ovlp = [-0.    -0.j -0.    -0.j  0.1875-0.j]
ovlp = (0.3749999999999997-1.734723475976807e-18j)

<S> = [-0. -0.j -0. -0.j  0.5-0.j]


## Project to $s_z = -0.5$

In [53]:
kets, coeffs = apply_Sz_projector_jax(psi[:, :nocc], -0.5, 10) # Should give |ddu> + |dud> + |udd>.
kets = np.array(kets)
coeffs = np.array(coeffs)

In [54]:
# Projected spin average and overlap.
savg_tot = np.zeros(3, dtype=np.complex128)
ovlp_tot = 0.

for ibra, bra in enumerate(kets):
    for iket, ket in enumerate(kets):
        G, ovlp = get_greens(bra, ket)
        savg = spin_utils.get_spin_average_from_greens_ghf(G) * ovlp
        savg_tot += savg * coeffs[ibra].conj() * coeffs[iket]
        ovlp_tot += ovlp * coeffs[ibra].conj() * coeffs[iket]

In [55]:
print(f'<S> * ovlp = {savg_tot}')
print(f'ovlp = {ovlp_tot}')
print(f'\n<S> = {savg_tot / ovlp_tot}')

<S> * ovlp = [-0.    +0.j -0.    +0.j -0.1875-0.j]
ovlp = (0.3749999999999997+1.734723475976807e-18j)

<S> = [-0. +0.j -0. +0.j -0.5-0.j]


## Project to $s_z = -1.5$

In [56]:
kets, coeffs = apply_Sz_projector_jax(psi[:, :nocc], -1.5, 10) # Should just give |ddd>.
kets = np.array(kets)
coeffs = np.array(coeffs)

In [57]:
# Projected spin average and overlap.
savg_tot = np.zeros(3, dtype=np.complex128)
ovlp_tot = 0.

for ibra, bra in enumerate(kets):
    for iket, ket in enumerate(kets):
        G, ovlp = get_greens(bra, ket)
        savg = spin_utils.get_spin_average_from_greens_ghf(G) * ovlp
        savg_tot += savg * coeffs[ibra].conj() * coeffs[iket]
        ovlp_tot += ovlp * coeffs[ibra].conj() * coeffs[iket]

In [58]:
print(f'<S> * ovlp = {savg_tot}')
print(f'ovlp = {ovlp_tot}')
print(f'\n<S> = {savg_tot / ovlp_tot}')

<S> * ovlp = [-0.    -0.j -0.    +0.j -0.1875+0.j]
ovlp = (0.12499999999999996+8.673617379884035e-19j)

<S> = [-0. -0.j -0. +0.j -1.5+0.j]


# Test `apply_S2_projector`

In [31]:
def calc_s(s2):
    return (-1. + jnp.sqrt(1+4*s2)) / 2.

def calc_multiplicity(s):
    return 2*s + 1

In [32]:
@jit
def build_pairwise_greens(bra_array, ket_array):
    # bra_array: [nbra, ...]
    # ket_array: [nket, ...]
    # we want G, ovlp for each pair (i, j)

    # Example: vmap the function over bra, then over ket
    # This will produce a result of shape [nbra, nket, ...]
    # Pseudocode only - depends on how your "get_greens" is defined.

    @jit
    def greens_over_ket(bra, ket_array):
        return vmap(lambda ket: get_greens(bra, ket))(ket_array)
    
    # Now we get G, ovlp for each (bra,ket)
    Gs, ovlps = vmap(greens_over_ket, in_axes=(0, None))(bra_array, ket_array)
    return Gs, ovlps

@jit
def get_projected_spin_square(kets, coeffs):
    Gs, ovlps = build_pairwise_greens(kets, kets)
    # Gs and ovlps each have shape [nkets, nkets, ...]
    # For each (i, j), we have G_ij, ovlp_ij

    # Next, compute spin squares for each pair. 
    # We can vectorize spin_utils.get_spin_square_from_greens_ghf if needed.
    # s2_ij now shape [nkets, nkets]
    s2_ij = vmap(
                vmap(
                    lambda G, ovlp: spin_utils.get_spin_square_from_greens_ghf(G) * ovlp,
                    in_axes=(0, 0)
                ), in_axes=(0, 0)
            )(Gs, ovlps)
    
    # Build outer products of coeffs
    coeff_matrix = jnp.outer(coeffs.conj(), coeffs)
    ovlp_tot = jnp.sum(ovlps * coeff_matrix)
    s2_tot = jnp.sum(s2_ij * coeff_matrix) / ovlp_tot
    s_tot = calc_s(s2_tot)
    return s2_tot, ovlp_tot, s_tot

@jit
def get_projected_spin_average(kets, coeffs):
    Gs, ovlps = build_pairwise_greens(kets, kets)
    # Gs and ovlps each have shape [nkets, nkets, ...]
    # For each (i, j), we have G_ij, ovlp_ij

    # Next, compute spin squares for each pair. 
    # We can vectorize spin_utils.get_spin_square_from_greens_ghf if needed.
    # savg_ij now shape [nkets, nkets, 3]
    savg_ij = vmap(
                vmap(
                    lambda G, ovlp: spin_utils.get_spin_average_from_greens_ghf(G) * ovlp,
                    in_axes=(0, 0)
                ), in_axes=(0, 0)
            )(Gs, ovlps)
    
    # Build outer products of coeffs
    coeff_matrix = jnp.outer(coeffs.conj(), coeffs)
    ovlp_tot = jnp.sum(ovlps * coeff_matrix)
    coeff_matrix = coeff_matrix.ravel()
    savg_ij = savg_ij.transpose(2, 0, 1).reshape(3, -1)
    savg_tot = jnp.sum(savg_ij * coeff_matrix, axis=-1) / ovlp_tot
    return savg_tot, ovlp_tot

# Test with $|\uparrow \downarrow \rangle = \frac{1}{\sqrt{2}}(|s=0, s_z=0 \rangle + |s=1, s_z=0\rangle)$

**Note that $s_z$ is a good quantum number for this state, but not $s$.**

In [6]:
norb = 3
nocc = 2
psi = np.array([[1., 0.], 
                [0., 0.], 
                [0., 0.], 
                [0., 0.], 
                [0., 1.], 
                [0., 0.]])

In [7]:
# Initial spin average and overlap.
G, ovlp = get_greens(psi[:, :nocc], psi[:, :nocc])
savg = spin_utils.get_spin_average_from_greens_ghf(G) * ovlp
s2 = spin_utils.get_spin_square_from_greens_ghf(G) * ovlp
s = calc_s(s2)

print(f'<S> = {savg}')
print(f'<S^2> = {s2}, S = {s}, 2S+1 = {calc_multiplicity(s)}')
print(f'ovlp = {ovlp}')

<S> = [0.+0.j 0.+0.j 0.+0.j]
<S^2> = (1+0j), S = (0.6180339887498949+0j), 2S+1 = (2.23606797749979+0j)
ovlp = (1+0j)


## Project to $s = 0$

In [8]:
ngrid_z = 6
ngrid_y = 6
s = sz = 0
kets, coeffs = apply_S2_projector(psi[:, :nocc], s, sz, ngrid_z, ngrid_y)
kets = jnp.array(kets).reshape(-1, *kets.shape[4:])
coeffs = jnp.array(coeffs).reshape(-1)

In [9]:
s2_tot, ovlp_tot, s_tot = get_projected_spin_square(kets, coeffs)
savg_tot, _ = get_projected_spin_average(kets, coeffs)

In [10]:
print(f'<S^2> * ovlp = {s2_tot * ovlp_tot}')
print(f'ovlp = {ovlp_tot}')
print(f'<S^2> = {s2_tot}, S = {s_tot}, 2S+1 = {calc_multiplicity(s_tot)}')
print(f'\n<S> = {savg_tot}')

<S^2> * ovlp = (0.03153703703703694+3.5495585291148316e-20j)
ovlp = (0.5-3.3101920919155616e-35j)
<S^2> = (0.06307407407407388+7.099117058229663e-20j), S = (0.05953022623811299+6.343819087271769e-20j), 2S+1 = (1.119060452476226+1.2687638174543537e-19j)

<S> = [ 0.+0.j  0.+0.j -0.-0.j]


In [11]:
ngrid_z = 10
ngrid_y = 10
s = sz = 0
kets, coeffs = apply_S2_projector(psi[:, :nocc], s, sz, ngrid_z, ngrid_y)
kets = jnp.array(kets).reshape(-1, *kets.shape[4:])
coeffs = jnp.array(coeffs).reshape(-1)

In [12]:
s2_tot, ovlp_tot, s_tot = get_projected_spin_square(kets, coeffs)
savg_tot, _ = get_projected_spin_average(kets, coeffs)

In [13]:
print(f'<S^2> * ovlp = {s2_tot * ovlp_tot}')
print(f'ovlp = {ovlp_tot}')
print(f'<S^2> = {s2_tot}, S = {s_tot}, 2S+1 = {calc_multiplicity(s_tot)}')
print(f'\n<S> = {savg_tot}')

<S^2> * ovlp = (0.011730937083828217-1.1737598687488631e-20j)
ovlp = (0.49999999999999994+6.817867234769268e-36j)
<S^2> = (0.023461874167656438-2.3475197374977265e-20j), S = (0.022935822226452962-2.2445581634691995e-20j), 2S+1 = (1.045871644452906-4.489116326938399e-20j)

<S> = [0.+0.j 0.+0.j 0.+0.j]


In [14]:
ngrid_z = 15
ngrid_y = 15
s = sz = 0
kets, coeffs = apply_S2_projector(psi[:, :nocc], s, sz, ngrid_z, ngrid_y)
kets = jnp.array(kets).reshape(-1, *kets.shape[4:])
coeffs = jnp.array(coeffs).reshape(-1)

In [15]:
s2_tot, ovlp_tot, s_tot = get_projected_spin_square(kets, coeffs)
savg_tot, _ = get_projected_spin_average(kets, coeffs)

In [16]:
print(f'<S^2> * ovlp = {s2_tot * ovlp_tot}')
print(f'ovlp = {ovlp_tot}')
print(f'<S^2> = {s2_tot}, S = {s_tot}, 2S+1 = {calc_multiplicity(s_tot)}')
print(f'\n<S> = {savg_tot}')

<S^2> * ovlp = (-5.724587470723463e-17-2.058612882858352e-21j)
ovlp = (0.4999999999999996+2.82118644197349e-37j)
<S^2> = (-1.1449174941446937e-16-4.117225765716707e-21j), S = (-1.1102230246251565e-16-4.117225765716708e-21j), 2S+1 = (0.9999999999999998-8.234451531433415e-21j)

<S> = [ 0.+0.j  0.+0.j -0.-0.j]


## Project to $s = 1$

In [75]:
ngrid_z = 6
ngrid_y = 6
s = 1
sz = 0
kets, coeffs = apply_S2_projector(psi[:, :nocc], s, sz, ngrid_z, ngrid_y)
kets = jnp.array(kets).reshape(-1, *kets.shape[4:])
coeffs = jnp.array(coeffs).reshape(-1)

In [76]:
s2_tot, ovlp_tot, s_tot = get_projected_spin_square(kets, coeffs)
savg_tot, _ = get_projected_spin_average(kets, coeffs)

In [77]:
print(f'<S^2> * ovlp = {s2_tot * ovlp_tot}')
print(f'ovlp = {ovlp_tot}')
print(f'<S^2> = {s2_tot}, S = {s_tot}, 2S+1 = {calc_multiplicity(s_tot)}')
print(f'\n<S> = {savg_tot}')

<S^2> * ovlp = (0.9288333333333321+2.8801799708123068e-18j)
ovlp = (0.49999999999999917+2.5176498693874662e-18j)
<S^2> = (1.8576666666666672-3.593548539772962e-18j), S = (0.9517805160101396-1.2376349248882824e-18j), 2S+1 = (2.903561032020279-2.4752698497765648e-18j)

<S> = [ 0.+0.j  0.+0.j -0.-0.j]


In [78]:
ngrid_z = 10
ngrid_y = 10
s = 1
sz = 0
kets, coeffs = apply_S2_projector(psi[:, :nocc], s, sz, ngrid_z, ngrid_y)
kets = jnp.array(kets).reshape(-1, *kets.shape[4:])
coeffs = jnp.array(coeffs).reshape(-1)

In [79]:
s2_tot, ovlp_tot, s_tot = get_projected_spin_square(kets, coeffs)
savg_tot, _ = get_projected_spin_average(kets, coeffs)

In [80]:
print(f'<S^2> * ovlp = {s2_tot * ovlp_tot}')
print(f'ovlp = {ovlp_tot}')
print(f'<S^2> = {s2_tot}, S = {s_tot}, 2S+1 = {calc_multiplicity(s_tot)}')
print(f'\n<S> = {savg_tot}')

<S^2> * ovlp = (0.9736114917395521+3.485709984540897e-17j)
ovlp = (0.4999999999999992+5.6107462426124854e-18j)
<S^2> = (1.9472229834791073+4.786345161464986e-17j), S = (0.9823032697390595+1.614495919687056e-17j), 2S+1 = (2.964606539478119+3.228991839374112e-17j)

<S> = [ 0.+0.j  0.+0.j -0.-0.j]


In [81]:
ngrid_z = 15
ngrid_y = 15
s = 1
sz = 0
kets, coeffs = apply_S2_projector(psi[:, :nocc], s, sz, ngrid_z, ngrid_y)
kets = jnp.array(kets).reshape(-1, *kets.shape[4:])
coeffs = jnp.array(coeffs).reshape(-1)

In [82]:
s2_tot, ovlp_tot, s_tot = get_projected_spin_square(kets, coeffs)
savg_tot, _ = get_projected_spin_average(kets, coeffs)

In [83]:
print(f'<S^2> * ovlp = {s2_tot * ovlp_tot}')
print(f'ovlp = {ovlp_tot}')
print(f'<S^2> = {s2_tot}, S = {s_tot}, 2S+1 = {calc_multiplicity(s_tot)}')
print(f'\n<S> = {savg_tot}')

<S^2> * ovlp = (1.0000000000000053+4.553649124439119e-18j)
ovlp = (0.5000000000000027+5.963111948670274e-19j)
<S^2> = (2+6.722053469410092e-18j), S = (1+2.240684489803364e-18j), 2S+1 = (3+4.481368979606728e-18j)

<S> = [0.+0.j 0.+0.j 0.+0.j]


# Test with $|+ + \rangle = \frac{1}{2}|s=1, s_z=1 \rangle + \frac{1}{\sqrt{2}}|s=1, s_z=0\rangle + \frac{1}{2}|s=1, s_z=-1 \rangle$

**Note that $s$ is a good quantum number for this state, but not $s_z$.**

In [49]:
norb = 3
nocc = 2
psi = np.array([[1., 0.], 
                [0., 1.], 
                [0., 0.], 
                [1., 0.], 
                [0., 1.], 
                [0., 0.]]) / np.sqrt(2.)

In [50]:
# Initial spin average and overlap.
G, ovlp = get_greens(psi[:, :nocc], psi[:, :nocc])
savg = spin_utils.get_spin_average_from_greens_ghf(G) * ovlp
s2 = spin_utils.get_spin_square_from_greens_ghf(G) * ovlp
s = calc_s(s2)

print(f'<S> = {savg}')
print(f'<S^2> = {s2}, S = {s}, 2S+1 = {calc_multiplicity(s)}')
print(f'ovlp = {ovlp}')

<S> = [1.+0.j 0.+0.j 0.+0.j]
<S^2> = (1.9999999999999991+0j), S = (0.9999999999999998+0j), 2S+1 = (2.9999999999999996+0j)
ovlp = (0.9999999999999996+0j)


## Project to $s = 1, s_z = -1$

In [28]:
ngrid_z = 5
ngrid_y = 5
s = 1
sz = -1
kets, coeffs = apply_S2_projector(psi[:, :nocc], s, sz, ngrid_z, ngrid_y)
kets = jnp.array(kets).reshape(-1, *kets.shape[4:])
coeffs = jnp.array(coeffs).reshape(-1)

In [29]:
s2_tot, ovlp_tot, s_tot = get_projected_spin_square(kets, coeffs)
savg_tot, _ = get_projected_spin_average(kets, coeffs)

In [30]:
print(f'<S^2> * ovlp = {s2_tot * ovlp_tot}')
print(f'ovlp = {ovlp_tot}')
print(f'<S^2> = {s2_tot}, S = {s_tot}, 2S+1 = {calc_multiplicity(s_tot)}')
print(f'\n<S> = {savg_tot}')

<S^2> * ovlp = (0.17157287525381051-8.890457814381135e-17j)
ovlp = (0.08578643762690524-1.713039432527097e-17j)
<S^2> = (2.0000000000000004-6.369746897629824e-16j), S = (1.0000000000000002-2.123248965876608e-16j), 2S+1 = (3.0000000000000004-4.246497931753216e-16j)

<S> = [ 0.-0.j -0.-0.j -1.-0.j]


In [31]:
ngrid_z = 10
ngrid_y = 10
s = 1
sz = -1
kets, coeffs = apply_S2_projector(psi[:, :nocc], s, sz, ngrid_z, ngrid_y)
kets = jnp.array(kets).reshape(-1, *kets.shape[4:])
coeffs = jnp.array(coeffs).reshape(-1)

In [32]:
s2_tot, ovlp_tot, s_tot = get_projected_spin_square(kets, coeffs)
savg_tot, _ = get_projected_spin_average(kets, coeffs)

In [33]:
print(f'<S^2> * ovlp = {s2_tot * ovlp_tot}')
print(f'ovlp = {ovlp_tot}')
print(f'<S^2> = {s2_tot}, S = {s_tot}, 2S+1 = {calc_multiplicity(s_tot)}')
print(f'\n<S> = {savg_tot}')

<S^2> * ovlp = (0.17157287525380804+7.323785675139582e-17j)
ovlp = (0.08578643762690408+3.1712913545201005e-17j)
<S^2> = (1.9999999999999987+1.143773996499027e-16j), S = (0.9999999999999996+3.8125799883300904e-17j), 2S+1 = (2.999999999999999+7.625159976660181e-17j)

<S> = [-0.+0.j  0.-0.j -1.+0.j]


## Project to $s = 1, s_z = 1$

In [34]:
ngrid_z = 5
ngrid_y = 5
s = 1
sz = 1
kets, coeffs = apply_S2_projector(psi[:, :nocc], s, sz, ngrid_z, ngrid_y)
kets = jnp.array(kets).reshape(-1, *kets.shape[4:])
coeffs = jnp.array(coeffs).reshape(-1)

In [35]:
s2_tot, ovlp_tot, s_tot = get_projected_spin_square(kets, coeffs)
savg_tot, _ = get_projected_spin_average(kets, coeffs)

In [36]:
print(f'<S^2> * ovlp = {s2_tot * ovlp_tot}')
print(f'ovlp = {ovlp_tot}')
print(f'<S^2> = {s2_tot}, S = {s_tot}, 2S+1 = {calc_multiplicity(s_tot)}')
print(f'\n<S> = {savg_tot}')

<S^2> * ovlp = (0.17157287525380943+1.52655665885959e-16j)
ovlp = (0.0857864376269047+8.673617379884035e-17j)
<S^2> = (2.0000000000000004-2.426570246716147e-16j), S = (1.0000000000000002-8.088567489053822e-17j), 2S+1 = (3.0000000000000004-1.6177134978107645e-16j)

<S> = [0.+0.j 0.-0.j 1.-0.j]


In [37]:
ngrid_z = 10
ngrid_y = 10
s = 1
sz = 1
kets, coeffs = apply_S2_projector(psi[:, :nocc], s, sz, ngrid_z, ngrid_y)
kets = jnp.array(kets).reshape(-1, *kets.shape[4:])
coeffs = jnp.array(coeffs).reshape(-1)

In [38]:
s2_tot, ovlp_tot, s_tot = get_projected_spin_square(kets, coeffs)
savg_tot, _ = get_projected_spin_average(kets, coeffs)

In [39]:
print(f'<S^2> * ovlp = {s2_tot * ovlp_tot}')
print(f'ovlp = {ovlp_tot}')
print(f'<S^2> = {s2_tot}, S = {s_tot}, 2S+1 = {calc_multiplicity(s_tot)}')
print(f'\n<S> = {savg_tot}')

<S^2> * ovlp = (0.1715728752538096-1.984089975648473e-17j)
ovlp = (0.0857864376269048-2.1575623232461538e-17j)
<S^2> = (2+2.717253140854007e-16j), S = (1+9.057510469513358e-17j), 2S+1 = (3+1.8115020939026715e-16j)

<S> = [-0.-0.j  0.+0.j  1.+0.j]


## Project to $s = 1, s_z = 0$

In [40]:
ngrid_z = 5
ngrid_y = 5
s = 1
sz = 0
kets, coeffs = apply_S2_projector(psi[:, :nocc], s, sz, ngrid_z, ngrid_y)
kets = jnp.array(kets).reshape(-1, *kets.shape[4:])
coeffs = jnp.array(coeffs).reshape(-1)

In [41]:
s2_tot, ovlp_tot, s_tot = get_projected_spin_square(kets, coeffs)
savg_tot, _ = get_projected_spin_average(kets, coeffs)

In [42]:
print(f'<S^2> * ovlp = {s2_tot * ovlp_tot}')
print(f'ovlp = {ovlp_tot}')
print(f'<S^2> = {s2_tot}, S = {s_tot}, 2S+1 = {calc_multiplicity(s_tot)}')
print(f'\n<S> = {savg_tot}')

<S^2> * ovlp = (0.17157287525380932-1.214306433183765e-16j)
ovlp = (0.08578643762690455-2.7755575615628914e-17j)
<S^2> = (2.0000000000000027-7.684139114601113e-16j), S = (1.0000000000000009-2.5613797048670364e-16j), 2S+1 = (3.0000000000000018-5.122759409734073e-16j)

<S> = [ 0.-0.j -0.+0.j -0.+0.j]


## Project to $s = 0, s_z = 0$

In [51]:
ngrid_z = 5
ngrid_y = 5
s = 0
sz = 0
kets, coeffs = apply_S2_projector(psi[:, :nocc], s, sz, ngrid_z, ngrid_y)
kets = jnp.array(kets).reshape(-1, *kets.shape[4:])
coeffs = jnp.array(coeffs).reshape(-1)

[ 25.01734  57.4205   90.      122.5795  154.98266]


In [52]:
s2_tot, ovlp_tot, s_tot = get_projected_spin_square(kets, coeffs)
savg_tot, _ = get_projected_spin_average(kets, coeffs)

In [53]:
print(f'<S^2> * ovlp = {s2_tot * ovlp_tot}')
print(f'ovlp = {ovlp_tot}')
print(f'<S^2> = {s2_tot}, S = {s_tot}, 2S+1 = {calc_multiplicity(s_tot)}')
print(f'\n<S> = {savg_tot}')

<S^2> * ovlp = (-8.977193988179978e-17+6.5052130349130266e-18j)
ovlp = (-5.572799166575493e-17+1.1926223897340549e-18j)
<S^2> = (1.6126545019805765-0.08221945618329506j), S = (0.86512320798114-0.03011429873237894j), 2S+1 = (2.73024641596228-0.06022859746475788j)

<S> = [-0.00004+0.00194j -0.01175+0.00364j  0.00025-0.01167j]


### Overlap is 0!

# Test with $|+ - \rangle = \frac{1}{2}|s=1, s_z=1 \rangle - \frac{1}{\sqrt{2}}|s=0, s_z=0\rangle + \frac{1}{2}|s=1, s_z=-1 \rangle$

**Note that neither $s_z$ or $s$ are a good quantum numbers for this state.**

In [23]:
norb = 3
nocc = 2
psi = np.array([[1., 0.], 
                [0., 1.], 
                [0., 0.], 
                [1., 0.], 
                [0., -1.], 
                [0., 0.]]) / np.sqrt(2.)

In [24]:
# Initial spin average and overlap.
G, ovlp = get_greens(psi[:, :nocc], psi[:, :nocc])
savg = spin_utils.get_spin_average_from_greens_ghf(G) * ovlp
s2 = spin_utils.get_spin_square_from_greens_ghf(G) * ovlp
s = calc_s(s2)

print(f'<S> = {savg}')
print(f'<S^2> = {s2}, S = {s}, 2S+1 = {calc_multiplicity(s)}')
print(f'ovlp = {ovlp}')

<S> = [0.+0.j 0.+0.j 0.+0.j]
<S^2> = (0.9999999999999996+0j), S = (0.6180339887498947+0j), 2S+1 = (2.2360679774997894+0j)
ovlp = (0.9999999999999996+0j)


## Project to $s = 1, s_z = 1$

In [88]:
ngrid_z = 6
ngrid_y = 6
s = 1
sz = 1
kets, coeffs = apply_S2_projector(psi[:, :nocc], s, sz, ngrid_z, ngrid_y)
kets = jnp.array(kets).reshape(-1, *kets.shape[4:])
coeffs = jnp.array(coeffs).reshape(-1)

In [89]:
s2_tot, ovlp_tot, s_tot = get_projected_spin_square(kets, coeffs)
savg_tot, _ = get_projected_spin_average(kets, coeffs)

In [90]:
print(f'<S^2> * ovlp = {s2_tot * ovlp_tot}')
print(f'ovlp = {ovlp_tot}')
print(f'<S^2> = {s2_tot}, S = {s_tot}, 2S+1 = {calc_multiplicity(s_tot)}')
print(f'\n<S> = {savg_tot}')

<S^2> * ovlp = (1.1188966420050403e-16-2.4858299419545583e-18j)
ovlp = (-5.594483210025203e-17-2.8214667473040744e-18j)
<S^2> = (-1.9926906922027958+0.14493100011699844j), S = (-0.4451537881454159+1.321248954269253j), 2S+1 = (0.10969242370916821+2.642497908538506j)

<S> = [ 0.00851+0.00843j -0.01451-0.00222j  0.01891-0.08622j]


## Project to $s = 1, s_z = -1$

In [91]:
ngrid_z = 5
ngrid_y = 5
s = 1
sz = -1
kets, coeffs = apply_S2_projector(psi[:, :nocc], s, sz, ngrid_z, ngrid_y)
kets = jnp.array(kets).reshape(-1, *kets.shape[4:])
coeffs = jnp.array(coeffs).reshape(-1)

In [92]:
s2_tot, ovlp_tot, s_tot = get_projected_spin_square(kets, coeffs)
savg_tot, _ = get_projected_spin_average(kets, coeffs)

In [93]:
print(f'<S^2> * ovlp = {s2_tot * ovlp_tot}')
print(f'ovlp = {ovlp_tot}')
print(f'<S^2> = {s2_tot}, S = {s_tot}, 2S+1 = {calc_multiplicity(s_tot)}')
print(f'\n<S> = {savg_tot}')

<S^2> * ovlp = (2.1163626406917044e-16-2.775557561562891e-17j)
ovlp = (4.666406150377611e-16+1.6566609195578508e-16j)
<S^2> = (0.3840147273145662-0.19581191992022703j), S = (0.30547457342498696-0.1215506524852847j), 2S+1 = (1.610949146849974-0.2431013049705694j)

<S> = [0.00252-0.00007j 0.00706-0.00531j 0.0086 +0.00066j]


## Project to $s = 1, s_z = 0$

In [94]:
ngrid_z = 5
ngrid_y = 5
s = 1
sz = 0
kets, coeffs = apply_S2_projector(psi[:, :nocc], s, sz, ngrid_z, ngrid_y)
kets = jnp.array(kets).reshape(-1, *kets.shape[4:])
coeffs = jnp.array(coeffs).reshape(-1)

In [95]:
s2_tot, ovlp_tot, s_tot = get_projected_spin_square(kets, coeffs)
savg_tot, _ = get_projected_spin_average(kets, coeffs)

In [96]:
print(f'<S^2> * ovlp = {s2_tot * ovlp_tot}')
print(f'ovlp = {ovlp_tot}')
print(f'<S^2> = {s2_tot}, S = {s_tot}, 2S+1 = {calc_multiplicity(s_tot)}')
print(f'\n<S> = {savg_tot}')

<S^2> * ovlp = (3.5388358909926865e-16-1.4745149545802857e-17j)
ovlp = (2.8275992658421956e-16+4.9873299934333204e-18j)
<S^2> = (1.2502250219730797-0.07419875422191781j), S = (0.725210961361064-0.030279991185963535j), 2S+1 = (2.450421922722128-0.06055998237192707j)

<S> = [ 0.00729-0.0003j   0.00297-0.00022j -0.03158+0.01407j]


## Project to $s = 0, s_z = 0$

In [25]:
ngrid_z = 5
ngrid_y = 5
s = 0
sz = 0
kets, coeffs = apply_S2_projector(psi[:, :nocc], s, sz, ngrid_z, ngrid_y)
kets = jnp.array(kets).reshape(-1, *kets.shape[4:])
coeffs = jnp.array(coeffs).reshape(-1)

In [26]:
s2_tot, ovlp_tot, s_tot = get_projected_spin_square(kets, coeffs)
savg_tot, _ = get_projected_spin_average(kets, coeffs)

In [27]:
print(f'<S^2> * ovlp = {s2_tot * ovlp_tot}')
print(f'ovlp = {ovlp_tot}')
print(f'<S^2> = {s2_tot}, S = {s_tot}, 2S+1 = {calc_multiplicity(s_tot)}')
print(f'\n<S> = {savg_tot}')

<S^2> * ovlp = (-1.0031411194332679e-17-4.570138218538032e-19j)
ovlp = (0.5000000000000001-4.333342374871281e-34j)
<S^2> = (-2.0062822388665355e-17-9.140276437076061e-19j), S = (-5.551115123125783e-17-9.140276437076061e-19j), 2S+1 = (0.9999999999999999-1.8280552874152123e-18j)

<S> = [ 0.-0.j -0.-0.j -0.+0.j]


In [63]:
ngrid_z = 15
ngrid_y = 15
s = 0
sz = 0
kets, coeffs = apply_S2_projector_jax(psi[:, :nocc], s, sz, ngrid_z, ngrid_y)

# Test `apply_S2_projector_jax`

# Test with $|\uparrow \downarrow \rangle = \frac{1}{\sqrt{2}}(|s=0, s_z=0 \rangle + |s=1, s_z=0\rangle)$

**Note that $s_z$ is a good quantum number for this state, but not $s$.**

In [6]:
norb = 3
nocc = 2
psi = np.array([[1., 0.], 
                [0., 0.], 
                [0., 0.], 
                [0., 0.], 
                [0., 1.], 
                [0., 0.]])

In [7]:
# Initial spin average and overlap.
G, ovlp = get_greens(psi[:, :nocc], psi[:, :nocc])
savg = spin_utils.get_spin_average_from_greens_ghf(G) * ovlp
s2 = spin_utils.get_spin_square_from_greens_ghf(G) * ovlp
s = calc_s(s2)

print(f'<S> = {savg}')
print(f'<S^2> = {s2}, S = {s}, 2S+1 = {calc_multiplicity(s)}')
print(f'ovlp = {ovlp}')

<S> = [0.+0.j 0.+0.j 0.+0.j]
<S^2> = (1+0j), S = (0.6180339887498949+0j), 2S+1 = (2.23606797749979+0j)
ovlp = (1+0j)


## Project to $s = 0$

In [8]:
%%time

ngrid_z = 6
ngrid_y = 6
s = sz = 0
kets, coeffs = apply_S2_projector_jax(psi[:, :nocc], s, sz, ngrid_z, ngrid_y)
kets = jnp.array(kets).reshape(-1, *kets.shape[4:])
coeffs = jnp.array(coeffs).reshape(-1)

CPU times: user 461 ms, sys: 12.5 ms, total: 474 ms
Wall time: 466 ms


In [9]:
s2_tot, ovlp_tot, s_tot = get_projected_spin_square(kets, coeffs)
savg_tot, _ = get_projected_spin_average(kets, coeffs)

In [10]:
print(f'<S^2> * ovlp = {s2_tot * ovlp_tot}')
print(f'ovlp = {ovlp_tot}')
print(f'<S^2> = {s2_tot}, S = {s_tot}, 2S+1 = {calc_multiplicity(s_tot)}')
print(f'\n<S> = {savg_tot}')

<S^2> * ovlp = (0.03153697688498419-8.191988107231445e-20j)
ovlp = (0.4999990463268197-1.8055593228630336e-35j)
<S^2> = (0.06307407407407402-1.63840074644399e-19j), S = (0.059530226238113104-1.4640860043070794e-19j), 2S+1 = (1.1190604524762262-2.9281720086141587e-19j)

<S> = [0.+0.j 0.+0.j 0.-0.j]


In [12]:
%%time

ngrid_z = 10
ngrid_y = 10
s = sz = 0
kets, coeffs = apply_S2_projector_jax(psi[:, :nocc], s, sz, ngrid_z, ngrid_y)
kets = jnp.array(kets).reshape(-1, *kets.shape[4:])
coeffs = jnp.array(coeffs).reshape(-1)

CPU times: user 233 ms, sys: 5 ms, total: 238 ms
Wall time: 229 ms


In [13]:
s2_tot, ovlp_tot, s_tot = get_projected_spin_square(kets, coeffs)
savg_tot, _ = get_projected_spin_average(kets, coeffs)

In [14]:
print(f'<S^2> * ovlp = {s2_tot * ovlp_tot}')
print(f'ovlp = {ovlp_tot}')
print(f'<S^2> = {s2_tot}, S = {s_tot}, 2S+1 = {calc_multiplicity(s_tot)}')
print(f'\n<S> = {savg_tot}')

<S^2> * ovlp = (0.01173091470886807-2.062750040755832e-20j)
ovlp = (0.4999990463268196+0j)
<S^2> = (0.023461874167656452-4.1255079502842395e-20j), S = (0.022935822226452962-3.944564299228408e-20j), 2S+1 = (1.045871644452906-7.889128598456816e-20j)

<S> = [0.+0.j 0.+0.j 0.-0.j]


In [15]:
%%time

ngrid_z = 15
ngrid_y = 15
s = sz = 0
kets, coeffs = apply_S2_projector_jax(psi[:, :nocc], s, sz, ngrid_z, ngrid_y)
kets = jnp.array(kets).reshape(-1, *kets.shape[4:])
coeffs = jnp.array(coeffs).reshape(-1)

CPU times: user 455 ms, sys: 17.2 ms, total: 472 ms
Wall time: 472 ms


In [16]:
s2_tot, ovlp_tot, s_tot = get_projected_spin_square(kets, coeffs)
savg_tot, _ = get_projected_spin_average(kets, coeffs)

In [17]:
print(f'<S^2> * ovlp = {s2_tot * ovlp_tot}')
print(f'ovlp = {ovlp_tot}')
print(f'<S^2> = {s2_tot}, S = {s_tot}, 2S+1 = {calc_multiplicity(s_tot)}')
print(f'\n<S> = {savg_tot}')

<S^2> * ovlp = (3.469446951953614e-18+1.3515619647481362e-20j)
ovlp = (0.49999904632681924-4.043700566828669e-36j)
<S^2> = (6.938907138806513e-18+2.703129085299698e-20j), S = 2.703129085299698e-20j, 2S+1 = (1+5.406258170599396e-20j)

<S> = [0.+0.j 0.+0.j 0.-0.j]


# Test with $|+ - \rangle = \frac{1}{2}|s=1, s_z=1 \rangle - \frac{1}{\sqrt{2}}|s=0, s_z=0\rangle + \frac{1}{2}|s=1, s_z=-1 \rangle$

**Note that neither $s_z$ or $s$ are a good quantum numbers for this state.**

In [18]:
norb = 3
nocc = 2
psi = np.array([[1., 0.], 
                [0., 1.], 
                [0., 0.], 
                [1., 0.], 
                [0., -1.], 
                [0., 0.]]) / np.sqrt(2.)

In [19]:
# Initial spin average and overlap.
G, ovlp = get_greens(psi[:, :nocc], psi[:, :nocc])
savg = spin_utils.get_spin_average_from_greens_ghf(G) * ovlp
s2 = spin_utils.get_spin_square_from_greens_ghf(G) * ovlp
s = calc_s(s2)

print(f'<S> = {savg}')
print(f'<S^2> = {s2}, S = {s}, 2S+1 = {calc_multiplicity(s)}')
print(f'ovlp = {ovlp}')

<S> = [0.+0.j 0.+0.j 0.+0.j]
<S^2> = (0.9999999999999996+0j), S = (0.6180339887498947+0j), 2S+1 = (2.2360679774997894+0j)
ovlp = (0.9999999999999996+0j)


## Project to $s = 0, s_z = 0$

### `numpy`

In [20]:
%%time

ngrid_z = 15
ngrid_y = 15
s = 0
sz = 0
kets, coeffs = apply_S2_projector(psi[:, :nocc], s, sz, ngrid_z, ngrid_y)
kets = jnp.array(kets).reshape(-1, *kets.shape[4:])
coeffs = jnp.array(coeffs).reshape(-1)

CPU times: user 239 ms, sys: 22 ms, total: 261 ms
Wall time: 246 ms


### `jax`

In [21]:
%%time

ngrid_z = 15
ngrid_y = 15
s = 0
sz = 0
kets_jax, coeffs_jax = apply_S2_projector_jax(psi[:, :nocc], s, sz, ngrid_z, ngrid_y)
kets_jax = jnp.array(kets_jax).reshape(-1, *kets_jax.shape[4:])
coeffs_jax = jnp.array(coeffs_jax).reshape(-1)

CPU times: user 251 ms, sys: 16.7 ms, total: 267 ms
Wall time: 254 ms


In [22]:
np.testing.assert_allclose(kets_jax, kets, atol=1e-15) 
np.testing.assert_allclose(coeffs_jax, coeffs, atol=1e-15)

AssertionError: 
Not equal to tolerance rtol=1e-07, atol=1e-15

Mismatched elements: 3375 / 3375 (100%)
Max absolute difference: 0.
Max relative difference: 0.
 x: array([0.000068+0.j, 0.000068+0.j, 0.000068+0.j, ..., 0.000068+0.j,
       0.000068+0.j, 0.000068+0.j])
 y: array([0.000068+0.j, 0.000068+0.j, 0.000068+0.j, ..., 0.000068+0.j,
       0.000068+0.j, 0.000068+0.j])

In [23]:
s2_tot, ovlp_tot, s_tot = get_projected_spin_square(kets_jax, coeffs_jax)
savg_tot, _ = get_projected_spin_average(kets_jax, coeffs_jax)

In [24]:
print(f'<S^2> * ovlp = {s2_tot * ovlp_tot}')
print(f'ovlp = {ovlp_tot}')
print(f'<S^2> = {s2_tot}, S = {s_tot}, 2S+1 = {calc_multiplicity(s_tot)}')
print(f'\n<S> = {savg_tot}')

<S^2> * ovlp = (-8.606119441899708e-18+1.3620304163690817e-21j)
ovlp = (0.4999990463268191+1.128474576789396e-36j)
<S^2> = (-1.721227171356325e-17+2.7240660284755922e-21j), S = (-5.551115123125783e-17+2.7240660284755922e-21j), 2S+1 = (0.9999999999999999+5.4481320569511844e-21j)

<S> = [0.+0.j 0.+0.j 0.-0.j]


# Test `get_real_wavefunction`

In [32]:
ngrid = 5
norb = 3
nocc = 3
kets = np.random.random((ngrid, 2*norb, nocc)) + 1.j * np.random.random((ngrid, 2*norb, nocc))
coeffs = np.random.random(ngrid) + 1.j * np.random.random(ngrid)

for ig in range(ngrid):
    for icol in range(nocc):
        kets[ig, :, icol] /= sp.linalg.norm(kets[ig, :, icol])

In [33]:
real_kets, real_coeffs = get_real_wavefunction(kets, coeffs)
sum_kets = np.sum(kets * coeffs[:, None, None], axis=0)
sum_real_kets = np.sum(real_kets * real_coeffs[:, None, None], axis=0)

print(f'\nsum_kets = \n{sum_kets}')
print(f'\nsum_real_kets = \n{sum_real_kets}')

np.testing.assert_allclose(np.amax(np.absolute(sum_real_kets.imag)), 0., atol=1e-15)


sum_kets = 
[[-0.11522+1.1677j   0.47885+1.10725j -0.05314+1.36802j]
 [ 0.05264+1.26083j  0.22826+0.98494j  0.30906+1.4398j ]
 [ 0.04951+1.2718j  -0.15291+1.55389j -0.14791+1.10968j]
 [-0.1311 +1.26815j -0.04548+1.08604j  0.41152+1.2339j ]
 [ 0.16708+1.28471j -0.26885+1.432j   -0.3988 +1.29482j]
 [ 0.07421+1.12675j  0.54184+1.09631j  0.1007 +0.87478j]]

sum_real_kets = 
[[-0.16294-0.j  0.6772 +0.j -0.07515+0.j]
 [ 0.07445+0.j  0.32281+0.j  0.43707-0.j]
 [ 0.07002-0.j -0.21625-0.j -0.20918+0.j]
 [-0.18541+0.j -0.06432-0.j  0.58198-0.j]
 [ 0.23628+0.j -0.38021+0.j -0.56399-0.j]
 [ 0.10495+0.j  0.76627-0.j  0.14241-0.j]]


# Test `get_energy`

In [35]:
U = 4
nx = 2
ny = 2
nup = 2
ndown = 2
open_x = False
verbose = 3

# -----------------------------------------------------------------------------
# Settings.
lattice = lattices.triangular_grid(nx, ny, open_x=open_x)
n_sites = lattice.n_sites
n_elec = (nup, ndown)
nocc = sum(n_elec)
filling = nocc / (2*n_sites)
if verbose: print(f'\n# Filling factor = {filling}')


# Filling factor = 0.5


In [36]:
integrals = {}
integrals["h0"] = 0.0

h1 = -1.0 * lattice.create_adjacency_matrix()
integrals["h1"] = h1

h2 = np.zeros((n_sites, n_sites, n_sites, n_sites))
for i in range(n_sites):
    h2[i, i, i, i] = U
integrals["h2"] = ao2mo.restore(8, h2, n_sites)

In [37]:
# Get Choleskies.
chol_cut = 1e-10
eri = ao2mo.restore(4, integrals["h2"], n_sites)
chol0 = pyscf_interface.modified_cholesky(eri, max_error=chol_cut)
nchol = chol0.shape[0]
chol = np.zeros((nchol, n_sites, n_sites))

for i in range(nchol):
    for m in range(n_sites):
        for n in range(m + 1):
            triind = m * (m + 1) // 2 + n
            chol[i, m, n] = chol0[i, triind]
            chol[i, n, m] = chol0[i, triind]

chol = chol.reshape(nchol, n_sites**2)
_eri = (chol.T @ chol).reshape((n_sites,) * 4)
max_absdiff = np.amax(np.absolute(h2 - _eri))
print(f'max_absdiff = {max_absdiff}')

max_absdiff = 5.000000413701855e-11


In [38]:
# make dummy molecule
mol = gto.Mole()
mol.nelectron = nocc
mol.incore_anyway = True
mol.spin = abs(n_elec[0] - n_elec[1])
mol.verbose = 3
mol.build()

# GHF.
gmf = scf.GHF(mol)
gmf.get_hcore = lambda *args: sp.linalg.block_diag(integrals["h1"], integrals["h1"])
gmf.get_ovlp = lambda *args: np.eye(2 * n_sites)
gmf._eri = ao2mo.restore(8, integrals["h2"], n_sites)

seed = 262
np.random.seed(seed)
dm_init = np.random.random((2*n_sites, 2*n_sites))
dm_init += dm_init.T.conj()

gmf.kernel(dm_init)
mo1 = gmf.stability(external=True)
gmf = gmf.newton().run(mo1, gmf.mo_occ)
mo1 = gmf.stability(external=True)
gmf = gmf.newton().run(mo1, gmf.mo_occ)
mo1 = gmf.stability(external=True)
gmf = gmf.newton().run(mo1, gmf.mo_occ)
mo1 = gmf.stability(external=True)

ghf_coeff = gmf.mo_coeff
ghf_rdm1 = gmf.make_rdm1()
ao_ovlp = np.eye(n_sites)
epsilon0, mu, spin_axis = spin_utils.spin_collinearity_test(ghf_coeff[:, :nocc], ao_ovlp, verbose=verbose)

converged SCF energy = -1.58788806326219  <S^2> = 2.3595938  2S+1 = 3.2308474
<class 'pyscf.scf.ghf.GHF'> wavefunction has an internal instability
converged SCF energy = -1.76329782852721  <S^2> = 1.5301269  2S+1 = 2.6684279
<class 'pyscf.soscf.newton_ah.SecondOrderGHF'> wavefunction is stable in the internal stability analysis
converged SCF energy = -1.76329782855061  <S^2> = 1.530127  2S+1 = 2.668428
<class 'pyscf.soscf.newton_ah.SecondOrderGHF'> wavefunction is stable in the internal stability analysis
converged SCF energy = -1.76329782855372  <S^2> = 1.530127  2S+1 = 2.668428
<class 'pyscf.soscf.newton_ah.SecondOrderGHF'> wavefunction is stable in the internal stability analysis

# ----------------------
# Spin collinearity test
# ----------------------
# epsilon0 = 5.600759033189529e-07
# non integer value indicates non-collinearity

# minimum mu = 0.2882556229567591
# Value is 0 iff wavefunction is collinear

# If mu = 0, the collinear spin axis is: 
[-0.31794-0.j -0.     +0.j -0

In [39]:
psi = ghf_coeff[:, :nocc]
psiTconj = psi.T.conj()
rotchol = np.zeros((2, nchol, nocc, n_sites), dtype=np.complex128)

for i in range(nchol):
    rotchol[0, i] = psiTconj[:, :n_sites] @ chol[i].reshape((n_sites, n_sites))
    rotchol[1, i] = psiTconj[:, n_sites:] @ chol[i].reshape((n_sites, n_sites))

energy = get_energy(psi, psi, [h1, h1], rotchol, 0.)
energy = get_energy_jax(psi, psi, [h1, h1], rotchol, 0.)
np.testing.assert_allclose(energy, gmf.e_tot)

In [40]:
def get_energy_with_eri(bra, ket, h1, eri, enuc):
    norb = h1.shape[0]
    nocc = bra.shape[1]
    G, ovlp = get_greens(bra, ket)
    Gaa = G[:norb, :norb]
    Gab = G[:norb, norb:]
    Gba = G[norb:, :norb]
    Gbb = G[norb:, norb:]
    Gcharge = Gaa + Gbb
    energy = enuc
    energy += np.trace(h1 @ Gcharge)
    energy += 0.5 * np.einsum('ijkl,lk,ji->', eri, Gcharge, Gcharge) # Ej
    energy -= 0.5 * np.einsum('ijkl,jk,li->', eri, Gaa, Gaa) # Ek
    energy -= 0.5 * np.einsum('ijkl,jk,li->', eri, Gbb, Gbb)
    energy -= 0.5 * np.einsum('ijkl,jk,li->', eri, Gab, Gba)
    energy -= 0.5 * np.einsum('ijkl,jk,li->', eri, Gba, Gab)
    return energy

In [41]:
m = 0.5 * (nup - ndown)
kets, coeffs = apply_Sz_projector_jax(psi, m, 10)
kets = np.array(kets)
coeffs = np.array(coeffs)
kets.shape

(10, 8, 4)

In [47]:
for bra in kets:
    for ket in kets:
        energy_test = get_energy(bra, ket, [h1, h1], rotchol, 0.)
        energy_ref = get_energy_with_eri(bra, ket, h1, h2, 0.)
        np.testing.assert_allclose(energy_test, energy_ref, atol=1e-15)

# Test `get_projected_energy`

## Test 1: UHF state with $s_z = 0$

In [48]:
U = 4
nx = 2
ny = 2
nup = 2
ndown = 2
open_x = False
verbose = 3

# -----------------------------------------------------------------------------
# Settings.
lattice = lattices.triangular_grid(nx, ny, open_x=open_x)
n_sites = lattice.n_sites
n_elec = (nup, ndown)
nocc = sum(n_elec)
filling = nocc / (2*n_sites)
if verbose: print(f'\n# Filling factor = {filling}')


# Filling factor = 0.5


In [49]:
integrals = {}
integrals["h0"] = 0.0

h1 = -1.0 * lattice.create_adjacency_matrix()
integrals["h1"] = h1

h2 = np.zeros((n_sites, n_sites, n_sites, n_sites))
for i in range(n_sites):
    h2[i, i, i, i] = U
integrals["h2"] = ao2mo.restore(8, h2, n_sites)

In [50]:
# Get Choleskies.
chol_cut = 1e-10
eri = ao2mo.restore(4, integrals["h2"], n_sites)
chol0 = pyscf_interface.modified_cholesky(eri, max_error=chol_cut)
nchol = chol0.shape[0]
chol = np.zeros((nchol, n_sites, n_sites))

for i in range(nchol):
    for m in range(n_sites):
        for n in range(m + 1):
            triind = m * (m + 1) // 2 + n
            chol[i, m, n] = chol0[i, triind]
            chol[i, n, m] = chol0[i, triind]

chol = chol.reshape(nchol, n_sites**2)
_eri = (chol.T @ chol).reshape((n_sites,) * 4)
max_absdiff = np.amax(np.absolute(h2 - _eri))
print(f'max_absdiff = {max_absdiff}')

max_absdiff = 5.000000413701855e-11


In [51]:
# make dummy molecule
mol = gto.Mole()
mol.nelectron = nocc
mol.incore_anyway = True
mol.spin = abs(n_elec[0] - n_elec[1])
mol.verbose = 3
mol.build()

# UHF.
umf = scf.UHF(mol)
umf.get_hcore = lambda *args: integrals["h1"]
umf.get_ovlp = lambda *args: np.eye(n_sites)
umf._eri = ao2mo.restore(8, integrals["h2"], n_sites)

seed = 262
np.random.seed(seed)
dm_init = np.random.random((2, n_sites, n_sites))
dm_init += dm_init.transpose(0, 2, 1).conj()

umf.kernel(dm_init)
mo1 = umf.stability(external=True)
umf = umf.newton().run(mo1[0], umf.mo_occ)
mo1 = umf.stability(external=True)
umf = umf.newton().run(mo1[0], umf.mo_occ)
mo1 = umf.stability(external=True)
umf = umf.newton().run(mo1[0], umf.mo_occ)
mo1 = umf.stability(external=True)

uhf_coeff = umf.mo_coeff
uhf_rdm1 = umf.make_rdm1()

psi = np.zeros((2*n_sites, nocc))
psi[:n_sites, :nup] = uhf_coeff[0, :, umf.mo_occ[0] > 0].T
psi[n_sites:, nup:] = uhf_coeff[1, :, umf.mo_occ[1] > 0].T

ao_ovlp = np.eye(n_sites)
epsilon0, mu, spin_axis = spin_utils.spin_collinearity_test(psi, ao_ovlp, verbose=verbose)

converged SCF energy = -1.76329782855449  <S^2> = 1.399794  2S+1 = 2.5688861
<class 'pyscf.scf.uhf.UHF'> wavefunction is stable in the internal stability analysis
<class 'pyscf.scf.uhf.UHF'> wavefunction is stable in the real -> complex stability analysis
<class 'pyscf.scf.uhf.UHF'> wavefunction is stable in the UHF/UKS -> GHF/GKS stability analysis
converged SCF energy = -1.76329782855458  <S^2> = 1.3997941  2S+1 = 2.5688862
<class 'pyscf.soscf.newton_ah.SecondOrderUHF'> wavefunction is stable in the internal stability analysis

WARN: Not enough eigenvectors (len(x0)=1, nroots=3)

<class 'pyscf.soscf.newton_ah.SecondOrderUHF'> wavefunction is stable in the real -> complex stability analysis
<class 'pyscf.soscf.newton_ah.SecondOrderUHF'> wavefunction is stable in the UHF/UKS -> GHF/GKS stability analysis
converged SCF energy = -1.76329782855459  <S^2> = 1.3997941  2S+1 = 2.5688862
<class 'pyscf.soscf.newton_ah.SecondOrderUHF'> wavefunction is stable in the internal stability analysis



In [52]:
m = 0.5 * (nup - ndown)
kets, coeffs = apply_Sz_projector_jax(psi, m, 10)
kets = np.array(kets)
coeffs = np.array(coeffs)

In [53]:
# Projected spin average and overlap.
savg_tot = np.zeros(3, dtype=np.complex128)
ovlp_tot = 0.

for ibra, bra in enumerate(kets):
    for iket, ket in enumerate(kets):
        G, ovlp = get_greens(bra, ket)
        savg = spin_utils.get_spin_average_from_greens_ghf(G) * ovlp
        savg_tot += savg * coeffs[ibra].conj() * coeffs[iket]
        ovlp_tot += ovlp * coeffs[ibra].conj() * coeffs[iket]

In [54]:
print(f'<S> * ovlp = {savg_tot}')
print(f'ovlp = {ovlp_tot}')
print(f'\n<S> = {savg_tot / ovlp_tot}')

<S> * ovlp = [0.+0.j 0.+0.j 0.-0.j]
ovlp = (1.0000000000000007+3.851859888774472e-34j)

<S> = [0.+0.j 0.+0.j 0.-0.j]


In [55]:
e_proj = get_projected_energy(psi, m, [h1, h1], chol, 0., 10)
e_proj = get_projected_energy_jax(psi, m, [h1, h1], chol, 0., 10)
np.testing.assert_allclose(e_proj, umf.e_tot)

## Test 1: UHF state with $s_z = 1.5$

In [56]:
U = 4
nx = 2
ny = 2
nup = 4
ndown = 1
open_x = False
verbose = 3

# -----------------------------------------------------------------------------
# Settings.
lattice = lattices.triangular_grid(nx, ny, open_x=open_x)
n_sites = lattice.n_sites
n_elec = (nup, ndown)
nocc = sum(n_elec)
filling = nocc / (2*n_sites)
if verbose: print(f'\n# Filling factor = {filling}')


# Filling factor = 0.625


In [57]:
integrals = {}
integrals["h0"] = 0.0

h1 = -1.0 * lattice.create_adjacency_matrix()
integrals["h1"] = h1

h2 = np.zeros((n_sites, n_sites, n_sites, n_sites))
for i in range(n_sites):
    h2[i, i, i, i] = U
integrals["h2"] = ao2mo.restore(8, h2, n_sites)

In [58]:
# Get Choleskies.
chol_cut = 1e-10
eri = ao2mo.restore(4, integrals["h2"], n_sites)
chol0 = pyscf_interface.modified_cholesky(eri, max_error=chol_cut)
nchol = chol0.shape[0]
chol = np.zeros((nchol, n_sites, n_sites))

for i in range(nchol):
    for m in range(n_sites):
        for n in range(m + 1):
            triind = m * (m + 1) // 2 + n
            chol[i, m, n] = chol0[i, triind]
            chol[i, n, m] = chol0[i, triind]

chol = chol.reshape(nchol, n_sites**2)
_eri = (chol.T @ chol).reshape((n_sites,) * 4)
max_absdiff = np.amax(np.absolute(h2 - _eri))
print(f'max_absdiff = {max_absdiff}')

max_absdiff = 5.000000413701855e-11


In [59]:
# make dummy molecule
mol = gto.Mole()
mol.nelectron = nocc
mol.incore_anyway = True
mol.spin = abs(n_elec[0] - n_elec[1])
mol.verbose = 3
mol.build()

# UHF.
umf = scf.UHF(mol)
umf.get_hcore = lambda *args: integrals["h1"]
umf.get_ovlp = lambda *args: np.eye(n_sites)
umf._eri = ao2mo.restore(8, integrals["h2"], n_sites)

seed = 262
np.random.seed(seed)
dm_init = np.random.random((2, n_sites, n_sites))
dm_init += dm_init.transpose(0, 2, 1).conj()

umf.kernel(dm_init)
mo1 = umf.stability(external=True)
umf = umf.newton().run(mo1[0], umf.mo_occ)
mo1 = umf.stability(external=True)
umf = umf.newton().run(mo1[0], umf.mo_occ)
mo1 = umf.stability(external=True)
umf = umf.newton().run(mo1[0], umf.mo_occ)
mo1 = umf.stability(external=True)

uhf_coeff = umf.mo_coeff
uhf_rdm1 = umf.make_rdm1()

psi = np.zeros((2*n_sites, nocc))
psi[:n_sites, :nup] = uhf_coeff[0, :, umf.mo_occ[0] > 0].T
psi[n_sites:, nup:] = uhf_coeff[1, :, umf.mo_occ[1] > 0].T

ao_ovlp = np.eye(n_sites)
epsilon0, mu, spin_axis = spin_utils.spin_collinearity_test(psi, ao_ovlp, verbose=verbose)

converged SCF energy = 0.999999999999997  <S^2> = 3.75  2S+1 = 4

WARN: Not enough eigenvectors (len(x0)=1, nroots=3)

<class 'pyscf.scf.uhf.UHF'> wavefunction is stable in the internal stability analysis

WARN: Not enough eigenvectors (len(x0)=1, nroots=3)

<class 'pyscf.scf.uhf.UHF'> wavefunction is stable in the real -> complex stability analysis
<class 'pyscf.scf.uhf.UHF'> wavefunction is stable in the UHF/UKS -> GHF/GKS stability analysis
converged SCF energy = 0.999999999999997  <S^2> = 3.75  2S+1 = 4

WARN: Not enough eigenvectors (len(x0)=1, nroots=3)

<class 'pyscf.soscf.newton_ah.SecondOrderUHF'> wavefunction is stable in the internal stability analysis

WARN: Not enough eigenvectors (len(x0)=1, nroots=3)

<class 'pyscf.soscf.newton_ah.SecondOrderUHF'> wavefunction is stable in the real -> complex stability analysis
<class 'pyscf.soscf.newton_ah.SecondOrderUHF'> wavefunction is stable in the UHF/UKS -> GHF/GKS stability analysis
converged SCF energy = 0.999999999999996  <S^2>

In [60]:
m = 0.5 * (nup - ndown)
kets, coeffs = apply_Sz_projector_jax(psi, m, 10)
kets = np.array(kets)
coeffs = np.array(coeffs)

In [61]:
# Projected spin average and overlap.
savg_tot = np.zeros(3, dtype=np.complex128)
ovlp_tot = 0.

for ibra, bra in enumerate(kets):
    for iket, ket in enumerate(kets):
        G, ovlp = get_greens(bra, ket)
        savg = spin_utils.get_spin_average_from_greens_ghf(G) * ovlp
        savg_tot += savg * coeffs[ibra].conj() * coeffs[iket]
        ovlp_tot += ovlp * coeffs[ibra].conj() * coeffs[iket]

In [62]:
print(f'<S> * ovlp = {savg_tot}')
print(f'ovlp = {ovlp_tot}')
print(f'\n<S> = {savg_tot / ovlp_tot}')

<S> * ovlp = [0. +0.j 0. +0.j 1.5+0.j]
ovlp = (0.9999999999999982-3.25862638659177e-18j)

<S> = [0. +0.j 0. +0.j 1.5+0.j]


In [63]:
e_proj = get_projected_energy(psi, m, [h1, h1], chol, 0., 10)
e_proj = get_projected_energy_jax(psi, m, [h1, h1], chol, 0., 10)
np.testing.assert_allclose(e_proj, umf.e_tot)

## Using a 2-site Ising model as example

In [267]:
norb = 2
nocc = 2
psi = 1./np.sqrt(2.) * np.array([[1., 0.],  # (site 1, up)
                                 [0., 1.],  # (site 2, up) 
                                 [1., 0.],  # (site 1, down)
                                 [0., 1.]]) # (site 2, down)

In [269]:
h1 = np.zeros((2*norb, 2*norb))
eri = np.zeros((2*norb, 2*norb, 2*norb, 2*norb))

for i in range(2*norb):
    spin_i = i // 2
    site_i = i % 2
    
    for j in range(2*norb):
        spin_j = j // 2
        site_j = j % 2
        
        if spin_i == spin_j: eri[i, i, j, j] = 1.
        if spin_i != spin_j: eri[i, i, j, j] = -1.

H = np.array([[0., 1., 0., -1.],
              [1., 0., -1., 0.],
              [0., -1., 0., 1.],
              [-1., 0., 1., 0.]])

In [264]:
get_projected_energy(psi, 0., [h1, h1], chol, 0., 10)

(-22.25000000001807-1.848290669299883e-15j)

# Test `optimize`

In [11]:
U = 4
nx = 4
ny = 4
nup = 8
ndown = 8
bc = 'open_x'
verbose = 3

# -----------------------------------------------------------------------------
# Settings.
# lattice = lattices.triangular_grid(nx, ny, boundary_condition=bc)
lattice = lattices.two_dimensional_grid(nx, ny)
n_sites = lattice.n_sites
n_elec = (nup, ndown)
nocc = sum(n_elec)
filling = nocc / (2*n_sites)
if verbose: print(f'\n# Filling factor = {filling}')


# Filling factor = 0.5


In [12]:
integrals = {}
integrals["h0"] = 0.0

h1 = -1.0 * lattice.create_adjacency_matrix()
integrals["h1"] = h1

h2 = np.zeros((n_sites, n_sites, n_sites, n_sites))
for i in range(n_sites):
    h2[i, i, i, i] = U
integrals["h2"] = ao2mo.restore(8, h2, n_sites)

In [13]:
# Get Choleskies.
chol_cut = 1e-10
eri = ao2mo.restore(4, integrals["h2"], n_sites)
chol0 = pyscf_interface.modified_cholesky(eri, max_error=chol_cut)
nchol = chol0.shape[0]
chol = np.zeros((nchol, n_sites, n_sites))

for i in range(nchol):
    for m in range(n_sites):
        for n in range(m + 1):
            triind = m * (m + 1) // 2 + n
            chol[i, m, n] = chol0[i, triind]
            chol[i, n, m] = chol0[i, triind]

chol = chol.reshape(nchol, n_sites**2)
_eri = (chol.T @ chol).reshape((n_sites,) * 4)
max_absdiff = np.amax(np.absolute(h2 - _eri))
print(f'max_absdiff = {max_absdiff}')

max_absdiff = 5.000000413701855e-11


In [14]:
# make dummy molecule
mol = gto.Mole()
mol.nelectron = nocc
mol.incore_anyway = True
mol.spin = abs(n_elec[0] - n_elec[1])
mol.verbose = 4
mol.build()

# GHF.
gmf = scf.GHF(mol)
gmf.max_cycle = 1000
gmf.get_hcore = lambda *args: sp.linalg.block_diag(integrals["h1"], integrals["h1"])
gmf.get_ovlp = lambda *args: np.eye(2 * n_sites)
gmf._eri = ao2mo.restore(8, integrals["h2"], n_sites)

seed = 262
np.random.seed(seed)
dm_init = np.random.random((2*n_sites, 2*n_sites))
dm_init += dm_init.T.conj()

gmf.kernel(dm_init)
mo1 = gmf.stability(external=True)
gmf = gmf.newton().run(mo1, gmf.mo_occ)
mo1 = gmf.stability(external=True)
gmf = gmf.newton().run(mo1, gmf.mo_occ)
mo1 = gmf.stability(external=True)
gmf = gmf.newton().run(mo1, gmf.mo_occ)
mo1 = gmf.stability(external=True)

ghf_coeff = gmf.mo_coeff
ghf_rdm1 = gmf.make_rdm1()
ao_ovlp = np.eye(n_sites)
epsilon0, mu, spin_axis = spin_utils.spin_collinearity_test(ghf_coeff[:, :nocc], ao_ovlp, verbose=verbose)

System: uname_result(system='Linux', node='g158', release='4.18.0-193.el8.x86_64', version='#1 SMP Fri Mar 27 14:35:58 UTC 2020', machine='x86_64')  Threads 32
Python 3.9.20 (main, Oct  3 2024, 07:27:41) 
[GCC 11.2.0]
numpy 1.26.4  scipy 1.10.1  h5py 3.13.0
Date: Thu Mar 13 23:57:37 2025
PySCF version 2.7.0
PySCF path  /burg/home/su2254/libs/shufay_pyscf
GIT ORIG_HEAD 135db139c2689071600947863b093ecf5f711353
GIT HEAD (branch master) 410d9608a73ad3a507b3008a6f68549288ebfc33

[CONFIG] conf_file None
[INPUT] verbose = 4
[INPUT] num. atoms = 0
[INPUT] num. electrons = 16
[INPUT] charge = 0
[INPUT] spin (= nelec alpha-beta = 2S) = 0
[INPUT] symmetry False subgroup None
[INPUT] Mole.unit = angstrom
[INPUT] Symbol           X                Y                Z      unit          X                Y                Z       unit  Magmom

nuclear repulsion = 0
number of shells = 0
number of NR pGTOs = 0
number of NR cGTOs = 0
basis = sto-3g
ecp = {}
CPU time:       146.86


******** <class 'pyscf.s

In [37]:
(-12.5665545206045--13.62185)/-13.62185 * 100

-7.747078989971995

# CPMC

In [17]:
# make dummy molecule
mol = gto.Mole()
mol.nelectron = nocc
mol.incore_anyway = True
mol.spin = abs(n_elec[0] - n_elec[1])
mol.build()

# UHF.
umf = scf.UHF(mol)
umf.get_hcore = lambda *args: integrals["h1"]
umf.get_ovlp = lambda *args: np.eye(n_sites)
umf._eri = ao2mo.restore(8, integrals["h2"], n_sites)
umf.max_cycle = -1
umf.kernel()

# AFQMC.
run_cpmc = True
nwalkers = 50

filetag = ''
jobid = ''
try: jobid = '.' +  os.environ["SLURM_JOB_ID"]
except: pass
filetag += jobid

pyscf_interface.prep_afqmc(
        umf, basis_coeff=np.eye(n_sites), integrals=integrals, filetag=filetag)

options = {
    "dt": 0.005,
    "n_eql": 20,
    "n_ene_blocks": 1,
    "n_sr_blocks": 5,
    "n_blocks": 100,
    "n_prop_steps": 100,
    "n_walkers": nwalkers,
    "seed": 98,
    "walker_type": "uhf",
    # "trial": "uhf",
    "save_walkers": False,
    #"do_sr": False
}

ham_data, ham, prop, trial, wave_data, sampler, observable, options, MPI = (
    mpi_jax._prep_afqmc(options, filetag=filetag)
)

if run_cpmc:
    if verbose: print(f'\n# Using CPMC propagator...')
    prop = propagation.propagator_cpmc_unrestricted(
        dt=options["dt"],
        n_walkers=options["n_walkers"],
    )

trial = wavefunctions.ghf_cpmc(n_sites, n_elec)
wave_data["mo_coeff"] = ghf_coeff[:, :nocc]
wave_data["rdm1"] = jnp.array([ghf_rdm1[:n_sites, :n_sites], ghf_rdm1[n_sites:, n_sites:]])
ham_data = ham.build_measurement_intermediates(ham_data, trial, wave_data)
ham_data = ham.build_propagation_intermediates(ham_data, prop, trial, wave_data)
ham_data["u"] = U

e_afqmc, err_afqmc = driver.afqmc(
    ham_data, ham, prop, trial, wave_data, sampler, observable, options, MPI,
)

SCF not converged.
SCF energy = -7.20597534695668 after -1 cycles  <S^2> = 5.3290705e-15  2S+1 = 1
#
# Preparing AFQMC calculation
# Calculating Cholesky integrals
# Finished calculating Cholesky integrals
#
# Size of the correlation space:
# Number of electrons: (8, 8)
# Number of basis functions: 16
# Number of Cholesky vectors: 16
#
# Number of MPI ranks: 1
#
# No trial specified in options.
# trial.pkl not found, make sure to construct the trial separately.
# norb: 16
# nelec: (8, 8)
#
# dt: 0.005
# n_eql: 20
# n_ene_blocks: 1
# n_sr_blocks: 5
# n_blocks: 100
# n_prop_steps: 100
# n_walkers: 50
# seed: 98
# walker_type: uhf
# save_walkers: False
# n_ene_blocks_eql: 5
# n_sr_blocks_eql: 10
# orbital_rotation: True
# do_sr: True
# symmetry: False
# ene0: 0.0
# free_projection: False
# n_batch: 1
#

# Using CPMC propagator...
# Equilibration sweeps:
#       Iter      Total block weight   Block energy         Walltime  
#          0      5.000000000e+01      -1.311457148e+01     2.12e+

In [18]:
(-13.58925760844421--13.62185)/-13.62185 * 100

-0.23926552968788925

# AFQMC

In [153]:
# make dummy molecule
mol = gto.Mole()
mol.nelectron = nocc
mol.incore_anyway = True
mol.spin = abs(n_elec[0] - n_elec[1])
mol.build()

# UHF.
umf = scf.UHF(mol)
umf.get_hcore = lambda *args: integrals["h1"]
umf.get_ovlp = lambda *args: np.eye(n_sites)
umf._eri = ao2mo.restore(8, integrals["h2"], n_sites)
umf.max_cycle = -1
umf.kernel()

# AFQMC.
run_cpmc = False
nwalkers = 50

filetag = ''
jobid = ''
try: jobid = '.' +  os.environ["SLURM_JOB_ID"]
except: pass
filetag += jobid

pyscf_interface.prep_afqmc(
        umf, basis_coeff=np.eye(n_sites), integrals=integrals, filetag=filetag)

options = {
    "dt": 0.005,
    "n_eql": 20,
    "n_ene_blocks": 1,
    "n_sr_blocks": 5,
    "n_blocks": 100,
    "n_prop_steps": 100,
    "n_walkers": nwalkers,
    "seed": 98,
    "walker_type": "uhf",
    # "trial": "uhf",
    "save_walkers": False,
    #"do_sr": False
}

ham_data, ham, prop, trial, wave_data, sampler, observable, options, MPI = (
    mpi_jax._prep_afqmc(options, filetag=filetag)
)

if run_cpmc:
    if verbose: print(f'\n# Using CPMC propagator...')
    prop = propagation.propagator_cpmc_unrestricted(
        dt=options["dt"],
        n_walkers=options["n_walkers"],
    )

trial = wavefunctions.ghf_cpmc(n_sites, n_elec)
wave_data["mo_coeff"] = ghf_coeff[:, :nocc]
wave_data["rdm1"] = jnp.array([ghf_rdm1[:n_sites, :n_sites], ghf_rdm1[n_sites:, n_sites:]])
ham_data = ham.build_measurement_intermediates(ham_data, trial, wave_data)
ham_data = ham.build_propagation_intermediates(ham_data, prop, trial, wave_data)
ham_data["u"] = U

e_afqmc, err_afqmc = driver.afqmc(
    ham_data, ham, prop, trial, wave_data, sampler, observable, options, MPI,
)

SCF not converged.
SCF energy = -16.9013249722489 after -1 cycles  <S^2> = 8.8817842e-16  2S+1 = 1
#
# Preparing AFQMC calculation
# Calculating Cholesky integrals
# Finished calculating Cholesky integrals
#
# Size of the correlation space:
# Number of electrons: (2, 2)
# Number of basis functions: 16
# Number of Cholesky vectors: 16
#
# Number of MPI ranks: 1
#
# No trial specified in options.
# trial.pkl not found, make sure to construct the trial separately.
# norb: 16
# nelec: (2, 2)
#
# dt: 0.005
# n_eql: 20
# n_ene_blocks: 1
# n_sr_blocks: 5
# n_blocks: 100
# n_prop_steps: 100
# n_walkers: 50
# seed: 98
# walker_type: uhf
# save_walkers: False
# n_ene_blocks_eql: 5
# n_sr_blocks_eql: 10
# orbital_rotation: True
# do_sr: True
# symmetry: False
# ene0: 0.0
# free_projection: False
# n_batch: 1
#
# Equilibration sweeps:
#       Iter      Total block weight   Block energy         Walltime  
#          0      5.000000000e+01      -1.690896893e+01     8.54e-01   
#          2      4.95

# FCI

In [159]:
mol.verbose = 5
ci = fci.FCI(mol)
e, ci_coeffs = ci.kernel(
    h1e=integrals["h1"], eri=integrals["h2"], norb=n_sites, nelec=n_elec, max_cycle=50000
)

print(ci.spin_square(ci_coeffs, n_sites, n_elec))
if verbose: print(f"\n# fci energy: {e}")


WARN: Not enough memory for FCI solver. The minimal requirement is 9938 MB

davidson 0 1  |r|= 4.69  e= [0.]  max|de|= 6.4e-09  lindep=    1



KeyboardInterrupt



# Optimized projected GHF with CPMC

In [20]:
psi = jnp.array(ghf_coeff[:, :nocc])
h1 = jnp.array(h1, dtype=jnp.complex128)
chol = jnp.array(chol, dtype=jnp.complex128)
proj_energy, opt_psi, energy_iters = optimize_jax(psi, n_elec, [h1, h1], chol, 0., s=0, store_iters=True, ngrid_z=10, ngrid_y=10, method='BFGS', projector='S2', verbose=True)
print(f'# Optimized projected energy = {proj_energy}')


# Initial projected energy = -13.106709070674972


  res = minimize(
  x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)
  x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)


projected E = -13.163531376075138
projected E = -13.167668057135055
projected E = -13.168712169103836
projected E = -13.168724198307183
projected E = -13.168724199457888
projected E = -13.168724199769937
projected E = -13.168724200235548
projected E = -13.168724200277326
projected E = -13.16872420034214
projected E = -13.168724200375314
projected E = -13.1687242003892
projected E = -13.168724200413028
projected E = -13.17070392550067
projected E = -13.170703925563261
projected E = -13.173268999772858
projected E = -13.177580380482288
projected E = -13.182900275697445
projected E = -13.188038057018417
projected E = -13.19525048562122
projected E = -13.204637807954732
projected E = -13.211453813275499
projected E = -13.220093031794379
projected E = -13.22338686990135
projected E = -13.225249819033456
projected E = -13.226787741919741
projected E = -13.227278791915072
projected E = -13.227940652173126
projected E = -13.228430409565936
projected E = -13.228684015614526
projected E = -13.22

In [None]:
psi = jnp.array(ghf_coeff[:, :nocc])
h1 = jnp.array(h1, dtype=jnp.complex128)
chol = jnp.array(chol, dtype=jnp.complex128)
proj_energy, opt_psi = optimize(psi, n_elec, [h1, h1], chol, 0., s=0, ngrid_z=15, ngrid_y=15, method='BFGS', projector='S2')
print(f'# Optimized projected energy = {proj_energy}')

In [21]:
norb = h1.shape[0]
rotchol = build_rotchol(psi.T.conj(), chol.reshape((-1, norb, norb)))
opt_rotchol = build_rotchol(opt_psi.T.conj(), chol.reshape((-1, norb, norb)))
ehf = get_energy_jax(psi, psi, [h1, h1], rotchol, 0.)
opt_ehf = get_energy_jax(opt_psi, opt_psi, [h1, h1], opt_rotchol, 0.)

In [22]:
# Normalize orbitals.
for iorb in range(opt_psi.shape[1]):
    norm = sp.linalg.norm(opt_psi[:, iorb])
    print(norm)
    opt_psi[:, iorb] /= norm
    norm = sp.linalg.norm(opt_psi[:, iorb])
    print(norm)

1.0333997654302778
1.0
1.054392423136029
0.9999999999999999
0.9921533800096151
1.0
1.0000820170468439
1.0
1.0142715763318633
1.0
1.0160291357499784
1.0
1.017407689373733
1.0
1.01775252797117
1.0
1.035450778084449
1.0
1.0665900870075156
1.0
1.0137842991246544
0.9999999999999999
1.0112324810704214
1.0
1.0103369364354193
1.0
1.0096142442167202
1.0
1.0092023507591394
1.0
1.0311725881731864
1.0


In [23]:
# Normalize orbitals.
for iorb in range(opt_psi.shape[1]):
    norm = sp.linalg.norm(opt_psi[:, iorb])
    print(norm)

1.0
0.9999999999999999
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
0.9999999999999999
1.0
1.0
1.0
1.0
1.0


In [24]:
print(f'EHF = {ehf}')
print(f'Optimized EHF = {opt_ehf}')
print(f'Optimized projected E = {proj_energy}')

EHF = (-12.56655452069899+0j)
Optimized EHF = (-12.117003925114378+0j)
Optimized projected E = -13.231003231751043


In [25]:
opt_rdm1 = opt_psi @ opt_psi.T.conj()
print(np.trace(opt_rdm1))
print(np.trace(opt_rdm1[:n_sites, :n_sites]))
print(np.trace(opt_rdm1[n_sites:, n_sites:]))
print(opt_rdm1.shape)

16.0
8.02512490988667
7.97487509011333
(32, 32)


In [26]:
rdm1 = psi @ psi.T.conj()
print(np.trace(rdm1))
print(np.trace(rdm1[:n_sites, :n_sites]))
print(np.trace(rdm1[n_sites:, n_sites:]))
print(rdm1.shape)

16.000000000000007
7.999999997143611
8.000000002856396
(32, 32)


In [27]:
rdm1

Array([[ 0.34231,  0.16114,  0.01868, ...,  0.03732, -0.     ,  0.03732],
       [ 0.16114,  0.65769,  0.16114, ..., -0.     , -0.03732, -0.     ],
       [ 0.01868,  0.16114,  0.34231, ...,  0.03732, -0.     ,  0.03732],
       ...,
       [ 0.03732, -0.     ,  0.03732, ...,  0.65769,  0.16114, -0.01868],
       [-0.     , -0.03732, -0.     , ...,  0.16114,  0.34231,  0.16114],
       [ 0.03732, -0.     ,  0.03732, ..., -0.01868,  0.16114,  0.65769]],      dtype=float64)

In [28]:
opt_rdm1

array([[ 0.46306,  0.13817,  0.01964, ...,  0.04794,  0.00991,  0.0325 ],
       [ 0.13817,  0.74109,  0.14809, ..., -0.00179, -0.04421, -0.01614],
       [ 0.01964,  0.14809,  0.25013, ...,  0.05173,  0.00301,  0.03149],
       ...,
       [ 0.04794, -0.00179,  0.05173, ...,  0.76495,  0.16235, -0.00868],
       [ 0.00991, -0.04421,  0.00301, ...,  0.16235,  0.26261,  0.14393],
       [ 0.0325 , -0.01614,  0.03149, ..., -0.00868,  0.14393,  0.53207]])

In [29]:
evals, evecs = np.linalg.eigh(opt_rdm1)
evals = evals[::-1]
evecs = evecs[:, ::-1]

In [33]:
G, ovlp = get_greens(opt_psi[:, :nocc], opt_psi[:, :nocc])
savg = spin_utils.get_spin_average_from_greens_ghf(G) * ovlp
s2 = spin_utils.get_spin_square_from_greens_ghf(G) * ovlp
s = calc_s(s2)

print(f'<S> = {savg}')
print(f'<S^2> = {s2}, S = {s}, 2S+1 = {calc_multiplicity(s)}')
print(f'ovlp = {ovlp}')

<S> = [0.02193+0.j 0.     +0.j 0.0391 +0.j]
<S^2> = (4.811598549799565+0j), S = (1.7497996688148847+0j), 2S+1 = (4.499599337629769+0j)
ovlp = (0.9691633937900228+0j)


In [34]:
G, ovlp = get_greens(psi[:, :nocc], psi[:, :nocc])
savg = spin_utils.get_spin_average_from_greens_ghf(G) * ovlp
s2 = spin_utils.get_spin_square_from_greens_ghf(G) * ovlp
s = calc_s(s2)

print(f'<S> = {savg}')
print(f'<S^2> = {s2}, S = {s}, 2S+1 = {calc_multiplicity(s)}')
print(f'ovlp = {ovlp}')

<S> = [ 0.+0.j  0.+0.j -0.+0.j]
<S^2> = (4.4371359015919545+0j), S = (1.6649794229026647+0j), 2S+1 = (4.329958845805329+0j)
ovlp = (1.000000000000006+0j)


In [35]:
# make dummy molecule
mol = gto.Mole()
mol.nelectron = nocc
mol.incore_anyway = True
mol.spin = abs(n_elec[0] - n_elec[1])
mol.build()

# UHF.
umf = scf.UHF(mol)
umf.get_hcore = lambda *args: integrals["h1"]
umf.get_ovlp = lambda *args: np.eye(n_sites)
umf._eri = ao2mo.restore(8, integrals["h2"], n_sites)
umf.max_cycle = -1
umf.kernel()

# AFQMC.
run_cpmc = True
nwalkers = 50
walker_type = 'uhf'
init_walkers = None

filetag = ''
jobid = ''
try: jobid = '.' +  os.environ["SLURM_JOB_ID"]
except: pass
filetag += jobid


if walker_type == 'uhf':    
    pyscf_interface.prep_afqmc(
            umf, basis_coeff=np.eye(n_sites), integrals=integrals, filetag=filetag)

elif walker_type == 'ghf':
    init_walkers = jnp.array([opt_psi + 0.j] * nwalkers)
    pyscf_interface.prep_afqmc(
            gmf, basis_coeff=np.eye(n_sites), integrals=integrals, filetag=filetag)

options = {
    "dt": 0.005,
    "n_eql": 20,
    "n_ene_blocks": 1,
    "n_sr_blocks": 5,
    "n_blocks": 100,
    "n_prop_steps": 100,
    "n_walkers": nwalkers,
    "seed": 98,
    "walker_type": walker_type,
    # "trial": "uhf",
    "save_walkers": False,
    #"do_sr": False
}

ham_data, ham, prop, trial, wave_data, sampler, observable, options, MPI = (
    mpi_jax._prep_afqmc(options, filetag=filetag)
)

if run_cpmc:
    if verbose: print(f'\n# Using CPMC propagator...')
        
    if walker_type == 'uhf':
        prop = propagation.propagator_cpmc_unrestricted(
            dt=options["dt"],
            n_walkers=options["n_walkers"],
        )

    elif walker_type == 'ghf':
        prop = propagation.propagator_cpmc_general(
            dt=options["dt"],
            n_walkers=options["n_walkers"],
        )

trial = wavefunctions.ghf_cpmc(n_sites, n_elec)
wave_data["mo_coeff"] = opt_psi

if walker_type == 'uhf':
    wave_data["rdm1"] = jnp.array([ghf_rdm1[:n_sites, :n_sites], ghf_rdm1[n_sites:, n_sites:]])
    # wave_data["rdm1"] = jnp.array([opt_rdm1[:n_sites, :n_sites], opt_rdm1[n_sites:, n_sites:]])
    
elif walker_type == 'ghf':
    wave_data["rdm1"] = opt_rdm1
    
ham_data = ham.build_measurement_intermediates(ham_data, trial, wave_data)
ham_data = ham.build_propagation_intermediates(ham_data, prop, trial, wave_data)
ham_data["u"] = U

e_afqmc, err_afqmc = driver.afqmc(
    ham_data, ham, prop, trial, wave_data, sampler, observable, options, MPI,
    init_walkers=init_walkers
)

SCF not converged.
SCF energy = -7.20597534695668 after -1 cycles  <S^2> = 5.3290705e-15  2S+1 = 1
#
# Preparing AFQMC calculation
# Calculating Cholesky integrals
# Finished calculating Cholesky integrals
#
# Size of the correlation space:
# Number of electrons: (8, 8)
# Number of basis functions: 16
# Number of Cholesky vectors: 16
#
# Number of MPI ranks: 1
#
# No trial specified in options.
# trial.pkl not found, make sure to construct the trial separately.
# norb: 16
# nelec: (8, 8)
#
# dt: 0.005
# n_eql: 20
# n_ene_blocks: 1
# n_sr_blocks: 5
# n_blocks: 100
# n_prop_steps: 100
# n_walkers: 50
# seed: 98
# walker_type: uhf
# save_walkers: False
# n_ene_blocks_eql: 5
# n_sr_blocks_eql: 10
# orbital_rotation: True
# do_sr: True
# symmetry: False
# ene0: 0.0
# free_projection: False
# n_batch: 1
#

# Using CPMC propagator...
# Equilibration sweeps:
#       Iter      Total block weight   Block energy         Walltime  
#          0      5.000000000e+01      -1.353198528e+01     1.09e+

In [36]:
(-13.587--13.62185)/-13.62185 * 100

-0.2558389646046645

In [None]:
# make dummy molecule
mol = gto.Mole()
mol.nelectron = nocc
mol.incore_anyway = True
mol.spin = abs(n_elec[0] - n_elec[1])
mol.build()

# UHF.
umf = scf.UHF(mol)
umf.get_hcore = lambda *args: integrals["h1"]
umf.get_ovlp = lambda *args: np.eye(n_sites)
umf._eri = ao2mo.restore(8, integrals["h2"], n_sites)
umf.max_cycle = -1
umf.kernel()

# AFQMC.
run_cpmc = False
nwalkers = 50
walker_type = 'uhf'
init_walkers = None

filetag = ''
jobid = ''
try: jobid = '.' +  os.environ["SLURM_JOB_ID"]
except: pass
filetag += jobid


if walker_type == 'uhf':    
    pyscf_interface.prep_afqmc(
            umf, basis_coeff=np.eye(n_sites), integrals=integrals, filetag=filetag)

elif walker_type == 'ghf':
    init_walkers = jnp.array([opt_psi + 0.j] * nwalkers)
    pyscf_interface.prep_afqmc(
            gmf, basis_coeff=np.eye(n_sites), integrals=integrals, filetag=filetag)

options = {
    "dt": 0.005,
    "n_eql": 20,
    "n_ene_blocks": 1,
    "n_sr_blocks": 5,
    "n_blocks": 100,
    "n_prop_steps": 100,
    "n_walkers": nwalkers,
    "seed": 98,
    "walker_type": walker_type,
    # "trial": "uhf",
    "save_walkers": False,
    #"do_sr": False
}

ham_data, ham, prop, trial, wave_data, sampler, observable, options, MPI = (
    mpi_jax._prep_afqmc(options, filetag=filetag)
)

if run_cpmc:
    if verbose: print(f'\n# Using CPMC propagator...')
        
    if walker_type == 'uhf':
        prop = propagation.propagator_cpmc_unrestricted(
            dt=options["dt"],
            n_walkers=options["n_walkers"],
        )

    elif walker_type == 'ghf':
        prop = propagation.propagator_cpmc_general(
            dt=options["dt"],
            n_walkers=options["n_walkers"],
        )

trial = wavefunctions.ghf_cpmc(n_sites, n_elec)
wave_data["mo_coeff"] = opt_psi

if walker_type == 'uhf':
    # wave_data["rdm1"] = jnp.array([opt_rdm1[:n_sites, :n_sites], opt_rdm1[n_sites:, n_sites:]])
    wave_data["rdm1"] = jnp.array([ghf_rdm1[:n_sites, :n_sites], ghf_rdm1[n_sites:, n_sites:]])
    
elif walker_type == 'ghf':
    wave_data["rdm1"] = opt_rdm1
    
ham_data = ham.build_measurement_intermediates(ham_data, trial, wave_data)
ham_data = ham.build_propagation_intermediates(ham_data, prop, trial, wave_data)
ham_data["u"] = U

e_afqmc, err_afqmc = driver.afqmc(
    ham_data, ham, prop, trial, wave_data, sampler, observable, options, MPI,
    init_walkers=init_walkers
)