# Import packages

In [None]:
# Standard packages
import sys
# When installing new packages uncomment the following line:
# !{sys.executable} -m pip install shap
import os
import numpy as np
from glob import glob
import re
import pandas as pd

# sklearn
from sklearn.metrics import r2_score,mean_squared_error,mean_absolute_error,mean_absolute_percentage_error
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.kernel_ridge import KernelRidge
from sklearn.model_selection import GridSearchCV

# SHAP
import shap

#Plotting
import seaborn as sns
sns.set_style()
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rcParams['figure.dpi'] = 200

# Silence
import warnings
warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning) 

# Timing 
from time import perf_counter
from tqdm.auto import tqdm

# Set options

In [None]:
savefiles=False
basis='cc-PVDZ'

# Functions to help generate the data

In [None]:
def gen_df(path):
    '''
    Function to generate data from path
    
    returns Pandas dataframe
    '''
    radius=float(re.findall("\d+\.\d+", os.path.basename(os.path.normpath(path)).split('_')[-1].split()[0])[0])
    data=[]
    mat_keys=[]
    with open(path,'r') as f:
        # Read file
        FILE = f.readlines()

        # Grab easy features
        for idx,i in enumerate(FILE):
            if len(i.split())!=0:
                fltfinder=re.findall('-?\d+\.?\d*', i.split()[-1])
                if '==> v2RDM' in i and 'Violation' not in i:
                    mat_keys.append((idx,''.join(i.split()[1:3])))
                if len(fltfinder)==1:
                    flt=float(fltfinder[0])
                    if 'DETCI' in i or 'HF' in i:
                        data.append(('DETCI@energy', flt))
                    if 'v2RDM' in i:
                        data.append((''.join(i.split()[0:2]),flt))

        og_data = dict(data)            
        # Dictionary with keys and indices for the matrices     
        mat_dict=dict(mat_keys)   

        result = {}

        for key,value in mat_dict.items():
            if value not in result.values():
                result[key] = value

        mat_dict=result

    return pd.DataFrame.from_dict(og_data,orient='index',columns=[radius])    

def collate_data(paths):
    '''
    Create a Pandas dataframe from all the available data
    '''
    cd = pd.concat([gen_df(i) for i in paths],axis=1).sort_index(axis=1)
    return cd.T.loc[:,~cd.T.columns.duplicated()].T

def Energies(data):
    '''
    Collate dataframe
    '''
    cd=collate_data(data)
    return cd.loc['DETCI@energy'],cd.loc['v2RDM@energy']

def genXy(data):
    ''' 
    Seperate targets (y) from features (X)
    '''
    cd=collate_data(data)
    
    y=cd.loc['DETCI@energy']-cd.loc['v2RDM@energy']
    X=cd.drop(['DETCI@energy','v2RDM@energy','v2RDM@Nalpha','v2RDM@Nbeta','v2RDM@Nact','v2RDM@S2']).T
                
    return X,y

In [None]:
# How the data is partitioned
systems=sorted([i.split('/')[-1] for i in glob('./singlet_doublet/*/*')])
spins=['singlet_doublet','triplet_quartet']
basissets=['STO-3G','6-31G','cc-PVDZ']

# Grab spins
s2_dict={sp:{sy:{bs:gen_df(glob(f'./{sp}/v2rdm_ontop_fci_eval_t1t2/{sy}/{bs}/*')[0]).loc['v2RDM@S2'].to_numpy() for bs in basissets} for sy in systems} for sp in spins}

# 
groundstates=[]
excitedstates=[]
for system in systems:
    for basis in basissets:
        sd_DETCI_E,sd_v2RDM_E=Energies(glob(f'./{spins[0]}/v2rdm_ontop_fci_eval_t1t2/{system}/{basis}/*'))
        tq_DETCI_E,tq_v2RDM_E=Energies(glob(f'./{spins[1]}/v2rdm_ontop_fci_eval_t1t2/{system}/{basis}/*'))
        if sd_DETCI_E.min()<tq_DETCI_E.min():
            # print(system,basis,sd_DETCI_E.min(),tq_DETCI_E.min())
            groundstates.append(('singlet_doublet',system,basis))
            excitedstates.append(('triplet_quartet',system,basis))
        else:
            groundstates.append(('triplet_quartet',system,basis))
            excitedstates.append(('singlet_doublet',system,basis))
            

