In [None]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import pandas as pd
import numpy as np
from matplotlib.patches import Patch

import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import pearsonr, spearmanr
import pickle
from itertools import combinations
import scikit_posthocs as sp

from statannotations.Annotator import Annotator
from itertools import combinations
from scipy.stats import ttest_ind
from statsmodels.stats.multitest import fdrcorrection
from itertools import combinations

In [None]:
# formatting figures for paper

import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['font.family'] = ['arial']
matplotlib.rcParams['font.size'] = 6

sns.set_theme(
    context ='paper', 
    palette="Paired", 
    style='white',
    font='arial',
    font_scale=1.0
)

In [None]:
from load_data import AMLData, load_drug_response, cluster_colors, data_colors
from load_data import drug_response_id, load_table

In [None]:
data = AMLData()

In [None]:
data.auc_table

In [None]:
data.wes.sample_id.unique()

In [None]:
data.drug_names

In [None]:
# removes combo drugs
drug_names = [i for i in data.drug_names if ' - ' not in i]

In [None]:
data.auc_table['Cytarabine']

## Heatmaps of AUC values


In [None]:
# format to patient vs drug matrix from long form
table = data.auc_table[drug_names].copy()
table.head()

In [None]:
def plot_and_save(
    matrix,
    save_name,
    yticklabels=False,
    xticklabels=False,
    use_predicted_cluster=False):
    
    # use clusters of 159 or with predicted
    if use_predicted_cluster:
        k = 'Cluster'    
    else:
        k = 'k=4'
        
    node_labels = data.auc_table[k]
    node_labels.dropna(inplace=True)
    node_labels = node_labels.astype(int)

    node_lut = dict(zip(sorted(node_labels.unique()), cluster_colors))
    node_colors = pd.Series(node_labels, index=data.meta.index.values, name='Subtype').map(node_lut)
    handles = [Patch(facecolor=node_lut[name]) for name in node_lut]
    
    matrix = matrix.T
    cmap = sns.color_palette("rocket", as_cmap=True, n_colors=7)
        
    g = sns.clustermap(
        matrix.fillna(100),
        yticklabels=yticklabels,
        xticklabels=xticklabels,
        row_cluster=True,
        col_cluster=True,
        method='ward',
#         metric='correlation',
        col_colors= node_colors,
        mask=matrix.isnull(),
        cmap=cmap,
        figsize = (5, 5)
    );
    
    plt.legend(
        handles, node_lut, title='Subtype',
        bbox_to_anchor=(1.1, .8),
        bbox_transform=plt.gcf().transFigure,
        loc='upper right'
    )
    g.ax_heatmap.set_ylabel('drug')
    g.ax_heatmap.set_yticklabels(
        g.ax_heatmap.get_yticklabels(), 
        size = 6
    )
#     plt.savefig(f"figures/drug_response_by_cluster/{save_name}_{k}.pdf", bbox_inches='tight')
#     plt.savefig(f"figures/drug_response_by_cluster/{save_name}_{k}.png", bbox_inches='tight')

In [None]:
plot_and_save(
    table, 
    'AUC_heatmap', 
    yticklabels=False
)

s

In [None]:
drug_counts = table.describe().T['count']
drug_counts.sort_values().tail(5)

In [None]:
sns.displot(drug_counts, bins=100, kind='hist');
plt.title("# samples per drug")
plt.savefig("n_drug_count_by_sample.png", bbox_inches='tight')
plt.close()

In [None]:
high_occ_drugs = drug_counts[drug_counts>=100].index.values
plot_and_save(
    table[high_occ_drugs], 
    'AUC_remove_low_drug_occ', 
    yticklabels=False, 
    xticklabels=False
)


# Drugs that have AUC values less than 100 in at least 20 samples

In [None]:
high_occ_drugs

In [None]:
low_auc_counts = table[table <= 100].count().sort_values()
drugs = low_auc_counts[low_auc_counts >= 10].index.values
len(drugs)

In [None]:
good_drugs = set(drugs).intersection(high_occ_drugs)

In [None]:
good_drugs

In [None]:
final_table = []
final_df = table[good_drugs].copy()
final_table.append(final_df.count())
final_table.append(final_df[final_df<=100].count())
final_table = pd.DataFrame(final_table).T
final_table.rename({0 : 'samples', 1: 'sens_samples'}, axis=1, inplace=True)

In [None]:
final_table.to_csv('number_patients_per_drug.csv')

In [None]:
final_table.sort_values('sens_samples')

In [None]:
plot_and_save(
    final_df,
    'AUC_final_table',
    yticklabels=True, 
    xticklabels=False
)

In [None]:
plot_and_save(
    final_df,
    'AUC_final_table_with_predicted_clusters',
    yticklabels=True, 
    xticklabels=False, 
    use_predicted_cluster=True
)

