# Load data

- Setting up the notebook
- Loading the config file
- Setting up significance thresholds

In [None]:
# %load load_manuscript_data.py
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
from pathlib import Path
import seaborn as sns
import sys
import plotly.express as px
import plotly.io as pio
import yaml

from datetime import date

sns.set_context("notebook", font_scale=1.4)
pd.set_option("display.max_columns", 100)
pd.set_option("display.max_rows", 100)
plt.rcParams["figure.figsize"] = (16, 12)
plt.rcParams['savefig.dpi'] = 200
plt.rcParams['figure.autolayout'] = False
plt.rcParams['axes.labelsize'] = 18
plt.rcParams['axes.titlesize'] = 20
plt.rcParams['font.size'] = 16
plt.rcParams['lines.linewidth'] = 2.0
plt.rcParams['lines.markersize'] = 8
plt.rcParams['legend.fontsize'] = 14
pd.set_option('display.float_format', lambda x: '{:,.4f}'.format(x))


 
config_file = "manuscript_config.yaml"
with open(config_file) as file:
    configs = yaml.load(file, Loader=yaml.FullLoader)
    

today = date.today().strftime('%d-%m-%y')


root = Path(configs['root'])
figures_dir = Path(configs['figures_dir'])

alphabetClrs = px.colors.qualitative.Alphabet
# clrs = ["#f7ba65", "#bf4713", "#9c002f", "#d73d00", "#008080", "#004c4c"]
# colors = {
#         'light_yellow': clrs[0],
#         'darko': clrs[1],
#         'maroon':clrs[2],
#         'brighto': clrs[3],
#         'teal':clrs[4],
#         'darkteal':clrs[5]
#        }


sushi_colors = {'red': '#C0504D',
             'orange': '#F79646',
             'medSea': '#4BACC6', 
             'black': '#000000',
             'dgreen': '#00B04E',
             'lgreen': '#92D050',
             'dblue': '#366092',
             'lblue': '#95B3D7',
             'grey': alphabetClrs[8]}


lfc_threshold = 0.6
fdr_threshold = 0.05


# Nguyen et al 2020

## Load the data

In [None]:
nguyen_config = configs['nguyen']
map_file = root/nguyen_config['map_file']
counts_file = root/nguyen_config['counts_file']
results_file = root/nguyen_config['results_file']
sample_data_file = root/nguyen_config['sample_data_file']
published_results_file = root/nguyen_config['published_results_file']
published_phenotypes_file = root/nguyen_config['published_phenotypes_file']

## Process results file

In [None]:
# Get gene annotations from the map file


annotations = (pd.read_csv(map_file)[['Name', 'locus_tag']]
               .drop_duplicates())


results = pd.read_csv(results_file).merge(annotations, on='Name', how='inner')
# What is a hit?
results['mbarq_hit'] = (((results['neg_selection_fdr'] < fdr_threshold) | (results['pos_selection_fdr'] < fdr_threshold)) 
                        & (abs(results.LFC) > lfc_threshold))


## Process published results

In [None]:
published_results = pd.read_csv(published_results_file, skiprows=1)
published_results = published_results[~published_results.gene.str.contains('control_')]
published_ci = (published_results.melt(id_vars=['locus', 'gene'], 
                                       value_vars=[c for c in published_results.columns if 'median_CI' in c],
                                       value_name='median_CI', 
                                       var_name=['contrast'])
                                 .rename({'locus': 'locus_tag'}, axis=1))
published_ci['contrast'] = published_ci.contrast.str.split("_", expand=True)[0]
published_ci['log2_median_CI'] = np.log2(published_ci.median_CI)
published_hits = (published_results.melt(id_vars=['locus'], var_name=['contrast'],
                                         value_vars=[c for c in published_results.columns if 'adj_p_value_CI'in c],
                                         value_name='adj_pvalue')
                                   .rename({'locus': 'locus_tag'}, axis=1))
