In [2]:
cd ..

# Multi-omics stratification on PDAC patients

In [80]:
import os
import pandas as pd
import numpy as np
import dill
from torch.utils.data import DataLoader
from pytorch_lightning.utilities.seed import isolate_rng
import torch
from captum.attr import LayerConductance, NeuronConductance, IntegratedGradients, GradientShap, NoiseTunnel
import types
from scipy.stats import kruskal, mannwhitneyu, kstest
from sklearn.pipeline import make_pipeline
from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import seaborn as sns
from statsmodels.stats.multitest import fdrcorrection


from src import settings
from utils import MultiViewDataset, transform_full_dataset
from optimization import Optimization
from explainability import (
    plot_comparison_attributions_weights, plot_attribution_distribution, plot_feature_importance, layerconductance, neuronconductance,
    FeatureAblationV2, compute_gradients, compute_mv_score, plot_attribution_algorithm_comparison, DeepV2, compute_most_important_features_based_attribution)

## Load dataset

In [4]:
views = ["RNAseq", "Methylation"]

In [5]:
methylation_data = pd.read_csv(settings.methylation_data_path, sep=";", index_col=0, decimal=",")
methylation_data.columns = methylation_data.columns.str.replace(".", "-")
methylation_data = methylation_data.T
methylation_data = methylation_data.astype(np.float32)
print("methylation_data.shape", methylation_data.shape)
methylation_data.head()

In [6]:
rnaseq_data = pd.read_csv(settings.rnaseq_data_path, sep=";", index_col=0, decimal=",")
rnaseq_data = rnaseq_data.T
rnaseq_data = rnaseq_data.astype(np.float32)
print("rnaseq_data.shape", rnaseq_data.shape)
rnaseq_data.head()

In [7]:
samples = methylation_data.index.intersection(rnaseq_data.index)
methylation_data = methylation_data.loc[samples]
rnaseq_data = rnaseq_data.loc[samples]
assert methylation_data.index.equals(rnaseq_data.index)
Xs= [rnaseq_data, methylation_data]
print("common samples:", len(samples))

In [8]:
date = "2023070315"
with open(os.path.join(settings.optimization_path, f'optimization_optuna_{date}.pkl'), 'rb') as file:
    optimization_study = dill.load(file)
optimization_results = pd.read_csv(os.path.join(settings.optimization_path, f"optimization_results_{date}.csv"))
best_trial = optimization_results.iloc[0]
print("optimization_results.shape", optimization_results.shape)
optimization_results.head()

In [9]:
pipeline_name = f"pipeline0.pkl"
with open(os.path.join(settings.results_path, pipeline_name), 'rb') as f:
    pipeline = dill.load(f)
pipeline

In [10]:
pipeline_name = f"pipeline1.pkl"
with open(os.path.join(settings.results_path, pipeline_name), 'rb') as f:
    pipeline = dill.load(f)
pipeline

In [11]:
transformed_Xs = transform_full_dataset(Xs=Xs, fit_pipelines = False, results_folder = settings.results_path)
transformed_X = pd.concat(transformed_Xs, axis = 1)
features = pd.Series(np.concatenate([(view + "_" + X_.columns) for X_, view in zip(transformed_Xs, views)]))
print("transformed_X.shape", transformed_X.shape)
transformed_X.head()

In [12]:
model_path = os.path.join("outputs", "models", "model.pt")
model = torch.load(model_path)
model = model.eval()

In [13]:
batch_size = int(best_trial["user_attrs_batch_size"])
full_data = MultiViewDataset(Xs=transformed_Xs)
full_dataloader = DataLoader(dataset=full_data, batch_size=batch_size, shuffle=False)
with torch.no_grad():
    z_full = torch.vstack([model.autoencoder.encode(batch) for batch in full_dataloader])
    clusters = model.predict_cluster_from_embedding(z_full).detach().cpu().numpy()

In [330]:
clusters = pd.read_csv("outputs/results/clusters.csv", index_col= 0).squeeze()
Xs = [X.loc[clusters.index] for X in Xs]

In [331]:
target = int(pd.Series(clusters).value_counts().index[1])
inputs = tuple([torch.tensor(Xs_.values) for Xs_ in transformed_Xs])
baselines = tuple([torch.zeros(Xs_.shape) for Xs_ in inputs])
target

