In [1]:
from data_utils import *

from torch.utils.data import DataLoader

from data_utils import *
from scf import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
mol_data = get_demo_geometry()
print(mol_data)

Data(pos=tensor([[-1.2700e-02,  1.0858e+00,  8.0000e-03],
        [ 2.2000e-03, -6.0000e-03,  2.0000e-03],
        [ 1.0117e+00,  1.4638e+00,  3.0000e-04],
        [-5.4080e-01,  1.4475e+00, -8.7660e-01],
        [-5.2380e-01,  1.4379e+00,  9.0640e-01]]), z=tensor([6, 1, 1, 1, 1]))


In [3]:
nb = noised_batch(mol_data,0.1,100)
basis = '6-31G(d)'
ba_HF, mols = HF_transform_batch(nb, basis, need_mols=True)

loader = DataLoader(ba_HF, batch_size=10)

In [4]:
def converge_func(i, dP_norm, dE): 
    dP_max = dP_norm.max().item()
    dE_max = dE.max().item()
    print(f'iter: {i}, max of dP: {dP_max}, max of dE: {dE_max}')
    if dE_max < 1e-10:
        print('converged')
        return True
    else:
        return False

In [5]:
res_list = []

device = 'cuda' if torch.cuda.is_available() else 'cpu'

for HFData in loader:
    e_nuc = HFData.e_nuc.to(device)
    h_core = HFData.h_core.to(device)
    overlap = HFData.overlap.to(device)
    ele_repul = HFData.ele_repul.to(device)
    res = solve_rhf(HFData.n_ele.max(), e_nuc, h_core, overlap, ele_repul, converge_func)
    res_list.append(res)

res = torch.cat(res_list).cpu()

  n_occ = n_ele // 2


iter: 1, max of dP: 7.816349401075277, max of dE: 18.55626369254619
iter: 2, max of dP: 2.137213958330434, max of dE: 11.2083529350069
iter: 3, max of dP: 1.1655613539144487, max of dE: 6.273600479105234
iter: 4, max of dP: 0.4816218582193564, max of dE: 2.983963918856851
iter: 5, max of dP: 0.22922805727148657, max of dE: 1.3910217795353361
iter: 6, max of dP: 0.09915662472444114, max of dE: 0.631454422213146
iter: 7, max of dP: 0.04544689308941352, max of dE: 0.28413945628017245
iter: 8, max of dP: 0.020192257364633016, max of dE: 0.1281201462539343
iter: 9, max of dP: 0.00912858053683705, max of dE: 0.057555153367303546
iter: 10, max of dP: 0.004093983395370822, max of dE: 0.02590641048929143
iter: 11, max of dP: 0.0018426370302200414, max of dE: 0.011645035783026003
iter: 12, max of dP: 0.0008285420898535579, max of dE: 0.0052384581041309275
iter: 13, max of dP: 0.00037248161676099837, max of dE: 0.002355449626072925
iter: 14, max of dP: 0.0001675861434001708, max of dE: 0.00105936

In [6]:
def check(mols, max_cycle=200, conv_tol=1e-10):
    pyscf_energy = []
    for mol in mols:
        mf = pyscf.scf.RHF(mol)
        mf.max_cycle = max_cycle
        mf.conv_tol = conv_tol
        energy = mf.kernel()
        pyscf_energy.append(energy)

    return pyscf_energy
ck = check(mols)
ck = torch.tensor(ck)

converged SCF energy = -40.1806954827713
converged SCF energy = -40.1837296470681
converged SCF energy = -40.0790603952915
converged SCF energy = -40.1236462224083
converged SCF energy = -40.1681142733156
converged SCF energy = -40.1057065765928
converged SCF energy = -40.1582076670034
converged SCF energy = -40.1547718427184
converged SCF energy = -40.178126311586
converged SCF energy = -40.1501718477897
converged SCF energy = -40.1436922628804
converged SCF energy = -40.1607107862783
converged SCF energy = -40.1426167865458
converged SCF energy = -40.1569732209995
converged SCF energy = -40.1597056293018
converged SCF energy = -39.8870849850997
converged SCF energy = -40.1604008054682
converged SCF energy = -40.153815851601
converged SCF energy = -40.1439187231269
converged SCF energy = -40.1701416449312
converged SCF energy = -40.1499143927174
converged SCF energy = -40.1815208630641
converged SCF energy = -40.0681299969199
converged SCF energy = -40.1447300930589
converged SCF ener

In [7]:
print('Solutions by pyscf:')
print(ck)
print('Solutions:')
print(res)
print('max-error:')
print((ck-res).abs().max())

Solutions by pyscf:
tensor([-40.1807, -40.1837, -40.0791, -40.1236, -40.1681, -40.1057, -40.1582,
        -40.1548, -40.1781, -40.1502, -40.1437, -40.1607, -40.1426, -40.1570,
        -40.1597, -39.8871, -40.1604, -40.1538, -40.1439, -40.1701, -40.1499,
        -40.1815, -40.0681, -40.1447, -40.1579, -40.1206, -40.0763, -40.0760,
        -40.1610, -40.1407, -40.1605, -40.0348, -40.1402, -40.1710, -40.1172,
        -40.1742, -40.1496, -40.1411, -40.1065, -40.1335, -40.1589, -40.1608,
        -40.1572, -40.1151, -40.1361, -40.1042, -40.0528, -40.1212, -40.0857,
        -40.1274, -40.0422, -40.0633, -40.1046, -40.1797, -40.0779, -40.1659,
        -40.1444, -40.1552, -40.0285, -40.1542, -40.1136, -40.1429, -40.1374,
        -40.1077, -40.1137, -40.1422, -40.1439, -40.0377, -39.8698, -40.1801,
        -40.1580, -40.1374, -40.0811, -39.9366, -40.1219, -40.1111, -40.1125,
        -40.1701, -40.1260, -40.1737, -40.1677, -40.1281, -40.0609, -40.1441,
        -40.1258, -40.0856, -40.1703, -40.17