In [1]:
from pathlib import Path

#from ConvergenceAnalysis import *
from DataAnalysisClass import *

In [2]:

single = ['aug-cc-pVDZ', 'aug-cc-pVTZ', 'aug-cc-pVQZ', 'aug-cc-pV5Z', 'aug-cc-pV6Z']
single_polarized = ['aug-cc-pCVDZ', 'aug-cc-pCVTZ', 'aug-cc-pCVQZ']
double = ['d-aug-cc-pVDZ', 'd-aug-cc-pVTZ', 'd-aug-cc-pVQZ', 'd-aug-cc-pV5Z', 'd-aug-cc-pV6Z']
double_polarized = ['d-aug-cc-pCVDZ', 'd-aug-cc-pCVTZ', 'd-aug-cc-pCVQZ']
all_basis_sets = single + single_polarized + double + double_polarized


In [3]:

pal2 = sns.color_palette("coolwarm_r", 4).as_hex()
p2 = [pal2[1], pal2[0], pal2[2], pal2[3]]
pal = sns.color_palette(p2)
p1 = [pal2[1], pal2[2], pal2[0], pal2[3]]
simple_pal = sns.color_palette(p1)
sns.set_style('darkgrid')




In [4]:

database_path = Path('/mnt/data/madness_data/august')
paper_path = Path('/home/adrianhurtado/projects/writing/mra-tdhf-polarizability/Figures')



In [5]:
import glob

# glob for .mol files in august molecules directory
mols = glob.glob('/mnt/data/madness_data/august/molecules/*.mol')
mols = [mol.split('/')[-1].split('.')[0] for mol in mols]

In [6]:
# remove LiH_s from mols
mols.remove('LiH_s')

In [7]:
august_database = PolarizabilityData(mols, 'hf', 'dipole', all_basis_sets, database_path, overwrite=True)


In [8]:
august_database.save_dfs()
polar_data = august_database.iso_data.copy()



In [9]:
august_database.energy_df.basis.unique()

In [10]:
class MRAComparedBasisDF(pd.DataFrame):
    def __init__(self, polar_data, index, values: list, PercentError: bool, *args, **kwargs):
        # Use the special_parameter to modify the DataFrame or perform additional initialization
        basis_data = polar_data.query('basis!="MRA"').copy()
        basis_data = basis_data.set_index(index)

        for value in values:
            basis_data[f'{value}MRA'] = polar_data.query('basis=="MRA"').set_index(index)[
                value]
            if PercentError:
                basis_data[f'{value}E'] = ((basis_data[value] - basis_data[f'{value}MRA']) / basis_data[f'{value}MRA'] * 100)
            else:
                basis_data[f'{value}E'] = (basis_data[value] - basis_data[f'{value}MRA'])
        basis_data = basis_data.reset_index()
        # create a column of percent error in alpha
        basis_data = make_detailed_df(basis_data)
        super().__init__(basis_data, *args, **kwargs)


basis_data = MRAComparedBasisDF(polar_data, ['molecule', 'omega'], ['alpha', 'gamma'], True)
basis_data


In [11]:
num_mols=len(basis_data.molecule.unique())
print(f'Number of molecules: {num_mols}')

In [57]:
from matplotlib.ticker import ScalarFormatter, FormatStrFormatter

Type_map = {"aug-cc-pVnZ": "sn", "aug-cc-pCVnZ": "sCn", "d-aug-cc-pVnZ": "dn",
            "d-aug-cc-pCVnZ": "dCn"}
aspect_ratio = .5
vlevel = ['D', 'T', 'Q', '5']
plot_data = basis_data.query('valence.isin(@vlevel) and omega==0').copy()
# remove unused categories from valence
plot_data.valence = plot_data.valence.cat.remove_unused_categories()
with sns.plotting_context('paper', font_scale=1.00):
    g = sns.catplot(col='valence', x='Type', y='alphaE', hue='mol_system',
                    data=plot_data, kind='strip', dodge=True, jitter=True,
                    s=15, palette='colorblind', height=4, aspect=aspect_ratio, sharey=True,
                    alpha=0.5,
                    sharex=False)
    # adjust the title to read D=$\zeta$D, T=$\zeta$T, etc.
    g.set_titles('{col_name}Z')
    g.set_xlabels('')

    g.set_xticklabels([Type_map[label.get_text()] for label in g.axes.flat[0].get_xticklabels()])
    v = 0
    for ax in g.axes.flat:
        # draw horizontal line at .05%
        # get the valence level from the title
        c = vlevel[v]
        v += 1
        print(c)
        xl = ['aug', 'aug+core', 'd-aug', 'd-aug+core']
        xl = ['s-n', 's-Cn', 'd-n', 'd-Cn']
        xl = ['sn', 'sCn', 'dn', 'dCn']
        ax.set_xticklabels(xl)

        ax.set_yscale('symlog', linthresh=1e-1)
        ax.axhline(y=.05, linestyle='--', color='orange')
        ax.axhline(y=-.05, linestyle='--', color='orange')
        # for the symlog scale to regular notation with no decimal places for the y axis
        ax.yaxis.set_major_formatter(ScalarFormatter())
        ax.yaxis.set_major_formatter(FormatStrFormatter('%2g'))

        # make the y limits symmetrical for each plot using the max absolute value
        max_lim = max(abs(ax.get_ylim()[0]), abs(ax.get_ylim()[1]))
        ax.set_ylim(-30, 30)

        ax.set_ylabel(r'$\alpha$ Error (%)')
        # rotate the xlabels
        for label in ax.get_xticklabels():
            label.set_rotation(0)
            #label.set_horizontalalignment('right')
        # create vertical lines to separate basis sets
        for i in range(1, len(plot_data.Type.unique())):
            ax.axvline(i - .5, linestyle='-.', color='k', linewidth=.5, alpha=.8)
        for i in range(1, 5):
            ax.axvspan((i - 1) - .5, i - .5, 1e-3, facecolor=pal[i - 1], alpha=0.1)

        # remove the legend title

# remove legend and create a new legend in 3 column format outside the plot
g._legend.remove()
g.fig.legend(loc='center', bbox_to_anchor=(0.5, 1.05), ncol=3, fancybox=True, fontsize=10)
g.fig.tight_layout()


In [58]:


Type_map = {"aug-cc-pVnZ": "sn", "aug-cc-pCVnZ": "sCn", "d-aug-cc-pVnZ": "dn",
            "d-aug-cc-pCVnZ": "dCn"}