In [333]:
mv_scores = compute_mv_score(shap_values= shap_values, view_names= views)
ax = pd.DataFrame([mv_scores]).plot(kind= "bar", ylabel= "% contribution", xlabel= "Modality", title= "Mean contribution of each modality to the prediction")
for container in ax.containers:
    ax.bar_label(container)

In [None]:
explainer = DeepV2(model, data = list(baselines))
shap_values = explainer.shap_values(list(inputs))

In [552]:
mv_scores = compute_mv_score(shap_values= shap_values, view_names= views)
ax = pd.DataFrame([mv_scores]).plot(kind= "bar", ylabel= "% contribution", xlabel= "Modality", title= "Mean contribution of each modality to the prediction")
for container in ax.containers:
    ax.bar_label(container)

In [335]:
mv_scores

In [336]:
mv_scores = compute_mv_score(shap_values= shap_values, view_names= views)
mv_scores = [compute_mv_score(shap_values= shap_values, view_names= views, idx = i) for i in range(len(clusters))]
mv_scores = pd.DataFrame(mv_scores, index= clusters.index)
mv_scores["clusters"] = clusters
fig, axes = plt.subplots(1, 3, figsize= (20, 4))
ax = axes[0]
mv_scores_grouped = mv_scores.replace({"clusters", 0}, "Cluster 0").replace({"clusters", 1}, "Cluster 1").groupby("clusters")
ax = mv_scores_grouped.mean().plot(kind= "bar", ylabel= "% contribution", capsize=4, yerr= mv_scores_grouped.std(), colormap= "Paired", stacked= True, ax= ax, rot= 0)
ax.get_children()[7].remove()
ax.get_children()[7].remove()
ax.get_children()[7].remove()
ax.set_xlabel("")
ax.legend(loc= "upper center", ncols= 2, bbox_to_anchor=(0.5, 1.11))
da_for_test= [mv_scores[mv_scores['clusters'] == cl]["Methylation"] for cl in mv_scores['clusters'].unique()]
pval = round(mannwhitneyu(da_for_test[0], da_for_test[1]).pvalue, 5)
ax.text(0.35, 80, f"p= {round(pval, 3)}")

vio_met = pd.DataFrame([methylation_data[transformed_Xs[1].columns].loc[clusters == cl].mean(0) for cl in sorted(np.unique(clusters))],
                       index = [f"Cluster {cl}" for cl in sorted(np.unique(clusters))]).T
sns.violinplot(data= vio_met, orient= "h", ax = axes[1], palette= colors)
axes[1].boxplot(vio_met, vert=False, positions= [0,1], labels= vio_met.columns, patch_artist=True, boxprops=dict(facecolor="green"), showmeans= True)
pval = round(kstest(vio_met["Cluster 0"], vio_met["Cluster 1"]).pvalue, 5)
axes[1].text(0.8, .5, f"p= {round(pval, 3)}")

vio_rna = pd.DataFrame([rnaseq_data[transformed_Xs[0].columns].loc[clusters == cl].apply(lambda x: np.log2(1 + x)).mean(0) for cl in sorted(np.unique(clusters))],
                       index = [f"Cluster {cl}" for cl in sorted(np.unique(clusters))]).T
sns.violinplot(data= vio_rna, orient= "h", ax = axes[2], inner=None, palette= colors)
axes[2].boxplot(vio_rna, vert=False, positions= [0,1], labels= vio_rna.columns, patch_artist=True, boxprops=dict(facecolor="green"), showmeans= True,
                flierprops={'markersize': 1, 'markerfacecolor': 'r'})
pval = round(kstest(vio_rna["Cluster 0"], vio_rna["Cluster 1"]).pvalue, 5)
axes[2].text(14, .5, f"p= {round(pval, 3)}")

fig.subplots_adjust()
_ = fig.suptitle(None)

