In [None]:
import pyscf
from pyscf import lo
import Driver_SCF
import get_atom_orb
import Util_Mole
from functools import reduce
import numpy
import Chem_Bond_Analysis
from Chem_Bond_Analysis import generate_atm_bas_given_label,occ_label,vir_label
import seaborn, pandas
import matplotlib.pyplot as plt
import Util_Math
import torch
import Util_Pic

In [None]:
basis = "6-31G(d)"
# basis = "sto-3g"
# basis = "ccpvtz"
atom_bas = get_atom_orb.atom_min_cas_bas(["C","H","O","N"], basis=basis, print_verbose=0)

In [None]:
mol_xyz = '''
C 0.1687934049 1.5251304224 -0.1574705569
C -0.1873762459 0.0619476271 0.1467937419
C 0.5091764497 -0.4399621499 1.3912584954
O 1.1819107901 -1.4356558471 1.4581638410
H 1.2312651068 1.6313725408 -0.3963269270
H -0.4074466801 1.8943168107 -1.0096924649
H -0.0493103389 2.1726626043 0.6985743244
H -1.2690515996 -0.0166806666 0.3285443317
H 0.0627906152 -0.6025047855 -0.6847403572
H 0.3538484078 0.2066337038 2.2887105216
'''

In [None]:
chem_bond_analyzer = Chem_Bond_Analysis.ChemBondAnalyzer(
    xyz=mol_xyz, print_verbose=0,basis=basis)
chem_bond_analyzer.atom_bas = atom_bas
chem_bond_analyzer._run_scf()

In [None]:
mole_graph = chem_bond_analyzer.get_mole_graph()

In [None]:
xyz_list = Util_Mole.get_mol_xyz_list_format(chem_bond_analyzer.mol)

In [None]:
def get_atm_bas_in_mole_without_local_gauge_problem(mol, mole_graph):
    basis = mol.basis
    mole_geometric_center = Util_Mole.get_mol_geometric_center(mol)

    res = numpy.zeros((mol.nao, mol.nao))
    occ = numpy.zeros((mol.nao))

    loc_res = 0

    for id_atm in range(mol.natm):

        # 获取与给原子编号相同的原子

        bonded = []
        for id in range(0, mol.natm):
            if id == id_atm:
                continue
            if mole_graph[id_atm, id] > 0:
                bonded.append(id)

        # 构造局域分子

        xyz_partial = [[xyz_list[id_atm][0]+'1', xyz_list[id_atm][1]]]

        for id in bonded:
            xyz_partial.append(xyz_list[id])

        xyz_partial.append(['X', mole_geometric_center])

        basis_list = {
            xyz_list[id_atm][0]+'1': pyscf.gto.basis.load(basis, xyz_list[id_atm][0]),
            'C': 'sto-3g',
            'H': 'sto-3g',
            'O': 'sto-3g',
            'N': 'sto-3g',
            'X': pyscf.gto.basis.load('sto-3g', 'H')
        }

        mol_partial = Util_Mole.get_mol(
            xyz_partial, spin=None, basis=basis_list)

        # make rdm1

        scf = pyscf.scf.ROHF(mol_partial)
        scf.kernel()
        dma, dmb = scf.make_rdm1(scf.mo_coeff, scf.mo_occ)
        dm1 = dma + dmb

        # 抽出局域原子基组, 用了 HF 可能很慢，暂时先看看可行性, expanded over atomic HF ? 

        atom = Util_Mole.get_mol([xyz_partial[0]], spin=None, basis=basis_list)
        Nao = atom.nao

        dm_atm = dm1[:Nao, :Nao]
        atom_orb_rotated = numpy.zeros((Nao, Nao))
        atom_occ = numpy.zeros((Nao))

        loc_now = 0

        for i in range(atom.nbas):
            # print('shell %d on atom %d l = %s has %d contracted GTOs' %
            #       (i, atom.bas_atom(i), atom.bas_angular(i), atom.bas_nctr(i)))
            for _ in range(atom.bas_nctr(i)):
                loc_end = loc_now + 2*atom.bas_angular(i)+1
                if atom.bas_angular(i) == 0:
                    atom_orb_rotated[loc_now:loc_end,
                                     loc_now:loc_end] = 1.0  # s function
                    atom_occ[loc_now:loc_end] = dm_atm[loc_now:loc_end, loc_now:loc_end]
                else:
                    dm_tmp = dm_atm[loc_now:loc_end, loc_now:loc_end]
                    e, m = numpy.linalg.eigh(dm_tmp)  # ascending order
                    # print(e)
                    atom_orb_rotated[loc_now:loc_end, loc_now:loc_end] = m # m 的相位问题靠 ovlp 矩阵消除
                    atom_occ[loc_now:loc_end] = e
                loc_now = loc_end

        res[loc_res:loc_res+Nao, loc_res:loc_res+Nao] = atom_orb_rotated
        occ[loc_res:loc_res+Nao] = atom_occ
        loc_res += Nao

    return res,occ


In [None]:
bas,occ = get_atm_bas_in_mole_without_local_gauge_problem(chem_bond_analyzer.mol,mole_graph)

In [None]:
mol = chem_bond_analyzer.mol
scf_mol = chem_bond_analyzer.rohf

In [None]:
dma,dmb = scf_mol.make_rdm1(scf_mol.mo_coeff,scf_mol.mo_occ)
dm1 = dma+dmb
print(dm1)

In [None]:
bas = numpy.matrix(bas)
print(bas)
dm1_atm_bas = reduce(numpy.dot,(bas.I,dm1,bas.I.T))

In [None]:
dm1[numpy.abs(dm1 - dm1_atm_bas) > 1e-8]

In [None]:
dm1_atm_bas[numpy.abs(dm1 - dm1_atm_bas) > 1e-8]

In [None]:
mol_rotated = Util_Mole.get_rotated_mol_coord(mol,[0.0,0.0,0.0],0.0,numpy.pi/2,0.0)

In [None]:
chem_bond_analyzer_new = Chem_Bond_Analysis.ChemBondAnalyzer(
    xyz=mol_rotated, print_verbose=0,basis=basis)
chem_bond_analyzer_new.atom_bas = atom_bas
chem_bond_analyzer_new._run_scf()

In [None]:
print(chem_bond_analyzer_new.e_tot)
mol_new = chem_bond_analyzer_new.mol
scf_mol_new = chem_bond_analyzer_new.rohf
mole_graph_new = chem_bond_analyzer_new.get_mole_graph()
bas_new,occ_new = get_atm_bas_in_mole_without_local_gauge_problem(chem_bond_analyzer_new.mol,mole_graph_new)

In [None]:
dma,dmb = scf_mol_new.make_rdm1(scf_mol_new.mo_coeff,scf_mol_new.mo_occ)
dm1_new = dma+dmb
# print(dm1_new)
bas_new = numpy.matrix(bas_new)
# print(bas_new)
dm1_atm_bas_new = reduce(numpy.dot,(bas_new.I,dm1_new,bas_new.I.T))

In [None]:
dm1_atm_bas_new[numpy.abs(dm1_atm_bas_new - dm1_atm_bas)>1e-8]

In [None]:
dm1_atm_bas[numpy.abs(dm1_atm_bas_new - dm1_atm_bas)>1e-8]

In [None]:
numpy.sum(numpy.abs(dm1_atm_bas_new - dm1_atm_bas) > 1e-8)