published_hits['contrast'] = published_hits.contrast.str.split("_", expand=True)[0]

published_df = published_ci.merge(published_hits, on=['locus_tag', 'contrast'])
published_df['published_hit'] = ((published_df.adj_pvalue < fdr_threshold) & (abs(published_df.log2_median_CI) > lfc_threshold)).astype(int)*2


## Merge mbarq and the results from the original study

In [None]:
ci_labels = {'0': 'None', '1': 'mBARq only', '2': 'Original analysis only', '3': 'Both methods'}
all_results = results.merge(published_df, on=['locus_tag', 'contrast'], how='outer')
common_genes = all_results.dropna()
all_results['published_hit'] = all_results['published_hit'].fillna(0).astype(int)
all_results['mbarq_hit'] = all_results['mbarq_hit'].fillna(False)
all_results['log2_median_CI'] = all_results['log2_median_CI'].fillna(0)
all_results['LFC'] = all_results['LFC'].fillna(0)
all_results['hit'] = (all_results['mbarq_hit'] + all_results['published_hit']).astype(str)
all_results.hit.replace(ci_labels, inplace=True)

## Compare CIs and LFCs

In [None]:
def compare_CIs(df, contrast, font_size=24, y_label='log2_median_CI'):
    """
    Look at correlation between published results and mBARq results for each contrast (i.e. day)
    
    """
    hit_label = 'Significant change in fitness'
    to_plot = df[df.contrast == contrast]
    to_plot = to_plot.rename({'hit': hit_label}, axis=1)
    print(to_plot[hit_label].value_counts())
    fig = px.scatter(to_plot, x='LFC', y=y_label, color=hit_label, 
                     height=800, width=1000,
                     template = 'plotly_white', 
                     labels = {'log2_median_CI': 'CI (original analysis)', 
                               'LFC': 'LFC (mBARq analysis)'},
                     color_discrete_map = {'None': sushi_colors['grey'], 
                                           'Both methods': sushi_colors['dgreen'], 
                                           'mBARq only': sushi_colors['dblue'], 
                                           'Original analysis only': sushi_colors['orange']},
                #hover_data=['locus_tag', 'gene'],
                category_orders = {'Significant change in fitness':['None', 'Both methods', 'Original analysis only',
                                                               'mBARq only']},)

    fig.update_traces(marker=dict(size=20, line=dict(width=1, color='DarkSlateGrey'), 
                                  opacity=0.9),
                      selector=dict(mode='markers'))
    fig.update_xaxes(showline=True, linewidth=2, linecolor='black',
                         tickfont=dict(size=font_size-6, color='black'), 
                     titlefont=dict(size=font_size, color='black'))
    fig.update_yaxes(showline=True, linewidth=2, linecolor='black',
                        tickfont=dict(size=font_size-6, color='black'), 
                     titlefont=dict(size=font_size, color='black'))

    fig.update_layout(legend=dict(font=dict(size=font_size-2)), 
                      legend_title=dict(font=dict(size=font_size)))
    return fig

In [None]:
new_genes = all_results[(all_results.LFC<-1*lfc_threshold)&(all_results.hit=='mBARq only') & (all_results.contrast.isin(['d1', 'd2']))].locus_tag.nunique()
print(f'Number of new genes with fitness defect: {new_genes}')

In [None]:
new_genes = all_results[(all_results.LFC>lfc_threshold)&(all_results.hit=='mBARq only') & (all_results.contrast.isin(['d1', 'd2']))].locus_tag.nunique()
print(f'Number of new genes with increased fitness: {new_genes}')

In [None]:
fig = compare_CIs(all_results, 'd1')
fig

In [None]:
#fig.write_image(figures_dir/f"{today}_Figure2C.png", format ='png', scale=2)

In [None]:
#compare_CIs(all_results, 'd2')
#compare_CIs(all_results, 'd3')
#compare_CIs(all_results, 'd4')

## Calculate CI/LFC correlations