In [193]:
fig, axes = plt.subplots(1, 3, figsize= (20, 4))
ax = axes[0]
mv_scores_grouped = mv_scores.replace({"clusters", 0}, "Cluster 0").replace({"clusters", 1}, "Cluster 1").groupby("clusters")
ax = mv_scores_grouped.mean().plot(kind= "bar", ylabel= "% contribution", capsize=4, yerr= mv_scores_grouped.std(), colormap= "Paired", stacked= True, ax= ax, rot= 0)
ax.get_children()[7].remove()
ax.get_children()[7].remove()
ax.get_children()[7].remove()
ax.set_xlabel("")
ax.legend(loc= "upper center", ncols= 2, bbox_to_anchor=(0.5, 1.11))
da_for_test= [mv_scores[mv_scores['clusters'] == cl]["Methylation"] for cl in mv_scores['clusters'].unique()]
pval = round(mannwhitneyu(da_for_test[0], da_for_test[1]).pvalue, 5)
ax.text(0.35, 80, f"p= {round(pval, 3)}")

vio_met = pd.DataFrame([methylation_data[transformed_Xs[1].columns].iloc[clusters == cl].mean(0) for cl in sorted(np.unique(clusters))],
                       index = [f"Cluster {cl}" for cl in sorted(np.unique(clusters))]).T
sns.violinplot(data= vio_met, orient= "h", ax = axes[1], palette= colors)
axes[1].boxplot(vio_met, vert=False, positions= [0,1], labels= vio_met.columns, patch_artist=True, boxprops=dict(facecolor="green"), showmeans= True)
pval = round(kstest(vio_met["Cluster 0"], vio_met["Cluster 1"]).pvalue, 5)
axes[1].text(0.8, .5, f"p= {round(pval, 3)}")

vio_rna = pd.DataFrame([rnaseq_data[transformed_Xs[0].columns].iloc[clusters == cl].apply(lambda x: np.log2(1 + x)).mean(0) for cl in sorted(np.unique(clusters))],
                       index = [f"Cluster {cl}" for cl in sorted(np.unique(clusters))]).T
sns.violinplot(data= vio_rna, orient= "h", ax = axes[2], inner=None, palette= colors)
axes[2].boxplot(vio_rna, vert=False, positions= [0,1], labels= vio_rna.columns, patch_artist=True, boxprops=dict(facecolor="green"), showmeans= True,
                flierprops={'markersize': 1, 'markerfacecolor': 'r'})
pval = round(kstest(vio_rna["Cluster 0"], vio_rna["Cluster 1"]).pvalue, 5)
axes[2].text(14, .5, f"p= {round(pval, 3)}")

fig.subplots_adjust()
_ = fig.suptitle(None)

In [329]:
mv_scores = [compute_mv_score(shap_values= shap_values, view_names= views, idx = i) for i in range(len(clusters))]
mv_scores = pd.DataFrame(mv_scores, clusters.index)
mv_scores["clusters"] = clusters
ax = mv_scores.sort_values(by= "RNAseq", ascending= True).drop(columns= "clusters").plot(kind= "bar", ylabel= "% contribution", 
                                                                                         figsize= (30,8), xlabel= "Sample", colormap= "Paired", xticks= [], stacked= True)
ax.xaxis.label.set_fontsize(16)
ax.set_xlabel(ax.get_xlabel(), labelpad=20)
ax.yaxis.label.set_fontsize(16)
ax.yaxis.set_tick_params(labelsize= 16)
_ = ax.legend(loc= "upper center", prop= {"size": 16}, ncols= 2, bbox_to_anchor=(0.5, 1.09))

In [197]:
ax = mv_scores.sort_values(by= "RNAseq", ascending= True).drop(columns= "clusters").plot(kind= "bar", ylabel= "% contribution", 
                                                                                         figsize= (30,8), xlabel= "Sample", colormap= "Paired", xticks= [], stacked= True)
ax.xaxis.label.set_fontsize(16)
ax.set_xlabel(ax.get_xlabel(), labelpad=20)
ax.yaxis.label.set_fontsize(16)
ax.yaxis.set_tick_params(labelsize= 16)
_ = ax.legend(loc= "upper center", prop= {"size": 16}, ncols= 2, bbox_to_anchor=(0.5, 1.09))

In [903]:
ax = mv_scores.drop(columns= "clusters").plot(kind= "bar", ylabel= "% contribution", figsize= (30,8), xlabel= "Sample", colormap= "Paired", xticks= [], stacked= True)
ax.xaxis.label.set_fontsize(16)
ax.set_xlabel(ax.get_xlabel(), labelpad=20)
ax.yaxis.label.set_fontsize(16)
ax.yaxis.set_tick_params(labelsize= 16)
_ = ax.legend(loc= "upper right", prop= {"size": 16})