aspect_ratio = 0.75
vlevel = ['D', 'T', 'Q', '5']
plot_data = basis_data.query('valence.isin(@vlevel) and omega==0').copy()
# remove unused categories from valence
plot_data.valence = plot_data.valence.cat.remove_unused_categories()
with sns.plotting_context('paper', font_scale=1.00):
    facet_kws = {'margin_titles': True}
    g = sns.catplot(row='valence', col='Type', y='alphaE', hue='mol_system',
                    data=plot_data, kind='strip', dodge=True, jitter=False,
                    palette='colorblind', height=2, aspect=aspect_ratio, sharey=True, s=5,
                    facet_kws=facet_kws, alpha=0.7)
    g.set_ylabels(r'$\alpha$ Error (%)')
    g.set_titles(row_template='{row_name}Z', col_template='{col_name}')

    for ax in g.axes.flat:
        ax.axhline(y=.02, linestyle='--', color='black', linewidth=.35)
        ax.axhline(y=-.02, linestyle='--', color='black', linewidth=.35)
        # rotate the xlabels

# remove legend and create a new legend in 3 column format outside the plot
g._legend.remove()
g.fig.tight_layout()
g.fig.legend(loc='center', bbox_to_anchor=(0.5, 1.005), ncol=3, fancybox=True, fontsize=10)
#g.fig.savefig(paper_path.joinpath('gamma_error.png'), dpi=600, bbox_inches='tight')
g.fig.savefig(paper_path.joinpath('alpha_error.png'), dpi=600, bbox_inches='tight')



In [164]:
from matplotlib.ticker import ScalarFormatter, FormatStrFormatter
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
Type_map = {"aug-cc-pVnZ": "sn", "aug-cc-pCVnZ": "sCn", "d-aug-cc-pVnZ": "dn",
            "d-aug-cc-pCVnZ": "dCn"}
vlevel = ['D', 'T', 'Q', '5']
plot_data = basis_data.query('valence.isin(@vlevel) and omega==0').copy()

height = 3.0
aspect_ratio = 1.05
palette = 'colorblind'
kind = 'strip'
dodge = True
jitter = .05
alpha = 0.8
s = 10
facet_kws = {
    'margin_titles': True,
}

# remove unused categories from valencej
plot_data.valence = plot_data.valence.cat.remove_unused_categories()
with sns.plotting_context('paper', font_scale=1.25):
    g = sns.catplot(x='valence', col='Type', y='alphaE', hue='mol_system',
                    data=plot_data,
                    kind=kind, dodge=jitter, jitter=jitter,
                    palette=palette, height=height, aspect=aspect_ratio, sharey=True,
                    facet_kws=facet_kws, alpha=0.7, s=s
                    )
    g.set_ylabels(r'$\alpha$ Error (%)')
    g.set_titles(row_template='{row_name}Z', col_template='{col_name}')
    g.set_xlabels('')
    
    

    types=['aug-cc-pVnZ','aug-cc-pCVnZ','d-aug-cc-pVnZ','d-aug-cc-pCVnZ']
    for i ,ax in enumerate(g.axes.flat):
        ax.axhline(y=.00, linestyle='--', color='black', linewidth=.35)
        ax.set_yscale('symlog', linthresh=1e-1)
        ax.yaxis.set_major_formatter(ScalarFormatter())
        ax.set_ylim(-30, 30)
        ax.axhline(y=.02, linestyle='--', color='red', linewidth=.50)
        ax.axhline(y=-.02, linestyle='--', color='red', linewidth=.50)
        if i>=4: 
            t=types[i]
            inset_ax=inset_axes(ax,width="90%",height="40%",loc="lower center")
            vs=['D','T','Q','5']
            inset_data=plot_data.query('valence.isin(@vs)').copy()
            inset_data.valence = inset_data.valence.cat.remove_unused_categories()
            inset_data=inset_data.query('Type==@t')
            sns.stripplot(data=inset_data,x='valence',y='alphaE',hue='mol_system',ax=inset_ax,legend=False,dodge=True,jitter=.1,alpha=.7,s=2.5,palette='colorblind')
            inset_ax.set_ylabel('')
            inset_ax.set_xlabel('')
            # remove tick marks from x axis
            inset_ax.set_xticks([])
            inset_ax.axhline(y=.02, linestyle='--', color='red', linewidth=.50)
            inset_ax.axhline(y=-.02, linestyle='--', color='red', linewidth=.50)
            
            inset_ax.set_yscale('symlog', linthresh=1e-1)
            inset_ax.yaxis.set_major_formatter(ScalarFormatter())

            # draw red box around the inset
            #inset_ax.add_patch(plt.Rectangle((1.5, -0.25), 2, 0.5, fill=False, edgecolor='black', lw=1,))
            
            # draw red box around the data in regular plot
            #ax.add_patch(plt.Rectangle((1.5, -2.0), 2, 4.0, fill=True, edgecolor='black', lw=1,facecolor='grey',alpha=.10))
    
        #ax.axhline(y=-.02, linestyle='--', color='black', linewidth=.35)
        # rotate the xlabels

# remove legend and create a new legend in 3 column format outside the plot
g._legend.remove()
g.fig.tight_layout()
g.fig.legend(loc='center', bbox_to_anchor=(0.5, 1.015), ncol=3, fancybox=True, fontsize=10)
#g.fig.savefig(paper_path.joinpath('gamma_error.png'), dpi=600, bbox_inches='tight')
g.fig.savefig(paper_path.joinpath('alpha_error.png'), dpi=600, bbox_inches='tight')



In [163]:

jitter = .05
alpha = 0.8
vlevel = ['D', 'T', 'Q', '5']
plot_data = basis_data.query('valence.isin(@vlevel) and omega==0 & gamma > 1e-3').copy()
# remove unused categories from valence
plot_data.valence = plot_data.valence.cat.remove_unused_categories()
aspect_ratio =1.05 
with sns.plotting_context('paper', font_scale=1.00):
    facet_kws = {'margin_titles': True, 'despine': False}
    g = sns.catplot(x='valence', col='Type', y='gammaE', hue='mol_system',
                    data=plot_data,
                    kind=kind, dodge=jitter, jitter=jitter,
                    palette=palette, height=height, aspect=aspect_ratio, sharey=True,
                    facet_kws=facet_kws, alpha=alpha, s=s
                    )
    # adjust the title to read D=$\zeta$D, T=$\zeta$T, etc.
    g.set_titles(row_template='{row_name}Z', col_template='{col_name}')

    g.set_ylabels(r'$\gamma$ Error (%)')
    g.set_xlabels('')
    for i, ax in enumerate(g.axes.flat):
        ax.axhline(y=0, linestyle='--', color='black', linewidth=.35)
        ax.set_ylim(-500, 500)
        ax.axhline(y=.2, linestyle='--', color='red', linewidth=.50)
        ax.axhline(y=-.2, linestyle='--', color='red', linewidth=.50)
        ax.set_yscale('symlog', linthresh=1)
        ax.yaxis.set_major_formatter(ScalarFormatter())