In [None]:
corr_df = common_genes.groupby('contrast')[['LFC', 'log2_median_CI']].corr().iloc[0::2,-1].reset_index()
corr_df.columns = ['contrast', 'LFC', 'R']
font_size=24
fig = px.bar(corr_df, x="R", y='contrast', color='contrast',
      color_discrete_sequence = ['black']*4,
             labels={"contrast":'', 'R': "Pearson's <i>r</i>"},
      height=350, width=500, text_auto='.2f', template='plotly_white', orientation='h')
fig.update_layout(showlegend=False)
fig.update_xaxes(showline=True, linewidth=2, linecolor='black',
                         tickfont=dict(size=font_size-6, color='black'), 
                 titlefont=dict(size=font_size, color='black'))
fig.update_yaxes(showline=True, linewidth=2, linecolor='black',
                        tickfont=dict(size=font_size-6, color='black'), 
                 titlefont=dict(size=font_size, color='black'))


In [None]:
#fig.write_image(figures_dir/f"{today}_Figure2C_inset.png", format ='png', scale=2)

## Calculate recall, precision, balanced accuracy 

In [None]:
from sklearn.metrics import recall_score, precision_score, balanced_accuracy_score

phenotypes = pd.read_csv(published_phenotypes_file)
phenotypes = phenotypes.rename({'locus': 'locus_tag', 'day': 'contrast'}, axis=1)
phenotypes = phenotypes.merge(all_results, how='left', on=['locus_tag', 'contrast'])


# Same definition of hit as before
phenotypes['pheno_hit'] = ((phenotypes['adjusted p value (C.I.)'] < fdr_threshold) & (abs(np.log2(phenotypes['median'])) > lfc_threshold)).astype(int)
phenotypes = phenotypes[['locus_tag', 'gene_x', 'contrast', 'pheno_hit', 'mbarq_hit', 'published_hit']].dropna()
phenotypes['mbarq_hit'] = phenotypes.mbarq_hit.astype(int)
phenotypes['published_hit'] = (phenotypes.published_hit/2).astype(int)


metrics = {'mBARq Analysis': (precision_score(phenotypes.pheno_hit, phenotypes.mbarq_hit), 
                     recall_score(phenotypes.pheno_hit, phenotypes.mbarq_hit), 
                     balanced_accuracy_score(phenotypes.pheno_hit, phenotypes.mbarq_hit)), 
          'Original Analysis': (precision_score(phenotypes.pheno_hit, phenotypes.published_hit), 
                       recall_score(phenotypes.pheno_hit, phenotypes.published_hit), 
                       balanced_accuracy_score(phenotypes.pheno_hit, phenotypes.published_hit) )}

metric_df = (pd.DataFrame(metrics, index=['Precision', 'Recall', 'Balanced Accuracy'])
              .T
            .reset_index()
            .rename({'index':'Method'}, axis=1)
              .melt(id_vars=['Method'], var_name='Metric', value_name='Score'))

In [None]:
font_size=24
fig = px.bar(metric_df, x='Metric', y='Score', 
       color='Method', barmode='group', text_auto='.2f', 
       height=500, width=600, 
      template='plotly_white', 
      color_discrete_map = {'mBARq Analysis':'black' , 'Original Analysis': sushi_colors['grey']})
fig.update_layout(legend=dict(font=dict(size=font_size-2)), legend_title=dict(font=dict(size=font_size)))
fig.update_xaxes(showline=True, linewidth=2, linecolor='black',
                         tickfont=dict(size=font_size-6, color='black'), title="")
fig.update_yaxes(showline=True, linewidth=2, linecolor='black',
                        tickfont=dict(size=font_size-8, color='black'), 
                 titlefont=dict(size=font_size, color='black'))


In [None]:
#fig.write_image(figures_dir/f"{today}_Figure2D.png", format ='png', scale=2)

# Wetmore et all 2015

## Load the data