## Look at individual drugs

In [None]:
mutation_table = pd.pivot_table(
    data.wes.loc[data.wes.source=='wes'],
    columns='label', 
    index='sample_id', 
    values='exp_value',
    fill_value=False
).astype(bool)
mutation_table = mutation_table.loc[:, (mutation_table.sum() > 4)].copy()
mutation_cols = set(mutation_table.columns.values)

In [None]:
mutation_table.sum().sort_values()

In [None]:
cp = data.meta.copy()
# cp['Cluster'] = cp['k=8']

In [None]:
subset_joined = final_df.join(cp).join(mutation_table).reset_index()
subset_joined

In [None]:
drug_vals = subset_joined.melt(
    id_vars=['sample_id', 'Cluster', 'NPM1calls', 'FLT3-ITDcalls']+list(mutation_cols),
    value_vars=final_df.columns.values,
    var_name='inhibitor', 
    value_name='auc'
)

drug_vals.head()

In [None]:
mutation_drug_df = final_df.join(mutation_table).join(cp)
mutation_drug_df.head()

In [None]:
drug_vals.dropna(subset=['Cluster'], inplace=True)

In [None]:
def create_groupby_plot_horizontal(drug_name):
    subset = data.auc_table[[drug_name, 'Cluster']].dropna().copy()
    
    plt.figure(figsize=(1.5, 3))
    ax = sns.kdeplot(
        data=subset,
#         vertical =True,
        y=drug_name,
        hue='Cluster',
        clip=(1, 300),
        bw_adjust=1.3,
        palette=cluster_colors,
        warn_singular=False,
        common_norm=True, 
        common_grid=True,
        gridsize =100,
        legend=False
        
    )
    ax.yaxis.set_label_position("right")
    ax.yaxis.tick_right()
    ax.legend([], frameon=False)

    plt.ylabel('AUC')
    plt.ylim(-20,400)
    plt.savefig(f'figures/drug_response_by_cluster/{drug_name}_kde.png',
                bbox_inches='tight', dpi=300)
    plt.show()

create_groupby_plot_horizontal('Venetoclax')
create_groupby_plot_horizontal('Panobinostat')
create_groupby_plot_horizontal('Sorafenib')
create_groupby_plot_horizontal('Venetoclax - Panobinostat')
create_groupby_plot_horizontal('Sorafenib - Panobinostat')

## Find drugs based on subtype response

In [None]:
import scikit_posthocs as sp


combos = list(combinations(range(1, int(drug_vals.Cluster.max())+1), 2))

def get_flat_row(matrix_dict):
    out_dict = {}
    for i, j in combos:
        out_dict[f'{i}_{j}'] = matrix_dict[i][j]
    return out_dict


cluster_drug_response = {}
sig_diff_drug_clusters = []
for i, d in drug_vals.groupby('inhibitor'):
    d = d.dropna(subset=['auc']).copy()    
    r = sp.posthoc_ttest(
        d, 
        val_col='auc',
        equal_var = False, #False
        pool_sd = False, #True
        group_col='Cluster', 
        p_adjust = 'fdr_bh', #fdr_bh
        sort=True
    )
    cluster_drug_response[i] = get_flat_row(r.to_dict())
    if (r <= .01).sum().max()>2:
        print(i)
        print(r)
        sig_diff_drug_clusters.append(i)

In [None]:
cluster_drug_response

In [None]:
cluster_pvalues = pd.DataFrame(cluster_drug_response).T
# cluster_pvalues.to_csv('drug_response_cluster_pvalue.csv', )

