In [None]:
import os
import tempfile
import sys
import argparse
import torch
import time
import numpy as np

from functools import partial
from line_profiler import LineProfiler
from loguru import logger
from torch import optim
from torch.nn.parallel import DistributedDataParallel as DDP
from pyscf import fci

from utils import setup_seed, Logger, ElectronInfo, Dtype, state_to_string
from utils.pyscf_helper import read_integral, interface
from utils import convert_onv, get_fock_space
from utils.det_helper import DetLUT, select_det, sort_det
from utils.distributed import get_rank
from utils.loggings import dist_print
from utils.pyscf_helper.dice_pyscf import read_dice_wf, run_shci
from vmc.ansatz import DecoderWaveFunction
from vmc.optim import VMCOptimizer
from ci_vmc.hybrid import NqsCi
from ci import unpack_ucisd, ucisd_to_fci, fci_revise, CIWavefunction
from torchinfo import summary


In [None]:
device = "cpu"
atom: str = ""
bond = 1.50
for k in range(6):
    atom += f"H, 0.00, 0.00, {k * bond:.3f} ;"
integral_file = tempfile.mkstemp()[1]
sorb, nele, e_lst, fci_amp, ucisd_amp, mf = interface(
    atom, integral_file=integral_file, cisd_coeff=True,
    basis="sto-3g",
    localized_orb=False,
    localized_method="meta-lowdin",
)
logger.info(e_lst)
h1e, h2e, ci_space, ecore, sorb = read_integral(
    integral_file,
    nele,
    # save_onstate=True,
    # external_onstate="profiler/H12-1.50",
    # ##given_sorb= (nele + 2),
    device=device,
    # prefix="test-onstate",
)

In [None]:
from utils.public_function import get_fock_space
from utils.public_function import torch_sort_onv

from libs.C_extension import get_comb_tensor, get_hij_torch, onv_to_tensor


In [None]:
fock_space = get_fock_space(sorb=sorb)
idx = torch_sort_onv(fock_space)
fock_space = fock_space[idx]
# fock_space.numpy().view(np.uint64)
fock_space_state = ((onv_to_tensor(fock_space, sorb=sorb) + 1)/2).to(torch.int64)
fock_space_state, fock_space.numpy().view(np.uint64)

In [None]:
dim = fock_space.size(0)

In [None]:
Ham = torch.zeros(dim, dim, dtype=torch.double)
for i in range(0, fock_space.shape[0]):
    x = fock_space_state[i]
    alpha = x[::2].sum().item()
    beta = x[1::2].sum().item()
    nele = alpha + beta
    comb = get_comb_tensor(fock_space[i].reshape(1, -1), sorb, nele, alpha, beta, True)[0]
    # print(comb.shape)
    # print(comb.squeeze(0))
    hij = get_hij_torch(fock_space[i].reshape(1, -1), comb.squeeze(0),h1e, h2e, sorb, nele)
    # print(hij)
    Ham[i][comb.view(torch.int64).flatten()] = hij
x = torch.linalg.eigh(Ham)
e_fci = e_lst[0]
assert abs(x[0][0].item() + ecore - e_fci) < 1.0e-10

In [None]:
Ham_cuda = torch.zeros(dim, dim, dtype=torch.double, device="cuda")
fock_space_cuda = fock_space.to("cuda")
fock_space_state_cuda = fock_space_state.to("cuda")
h1e_cuda = h1e.to("cuda")
h2e_cuda = h2e.to("cuda")

for i in range(0, fock_space.shape[0]):
    x = fock_space_state_cuda[i]
    alpha = x[::2].sum().item()
    beta = x[1::2].sum().item()
    nele = alpha + beta
    comb = get_comb_tensor(fock_space_cuda[i].reshape(1, -1), sorb, nele, alpha, beta, True)[0]
    # print(comb.shape)
    # print(comb.squeeze(0))
    hij = get_hij_torch(fock_space_cuda[i].reshape(1, -1), comb.squeeze(0), h1e_cuda, h2e_cuda, sorb, nele)
    # print(hij)
    Ham_cuda[i][comb.view(torch.int64).flatten()] = hij
x = torch.linalg.eigh(Ham_cuda)
e_fci = e_lst[0]
assert abs(x[0][0].item() + ecore - e_fci) < 1.0e-10

In [None]:
data = []
col = []
row = []

for i in range(0, fock_space.shape[0]):
    x = fock_space_state[i]
    alpha = x[::2].sum().item()
    beta = x[1::2].sum().item()
    nele = alpha + beta
    comb = get_comb_tensor(fock_space[i].reshape(1, -1), sorb, nele, alpha, beta, True)[0]
    # print(comb.shape)
    # print(comb.squeeze(0))
    hij = get_hij_torch(fock_space[i].reshape(1, -1), comb.squeeze(0),h1e, h2e, sorb, nele)
    # print(hij)
    data.append(hij.cpu().numpy().flatten())
    col.append(np.array([i] * comb.size(1), dtype=np.int64))
    row.append(comb.cpu().numpy().view(np.uint64).flatten())

In [None]:
data[3], col[3], row[3]

In [None]:
data = np.concatenate(data)
col = np.concatenate(col)
row = np.concatenate(row)

In [None]:
from scipy.sparse import csr_matrix
x = csr_matrix((data, (row, col)), shape=(dim, dim))

In [None]:
type(x)

In [None]:
import scipy
scipy.sparse.linalg.eigsh(x)[0][0] + ecore

In [None]:
e_lst[0]

In [None]:
hij = get_hij_torch(ci_space, ci_space, h1e, h2e, 12, 6)

torch.linalg.eigh(hij)[0][0].item() + ecore