g._legend.remove()
g.fig.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=3, fancybox=True, fontsize=10)
g.fig.tight_layout()
g.fig.savefig(paper_path.joinpath('gamma_error.png'), dpi=600, bbox_inches='tight')


In [61]:
vlevel = ['D', 'T', 'Q', '5']
basis_e_data = MRAComparedBasisDF(august_database.energy_df, ['molecule'], ['energy'], False)
plot_data = basis_e_data.query('valence.isin(@vlevel)').copy()
# remove unused categories from valence
# remove unused categories from valence
plot_data.valence = plot_data.valence.cat.remove_unused_categories()
with sns.plotting_context('paper', font_scale=1.00):
    facet_kws = {'margin_titles': True, 'despine': False}
    g = sns.catplot(x='valence', col='Type', y='energyE', hue='mol_system',
                    data=plot_data,
                    kind=kind, dodge=jitter, jitter=jitter,
                    palette=palette, height=height, aspect=aspect_ratio, sharey=True,
                    facet_kws=facet_kws, alpha=0.7, s=s
                    )
    # adjust the title to read D=$\zeta$D, T=$\zeta$T, etc.
    g.set_titles(row_template='{row_name}Z', col_template='{col_name}')

    g.set_xlabels('')
    g.set_ylabels(r'Energy Error (a.u.)')
    for i, ax in enumerate(g.axes.flat):
        #ax.axhline(y=.02, linestyle='--', color='black', linewidth=.35)
        ax.axhline(y=-.00, linestyle='--', color='black', linewidth=.35)
g._legend.remove()
g.fig.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=3, fancybox=True, fontsize=10)
g.fig.tight_layout()
g.fig.savefig(paper_path.joinpath('energy_error.png'), dpi=600, bbox_inches='tight')


In [62]:
# Now make a plot of the frequency dependence of the error

vlevel = ['D', 'T', 'Q', '5']
plot_data = basis_data.query('valence.isin(@vlevel)').copy()
# remove unused categories from valence
plot_data.valence = plot_data.valence.cat.remove_unused_categories()
aspect_ratio = 1.2
with sns.plotting_context('paper', font_scale=1.00):
    facet_kws = {'margin_titles': True}
    g = sns.catplot(row='valence', col='Type',
                    x='omega',
                    y='alphaE', hue='mol_system',
                    data=plot_data, kind='strip',
                    dodge=True, jitter=False,
                    palette='colorblind',
                    height=1.5, aspect=aspect_ratio, sharex=True, sharey='row',
                    alpha=alpha,
                    s=5,
                    margin_titles=True,
                    facet_kws=facet_kws)

    # adjust the title to read D=$\zeta$D, T=$\zeta$T, etc.
    g.set_titles(col_template='{col_name}', row_template='{row_name}Z')

    g.set_xlabels('Frequency')
    for ax in g.axes.flat:
        # draw horizontal line at .05%rue
        #ax.set_yscale('symlog', linthresh=1e-1)
        ax.axhline(y=.02, linestyle='--', color='black', linewidth=.35)
        ax.axhline(y=-.02, linestyle='--', color='black', linewidth=.35)
        # for the symlog scale to regular notation with no decimal places for the y axis

        # make the y limits symmetrical for each plot using the max absolute value
        max_lim = max(abs(ax.get_ylim()[0]), abs(ax.get_ylim()[1]))
        #ax.set_ylim(-max_lim, max_lim)
        # remove the y axis label if not the first column
        if ax.get_subplotspec().colspan.start != 0:
            ax.set_ylabel('')
        else:
            ax.set_ylabel(r'$\alpha$ Error (%)')

        #ax.set_ylabel(r'$\alpha$ Error (%)')
        # rotate the xlabels
        for label in ax.get_xticklabels():
            label.set_horizontalalignment('right')

        # create vertical lines to separate basis sets
        for i in range(1, len(plot_data.omega.unique())):
            ax.axvline(i - .5, linestyle='-.', color='k', linewidth=.5, alpha=0.5)

# remove legend and create a new legend in 3 column format outside the plot
g._legend.remove()
# rename legend title
#g._legend.set_title('Subset')
g.fig.tight_layout()
g.fig.legend(loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=3, fancybox=True, fontsize=10)
g.fig.savefig(paper_path.joinpath('alpha_freq_error.png'), dpi=600, bbox_inches='tight')


In [63]:
# query molecules with alphaE.abs()>.2 in singly augmented 5Z basis
plot_data = basis_data.query('valence=="5" and augmentation=="aug" and alphaE.abs()>.2').copy()
mol5 = plot_data.molecule.unique()

In [64]:
# query molecules with alphaE.abs()>.2 in singly augmented 5Z basis
plot_data = basis_data.query('valence=="Q" and Type=="d-aug-cc-pVnZ" and alphaE.abs()>.25').copy()
daug_outliers = plot_data.molecule.unique()
plot_data = basis_data.query('valence=="Q" and Type=="aug-cc-pVnZ" and alphaE.abs()>.6').copy()
aug_outliers = plot_data.molecule.unique()





In [65]:
vlevel = ['Q']
plot_data = (basis_data.query('valence.isin(@vlevel) & molecule.isin(@aug_outliers)')
             .copy())
# remove unused categories from valence
plot_data.valence = plot_data.valence.cat.remove_unused_categories()
facet_kws = {'margin_titles': True, 'sharey': False, 'sharex': True}

with sns.plotting_context('paper', font_scale=1.05):
    g = sns.relplot(col='molecule',
                    col_wrap=3,
                    x='omega',
                    y='alphaE',
                    hue='Type',
                    style='Type',
                    data=plot_data, kind='line',
                    palette='colorblind', height=2.0, aspect=1.2,
                    markers=True,
                    ms=5,
                    facet_kws=facet_kws)
    # adjust the title to read D=$\zeta$D, T=$\zeta$T, etc.

    g.set_xlabels('$\omega(i)$')
    g.set_xlabels('Frequency')
    ax_titles = ['aug-cc-pVQZ', 'aug-cc-pCVQZ', 'd-aug-cc-pVQZ', 'd-aug-cc-pCVQZ']
    g.set_titles(col_template='{col_name}')
    for i, ax in enumerate(g.axes.flat):
        # draw horizontal line at .05%rue
        #ax.set_yscale('symlog', linthresh=1e-2)
        max_lim = max(abs(ax.get_ylim()[0]), abs(ax.get_ylim()[1]))
        ax.set_ylim(-max_lim, max_lim)
        ax.axhline(y=.02, linestyle='--', color='black', linewidth=.35)
        ax.axhline(y=-.02, linestyle='--', color='black', linewidth=.35)
        # for the symlog scale to regular notation with no decimal places for the y axis

        if i % 3 == 0:
            ax.set_ylabel(r'$\alpha$ Error (%)')
        else:
            ax.set_ylabel('')
        # rotate the xlabels
        for label in ax.get_xticklabels():
            label.set_horizontalalignment('right')

