Import Packages

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import randint
import seaborn as sns
from sklearn.inspection import PartialDependenceDisplay
from sklearn.model_selection import train_test_split,RandomizedSearchCV,KFold,learning_curve
from sklearn.ensemble import RandomForestRegressor
from lightgbm import LGBMRegressor
import xgboost as xgb
from sklearn.metrics import r2_score,root_mean_squared_error, mean_absolute_error, mean_absolute_percentage_error
import shap

Load dataset

In [None]:
data=pd.read_csv('/content/Main_database.csv')
data=data.drop(['Source'],axis=1)
data.sort_values(by=['Filename','T(K)'],inplace=True)
data=data.reset_index(drop=True)
data.head()

Check for missing and zero values

In [None]:
print(data.isna().sum())
print((data == 0).sum())

Remove missing and zero values

In [None]:
data.dropna(inplace=True)
data = data[~(data == 0).any(axis=1)]

In [None]:
data.info()

In [None]:
data['Metal node'].value_counts()

Function creation

In [None]:
def datasplit(dataframe,coltorem,stratify):
  X=dataframe.drop(coltorem,axis=1)
  y=dataframe['Adsorption']
  X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.2,random_state=42,stratify=stratify)
  return (X_train,X_test,y_train,y_test)

In [None]:
def featureimp(model,dataframe):
  featureimp=model.feature_importances_
  featurename=dataframe.columns
  importance_df = pd.DataFrame({
      'Feature': featurename,
      'Importance': featureimp
  }).sort_values(by='Importance', ascending=False)

  ax = sns.barplot(data=importance_df, x='Importance', y='Feature')
  ax.bar_label(ax.containers[0], fmt='%.3f', padding=3)
  plt.show()

def viz(ytest,ypred):
  predf=pd.DataFrame({'Actual':ytest,'Predicted':ypred})
  sns.scatterplot(data=predf,x='Actual',y='Predicted',marker='+',label='Predicted')
  sns.scatterplot(data=predf,x='Actual',y='Actual',marker='o',label='Actual')
  plt.legend()
  plt.show()

In [None]:
def plot_learning_curves(model, mdlnm, X, y, cv=5,scoring2="r2"):

    train_sizes, train_scores, val_scores = learning_curve(
        model, X, y, cv=cv, n_jobs=-1, scoring=scoring2,
        train_sizes=[50,100,250,750,1500,2250,3000,4000], shuffle=True, random_state=42
    )

    train_scores_mean = np.mean(train_scores, axis=1)
    val_scores_mean = np.mean(val_scores, axis=1)
    plt.plot(train_sizes, train_scores_mean, label='Training score',marker='o')
    plt.plot(train_sizes, val_scores_mean, label='Validation score',marker='o')
    plt.xlabel('Number of samples in training set')
    plt.ylabel('R²')
    plt.title(f'Learning Curve for {mdlnm}')
    plt.legend()
    plt.grid(False)
    plt.tight_layout()
    plt.show()
    return train_scores_mean, val_scores_mean


In [None]:
param_dist_rf = {
    'n_estimators': randint(5,500),
    'max_depth': randint(1,50),
    'max_features': randint(2,8),
    'min_samples_split':randint(2,20,2),
    'min_samples_leaf':randint(2,20,2)
}

param_dist_lg = {
    'n_estimators': randint(1,1000),
    'max_depth': randint(1,14),
    'learning_rate':[0.01,0.05,0.1,0.5,0.9],
    'subsample':[0.1,0.3,0.5,0.8,1]

}

param_dist_xg = {
    'n_estimators': randint(1,1000),
    'max_depth': randint(1,16),
    'learning_rate':[0.01,0.05,0.1,0.5,0.9],
    'subsample':[0.2,0.3,0.5,0.8,1],
    'gamma':[0, 0.1, 0.5, 1],

}

rfmodel=RandomForestRegressor(random_state=22)
xgmodel=xgb.XGBRegressor(objective='reg:pseudohubererror',random_state=22)
lgmodel = LGBMRegressor(objective='regression_l1',random_state=22,verbosity=-1)

