In [12]:
import time
import torch
import numpy as np 
from torch import Tensor

import libs.hij_tensor as pt
import libs.py_fock as fock
import libs.py_integral as integral

from vmc.PublicFunction import unit8_to_bit, check_para

In [13]:
def string_to_lst(sorb: int, string: str):
    arr = np.array(list(map(int, string)))[::-1]
    lst = [0] * ((sorb-1)//64 +1)*8
    for i in range((sorb-1)//8+1):
        begin = i * 8
        end = (i+1) * 8 if (i+1)*8 < sorb else sorb
        idx = arr[begin:end]
        lst[i] = np.sum(2**np.arange(len(idx)) * idx)

    return lst

In [14]:
chain_len = 10
integral_file = f"integral/rmole-N2.info"
int2e, int1e, ecore = integral.load(integral.two_body(), integral.one_body(), 0.0, integral_file)
sorb = int2e.sorb
# nele = chain_len
nele = 14
alpha_ele = 7 # nele//2 
beta_ele = 7  # nele//2
device = "cuda"
space = fock.get_fci_space(int(sorb//2), alpha_ele, beta_ele)
dim = len(space)

# h1e/h2e 
h1e = torch.tensor(int1e.data, dtype=torch.float64).to(device)
h2e = torch.tensor(int2e.data, dtype=torch.float64).to(device)

# bra/ket
lst = []
for i in range(dim):
    lst.append(string_to_lst(sorb, space[i].to_string()))
onstate1 = torch.tensor(lst, dtype=torch.uint8).to(device)
# onstate2 = torch.tensor(lst, dtype=torch.uint8).to(device)
print(onstate1.shape)


integral::load fname = integral/rmole-N2.info
sorb = 20
size(int1e) = 400:0.00305176MB:2.98023e-06GB
size(int2e) = 18335:0.139885MB:0.000136606GB
----- TIMING FOR integral::load_integral : 5.000e-03 S -----
torch.Size([14400, 8])


In [16]:
print(onstate1[:3])

tensor([[255,  63,   0,   0,   0,   0,   0,   0],
        [255, 159,   0,   0,   0,   0,   0,   0],
        [255, 183,   0,   0,   0,   0,   0,   0]], device='cuda:0',
       dtype=torch.uint8)


In [21]:
# TODO:  where using detach()???从计算图脱离出来，不计算梯度
def local_energy(x: Tensor, h1e: Tensor, h2e: Tensor, ansatz, sorb: int, nele: int,) ->tuple[Tensor, Tensor]:
    """
    Calculate the local energy for given state.
    E_loc(x) = \sum_x' psi(x')/psi(x) * <x|H|x'> 
    1. the all Signles and Doubles excitions about given state using cpu:
        x: (1, sorb)/(batch, sorb) -> comb_x: (batch, ncomb, sorb)/(ncomb, sorb)
    2. matrix <x|H|x'> (1, ncomb)/(batch, ncomb)
    3. psi(x), psi(comb_x)[ncomb] using NAQS. 
    4. calculate the local energy
    """
    check_para(x)
    # TODO: "get_comb_tensor" in cuda 
    # TODO: python version x->comb_x

    device = x.device
    dim: int   = x.dim()
    batch: int = x.shape[0]
    t0 = time.time_ns()
    comb_x = pt.get_comb_tensor(x.to("cpu"), sorb, nele, True).to(device)
    # calculate matrix <x|H|x'>
    print(f"comb_x delta t0: {(time.time_ns()-t0)/1.0E06:.3f} ms")
    comb_hij = pt.get_hij_torch(x, comb_x, h1e, h2e, sorb, nele) # shape (1, n)/(batch, n)
    t1 =  time.time_ns()
    # TODO: time consuming
    x =  pt.unit8_to_bit(comb_x, sorb)
    print(f"unit8_to_bit t1: {(time.time_ns()-t1)/1.0E06:.3f} ms")

    t2 = time.time_ns()
    psi_x1 = ansatz(x)
    torch.cuda.synchronize()
    print(f"ansatz delta t2: {(time.time_ns()-t2)/1.0E06:.3f} ms")
    # print(rbm.phase(unit8_to_bit(comb_x, sorb))[1])
    # print(rbm.amplitude(unit8_to_bit(comb_x, sorb))[1])
    if dim == 2 and batch == 1:
        eloc  = torch.sum(comb_hij * psi_x1 / psi_x1[..., 0]) # scalar
    elif dim == 2 and batch > 1:
        eloc = torch.sum(torch.div(psi_x1.T, psi_x1[..., 0]).T * comb_hij, -1) # (batch)

    return eloc, psi_x1[..., 0]

# print(local_energy(onstate1[idx].view(1, -1), h1e, h2e, ansatz, sorb, nele)[0])
# t0 = time.time_ns()
# local_energy(onstate1, h1e, h2e, ansatz, sorb, nele)
# delta = time.time_ns() - t0 
# print(f"Cost time: {delta/1.0E06:.3f} ms")

In [22]:
def total_enrgy(x: Tensor, nbatch: int, h1e: Tensor, h2e: Tensor, ansatz, 
                sorb: int, nele: int, device: str="cuda"):
    dim: int = x.shape[0]
    eloc_lst = torch.zeros(dim, dtype=torch.float64).to(device)
    psi_lst = torch.zeros(dim, dtype=torch.float64).to(device)
    idx_lst = torch.arange(dim).to(device)

    # calculate the total energy using splits
    for ons, idx in zip(x.split(nbatch), idx_lst.split(nbatch)):
        eloc_lst[idx], psi_lst[idx] =  local_energy(ons, h1e, h2e, ansatz, sorb, nele)
    
    return (eloc_lst * (psi_lst.pow(2)/(psi_lst.pow(2).sum()))).sum()

from vmc.ansatz import RBM 
rbm = RBM(sorb, sorb*2)
ansatz = rbm.prob

t0 = time.time_ns()
print(onstate1.shape)
e = total_enrgy(onstate1, 5000, h1e, h2e, ansatz, sorb, nele)
torch.cuda.synchronize()
delta = time.time_ns() - t0 
print(e)
print(f"Cost time: {delta/1.0E06:.3f} ms")

# H10
# tensor(-11.0171, device='cuda:0', dtype=torch.float64)
# GPU 3568Mb
# Cost time: 6361.739 ms

torch.Size([14400, 8])
comb_x delta t0: 51.351 ms
ket dim: unit8_to_bit t1: 70.953 ms
3
GPU Hmat initialization time: 0.280032 ms
GPU calculate <n|H|m> time: 1.4145 ms
Total function GPU function time: 1.74099 ms

GPU calculate comb(unit8->bit) time: 65.83ms

ansatz delta t2: 276.543 ms
comb_x delta t0: 38.927 ms
ket dim: 3
GPU Hmat initialization time: 0.251968 ms
GPU calculate <n|H|m> time: 1.28246 ms
Total function GPU function time: 1.56877 ms

GPU calculate comb(unit8->bit) time: 63.6314ms

unit8_to_bit t1: 68.006 ms
ansatz delta t2: 219.972 ms
comb_x delta t0: 34.512 ms
ket dim: 3
GPU Hmat initialization time: 0.282048 ms
GPU calculate <n|H|m> time: 1.14563 ms
Total function GPU function time: 1.48429 ms

GPU calculate comb(unit8->bit) time: 55.7564ms

unit8_to_bit t1: 59.668 ms
ansatz delta t2: 175.141 ms
tensor(-106.0730, device='cuda:0', dtype=torch.float64)
Cost time: 1004.012 ms