g.fig.tight_layout()
g._legend.remove()
g.fig.legend(loc='upper center', bbox_to_anchor=(0.5, 1.060), fancybox=False, ncol=4)
g.fig.savefig(paper_path.joinpath('frequency_dependent_iso_error.png'), dpi=600,
              bbox_inches='tight')


In [66]:
ij_basis_data = MRAComparedBasisDF(august_database.alpha_eigen, ['molecule', 'ij', 'omega'], ['alpha'], True)
selection = ['SF2', 'Na2', 'NaCl', 'BH2Cl', 'HF', ]
vlevel = ['Q']
plot_data = (ij_basis_data.query('valence.isin(@vlevel) & molecule.isin(@selection)')
             .copy())
# remove unused categories from valence
plot_data.valence = plot_data.valence.cat.remove_unused_categories()
facet_kws = {'margin_titles': True, 'sharey': "row", 'sharex': True}

with sns.plotting_context('paper', font_scale=1.00):
    g = sns.relplot(row='molecule',
                    x='omega',
                    y='alphaE',
                    hue='Type',
                    col='ij',
                    data=plot_data, kind='line',
                    palette='colorblind',
                    height=2.0, aspect=1.2,
                    markers=True,
                    style='Type',
                    ms=7,
                    facet_kws=facet_kws)
    # adjust the title to read D=$\zeta$D, T=$\zeta$T, etc.

    g.set_xlabels('$\omega(i)$')
    g.set_xlabels('Frequency')
    ax_titles = ['aug-cc-pVQZ', 'aug-cc-pCVQZ', 'd-aug-cc-pVQZ', 'd-aug-cc-pCVQZ']
    g.set_titles(col_template='{col_name}', row_template='{row_name}')
    for i, ax in enumerate(g.axes.flat):
        # draw horizontal line at .05%rue
        #ax.set_yscale('symlog', linthresh=1e-2)
        #ax.axhline(y=.02, linestyle='--', color='black',linewidth=.35)
        #ax.axhline(y=-.02, linestyle='--', color='black',linewidth=.35)
        max_lim = max(abs(ax.get_ylim()[0]), abs(ax.get_ylim()[1]))
        ax.set_ylim(-max_lim, max_lim)
        # for the symlog scale to regular notation with no decimal places for the y axis

        if i % 3 == 0:
            ax.set_ylabel(r'$\alpha$ Error (%)')
        else:
            ax.set_ylabel('')
        # rotate the xlabels
        for label in ax.get_xticklabels():
            label.set_horizontalalignment('right')

g.fig.tight_layout()
g._legend.remove()
g.fig.legend(loc='upper center', bbox_to_anchor=(0.5, 1.025), fancybox=True, ncol=4)
g.fig.savefig(paper_path.joinpath('frequency_dependent_component_error.png'), dpi=600,
              bbox_inches='tight')


In [67]:

ij_basis_data = MRAComparedBasisDF(august_database.alpha_eigen, ['molecule', 'ij', 'omega'], ['alpha'], True)
ij_basis_data


In [68]:
# save as csv 

csv_dir_path = Path('/home/adrianhurtado/projects/QuantumResponsePro/examples/data_apps/csv')
ij_basis_data.to_csv(csv_dir_path.joinpath('alpha_ij.csv'))

In [69]:
# query molecules with alphaE.abs()>.2 in singly augmented 5Z basis
plot_data = basis_data.query('valence=="Q" and gammaE.abs()>4.0 & gamma>1e-2').copy()
gamma_outliers = plot_data.molecule.unique()
gamma_outliers.tolist()




In [70]:
facet_kw = {'sharey': "row", 'sharex': True, "margin_titles": True}
ms = ['First-row', 'Fluorine']
type = ['aug-cc-pVnZ', 'd-aug-cc-pVnZ']
vlevels = ['D', 'T', 'Q', '5']

p_data = ij_basis_data.query('molecule.isin(@gamma_outliers) and mol_system.isin'
                             '(@ms) & Type.isin(@type) ').copy()
p_data = p_data.query('valence.isin(@vlevels) and omega==0').copy()
#p_data.Type = p_data.Type.cat.remove_unused_categories()

# plot valence vs alpha error with component hue for each molecule
with sns.plotting_context('paper', font_scale=1.0, ):
    facet_kw = {'sharey': "row", 'sharex': True, "margin_titles": True}
    g = sns.relplot(row='molecule',
                    x='valence', y='alphaE', col='ij',
                    hue='Type',
                    style='Type',
                    kind='line',
                    data=p_data,
                    markers=True,
                    dashes=True,
                    palette="colorblind",
                    height=1.5, aspect=1.2, facet_kws=facet_kw, ms=5, alpha=.8)
    # adjust the title to read D=$\zeta$D, T=$\zeta$T, etc.

    g.set_titles(col_template='{col_name}')
    g.set_ylabels(r'$\alpha_{ij}$ Error (%)')
    g.margin_titles = True
    g.set_titles(col_template='{col_name}', row_template='{row_name}')
    g.set_xlabels('')

    for ax in g.axes.flat:
        ax.axhline(y=-.00, linestyle='--', color='black', linewidth=.5)

g.fig.tight_layout()
g._legend.remove()

g.fig.legend(loc='upper center', bbox_to_anchor=(0.5, 1.025), fancybox=True, ncol=4, fontsize=8)
g.savefig(paper_path.joinpath('first_and_fluorine_component_error.png'), dpi=600)



In [71]:
gamma_outliers = plot_data.molecule.unique()
gamma_outliers.tolist()

gamma_outliers = gamma_outliers.tolist() + ['SF2']

selected = ['SiO', 'NaCl', 'SH2', 'NaH']

ms = ['Second-row']
p_data = ij_basis_data.query('molecule.isin(@selected) and mol_system.isin'
                             '(@ms) and omega==0').copy()