# Choose your basis set, this example uses cc-PVDZ            
rads=[]

for st,sy,bs in tqdm(groundstates+excitedstates, desc='Data', position=0, leave=True):
    if bs==basis:
        rads.append((sy+'_'+st,genXy(glob(f'./{st}/v2rdm_ontop_fci_eval_t1t2/{sy}/{bs}/*'))))
    
system_dict=dict(rads)       


# Remove data based on violations
violations=['v2RDM@RMSViolation(T1)','v2RDM@RMSViolation(T2)']
popped={}
for k,v in system_dict.copy().items():
    if any(v[0][violations].mean()<=1e-6):
        popped[k]=v
        del system_dict[k]
        


# Machine learning

In [None]:
# Split data into train and testing 
key_train={}
key_test={}
X_train=[]
y_train=[]

X_test=[]
y_test=[]


for k,(X,y) in system_dict.items():
    train,test=list(y.index[0::2]),list(y.index[1::2])
    key_train[k]=train
    key_test[k]=test    
    X_train.append(X.loc[train])
    y_train.append(pd.DataFrame(y.loc[train]).rename(columns={0:k}))
    X_test.append(X.loc[test])
    y_test.append(pd.DataFrame(y.loc[test]).rename(columns={0:k}))
    
X_train=pd.concat(X_train)
ytraindf=pd.concat(y_train,axis=1)

X_test=pd.concat(X_test)
ytestdf=pd.concat(y_test,axis=1)




trainidx=X_train.index
traincol=X_train.columns

testidx=X_test.index
testcol=X_test.columns

y_train=ytraindf
y_test=ytestdf

# Scale the features
scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
y_train = y_train.T.stack().to_numpy()
y_test = y_test.T.stack().to_numpy()

X_train = pd.DataFrame(X_train,index=trainidx,columns=traincol)
X_test = pd.DataFrame(X_test,index=testidx,columns=testcol)


# Find optimal hyperparameters for KRR
parameters = {'kernel': ['rbf'],'alpha':np.logspace(-6,6,7),'gamma':np.logspace(-6,6,7)}
GridSearch = GridSearchCV(KernelRidge(),param_grid=parameters,cv=5,verbose=0).fit(X_train,y_train)
model=GridSearch.best_estimator_.fit(X_train,y_train)



# Predicted values
train_pred=model.predict(X_train)
test_pred=model.predict(X_test)

# Evaluation metrics
stats={'Train':len(train),
       'Test':len(test),
       'Train MAPE':mean_absolute_percentage_error(y_train,train_pred)*100,
       'Test MAPE':mean_absolute_percentage_error(y_test,test_pred)*100,
       'Train R2':r2_score(y_train,train_pred),
       'Test R2':r2_score(y_test,test_pred),
       'Train MAE':mean_absolute_error(y_train,train_pred),
       'Test MAE':mean_absolute_error(y_test,test_pred),
       'Train RMSE':mean_squared_error(y_train,train_pred, squared=False),
       'Test RMSE':mean_squared_error(y_test,test_pred, squared=False)}

df_stats=pd.DataFrame.from_dict(stats,orient='index')


if savefiles==True:
    df_stats.rename(columns={0:'Stats'}).to_excel(f'{basis}_bigstackstats.xlsx')

# Machine Learning Analysis

In [None]:
# Plotting colors
colormap=sns.color_palette('rocket',2)

In [None]:
# A collection of data used for plotting
# GS-ground state
# ES-excited state
Y=pd.concat([pd.DataFrame(y,columns=[k]) for k,(X,y) in system_dict.items()],axis=1)
ytrain_preddf=pd.DataFrame(train_pred,index=ytraindf.T.stack().index).unstack().T.droplevel(0)
ytest_preddf=pd.DataFrame(test_pred,index=ytestdf.T.stack().index).unstack().T.droplevel(0)

GS_df=pd.concat([Y[f'{nam}_{st}'] for (st,nam,bs) in groundstates if bs==basis and f'{nam}_{st}' in Y.columns],axis=1)
ES_df=pd.concat([Y[f'{nam}_{st}'] for (st,nam,bs) in excitedstates if bs==basis and f'{nam}_{st}' in Y.columns],axis=1)
all_df=pd.concat([GS_df,ES_df],axis=1).sort_index(axis=1)