In [None]:
def modeltrain(model,param,X_train,y_train):
  kf=KFold(n_splits=5, shuffle=True, random_state=32)
  random_search = RandomizedSearchCV(
      estimator=model,
      param_distributions=param,
      n_iter=50,
      cv=kf,
      scoring='r2',
      verbose=2,
      random_state=42,
      n_jobs=-1
  )

  random_search.fit(X_train, y_train)
  bestscore=random_search.best_score_
  bestmodel=random_search.best_estimator_
  return (bestmodel,bestscore)

def modelresult(bestmodel,X_train,y_train,X_test,y_test):
  bestmodel.fit(X_train,y_train)
  ytrainpred=bestmodel.predict(X_train)
  modelpred=bestmodel.predict(X_test)
  trainscore=bestmodel.score(X_train,y_train)
  testscore=bestmodel.score(X_test,y_test)
  rmsetrain=root_mean_squared_error(y_train,ytrainpred)
  rmse=root_mean_squared_error(y_test,modelpred)
  maetrain=mean_absolute_error(y_train,ytrainpred)
  mae=mean_absolute_error(y_test,modelpred)
  mape=mean_absolute_percentage_error(y_test,modelpred)
  return (trainscore,testscore,rmsetrain,rmse,maetrain,mae,mape,ytrainpred,modelpred)

Analysis for learning curve

In [None]:
X=data.drop(['Filename','adsorption (mmol/g)','Metal node'],axis=1)
y=data['adsorption (mmol/g)']

In [None]:
train_score,val_score=plot_learning_curves(RandomForestRegressor(random_state=42),"RF",X,y)

In [None]:
train_score2,val_score2=plot_learning_curves(LGBMRegressor(random_state=42),"LGB",X,y)


In [None]:
train_score3,val_score3=plot_learning_curves(xgb.XGBRegressor(random_state=42),"XGB",X,y)

Dataset creation

In [None]:
df=data.sample(3000,random_state=42)
df=df.reset_index(drop=True)
column_to_move = 'adsorption (mmol/g)'
new_columns = [col for col in df.columns if col != column_to_move] + [column_to_move]
df = df[new_columns]
df = df.rename(columns={'T(K)':'T','Pressure(bar)': 'P', 'hoa (kcal/mol)':'HOA','Metal%': 'M%','UC_volume':'UCV','Density':'D','adsorption (mmol/g)':'Adsorption','AVAf':'AVAF'})

In [None]:
df.info()

In [None]:
df['Metal node'].value_counts()

EDA for the dataset

In [None]:
numdf=df.select_dtypes(include='number')
for feat in numdf:
  sns.histplot(df[feat],color='#83267A')
  plt.xlabel(f"{feat}",fontsize=15)
  plt.ylabel("Count",fontsize=15)
  plt.show()

In [None]:
order = df['Metal node'].value_counts().index
sns.countplot(x='Metal node', data=df, order=order,color='#83267A')
plt.xlabel('Metal Node',fontsize=15)
plt.ylabel('Count',fontsize=15)
plt.show()

Correlation plot

In [None]:
plt.figure(figsize=(12,8))
numdf=df.select_dtypes(include='number')
sns.heatmap(numdf.corr(),annot=True,fmt='.2f')
plt.xticks(fontsize=14, rotation=0)
plt.yticks(fontsize=14, rotation=0)
plt.show()

Split data into training and testing set

In [None]:
rem=['Filename','Metal node','Adsorption']
X_train,X_test,y_train,y_test=datasplit(df,rem,None)

Model training with hyperparmeter tuning

In [None]:
bestmodelrf,bestscorerf=modeltrain(rfmodel,param_dist_rf,X_train,y_train)

Fitting 5 folds for each of 50 candidates, totalling 250 fits


In [None]:
bestmodellg,bestscorelg=modeltrain(lgmodel,param_dist_lg,X_train,y_train)

Fitting 5 folds for each of 50 candidates, totalling 250 fits


In [None]:
bestmodelxg,bestscorexg=modeltrain(xgmodel,param_dist_xg,X_train,y_train)

Fitting 5 folds for each of 50 candidates, totalling 250 fits


Best hyperparameters

In [None]:
print(bestmodelrf)
print(bestmodellg)
print(bestmodelxg)
print(bestscorerf)
print(bestscorelg)
print(bestscorexg)

Model performance metrics