#remove 
facet_kw = {'sharey': "row", 'sharex': True, "margin_titles": True}
# plot valence vs alpha error with component hue for each moleculek
with sns.plotting_context('paper', font_scale=1.0, ):
    g = sns.relplot(row='molecule',
                    x='valence', y='alphaE', col='ij',
                    hue='Type',
                    style='Type',
                    kind='line',
                    data=p_data,
                    palette="colorblind",
                    markers=True,
                    dashes=True,
                    height=1.5, aspect=1.2, facet_kws=facet_kw, ms=5, alpha=.8)
    # adjust the title to read D=$\zeta$D, T=$\zeta$T, etc.
    g.set_titles(col_template='{col_name}')
    g.set_ylabels(r'$\alpha_{ij}$ Error (%)')
    g.set_titles(col_template='{col_name}', row_template='{row_name}')
    g.set_xlabels('')

    for ax in g.axes.flat:
        ax.axhline(y=.0, linestyle='--', color='black', linewidth=.5)

g.fig.tight_layout()
g._legend.remove()
g.fig.legend(loc='upper center', bbox_to_anchor=(0.5, 1.025), fancybox=True, ncol=4, fontsize=8)
g.savefig(paper_path.joinpath('component_error.png'), dpi=600)


In [72]:
basis_e_data = MRAComparedBasisDF(august_database.energy_df, ['molecule'], ['energy'], False)
len(basis_e_data.molecule.unique())

In [73]:
vlevel = ['D', 'T', 'Q', '5']
plot_data = basis_e_data.query('valence.isin(@vlevel)').copy()
# remove unused categories from valence
plot_data.valence = plot_data.valence.cat.remove_unused_categories()
with sns.plotting_context('paper', font_scale=1.00):
    g = sns.catplot(col='valence', x='Type', y='energyE', hue='mol_system',
                    data=plot_data, kind='strip', dodge=True, jitter=True,
                    s=15, palette='colorblind', height=4, aspect=aspect_ratio, sharey=True,
                    alpha=0.8)
    # adjust the title to read D=$\zeta$D, T=$\zeta$T, etc.
    g.set_titles('{col_name}Z')
    g.set_xlabels('')

    for ax in g.axes.flat:
        # draw horizontal line at .05%
        ax.set_yscale('log')
        #ax.axhline(y=.05,linestyle='--',color='orange')
        #ax.axhline(y=-.05,linestyle='--',color='orange')
        # for the symlog scale to regular notation with no decimal places for the y axis
        ax.yaxis.set_major_formatter(ScalarFormatter())
        ax.yaxis.set_major_formatter(FormatStrFormatter('%2g'))
        # on each axis divide the background into 4 regions horizontally and color the background 
        # of each region a different color
        xl = ['sn', 'sCn', 'dn', 'dCn']
        ax.set_xticklabels(xl)

        ax.set_ylabel(r'Energy Error (a.u.)')
        # rotate the xlabels
        # create vertical lines to separate basis sets
        for i in range(1, len(plot_data.Type.unique())):
            ax.axvline(i - .5, linestyle='-.', color='k', linewidth=.5, alpha=.8)

        for i in range(1, 5):
            ax.axvspan((i - 1) - .5, i - .5, 1e-3, facecolor=pal[i - 1], alpha=0.1)
        # remove the legend title

# remove legend and create a new legend in 3 column format outside the plot
g._legend.remove()
g.fig.legend(loc='center', bbox_to_anchor=(0.5, 1.05), ncol=3, fancybox=True, fontsize=10)
g.fig.tight_layout()

#g.fig.savefig(paper_path.joinpath('energy_error.png'), dpi=600, bbox_inches='tight')


In [74]:

daug_outliers = (basis_e_data.query('valence=="Q" & basis=="d-aug-cc-pVQZ" ').sort_values
                 ('energyE',
                  ascending=False).head(20)).molecule.unique()
daugC_outliers = (basis_e_data.query('valence=="Q" & basis=="d-aug-cc-pCVQZ"').sort_values
                  ('energyE',
                   ascending=False).head(20)).molecule.unique()
e_outliers = daug_outliers.tolist() + daugC_outliers.tolist()
e_outliers = (basis_e_data.query("molecule.isin(@e_outliers) & valence=='Q' & "
                                 "augmentation=='d-aug'")
              )
# remove any values less with abs less than 0.02
e_outliers = e_outliers.query('energyE.abs()>.002').sort_values('energyE', ascending=True)

with sns.plotting_context('paper', font_scale=1.00):
    #g = sns.FacetGrid(data=e_outliers, col='polarization', hue='mol_system', sharey=True,
    #                  sharex=False, aspect=2.0, height
    #                  =10, col_order=['V', 'CV'])
    #g.map(sns.barplot, 'molecule', 'eE', 'basis',  alpha=.9)
    ## adjust title to read d-aug-cc-pVQZ
    # color the bars based on molecules mol_system typeo
    g = sns.catplot(x='molecule', y='energyE',
                    hue='mol_system',
                    col='polarization', col_order=['V', 'CV'],
                    data=e_outliers, dodge=False,
                    alpha=.9, palette='colorblind', fill=True, height=3, aspect=1.4, kind='bar',
                    sharex=False
                    )
    g.set_titles('d-aug-cc-p{col_name}QZ')
    g.set_xlabels('')
    g.set_ylabels('Energy Error (a.u.)')
    # rotate the xticks
    for ax in g.axes.flat:
        for label in ax.get_xticklabels():
            label.set_rotation(90)
            label.set_horizontalalignment('center')
    # remove legend title
    # create a new legend in 3 column format outside the plot
    # using only a single legend only using the second axis
    #g.fig.savefig(paper_path.joinpath('energy_error_outliers.png'), dpi=600, bbox_inches='tight')

In [75]:

basis_data['absolute_alphaE'] = basis_data.alphaE.abs()

daug_outliers = (basis_data.query('valence=="Q" & basis=="d-aug-cc-pVQZ" & omega==0').sort_values
                 ('absolute_alphaE',
                  ascending=False).head(20)).molecule.unique()
daugC_outliers = (basis_data.query('valence=="Q" & basis=="d-aug-cc-pCVQZ" & omega==0').sort_values
                  ('absolute_alphaE',
                   ascending=False).head(20)).molecule.unique()

a_outliers = daug_outliers.tolist() + daugC_outliers.tolist()
print(a_outliers)

a_outliers = (basis_data.query("molecule.isin(@a_outliers) & valence=='Q' & augmentation=='d-aug' ")
              .sort_values
              ('alphaE', ascending=True)
              .sort_values('alphaE'))
# remove any values less with abs less than 0.02
a_outliers = a_outliers.query('alphaE.abs()>.02')