In [None]:
wetmore_config = configs['wetmore']
wetmore_counts_file = root/wetmore_config['counts_file']
wetmore_results_file = root/wetmore_config['results_file']
wetmore_sample_data_file = root/wetmore_config['sample_data_file']
wetmore_published_results_file = root/wetmore_config['published_results_file']
wetmore_published_counts_file = root/wetmore_config['published_counts_file']
wetmore_map_file = root/wetmore_config['map_file']
wetmore_published_stats_file = root/wetmore_config['published_stats_file']


def get_bigger_tstat(x):
    a = [abs(i) for i in x.values]
    return x.values[a.index(max(a))]

contrast_map = {'D-Maltose_monohydrate': 'D-Maltose',
                'a-Ketoglutaric_acid_disodium_salt_hydrate': 'a-Ketoglutaric acid',
                'a-Ketoglutaric': 'a-Ketoglutaric acid',
                'Potassium_acetate': 'Acetate',
                'acetate': 'Acetate',
                'CAS_amino_acids': 'CAS amino acids',
                'CAS': 'CAS amino acids',
                'Tween_20': 'Tween',
                'Sodium_L-Lactate': 'L-Lactate',
                'Sodium_D,L-Lactate': 'D,L-Lactate',
                'Sodium_pyruvate': 'Pyruvate',
                'pyruvate': 'Pyruvate',
                'Putrescine_Dihydrochloride': 'Putrescine',
                'N-Acetyl-D-Glucosamine': 'NAG',
                'L-Glutamic_acid_monopotassium_salt_monohydrate': 'L-Glutamic acid',
                'L-Glutamic': 'L-Glutamic acid',
                'Sodium_Fumarate_dibasic': 'Fumarate',
                'L-Malic_acid_disodium_salt_monohydrate': 'L-Malic acid',
                'Sodium_Fumarate_dibasic': 'Fumarate', 
                'Sodium_succinate_dibasic_hexahydrate': 'Succinate'
               }

carbon_sources =['D-Glucose', 'D-Maltose', 'a-Ketoglutaric acid', 'Acetate',
       'D-Cellobiose', 'L-Lactate', 'D,L-Lactate', 'Pyruvate',
       'D-Mannitol', 'Tween', 'L-Glutamic acid', 'L-Glutamine', 'Gly-Glu',
       'Gelatin', 'CAS amino acids', 'Putrescine', 'NAG', 'Adenosine',
       'Uridine', 'Thymidine', 'Inosine', 'Cytidine', 'D-Mannose',
       'Sucrose', 'L-Serine']

## Clean published results

In [None]:
# Slow
wetmore_sample_data = pd.read_csv(wetmore_sample_data_file)

wetmore_published_counts = (pd.read_table(wetmore_published_counts_file))
columns_to_keep = ['barcode','rcbarcode']+ [c for c in wetmore_published_counts if 'set1' in c]
wetmore_published_counts = (wetmore_published_counts[columns_to_keep]
                            .melt(id_vars=['barcode', 'rcbarcode'], 
                                  value_name='feba_counts', var_name='sample_id'))
wetmore_published_counts = wetmore_published_counts[wetmore_published_counts.sample_id.str.contains('set1')]
wetmore_published_counts['sample_id'] = (wetmore_published_counts['sample_id']
                                         .str.split('.', expand=True)[1])

In [None]:
wetmore_published_results = (pd.read_table(wetmore_published_results_file)
                             .drop(['locusId', 'desc', 'comb'], axis=1)
                             .melt(id_vars=['sysName'], var_name='contrast', 
                                   value_name='LFC'))

wetmore_published_results['set'] = wetmore_published_results.contrast.str.split(expand=True)[0]
wetmore_published_results['contrast'] = wetmore_published_results.contrast.str.split(expand=True)[1]

wetmore_published_stats = pd.read_table(wetmore_published_stats_file).drop(['locusId', 'desc'], axis=1)
wetmore_published_stats = wetmore_published_stats.melt(id_vars=['sysName'],  
                                         var_name='contrast', 
                                         value_name='tstat')