In [None]:
trainscorerf,testscorerf,rmsetrnr,rmserf,maetrnr,maerf,maperf,rfytpred,rfpred=modelresult(bestmodelrf,X_train,y_train,X_test,y_test)
trainscorelg,testscorelg,rmsetrlg,rmselg,maetrnlg,maelg,mapelg,lgytpred,lgpred=modelresult(bestmodellg,X_train,y_train,X_test,y_test)
trainscorexg,testscorexg,rmsetrxg,rmsexg,maetrnxg,maexg,mapexg,xgytpred,xgpred=modelresult(bestmodelxg,X_train,y_train,X_test,y_test)

Scatter plot for true and predicted values

In [None]:
rfplot = sns.jointplot(
    x=y_test,
    y=rfpred,
    label="Test data",
    color='#83267A'
)
min_val = min(y_test.min(), rfpred.min())
max_val = max(y_test.max(), rfpred.max())
rfplot.ax_joint.plot([min_val, max_val], [min_val, max_val], color="#32CD32",linestyle='--',linewidth=3,label='X=Y')
rfplot.ax_joint.set_xlabel("Actual CO₂ uptake, mmol/g", fontsize=15)
rfplot.ax_joint.set_ylabel("Predicted CO₂ uptake, mmol/g",fontsize=15)
rfplot.ax_joint.legend(loc='upper left',fontsize=12)
plt.show()

lgplot = sns.jointplot(
    x=y_test,
    y=lgpred,
    label="Test data",
    color='#83267A'
)
min_val = min(y_test.min(), lgpred.min())
max_val = max(y_test.max(), lgpred.max())
lgplot.ax_joint.plot([min_val, max_val], [min_val, max_val], color="#32CD32",linestyle='--',linewidth=3,label='X=Y')
lgplot.ax_joint.set_xlabel("Actual CO₂ uptake, mmol/g",fontsize=15)
lgplot.ax_joint.set_ylabel("Predicted CO₂ uptake, mmol/g",fontsize=15)
lgplot.ax_joint.legend(loc='upper left',fontsize=12)
plt.show()

xgplot = sns.jointplot(
    x=y_test,
    y=xgpred,
    label="Test data",
    color='#83267A'
)
min_val = min(y_test.min(), xgpred.min())
max_val = max(y_test.max(), xgpred.max())
xgplot.ax_joint.plot([min_val, max_val], [min_val, max_val], color="#32CD32",linestyle='--',linewidth=3,label='X=Y')
xgplot.ax_joint.set_xlabel("Actual CO₂ uptake, mmol/g",fontsize=15)
xgplot.ax_joint.set_ylabel("Predicted CO₂ uptake, mmol/g",fontsize=15)
xgplot.ax_joint.legend(loc='upper left',fontsize=12)
plt.show()

Residual plots

In [None]:
residual1 = y_test - rfpred
residual2 = y_test - lgpred
residual3 = y_test - xgpred

#residual plots for RF model
fig, axes = plt.subplots(2, 1)

sns.histplot(residual1, kde=True, bins=30, ax=axes[0],color='#83267A')
axes[0].set_xlabel("Residuals",fontsize=15)
axes[0].set_ylabel("Frequency",fontsize=15)

sns.scatterplot(x=rfpred, y=residual1, ax=axes[1],color='#83267A')
axes[1].axhline(0, color="#32CD32",linestyle='--',linewidth=3)
axes[1].set_xlabel("Predicted CO₂ uptake (mmol/g)",fontsize=15)
axes[1].set_ylabel("Residuals",fontsize=15)
axes[1].set_ylim(-15, 15)
axes[1].yaxis.set_major_locator(plt.MultipleLocator(5))

plt.tight_layout()
plt.show()

#residual plots for LGB model
fig, axes = plt.subplots(2, 1)

sns.histplot(residual2, kde=True, bins=30, ax=axes[0],color='#83267A')
axes[0].set_xlabel("Residuals",fontsize=15)
axes[0].set_ylabel("Frequency",fontsize=15)

sns.scatterplot(x=lgpred, y=residual2, ax=axes[1],color='#83267A')
axes[1].axhline(0, color="#32CD32",linestyle='--',linewidth=3)
axes[1].set_xlabel("Predicted CO₂ uptake (mmol/g)",fontsize=15)
axes[1].set_ylabel("Residuals",fontsize=15)
axes[1].set_ylim(-15, 15)
axes[1].yaxis.set_major_locator(plt.MultipleLocator(5))