with sns.plotting_context('paper', font_scale=1.00):
    #g = sns.FacetGrid(data=e_outliers, col='polarization', hue='mol_system', sharey=True,
    #                  sharex=False, aspect=2.0, height
    #                  =10, col_order=['V', 'CV'])
    #g.map(sns.barplot, 'molecule', 'eE', 'basis',  alpha=.9)
    ## adjust title to read d-aug-cc-pVQZ
    # color the bars based on molecules mol_system typeo
    g = sns.catplot(x='molecule', y='alphaE',
                    hue='mol_system',
                    col='polarization', col_order=['V', 'CV'],
                    data=a_outliers, dodge=False,
                    alpha=.9, palette='colorblind', fill=True, height=3, aspect=1.4, kind='bar'
                    , sharey=False,
                    sharex=False
                    )
    g.set_titles('d-aug-cc-p{col_name}QZ')
    g.set_xlabels('')
    g.set_ylabels(r'$\alpha$ Error (%)')
    # rotate the xticks
    for ax in g.axes.flat:
        for label in ax.get_xticklabels():
            label.set_rotation(90)
            label.set_horizontalalignment('center')
        # add a horizontal line at 0.02 and -0.02
        ax.axhline(y=.02, linestyle='--', color='orange')
        ax.axhline(y=-.02, linestyle='--', color='orange')
        # symlog scale
    # remove legend title
    # create a new legend in 3 column format outside the plot
    # using only a single legend only using the second axis
    #g.fig.savefig(paper_path.joinpath('alpha_error_outliers.png'), dpi=600, bbox_inches='tight')


In [76]:
# make a plot of the frequency dependence of the error for a single molecule
mol = ['NaCl', 'NaCN', 'Na2', 'NaH', 'Be', 'NH2F', 'HF', 'F2', 'Ne']
plot_data = basis_data.query('molecule.isin(@mol)').copy()
# remove unused categories from valence
plot_data.valence = plot_data.valence.cat.remove_unused_categories()
# plot valence versus alpha error with component hue for Type
facet_kw = {'sharey': False, 'sharex': False}

with sns.plotting_context('paper', font_scale=1.00):
    g = sns.relplot(x='valence', y='alphaE', hue='Type', kind='line', data=plot_data,
                    palette=pal, markers=True, height=2, aspect=1, alpha=.8, col='molecule', col_wrap=3, facet_kws=facet_kw)
    g.set_ylabels(r'$\alpha$ Error (%)')
    g.set_xlabels('')

    g.set_titles('{col_name}')

    # symlog scale
    for ax in g.axes.flat:
        #ax.set_yscale('symlog', linthresh=1e-1)
        ax.axhline(y=.02, linestyle='--', color='orange')
        ax.axhline(y=-.02, linestyle='--', color='orange')
        # set the yaxis to scalar formatter
        ax.yaxis.set_major_formatter(ScalarFormatter())
    g.savefig(paper_path.joinpath('alpha_outlier_valence_error.png'), dpi=600, bbox_inches='tight')


In [77]:
basis_data
data = pd.DataFrame()
# for each frequency
for omega in basis_data.omega.unique():
    # filter data for frequency
    omega_data = basis_data.query('omega==@omega')
    # create the column for eE by first setting index to basis and molecule
    omega_data = omega_data.set_index(['basis', 'molecule'])
    # create a column of MRA energy data for each molecule omega pair
    omega_data['eE'] = basis_e_data.set_index(['basis', 'molecule'])['energyE']
    # reset index
    omega_data = omega_data.reset_index()
    # concat to full data
    data = pd.concat([data, omega_data])

bdata = data.copy()




In [78]:
v_level = ['D', 'T', 'Q', '5']
omegas = [0, 1, 2, 3, 4, 5, 6, 7, 8]
molsys = ['First-row', 'Fluorine']
types = ['aug-cc-pVnZ', 'd-aug-cc-pVnZ']
data = (bdata.query(
    'valence.isin(@v_level) & omega.isin(@omegas) & mol_system.isin(@molsys) & Type.isin(@types) ')
        .copy())
data.valence = data.valence.cat.remove_unused_categories()
data.Type = data.Type.cat.remove_unused_categories()
data.mol_system = data.mol_system.cat.remove_unused_categories()
data['absolute_alphaE'] = data.alphaE.abs()

# relplot comparing energy error to alpha error
facet_kw = {"sharey": True, "sharex": True, "margin_titles": True}
with sns.plotting_context('paper', font_scale=1.25):
    g = sns.relplot(x='eE', y='alphaE', col='Type', hue='valence', data=data, kind='scatter',
                    size='omega',
                    legend=True,
                    row='mol_system',
                    palette='Set1', height=4, aspect=1,
                    alpha=0.5, facet_kws=facet_kw)
    g.set_xlabels('Energy Error (a.u.)')
    g.set_ylabels(r'$\alpha$ Error (%)')
    g.set_titles(row_template='{row_name}', col_template='{col_name}')

g.fig.gca().invert_xaxis()

#g.fig.savefig(paper_path.joinpath('alpha_energy_error_1.png'), dpi=600)

# remove the legend title

In [79]:
omegas = [0, 1, 2, 3, 4, 5, 6, 7, 8]
molsys = ['Second-row']
data = bdata.query(
    'valence.isin(@v_level) & omega.isin(@omegas) & mol_system.isin(@molsys) ').copy()
data.valence = data.valence.cat.remove_unused_categories()
data.mol_system = data.mol_system.cat.remove_unused_categories()
data['absolute_alphaE'] = data.alphaE.abs()

# relplot comparing energy error to alpha error
facet_kw = {"sharey": True, "sharex": True, "margin_titles": True}
with sns.plotting_context('paper', font_scale=1.25):
    g = sns.relplot(x='eE', y='alphaE', row='polarization', hue='valence', data=data,
                    kind='scatter', s=50,
                    row_order=['V', 'CV'],
                    col='augmentation',
                    size='omega',
                    palette='Set1', height=4, aspect=1, legend='full',
                    alpha=0.5, facet_kws=facet_kw)
    g.set_titles(row_template='{row_name}', col_template='{col_name}')
    for ax in g.axes.flat:
        #ax.set_yscale('symlog', linthresh=1e-1)
        ax.axhline(y=.05, linestyle='--', color='orange')
        ax.axhline(y=-.05, linestyle='--', color='orange')
        # reverse the x axis
        ax.set_xlim(ax.get_xlim()[::-1])
        # make the yaxis symmetrical
        max_lim = max(abs(ax.get_ylim()[0]), abs(ax.get_ylim()[1]))
        ax.set_ylim(-25, 25)
        # use a scalar formatter for the y axis
        ax.yaxis.set_major_formatter(ScalarFormatter())
        # reverse the x axis