wetmore_published_stats['set'] = wetmore_published_stats.contrast.str.split(expand=True)[0]
wetmore_published_stats['contrast'] = wetmore_published_stats.contrast.str.split(expand=True)[1]

wetmore_published = wetmore_published_results.merge(wetmore_published_stats, on=['sysName', 'contrast', 'set'])
wetmore_published = wetmore_published[wetmore_published.set.str.contains('set1')]
wetmore_published = wetmore_published.rename({'sysName':'Name'}, axis=1)
wetmore_published = (wetmore_published.groupby(['contrast', 'Name']).agg({'LFC': ['median'], 
                                                               'tstat':[get_bigger_tstat]})
                                       .reset_index())
wetmore_published.columns = ['contrast', 'Name', 'published_LFC', 'tstat']
wetmore_published['contrast'] = wetmore_published['contrast'].replace(contrast_map)


## Compare counts and LFC to published data

In [None]:
wetmore_counts = pd.read_csv(wetmore_counts_file)
wetmore_counts = wetmore_counts[~wetmore_counts.old_locus_tag.isna()]
wetmore_counts = wetmore_counts.rename({'barcode':'rcbarcode'}, axis=1)
wetmore_counts = (wetmore_counts.melt(id_vars=['rcbarcode', 'old_locus_tag'], 
                                      var_name='sample_id', value_name='mbarq_counts'))
wetmore_all_counts = wetmore_counts.merge(wetmore_published_counts, 
                                          on=['rcbarcode', 'sample_id'], how='inner')
wetmore_all_counts['log_feba_counts'] = np.log2(wetmore_all_counts['feba_counts'] + 1)
wetmore_all_counts['log_mbarq_counts'] = np.log2(wetmore_all_counts['mbarq_counts'] + 1)
wetmore_count_corr = (wetmore_all_counts
                      .groupby('sample_id')[['log_mbarq_counts', 'log_feba_counts']]
                      .corr().iloc[0::2,-1]
                      .reset_index())
wetmore_count_corr.columns = ['contrast', 'comparison', 'R']
wetmore_count_corr['R2'] = round(wetmore_count_corr['R']**2, 3)

In [None]:
wetmore_results = pd.read_csv(wetmore_results_file)
wetmore_results = wetmore_results[~wetmore_results.Name.str.contains(":")]
wetmore_results['contrast'] = wetmore_results['contrast'].replace(contrast_map)
wetmore_results = wetmore_results[wetmore_results.contrast.isin(carbon_sources)]
#wetmore_results.to_csv(root/"Set1_rra_results_with_replicates.csv")
wetmore_all_results = wetmore_results.merge(wetmore_published, on=['Name', 'contrast'], how='inner')

ci_labels = {'0': 'None', '1': 'mBARq only', '2': 'Original analysis only', '3': 'Both methods'}

wetmore_all_results['mbarq_hits'] = ((abs(wetmore_all_results.LFC) > lfc_threshold) & ((wetmore_all_results.neg_selection_fdr < fdr_threshold)| (wetmore_all_results.pos_selection_fdr < fdr_threshold)))
wetmore_all_results['feba_hits'] = (abs(wetmore_all_results.tstat) > 4).astype(int)*2
wetmore_all_results['hit'] = (wetmore_all_results['mbarq_hits'].astype(int) + wetmore_all_results['feba_hits']).astype(str)
wetmore_all_results.hit.replace(ci_labels, inplace=True)

In [None]:
fig = compare_CIs(wetmore_all_results, 'Acetate', y_label='published_LFC')
fig.update_yaxes(title='LFC (original analysis)', titlefont=dict(size=24.1, color='black'))
fig.update_xaxes( titlefont=dict(size=24, color='black'))

In [None]:
#fig.write_image(figures_dir/f"{today}_Figure3B.png", format ='png', scale=2)