In [835]:
fig, axes = plt.subplots(1, 5, figsize= (20, 5), gridspec_kw={'width_ratios': [0.8, 0.1, 1, 0.01, 1]})
axes[1].set_visible(False)
axes[3].set_visible(False)
ax = axes[0]
props = mv_scores.replace({"clusters", 0}, "Cluster 0").replace({"clusters", 1}, "Cluster 1").boxplot(column= "Methylation", 
                                                                                                      by= 'clusters', figsize= (10, 4),
                                                                                                      grid= False, ylabel= "% contribution of methylation",
                                                                                                      showmeans= True, ax= ax, patch_artist=True, return_type='dict')
ax.set_xlabel("")
ax.set_title("")
ax.axhline(mv_scores["Methylation"].mean(), c= "red", linestyle= ":")
colors = ['#1f77b4', '#ff7f0e']
for i,box in enumerate(props[0]['boxes']):
    box.set_facecolor(colors[i])
ax2 = ax.twinx()
pval = round(kruskal(*[mv_scores[mv_scores['clusters'] == cl]["Methylation"] for cl in mv_scores['clusters'].unique()]).pvalue, 5)
ax.set_ylim(0,100)
ax2.set_ylim(100,0)
ax2.set_ylabel("% contribution of gene expression")
_ = ax.text(1.3, 50, f"p= {round(pval, 3)}")

vio_met = pd.DataFrame([methylation_data[transformed_Xs[1].columns].iloc[clusters == cl].mean(0) for cl in sorted(np.unique(clusters))],
                       index = [f"Cluster {cl}" for cl in sorted(np.unique(clusters))]).T
sns.violinplot(data= vio_met, orient= "h", ax = axes[2], palette= colors)
axes[2].boxplot(vio_met, vert=False, positions= [0,1], labels= vio_met.columns, patch_artist=True, boxprops=dict(facecolor="green"), showmeans= True)

vio_rna = pd.DataFrame([rnaseq_data[transformed_Xs[0].columns].iloc[clusters == cl].apply(lambda x: np.log2(1 + x)).mean(0) for cl in sorted(np.unique(clusters))],
                       index = [f"Cluster {cl}" for cl in sorted(np.unique(clusters))]).T
sns.violinplot(data= vio_rna, orient= "h", ax = axes[4], inner=None, palette= colors)
axes[4].boxplot(vio_rna, vert=False, positions= [0,1], labels= vio_rna.columns, patch_artist=True, boxprops=dict(facecolor="green"), showmeans= True,
                flierprops={'markersize': 1, 'markerfacecolor': 'r'})

fig.subplots_adjust()
_ = fig.suptitle(None)

In [553]:
mv_scores = [compute_mv_score(shap_values= shap_values, view_names= views, idx = i) for i in range(len(clusters))]
mv_scores = pd.DataFrame(mv_scores)
mv_scores["clusters"] = clusters
ax = mv_scores.drop(columns= "clusters").plot(kind= "bar", ylabel= "% contribution", figsize= (30,8), xlabel= "Samples", 
                                              title= "Contribution of each modality to the prediction")
axes = mv_scores.boxplot(by= 'clusters', figsize= (10, 4), grid= False, ylabel= "% contribution", showmeans= True)
for ax,view in zip(axes, views):
    pval = round(kruskal(*[mv_scores[mv_scores['clusters'] == cl][view] for cl in mv_scores['clusters'].unique()]).pvalue, 5)
    _ = ax.text(1.3, 50, f"p-value= {round(pval, 3)}")

In [None]:
ig = IntegratedGradients(model)
ig.gradient_func = compute_gradients
ig_nt = NoiseTunnel(ig)
gs = GradientShap(model)
gs.gradient_func = compute_gradients
fa = FeatureAblationV2(model)

ig_attr_test = ig.attribute(inputs, target = target)
ig_nt_attr_test = ig_nt.attribute(inputs, target = target)
gs_attr_test = gs.attribute(inputs, inputs, target = target)
fa_attr_test = fa.attribute(inputs, target = target)

In [850]:
df, most_important_features = compute_most_important_features_based_attribution(features=features, algorithms=[ig_attr_test, gs_attr_test, fa_attr_test],
                                                                                names= ['Int Grads', 'GradientSHAP', 'Feature Ablation'], top_n = 25)