g.fig.gca().invert_xaxis()

#g.fig.savefig(paper_path.joinpath('alpha_energy_error_2.png'), dpi=600)



In [80]:
omegas = [0, 1, 2, 3, 4, 5, 6, 7, 8]
molsys = ['Second-row']
data = bdata.query(
    'valence.isin(@v_level) & omega.isin(@omegas) & mol_system.isin(@molsys) ').copy()
data.valence = data.valence.cat.remove_unused_categories()
data.mol_system = data.mol_system.cat.remove_unused_categories()
data['absolute_alphaE'] = data.alphaE.abs()

# relplot comparing energy error to alpha error
facet_kw = {"sharey": True, "sharex": True, "margin_titles": True}
with sns.plotting_context('paper', font_scale=1.25):
    g = sns.relplot(x='eE', y='alphaE', row='polarization', hue='valence', data=data,
                    kind='scatter', s=50,
                    row_order=['V', 'CV'],
                    col='augmentation',
                    size='omega',
                    palette='Set1', height=4, aspect=1, legend='full',
                    alpha=0.5, facet_kws=facet_kw)
    g.set_titles(row_template='{row_name}', col_template='{col_name}')
    for ax in g.axes.flat:
        ax.set_yscale('symlog', linthresh=1e-1)
        ax.set_xscale('log')
        ax.axhline(y=.02, linestyle='--', color='orange')
        ax.axhline(y=-.02, linestyle='--', color='orange')
        # reverse the x axis
        ax.set_xlim(ax.get_xlim()[::-1])
        # make the yaxis symmetrical
        max_lim = max(abs(ax.get_ylim()[0]), abs(ax.get_ylim()[1]))
        ax.set_ylim(-25, 25)
        # use a scalar formatter for the y axis
        ax.yaxis.set_major_formatter(ScalarFormatter())
        # reverse the x axis

g.fig.gca().invert_xaxis()

#g.fig.savefig(paper_path.joinpath('alpha_energy_error_2.png'), dpi=600)


In [81]:
data

In [82]:
basis_data

In [21]:
# for each basis set get the alphaE statistics and place them in a dataframe

iso_stats = pd.DataFrame()
for basis in basis_data.basis.unique():
    # filter data for basis
    subset = basis_data.query('basis==@basis & omega==0')
    # remove nan values row
    # get the mean and std of alphaE
    d = subset.describe()['alphaE']
    d.name = basis
    # concat to the full dataframe
    iso_stats = pd.concat([iso_stats, d], axis=1)
iso_stats = iso_stats.T



In [83]:
# write a function which does the above for a given omega and alphaE or gammaE
def basis_stats(iso_type, omega):
    iso_stats = pd.DataFrame()
    for basis in basis_data.basis.unique():
        # filter data for basis
        subset = basis_data.query('basis==@basis & omega==@omega')
        # remove nan values row
        if iso_type == 'gammaE':
            # drop any values with gammaE.abs() < 1e-3
            subset = subset.query('gamma.abs()>1e-3')
        # concat to the full dataframe
        # get the mean and std of alphaE
        d = subset.describe()[iso_type]
        d.name = basis
        iso_stats = pd.concat([iso_stats, d], axis=1)
    iso_stats = iso_stats.T
    # remove the 6Z basis set
    iso_stats = iso_stats.drop('aug-cc-pV6Z', axis=0)
    iso_stats = iso_stats.drop('d-aug-cc-pV6Z', axis=0)
    return iso_stats


# write a function which does the above for a given omega and alphaE or gammaE
def basis_e_stats():
    iso_stats = pd.DataFrame()
    for basis in basis_e_data.basis.unique():
        # filter data for basis
        subset = basis_e_data.query('basis==@basis')
        # remove nan values row
        # get the mean and std of alphaE
        d = subset.describe()['energyE']
        d.name = basis
        # concat to the full dataframe
        iso_stats = pd.concat([iso_stats, d], axis=1)
    iso_stats = iso_stats.T
    # remove the 6Z basis set
    iso_stats = iso_stats.drop('aug-cc-pV6Z', axis=0)
    iso_stats = iso_stats.drop('d-aug-cc-pV6Z', axis=0)
    return iso_stats


def get_iso_mols(iso_type, omega):
    iso_stats = pd.DataFrame()
    for basis in basis_data.basis.unique():
        # filter data for basis
        subset = basis_data.query('basis==@basis & omega==@omega')
        # remove nan values row
        if iso_type == 'gammaE':
            # drop any values with gammaE.abs() < 1e-3
            subset = subset.query('gamma.abs()>1e-3')
        # concat to the full dataframe
        mols = pd.Series(subset.dropna().molecule.unique())
        mols.name = basis
        iso_stats = pd.concat([iso_stats, mols], axis=1)
    iso_stats = iso_stats.T
    # remove the 6Z basis set
    iso_stats = iso_stats.drop('aug-cc-pV6Z', axis=0)
    iso_stats = iso_stats.drop('d-aug-cc-pV6Z', axis=0)
    return iso_stats.T



In [84]:
# Get the basis set statistics 
e_stats = basis_e_stats()
iso_0 = basis_stats('alphaE', 0)
gamma_0 = basis_stats('gammaE', 0)

gamma_mols = get_iso_mols('gammaE', 0)
alpha_mols = get_iso_mols('alphaE', 0)



In [85]:
iso_0
# new order of rows
DZ = ['aug-cc-pVDZ', 'aug-cc-pCVDZ', 'd-aug-cc-pVDZ', 'd-aug-cc-pCVDZ']
TZ = ['aug-cc-pVTZ', 'aug-cc-pCVTZ', 'd-aug-cc-pVTZ', 'd-aug-cc-pCVTZ']
QZ = ['aug-cc-pVQZ', 'aug-cc-pCVQZ', 'd-aug-cc-pVQZ', 'd-aug-cc-pCVQZ']
fiveZ = ['aug-cc-pV5Z', 'd-aug-cc-pV5Z']
new_order = DZ + TZ + QZ + fiveZ
iso_0 = iso_0.reindex(new_order)
gamma_0 = gamma_0.reindex(new_order)
e_stats = e_stats.reindex(new_order)
iso_0