# Reformat the data
GS_ytrain_df=ytraindf[GS_df.columns]
ES_ytrain_df=ytraindf[ES_df.columns]
GS_ytrain_preddf=ytrain_preddf[GS_df.columns]
ES_ytrain_preddf=ytrain_preddf[ES_df.columns]

GS_ytest_df=ytestdf[GS_df.columns]
ES_ytest_df=ytestdf[ES_df.columns]
GS_ytest_preddf=ytest_preddf[GS_df.columns]
ES_ytest_preddf=ytest_preddf[ES_df.columns]

# The calculated energies
GS_energies={nam:dict(zip(['FCI','v2RDM'],Energies(glob(f'./{st}/v2rdm_ontop_fci_eval_t1t2/{nam}/{bs}/*')))) for (st,nam,bs) in groundstates if bs==basis}
ES_energies={nam:dict(zip(['FCI','v2RDM'],Energies(glob(f'./{st}/v2rdm_ontop_fci_eval_t1t2/{nam}/{bs}/*')))) for (st,nam,bs) in excitedstates if bs==basis}


In [None]:
# Plot machine learning results versus the v2RDM and FCI error... FCI error is 0
fontsize = 10
plt.rcParams.update({'font.size': fontsize})
plt.rc('font', size=fontsize)          # controls default text sizes
plt.rc('axes', titlesize=fontsize)     # fontsize of the axes title
plt.rc('axes', labelsize=fontsize)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=fontsize)    # fontsize of the tick labels
plt.rc('ytick', labelsize=fontsize)    # fontsize of the tick labels
plt.rc('legend', fontsize=fontsize)    # legend fontsize
plt.rc('figure', titlesize=fontsize)  # fontsize of the figure title

for idx,nam in enumerate(GS_df.columns):
    fig,ax=plt.subplots(1,2,figsize=(10,5),sharey=True,sharex=True)
    if 'triplet_quartet' in nam:
        i=nam.replace('_triplet_quartet','')
    if 'singlet_doublet' in nam:
        i=nam.replace('_singlet_doublet','')
        
    GS_CI,GS_rdm=GS_energies[i]['FCI'],GS_energies[i]['v2RDM']
    ES_CI,ES_rdm=ES_energies[i]['FCI'],ES_energies[i]['v2RDM']

    for j in list(GS_df.columns):
        if f'{i}_' in j:
            train=key_train[j]
            test=key_test[j]
            
            ax[0].plot(GS_df[j],'k--',label='E$_{CI}$-E$_{v2RDM}$')
            ax[0].plot(range(len(GS_df[j])),len(GS_df[j])*[0],'k-')
            ax[0].plot(GS_CI.loc[train]-(GS_ytrain_preddf[j]+GS_rdm.loc[train]),'o',color=colormap[0],label='E$_{CI}$-E$_{DDv2RDM}$ (Train)')
            ax[0].plot(GS_CI.loc[test]-(GS_ytest_preddf[j]+GS_rdm.loc[test]),'x',color=colormap[0],label='E$_{CI}$-E$_{DDv2RDM}$ (Test)')
            
            ax[0].set_xlim(.5,3)
            ax[0].set_ylim(-0.01,0.15)
            ax[0].legend(loc=2)
            ax[0].set_ylabel('Deviation (E$_{h}$)')
            ax[0].set_title(f'{i} Ground Spin State')
            ax[0].set_xlabel('Bond Length (Å)')            

    for j in list(ES_df.columns):
        if f'{i}_' in j:     
            print(i,nam,j)
            train=key_train[j]
            test=key_test[j]
            
            ax[1].plot(ES_df[j],'k--',label='E$_{CI}$-E$_{v2RDM}$')
            ax[1].plot(range(len(ES_df[j])),len(ES_df[j])*[0],'k-')
            ax[1].plot(ES_CI.loc[train]-(ES_ytrain_preddf[j]+ES_rdm.loc[train]),'o',color=colormap[1],label='E$_{CI}$-E$_{DDv2RDM}$ (Train)')
            ax[1].plot(ES_CI.loc[test]-(ES_ytest_preddf[j]+ES_rdm.loc[test]),'x',color=colormap[1],label='E$_{CI}$-E$_{DDv2RDM}$ (Test)')
        
            ax[1].set_xlim(.5,3)
            ax[1].legend(loc=2)
            ax[1].set_title(f'{i} Excited Spin State')            
            ax[1].set_xlabel('Bond Length (Å)')            
    # if idx==21:
        
        

    plt.tight_layout()
    if savefiles==True:
        plt.savefig(f'{basis}_bigstack_stacked_{idx}.png',dpi=300,bbox_inches='tight')
    plt.show()