In [None]:
wetmore_corr = (wetmore_all_results.groupby('contrast')[['LFC', 'published_LFC']]
                .corr()
                .iloc[0::2,-1]
                .reset_index())
wetmore_corr.columns = ['contrast', 'comparison', 'R']
wetmore_corr = pd.concat([wetmore_corr, wetmore_count_corr])
wetmore_corr.replace({'log_mbarq_counts': 'Counts'}, inplace=True)

fig = px.box(wetmore_corr, x='comparison', y='R', width=400, height=400, color='comparison',
      color_discrete_map = {'LFC': 'black' ,'Counts': 'black'
                           },
             labels={'comparison': ''},
        category_orders = {'comparison': ['Counts', 'LFC']},
      template='plotly_white', hover_data=['contrast'])

fig.update_xaxes(showline=True, linewidth=2, linecolor='black',
                         tickfont=dict(size=18, color='black'), titlefont=dict(size=24, color='black'))
fig.update_yaxes(showline=True, linewidth=2, linecolor='black', range=[0, 1.1],
                        tickfont=dict(size=18, color='black'), 
                 titlefont=dict(size=24, color='black'))


fig.update_layout(showlegend=False, font=dict(size=20))

In [None]:
#fig.write_image(figures_dir/f"{today}_Figure3B_inset.png", format ='png', scale=2)

## Look at how many hits are generated by each analysis

In [None]:
wetmore_hits = (wetmore_all_results[wetmore_all_results.hit != 'None']
                .groupby('contrast').hit.value_counts(normalize=True))
wetmore_hits.name = 'hit_props'
wetmore_hits = wetmore_hits.reset_index()
legend_title = 'Significant change in fitness detected by:'
wetmore_hits = wetmore_hits.rename({'hit':legend_title}, axis=1)
fig = px.bar(wetmore_hits, x='contrast', y='hit_props', color=legend_title, 
             labels = {'hit_props': 'Proportion of hits', 'contrast':''},
       color_discrete_map = {'Both methods': sushi_colors['dgreen'] ,'Original analysis only': sushi_colors['orange'], 
                            'mBARq only': sushi_colors['dblue']}, 
      template='plotly_white', width = 1200, height=800)
fig.update_xaxes(showline=True, linewidth=2, linecolor='black',
                         tickfont=dict(size=18, color='black'), titlefont=dict(size=24, color='black'))
fig.update_yaxes(showline=True, linewidth=2, linecolor='black', range=[0, 1.1],
                        tickfont=dict(size=18, color='black'), 
                 titlefont=dict(size=24, color='black'))

fig.update_layout(legend=dict(font=dict(size=18)), legend_title=dict(font=dict(size=20)))
fig 

In [None]:
wetmore_hits_nums = (wetmore_all_results[wetmore_all_results.hit != 'None']
                .groupby('contrast').hit.value_counts())
wetmore_hits_nums.name = 'hit_nums'
wetmore_hits_nums = wetmore_hits_nums.reset_index()
round(wetmore_hits_nums.groupby('hit').hit_nums.mean(), 0)

In [None]:
round(wetmore_hits.groupby('Significant change in fitness detected by:').hit_props.mean(), 2)

In [None]:
#fig.write_image(figures_dir/f"{today}_Figure3C.png", format ='png', scale=2)

# Jasinska 2020

## Load the data

In [None]:
jasinska_config = configs['jasinska']
jasinska_sample_data_file = root/jasinska_config['sample_data_file']
jasinska_published_freq_file = root/jasinska_config['published_frequency_file']
jasinska_sample_data = pd.read_csv(jasinska_sample_data_file)
jasinska_counts_file = root/jasinska_config['no_drugs_file']
jasinska_sample_data = jasinska_sample_data[["Run", "Drug_condition_and_replicate",  "Sample Name"]]
drugs_and_reps = jasinska_sample_data["Drug_condition_and_replicate"].str.split(" r", expand=True)
drugs_and_reps.columns = ['drug_condition', 'replicate']
drugs_and_reps['replicate'] = drugs_and_reps['replicate'].replace({'1': 'Replicate 1', 
                                                                  '2': 'Replicate 2', 
                                                                  '3': 'Replicate 3'})