weights = torch.cat([eval(f"model.autoencoder.encoder_{0}")[1][0].weight for enc in range(len(views))], dim= 1).mean(0).detach().numpy()
df["Weights"] = weights / np.linalg.norm(weights, ord=1)
df = df.loc[most_important_features]
_ = df.plot(kind= "bar", figsize = (20, 8), ylabel= "Attribution", colormap= "tab10")

In [543]:
plot_attribution_algorithm_comparison(features=features, algorithms= [ig_attr_test, gs_attr_test, fa_attr_test],
                                      names= ['Int Grads', 'GradientSHAP', 'Feature Ablation'],
                                      weights= [eval(f"model.autoencoder.encoder_{0}")[1][0].weight for enc in range(len(views))],
                                      top_n = 10, figsize = (20, 8))

In [837]:
df = pd.concat(Xs, axis= 1)
_, most_important_features = compute_most_important_features_based_attribution(features=features, algorithms= [ig_attr_test, gs_attr_test, fa_attr_test],
                                                                               names= ['Int Grads', 'GradientSHAP', 'Feature Ablation'], top_n = 25)
most_important_features = [i.split("_")[-1] for i in most_important_features]
df = df[most_important_features]
df["clusters"] = clusters
df = df.fillna(df.mean())

plt.figure(figsize=(20,8))
ax = sns.boxplot(data=pd.melt(df, id_vars='clusters').replace({"clusters", 0}, "Cluster 0").replace({"clusters", 1}, "Cluster 1"),
                 x="variable", y="value", hue="clusters", hue_order= ["Cluster 0", "Cluster 1"], showmeans= True, palette= colors)
ax.tick_params(axis='x', rotation=90)
ax.legend(title= "", loc= "lower right")
ax.set_xlabel("")
ax.set_ylabel("Beta value")
ax.set_title("")

pvals = [kruskal(*[df[df['clusters'] == cl][feature] for cl in df['clusters'].unique()]).pvalue for feature in most_important_features]
c, pvals = fdrcorrection(pvals)
c, pvals = pd.Series(c).apply(lambda x: "red" if x else "black"), pvals.round(3)

for xtick in ax.get_xticks():
    ax.text(xtick-0.2, .98, f"p={pvals[xtick]}", size='x-small', color= c.iloc[xtick])

In [546]:
print("Accuracy score =", cross_val_score(SVC(), df[most_important_features], df["clusters"], cv= 5).mean().round(2))

In [547]:
df = pd.concat(transformed_Xs, axis= 1)
_, most_important_features = compute_most_important_features_based_attribution(features=features, algorithms= [ig_attr_test, gs_attr_test, fa_attr_test],
                                                                               names= ['Int Grads', 'GradientSHAP', 'Feature Ablation'], top_n = 25)
most_important_features = [i.split("_")[-1] for i in most_important_features]
df = df[most_important_features]
df["clusters"] = clusters

fig, axes = plt.subplots(2, 1, figsize= (8, 10))

ax = axes[0]
ax = sns.violinplot(data= pd.DataFrame([df.loc[df["clusters"] == cl].drop(columns= "clusters").values.flatten() for cl in sorted(df["clusters"].unique())],
                                      index = [cl for cl in sorted(df["clusters"].unique())]).T, orient= "h", ax = ax)
ax.set_title("All feature values by cluster")

ax = axes[1]
ax = sns.violinplot(data= pd.DataFrame([df.loc[df["clusters"] == cl].drop(columns= "clusters").mean(0).values.flatten() for cl in sorted(df["clusters"].unique())],
                                      index = [cl for cl in sorted(df["clusters"].unique())]).T, orient= "h",  ax = ax)
_ = ax.set_title("Mean feature values by cluster")

In [231]:
pd.DataFrame([df.loc[df["clusters"] == cl].drop(columns= "clusters").mean(0).values.flatten() for cl in sorted(df["clusters"].unique())],
                                      index = [cl for cl in sorted(df["clusters"].unique())]).T

In [260]:
df

In [266]:
most_important_features = ["cg10794257", "cg03306374", "cg09656848", "cg27633530", "cg16729415", "cg03650946", "cg16856286", "cg04344565", "cg11527326",
                           "cg22674699", "cg27058257", "cg20718350", "cg16427096", "cg07589773", "cg01277542", "cg07085827", "cg12559197", "cg14004073",
                           "cg16816603", "cg09053680", "cg12040830", "cg09493505", "cg20482698", "cg21097881", "cg21039708"]