In [None]:
# plot regression parity plots
from mpl_toolkits.axes_grid1 import make_axes_locatable
fontsize = 16
plt.rcParams.update({'font.size': fontsize})
plt.rc('font', size=fontsize)          # controls default text sizes
plt.rc('axes', titlesize=fontsize)     # fontsize of the axes title
plt.rc('axes', labelsize=fontsize)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=fontsize)    # fontsize of the tick labels
plt.rc('ytick', labelsize=fontsize)    # fontsize of the tick labels
plt.rc('legend', fontsize=fontsize)    # legend fontsize
plt.rc('figure', titlesize=fontsize)  # fontsize of the figure title

train_plot=pd.DataFrame([y_train,train_pred],index=['True Train','Predicted Train']).T
test_plot=pd.DataFrame([y_test,test_pred],index=['True Test','Predicted Test']).T

fig, axes = plt.subplots(figsize=(15,8),nrows=1, ncols=2)


axes[0].scatter(train_plot['True Train'],train_plot['Predicted Train'],color = colormap[0],edgecolors='k')
axes[0].text(train_plot['True Train'].min(),train_plot['True Train'].max()-train_plot['True Train'].quantile(0.5),r"R$^{2}$="+f"{r2_score(train_plot['True Train'],train_plot['Predicted Train']):.4f}\nRMSE={mean_squared_error(train_plot['True Train'],train_plot['Predicted Train'],squared=False):.4e}"+r" E$_{h}$")

minx=-5e-3
maxx=max(train_plot['True Train'])+max(train_plot['True Train'])/10
axes[0].set_xlim(minx,maxx)
axes[0].set_ylim(minx,maxx)
axes[0].set_xticks(axes[0].get_xticks())
axes[0].set_yticks(axes[0].get_yticks())
axes[0].set_xlabel('True Target Value (E$_{h}$)')
axes[0].set_ylabel('Predicted Target Value (E$_{h}$)')    

divider = make_axes_locatable(axes[0])
axHistx = divider.append_axes("top", 1.2, pad=0.4, sharex=axes[0])
axHisty = divider.append_axes("right", 1.2, pad=0.4, sharey=axes[0])


# make some labels invisible
axHistx.xaxis.set_tick_params(labelbottom=False)
axHisty.yaxis.set_tick_params(labelleft=False)

# now determine nice limits by hand:
binwidth = 0.01
x=train_plot['True Train']
y=train_plot['Predicted Train']
xymax = max(np.max(np.abs(x)), np.max(np.abs(y)))
lim = (int(xymax/binwidth) + 1)*binwidth

bins = np.arange(-lim, lim + binwidth, binwidth)
axHistx.hist(x, bins=bins,color=colormap[0],density=True,edgecolor='k',stacked=True)
axHisty.hist(y, bins=bins, orientation='horizontal',color=colormap[0],density=True,edgecolor='k',stacked=True)

axHistx.set_ylabel('Probability',fontsize=fontsize-2)

axHisty.yaxis.set_label_position("right")
# axHisty.yaxis.tick_right()
# axHisty.set_ylabel('Probability', rotation=270,labelpad=15)
axHisty.set_xlabel('Probability',fontsize=fontsize-2)

# the xaxis of axHistx and yaxis of axHisty are shared with axScatter,
# thus there is no need to manually adjust the xlim and ylim of these
# axis.
axHistx.set_title('Train')
axHistx.set_yticks(np.linspace(0,70,3))
axHisty.set_xticks(np.linspace(0,70,3))