names = jasinska_sample_data['Sample Name'].str.split('_', expand=True)
names.columns = ['exp', 'well', 'passage', 'subsample']
jasinska_sample_data = pd.concat([jasinska_sample_data, drugs_and_reps, names], axis=1)
jasinska_sample_data = jasinska_sample_data.drop(['Drug_condition_and_replicate', 'Sample Name'], axis=1)
jasinska_sample_data['passage'] = jasinska_sample_data['passage'].str.split("-", expand=True)[1].astype(int)
jasinska_sample_data['generation'] = jasinska_sample_data['passage']*6 # from the paper

no_drug_sample_data = jasinska_sample_data[jasinska_sample_data.drug_condition == 'No drug']

In [None]:
jasinska_published_freq = pd.read_csv(jasinska_published_freq_file)
jasinska_published_freq['replicate'] = jasinska_published_freq['replicate'].replace({'Rep1': 'Replicate 1', 
                                                                  'Rep2': 'Replicate 2', 
                                                                  'Rep3': 'Replicate 3'})
jasinska_pub_barcodes = jasinska_published_freq.barcode.unique()
jasinska_color_map = jasinska_published_freq[['barcode', 'color']].set_index('barcode').to_dict()['color']

In [None]:
jasinska_counts = None
for chunk in pd.read_csv(jasinska_counts_file, chunksize=1000000):
    chunk_result = chunk.set_index("barcode")
    chunk_result = chunk_result[chunk_result.sum(axis=1) > 10]
    if jasinska_counts is None:
        jasinska_counts = chunk_result
    else:
        jasinska_counts = jasinska_counts.add(chunk_result, fill_value=0)

In [None]:
jasinska_counts = jasinska_counts/jasinska_counts.sum()
jasinska_counts = jasinska_counts.reset_index()

## Graph frequencies overtime

In [None]:
hi_freq = jasinska_counts[jasinska_counts.barcode.isin(jasinska_pub_barcodes)]
hi_freq  = hi_freq.melt(id_vars='barcode', value_name = 'Frequency', var_name = 'Run')
hi_freq_full = no_drug_sample_data.merge(hi_freq, on='Run', how='left')
freq_overtime = hi_freq_full.groupby(['barcode', 'replicate', 'generation']).Frequency.mean().reset_index()

In [None]:
def graph_frequency_over_time(df, time_col, freq_col, color_dict, barcode_col='barcode', 
                              filter_by_col='', filter_by_value='' ):
    if filter_by_col:
        df = df[df[filter_by_col] == filter_by_value]
    barcode_order = list(df.groupby([barcode_col])[freq_col].sum().sort_values(ascending=False).index)
    df = (df[[barcode_col, freq_col, time_col]].drop_duplicates()
          .pivot(index = time_col, columns = barcode_col))
    df.columns = [c[1] for c in list(df.columns)]
    df = df.reset_index()
    df = df[[time_col] + barcode_order]
    y = [df[c] for c in df.columns[1:]]
    color_map = [color_dict[c] for c in df.columns[1:]]
    sns.set_style('ticks')
    fig = plt.figure(figsize=(5,4))
    #create area chart
    plt.stackplot(df[time_col], y, colors=color_map)
    #add axis labels
    plt.xlabel('Time (generations)')
    plt.ylabel('Lineage frequency')
    return fig

In [None]:
fig = graph_frequency_over_time(freq_overtime, 'generation', 'Frequency', jasinska_color_map, 
                                'barcode', 'replicate', 'Replicate 1')
#fig.savefig(figures_dir/f"{today}_Figure4A_i.png", dpi=150, bbox_inches = "tight")

