In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

In [None]:
import matplotlib
# plotting for paper
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['font.family'] = ['arial']
matplotlib.rcParams['font.size'] = 6

sns.set_theme(context ='paper', 
    palette="Paired", 
    style='white',
    font='arial',
    font_scale=1.0)

In [None]:
data = pd.read_csv('models/regression_gbt_all_data_combos_cv_5v5.csv', index_col=0)
data.reset_index(inplace=True)

In [None]:
data2 = pd.read_csv('models/regression_all_models_all_data_combos_cv_5v5.csv', index_col=0)
data2.reset_index(inplace=True)

In [None]:
len(data.drug_name.unique())

In [None]:
data.head()

In [None]:
data2.head()

In [None]:
data.drug_name.unique()

In [None]:
data.model.unique()

In [None]:
# results_syn_id = 'syn27091721'
# all_data = load_table(results_syn_id)

In [None]:
# full_data = pd.concat([data, data2])

In [None]:
full_data = pd.read_csv("drug_response_regression_model_features.csv", )

In [None]:
#full_data.to_csv("drug_response_regression_model_features.csv")

In [None]:
def plot_metric_by_drug(
    data_set, x='spearman', y='drug_name',
    save_name='', hue=None, 
    figsize=(6,12),
    sort_index=None
):
    hue_order = None
    if sort_index is None:
        sort_index = data_set.groupby(y)[x].mean().sort_values().index.values
    if hue is not None:
        hue_order = data_set.groupby(hue)[x].mean().sort_values().index.values
    
    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111)
    ax = sns.boxenplot(
        data=data_set,
        x=y,
        y=x, 
        hue=hue,
        ax=ax, 
        order=sort_index,
        hue_order=hue_order,
        k_depth='full'
    )
    
    ax = sns.stripplot(
        data=data_set,
        x=y,
        y=x, 
        hue=hue,
        dodge =True,
        ax=ax, 
        order=sort_index,
        hue_order=hue_order,
        size=2, color='black', alpha=.5,
    )   
    # Put the legend out of the figure
    if hue is not None:
        handles, labels = ax.get_legend_handles_labels()
        ax.legend(handles=handles[:len(handles)//2], labels=labels[:len(handles)//2],
                  bbox_to_anchor=(1.25, 1), loc=0, borderaxespad=0.)
    ax.tick_params(axis='x', rotation=90)
    plt.savefig(f"{save_name}.png", dpi=300, bbox_inches='tight')
    plt.savefig(f"{save_name}.pdf", dpi=300, bbox_inches='tight')
    return sort_index

In [None]:
plot_metric_by_drug(
    data, 'spearman', 'drug_name', 'model_performance', hue=None,
    figsize=(12,12)
)

In [None]:
i = 'phospho_proteomics'
phospho_proteomics = data.loc[data.data_type == i]
phospho_proteomics = phospho_proteomics.loc[phospho_proteomics.model == 'gbt']
plot_metric_by_drug(
    phospho_proteomics, 
    'spearman', 
    'drug_name',
    f'model_performance_{i}_org',
    figsize=(8,3)
)

In [None]:
drugs_to_focus = [
#     'Gilteritinib',
#     'Quizartinib (AC220)',
#     'Trametinib (GSK1120212)',
    'Panobinostat',
    'Sorafenib',
    'Venetoclax',
    ]
d1= data.copy()
d1['model'] = 'gbt_original'
fd = data2.copy()
fd = pd.concat([d1, fd])
# fd = fd.loc[fd.drug_name.isin(drugs_to_focus)]

fd = fd.loc[fd.data_type.isin([
    'phospho_proteomics'
])]
fd = fd.loc[fd.model.isin(['EN', 'gbt'])]
si = plot_metric_by_drug(
    fd, 
    'spearman', 
    'drug_name', 
    'compare_models', 
    hue='model', 
    figsize=(12,4)
)

In [None]:
plot_metric_by_drug(
    fd,
    'n_feat', 
    'drug_name',
    'compare_models_n_feat', 
    hue='model', 
    figsize=(12, 4),
    sort_index=si
);

In [None]:
d1= data.copy()
d1['model'] = 'gbt_original'
fd = data2.copy()
fd = pd.concat([d1, fd])
fd = fd.loc[fd.data_type.isin([
    'phospho_proteomics'
])]
fd = fd.loc[fd.model.isin(['gbt_original', 'gbt'])]
si = plot_metric_by_drug(
    fd, 
    'spearman', 
    'drug_name', 
    'compare_gbt_lambda', 
    hue='model', 
    figsize=(4,12)
)

In [None]:
i = 'phospho_proteomics'
phospho_proteomics = data2.loc[data2.data_type == i]
phospho_proteomics = phospho_proteomics.loc[phospho_proteomics.model == 'gbt']

In [None]:
plot_metric_by_drug(
    phospho_proteomics, 
    'spearman', 
    'drug_name',
    f'model_performance_{i}_2',
    figsize=(4,8)
)

In [None]:
i = 'rna_seq'
rna_seq = data.loc[data.data_type == i]
plot_metric_by_drug(
    rna_seq, 
    'spearman', 
    'drug_name',
    f'model_performance_{i}',
)

In [None]:
phospho_proteomics

In [None]:
phospho_proteomics.groupby('drug_name').mean()[['spearman', 'sr']].sort_values(by='spearman',ascending=False)

In [None]:
phospho_proteomics.groupby('drug_name').mean()['sr'].sort_values()

In [None]:
drugs_to_focus = [
#     'Gilteritinib',
#     'Quizartinib (AC220)',
#     'Trametinib (GSK1120212)',
#     'Panobinostat',
#     'Sorafenib',
    'Venetoclax',
    ]
df_subset = data.loc[data.drug_name.isin(drugs_to_focus)]
df_subset = df_subset.loc[~(df_subset.data_type.str.contains('wes'))].copy()

plot_metric_by_drug(
    df_subset, 
    x='spearman',
    y='drug_name', 
    save_name='model_performance_drugs_subset', 
    hue='data_type',
    figsize=(3,4)
)
plt.legend(bbox_to_anchor=(1.05, 1), loc=0, borderaxespad=0.);
plt.savefig('drugs_of_interest_performace.pdf', dpi=300, bbox_inches='tight')

In [None]:
plot_metric_by_drug(
    df_subset, 
    x='n_feat',
    y='drug_name', 
    save_name='model_performance_drugs_subset', 
    hue='data_type',
    figsize=(4,7)
)

In [None]:
sns.scatterplot(x='n_feat', y='spearman', data=phospho_proteomics)

In [None]:
sns.scatterplot(x='n_feat', y='spearman', data=phospho_proteomics)

In [None]:
venetoclax = data.loc[data.drug_name == 'Venetoclax'].copy()
sort_index = venetoclax.groupby('data_type')['spearman'].mean()
sort_index.sort_values(inplace=True)
sort_index = sort_index.index.values

In [None]:
venetoclax

In [None]:
def plot_indidvidual_drug(df, drug_name, prefix, x='spearman', y='data_type'):
    subset = df.loc[df.drug_name == drug_name].copy()
    subset = subset.loc[~(subset.data_type.str.contains('wes'))]
    sort_index = subset.groupby(y)[x].mean()
    sort_index = sort_index.sort_values().index.values
    
#     fig = plt.figure(figsize=(6, 6))
#     ax = fig.add_subplot(111)
#     plt.title(drug_name)
#     ax = sns.boxenplot(
#         data=subset,
#         x=x,
#         y=y, 
#         k_depth='full',
# #         hue="model",
#         ax=ax, 
#         order=sort_index
#     )
#     ax = sns.swarmplot(
#         data=subset,
#         x=x,
#         y=y, 
# #         hue="model",
#         ax=ax, 
#         color='k',
#         order=sort_index
#     )
#     plt.xlim(0, 1.1)
    # Put the legend out of the figure
#     plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.);
    
    drug_name = drug_name.split(' ')[0]
    
#     plt.savefig(f"{prefix}_{drug_name}_pearson_by_data.png", dpi=300, bbox_inches='tight')
    
    
    subset = pd.pivot_table(subset, index='data_type', columns='k', values=x, aggfunc='mean')
    subset['mean'] = subset.T.mean()
    fmt = '.02f'
    if x == 'n_feat':
        subset['mean'] = subset['mean'].astype(int)
        fmt = 'g'
    ax = sns.clustermap(
        data=subset,
        row_cluster=True, 
        col_cluster=False,
        annot=True,
        fmt=fmt,
        linewidths=0.01,
        figsize=(14, 4),
        cmap='Reds'

    )
    plt.tight_layout();
    plt.savefig(f"{prefix}_{drug_name}_indepth_by_data.png", dpi=300, bbox_inches='tight')

In [None]:
plot_indidvidual_drug(data, 'Venetoclax', 'k_grid_results', 'spearman', )
plot_indidvidual_drug(data, 'Venetoclax', 'k_grid_results', 'n_feat', )

In [None]:
plot_indidvidual_drug(data, 'Gilteritinib', 'd1', 'spearman', )
plot_indidvidual_drug(data, 'Gilteritinib', 'd1', 'n_feat', )

In [None]:
data.loc[data.data_type.str.contains('phospho_proteomics_rna_seq_wes')].feature_names.str.split('|').apply(sorted)

In [None]:
fd

In [None]:
ven_pivot = pd.pivot_table(fd, index=['data_type', 'model'], values='spearman', columns='drug_name', aggfunc='mean')
ven_pivot

In [None]:
ax = sns.clustermap(
    data=ven_pivot,
    row_cluster=True, 
    col_cluster=False,
    annot=True,
    fmt='0.3f',
    linewidths=0.01,
    figsize=(12,12),
    cmap='Reds'

)
# Put the legend out of the figure
# plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.);
plt.savefig("venetoclax_indepth_by_data.png", dpi=300, bbox_inches='tight')

In [None]:
from collections import Counter
from functools import reduce

def get_feature_counts(features):
    c = Counter()
    for k in features:
        current_features = Counter(k.split('|'))
        c += current_features
    return c

def get_feature_matrix(df):
    c=0
    all_counts = []
    for i, d in df.groupby(['data_type', 'drug_name']):
        counts = pd.Series(get_feature_counts(d.feature_names))
        counts = counts.to_frame(name='_'.join(i))
        all_counts.append(counts)
    all_counts = pd.concat(all_counts, axis=1)
    all_counts.fillna(0, inplace=True)
    all_counts.head()
    return all_counts

In [None]:
venetoclax = data2.loc[data2.drug_name == 'Venetoclax'].copy()
venetoclax = venetoclax.loc[venetoclax.model=='gbt']
venetoclax = venetoclax.loc[venetoclax.data_type!='wes']
all_counts = get_feature_matrix(venetoclax)
all_counts.sort_values(by='proteomics_Venetoclax')['proteomics_Venetoclax']

In [None]:
all_counts = all_counts[[i for i in all_counts.columns if 'wes' not in i]]

In [None]:
all_counts

In [None]:
feature_output = dict()
feature_output_raw = dict()

for i in all_counts.columns:
    if i == 'wes_Venetoclax':
        continue
    tmp = all_counts[i].copy()
    tmp = tmp[tmp>1]
#     print(i, sorted(tmp.sort_values(ascending=False).index.values))
    genes = sorted(tmp.sort_values(ascending=False).index.values)
    feature_output_raw[i] = genes
    clean_genes = []
    for n in genes:
        if '_prot' in n or '_rna' in n or '_mut' in n:
            clean_genes.append(n.split('_')[0])
        else:
            clean_genes.append(n.split('-')[0])
    feature_output[i] = sorted(set(clean_genes))
for i in feature_output:
    print(i, len(feature_output[i]),)# feature_output[i], '\n')

In [None]:
from magine.enrichment.enrichr import Enrichr
e = Enrichr()

In [None]:
enrichment_results_vent = e.run_samples(
    list(feature_output.values()),
    list(feature_output.keys()), 
    gene_set_lib='Reactome_2022'
)

In [None]:
enrichment_results_vent.term_name = enrichment_results_vent.term_name.str.split('r-hsa').str.get(0)

In [None]:
enrichment_results_vent.sample_id = enrichment_results_vent.sample_id.str.strip('_Venetoclax')

In [None]:
enrichment_results_vent.remove_redundant(level='dataframe', threshold=.5).heatmap(
    min_sig=1,
    y_tick_labels=True, 
    cluster_row=True, 
    cluster_col=False,
    figsize=(8, 14)
    
)
plt.savefig('venetoclax_feature_enrichment_all_data_compare2.pdf', bbox_inches='tight')

In [None]:
count_of_sig_terms = enrichment_results_vent.pivot_table('significant', 'sample_id', aggfunc='sum')
count_of_sig_terms.reset_index(inplace=True)


for i in feature_output:
    n_genes = len(feature_output[i])
    i = i.rstrip('_Venetoclax')
    count_of_sig_terms.loc[count_of_sig_terms.sample_id==i, 'n_genes'] = int(n_genes)
count_of_sig_terms.n_genes= count_of_sig_terms.n_genes.astype(int)
count_of_sig_terms.sort_values('n_genes', inplace=True)
count_of_sig_terms.loc[count_of_sig_terms.sample_id=='phosph', 'sample_id'] = 'phospho'

In [None]:
count_of_sig_terms

In [None]:
sort_index = [
    'phospho','proteomics', 'rna_seq',
    'phospho_proteomics', 
    'phospho_rna_seq',
    'proteomics_rna_seq',
    'phospho_proteomics_rna_seq',
]
vent_sub = venetoclax.loc[venetoclax.data_type.isin(sort_index)]
vent_sub

In [None]:
sort_index = vent_sub.groupby('data_type')['spearman'].mean().sort_values().index.values

In [None]:
f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(6, 2), sharey=True)