# 2
axes[1].scatter(test_plot['True Test'],test_plot['Predicted Test'],color = colormap[1],edgecolors='k')
axes[1].text(train_plot['True Train'].min(),train_plot['True Train'].max()-train_plot['True Train'].quantile(0.5),r"R$^{2}$="+f"{r2_score(test_plot['True Test'],test_plot['Predicted Test']):.4f}\nRMSE={mean_squared_error(test_plot['True Test'],test_plot['Predicted Test'],squared=False):.4e}"+r" E$_{h}$")
minx=-5e-3
maxx=max(test_plot['True Test'])+max(test_plot['True Test'])/10
axes[1].set_xlim(minx,maxx)
axes[1].set_ylim(minx,maxx)
axes[1].set_xticks(axes[1].get_xticks())
axes[1].set_yticks(axes[1].get_yticks())
axes[1].set_xlabel('True Target Value (E$_{h}$)')
axes[1].set_ylabel('Predicted Target Value (E$_{h}$)')    

divider = make_axes_locatable(axes[1])
axHistx = divider.append_axes("top", 1.2, pad=0.4, sharex=axes[1])
axHisty = divider.append_axes("right", 1.2, pad=0.4, sharey=axes[1])


# make some labels invisible
axHistx.xaxis.set_tick_params(labelbottom=False)
axHisty.yaxis.set_tick_params(labelleft=False)

# now determine nice limits by hand:
x=test_plot['True Test']
y=test_plot['Predicted Test']
binwidth = 0.01
xymax = max(np.max(np.abs(x)), np.max(np.abs(y)))
lim = (int(xymax/binwidth) + 1)*binwidth

bins = np.arange(-lim, lim + binwidth, binwidth)
axHistx.hist(x, bins=bins,color=colormap[1],density=True,edgecolor='k',stacked=True)
axHisty.hist(y, bins=bins, orientation='horizontal',color=colormap[1],density=True,edgecolor='k',stacked=True)

axHistx.set_ylabel('Probability',fontsize=fontsize-2)
axHistx.set_title('Test')
axHisty.yaxis.set_label_position("right")
# axHisty.yaxis.tick_right()
# axHisty.set_ylabel('Probability', rotation=270,labelpad=15)
axHisty.set_xlabel('Probability',fontsize=fontsize-2)

# the xaxis of axHistx and yaxis of axHisty are shared with axScatter,
# thus there is no need to manually adjust the xlim and ylim of these
# axis.

axHistx.set_yticks(np.linspace(0,70,3))
axHisty.set_xticks(np.linspace(0,70,3))
plt.tight_layout(pad=1, w_pad=1, h_pad=1.0)
if savefiles==True:
    plt.savefig(f'{basis}_bigstack_error.png',dpi=300,bbox_inches='tight')    
plt.show()

In [None]:
# Grab spins
s2_dict={sp:{sy:{bs:gen_df(glob(f'./{sp}/v2rdm_ontop_fci_eval_t1t2/{sy}/{bs}/*')[0]).loc['v2RDM@S2'].to_numpy() for bs in basissets} for sy in systems} for sp in spins}
alpha_dict={sp:{sy:{bs:gen_df(glob(f'./{sp}/v2rdm_ontop_fci_eval_t1t2/{sy}/{bs}/*')[0]).loc['v2RDM@Nalpha'].to_numpy() for bs in basissets} for sy in systems} for sp in spins}
beta_dict={sp:{sy:{bs:gen_df(glob(f'./{sp}/v2rdm_ontop_fci_eval_t1t2/{sy}/{bs}/*')[0]).loc['v2RDM@Nbeta'].to_numpy() for bs in basissets} for sy in systems} for sp in spins}

gs_df=[]
for st,sy,bs in tqdm(groundstates, desc='Data', position=0, leave=True):
    gs_df.append((st,sy,bs,0.5*int(alpha_dict[st][sy][bs]-beta_dict[st][sy][bs])))
gs_df=pd.DataFrame(gs_df,columns=['spin_dir','system','basis','S'])


es_df=[]
for st,sy,bs in tqdm(excitedstates, desc='Data', position=0, leave=True):
    es_df.append((st,sy,bs,0.5*int(alpha_dict[st][sy][bs]-beta_dict[st][sy][bs])))
es_df=pd.DataFrame(es_df,columns=['spin_dir','system','basis','S'])

es_df['2S+1']=2*es_df['S']+1
gs_df['2S+1']=2*gs_df['S']+1
sp_df=pd.concat([gs_df,es_df])