plt.tight_layout()
plt.show()

#residual plots for XGB model
fig, axes = plt.subplots(2, 1)

sns.histplot(residual3, kde=True, bins=30, ax=axes[0],color='#83267A')
axes[0].set_xlabel("Residuals",fontsize=15)
axes[0].set_ylabel("Frequency",fontsize=15)

sns.scatterplot(x=xgpred, y=residual3, ax=axes[1],color='#83267A')
axes[1].axhline(0, color="#32CD32",linestyle='--',linewidth=3)
axes[1].set_xlabel("Predicted CO₂ uptake (mmol/g)",fontsize=15)
axes[1].set_ylabel("Residuals",fontsize=15)
axes[1].set_ylim(-15, 15)
axes[1].yaxis.set_major_locator(plt.MultipleLocator(5))

plt.tight_layout()
plt.show()

In [None]:
trured=pd.DataFrame({'Actual':y_test,'residual1':(y_test-rfpred),'residual2':(y_test-lgpred),'residual3':(y_test-xgpred)})
X_testdiag=X_test.copy()
X_testdiag['residual1']=trured['residual1']
X_testdiag['residual2']=trured['residual2']
X_testdiag['residual3']=trured['residual3']
X_testdiag['adsorption (mmol/g)']=y_test
X_testdiag['rfprediction']=rfpred
X_testdiag['lgprediction']=lgpred
X_testdiag['xgprediction']=xgpred

In [None]:
def mape(y_true, y_pred):
    return np.mean(np.abs((y_true - y_pred) / y_true))

Model analysis at different temperature and pressure

In [None]:
for T in [298, 313, 338]:
    print(f"{T}K:")

    print(r2_score(X_testdiag[X_testdiag['T']==T]['adsorption (mmol/g)'],
                    X_testdiag[X_testdiag['T']==T]['rfprediction']))
    print(root_mean_squared_error(X_testdiag[X_testdiag['T']==T]['adsorption (mmol/g)'],
                                  X_testdiag[X_testdiag['T']==T]['rfprediction']))
    print(mean_absolute_error(X_testdiag[X_testdiag['T']==T]['adsorption (mmol/g)'],
                              X_testdiag[X_testdiag['T']==T]['rfprediction']))
    print(mean_absolute_percentage_error(X_testdiag[X_testdiag['T']==T]['adsorption (mmol/g)'],
                              X_testdiag[X_testdiag['T']==T]['rfprediction']))

    print(r2_score(X_testdiag[X_testdiag['T']==T]['adsorption (mmol/g)'],
                    X_testdiag[X_testdiag['T']==T]['lgprediction']))
    print(root_mean_squared_error(X_testdiag[X_testdiag['T']==T]['adsorption (mmol/g)'],
                                  X_testdiag[X_testdiag['T']==T]['lgprediction']))
    print(mean_absolute_error(X_testdiag[X_testdiag['T']==T]['adsorption (mmol/g)'],
                              X_testdiag[X_testdiag['T']==T]['lgprediction']))
    print(mean_absolute_percentage_error(X_testdiag[X_testdiag['T']==T]['adsorption (mmol/g)'],
                              X_testdiag[X_testdiag['T']==T]['lgprediction']))

    print(r2_score(X_testdiag[X_testdiag['T']==T]['adsorption (mmol/g)'],
                    X_testdiag[X_testdiag['T']==T]['xgprediction']))
    print(root_mean_squared_error(X_testdiag[X_testdiag['T']==T]['adsorption (mmol/g)'],
                                  X_testdiag[X_testdiag['T']==T]['xgprediction']))
    print(mean_absolute_error(X_testdiag[X_testdiag['T']==T]['adsorption (mmol/g)'],
                              X_testdiag[X_testdiag['T']==T]['xgprediction']))
    print(mean_absolute_percentage_error(X_testdiag[X_testdiag['T']==T]['adsorption (mmol/g)'],
                              X_testdiag[X_testdiag['T']==T]['xgprediction']))

    print("\n")