g = sns.boxenplot(
    data=vent_sub,
    x="spearman",
    y="data_type", 
    ax=ax1, 
    order=sort_index,
    k_depth='full'
)

g = sns.swarmplot(
    data=vent_sub,
    x="spearman",
    y="data_type", 
    color='black',
    alpha=0.7,
    size=3,
#     hue="",
    ax=ax1, 
    order=sort_index
)
g.set_xlabel('Spearman $\\rho$')
g.set_ylabel('')
# ax2 = fig.add_subplot(121)
g = sns.barplot(
    data=count_of_sig_terms,
    y='sample_id',
    x='n_genes',
    orient='h',
    order=sort_index,
    ax=ax2
);
g.set_xlabel('Number of model features')
g.set_ylabel('')

g = sns.barplot(
    data=count_of_sig_terms,
    y='sample_id',
    x='significant',
    orient='h',
    order=sort_index,
    ax=ax3
);
g.set_xlabel('Number of significantly\n enriched terms')
g.set_ylabel('')


plt.savefig('barplot_venetoclax_enriched_terms2.pdf', bbox_inches='tight')
# g.set_xticklabels(g.get_xticklabels(), rotation=45);

In [None]:
all_counts2 = all_counts[all_counts[all_counts>9].count(axis=1) > 3]
len(all_counts2)
top_35_features = all_counts2.sum(axis=1).sort_values(ascending=False).index.values

