# Code to generate figures for paper

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

sns.set_context("talk")

plt.rcParams["font.family"] = "Times New Roman"

## Figure 2

In [None]:
catalog = pd.read_pickle("../data/final_filtered_catalog_for_model_training_and_eval.pkl")
rock_exp_nums = list(set(catalog[catalog.substrate=="rock"]["expname"]))
till_exp_nums = list(set(catalog[catalog.substrate=="till"]["expname"]))

In [None]:
f, ax = plt.subplots(7,2,figsize=(15,35))

ax[0,0].pcolormesh(
    np.stack(catalog[catalog.labels==0]["alignedwaves"].values)[:,350:500],
    vmin=-1,
    vmax=1,
    cmap="seismic"
    )
ax[0,0].set_title("Rock", fontsize=35)
ax[0,0].set_xticks([0,50,100,150])
ax[0,0].set_xticklabels([0,5,10,15])
ax[0,0].set_xlabel("Time (microseconds)")
ax[0,0].set_ylabel("Waveform number")

ax[0,1].pcolormesh(
    np.stack(catalog[catalog.labels==1]["alignedwaves"].values)[:,350:500],
    vmin=-1,
    vmax=1,
    cmap="seismic"
    )
ax[0,1].set_title("Till", fontsize=35)
ax[0,1].set_xticks([0,50,100,150])
ax[0,1].set_xticklabels([0,5,10,15])
ax[0,1].set_xlabel("Time (microseconds)")
ax[0,1].set_ylabel("Waveform number")

for i, num in enumerate(sorted(rock_exp_nums)):
    i+=1
    waves = np.stack(catalog[catalog.expname==num]["alignedwaves"].values)[:,350:500]
    for wave in waves:
        ax[i][0].plot(wave, c="red", alpha=0.05)
        ax[i][0].set_ylabel("Normalized amplitude")
        #ax[i][0].get_xaxis().set_visible(False)
        #ax[i][0].get_yaxis().set_visible(False)
        ax[i,0].set_xlim([0,150])
        ax[i,0].set_xticks([0,50,100,150])
        ax[i,0].set_xticklabels([0,5,10,15])
        ax[i,0].set_xlabel("Time (microseconds)")
        ax[i,0].text(5, 0.8, num, fontsize=20)

for i, num in enumerate(sorted(till_exp_nums)):
    i+=1
    waves = np.stack(catalog[catalog.expname==num]["alignedwaves"].values)[:,350:500]
    for wave in waves:
        ax[i][1].plot(wave, c="teal", alpha=0.05)
        ax[i][1].set_ylabel("Normalized amplitude")
        #ax[i][1].get_xaxis().set_visible(False)
        #ax[i][1].get_yaxis().set_visible(False)
        ax[i,1].set_xlim([0,150])
        ax[i,1].set_xticks([0,50,100,150])
        ax[i,1].set_xticklabels([0,5,10,15])
        ax[i,1].set_xlabel("Time (microseconds)")
        ax[i,1].text(5, 0.8, num, fontsize=20)

plt.tight_layout()
#plt.savefig("../figures/figure2.png",format="png")
plt.show()

## Figure 3

In [None]:
res = pd.read_pickle("../data/train_test_results.pkl")
acc_by_testexps = pd.DataFrame(res.groupby("test_exps")["balanced_accuracy"].mean()).reset_index()
acc_by_test_exp = res.groupby("run_number").aggregate(
    {
        "balanced_accuracy": np.mean,
        "till_test_exp": lambda x: list(x)[0],
        "rock_test_exp": lambda x: list(x)[0]
    }
)
acc_by_test_exp = acc_by_test_exp.pivot(index="till_test_exp", columns="rock_test_exp", values="balanced_accuracy")

In [None]:
plt.figure(figsize=(9,7))
sns.heatmap(
    acc_by_test_exp,
    cmap="seismic",
    vmin=0, vmax=1,
    annot=True, fmt=".2f"
)
plt.title("Mean accuracy by test set experiment pair", fontsize=25, wrap=True)
plt.ylabel("Till experiment #", fontsize=20)
plt.xlabel("Rock experiment #", fontsize=20)
#plt.savefig("../figures/figure3.png",format="png")
plt.show()

## Figure 4

In [None]:
importances = np.stack(
    res.classifier.apply(lambda x: x.feature_importances_)
)

In [None]:
f, ax = plt.subplots(2,1,figsize=(10,7), sharex=True)
plt.suptitle("Random Forest feature importances")

ninetieth = np.quantile(importances, 0.9,axis=0)
tenth = np.quantile(importances, 0.1, axis=0)
ax[0].fill_between(
    x=range(ninetieth.shape[0]),
    y1=ninetieth,
    y2=tenth,
    color="black",
    alpha=0.25
)
ax[0].plot(np.quantile(importances, 0.9,axis=0),"k", alpha=0.5, label="90th percentile")
ax[0].plot(np.quantile(importances, 0.1, axis=0),"k", alpha=0.5, label="10th percentile")
ax[0].plot(np.stack(importances).mean(axis=0),"k", label="mean")
ax[0].legend()
ax[0].set_ylabel("Feature importance")

for wave in catalog[catalog.labels==0]["alignedwaves"]:
    ax[1].plot(wave[350:500], "teal", alpha=0.01)
for wave in catalog[catalog.labels==1]["alignedwaves"]:
    ax[1].plot(wave[350:500], "red", alpha=0.01)
ax[1].set_ylabel("Normalized amplitude")
ax[1].set_xticks(np.arange(0,150,10))
ax[1].set_xticklabels(np.arange(0,15,))


plt.xlabel("Time (microseconds)")
plt.tight_layout()
#plt.savefig("../figures/figure4.png",format="png")
plt.show()