df = pd.concat(Xs, axis= 1)[most_important_features]
df["clusters"] = pd.Series(clusters).replace(0, "Cluster 0").replace(1, "Cluster 1").values
ax = sns.violinplot(data= df.groupby("clusters").mean().T, orient= "h")

In [294]:
ax.figure

In [314]:
df = pd.concat(Xs, axis= 1)
df = df[most_important_features]
df["clusters"] = clusters
df = df.fillna(df.mean())

plt.figure(figsize=(20,8))
ax = sns.boxplot(data=pd.melt(df, id_vars='clusters').replace({"clusters", 0}, "Cluster 0").replace({"clusters", 1}, "Cluster 1"),
                 x="variable", y="value", hue="clusters", hue_order= ["Cluster 0", "Cluster 1"], showmeans= True, palette= colors)
ax.tick_params(axis='x', rotation=90)
ax.legend(loc= "upper center", ncols= 2, bbox_to_anchor=(0.5, 1.06))

# ax.legend(title= "", loc= "lower right")
ax.set_xlabel("")
ax.set_ylabel("Beta value")
ax.set_title("")

pvals = [kruskal(*[df[df['clusters'] == cl][feature] for cl in df['clusters'].unique()]).pvalue for feature in most_important_features]
c, pvals = fdrcorrection(pvals)
c, pvals = pd.Series(c).apply(lambda x: "red" if x else "black"), pvals.round(3)

for xtick in ax.get_xticks():
    ax.text(xtick-0.2, .98, f"p={pvals[xtick]}", size='x-small', color= c.iloc[xtick])

In [403]:
from scipy.stats import pearsonr
plt.figure(figsize=(20,13))
corr = df.drop(columns= "clusters").corr(lambda x,y: pearsonr(x,y)[1])
corr = corr.replace(1, 0) * (corr.size/2 - len(corr))
ax = sns.heatmap(corr,annot=True, cmap= "coolwarm", fmt= ".3f", vmax=1)

In [302]:
plt.figure(figsize=(20,13))
corr = df.drop(columns= "clusters").corr().replace(1, np.nan)
ax = sns.heatmap(corr,annot=True, cmap= "coolwarm", fmt= ".2f", vmin= -1, vmax= 1)
# ax.figure.colorbar(ax.collections[0]).set_clim(-1,1)

In [None]:
ax = axes[1]
ax = sns.violinplot(data= pd.DataFrame([df.loc[df["clusters"] == cl].drop(columns= "clusters").mean(0).values.flatten() for cl in sorted(df["clusters"].unique())],
                                      index = [cl for cl in sorted(df["clusters"].unique())]).T, orient= "h",  ax = ax)
_ = ax.set_title("Mean feature values by cluster")

In [548]:
fig, axes = plt.subplots(2, 1, figsize= (8, 10))

ax = axes[0]
ax = sns.violinplot(data= pd.DataFrame([df.loc[df["clusters"] == cl].drop(columns= "clusters").values.flatten() for cl in sorted(df["clusters"].unique())],
                                      index = [cl for cl in sorted(df["clusters"].unique())]).T, orient= "h", ax = ax)
ax.set_title("All feature values by cluster")

ax = axes[1]
ax = sns.violinplot(data= pd.DataFrame([df.loc[df["clusters"] == cl].drop(columns= "clusters").mean(0).values.flatten() for cl in sorted(df["clusters"].unique())],
                                      index = [cl for cl in sorted(df["clusters"].unique())]).T, orient= "h",  ax = ax)
_ = ax.set_title("Mean feature values by cluster")

In [None]:
layer = eval(f"model.autoencoder.encoder_{model.autoencoder.views}")
lc = LayerConductance(model, layer[1], device_ids = model.cluster_centers_.device)
lc._attribute = types.MethodType(layerconductance._attribute, lc)