In [None]:
all_counts.T[top_35_features].sum()

In [None]:
top_features = all_counts.T[top_35_features]

In [None]:
top_features.shape

In [None]:
all_counts.T['BCL2_prot'].sort_values()

In [None]:
all_counts.columns

## Plot GBT features for venetoclax

In [None]:
sns.clustermap(top_features.T, col_cluster=True, row_cluster=True, method='ward',
               cmap=sns.color_palette("Reds"),
               figsize=(4,10),
              linewidth=.0,
              yticklabels=True);

In [None]:
subset = all_counts[all_counts.sum(axis=1) > 5]
subset.shape

In [None]:
all_counts.head()

In [None]:
ref = all_counts.index.values.copy()
ref

In [None]:
d = pd.DataFrame(ref, columns=['name'])

d['source'] = 'red'
d.loc[d.name.str.endswith('_prot'), 'source'] = 'blue'
d.loc[d.name.str.endswith('_rna'), 'source'] = 'gold'
colors = d.set_index('name')
colors.sort_index(inplace=True)
colors.head()

In [None]:
all_counts.sort_index(inplace=True)
all_counts.head(10)

In [None]:
all_counts.loc[~(all_counts['proteomics_Venetoclax'].index.str.endswith('_prot')), 'proteomics_Venetoclax'] = np.nan
all_counts.loc[~(all_counts['rna_seq_Venetoclax'].index.str.endswith('_rna')), 'rna_seq_Venetoclax'] = np.nan
all_counts.loc[all_counts['phospho_Venetoclax'].index.str.endswith('_prot'), 'phospho_Venetoclax'] = np.nan
all_counts.loc[all_counts['phospho_Venetoclax'].index.str.endswith('_rna'), 'phospho_Venetoclax'] = np.nan