In [None]:
# Spread of the target values for the ground and excited state molecules
fig,(ax1,ax2)=plt.subplots(1,2,figsize=(15,5),sharey=True)
ax1.plot(GS_df.T.columns,GS_df.T.max(),'-x',label='Max',color=colormap[0])
ax1.plot(GS_df.T.columns,GS_df.T.mean(),'-',label='Mean',color=colormap[0])
ax1.plot(GS_df.T.columns,GS_df.T.min(),'-^',label='Min',color=colormap[0])
ax1.fill_between(GS_df.T.columns,GS_df.T.min(),GS_df.T.mean(),alpha=0.4,color=colormap[0])
ax1.fill_between(GS_df.T.columns,GS_df.T.mean(),GS_df.T.max(),alpha=0.4,color=colormap[0])
ax1.legend()
ax1.set_xticks(np.linspace(0.5,4,15))
ax1.set_xlim(0.5,3)
ax1.set_ylim(-0.001,0.10)

ax1.set_xlabel('Bond Length ($\AA$)')
ax1.set_ylabel('E$_{CI}$-E$_{v2RDM}$ (E$_{h}$)')
ax1.set_title('Ground Spin States')

ax2.plot(ES_df.T.columns,ES_df.T.max(),'-x',label='Max',color=colormap[1])
ax2.plot(ES_df.T.columns,ES_df.T.mean(),'-',label='Mean',color=colormap[1])
ax2.plot(ES_df.T.columns,ES_df.T.min(),'-^',label='Min',color=colormap[1])
ax2.fill_between(ES_df.T.columns,ES_df.T.min(),ES_df.T.mean(),alpha=0.4,color=colormap[1])
ax2.fill_between(ES_df.T.columns,ES_df.T.max(),ES_df.T.mean(),alpha=0.4,color=colormap[1])
ax2.set_xticks(np.linspace(0.5,4,15))
ax2.set_xlim(0.5,3)
ax2.set_xlabel('Bond Length ($\AA$)')
ax2.set_title('Excited Spin States')
ax2.legend()
plt.tight_layout()
if savefiles==True:
    plt.savefig(f'{basis}_GS_ES_spreads.png',dpi=300,bbox_inches='tight')
plt.show()

In [None]:
statdf=pd.DataFrame(columns=system_dict.keys(),index=['Train R2','Test R2','Train MAPE','Test MAPE','Train RMSE','Test RMSE','State','system'])
survived_ES=set(statdf.columns).intersection(set([sy+'_'+st for st,sy,bs in excitedstates if bs==basis]))
survived_GS=set(statdf.columns).intersection(set([sy+'_'+st for st,sy,bs in groundstates if bs==basis]))

for i in system_dict.keys():
    statdf[i].loc['system']=i.split('_')[0]
    statdf[i].loc['Train R2']=r2_score(ytraindf[i].dropna(),ytrain_preddf[i].dropna())
    statdf[i].loc['Test R2']=r2_score(ytestdf[i].dropna(),ytest_preddf[i].dropna())
    statdf[i].loc['Train MAPE']=mean_absolute_percentage_error(ytraindf[i].dropna(),ytrain_preddf[i].dropna())
    statdf[i].loc['Test MAPE']=mean_absolute_percentage_error(ytestdf[i].dropna(),ytest_preddf[i].dropna())
    statdf[i].loc['Train RMSE']=mean_squared_error(ytraindf[i].dropna(),ytrain_preddf[i].dropna(), squared=False)
    statdf[i].loc['Test RMSE']=mean_squared_error(ytestdf[i].dropna(),ytest_preddf[i].dropna(), squared=False)
for i in list(survived_ES):
    statdf[i].loc['State']='ES'
for i in list(survived_GS): 
    statdf[i].loc['State']='GS'  

if savefiles==True:    
    statdf.to_excel(f'{basis}_stats.xlsx')


# SHapley Additive exPlanation (SHAP) Analysis

In [None]:
# Compute the SHAP values
cc_xtestdf=pd.concat([X for k,(X,y) in system_dict.items()],axis=0)
# .iloc[CC_singlet_doublet_idx]
cc_xtest=pd.DataFrame(scaler.transform(cc_xtestdf),columns=cc_xtestdf.columns,index=cc_xtestdf.index)
explainer = shap.Explainer(model.predict, cc_xtest)
shap_values = explainer(cc_xtest)