In [None]:
temperatures = [298, 313, 338]
models = {
    "RF": "rfprediction",
    "LGB": "lgprediction",
    "XGB": "xgprediction"
}
results = []
for t in temperatures:
    subset = X_testdiag[X_testdiag['T']==t]
    y_true = subset['adsorption (mmol/g)']

    for model_name, pred_col in models.items():
        y_pred = subset[pred_col]

        results.append({
            "T": t,
            "Model": model_name,
            "R²": r2_score(y_true, y_pred),
            "MAPE": mean_absolute_percentage_error(y_true, y_pred)
        })

results_df = pd.DataFrame(results)

metrics = ["R²", "MAPE"]
for i, metric in enumerate(metrics):
    sns.barplot(
        data=results_df,
        x="T",
        y=metric,
        hue="Model",
        palette = 'magma',
        saturation=1,
    )
    plt.xlabel("T",fontsize=15)
    plt.ylabel(metric,fontsize=15)
    plt.show()

In [None]:
for p in [0.1, 0.26, 0.4, 1, 3.2, 16]:

  print(f"Pressure:{p}")
  print(r2_score(X_testdiag[X_testdiag['P']==p]['adsorption (mmol/g)'],
                  X_testdiag[X_testdiag['P']==p]['rfprediction']))
  print(root_mean_squared_error(X_testdiag[X_testdiag['P']==p]['adsorption (mmol/g)'],
                                X_testdiag[X_testdiag['P']==p]['rfprediction']))
  print(mean_absolute_error(X_testdiag[X_testdiag['P']==p]['adsorption (mmol/g)'],
                            X_testdiag[X_testdiag['P']==p]['rfprediction']))
  print(mean_absolute_percentage_error(X_testdiag[X_testdiag['P']==p]['adsorption (mmol/g)'],
                            X_testdiag[X_testdiag['P']==p]['rfprediction']))

  print(r2_score(X_testdiag[X_testdiag['P']==p]['adsorption (mmol/g)'],
                  X_testdiag[X_testdiag['P']==p]['lgprediction']))
  print(root_mean_squared_error(X_testdiag[X_testdiag['P']==p]['adsorption (mmol/g)'],
                                X_testdiag[X_testdiag['P']==p]['lgprediction']))
  print(mean_absolute_error(X_testdiag[X_testdiag['P']==p]['adsorption (mmol/g)'],
                            X_testdiag[X_testdiag['P']==p]['lgprediction']))
  print(mean_absolute_percentage_error(X_testdiag[X_testdiag['P']==p]['adsorption (mmol/g)'],
                            X_testdiag[X_testdiag['P']==p]['lgprediction']))

  print(r2_score(X_testdiag[X_testdiag['P']==p]['adsorption (mmol/g)'],
                  X_testdiag[X_testdiag['P']==p]['xgprediction']))
  print(root_mean_squared_error(X_testdiag[X_testdiag['P']==p]['adsorption (mmol/g)'],
                                X_testdiag[X_testdiag['P']==p]['xgprediction']))
  print(mean_absolute_error(X_testdiag[X_testdiag['P']==p]['adsorption (mmol/g)'],
                            X_testdiag[X_testdiag['P']==p]['xgprediction']))
  print(mean_absolute_percentage_error(X_testdiag[X_testdiag['P']==p]['adsorption (mmol/g)'],
                            X_testdiag[X_testdiag['P']==p]['xgprediction']))

  print("\n")

In [None]:
models = {
    "RF": "rfprediction",
    "LGB": "lgprediction",
    "XGB": "xgprediction"
}
results2 = []
for p in [0.1, 0.26, 0.4, 1, 3.2, 16]:
    subset2 = X_testdiag[X_testdiag['P']==p]
    y_true2 = subset2['adsorption (mmol/g)']

    for model_name, pred_col in models.items():
        y_pred2 = subset2[pred_col]

        results2.append({
            "P": p,
            "Model": model_name,
            "R²": r2_score(y_true2, y_pred2),
            "MAPE": mean_absolute_percentage_error(y_true2, y_pred2)
        })

results2_df = pd.DataFrame(results2)

metrics = ["R²","MAPE"]
for i, metric in enumerate(metrics):
    sns.barplot(
        data=results2_df,
        x="P",
        y=metric,
        hue="Model",
        palette = 'magma',
        saturation=1,
    )
    plt.xlabel("P",fontsize=15)
    plt.ylabel(metric,fontsize=15)
    plt.show()