all_counts.loc[all_counts['phospho_proteomics_Venetoclax'].index.str.endswith('_rna'), 'phospho_proteomics_Venetoclax'] = np.nan
all_counts.loc[all_counts['phospho_rna_seq_Venetoclax'].index.str.endswith('_prot'), 'phospho_rna_seq_Venetoclax'] = np.nan
all_counts.loc[~(all_counts['phospho_rna_seq_Venetoclax'].index.str.endswith('_prot') | 
                all_counts['phospho_rna_seq_Venetoclax'].index.str.endswith('_rna'))
                , 
                'phospho_rna_seq_Venetoclax'] = np.nan




In [None]:
all_counts['List']=colors['source']
all_counts.sort_values(by=['List'], inplace=True)
all_counts.head()
del all_counts['List']

In [None]:
all_counts = all_counts[['phospho_Venetoclax',
                         'proteomics_Venetoclax',
                         'rna_seq_Venetoclax',
                         'phospho_proteomics_Venetoclax',
                         'phospho_rna_seq_Venetoclax',
                         'proteomics_rna_seq_Venetoclax',
                         'phospho_proteomics_rna_seq_Venetoclax',
       ]]

In [None]:
g = sns.clustermap(
    all_counts.T,
    col_colors=colors,
    col_cluster=False, row_cluster=False, 
    method='ward',
    xticklabels=False, yticklabels=True,
    figsize=(8, 2),
    linewidths=0.0,
    cmap=sns.color_palette('rocket', n_colors=26 )
);

from matplotlib.patches import Patch
leg = plt.legend(
        [Patch(facecolor='blue'), Patch(facecolor='gold'), Patch(facecolor='red')], 
        ['rna', 'proteomics', 'phospho'], 
        title='Data type',
        ncol=3,
        bbox_to_anchor=(.575, 1.1),
        bbox_transform=plt.gcf().transFigure,
        loc='upper right'
    )
plt.gca().add_artist(leg)
# plt.setp(g.ax_heatmap.get_xticklabels(), rotation=45)
plt.savefig("features_all.pdf", dpi=300, bbox_inches='tight')
# g = sns.clustermap(all_counts, col_cluster=True, row_cluster=True, method='ward', xticklabels=False, yticklabels=False, figsize=(4,4));
plt.savefig("features_all.png", dpi=300, bbox_inches='tight')