In [None]:
import pyscf
from pyscf import lo
import Driver_SCF
import get_atom_orb
import Util_Mole
from functools import reduce
import numpy
import Chem_Bond_Analysis
from Chem_Bond_Analysis import generate_atm_bas_given_label,occ_label,vir_label
import seaborn, pandas
import matplotlib.pyplot as plt
import Util_Math
import torch
import Util_Pic

This script demonstrates how to resolve the local gauge problem when expanding DM over atom's HF orbitals. 
However, if you use a 3D-equivariant network, this script doesn't help you.

In [None]:
# basis = "6-31G(d)"
# basis = "sto-3g"
basis = "ccpvdz"
atom_bas = get_atom_orb.atom_min_cas_bas(["C","H","O","N"], basis=basis, print_verbose=0)

In [None]:
mol_xyz = '''
C 0.1687934049 1.5251304224 -0.1574705569
C -0.1873762459 0.0619476271 0.1467937419
C 0.5091764497 -0.4399621499 1.3912584954
O 1.1819107901 -1.4356558471 1.4581638410
H 1.2312651068 1.6313725408 -0.3963269270
H -0.4074466801 1.8943168107 -1.0096924649
H -0.0493103389 2.1726626043 0.6985743244
H -1.2690515996 -0.0166806666 0.3285443317
H 0.0627906152 -0.6025047855 -0.6847403572
H 0.3538484078 0.2066337038 2.2887105216
'''

mol_xyz = '''
O 0 0 0
H 1 0 0 
H 0 1 0
'''

# mol_xyz = '''
# H 0.7 0 0 
# H -0.7 1 0
# '''

In [None]:
Chem_Bond_Analysis.analysis_mole_occ_orb(mol_xyz, atom_bas, basis)

In [None]:
# the origin moleculer

chem_bond_analyzer = Chem_Bond_Analysis.ChemBondAnalyzer(
    xyz=mol_xyz, print_verbose=0,basis=basis)
chem_bond_analyzer.atom_bas = atom_bas
chem_bond_analyzer._run_scf()

In [None]:
# get mole bonding graph

mole_graph = chem_bond_analyzer.get_mole_graph()
print(mole_graph)

chem_bond_analyzer.analysis_mole_occ_orb(print_verbose=12)

In [None]:
xyz_list = Util_Mole.get_mol_xyz_list_format(chem_bond_analyzer.mol)

In [None]:
bas,occ = Util_Mole.get_atm_bas_in_mole_fix_local_gauge_problem(chem_bond_analyzer.mol,mole_graph)

# print(bas)

_, bas, loc_occ, bas_occ, _, _ = Chem_Bond_Analysis.generate_atom_basis(chem_bond_analyzer.mol, atom_bas)

print(bas)

In [None]:
mol = chem_bond_analyzer.mol
scf_mol = chem_bond_analyzer.rohf

In [None]:
dma,dmb = scf_mol.make_rdm1(scf_mol.mo_coeff,scf_mol.mo_occ)
dm1 = dma+dmb
print(dm1)

In [None]:
bas = numpy.matrix(bas)
# print(bas)
dm1_atm_bas = reduce(numpy.dot,(bas.I,dm1,bas.I.T))

print(dm1_atm_bas)

Util_Pic.draw_heatmap(dm1_atm_bas, None, None,vmax=2, vmin=-0.5)


In [None]:
dm1[numpy.abs(dm1 - dm1_atm_bas) > 1e-8]

In [None]:
dm1_atm_bas[numpy.abs(dm1 - dm1_atm_bas) > 1e-8]

In [None]:
mol_rotated = Util_Mole.get_rotated_mol_coord(mol,[0.0,0.0,0.0],0.23,numpy.pi/2.5,1.34)

In [None]:
print(mol_rotated)

In [None]:
chem_bond_analyzer_new = Chem_Bond_Analysis.ChemBondAnalyzer(
    xyz=mol_rotated, print_verbose=0,basis=basis)
chem_bond_analyzer_new.atom_bas = atom_bas
chem_bond_analyzer_new._run_scf()

In [None]:
print(chem_bond_analyzer_new.e_tot)
print(chem_bond_analyzer.e_tot)
mol_new = chem_bond_analyzer_new.mol
scf_mol_new = chem_bond_analyzer_new.rohf
mole_graph_new = chem_bond_analyzer_new.get_mole_graph()
bas_new,occ_new = Util_Mole.get_atm_bas_in_mole_fix_local_gauge_problem(chem_bond_analyzer_new.mol,mole_graph_new)

In [None]:
dma,dmb = scf_mol_new.make_rdm1(scf_mol_new.mo_coeff,scf_mol_new.mo_occ)
dm1_new = dma+dmb
# print(dm1_new)
bas_new = numpy.matrix(bas_new)
# print(bas_new)
dm1_atm_bas_new = reduce(numpy.dot,(bas_new.I,dm1_new,bas_new.I.T))

In [None]:
dm1_atm_bas_new[numpy.abs(dm1_atm_bas_new - dm1_atm_bas)>1e-8]

In [None]:
dm1_atm_bas[numpy.abs(dm1_atm_bas_new - dm1_atm_bas)>1e-8]

In [None]:
dm1_atm_bas_new[numpy.abs(dm1_atm_bas_new - dm1_atm_bas)<1e-8]

In [None]:
dm1_atm_bas[numpy.abs(dm1_atm_bas_new - dm1_atm_bas)<1e-8]

# you will see that the only thing left is the phase problem, which may be fixed based on ovlp matrix S

In [None]:
numpy.sum(numpy.abs(dm1_atm_bas_new - dm1_atm_bas)>1e-8)