In [None]:
import pyscf
from pyscf import lo
import Driver_SCF
import get_atom_orb
from functools import reduce
import numpy

import seaborn,pandas
import matplotlib.pyplot as plt

In [None]:
def draw_heatmap(mat, column, indx):
    fig, ax = plt.subplots(figsize=(mat.shape[1],mat.shape[0]))
    seaborn.heatmap(pandas.DataFrame(numpy.round(mat,2),columns = column,index = indx),
                    annot=True,
                    vmax = 100, vmin = 0,
                    xticklabels=True,yticklabels=True,
                    square=True, cmap="Blues")
    ax.set_ylabel("atom",fontsize=18)
    ax.set_xlabel("orbital",fontsize=18)
    plt.show()
                

In [None]:
def generate_atom_basis(mol,atom_min_cas_bas):
    loc = [0]
    res = numpy.zeros((mol.nao,mol.nao))
    nocc = 0
    nval = 0
    for i in range(mol.natm):
        atom_min_cas_bas_ = atom_min_cas_bas[mol.atom_pure_symbol(i)]
        loc_begin = loc[-1]
        loc_end = loc_begin + atom_min_cas_bas_.shape[0]
        loc.append(loc_end)
        res[loc_begin:loc_end,loc_begin:loc_end] = atom_min_cas_bas_
        if mol.atom_pure_symbol(i)!="H":
            nocc += 5
            nval += 4
        else:
            nocc += 1
            nval += 1
    # print(nocc)  
    loc2 = [0]
    loc3 = [0]
    res2 = numpy.zeros((mol.nao,nocc))
    res3 = numpy.zeros((mol.nao,nval))
    loc_now = 0
    loc_now_3 = 0
    for i in range(mol.natm):
        if mol.atom_pure_symbol(i)!="H":
            res2[loc[i]:loc[i+1],loc_now:loc_now+5] = atom_min_cas_bas[mol.atom_pure_symbol(i)][:,:5]
            res3[loc[i]:loc[i+1],loc_now_3:loc_now_3+4] = atom_min_cas_bas[mol.atom_pure_symbol(i)][:,1:5]
            loc_now += 5
            loc_now_3 += 4
            loc2.append(loc_now)
            loc3.append(loc_now_3)
        else:
            res2[loc[i]:loc[i+1],loc_now:loc_now+1] = atom_min_cas_bas[mol.atom_pure_symbol(i)][:,:1]
            res3[loc[i]:loc[i+1],loc_now_3:loc_now_3+1] = atom_min_cas_bas[mol.atom_pure_symbol(i)][:,:1]
            loc_now += 1
            loc_now_3+=1
            loc2.append(loc_now)
            loc3.append(loc_now_3)
    return loc, res, loc2, res2, loc3, res3


In [None]:
def analysis_mole_occ_orb(xyz, basis="6-31G(d)",verbose=0,latex=False):
    mol = pyscf.gto.M(
            verbose=verbose,
            atom=xyz,
            basis=basis,
            spin=0,
            charge=0,
        )
    mol.build()

    rohf = pyscf.scf.ROHF(mol)
    rohf.kernel()

    atom_bas = get_atom_orb.atom_min_cas_bas(["C","H","O","N"], basis=basis)

    # construct atomic basis for mole basis 

    _,bas,loc_occ,bas_occ, _, _ = generate_atom_basis(mol, atom_bas)

    ovlp = mol.intor("int1e_ovlp")
    bas = numpy.matrix(bas)

    nocc = numpy.sum(rohf.mo_occ > 0)

    mole_orb_occ = rohf.mo_coeff[:,:nocc]

    # print()

    indx = []
    atom = []

    for i in range(nocc):
        indx.append(str(i))
    for i in range(mol.natm):
        atom.append(mol.atom_pure_symbol(i))


    # analysis orb 成分 

    print("canonicalized orbitals")

    comp_orb = numpy.zeros((mol.natm,nocc))

    ovlp_atom_occ_mole_occ = reduce(numpy.dot,(bas_occ.T,ovlp,mole_orb_occ))
    ovlp_atom_occ_mole_occ = numpy.square(ovlp_atom_occ_mole_occ)
    for i in range(mol.natm):
        tmp = ovlp_atom_occ_mole_occ[loc_occ[i]:loc_occ[i+1],:]
        tmp = numpy.sum(tmp,axis=0)
        comp_orb[i,:] = tmp * 100
        # print(tmp.shape)
        print("%2s " % (mol.atom_pure_symbol(i)), end= "")
        for comp in tmp:
            print("%8.2f " % (comp*100),end="")
        print("")
    
    draw_heatmap(comp_orb,indx,atom)

    # latex 

    if latex:
        for i in range(mol.natm):
            tmp = ovlp_atom_occ_mole_occ[loc_occ[i]:loc_occ[i+1],:]
            tmp = numpy.sum(tmp,axis=0)
            # print(tmp.shape)
            print("%2s &" % (mol.atom_pure_symbol(i)), end= "")
            for comp in tmp:
                print("%8.2f &" % (comp*100),end="")
            print("\\\\\midrule")

    print("localized orbitals")
    loc_orb = lo.Boys(mol, mole_orb_occ).kernel()
    # mole_orb_occ[:,:] = loc_orb
    
    ovlp_atom_occ_mole_occ = reduce(numpy.dot,(bas_occ.T,ovlp,loc_orb))
    ovlp_atom_occ_mole_occ = numpy.square(ovlp_atom_occ_mole_occ)
    
    comp_orb = numpy.zeros((mol.natm,nocc))

    orb_label = ["1s", "2s", "2p", "2p", "2p"]

    for i in range(mol.natm):
        tmp = ovlp_atom_occ_mole_occ[loc_occ[i]:loc_occ[i+1],:]
        tmp = numpy.sum(tmp,axis=0)
        comp_orb[i,:] = tmp * 100
        print("%2s " % (mol.atom_pure_symbol(i)), end= "")
        for comp in tmp:
            print("%8.2f " % (comp*100),end="")
        print("")

    draw_heatmap(comp_orb,indx,atom)

    if latex:
        for i in range(mol.natm):
            tmp = ovlp_atom_occ_mole_occ[loc_occ[i]:loc_occ[i+1],:]
            tmp = numpy.sum(tmp,axis=0)
            print("%2s &" % (mol.atom_pure_symbol(i)), end="")
            for comp in tmp:
                print("%8.2f &" % (comp*100),end="")
            print("\\\\\midule")

    for i in range(mol.natm):
        tmp = ovlp_atom_occ_mole_occ[loc_occ[i]:loc_occ[i+1],:]
        nocc_atom = 5
        if mol.atom_pure_symbol(i) == "H":
            nocc_atom = 1
        for j in range(nocc_atom):
            print("%2s %2s " % (mol.atom_pure_symbol(i), orb_label[j]), end= "")
            for comp in tmp[j]:
                print("%8.2f " % (comp*100),end="")
            print("")