In [None]:
fig = graph_frequency_over_time(freq_overtime, 'generation', 'Frequency', jasinska_color_map, 
                                'barcode', 'replicate', 'Replicate 2')
#fig.savefig(figures_dir/f"{today}_Figure4A_ii.png", dpi=150, bbox_inches = "tight")

In [None]:
fig = graph_frequency_over_time(freq_overtime, 'generation', 'Frequency', jasinska_color_map, 
                                'barcode', 'replicate', 'Replicate 3')

#fig.savefig(figures_dir/f"{today}_Figure4A_iii.png", dpi=150, bbox_inches = "tight")

## Graph final frequencies

In [None]:
mbarq_result_mean = hi_freq_full.groupby(['barcode', 'replicate']).Frequency.mean().reset_index()
mbarq_result_mean.columns = ['barcode', 'replicate', 'mbarq_av_freq']
mbarq_result_final = (hi_freq_full[hi_freq_full.generation == 420]
                      .groupby(['barcode', 'replicate'])
                      .Frequency.mean()
                      .reset_index())
mbarq_result_final.columns = ['barcode', 'replicate', 'mbarq_final_freq']

mbarq_result = (mbarq_result_mean.merge(mbarq_result_final, on=['barcode', 'replicate'])
                .merge(jasinska_published_freq, on=['barcode', 'replicate']))

In [None]:
final_freq_df = (mbarq_result[['barcode', 'replicate', 'mbarq_final_freq', 'final_freq']]
                 .melt(id_vars=['barcode', 'replicate'], value_name = 'Frequency', var_name ='method'))
final_freq_df['method'] = final_freq_df['method'].replace({'mbarq_final_freq': 'mBARq',
                                                          'final_freq': 'Original analysis'})
fig = px.bar(final_freq_df, x="method", y="Frequency", color="barcode", log_y=True, 
             color_discrete_map = jasinska_color_map, 
       facet_col='replicate', template="plotly_white",
       category_orders = {'replicate': ['Replicate 1', 'Replicate 2', 'Replicate 3']},
             labels = { 'Frequency': 'Lineage frequency','method':''},
            width=600, height=600
      )
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fig.update_layout(showlegend=False)
fig.update_xaxes(showline=True, linewidth=2, linecolor='black',
                         tickfont=dict(size=24, color='black'), 
                
                 titlefont=dict(size=24, color='black'))
fig.update_yaxes(showline=True, linewidth=2, linecolor='black',
                        tickfont=dict(size=18, color='black'), 
                
                 titlefont=dict(size=24, color='black'))

#fig.write_image(figures_dir/f"{today}_Figure4C.png", format ='png', scale=2 )

## Graph correlation for average frequencies

In [None]:
fig = px.scatter(mbarq_result, y='average_freq', x='mbarq_av_freq', log_x=True, log_y=True, 
                 labels={'mbarq_av_freq': 'Average lineage frequency (mBARq)',
                        'average_freq': 'Average lineage frequency (original analysis)'},
           color='barcode', template='plotly_white', width=700, height=700, color_discrete_map=jasinska_color_map)
fig.update_layout(showlegend=False)
fig.update_traces(marker=dict(size=14,
                              line=dict(width=1,
                                        color='DarkSlateGrey')),
                  selector=dict(mode='markers'))


fig.update_xaxes(showline=True, linewidth=2, linecolor='black',
                         tickfont=dict(size=18, color='black'), 
                 range=[-3.5, 0.0],
                 titlefont=dict(size=24, color='black'))
fig.update_yaxes(showline=True, linewidth=2, linecolor='black',
                        tickfont=dict(size=18, color='black'), 
                 range=[-3.5, 0.0],
                 titlefont=dict(size=24, color='black'))

fig.update_layout(legend=dict(font=dict(size=18)), legend_title=dict(font=dict(size=20)))


#fig.write_image(figures_dir/f"{today}_Figure4B.png", format ='png', scale=2 )