In [None]:
import random
import torch
import time  
import numpy as np 
from torch import nn, Tensor, optim
from torch.nn.parameter import Parameter 
from torch.optim.optimizer import Optimizer, required

from typing import List, Optional, Callable, Tuple

import libs.hij_tensor as pt
import libs.py_fock as fock
import libs.py_integral as integral

from vmc.PublicFunction import check_para, unit8_to_bit, setup_seed
from vmc.ansatz import rRBMWavefunction
from vmc.optim import SR
from vmc.eloc import local_energy, total_energy
from vmc.sample import MCMCSampler

In [None]:
import os
os.system("pwd") 
os.environ["CUDA_LAUNCH_BLOCKING"]= '1'
torch.set_default_dtype(torch.double)
torch.set_printoptions(precision=5)

In [None]:
def string_to_lst(sorb: int, string: str):
    arr = np.array(list(map(int, string)))[::-1]
    lst = [0] * ((sorb-1)//64 +1)*8
    for i in range((sorb-1)//8+1):
        begin = i * 8
        end = (i+1) * 8 if (i+1)*8 < sorb else sorb
        idx = arr[begin:end]
        lst[i] = np.sum(2**np.arange(len(idx)) * idx)

    return lst

chain_len = 2
integral_file = f"../integral/rmole-H2-0.734.info"
int2e, int1e, ecore = integral.load(integral.two_body(), integral.one_body(), 0.0, integral_file)
print(ecore)
sorb = int2e.sorb
nele = 2
alpha_ele = nele//2 
beta_ele = nele//2
# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
space = fock.get_fci_space(int(sorb//2), alpha_ele, beta_ele)
dim = len(space)

# h1e/h2e 
h1e = torch.tensor(int1e.data, dtype=torch.float64).to(device)
h2e = torch.tensor(int2e.data, dtype=torch.float64).to(device)

# bra/ket
lst = []
for i in range(dim):
    lst.append(string_to_lst(sorb, space[i].to_string()))
onstate1 = torch.tensor(lst, dtype=torch.uint8).to(device)
# onstate2 = torch.tensor(lst, dtype=torch.uint8).to(device)
print(onstate1)

In [None]:
def compute_derivs(derivs, Eloc, N_s, p):
    """Computes variational derivatives and update to params"""
    
    # print(sigmas)
    # theta = np.dot(weights.transpose(),
    #                sigmas.transpose()) + B  # n_h x N_s
    # dA = sigmas.transpose()  # n_spins x N_s
    # dB = np.tanh(theta)  # n_h x N_s
    # dW = sigmas.transpose().reshape((n_spins, 1, N_s)) * \
    #      np.tanh(theta.reshape((1, n_h, N_s)))
    # derivs = np.concatenate([dA, dB, dW.reshape(n_spins * n_h, N_s)])
   
    print(f"derivs shape {derivs.shape}")
    # print(f"N_s {N_s}")

    avg_derivs = np.sum(derivs, axis=1, keepdims=True) / N_s
    avg_derivs_mat = np.conjugate(avg_derivs.reshape(derivs.shape[0], 1))
    # print(np.allclose(
    #     avg_derivs_mat.real,
    #     avg_derivs.real
    # ))

    # print(avg_derivs_mat.shape)
    avg_derivs_mat = avg_derivs_mat * avg_derivs.reshape(
        1, derivs.shape[0]
    )
    # print(derivs.shape, avg_derivs.shape, avg_derivs_mat.shape)
    moment2 = np.einsum('ik,jk->ij', np.conjugate(derivs), derivs) / N_s
    # print(f"moment2: {moment2}")
    S_kk = np.subtract(moment2, avg_derivs_mat)
    # print(f"moment2 shape {moment2.shape}")
    # print(Eloc.shape)
    F_p = np.sum(Eloc.transpose() * np.conjugate(derivs), axis=1) / N_s
    F_p -= np.sum(Eloc.transpose(), axis=1) * \
           np.sum(np.conjugate(derivs), axis=1) / (N_s ** 2)
    print(f"F_p shape {F_p.shape}")
    S_kk2 = np.zeros(S_kk.shape, dtype=complex)
    row, col = np.diag_indices(S_kk.shape[0])
    S_kk2[row, col] = 0.02 # np.diagonal(S_kk) * regular(p)
    S_reg = S_kk + S_kk2
    update = np.dot(np.linalg.inv(S_reg), F_p).reshape(derivs.shape[0], 1)
    
    # print(f"S_reg shape {S_reg.shape}")
    # print(f"S_kk shape {S_kk.shape}")
    # print(f"F_p shape {F_p.shape}")
    # print(f"S_reg-1: {np.linalg.inv(S_reg).shape}")
    # print(f"update shape {update.shape}")
    # print(f"ssss {np.dot(np.linalg.inv(S_reg), F_p).shape}") 
    return update

def regular(p, l0=100, b=0.9, l_min=1e-4):
    """
    Lambda regularization parameter for S_kk matrix,
    see supplementary materials
    """

    return max(l0 * (b**p) , l_min)

In [None]:
def sr(eloc: Tensor, grad_total: Tensor, N_state: int, p: int, 
      debug=False, L2_penalty: Tensor = None, opt_gd = False) -> Tensor:


    avg_grad = torch.sum(grad_total, axis=0, keepdim=True)/N_state
    avg_grad_mat = avg_grad.reshape(-1, 1)
    # avg_grad_mat = torch.conj(avg_grad.reshape(-1, 1))
    avg_grad_mat = avg_grad_mat * avg_grad.reshape(1, -1)
    # moment2 = torch.einsum("ki, kj->ij", torch.conj(grad_total), grad_total)/N_state
    moment2 = torch.einsum("ki, kj->ij", grad_total, grad_total)/N_state
    S_kk = torch.subtract(moment2, avg_grad_mat)
    
    F_p = torch.sum(eloc.transpose(1, 0) * grad_total, axis=0)/N_state
    F_p -= torch.sum(eloc.transpose(1, 0), axis=0) * torch.sum(grad_total, axis=0)/(N_state**2)
    if L2_penalty is not None:
        # print(f"L2 re: \n {L2_penalty}")
        F_p += L2_penalty
    # F_p = torch.sum(eloc.transpose(1, 0) * torch.conj(grad_total), axis=0)/N_state
    # F_p -= torch.sum(eloc.transpose(1, 0), axis=0) * torch.sum(torch.conj(grad_total), axis=0)/(N_state**2)

    if opt_gd:
        update = F_p
    else:
        S_kk2 = torch.eye(S_kk.shape[0], dtype=S_kk.dtype, device=S_kk.device) * 0.02
        # S_kk2 = regular(p) * torch.diag(S_kk)
        S_reg = S_kk + S_kk2
        # if debug:
        #     print(f"S_kk.-1", torch.linalg.inv(S_reg))
        update = torch.matmul(torch.linalg.inv(S_reg), F_p).reshape(1, -1)
        # update = torch.matmul(torch.linalg.inv(torch.eye(S_kk.shape[0], dtype=torch.double)), F_p).reshape(1, -1)
    
    return update

def calculate_sr_grad(params: List[Tensor], 
                      grad_save: List[Tensor],
                      eloc: Tensor, 
                      N_state: int,
                      p: int,
                      opt_gd = False, 
                      lr: float = 0.02):
    n_para = len(grad_save)
    param_group = list(params)
    for i in range(n_para):

        L2 = 0.001 * (param_group[i].detach().clone()**2).reshape(-1)
        shape = param_group[i].shape
        dlnpsi = grad_save[i].reshape(N_state, -1) # (N_state, N_para) two dim 
        # print(f"dlnpis shape {dlnpsi.shape}")
        # print(dlnpsi)
        update = sr(eloc, dlnpsi, N_state, p, debug = (i==1), L2_penalty=L2, opt_gd=opt_gd)
        # print(f"grad_comb {grad_comb_lst[i]}")
        # update1 = compute_derivs(grad_comb_lst[i].T.cpu().detach().numpy(), eloc.T.cpu().detach().numpy(), N_state, p)
        # print("sssss", np.allclose(
        #     update.detach().cpu().numpy(),
        #     update1.real.T
        # ))
        if p >= 100:
            if i >= 0:
                print(f"{i}th para")
                print(f"parameter in model\n {param_group[i]}")
                print(f"dpsi \n {param_group[i].grad}")
                print(f"dlnpsi \n{dlnpsi}")
                print(f"update * -lr \n{update.reshape(shape)*(-lr)}")
        # param_group[i].data = param_group[i].data.add(update.reshape(shape_lst[i]), alpha=-lr)
        param_group[i].data.add_(update.reshape(shape), alpha=-lr)


In [None]:
# ecore = 0.00
seed = 42
setup_seed(seed)
e_list =[]
model = rRBMWavefunction(sorb, sorb*2, init_weight=0.001).to(device)
print(model)
analytic_derivative = True
time_sample = []
time_iter = []
# print(model(unit8_to_bit(onstate1, sorb)))

n = 60
debug = True
N = onstate1.shape[0] if debug else n
print(onstate1.shape)

with torch.no_grad():
    nbatch =  len(onstate1)
    e = total_energy(onstate1, nbatch, h1e, h2e, 
                                 model,ecore, sorb, nele, exact=debug)
    e_list.append(e.item())
print(f"begin e is {e}")

from vmc.optim import SR

opt = SR(model.parameters(), lr=0.005, N_state=N)

for p in range(5000):
    if p <= 800:
        initial_state = onstate1[random.randrange(len(onstate1))].clone().detach()
    else:
        initial_state = onstate1[0].clone().detach()
    
    dln_grad_lst = []
    out_lst = []
    t0 = time.time_ns()
    sample = MCMCSampler(model, initial_state, h1e, h2e , n, sorb, nele, 
                        verbose=True, debug_exact=debug, full_space=onstate1)
    state, eloc = sample.run() # eloc [n_sample]
    n_sample = len(state)
    # print("state: ")
    # print(state)
    # print("local energy")
    # print(eloc)
    # TODO: cuda version unit8_to_bit 2D
    sample_state = unit8_to_bit(state, sorb)
    delta = (time.time_ns() - t0)/1.00E09
    time_sample.append(delta)

    if analytic_derivative:
        model.zero_grad()
        grad_sample = model(sample_state, dlnPsi=True) 
        # tuple, length: n_para, shape: (n_sample, param.shape),
    else:
        for i in range(n_sample):
            model.zero_grad()
            # handle = model.register_full_backward_hook(hook_fn_backward)
            psi = model(sample_state[i].requires_grad_())
            out_lst.append(psi.detach().clone())
            psi.backward()
            lst = []
            for para in model.parameters():
                if para.grad is not None:
                    lst.append(para.grad.detach().clone()/psi.detach().clone())
            dln_grad_lst.append(lst)
            # handle.remove()
            del psi, lst

        # print("psi:")
        # print(torch.tensor(out_lst))
        # combine all sample grad => tuple, length: n_para (N_sample, n_para)
        n_para = len(list(model.parameters()))
        grad_comb_lst = []
        for i in range(n_para):
            comb = []
            for j in range(n_sample):
                comb.append(dln_grad_lst[j][i].reshape(1, -1)) 
            grad_comb_lst.append(torch.cat(comb)) # (n_sample, n_para)

        grad_sample = grad_comb_lst
    if p >= 530:
        print("22222")
    # print(f"dln_psi")
    # for i in range(n_para):
    #     print(grad_sample[i][:10])
    calculate_sr_grad(model.parameters(), grad_sample, eloc.reshape(1, -1), n_sample, p+1, lr=0.010, opt_gd=True)
    opt.zero_grad()

    with torch.no_grad():
        nbatch = n_sample
        e = total_energy(state.detach(), nbatch, h1e, h2e, 
                                 model, ecore, sorb, nele, exact=debug)
        e_list.append(e.item())
    print(f"{p} iteration total energy is {e:.5f} \n")
    time_iter.append((time.time_ns() - t0)/1.00E09)
    del dln_grad_lst, out_lst


# x0 = torch.tensor([1.0, 1.0, -1.0, -1.0], dtype=torch.float, requires_grad=True).to(device)
# handle = net.register_full_backward_hook(hook_fn_backward)
# y = net(x0)
# y.backward()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
e = np.array(e_list)
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.plot(np.arange(len(e)), e)
print(e[-10])
plt.show()
plt.savefig(r"H2-0.735.png", dpi=1000)
plt.close()