SHAP bar plot

In [None]:

explainer_rf = shap.TreeExplainer(bestmodelrf)
shap_values_rf = explainer_rf.shap_values(X_train)
shap.summary_plot(shap_values_rf, X_train, plot_type="bar",color="#83267A")

explainer_lg = shap.TreeExplainer(bestmodellg)
shap_values_lg = explainer_lg.shap_values(X_train)
shap.summary_plot(shap_values_lg, X_train, plot_type="bar",color="#83267A")

explainer_xg = shap.TreeExplainer(bestmodelxg)
shap_values_xg = explainer_xg.shap_values(X_train.values)
shap.summary_plot(shap_values_xg, X_train, plot_type="bar",color="#83267A")



SHAP beeswarm plot

In [None]:
shap.summary_plot(shap_values_rf, X_train, plot_type="dot")
shap.summary_plot(shap_values_lg, X_train, plot_type="dot")
shap.summary_plot(shap_values_xg, X_train, plot_type="dot")

Feature importance by feature types

In [None]:
shap_dict = {
    'Random Forest': shap_values_rf,
    'LightGBM': shap_values_lg,
    'XGBoost': shap_values_xg
}

feature_types = {
    'T': 'Operational',
    'P': 'Operational',
    'HOA': 'Energetic',
    'M%': 'Chemical',
    'UCV': 'Geometric',
    'D': 'Geometric',
    'ASA': 'Geometric',
    'AVAF': 'Geometric'
}

for model_name, shap_values in shap_dict.items():
    feature_importance = np.abs(shap_values).mean(axis=0)
    df = pd.DataFrame({
        'Feature': X_train.columns,
        'Mean_abs_SHAP': feature_importance
    })
    df['Feature Type'] = df['Feature'].map(feature_types)
    type_importance = df.groupby('Feature Type')['Mean_abs_SHAP'].sum().reset_index()

    colors = sns.color_palette("muted")
    wedges, texts, autotexts=plt.pie(
        type_importance['Mean_abs_SHAP'],
        labels=None,
        colors=colors,
        autopct='%1.1f%%',
        wedgeprops={'edgecolor': 'white'},
    )
    plt.legend(
        wedges,
        type_importance['Feature Type'],
        title="Feature Type",
        loc="center left",
        bbox_to_anchor=(1, 0, 0.5, 1),
        fontsize=12,
        title_fontsize=13
    )
    plt.show()

Partial dependence plots

In [None]:
for feat in X_train.columns:
  disp=PartialDependenceDisplay.from_estimator(
      bestmodelrf,
      X_train,
      features=[feat],
      kind="average"
  )
  plt.ylabel("CO₂ adsorption (mmol/g)")
  plt.show()

In [None]:
for feat in X_train.columns:
  disp1=PartialDependenceDisplay.from_estimator(
      bestmodellg,
      X_train,
      features=[feat],
      kind="average"
  )
  plt.ylabel("CO₂ adsorption (mmol/g)")
  plt.show()

In [None]:
max_info = {}

for feat in X_train.columns:

    disp2 = PartialDependenceDisplay.from_estimator(
        bestmodelxg,
        X_train,
        features=[feat],
        kind="average"
    )

    x = disp2.lines_[0][0].get_xdata()
    y = disp2.lines_[0][0].get_ydata()

    idx_max = np.argmax(y)

    max_info[feat] = (x[idx_max], y[idx_max])
    plt.xlabel(f"{feat}",fontsize=14)
    plt.ylabel("CO₂ adsorption (mmol/g)",fontsize=14)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.show()

In [None]:
for feat, (f_val, pred_val) in max_info.items():
    print(f"Feature: {feat}, Max Feature Value: {f_val}, Max Predicted Value: {pred_val}")

SHAP dependence plots

In [None]:
feature_names = X_train.columns.tolist()
for feature in feature_names:
    shap.dependence_plot(feature, shap_values_rf, X_train, interaction_index="auto")

In [None]:
for feature in feature_names:
    shap.dependence_plot(feature, shap_values_lg, X_train, interaction_index="auto")

In [None]:
for feature in feature_names:
    shap.dependence_plot(feature, shap_values_xg, X_train, interaction_index="auto")