In [851]:
def plot_comparison_attributions_weights_(lc_attr, weights, layer_name = "embedding", figsize=(15, 8)):
    plt.figure(figsize=figsize)
    x_axis_data = np.arange(lc_attr.shape[1])
    y_axis_lc_attr = lc_attr.mean(0)
    y_axis_lc_attr = y_axis_lc_attr / np.linalg.norm(y_axis_lc_attr, ord=1)
    
    y_axis_layer_weight = weights.mean(1)
    y_axis_layer_weight = y_axis_layer_weight / np.linalg.norm(y_axis_layer_weight, ord=1)
    width = 0.25
    legends = ['Attributions','Weights']
    ax = plt.subplot()
    ax.bar(x_axis_data + width, y_axis_lc_attr, width, align='center', alpha=0.5, color='red')
    ax.bar(x_axis_data + 2 * width, y_axis_layer_weight, width, align='center', alpha=0.5, color='green')
    ax.set_ylabel("Attribution")
    ax.set_xlabel("Neuron")
    plt.legend(legends)
    ax.autoscale_view()
    if len(y_axis_layer_weight) <= 100:
        x_axis_labels = list(range(len(y_axis_layer_weight)))
        ax.set_xticks(x_axis_data + 0.5)
        ax.set_xticklabels(x_axis_labels)
    plt.show()


plot_comparison_attributions_weights_(lc_attr = lc_attr_test.detach().numpy(), weights = layer[0].weight.detach().numpy(), layer_name = "embedding", figsize=(20, 8))

In [549]:
lc_attr_test = lc.attribute(inputs = inputs, baselines=baselines, target = target)
plot_comparison_attributions_weights(lc_attr = lc_attr_test.detach().numpy(), weights = layer[0].weight.detach().numpy(), layer_name = "embedding", figsize=(20, 8))

In [550]:
cond_vals = lc_attr_test.detach().numpy()
sorted_neurons = np.abs(cond_vals.mean(0)).argsort()[::-1].tolist()
plot_attribution_distribution(cond_vals = cond_vals, figsize = (20, 8), strong_features = sorted_neurons[:3], weak_features = sorted_neurons[-3:][::-1])

In [None]:
neuron_cond = NeuronConductance(model, layer[1], device_ids = model.cluster_centers_.device)
neuron_cond._attribute = types.MethodType(neuronconductance._attribute, neuron_cond)

In [863]:
df = pd.DataFrame()

for neuron_selector in sorted_neurons:
    neuron_cond_vals = neuron_cond.attribute(inputs = inputs, baselines=baselines, target = target, neuron_selector= neuron_selector)
    importances = torch.cat(neuron_cond_vals, 1).mean(0).detach().numpy()
    most_important_features = features.to_frame("feature")
    most_important_features["importance"] = importances
    for view in views:
        df.loc[neuron_selector, view] = most_important_features[most_important_features["feature"].str.startswith(view)]["importance"].sum()

df[df < 0] = 0
df = (df.T / df.sum(1)).T
df = df.dropna()
df = df.round(2) * 100
df = df.astype(int)

ax = df.plot(kind='bar', stacked=True, figsize= (25, 5), xlabel= "Neuron", ylabel= "% contribution", colormap= "Paired")
for p in ax.patches:
    width, height = p.get_width(), p.get_height()
    x, y = p.get_xy() 
    ax.text(x+width/2, 
            y+height/2, 
            '{:.0f} %'.format(height), 
            horizontalalignment='center', 
            verticalalignment='center')

In [530]:
neuron_selector = sorted_neurons[0]
neuron_cond_vals = neuron_cond.attribute(inputs = inputs, baselines=baselines, target = target, neuron_selector= neuron_selector)
plot_feature_importance(features=features, top_n= 25, values= neuron_cond_vals, neuron_selector= neuron_selector, figsize= (20,6))

In [531]:
importances = torch.cat(neuron_cond_vals, 1).mean(0).detach().numpy()
most_important_features = features.to_frame("feature")
most_important_features["importance"] = importances
most_important_features = most_important_features.sort_values("importance", ascending= False).iloc[:25]
most_important_features = most_important_features.set_index("feature")
df = pd.concat(transformed_Xs, axis= 1)
most_important_features = [i.split("_")[-1] for i in most_important_features.index.to_list()]
df = df[most_important_features]
df["clusters"] = clusters

plt.figure(figsize=(20,8))
ax = sns.boxplot(data=pd.melt(df, id_vars='clusters'), x="variable", y="value", hue="clusters", showmeans= True)
ax.set_title(f"Boxplots for top input features for neuron {neuron_selector} grouped by clusters")
ax.tick_params(axis='x', rotation=90)
ax.legend(title= "clusters", loc= "lower right")