plt.figure(figsize=(8,10))
color_map=sns.color_palette('rocket',6)
fontsize = 16
upd_feat_dict={'entropy(D1a)':'entropy(${}^1\mathbf{D}_{\\alpha}$)', 'entropy(D1b)':'entropy(${}^1\mathbf{D}_{\\beta}$)', 'entropy(Q1a)':'entropy(${}^1\mathbf{Q}_{\\alpha}$)', 'entropy(Q1b)':'entropy(${}^1\mathbf{Q}_{\\beta}$)', 'entropy(D2aa)':'entropy(${}^2\mathbf{D}_{\\alpha \\alpha}$)', 'entropy(D2bb)':'entropy(${}^2\mathbf{D}_{\\beta \\beta}$)', 'entropy(D2ab)':'entropy(${}^2\mathbf{D}_{\\alpha \\beta}$)', 'entropy(Q2aa)':'entropy(${}^2\mathbf{Q}_{\\alpha \\alpha}$)', 'entropy(Q2bb)':'entropy(${}^2\mathbf{Q}_{\\beta \\beta}$)', 'entropy(Q2ab)':'entropy(${}^2\mathbf{Q}_{\\alpha \\beta}$)', 'entropy(G2ab)':'entropy(${}^2\mathbf{G}_{\\alpha \\beta}$)', 'entropy(G2ba)':'entropy(${}^2\mathbf{G}_{\\beta \\alpha}$)', 'entropy(G2aa/bb)':'entropy(${}^2\mathbf{G}_{\\alpha\\alpha / \\beta\\beta}$)', 'Violation%(T1)':'Percent Violation T1', 'Violation%(T2)':'Percent Violation T2', 'AveViolation(T1)':'Average Violation T1', 'AveViolation(T2)':'Average Violation T2', 'RMSViolation(T1)':'RMS Violation T1', 'RMSViolation(T2)':'RMS Violation T2', 'VarViolation(T1)':'Variance Violation T1', 'VarViolation(T2)':'Variance Violation T2', '||del2(aa)||^2':'$|| ^{2}{\\Delta}_{\\alpha \\alpha}||^{2}$', '||del2(bb)||^2':'$|| ^{2}{\\Delta}_{\\beta \\beta}||^{2}$', '||del2(ab)||^2':'$|| ^{2}{\\Delta}_{\\alpha \\beta}||^{2}$', 'Tr[del2(aa)]':'Tr$(^{2}{\\Delta}_{\\alpha \\alpha})$', 'Tr[del2(bb)]':'Tr$(^{2}{\\Delta}_{\\beta \\beta})$', 'Tr[del2(ab)]':'Tr$(^{2}{\\Delta}_{\\alpha \\beta})$'}
plt.rcParams.update({'font.size': fontsize})
plt.rc('font', size=fontsize)          # controls default text sizes
plt.rc('axes', titlesize=fontsize)     # fontsize of the axes title
plt.rc('axes', labelsize=fontsize)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=fontsize)    # fontsize of the tick labels
plt.rc('ytick', labelsize=fontsize)    # fontsize of the tick labels
plt.rc('legend', fontsize=fontsize)    # legend fontsize
plt.rc('figure', titlesize=fontsize)  # fontsize of the figure title

df=pd.DataFrame(dict(zip([upd_feat_dict[i.replace('v2RDM@','')] for i in shap_values.feature_names],shap_values.abs.values.mean(axis=0).reshape(-1,))).items())
df.columns=['Features','|SHAP|']
df=df.sort_values('|SHAP|')

ax=df.plot.barh(x='Features',y='|SHAP|',color=color_map[3], figsize=(8, 10),legend=False)
ax.bar_label(ax.containers[0], fmt='%.4e',fontsize=12,padding=1)
plt.xticks(np.linspace(0,3e-3,4))
plt.xlim(0,3e-3)
plt.xlabel('mean(|SHAP value|)')
plt.tight_layout()
if savefiles==True:
    plt.savefig(f'{basis}_barall.png',dpi=300,bbox_inches='tight')
plt.show()

if savefiles==True:
    df.to_excel(f'{basis}_SHAP.xlsx')