In [86]:
def diff_mol_subset(df):
    # get the molecules in the dataframe
    mols = basis_data.molecule.unique()
    full_set = set(mols)
    # get the difference between the full set and the subset
    diff = full_set.difference(mols)
    diff_mols = pd.DataFrame()
    for basis in df.columns:
        basis_mols = set(df[basis])
        # compute the difference between the full set and the basis set
        basis_difference = full_set.difference(basis_mols)
        basis_difference = list(basis_difference)

        # if basis difference is not empty then add it to the dataframe
        if len(basis_difference) > 0:
            d = pd.Series(basis_difference, name=basis).T

            diff_mols = pd.concat([diff_mols, d], axis=1)
            # return the difference
    not_available = diff_mols.T.drop_duplicates().T
    not_available = not_available.applymap(lambda x: '\\ce{' + str(x) + '}', na_action='ignore')
    # sort the values of each column alphabetically
    not_available = not_available.apply(lambda x: x.sort_values().values)
    # sort the columns by value
    # sort the columns by the values in the first row
    return not_available


polar_na = diff_mol_subset(alpha_mols)

# concat to the full dataframe


In [87]:
gamma_na = diff_mol_subset(gamma_mols)


In [88]:
# make a latex table not_available
polar_na.to_latex(paper_path.joinpath('not_available.tex'), na_rep='-'
                  )

In [89]:
# apply a symlog color map to the all columns except count
import matplotlib.cm as cm
import matplotlib.colors as mcolors


def background_with_norm(s, vmax=1e-1):
    linthresh = 2e-3
    linscale = .10
    cmap = cm.bwr
    norm = mcolors.SymLogNorm(linthresh=linthresh, linscale=linscale, base=10, vmin=-vmax,
                              vmax=vmax)
    return ['background-color: {:s}'.format(mcolors.to_hex(c.flatten())) for c in
            cmap(norm(s.values))]


def __style_summary(df, fmt="{:.2e}", vmax=20):
    # only vmax on everything but first column

    vmax = df.iloc[:, 1:].abs().max().max()
    print(vmax)
    bnorm = lambda x: background_with_norm(x, vmax)
    df = df.T

    new_idx = {}
    for i in df.index:
        if i[-1] == '%':
            new_idx[i] = i[:-1] + '\\' + i[-1]
        else:
            new_idx[i] = i
    df.rename(index=new_idx, inplace=True)
    df = df.T

    # Define the maximum data value (in absolute terms) for normalization
    #styled_df = df.style.apply(bnorm, subset=df.columns[1:])
    # Apply the formatting to every column except first one
    styled_df = df.style.format(fmt, subset=df.columns[1:])
    # for the first column (count) use integer formatting
    # set the data type to int first
    styled_df = styled_df.format("{:.0f}", subset=df.columns[0])
    #make the background the first column to be white

    #styled_df = styled_df.format(fmt)
    return styled_df

# make a latex table of the statistics


In [90]:

style_iso = __style_summary(iso_0, fmt="{:.2e}", vmax=20)
# format the count column to be an integer
style_iso.to_latex(paper_path.joinpath('alphaE_stats.tex'),
                   hrules=True,
                   convert_css=True,
                   multicol_align='|c|',
                   siunitx=True,
                   )


In [91]:
gamma_0 = basis_stats('gammaE', 0)

style_gamma = __style_summary(gamma_0, fmt="{:.2e}", vmax=20)
# format the count column to be an integer
style_gamma.to_latex(paper_path.joinpath('gammaE_stats.tex'),
                     hrules=True,
                     convert_css=True,
                     multicol_align='|c|',
                     siunitx=True,
                     )


In [92]:

style_iso = __style_summary(e_stats, fmt="{:.2e}", vmax=.05)
# format the count column to be an integer
style_iso.to_latex(paper_path.joinpath('energy_stats.tex'),
                   hrules=True,
                   convert_css=True,
                   multicol_align='|c|',
                   siunitx=True,
                   )


In [93]:

mra_data = august_database.iso_data.query('basis=="MRA"').copy()

In [94]:
molecule_subset = basis_data.query('basis=="aug-cc-pVDZ" & omega==0')[
    ['molecule', 'mol_system']].drop_duplicates()

In [95]:
molecule_subset

In [96]:

molecule_subset.sort_values('mol_system', inplace=True)
# get the first three colors of the colorblind palette
pal = sns.color_palette('colorblind', n_colors=3)
# make a barplot of the number of molecules in each system
mol_only = molecule_subset['molecule']
# make a table of the number of molecules in each system columns
# being first row, fluorine, and second row
mol_table = pd.DataFrame()

first_row = molecule_subset[molecule_subset['mol_system'] == 'First-row']['molecule'].unique()
fluorine = molecule_subset[molecule_subset['mol_system'] == 'Fluorine']['molecule'].unique()
second_row = molecule_subset[molecule_subset['mol_system'] == 'Second-row']['molecule'].unique()
first_row = pd.Series(first_row, name='First-row')
fluorine = pd.Series(fluorine, name='Fluorine')
second_row = pd.Series(second_row, name='Second-row')
mol_table = pd.concat([first_row, fluorine, second_row], axis=1)
# sort the values of each column alphabetically

# replace nan values with -
mol_table = mol_table.fillna('-')
mol_table = mol_table.applymap(lambda x: '\\ce{' + str(x) + '}', na_action='ignore')
print(mol_table)
# color the columns with the colorblind palette 
# map the colors to the columns just based on the index
# make pal lighter by 50%
pal = sns.color_palette('muted', n_colors=3)
# blend the colors with white


# make a dictionary of the colors
colormap = dict(zip(mol_table.columns, pal.as_hex()))
# map colorblind colormap to the columns
mol_only = molecule_subset['molecule']
# make mol only into a table 9 x 10 
mol_df = pd.DataFrame(mol_only.to_numpy().reshape((10, 9)))
# add one to each index and column to make the numbers run from 1 to 9
mol_df.index += 1
mol_df.columns += 1

ms = molecule_subset.molecule
# surround the values with \ce{} to make them chemical formulas
ms = ms.apply(lambda x: '\\ce{' + str(x) + '}')
# make a dictionary of the colors
mol_dict = dict(zip(ms, molecule_subset.mol_system))

mol_color = {mol: colormap[mol_dict[mol]] for mol in ms}
# apply the colors to the dataframe based on mol_color dictionary

mol_df = mol_df.applymap(lambda x: '\\ce{' + str(x) + '}')
mol_df = mol_df.style.apply(lambda x: [f'background-color: {mol_color[v]}' for v in x],
                            axis=1)
# before printing surround the values with \ce{} to make them chemical formulas
# write the dataframe to a latex table
mol_df.to_latex(paper_path.joinpath('molecule_table.tex'),
                multicol_align='|c|',
                hrules=True,
                convert_css=True,
                )


In [97]:
molecule_subset.sort_values('mol_system')