pvals = [kruskal(*[df[df['clusters'] == cl][feature] for cl in df['clusters'].unique()]).pvalue for feature in most_important_features]
c, pvals = fdrcorrection(pvals)
c, pvals = pd.Series(c).apply(lambda x: "red" if x else "black"), pvals.round(3)

for xtick in ax.get_xticks():
    ax.text(xtick-0.2, 3.3, pvals[xtick], size='x-small', color= c.iloc[xtick])

In [532]:
print("Accuracy score =", cross_val_score(SVC(), df[most_important_features], df["clusters"], cv= 5).mean().round(2))

In [533]:
neuron_selector = sorted_neurons[1]
neuron_cond_vals = neuron_cond.attribute(inputs = inputs, baselines=baselines, target = target, neuron_selector= neuron_selector)
plot_feature_importance(features=features, top_n= 25, values= neuron_cond_vals, neuron_selector= neuron_selector, figsize= (20,6))

In [534]:
importances = torch.cat(neuron_cond_vals, 1).mean(0).detach().numpy()
most_important_features = features.to_frame("feature")
most_important_features["importance"] = importances
most_important_features = most_important_features.sort_values("importance", ascending= False).iloc[:25]
most_important_features = most_important_features.set_index("feature")
df = pd.concat(transformed_Xs, axis= 1)
most_important_features = [i.split("_")[-1] for i in most_important_features.index.to_list()]
df = df[most_important_features]
df["clusters"] = clusters

plt.figure(figsize=(20,8))
ax = sns.boxplot(data=pd.melt(df, id_vars='clusters'), x="variable", y="value", hue="clusters", showmeans= True)
ax.set_title(f"Boxplots for top input features for neuron {neuron_selector} grouped by clusters")
ax.tick_params(axis='x', rotation=90)
ax.legend(title= "clusters", loc= "lower right")

pvals = [kruskal(*[df[df['clusters'] == cl][feature] for cl in df['clusters'].unique()]).pvalue for feature in most_important_features]
c, pvals = fdrcorrection(pvals)
c, pvals = pd.Series(c).apply(lambda x: "red" if x else "black"), pvals.round(3)

for xtick in ax.get_xticks():
    ax.text(xtick-0.2, 5.3, pvals[xtick], size='x-small', color= c.iloc[xtick])

In [535]:
print("Accuracy score =", cross_val_score(SVC(), df[most_important_features], df["clusters"], cv= 5).mean().round(2))

In [536]:
neuron_selector = sorted_neurons[2]
neuron_cond_vals = neuron_cond.attribute(inputs = inputs, baselines=baselines, target = target, neuron_selector= neuron_selector)
plot_feature_importance(features=features, top_n= 25, values= neuron_cond_vals, neuron_selector= neuron_selector, figsize= (20,6))

In [537]:
importances = torch.cat(neuron_cond_vals, 1).mean(0).detach().numpy()
most_important_features = features.to_frame("feature")
most_important_features["importance"] = importances
most_important_features = most_important_features.sort_values("importance", ascending= False).iloc[:25]
most_important_features = most_important_features.set_index("feature")
df = pd.concat(transformed_Xs, axis= 1)
most_important_features = [i.split("_")[-1] for i in most_important_features.index.to_list()]
df = df[most_important_features]
df["clusters"] = clusters

plt.figure(figsize=(20,8))
ax = sns.boxplot(data=pd.melt(df, id_vars='clusters'), x="variable", y="value", hue="clusters", showmeans= True)
ax.set_title(f"Boxplots for top input features for neuron {neuron_selector} grouped by clusters")
ax.tick_params(axis='x', rotation=90)
ax.legend(title= "clusters", loc= "lower right")

pvals = [kruskal(*[df[df['clusters'] == cl][feature] for cl in df['clusters'].unique()]).pvalue for feature in most_important_features]
c, pvals = fdrcorrection(pvals)
c, pvals = pd.Series(c).apply(lambda x: "red" if x else "black"), pvals.round(3)

for xtick in ax.get_xticks():
    ax.text(xtick-0.2, 5.3, pvals[xtick], size='x-small', color= c.iloc[xtick])

In [538]:
print("Accuracy score =", cross_val_score(SVC(), df[most_important_features], df["clusters"], cv= 5).mean().round(2))