In [None]:
# 35
analysis_mole_occ_orb('''
C 0.1687934049 1.5251304224 -0.1574705569
C -0.1873762459 0.0619476271 0.1467937419
C 0.5091764497 -0.4399621499 1.3912584954
O 1.1819107901 -1.4356558471 1.4581638410
H 1.2312651068 1.6313725408 -0.3963269270
H -0.4074466801 1.8943168107 -1.0096924649
H -0.0493103389 2.1726626043 0.6985743244
H -1.2690515996 -0.0166806666 0.3285443317
H 0.0627906152 -0.6025047855 -0.6847403572
H 0.3538484078 0.2066337038 2.2887105216
''',basis="ccpvtz", latex=True)

# 137
# analysis_mole_occ_orb('''
# C 0.0347806456 1.2913768586 0.1247459533
# O -0.0124513596 -0.1105695501 0.0035215127
# C -1.3281919535 -0.6320221466 -0.0127607421
# C -1.1967300713 -2.1468266455 0.0064786786
# O -0.5031446939 -2.5873432681 1.1564186604
# H 1.0873414096 1.5846828113 0.1391786759
# H -0.4608000759 1.7890668040 -0.7231438927
# H -0.4453207650 1.6365049394 1.0536386033
# H -1.8839981446 -0.2911083706 0.8759161059
# H -1.8749792357 -0.2956683369 -0.9096196771
# H -2.1909411142 -2.6051178926 0.0283531110
# H -0.6937741941 -2.4732100908 -0.9180616798
# H 0.2895967725 -2.0398442621 1.2052767005
# ''',basis="ccpvtz")

# 180
# analysis_mole_occ_orb('''
# C -0.0271346400 1.5221203086 0.0521305946
# C 0.0323693578 -0.0124749555 -0.0304120233
# C 0.7282693006 -0.5092869150 -1.2711077622
# C 0.1197239999 -0.2618318564 -2.6264349659
# N 1.8326439706 -1.1282801584 -1.0913424504
# O 2.4056288198 -1.5311268206 -2.3124791001
# H 0.9803323016 1.9481482522 0.0827491978
# H -0.5483940531 1.9540764861 -0.8082238646
# H -0.5580317568 1.8383921734 0.9549880432
# H 0.5562345897 -0.4149043749 0.8407760193
# H -0.9901253601 -0.4139258185 -0.0136576368
# H 0.7340383410 0.4356263481 -3.2064085046
# H 0.0848704736 -1.1928804821 -3.1995737774
# H -0.8885473976 0.1490816380 -2.5428579112
# H 3.1970636931 -1.9934940451 -2.0182152584
# ''',basis="ccpvtz")

# 199
analysis_mole_occ_orb('''
O 0.1144614910 1.3501199050 0.2152030541
C 0.0733291800 0.0129344505 0.0032687066
C 0.0776309623 -0.8015496878 -1.0981976451
N 0.0184319312 -2.1238248043 -0.6969565349
C -0.0205120527 -2.1011226968 0.6089495301
N 0.0108864053 -0.8149840425 1.0932893989
H 0.1574570860 1.7926474026 -0.6379727847
H 0.1191646930 -0.5214360219 -2.1394911247
H -0.0709681805 -2.9568779431 1.2645592403
H -0.0073527155 -0.5110280517 2.0524576694
''',basis="ccpvtz",latex=True)


In [None]:
# 146
analysis_mole_occ_orb('''
O -0.1067522688 0.2689337451 0.6057961106
C 0.1092672367 1.1676152927 -0.1749216340
C 1.2790736750 2.0782343341 -0.0588467453
C 1.4964489390 2.7013959092 1.2927537232
N 2.3836126925 1.6552095598 0.8154240580
H -0.5696452368 1.3625356471 -1.0315900759
H 1.5581001544 2.6326128995 -0.9500281947
H 1.9254436948 3.6980650620 1.3385645979
H 0.7690944710 2.4619330341 2.0635464902
H 2.1010635423 0.7682690363 1.2350124001
''',basis="ccpvtz", latex = True)