In [None]:
def plot_subtype_mutation( drug_name, use_prior_stats=False, save=False,hue='FLT3-ITDcalls'):
    
    subset = mutation_drug_df[[drug_name, 'Cluster', hue ]].dropna().copy()
    cluster_size = int(mutation_drug_df.Cluster.max()+1)
    order = list(range(1, cluster_size))
    
    fig = plt.figure(figsize=(5, 3))
    ax = plt.subplot(111)
  
    ax = sns.violinplot(
        data=subset,
        x='Cluster', 
        y=drug_name,
#         palette=cluster_colors,
        split=True,
        hue=hue,
        palette=['grey', 'darkred'],
        cut=0,
        label=False,
        ax=ax
    )
    

    
    pairs2 = [[(i, True), (i, False)] for i in range(1, cluster_size)]
    
    annotator = Annotator(
        ax,
        pairs2, 
        data=subset, 
        x='Cluster', 
        hue=hue,
        y=drug_name, 
        order=order,
        verbose=0
    )

    annotator.configure(
        test='t-test_welch',
        comparisons_correction="BH", 
        correction_format="replace",
#         text_format='simple', 
        text_format='star', 
        loc='inside',
        line_offset=.05,
        line_height=.01,
        text_offset=0,
        line_offset_to_group=.5,
        use_fixed_offset=True,
        verbose=0
    )
    annotator.apply_and_annotate()
    
    if use_prior_stats:
        pairs = []
        for pval in cluster_drug_response[drug_name]:
            if cluster_drug_response[drug_name][pval] <=0.05:
                n,m = pval.split('_')
                val = (int(n), int(m))
                pairs.append(val)
    else:
        pairs = list(combinations(range(1, cluster_size), 2))
        
    if len(pairs):
        annotator = Annotator(
            ax,
            pairs, 
            data=subset, 
            x='Cluster', 
            y=drug_name, 
            order=order,
            verbose=0
        )

        annotator.configure(
            test='t-test_welch',
            comparisons_correction="BH", 
            correction_format="replace",
#             text_format='star', 
            loc='outside',
            line_offset=.05,
            line_height=.01,
            text_offset=0,
            line_offset_to_group=.5,
            use_fixed_offset=False,
            verbose=0
        )
        annotator.apply_and_annotate()
        
   
    p = sns.stripplot(
        data=subset,
        order=order,
        x='Cluster', 
        y=drug_name,
        dodge=True,
        color='black',
        edgecolor='black',
        hue=hue,
        palette=['black', 'red'],
        alpha=.7,
        size=8,
        label=False
    )
    
    handles = p.legend_.legendHandles
    labels = [text.get_text() for text in p.legend_.texts]
    
    plt.title(drug_name)
    plt.ylabel("AUC")
    plt.xlabel("Subtype")
    # Put the legend out of the figure
    plt.legend(handles[-2:], labels[-2:], title=hue, 
               bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    
    if save:
        i = drug_name.replace("/", '')
        plt.savefig(f'figures/drug_response_by_cluster/violin_plot_split_{i}_{hue}.png', dpi=150, bbox_inches='tight')
        plt.savefig(f'figures/drug_response_by_cluster/violin_plot_split_{i}_{hue}.pdf', dpi=150, bbox_inches='tight')
    
    

In [None]:

plot_subtype_mutation(
    'Pelitinib (EKB-569)',
    use_prior_stats=False,
    save=False,
    hue='FLT3-ITDcalls'
)
plt.show()

In [None]:
plot_subtype_mutation(
    'A-674563',
    use_prior_stats=True,
    save=False,
    hue='NPM1calls'
)
plt.show()

In [None]:
plot_subtype_mutation(
    'Sorafenib',
    use_prior_stats=True,
    save=False,
    hue='FLT3-ITDcalls'
)
plt.show()

In [None]:
plot_subtype_mutation('Elesclomol', use_prior_stats=True, save=False)


In [None]:
plot_subtype_mutation(
    'JAK Inhibitor I',
    use_prior_stats=True,
    save=False
)

In [None]:
plot_subtype_mutation('Sorafenib', use_prior_stats=True, save=False, hue='FLT3-ITDcalls')

In [None]:
plot_subtype_mutation(
    'Foretinib (XL880)',
    use_prior_stats=True,
    save=False
)

In [None]:
# for i in list(good_drugs):
#     plot_subtype_mutation(
#         i, use_prior_stats=True, save=True, 
#         hue='NPM1calls'
#     )
#     plt.close()

In [None]:
# for i in list(good_drugs):
#     plot_subtype_mutation(i, use_prior_stats=True, save=True)
#     plt.close()

In [None]:
for i in sig_diff_drug_clusters:
    plot_subtype_mutation(i, use_prior_stats=True, save=False, hue='NPM1calls')
    plt.close()

In [None]:
def plot_subtype(drug_name, use_prior_stats=False, save=False, apply_annot=True):
    
    subset = mutation_drug_df[[drug_name, 'Cluster']].dropna().copy()
    
    cluster_size = int(subset.Cluster.max()+1)
    
    
    
    fig = plt.figure(figsize=(4, 3))
    ax = plt.subplot(111)
    g = sns.violinplot(
        data=subset,
        x='Cluster', 
        y=drug_name,
        palette=cluster_colors,
        cut=0,
        label=None,
        ax=ax
    )
    
    if use_prior_stats:
        pairs = []
        for pval in cluster_drug_response[drug_name]:
            if cluster_drug_response[drug_name][pval] <=0.05:
                n,m = pval.split('_')
                val = (int(n), int(m))
                pairs.append(val)
    else:
        pairs = list(combinations(range(1, cluster_size), 2))
    
    subset['Cluster'] = subset['Cluster'].astype(int).astype(str)
    
    pairs = [(str(i), str(j)) for i,j in pairs]
    order = [str(i) for i in range(1, cluster_size)]
    
    if apply_annot:
        annotator = Annotator(
            ax,
            pairs, 
            data=subset, 
            x='Cluster', 
            y=drug_name, 
            order=order,
            verbose=0
        )

        annotator.configure(
            test='t-test_welch',
            comparisons_correction=None, 
            correction_format="replace",
    #         text_format='simple', 
            text_format='star', 
            loc='inside',
            line_offset=.01,
            line_height=.01,
            text_offset=0.01,
            line_offset_to_group=0.1,
            verbose=0
        )
        annotator.apply_and_annotate()
    
    p = sns.stripplot(
        data=subset,
        order=order,
        x='Cluster', 
        y=drug_name,
        color='black',
        alpha=.7,
        label=None,
        ax=ax
    )
    
    ax.set_yticks(range(0, 351, 50), fontsize=12)
    ax.set_xticklabels(ax.get_xticklabels(), fontsize=12)
    
    ax.set_title(drug_name, fontsize=16)
    ax.set_ylabel("AUC", fontsize=14)
    ax.set_xlabel('Subtype', fontsize=14)

#     ax.set_ylim(-20, 400)
    
    if save:
        i = drug_name.replace("/", '')
        plt.savefig(f'figures/drug_response_by_cluster/violin_plot_{i}.png', dpi=300, bbox_inches='tight')
        plt.savefig(f'figures/drug_response_by_cluster/violin_plot_{i}.pdf', dpi=300, bbox_inches='tight')
        


In [None]:
plot_subtype('Cabozantinib', use_prior_stats=True, save=False)
plt.show()
plot_subtype('Elesclomol', True, save=False)
plt.show()

In [None]:
plot_subtype('Venetoclax', True, save=False)
plt.show()
plot_subtype('Sorafenib', True, save=False)
plt.show()

In [None]:
plot_subtype('NF-kB Activation Inhibitor', use_prior_stats=False,)

In [None]:
plot_subtype('Venetoclax - Panobinostat', use_prior_stats=True, save=False, apply_annot=False)

plot_subtype('Sorafenib - Panobinostat', use_prior_stats=True, save=False, apply_annot=False)


In [None]:
for i in sig_diff_drug_clusters:
    plot_subtype(i, use_prior_stats=True, save=False)
#     plt.close()
    plt.show()

In [None]:
def calc_ttest(drug_name, mutation, alternative):
    
    subset = mutation_drug_df[[mutation, drug_name]].dropna(subset=[mutation, drug_name])
    subset[mutation] = subset[mutation].astype(bool)
    true_mut = subset.loc[subset[mutation]][drug_name].values
    false_mut = subset.loc[~subset[mutation]][drug_name].values
    stat, pvalue = ttest_ind(
        true_mut,
        false_mut, 
        equal_var=False, 
        permutations=0,
        alternative=alternative
    )
    return stat, pvalue

In [None]:
def test_drug_by_subtype(drug_name, alternative):
#     subset = drug_vals.loc[drug_vals.inhibitor == drug_name].copy()
#     subset.dropna(subset=['auc'], inplace=True)
    subset = mutation_drug_df[['Cluster', drug_name]].dropna(subset=['Cluster', drug_name])
    results = []
    for i in range(1, int(subset.Cluster.max()+1)):
        true_mut = subset.loc[subset.Cluster.isin([i])][drug_name].values
        false_mut = subset.loc[~subset.Cluster.isin([i])][drug_name].values
        stat, pvalue = ttest_ind(
            true_mut,
            false_mut, 
            equal_var=False, 
            permutations=0,
            alternative=alternative
        )
        results.append([drug_name, stat, pvalue, f'{i}_vs_rest'])
    return results

In [None]:
test_drug_by_subtype('Venetoclax', 'two-sided')

In [None]:
def plot_mutation_difference(drug_name, mutation, save=False, alternative='two-sided'):
    
    subset = mutation_drug_df[[drug_name, mutation]].dropna().copy()
    subset.dropna(subset=[mutation, drug_name], inplace=True)
    subset[mutation] = subset[mutation].astype(bool)

    stat, pvalue = calc_ttest(drug_name, mutation, alternative)
    
    order = ["False", "True"]
    pairs = [("False", "True")]
    
    subset[mutation] = subset[mutation].astype(str)
    fig = plt.figure(figsize=(2, 2))
    ax = fig.add_subplot(111)
    
    ax = sns.violinplot(
        data = subset,
        x = mutation, 
        y = drug_name,
        cut = .5,
        order=order, ax=ax
    )

    annotator = Annotator(
        ax,
        pairs, 
        data=subset, 
        x=mutation, 
        y=drug_name, 
        order=order,
        verbose=0
    )

    annotator.configure(
        test='t-test_welch',
#         test='Mann-Whitney-ls',
        
        comparisons_correction="fdr_bh", 
        correction_format="replace",
#         text_format='star', 
        loc='inside',
        line_offset=.4,
        line_height=.15,
        text_offset=1,
        line_offset_to_group=0.501,
        verbose=0
    )
    annotator.apply_and_annotate()

    p = sns.stripplot(
        data=subset,
        x=mutation, 
        y=drug_name,
        ax=ax,
        color='black',
        edgecolor='black',
        alpha=.7,
        order=order,
        size=3,
    )
    plt.title(drug_name)
#     if save:
#         i = drug_name.replace("/", '')
#         plt.savefig(f'figures/drug_response_by_cluster/violin_plot_{mutation}_{i}.png', dpi=300, bbox_inches='tight')
#         plt.savefig(f'figures/drug_response_by_cluster/violin_plot_{mutation}_{i}.pdf', dpi=300, bbox_inches='tight')
    return [drug_name, stat, pvalue, mutation]

# make_plot('A-674563', 'FLT3-ITDcalls', False)
# plt.show()
# make_plot('Sorafenib', 'NPM1calls', False)
# #plt.savefig('sorafenib_violoet_npm1calls.pdf', bbox_inches='tight')
# plt.show()
plot_mutation_difference('Venetoclax', 'SF3B1_mut', False)
plt.show()
plot_mutation_difference('Panobinostat', 'SF3B1_mut', False)
plt.show()
plot_mutation_difference('CYT387', 'SF3B1_mut', False)
plt.show()
# make_plot('Venetoclax', 'NPM1calls', False)
# make_plot('Foretinib (XL880)', None, True)

In [None]:
plot_mutation_difference('Sorafenib', 'NRAS_mut', False)

In [None]:
plot_mutation_difference('Elesclomol', 'FLT3-ITDcalls', False)

In [None]:
test_drug_by_subtype('Sorafenib', alternative='two-sided')

In [None]:
def compare_mutations(drug_name, mutation_1, mutation_2,  save=False, ):
    
    name_1 = f'{mutation_1}_only'
    name_2 =  f'{mutation_2}_only'
    
    subset =  mutation_drug_df[[drug_name, mutation_1, mutation_2]].dropna(subset=[drug_name, mutation_1, mutation_2]).copy()
#     subset[mutation_1] = subset[mutation_1].astype(bool)
    subset.loc[(~subset[mutation_2]) & (~subset[mutation_1]), 'group']= 'double_wt'
    subset.loc[(subset[mutation_2] & subset[mutation_1]), 'group']= 'double_mutant'
    
    
    subset.loc[(subset[mutation_1]) & ~(subset[mutation_2]), 'group'] = name_1
    subset.loc[(subset[mutation_2]) & ~(subset[mutation_1]), 'group'] = name_2
#     print(subset)
    print()
    to_check = subset.group.value_counts().to_dict()
    order = ["double_mutant", "double_wt", name_1,  name_2]
    for i in order:
        if i not in to_check:
            return
    
    pairs = list(combinations(order, 2))
    
    fig = plt.figure(figsize=(6, 5))
    ax = fig.add_subplot(111)
    
    ax = sns.violinplot(
        data = subset,
        x = 'group', 
        y = drug_name,
        cut = .5,
        order=order,
        ax=ax
    )

    annotator = Annotator(
        ax,
        pairs, 
        data=subset, 
        x='group', 
        y=drug_name, 
        order=order,
#         hue='group',
        verbose=0
    )

    annotator.configure(
        test='t-test_welch',
#         test='Mann-Whitney-ls',
        
        comparisons_correction="fdr_bh", 
        correction_format="replace",
#         text_format='star', 
        loc='inside',
        line_offset=.14,
        line_height=.15,
        text_offset=1,
        line_offset_to_group=0.501,
        verbose=0
    )
    annotator.apply_and_annotate()

    p = sns.stripplot(
        data=subset,
        x='group', 
        y=drug_name,
        ax=ax,
        color='black',
        edgecolor='black',
        alpha=.7,
        order=order,
        size=3,
    )
    plt.title(drug_name)
    if save:
        i = drug_name.replace("/", '')
        plt.savefig(f'figures/drug_response_by_cluster/violin_plot_double_mutations_{i}.png', dpi=300, bbox_inches='tight')
        plt.savefig(f'figures/drug_response_by_cluster/violin_plot_double_mutations_{i}.pdf', dpi=300, bbox_inches='tight')

compare_mutations('Sorafenib', 'NPM1calls', 'FLT3-ITDcalls',  save=True)
plt.show()

In [None]:
mutation_table.sum().sort_values()

In [None]:
compare_mutations('Sorafenib', 'NPM1_mut', 'WT1_mut',  save=False)
plt.show()

In [None]:
mut_drug_diff = []
for mutation in list(mutation_cols)+['FLT3-ITDcalls', 'NPM1calls'] :
    for drug_name in drug_vals.inhibitor.unique():
        stat, pvalue = calc_ttest(drug_name, mutation, alternative='two-sided')
        mut_drug_diff.append([drug_name, mutation, stat, pvalue])

mut_drug_diff = pd.DataFrame(mut_drug_diff, columns=['drug', 'mut', 't_stat', 'p_value'])
mut_drug_diff['p_value'] = mut_drug_diff['p_value'].fillna(1)
mut_drug_diff['fdr_bh_sig'], mut_drug_diff['fdr_bh'] = fdrcorrection(mut_drug_diff.p_value, alpha=0.05)
mut_drug_pivot = mut_drug_diff.pivot(columns='mut', values=['fdr_bh', 't_stat'], index='drug')
mut_drug_pivot.replace(np.nan, 0, inplace=True)
mut_drug_pivot.replace(np.inf, 0, inplace=True)

In [None]:
drug_diff_hits = mut_drug_pivot[mut_drug_pivot.fdr_bh<0.05]['fdr_bh'].count()
drug_diff_hits = drug_diff_hits[drug_diff_hits>0]
drug_diff_hits.sort_values()

In [None]:
annot = make_annotations(mut_drug_pivot)

In [None]:
sns.clustermap(
    mut_drug_pivot.t_stat,
    annot=annot,
    fmt='s',
    row_cluster=True,
    col_cluster=True,
    yticklabels=True,
    figsize=(10, 10),
#     metric='correlation',
#     method ='ward',
    cmap='coolwarm',
    linecolor='black',
    linewidths=0.003,
    cbar_kws=dict( label='$t$ statistic', use_gridspec=False,)
);
# plt.savefig('mutational_caused_drug_differences.png', bbox_inches='tight', dpi=200)
# plt.savefig('mutational_caused_drug_differences.pdf', bbox_inches='tight', dpi=200)

In [None]:
subset_drug_hits = mut_drug_pivot.fdr_bh[drug_diff_hits.index]
annotations = subset_drug_hits.copy()
annot = subset_drug_hits<0.05
annotations[~annot] = ''
annotations[annot] = '+'
subset_drug_t_stat = mut_drug_pivot.t_stat[drug_diff_hits.index]

In [None]:
sns.clustermap(
    subset_drug_t_stat,
    annot=annotations,
    fmt='s',
    row_cluster=True,
    col_cluster=True,
    yticklabels=True,
    figsize=(5, 10),
    metric='correlation',
#     method ='ward',
    cmap='coolwarm',
    linecolor='black',
    linewidths=0.003,
    cbar_kws=dict( label='$t$ stat', use_gridspec=False,)
)
# plt.savefig('mutational_caused_drug_differences_focused.png', bbox_inches='tight', dpi=200)
# plt.savefig('mutational_caused_drug_differences_focused.pdf', bbox_inches='tight', dpi=200)

In [None]:
d_hits = (subset_drug_hits<0.05).sum(axis=1)
d_hits = d_hits[d_hits>1]

In [None]:
plot_mutation_difference('Sorafenib', 'FLT3-ITDcalls');
plt.show()

In [None]:
plot_mutation_difference('Midostaurin', 'GATA2_mut');
plt.show()
plot_mutation_difference('JNJ-28312141', 'GATA2_mut');
plt.show()
plot_mutation_difference('Selinexor', 'GATA2_mut');
plt.show()

In [None]:
plot_mutation_difference('Selumetinib (AZD6244)', 'NRAS_mut');
plt.show()
plot_mutation_difference('Tivozanib (AV-951)', 'NRAS_mut');
plt.show()
plot_mutation_difference('PD173955', 'NRAS_mut');
plt.show()


In [None]:
plot_mutation_difference('Dasatinib', 'RAD21_mut');
plt.show()
plot_mutation_difference('Vargetef', 'TP53_mut');
plt.show()

In [None]:
plot_mutation_difference('Panobinostat', 'SF3B1_mut')
plt.savefig('pano_sf3b1_split.pdf', bbox_inches='tight')
plt.show()
plot_mutation_difference('Venetoclax', 'SF3B1_mut')
plt.savefig('ven_sf3b1_split.pdf', bbox_inches='tight')
plt.show()

In [None]:
def gather_tests(alternative='two-sided'):
    output = []
    for i in drug_vals.inhibitor.unique():    
        output += test_drug_by_subtype(i, alternative=alternative)
    output = pd.DataFrame(output, columns=['Drug', 't_stat', 'p_value', 'comparison'])
    output['fdr_bh_sig'], output['fdr_bh'] = fdrcorrection(output.p_value)
    output_pivot = output.pivot(columns='comparison', values=['fdr_bh', 't_stat'], index='Drug')
    return output_pivot, output

def make_annotations(df):
    annot = df['fdr_bh'].astype(float)< .01
    annot[~annot] = ' '
    annot[df['fdr_bh'] < .05] = '+'
    return annot

In [None]:
plot_subtype_mutation('Panobinostat', 
        use_prior_stats=True,
        save=False,
        hue='WT1_mut')

In [None]:
drug_substype_gt, subtype_pvals_gt = gather_tests('greater')
drug_substype_lt, subtype_pvals_lt = gather_tests('less')
drug_substype_ts, subtype_pvals_ts = gather_tests('two-sided')

In [None]:


annot = make_annotations(drug_substype_ts)
sns.clustermap(
    drug_substype_ts.t_stat,
    fmt='s',
    annot=annot,
    row_cluster=True,
    col_cluster=False,
    yticklabels=True,
    figsize=(5, 10),
#     metric='correlation',
    method ='ward',
    cmap='coolwarm',
    linecolor='black',
    linewidths=0.003,
    cbar_kws=dict( label='$t$ stat', use_gridspec=False,),
);

In [None]:


annot = make_annotations(drug_substype_lt)
sns.clustermap(
    drug_substype_lt.t_stat,
    fmt='s',
    annot=annot,
    row_cluster=True,
    col_cluster=False,
    yticklabels=True,
    figsize=(5, 10),
    metric='correlation',
#     method ='ward',
    cmap='coolwarm',
    linecolor='black',
    linewidths=0.003,
    cbar_kws=dict( label='$t$ statistic', use_gridspec=False,),
);
plt.close()

In [None]:
mut_t_test = []

for mutation in ['FLT3-ITDcalls', 'NPM1calls'] :
    for drug_name in drug_vals.inhibitor.unique():
        stat, pvalue  = calc_ttest(drug_name, mutation, alternative='two-sided')
        mut_t_test.append([drug_name, stat, pvalue,  mutation,])
mut_t_test = pd.DataFrame(mut_t_test, columns=['Drug', 't_stat', 'p_value', 'comparison'])
mut_t_test['fdr_bh_sig'], mut_t_test['fdr_bh'] = fdrcorrection(mut_t_test.p_value, alpha=0.05)
mut_t_test[mut_t_test.fdr_bh_sig].sort_values('Drug').head(50)

In [None]:
npm1 = mut_t_test[mut_t_test.comparison=='NPM1calls'].sort_values('p_value')
npm1.head(10)

In [None]:
for i in npm1[npm1.fdr_bh<0.05].Drug:
#     plot_subtype_mutation(i, use_prior_stats=False, save=False, hue='NPM1calls')
    
    plot_mutation_difference(i, mutation='NPM1calls')
    plt.show()

In [None]:
plot_subtype_mutation('Gilteritinib', use_prior_stats=True, save=False)
plt.show()
    

In [None]:
mut_t_test_pivot = mut_t_test.pivot(
    index='Drug', columns='comparison', values='fdr_bh'
)
mut_t_test_pivot.to_csv('mutant_drug_response.csv', index=False)
# mut_t_test_pivot.sort_values('NPM1calls')

In [None]:
subtype_pvals_pivot = subtype_pvals.pivot(
     index='Drug', 
    columns='comparison', 
    values=['fdr_bh', 't_stat']
)

In [None]:
mutation_and_subgraph = mut_t_test_pivot.join(subtype_pvals_pivot)
mutation_and_subgraph.head()

In [None]:
joined_fdr = mut_drug_pivot.fdr_bh.join(subtype_pvals_pivot.fdr_bh)
joined_t_stat = mut_drug_pivot.t_stat.join(subtype_pvals_pivot.t_stat)

In [None]:
sns.set_theme(
    context ='paper', 
    palette="Paired", 
    style='white',
    font_scale=1.0
)


annotations = joined_fdr.copy()
annot = joined_fdr<0.05
annotations[~annot] = ''
annotations[annot] = '+'

sns.clustermap(
    joined_t_stat,
#      z_score=1,
     yticklabels=True,
    linewidths=0.01,
    col_cluster=False,
    row_cluster=True,
    annot=annotations, 
    fmt='s',
    method='ward',
#     metric='correlation',
#     cmap=['grey', 'black'],
    cmap='coolwarm',
    figsize=(12, 12)
);

In [None]:
drug_diff_hits = joined_fdr[joined_fdr<0.05].count()
drug_diff_hits = drug_diff_hits[drug_diff_hits>0]
drug_diff_hits

In [None]:
all_hits = joined_fdr[drug_diff_hits.index]
annotations = all_hits.copy()
annot = all_hits<0.05
annotations[~annot] = ''
annotations[annot] = 'x'
all_hits_t_stat = joined_t_stat[drug_diff_hits.index]

In [None]:
all_hits_t_stat

In [None]:

sns.clustermap(
    all_hits_t_stat,
#      z_score=1,
     yticklabels=True,
    col_cluster=False,
    row_cluster=True,
    annot=annotations, 
    fmt='s',
#     method='ward',
    metric='mahalanobis',
    figsize=(6, 8),
    linecolor='black',
    linewidths=0.003,
    cmap='coolwarm',
    cbar_kws=dict( label='$t$ statistic', use_gridspec=False,),
);
plt.savefig('drug_response_by_mutation_or_subtype.pdf', bbox_inches='tight')

In [None]:
annot = mutation_and_subgraph.astype(float)< .01
annot[~annot] = ' '
annot[mutation_and_subgraph < .05] = '+'
annot.shape
sns.clustermap(
    -1*mutation_and_subgraph.apply(np.log10),
#      z_score=1,
     yticklabels=True,
    col_cluster=False,
    row_cluster=True,
    annot=annot, fmt='s',
    method='ward',
    figsize=(4, 8),
    linecolor='black',
    linewidths=0.003,
    cmap='coolwarm',
    cbar_kws=dict( label='$-log_{10}$ p-value', use_gridspec=False,),
);

In [None]:

g = sns.clustermap(
    -1*mutation_and_subgraph[['FLT3-ITDcalls', 'NPM1calls']].apply(np.log10),
#      z_score=1,
     yticklabels=True,
    linewidths=0.005,
    col_cluster=False,
    row_cluster=True,
    annot=annot[['FLT3-ITDcalls', 'NPM1calls']],
    fmt='s',
    method='ward',
#     metric='correlation',
#     cmap=['grey', 'black'],
    cmap='rocket_r',
    figsize=(3, 6),
    cbar_pos=(-.10, 0.65, 0.05, 0.1),
);

g.ax_heatmap.set(xlabel=None)
g.ax_heatmap.set_xticklabels(['FLT3-ITD', 'NPM1'], rotation = 45)
g.ax_heatmap.set_yticklabels( g.ax_heatmap.get_yticklabels(), fontsize=6)
plt.savefig('figures/drug_response_by_cluster/drug_response_mutations_diff.png', dpi=300, bbox_inches='tight')
plt.savefig('figures/drug_response_by_cluster/drug_response_mutations_diff.pdf', bbox_inches='tight')
plt.close()

In [None]:
flt3_only = mutation_and_subgraph[['FLT3-ITDcalls']].copy()
flt3_only.sort_values('FLT3-ITDcalls', inplace=True)
annot = flt3_only.astype(float)< .01
annot[~annot] = ' '
annot[flt3_only < .01] = '+'
annot.shape

g = sns.clustermap(
    -1*flt3_only.apply(np.log10),
#      z_score=1,
     yticklabels=True,
    linewidths=0.005,
    col_cluster=False,
    row_cluster=False,
    annot=annot,
    fmt='s',
    method='ward',
#     metric='correlation',
#     cmap=['grey', 'black'],
    cmap='rocket_r',
    figsize=(3, 8),
    cbar_pos=(-.10, 0.65, 0.05, 0.1),
);

g.ax_heatmap.set(xlabel=None)
g.ax_heatmap.set_xticklabels(['FLT3-ITD'], rotation = 45)
g.ax_heatmap.set_yticklabels( g.ax_heatmap.get_yticklabels(), fontsize=6)
plt.savefig('figures/drug_response_by_cluster/drug_response_flt3_diff.png', dpi=300, bbox_inches='tight')
plt.savefig('figures/drug_response_by_cluster/drug_response_flt3s_diff.pdf', bbox_inches='tight')
plt.close()

In [None]:
[i for i in mutation_drug_df.columns if i.startswith('Da')]

In [None]:
    
plot_subtype_mutation('Venetoclax', 
        use_prior_stats=True,
        save=False,
        hue='SF3B1_mut')

In [None]:

plot_subtype_mutation('Selumetinib (AZD6244)', 
        use_prior_stats=False,
        save=False,
        hue='NRAS_mut')

In [None]:
plot_subtype_mutation('Panobinostat', 
        use_prior_stats=True,
        save=False,
        hue='SF3B1_mut')

In [None]:
plot_subtype_mutation('NF-kB Activation Inhibitor', 
        use_prior_stats=True,
        save=False,
        hue='EZH2_mut')

In [None]:
plot_subtype_mutation('Vargetef', 
        use_prior_stats=True,
        save=False,
        hue='TP53_mut')

# 'Vargetef', 'TP53_mut'