In [1]:
%run lgs00_general_functions.ipynb

### General

In [2]:
identities = {frozenset(['amac','av']):82, frozenset(['cgre','av']):41, frozenset(['pplu','av']):18,
              frozenset(['amac','cgre']):43, frozenset(['amac','pplu']):17, frozenset(['cgre','pplu']):19,
        frozenset(['amacV14L','av']):82, frozenset(['amacV14L','cgre']):43, frozenset(['amacV14L','pplu']):17,
             frozenset(['amac','amacV14L']):100}

In [3]:
def make_figure_letters(figure, coordinates, **kwargs):
    ax = figure.add_subplot(1,1,1)
    ax = fig.add_subplot(1,1,1)
    for side in ['top','bottom','right','left']:
        ax.spines[side].set_color(None)
    ax.tick_params(labelcolor='w', top=False, bottom=False, left=False, right=False)
    
    for letter in coordinates:
        plt.text(coordinates[letter][0], coordinates[letter][1], letter, weight='bold', **kwargs)
    return ax

In [21]:
import matplotlib.image as mpimg
def show_image(image_path):
    plt.imshow(mpimg.imread(image_path))
    plt.axis('off')

In [None]:
def patch_violinplot(facecolors, edgecolors, ):
    from matplotlib.collections import PolyCollection
    ax = plt.gca()
    violins = [art for art in ax.get_children() if isinstance(art, PolyCollection)]
#     colors = sns.color_palette(palette, n_colors=n) * (len(violins)//n)
    for i in range(len(violins)):
        violins[i].set_edgecolor(edgecolors[i])
        violins[i].set_facecolor(facecolors[i])

# Overview

### Library distributions with WTs and chromomutants

In [4]:
def plot_library_distributions(gene, nt_data, aa_data, ax, 
                               y1_label='Number of genotypes: Library', 
                               y2_label='Number of genotypes:\nWTs, chromophore mutants'):
    library = aa_data[(aa_data.gene==gene)]
    wts = nt_data[(nt_data.gene==gene) & (nt_data.aa_genotype_native=='wt')]
    
    nonsense_mask = (aa_data.aa_genotype_pseudo.str.contains('\*')) & ~(aa_data.aa_genotype_pseudo.str.contains('\*246'))
    chromomuts = aa_data[(aa_data.gene==gene) & ~nonsense_mask & (aa_data.n_mut<=20)
                         & (aa_data.aa_genotype_pseudo.str.contains('G69|Y68|R99'))]
    
    sns.histplot(data = library['brightness'], 
                        bins=50, element='poly', color=colors[gene], fill=True, linewidth=2, )
    
    plt.ylim(0,4500)
    plt.yticks([])
    plt.ylabel('')
    plt.xlabel('Fluorescence (log)', fontsize=12)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    ax.twinx()
    wtmean = round(wts['log_brightness'].mean(),1)
    wtstd = round(wts['log_brightness'].std(),3)
    sns.histplot(data = wts['log_brightness'],
                        bins=10, element='poly', color='w', fill=True, linewidth=2, edgecolor=colors[gene])
    

    sns.histplot(data = chromomuts['brightness'],
                        bins=50, element='poly', color='k', fill=True, linewidth=0, )
    plt.ylabel('')
    if gene=='avGFP':
        plt.ylim(0,200)
    elif gene=='cgreGFP':
        plt.ylim(0,900)
        
    
        
    y2_ticks = {'amacGFP':range(0,500,100), 'cgreGFP':range(0,900,200), 
               'ppluGFP':range(0,1600,400), 'avGFP':range(0,200,50)}
    plt.yticks([])
    legendary(colors = [colors[gene],'w','k'], edges = [colors[gene],colors[gene], 'k'],
              labels = [f'Library\n(n = {len(library)})', f'WT protein \n(n = {len(wts)})',
                       f'Mut. chromophore\n(n = {len(chromomuts)})'], loc='upper left',
             )
#     plt.setp(legend.get_title(),fontsize=14)
    plt.title(names[gene], fontsize=13)    
    plt.sca(ax)


# Comparison of genes

### Violins of N-mutants with median curves

In [5]:
from scipy.optimize import curve_fit

def sigmoid(x, L ,x0, k, b):
    y = L / (1 + np.exp(-k*(x-x0)))+b
    return (y)

def fit_sigmoid(xdata, ydata, mid_guess):
    p0 = [max(ydata), mid_guess, 1, min(ydata)]
    popt, pcov = curve_fit(sigmoid, xdata, ydata, p0, maxfev=10000)
    return popt

def get_fitnesses_by_nmut(gene, n, data_nt, data_aa, fit=False):
    wt = data_nt[(data_nt.gene==gene) & (data_nt.aa_genotype_native=='wt')]['log_brightness']
    muts = [data_aa[(data_aa.gene==gene) & (data_aa.n_mut==i)]['brightness'] for i in range(1,n)]
    
#     labels = ['WT'] + [f'{i} mut.' for i in range(1,n)]
    labels = ['WT'] + [i for i in range(1,n)]
    
    toplot = [list(wt)]+[list(x) for x in muts]
    violins = plt.violinplot(toplot, showmedians=True, showmeans=False, positions=range(0,n),widths=0.8,
                            showextrema=False)
    
    
    for patch in violins['bodies']:
        patch.set_color(colors[gene])
#     for item in ['cmedians']:#'cbars', 'cmaxes', 'cmins', 
#         violins[item].set_color(colors[gene])
#         violins[item].set_linewidth(3)
    violins['cmedians'].set_linewidth(2)
    violins['cmedians'].set_color('k')
    for v in violins['bodies']:
        v.set_alpha(1)
#         v.set_edgecolor('k')
    
    plt.ylabel('Fluorescence (log)', fontsize=10)
    plt.xticks(range(0,n), labels)
    plt.xlabel(f'Number of mutations')
    
    if fit==True:        
        popt = fit_sigmoid(xdata = list(range(0,n)), ydata = [np.median(x) for x in toplot], mid_guess=5)
        x = np.linspace(0, 8, 100)
        y = sigmoid(x, *popt)
        plt.plot(x,y, color='k', linewidth=2)

### Buried vs exposed residues violins

In [6]:
def plot_buried_vs_exposed_violins_singles(gene):
    
    ax = sns.violinplot(data=data_aa[~wt_mask & singles_mask & (data_aa.gene==gene)], 
                        x='gene', y='brightness', 
                   saturation=1,width=1,
                   hue='has_buried_mutation', col='gene', split=True, 
               linewidth=2, inner=None, scale='area', cut=0)

    patch_violinplot(edgecolors=[None, colors[gene]], facecolors=[colors[gene],'w'])
    
    legendary(['k', 'w'], ['Exposed sites', 'Buried sites'], edges=['k','k'],
              ncol=2, loc='upper center')
    
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_position('center')
    ax.get_legend().remove()
    
    yticks = {'amacGFP' : [3,3.5,4,4.5], 'cgreGFP':[3,3.5,4,4.5], 
              'ppluGFP':[3,3.5,4,4.5], 'avGFP':[1.5,2,2.5,3,3.5,4]}
    plt.yticks(yticks[gene],yticks[gene])
    
    label_plot_axis(y='Fluorescence of single mutants', fontsize_y=12)
    plt.xticks([0],[names[gene]],fontsize=12)

### Observed vs expected bright genotypes

In [None]:
def get_ratio_predicted2observed_fluo(dataset, n_mut, threshold=0.4, **kwargs):
    '''Returns the fraction of genotypes observed to be bright, from among those predicted to be bright
    (measure of negative epistasis), as well as the fraction of genotypes observed to be dark,
    from among those predicted to be dark (measure of positive epistasis).'''
    
    n_mut_mask = dataset['n_mut']==n_mut
    bright_prediction_mask = dataset['expected_effect']>=threshold
    dark_prediction_mask = dataset['expected_effect']<threshold
    is_bright_mask = dataset['measured_effect']>=threshold
    
    bright_predicted = dataset[n_mut_mask & bright_prediction_mask]
    bright_real = dataset[n_mut_mask & is_bright_mask & bright_prediction_mask]
    dark_predicted = dataset[n_mut_mask & dark_prediction_mask]
    dark_real = dataset[n_mut_mask & ~is_bright_mask & dark_prediction_mask]
    
    return len(bright_real)/len(bright_predicted)

In [7]:
def plot_predicted_bright_vs_real(dataset, threshold=False):
    stds = {'amac' : 0.03148308406723709, 'cgre' : 0.02852658185221242, 'pplu' : 0.027489382693326424,
          'amacV14L' : 0.029670382945273337, 'av' : 0.0817635644945253}

#     brightness_thresholds = {'amac':np.log10(3000), 'cgre':np.log10(6000), 'pplu':np.log10(5000), 'av':2, 
#                              'amacV14L':np.log10(3000), 'amacV14V':np.log10(3000)}
    


    n = 8
    name = {'amac':'amacGFP', 'av':'avGFP', 'cgre':'cgreGFP', 'pplu':'ppluGFP2'}
    
    for gene in ['amac', 'cgre', 'pplu','av' ,]:
        if threshold == False:
            threshold = gate_borders_log[gene][0] - ref_wt_log[gene+'GFP']
            
        plt.plot([100 * get_ratio_predicted2observed_fluo(dataset[eval(gene+'_mask')], n_mut=i, 
                            threshold=threshold) for i in range(1,n)],
                 color=colors[gene],label=name[gene], linewidth=3)
            
        plt.legend(loc='lower left', frameon=False)
        plt.ylim(0,105)
        plt.xticks(range(n-1), range(1,n))
#         label_plot_axis(x='Number of amino acid substitutions',
#                         y='Observed v. expected\nfunctional genotypes (%)',)

### Mini landscapes of single mutants

In [8]:
singlemuts_landscapes = pd.read_csv('single_mutant_landscapes.csv', index_col=0)

In [9]:
for n in [1,2,3,4,5,6,7]:
    singlemuts_landscapes[str(n)] = singlemuts_landscapes[[str(n), f'n_obs:{n}']].apply(lambda x:
                                                    x[0] if x[1] > 15 else np.nan, axis=1)

In [10]:
def plot_singlemut_landscapes(gene):
    meds = [data_aa[(data_aa.gene==gene) & (data_aa.n_mut==i)]['brightness'].median() for i in range(8)]
    for i in singlemuts_landscapes[singlemuts_landscapes.gene==gene].index:
        plt.plot(range(8), singlemuts_landscapes[singlemuts_landscapes.gene==gene].loc[i][0:8], 
             linewidth =1, color=colors[gene], alpha=0.2)
    plt.plot(range(8), meds, color='k', linewidth=2)
    plt.xticks(range(8),range(8))
    legendary(['w'], [''], title=names[gene], loc='upper right', fontsize=12)
    plt.ylabel('Fluorescence (log)', fontsize=10)
    plt.xlabel('Number of mutations', fontsize=10)

# Effect change of mutations across genes

In [None]:
def get_overall_fractions_cross_gene_effects(gene1, gene2, neutral_threshold, bad_threshold):
    df = df_effects_singles[['effect_in_'+gene1, 'effect_in_'+gene2]].dropna() # muts must be obsered in both genes
    df = df.reset_index()
    
    wts = {'amac' : 3.974485547944876, 'cgre' : 4.49832043653411, 'pplu' : 4.226492294827216,
          'amacV14L' : 3.9577171946556176, 'av' : 3.726358457799625}
    
    stds = {'amac' : 0.03148308406723709, 'cgre' : 0.02852658185221242, 'pplu' : 0.027489382693326424,
          'amacV14L' : 0.029670382945273337, 'av' : 0.0817635644945253}
    
    neutral_in_both_mask = (df['effect_in_'+gene1] > -stds[gene1]*neutral_threshold) & (
                                        df['effect_in_'+gene2] > -stds[gene2]*neutral_threshold)
    
    bad_in_both_mask = (df['effect_in_'+gene1] < -stds[gene1]*bad_threshold) & (
                                        df['effect_in_'+gene2] < -stds[gene2]*bad_threshold)
    
    changed_effect_mask = ((df['effect_in_'+gene1] > -stds[gene1]*neutral_threshold) & 
                            (df['effect_in_'+gene2] < -stds[gene2]*bad_threshold)) | (
                            (df['effect_in_'+gene2] > -stds[gene2]*neutral_threshold) & 
                            (df['effect_in_'+gene1] < -stds[gene1]*bad_threshold))
    
#     print('neutral in both', len(df[neutral_in_both_mask]) / len(df))
#     print('bad in both', len(df[bad_in_both_mask]) / len(df))
#     print('changed effect', len(df[changed_effect_mask]), len(df))

    return {'nn' : len(df[neutral_in_both_mask]) / len(df) ,
           'bb' : len(df[bad_in_both_mask]) / len(df), 'bn' : len(df[changed_effect_mask]) / len(df),
           'n_muts' : len(df)}

In [None]:
def get_overall_epistasis_overlap(gene1, gene2, e):
    # get mutation pairs which are measured to be epistatic in one gene
    # and measured as non-epistatic in another gene
    
    epairs1 = {x for x in epistatic_pairs_pos[gene1] if
                                abs(max(epistatic_pairs_pos[gene1][x]))>e
                            and x in epistatic_pairs_pos[gene2]}
    
    epairs2 = {x for x in epistatic_pairs_pos[gene2] if
                                abs(max(epistatic_pairs_pos[gene2][x]))>e
                            and x in epistatic_pairs_pos[gene1]}
#     assert len(epairs1) == len(epairs2)
    
    eboth = [x for x in epairs1 if x in epairs2]
    esingle = list(epairs1) + [x for x in epairs2 if x not in epairs1]
#     esingle = [x for x in epairs1 if x not in epairs2] + [x for x in epairs2 if x not in epairs1]
    
    return [len(eboth) / len(esingle), len(esingle)]

In [11]:
def plot_bars_effectchange_epimaintain_split(ax1, ax2, spread = 1.4, c=matplotlib.cm.bone(225)):
    mako = ListedColormap(sns.color_palette('bone', 256))

    shares = {}
    fractions = {}
    pairs = [('amac','amacV14L'), ('amac','av'),('amacV14L','av'),
                 ('amac','cgre'), ('amacV14L', 'cgre'), ('av','cgre'),
             ('amac','pplu'), ('amacV14L','pplu'), ('av','pplu'), ('cgre','pplu')]

    
#     names = {'amac' : 'Amac', 'cgre' : 'Cgre', 'pplu':'Pplu', 'av':'Av', 'amacV14L':'Amac:V12L'}
    
    for pair in pairs:
        fractions[pair] = get_overall_fractions_cross_gene_effects(pair[0],pair[1],2,5)
        shares[pair] = get_overall_epistasis_overlap(pair[0], pair[1], e=0.3)

    # TOP PLOT
    x = np.array([0, 2,3, 5,6,7, 9,10,11,12])
    plt.sca(ax1)
    plt.bar(x*spread, [100*fractions[pair]['bn'] for pair in pairs], width=0.8, color=c, edgecolor='k')
    plt.ylabel('Neutrality change\n(% of mutations)', fontsize=11)
    
    plt.xticks(x*spread, 
               [f'n = {fractions[pair]["n_muts"]}' for pair in pairs], 
               rotation=0, ha='center' ,  fontsize=10, va = 'top')#va='center_baseline')
    
    for i in range(len(pairs)):
        plt.text(x = x[i]*spread, y = 1+ 100*fractions[pairs[i]]['bn'], #s = fractions[pairs[i]]['n_muts'],
                 s = f'{names[pairs[i][1]]},\n{names[pairs[i][0]]}', fontsize=9,
                rotation = 0, ha='center', va='bottom')
    plt.ylim(0,21)
    
    for xi in np.array([1,4,8])*spread:
        plt.axvline(xi, linewidth=1, linestyle = ':', color='k')

    for xi,div in zip(np.array([0,2.5,6,10.5])*spread, 
                        ['0.4% divergence', '18% divergence', '57-59% divergence', '81-83% divergence']):
        plt.text(xi, 23, div, horizontalalignment='center', verticalalignment='center', fontsize=12)  
    
    
    
    # BOTTOM PLOT
    plt.sca(ax2)
    plt.bar(x*spread, [100*shares[pair][0] for pair in pairs], width=0.8, color=c, edgecolor='k')
    plt.ylabel('Pairs of sites epistatic\nin both genes (%)', fontsize=11)
    plt.xticks(x*spread, ['']*len(x), rotation=0, ha='right' , fontsize=8, va='center_baseline')
    plt.ylim(0,9)
    
    for i in range(len(pairs)):
        plt.text(x = x[i]*spread, y = 0.5+ 100*shares[pairs[i]][0], #s = fractions[pairs[i]]['n_muts'],
                 s = f'{names[pairs[i][1]]},\n{names[pairs[i][0]]}', fontsize=9,
                rotation = 0, ha='center', va='bottom')
    plt.xticks(x*spread, 
               [f'n = {shares[pair][1]}' for pair in pairs], 
               rotation=0, ha='center' ,  fontsize=10, va = 'top')#va='center_baseline')
        
    for xi in np.array([1,4,8])*spread:
        plt.axvline(xi, linewidth=1, linestyle = ':', color='k')
        
    
#     plt.xlabel('Percent')


# Testing neural net predictions

In [None]:
stds = {'amac' : 0.03148308406723709, 'cgre' : 0.02852658185221242, 'pplu' : 0.027489382693326424,
          'amacV14L' : 0.029670382945273337, 'av' : 0.0817635644945253}

In [None]:
def load_predictions_data():
    predictions = pd.read_csv(os.path.join(data_folder, 'predictions','experimentally_tested_predictions.csv'))
    fiji_ctrls = {'cgre_wt': 246, 'cgre_neg': 5, 'amac_wt': 65, 'amac_neg': 4, 'pplu_wt': 191, 'pplu_neg': 4}
    predictions['fiji_log_value'] = np.log10(predictions['fiji_value'])
    
    predictions['fiji_log_scaled'] = predictions[['fiji_log_value', 'gene']].apply(lambda x:
                                        (x[0] - np.log10(fiji_ctrls[x[1]+'_neg'])) / 
                            (np.log10(fiji_ctrls[x[1]+'_wt']) -  np.log10(fiji_ctrls[x[1]+'_neg'])), axis=1)
    
    predictions['fiji_log_scaled'] = predictions['fiji_log_scaled'].apply(lambda x: x if x>0 else 0)
    
    minval = {gene : data_aa[data_aa.gene==gene]['brightness'].min() 
              for gene in ['amacGFP', 'cgreGFP', 'ppluGFP', 'avGFP']}
    print(minval)
    predictions['fiji_library_values'] = predictions[['fiji_log_scaled', 'gene']].apply(lambda x:
                                    x[0] * (ref_wt_log[x[1]+'GFP'] - minval[x[1]+'GFP']) + minval[x[1]+'GFP']
                                                                                  , axis=1    )
    
    return predictions

In [None]:
predictions = load_predictions_data()

In [None]:
def plot_all_vs_neutral_muts(gene, y_axis, threshold, df_effects, color, 
                             df = data_aa, borders = gate_borders_log):
    df = df[(df['gene']==gene+'GFP')]
    plot_half_violin([df[df['n_mut']==i][y_axis] for i in range(1,9)], side='left', 
                     color=color[1], alpha=1, widths=0.9, chonkylines=True)
    bads = df_effects[df_effects['effect_in_'+gene] < -threshold].copy()
    bads['full_mutation'] = bads[['wt_state_'+gene, 'position', 'mutation']].apply(lambda x:
                                            x[0] + str(x[1]) + x[2], axis=1)
    bads = set(bads['full_mutation'])
    bads = {x for x in bads if '*' not in x}

    df = df[(df['gene']==gene+'GFP') & ~(df['aa_genotype_pseudo'].str.contains('|'.join(bads)))]

    plot_half_violin([df[df['n_mut']==i][y_axis] for i in range(1,9)], side='right', alpha=0.5, 
                     color=color[0], widths=0.9, chonkylines=True)

    label_plot_axis(x = 'Number of mutations', t= gene+'GFP fitness')

    if gene!='av':
        plt.axhline(borders[gene][0], color='crimson', linewidth=1, linestyle='--',
                   path_effects=[pe.Stroke(linewidth=3, foreground='w'), pe.Normal()])
    

In [None]:
def plot_predictions(gene, color, threshold=0.15, ycol = 'fiji_library_values', 
                     df = df_effects_singles, mode='full',**kwargs):
    c = ListedColormap(sns.color_palette('mako', 256))
    pos = [1,2,3,4,5,6,7,8,12,18,24,30,36,42,48]
    plot_all_vs_neutral_muts(gene, 'brightness', threshold, 
                                         df, color=color)

    dummy = predictions[predictions['gene']==gene].copy()
    dists = list(range(1,49))
    for d in [6,12,18,24,30,36,42,48]:
        dists.remove(d)
    for d in dists:
        dummy.loc[10000*d,'distance'] = d
        dummy.loc[10000*d,'gene'] = gene
    sns.swarmplot(data=dummy[dummy['gene']==gene], x='distance', y=ycol, color='k',
                  edgecolor='w', linewidth=1, **kwargs)

    dists = [dummy[dummy['distance']==x][ycol].median() for x in [6,12,18,24,30,36,42,48]]
    plt.plot([x-1 for x in [6,12,18,24,30,36,42,48]], dists, color='k', linestyle='--', linewidth=1)
    
    plt.xticks([i-1 for i in pos], pos);

# FACS library sorting

In [12]:
from FlowCytometryTools import FCMeasurement
from FlowCytometryTools import ThresholdGate, PolyGate, IntervalGate, QuadGate

def transform_by_log10(dataset):
    dataset = np.where(dataset<=1, 0, np.log10(dataset))
    return dataset

def determine_filename(gene, machine, ctrl):
    path = os.path.join(data_folder, 'cell_sorting', 'fcs_files')
    machine = 'GFP_' + machine + '__' if gene!='pplu' else 'GFP2_' + machine + '__'
    if ctrl in ['ctrl', 'negctrl']:
        name = 'negative_control'
    else:
        name = 'library'
    filename = os.path.join(path, gene + machine + name + '.fcs')
    return filename

def load_fcs_data(filename):
    fcs = FCMeasurement(ID=filename, datafile=filename)
    cols = [x for x in fcs.data.columns if x!='Time']
    fcs = fcs.transform(transform_by_log10, auto_range=False, use_spln=False, channels=cols)
    for channel_name in ['GFP-A', 'FITC-A']:
        if channel_name in fcs.data.columns:
            fcs.data.rename(columns={channel_name:'GFP'}, inplace=True)
    for channel_name in ['mCherry-A', 'PE-Texas Red-A']:
        if channel_name in fcs.data.columns:
            fcs.data.rename(columns={channel_name:'mKate2'}, inplace=True)
    return fcs

The dedent function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use inspect.cleandoc instead.
  doc = dedent(doc)


In [13]:
is_bacteria_gate = ThresholdGate(threshold=1, channel=['FSC-A'], region='above', name='is_bacteria_gate')

# Based on sorting data pdfs from VBCF
mKate2_gate_A = IntervalGate(vert=(3.2, 3.4), channel=['mKate2'], region='in', name='mKate2_gate_A')
mKate2_gate_B = IntervalGate(vert=(3.4, 3.6), channel=['mKate2'], region='in', name='mKate2_gate_B')

In [14]:
def sortingplot_library_scatter(gene, machine, fontscale=1, savepath=False):
    redgate = {'A' : {'amac':(3.18, 3.35), 'cgre':(3.14, 3.32), 'pplu':(3.18, 3.44)},
              'B' : {'amac':(3.4, 3.6), 'cgre':(3.29, 3.57), 'pplu':(3.28, 3.53)}}
    if gene=='pplu':
        gate_borders = pd.read_csv(os.path.join(data_folder, 'cell_sorting', 'ppluGFP2__gate_border_values.txt'),
                              sep='\t', index_col=0)
    else:
        gate_borders = pd.read_csv(os.path.join(data_folder, 'cell_sorting', gene+'GFP__gate_border_values.txt'),
                                  sep='\t', index_col=0)
    gate_borders = {'A': np.log10(np.array(gate_borders.head(7)['upper_limit'])),
                  'B': np.log10(np.array(gate_borders.iloc[8:15]['upper_limit']))}
    
    library = load_fcs_data(determine_filename(gene,machine,''))
    library = library.gate(is_bacteria_gate).data.copy().replace(0,np.nan).sample(10000)
    library[''] = 'Whole library'
    ctrl = load_fcs_data(determine_filename(gene,machine,'ctrl'))
    ctrl = ctrl.gate(is_bacteria_gate).data.copy().replace(0,np.nan).sample(1500)
    ctrl[''] = 'Non-fluorescent controls'
    data=pd.concat([library,ctrl])
    data['sorted'] = data['mKate2'].apply(lambda x: True if 
                                          redgate[machine][gene][0]<x<redgate[machine][gene][1] else False)
    
    ax = sns.jointplot(data = data, x='GFP', y='mKate2', hue='', 
                       alpha=0.4, s=5, edgecolor=None, palette=['darkgrey', 'k',], )
    ax.set_axis_labels('GFP fluorescence (log)','mKate2 fluorescence (log)', fontsize=10+fontscale)
    ax.ax_joint.legend(frameon=False)

    plt.sca(ax.ax_joint)
    plt.scatter(data[data.sorted==True]['GFP'], data[data.sorted==True]['mKate2'], color='crimson', s=4, 
                label='Sorted cells', alpha=0.4)
    legendary(['lightgrey', 'grey', 'crimson'], 
              ['Whole library', 'Non-fluorescent controls', 'Sorted cells'], fontsize=10+fontscale,
             title='%sGFP library: sorter %s' % (gene,machine), title_fontsize=12+fontscale, loc='upper left')
    plt.ylim(0,5.5)
    plt.xlim(0,5.5)
    plt.yticks(fontsize=10+fontscale)
    plt.xticks(fontsize=10+fontscale)
    
    plt.sca(ax.ax_marg_x)
    sns.kdeplot(x=data[data.sorted==True]['GFP'], color='crimson')
    for border in gate_borders[machine]:
        plt.axvline(border, color='crimson', linestyle='--', linewidth=1)
    
#     plt.fill_between([0,6], redgate[0], redgate[1], facecolor='r', edgecolor='w', alpha=0.4, linewidth=0,
#                      hatch='////', )
    return ax
    if savepath:
        savefig(savepath, dpi=300)

# Extant vs non-extant mutations

### Scatterplot: extant states becoming deleterious by background

In [16]:
def compare_against_extant(gene, dataset, nok=-0.2, overlap_n=10, **kwargs):
#     muts_all = set(dataset[singles_mask & eval(gene+'_mask')]['quasipos_genotype'])
#     muts_nok = set(dataset[singles_mask & eval(gene+'_mask') 
#                        & (dataset.measured_effect < nok)]['quasipos_genotype'])
#     namekey = {'amac':'GFPxm'}
    gene_name = namekey[gene]
    gene_wt = get_wt_states(gene_name)

    muts_all = {x for x in eval('siffects_'+gene+'_log')}
    muts_all = {x[0] + str(pseudopos_to_nativepos[int(x[1:-1])][genekey[gene]]) + x[-1] for x in muts_all}
    muts_all = muts_all | {x[0:-1] + x[0] for x in muts_all}
    muts_all = {str(nativepos_to_quasipos[int(x[1:-1])][gene_name]) + x[-1] for x in muts_all}
    
    muts_nok = {x for x in eval('siffects_'+gene+'_log') if eval('siffects_'+gene+'_log')[x] < nok}
    muts_nok = {str(pseudopos_to_nativepos[int(x[1:-1])][genekey[gene]]) + x[-1] for x in muts_nok}
    muts_nok = {str(nativepos_to_quasipos[int(x[:-1])][gene_name]) + x[-1] for x in muts_nok}

    for extant in wts:
        if extant=='RrGFP' or extant==namekey[gene]:
#             print(extant)
            pass
        
        else:
            extant_ok = get_wt_states(extant)
            extant_ok = {x for x in extant_ok if x not in gene_wt}
                
#             print(extant, m)
            overlap = {x for x in extant_ok if x in muts_all}
            if len(overlap) > overlap_n:
                nok_in_gene = {x for x in overlap if x in muts_nok}
                aa_id = 100 - pairwise[frozenset([extant, namekey[gene]])]

                plt.scatter(aa_id, len(nok_in_gene) / len(overlap), color=colors[gene],**kwargs)
#             print(gene,extant, len(overlap))        
#     return x, y

### Violins of extant vs non-extant mutation effects

In [18]:
def plot_extant_vs_nonextant_muts(gene, dataset=data_aa, y_axis='brightness', palette='mako', n=5):
    c = ListedColormap(sns.color_palette(palette, 256))
    df = dataset[dataset.gene==gene+'GFP']
    df['f_extant'] = df[['n_mut', 'n_mut_extant']].apply(lambda x:
                                            'all' if x[0]==x[1] else 'none' if x[1]==0 else 'some', axis=1)
    
    plot_half_violin([df[(df['n_mut']==i) & (df['f_extant']=='none')][y_axis] for i in range(1,n)], 
                     side='left', 
                     color=[colors[gene],colors[gene]], alpha=1, widths=0.8, chonkylines=True)

    plot_half_violin([df[(df['n_mut']==i) & (df['f_extant']=='all')][y_axis] for i in range(1,n)], 
                     side='right', alpha=1, linestyle='--', linewidth=1,
                     color=['w',colors[gene]], widths=0.9, chonkylines=True)
#     plt.ylim(df[y_axis].min() - 0.25, df[y_axis].max() + 0.1)

    label_plot_axis(x = 'Number of mutations', t= gene+'GFP')
    plt.xticks(range(n-1), range(1,1+n))

# Mutation effects by position

### Heatmap of median effects by position

In [20]:
def plot_positional_effects_heatmap(df, func=np.nanmedian, cm='Blues_r'):
    effects_amac = get_effects_by_position('amac', df, positions='pseudo', func=func)
    effects_cgre = get_effects_by_position('cgre', df, positions='pseudo', func=func)
    effects_pplu = get_effects_by_position('pplu', df, positions='pseudo', func=func)
    effects_av = get_effects_by_position('av', df, positions='pseudo', func=func)

    font = 'Arial'
    
    to_plot = np.array([effects_av, effects_amac, effects_cgre, effects_pplu, ])
        
    hm = sns.heatmap(to_plot, 
               cmap=cm, yticklabels=['avGFP','amacGFP', 'cgreGFP', 'ppluGFP2', ], xticklabels=10,
               cbar_kws={'label': 'Median mutation effect', 'pad':0.01, 'fraction':0.05})

    y = 0.5
    for gene in ['av','amac','cgre','pplu']:
        ss = pseudify(eval(gene+'_ss_pymol'), gene)
        plot_secondary_structure(ss, y=y, hel=0.05, arrow_width=0.1, linewidth=1, c='k')
        y += 1
    plt.scatter([x+0.5 for x in buried_pos], [0 for x in buried_pos], color='k', s=50, marker=2)

    plt.pcolor(np.arange(len(effects_amac)+1), np.arange(len(to_plot)+1), 
               np.ma.masked_less(np.array([[np.nan]*67+[10,10,10]+[np.nan]*178 for i in range(4)]), 5), 
               hatch='///', alpha=0)

    label_plot_axis(x='Aligned amino acid position', t='Median effects of single mutations, by position')
    plt.xticks(fontname=font)
    plt.yticks(rotation=0, fontname=font)
    plt.ylim(len(to_plot)+0.1, -0.1)
    return to_plot, hm

# Epistasis overview

### Scatterplot of double mutant epistasis

In [22]:
def doublemut_scatterplot(gene, cm=matplotlib.cm.Spectral):
    df = doublemuts[doublemuts['gene']==gene+'GFP'].copy()
    df['abs_epistasis'] = abs(df['epistasis'])
    df.sort_values('abs_epistasis', inplace=True)
    x,y,z = 'mut1_effect','mut2_effect','epistasis'
    plt.scatter(df[x], df[y], alpha=0.8,
               c=df[z], cmap=cm, s=5,vmin=-2, vmax=2)
    label_plot_axis(x='Effect of first mutation', y='Effect of second mutation', t=gene+'GFP')
#     cbar = plt.colorbar()
#     cbar.ax.set_ylabel('Epistasis')
plt.tight_layout()

<Figure size 432x288 with 0 Axes>

### Barplot: fraction of epistatic N-mutant genotypes

In [None]:
def barplot_epistasis(gene, e):
    neg = [len(data_aa[(data_aa['gene']==gene+'GFP') & (data_aa['epistasis'] < -e) & (data_aa['n_mut']==i)]) / 
         len(data_aa[(data_aa['gene']==gene+'GFP') & (data_aa['n_mut']==i) & (data_aa['epistasis'].notnull())])
        for i in range(2,9)]
    pos = [len(data_aa[(data_aa['gene']==gene+'GFP') & (data_aa['epistasis'] > e) & (data_aa['n_mut']==i)]) / 
         len(data_aa[(data_aa['gene']==gene+'GFP') & (data_aa['n_mut']==i) & (data_aa['epistasis'].notnull())])
        for i in range(2,9)]

    plt.bar(range(len(neg)), neg, color=colors[gene])
    plt.bar(range(len(neg)), pos, bottom=neg, color=colors[gene], alpha=1, hatch='///', 
            edgecolor='w', linewidth=0)
    plt.xticks(range(len(neg)), range(2,11))
    plt.ylim(0,0.25)
    label_plot_axis(t=gene+'GFP', x='Number of mutations', y='Fraction of genotypes\nwith epistasis over %s' % e);

# Physical distance between epistatic pairs

### Pairwise distances map

In [23]:
def distances_heatmap(gene, epistasis=0.3,c ='mako'):
    length = len(eval(gene+'_wt'))
    name = {'amac':'amacGFP', 'av':'avGFP', 'cgre':'cgreGFP', 'pplu':'ppluGFP2'}
    df = pd.read_csv(os.path.join(structure_folder, 'residue_distance_matrices', 
                                  name[gene]+'__minimal_distances_between_aa.csv'), index_col=0)
    positional_max_e = np.zeros([length, length])
    sites_to_ignore = [x for x in range(length) if x not in df.index]
    
    for i in range(length):
        for j in range(length):
            if str(i)+':'+str(j) in epistatic_pairs_pos[gene]:
                positional_max_e[i,j] = max(map(abs, epistatic_pairs_pos[gene][str(i)+':'+str(j)]))

    positional_max_e = np.delete(positional_max_e, sites_to_ignore, axis=1)
    positional_max_e = np.delete(positional_max_e, sites_to_ignore, axis=0)

    d = np.array(df)
    sns.heatmap(data=d, square=True, cmap=c, mask=np.tril(d), cbar=False,
                cbar_kws={'label': 'Minimum distance between residues (A)', 'shrink':1, 'pad':0.15})
    ax = sns.heatmap(data=positional_max_e, mask=positional_max_e<0.3,square=True, alpha=1, 
                cmap='hot', cbar=False, vmin=0, vmax=0.3)
    plt.ylim(0,len(d))
    plt.xlim(0,len(d))
    plt.xticks(range(0, len(d), 20), df.index[0::20], fontsize=8)
    plt.yticks(range(0, len(d), 20), df.index[0::20], rotation=0, fontsize=8)
    ax.yaxis.set_label_position('right')
    ax.yaxis.tick_right()
        
    label_plot_axis(x='Amino acid position',y='Amino acid position', fontsize_x=10, fontsize_y=10)

In [24]:
def setup_inset(ax, c):
#     c = ListedColormap(sns.color_palette('mako', 256))
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_position('center')
    plt.xticks(fontsize=12)
    label_plot_axis(y='Distance between residues (A)', fontsize_y=8)
    ax.text(0.1, 50, '|Epistasis| < 0.3', color=c(175))
    ax.text(0.1, 45, '|Epistasis| > 0.3', color=c(75))
    plt.ylim(0,50)

In [27]:
def distances_violins(df, epistasis, palette='mako'):
    c = ListedColormap(sns.color_palette(palette, 256))
    df['e_above_threshold'] = df['epistasis'].apply(lambda x: True if abs(x)>epistasis else False)

    sns.violinplot(data=df[(df.gene!='amacV14LGFP') & 
                                   (abs(df['epistasis'])>0)], 
                   x='gene', y='distance', hue='e_above_threshold', split=True, linewidth=0,
                  palette=[c(75), c(175)], saturation=100, cut=0)
    legendary([c(75), c(175)], ['No epistasis', 'Epistasis > %s' % epistasis], ncol=2, 
              loc='upper center')
    label_plot_axis(y='Distance (A) between\namino acid pairs', )

In [26]:
def plot_all_distance_maps(df, e = 0.3, c = ListedColormap(sns.color_palette('mako', 256))):
    df['e_above_threshold'] = df['epistasis'].apply(lambda x: True if abs(x)>e else False)
    fig = plt.figure(figsize=[10,10], dpi=200)

    

    plt.subplot(2,2,1)
    distances_heatmap('amac', c=c)
    left, bottom, width, height = [0.02, 0.75, 0.14, 0.18]
    ax1 = fig.add_axes([left, bottom, width, height])
    ax = sns.violinplot(data=df[(df.gene=='amacGFP')], x='gene', y='distance', linewidth=0, cut=0,
                  hue='e_above_threshold', split=True, palette=[c(175), c(75)], )
    ax.get_legend().remove()
    setup_inset(ax1, c)

    plt.subplot(2,2,2)
    distances_heatmap('cgre', c=c)
    left, bottom, width, height = [0.52, 0.75, 0.14, 0.18]
    ax2 = fig.add_axes([left, bottom, width, height])
    ax = sns.violinplot(data=df[(df.gene=='cgreGFP')], x='gene', y='distance', linewidth=0, cut=0,
                  hue='e_above_threshold', split=True, palette=[c(175), c(75)], )
    ax.get_legend().remove()
    setup_inset(ax2, c)

    plt.subplot(2,2,3)
    distances_heatmap('pplu', c=c)
    left, bottom, width, height = [0.02, 0.25, 0.14, 0.18]
    ax3 = fig.add_axes([left, bottom, width, height])
    ax = sns.violinplot(data=df[(df.gene=='ppluGFP')], x='gene', y='distance', linewidth=0, cut=0,
                  hue='e_above_threshold', split=True, palette=[c(175), c(75)], )
    ax.get_legend().remove()
    setup_inset(ax3, c)

    plt.subplot(2,2,4)
    distances_heatmap('av', c=c)
    left, bottom, width, height = [0.52, 0.25, 0.14, 0.18]
    ax4 = fig.add_axes([left, bottom, width, height])
    ax = sns.violinplot(data=df[(df.gene=='avGFP')], x='gene', y='distance', linewidth=0, cut=0,
                  hue='e_above_threshold', split=True, palette=[c(175), c(75)], )
    ax.get_legend().remove()
    setup_inset(ax4, c)

    plt.tight_layout(pad=2)

# Urea sensitivity

### Urea denaturation spectra

In [29]:
def import_plate_reader_data():
#     spectra_fluo = pd.read_csv('urea_spectra_fluo_7_longform.txt', sep='\t', index_col=0)
    spectra_fluo = pd.read_csv(os.path.join(structure_folder, 'stability_measurements',
                                           'urea__fluorescence_spectra.csv'), index_col=0)
#     spectra_abs = pd.read_csv('urea_spectra_abs_longform.txt', sep='\t', index_col=0)
    spectra_abs = pd.read_csv(os.path.join(structure_folder, 'stability_measurements',
                                           'urea__absorbance_spectra.csv'), index_col=0)

    
    spectra_fluo['\nTreatment'] = spectra_fluo['treatment'].apply(lambda x: '9M urea' if x=='urea' else 'control')
    spectra_fluo['hours'] = spectra_fluo['t'] * 25 / 60 #each fluorescence spectrum measurement takes ~25min
    spectra_fluo['Gene'] = spectra_fluo['gene'].apply(lambda x: 'amacGFP' if x=='amac' else 'cgreGFP' if x=='cgre'
                                                 else 'ppluGFP' if x=='pplu' else 'avGFP' if x=='av'
                                                 else 'amacGFP_V14L' if x=='amacV14L' else 'blank')
    
    spectra_abs['\nTreatment'] = spectra_abs['treatment'].apply(lambda x: '9M urea' if x=='urea' else 'control')
    spectra_abs['hours'] = spectra_abs['t'] * 40 / 60 #each absorbance spectrum measurement takes ~40min
    spectra_abs['Gene'] = spectra_abs['gene'].apply(lambda x: 'amacGFP' if x=='amac' else 'cgreGFP' if x=='cgre'
                                                 else 'ppluGFP' if x=='pplu' else 'avGFP' if x=='av'
                                                 else 'amacGFP_V14L' if x=='amacV14L' else 'blank')
    return spectra_abs, spectra_fluo

In [30]:
spectra_abs, spectra_fluo = import_plate_reader_data()

In [31]:
def plot_spectra(gene, treatment, normed=True):
    if normed==True:
        y = 'signal_normalized'
    elif normed==False:
        y = 'signal'
    for df,color in zip([spectra_abs, spectra_fluo], ['bone_r', 'mako_r']):
            df = df.reset_index()
            sns.lineplot(data=df[(df['treatment']==treatment) & (df['gene']==gene)
                                 & (df['hours']<=57) & (df['Wavelength']<=650)],
                        x='Wavelength', y=y, hue='t', ci=None, palette=color, legend=False)
    plt.ylim(-0.1,1.1)
    label_plot_axis()

### Urea denaturation timelapse with curve

In [32]:
from scipy.optimize import curve_fit

def plot_urea_fluo_loss():

    spectra_fluo['time_h'] = spectra_fluo['t'].apply(lambda x : x*(25/60)) # timepoints every 25min

    for gene in ['pplu','av','amacV14L','amac','cgre',]:

        xdata = spectra_fluo[(spectra_fluo.gene==gene) & (spectra_fluo.treatment=='urea') &
                             (spectra_fluo.peak==True) ]['time_h']
        ydata = spectra_fluo[(spectra_fluo.gene==gene) & (spectra_fluo.treatment=='urea') &
                             (spectra_fluo.peak==True) ]['signal_normalized']

        plt.scatter(xdata,ydata,
                   color=colors[gene], label=gene, s=1, alpha=0.5)
        ax = sns.lineplot(data = spectra_fluo[(spectra_fluo.treatment=='urea') & (spectra_fluo.peak==True)
                                             & (spectra_fluo.gene==gene)], 
                     x = 'time_h', y = 'signal_normalized', ci=None, color=colors[gene])
        ax.get_legend().remove()

        if gene=='av':
            def curvy(x,a1,k1):
                return a1 * np.exp(-k1 * x)
        elif gene=='pplu' or gene=='cgre':
            def curvy(x,a0,a1,k1):
                return a0 + a1 * np.exp(-k1 * x)
        elif gene=='amac' or gene=='amacV14L':
            def curvy(x,a0,a1,k1,a2,k2):
                return a0 + a1 * np.exp(-k1 * x) + a2 * np.exp(-k2 * x) 
        elif gene=='cgre':
            def curvy(x,a0,a1,k1,a2,k2, c2):
                return a0 + a1 * np.exp(-k1 * x) + a2 * np.exp(-k2 * (x-c2)**2) 

        if gene !='cgre':
            popt, pcov = curve_fit(curvy,  xdata,  ydata, maxfev=100000)

            x = np.linspace(0, 57, 1000)
            y = curvy(x, *popt)
            plt.plot(x,y, color='k', linewidth =1 )

### Post-urea refolding curves

In [None]:
def plot_refolding(time='short'):
    assert time in ['short','long']
    
    def curvy(x, y0 ,a, b, c, d,e,f):
        y = y0 + a * np.exp(-b*x) + c * np.exp(-d * x) + e * np.exp(-f * x)
        return (y) 
    
    if time=='short':
        rf = pd.read_csv(os.path.join(structure_folder, 'stability_measurements', 
                                   'refolding_post_urea__20min.csv'), index_col=0)
        for gene in ['amac','amacV14L','cgre','pplu','av']:
            xdata = rf[rf.gene==gene]['time']
            ydata = rf[rf.gene==gene]['fluorescence_normed']
            plt.xticks(range(0, 1201, 120), range(0,21,2))
            label_plot_axis(y = 'Fluorescence', x = 'Time (min)')
            plt.scatter(xdata,ydata, color=colors[gene], label=gene, s=1, alpha=0.5)

            popt, pcov = curve_fit(curvy,  xdata,  ydata, maxfev=100000)
            x = np.linspace(0, 1250, 1250)
            y = curvy(x, *popt)
            plt.plot(x,y, color='k', linewidth =1 )
        
    elif time=='long':
        rf = pd.read_csv(os.path.join(structure_folder, 'stability_measurements', 
                                   'refolding_post_urea__overnight.csv'), index_col=0)
        ax = sns.lineplot(data = rf, x = 'time', y = 'fluorescence_normed', hue='gene', ci = 'sd',
            palette = [colors[gene] for gene in ['cgre','av','amacV14L','amac','pplu']])
        label_plot_axis(y = 'Fluorescence', x = 'Time (hours)')
        ax.get_legend().remove()
        plt.xticks(range(0,51000,3600), range(0,15))

# amacGFP vs amacGFP:V12L

### Amac vs V14L: mut effects, medians, epistasis

In [33]:
def scatter_amac_v14l_mut_effects():
    density_plot(np.array(df_effects_singles['effect_in_amacV14L']), np.array(df_effects_singles['effect_in_amac']),
            s=10, cmap='bone')
    plt.yticks([-1,-0.5,0])
    label_plot_axis(x='Mutation effect in amacGFP:V12L', y='Mutation effect in amacGFP')

In [34]:
def amac_v14l_nmut_medians(n = 8):
    plt.plot(range(n), [data_aa[amacV14V_mask & (data_aa.n_mut==i)]['brightness'].median() for i in range(n)],
        color=cm(75), linewidth=3, label = 'amacGFP')
    plt.plot(range(n), [data_aa[amacV14L_mask & (data_aa.n_mut==i)]['brightness'].median() for i in range(n)],
            color=cm(0), linewidth=3, label = 'amacGFP:V12L', )
    plt.plot(range(n), [data_aa[amacV14V_mask & (data_aa.n_mut==i)]['brightness'].mean() for i in range(n)],
            color=cm(75), linewidth=3, linestyle='--')
    plt.plot(range(n), [data_aa[amacV14L_mask & (data_aa.n_mut==i)]['brightness'].mean() for i in range(n)],
            color=cm(0), linewidth=3, linestyle='--',)
    label_plot_axis(y='Fluorescence (log)',x='Number of mutations')
    plt.legend(frameon=False, fontsize=12)
    plt.xticks(range(n))

In [36]:
def amac_v14l_obs_exp_bright():
    cutoff = gate_borders_log['amac'][1] - ref_wt_log['amacGFP']
    plt.plot([get_ratio_predicted2observed_fluo(data_aa[amac_mask], n_mut=i, 
                                threshold=cutoff) for i in range(1,8)],
                     color=cm(75),label='amacGFP', linewidth=3)
    plt.plot([get_ratio_predicted2observed_fluo(data_aa[amacV14L_mask], n_mut=i, 
                                threshold=cutoff) for i in range(1,8)],
                     color=cm(0),label='amacGFP:V12L', linewidth=3)
    plt.legend(frameon=False, fontsize=12)
    label_plot_axis(x='Number of mutations', 
                   y='Functional genotypes\nexplained without epistasis')
    plt.ylim(0.,1.05)
    plt.xticks(range(7),range(1,8))

### Amac vs V14L: examples of differences in mutation effect

In [None]:
def get_mutposeffect(data, y_column, x_val, x_column='position', s_column='mutation'):
    df = data[data[x_column]==x_val]
    return list(df[y_column]), list(df[s_column])

def set_x_coordinates(list_of_yvals, list_of_labels, y_distance_threshold, center_x, x_distance):
    '''Tries to find non-overlapping x coordinates for text-swarmplot'''
    yvals_to_labels = dict(zip(list_of_yvals, list_of_labels))
    list_of_yvals = sorted([x for x in list_of_yvals if str(x)!='nan'])
    list_of_xvals = [center_x for x in list_of_yvals]
    
    ref_y = 0
    multiplier = 1
#     original_x_distance = x_distance
    current_x_spacer = x_distance
    
#     print(list_of_yvals)
    
    for i in range(1, len(list_of_yvals)):

        if abs(list_of_yvals[i] - list_of_yvals[ref_y]) < y_distance_threshold:
#             print(i, 'too close')
            list_of_xvals[i] += multiplier*current_x_spacer
            multiplier *= (-1)
            if multiplier== 1:
#                 print(i, 'multiply -1')
                current_x_spacer+=x_distance
        else:
            ref_y = i
            current_x_spacer = x_distance
#             print(i, 'new ref')

    return list_of_xvals, list_of_yvals, [yvals_to_labels[y] for y in list_of_yvals]

In [37]:
def amac_v14l_mutposeffect(positions = [1, 69, 73, 77, 81, 83, 87, 88, 153, 160, 184, 195, 198]):

    plt.fill_between([-1,17], [-0.07,-0.07], [0.07,0.07], color='lightgrey')
    df_effects_singles['native_position_amac'] = df_effects_singles['position'].apply(lambda x:
                                                        pseudopos_to_nativepos[x][1])
    df_effects_singles['effect_amacV14L_minus_amac'] = df_effects_singles['effect_in_amacV14L'] - df_effects_singles['effect_in_amac']
    for j in range(len(positions)):
        yvals, labels = get_mutposeffect(df_effects_singles, y_column='effect_amacV14L_minus_amac', 
                                     x_column = 'native_position_amac', x_val=positions[j])
        xvals, yvals, labels = set_x_coordinates(yvals, labels, y_distance_threshold=0.15, 
                                             center_x=j, x_distance=0.15)
        for i in range(len(yvals)):
            plt.text(x = xvals[i] , y = yvals[i], s=labels[i], size='medium', color=cm(75) if j==0 else 'k',
                   verticalalignment='center', horizontalalignment='center', weight='semibold')
    plt.xlim(-1,len(positions))
    plt.ylim(-1,1)
    plt.xticks(range(len(positions)), [str(pos)+'\n'+amac_wt[pos] for pos in positions], fontsize=12)
    label_plot_axis(x='Position and wildtype state in sequence', y='Difference in effect')

# Synonymous mutations

In [38]:
bcs = pd.read_csv(os.path.join(data_folder, 'final_datasets',
                              'avGFP_amacGFP_cgreGFP_ppluGFP2__wt_and_synonymous_barcodes.csv'))

In [39]:
def plot_wts_and_synmuts(gene, bc_data, xticks=[3,3.5,4,4.5], yticks=[100,500,1000], cellcount=50):
    ntwt = bc_data[(bc_data.nt_genotype.isnull()) & (bc_data.gene==gene)]
    aawt = bc_data[(bc_data.aa_genotype_native.isnull()) & (bc_data.nt_genotype.notnull()) & (bc_data.gene==gene)]
    
    if gene!= 'avGFP':
        ntwt = ntwt[ntwt.clone_cell_count > cellcount]
        aawt = aawt[aawt.clone_cell_count > cellcount]
        
    aabins = {'avGFP':10, 'amacGFP':10, 'cgreGFP':60, 'ppluGFP':20}
    
    sns.histplot(data = ntwt['log_brightness'],
                        bins=70, element='poly', color=colors[gene], fill=True, linewidth=0, 
                         label = f'Nucleotide WTs \n(n = {len(ntwt)})')
    
    sns.histplot(data = aawt['log_brightness'],
                    bins=aabins[gene], element='poly', color='k', fill=False, linewidth=2, linestyle='--', 
                     label=f'Synonymous mutants \n(n = {len(aawt)})', )
     
    plt.title(gene)
    plt.xlabel('Fluorescence (log)', fontsize=12)
    plt.ylabel('Number of barcodes', fontsize=12)
    plt.xticks(xticks,xticks)
    plt.yticks(yticks,yticks)
    
    p = scipy.stats.mannwhitneyu(ntwt['log_brightness'], aawt['log_brightness'], alternative='two-sided')
    plt.legend(loc='upper left', frameon=False, title=f'MWU, p = {round(p[1],2)}')

# Protein thermostability

### Differential scanning fluorimetry

In [None]:
dsf = pd.read_csv(os.path.join(structure_folder, 'stability_measurements',
                              'dsf_data__av_amac_amacV14L_cgre_pplu.csv'))

In [None]:
dsf_tm_ratio = {}
for gene in set(dsf['gene']):
    t = list(dsf[(dsf.gene==gene)]['temperature'])
    r = list(dsf[(dsf.gene==gene)]['unfolding_ratio_deriv'])
    if gene == 'cgre' or gene == 'av':
        i = r.index(max(r))
        dsf_tm_ratio[gene] = t[i]
    else:
        i = r.index(min(r))
        dsf_tm_ratio[gene] = t[i]
        
dsf_tm_scatter = {}
for gene in set(dsf['gene']):
    t = list(dsf[(dsf.gene==gene)]['temperature'])
    r = list(dsf[(dsf.gene==gene)]['unfolding_scatter_deriv'])
    i = r.index(max(r))
    dsf_tm_scatter[gene] = t[i]

In [None]:
def plot_dsf(what='ratio', fontsize=12):
    assert what in ['ratio','scatter']
    ax = sns.lineplot(data = dsf, 
             x='temperature', y=f'unfolding_{what}_deriv', hue='gene', linewidth=2,
            ci='sd', palette = [colors[gene] for gene in ['pplu', 'amac', 'amacV14L', 'cgre', 'av']])
    
    for gene in ['av','pplu','amac','amacV14L','cgre']:
        plt.axvline(eval(f'dsf_tm_{what}')[gene], color=colors[gene], linestyle='--', linewidth=2)
        
    plt.xlabel('Temperature (C)', fontsize=fontsize)
    plt.xticks(range(20,111,10),range(20,111,10))
    ax.get_legend().remove()

### Differential scanning calorimetry

In [40]:
dsc = pd.read_csv(os.path.join(structure_folder, 'stability_measurements',
                              'dsc_data__av_amac_amacV14L_cgre_pplu.csv'))

In [None]:
dsc_tm = {}
for gene in set(dsc['gene']):
    t1 = list(dsc[(dsc.gene==gene) & (dsc.replicate==0)]['temperature'])
    t2 = list(dsc[(dsc.gene==gene) & (dsc.replicate==1)]['temperature'])
    m1 =  list(dsc[(dsc.gene==gene) & (dsc.replicate==0)]['Cp_kJ/mol/K'])
    m2 = list(dsc[(dsc.gene==gene) & (dsc.replicate==1)]['Cp_kJ/mol/K'])
    if gene=='cgre132':
        tm = t1[m1.index(max(m1))]
    else:
        tm = (t1[m1.index(max(m1))] + t2[m2.index(max(m2))]) / 2
    dsc_tm[gene] = tm

In [None]:
def plot_dsc(fontsize=12):
    ax = sns.lineplot(data = dsc, 
             x='temperature', y='Cp_kJ/mol/K', ci=None, hue='gene', style='replicate', dashes=False,
            palette = [colors[gene] for gene in ['pplu','amac','V14L','cgre','av']], linewidth=2)
    
    for gene in ['av','amac','V14L','cgre','pplu']:
        plt.axvline(dsc_tm[gene], color=colors[gene], linestyle='--', linewidth=2)
    plt.xlabel('Temperature (C)', fontsize=fontsize)
    ax.get_legend().remove()

### qPCR melting curves

In [41]:
qpcr = pd.read_csv(os.path.join(structure_folder, 'stability_measurements',
                              'qpcr_melting_curves.csv'))
qpcr = qpcr[qpcr.gene!='blank'].copy()

In [42]:
qpcr_tm = {}
for gene in set(qpcr['gene']):
    t = list(qpcr[(qpcr.gene==gene)]['temperature'])
    m = list(qpcr[(qpcr.gene==gene)]['value'])
    tm = t[m.index(max(m))]
    qpcr_tm[gene] = tm

In [43]:
def plot_qpcr(what='melt', fontsize=12):
    ax = sns.lineplot(data = qpcr,
                      x='temperature', y='value', hue='gene', ci='sd',
                      palette=[colors[gene] for gene in ['amac','amacV14L','av','pplu','cgre']], linewidth=2)
    ax.get_legend().remove()
    plt.xlabel('Temperature (C)', fontsize=fontsize)
    for gene in ['av','amac','amacV14L','cgre',]:
        plt.axvline(qpcr_tm[gene], color=colors[gene], linewidth=2, linestyle='--')

### Circular dichroism 

In [45]:
cd_spectra = pd.read_csv(os.path.join(structure_folder, 'stability_measurements',
                              'cd_spectra_200to260.csv'))
cd_melt = pd.read_csv(os.path.join(structure_folder, 'stability_measurements',
                              'cd_melting_curves.csv'))

In [46]:
cd_spectra['buffer'] = list(cd_spectra[cd_spectra.gene=='buffer']['value']) * 6
cd_spectra['value_normed'] = cd_spectra[['value','buffer']].apply(lambda x: x[0]-x[1], axis=1)

In [47]:
def plot_cd_spectra(gene, fontsize=12):
    sns.lineplot(data = cd_spectra[(cd_spectra.gene==gene) & ~((cd_spectra.nm<=205)&(cd_spectra.temp==98))], 
                 x='nm', y='value_normed', style='temp', color=colors[gene], linewidth=2)
    if gene=='av':
        plt.axvline(208, color='k', linestyle = ':', linewidth=2)
        plt.axvline(218, color='grey', linestyle=':', linewidth=2)
    else:
        plt.axvline(218, color='k', linestyle=':', linewidth=2)
    
    plt.xlabel('Wavelength (nm)', fontsize=fontsize)
    plt.ylabel('CD (mdeg)', fontsize=fontsize)
    legendary(['k','k'], ['30C', '98C'], edges=['-','--'], style='lines', title='Temperature', loc='upper right')

In [None]:
def sigmoid(x, L ,x0, k, b):
    y = L / (1 + np.exp(-k*(x-x0)))+b
    return (y)

def cd_melt_fit(gene):
    nm = 208 if gene=='av' else 218
    xdata = cd_melt[(cd_melt.gene==gene) & (cd_melt.nm == nm)]['temperature']
    ydata = cd_melt[(cd_melt.gene==gene) & (cd_melt.nm == nm)]['value'] 
    guess = 80 if gene=='amac' else 75
    p0 = [max(ydata), guess, 1, min(ydata)]
    popt, pcov = curve_fit(sigmoid, xdata, ydata, p0)
    print(f'{gene}: L = {popt[0]}, x0 = {popt[1]}, k = {popt[2]}, b = {popt[3]}')
    return popt

In [None]:
def plot_cd_melting(gene, fontsize=12):
    nm = 208 if gene=='av' else 218
    xdata = cd_melt[(cd_melt.gene == gene) & (cd_melt.nm == nm)]['temperature']
    ydata = cd_melt[(cd_melt.gene == gene) & (cd_melt.nm == nm)]['value']
    
    if gene == 'av':
        plt.scatter(cd_melt[(cd_melt.gene == gene) & (cd_melt.nm == 218)]['temperature'],
                   cd_melt[(cd_melt.gene == gene) & (cd_melt.nm == 218)]['value'], 
                    color='lightgrey', s=3)
    
    plt.scatter(xdata, ydata, color=colors[gene], s=3, alpha=0.7)
    plt.xlabel('Temperature (C)', fontsize=fontsize)
    plt.ylabel('CD (mdeg)', fontsize=fontsize)
    
    x = np.linspace(30, 100, 100)
    popt = cd_melt_fit(gene)
    y = sigmoid(x, *popt)
    plt.plot(x,y, color='w', linewidth =6 )
    plt.plot(x,y, color='k', linewidth =2 )
    plt.axvline(popt[1], linewidth = 2, linestyle='--', color=colors[gene])

# SEC-MALS

In [48]:
sec = pd.read_csv(os.path.join(structure_folder, 'secmals',
                              'secmals_raw.csv'))
sec_peaks = pd.read_csv(os.path.join(structure_folder, 'secmals',
                              'secmals_peak_weights.csv'))

In [49]:
sec_bestrep = {'amac':1, 'av':2, 'pplu':1, 'V14L':2, 'cgre':2}

In [50]:
# label data
peak_num = {'V14L':{4:(9.53, 89193, 0.45), 3:(9.9, 76369,1.06), 2:(10.5, 55843,31.27), 1:(11.5,28909,67.22)},
           'amac': {4:(9.57,110977,0.68), 3:(9.9,87143,1.92), 2:(10.5, 56064,46.09), 1:(11.5,30114,51.31)},
           'cgre': {2:(9.5,106915,0.7), 1:(11,55082,99.3)},
           'pplu': {1:(11.5, 104228, 97.51), 2:(8.2, 1341321, 2.49)},
           'av':{1:(11.5, 30166, 99.08), 2:(10.26, 51169, 0.92)}}

In [52]:
def plot_secmals_curves(sec = sec, legend=False, cm=sns.cm.mako, gene='amac', loc=(1,1,1)):
    ax1 = plt.subplot(*loc)
    df = sec[(sec.gene==gene) & (sec.replicate==sec_bestrep[gene]) & (sec.retention_volume_mL<13)]
    plt.plot(df['retention_volume_mL'], df['refractive_index_mV'], color=cm(125), label='Refractive index')
    plt.plot(df['retention_volume_mL'], df['ultraviolet_mV'], color=cm(200), label='Ultraviolet')
#     plt.plot(df['retention_volume_mL'], df['right_angle_light_scatter_mV']-82)
    plt.xlim(6,13)
    plt.title(names[gene], fontsize=16)
    plt.xlabel('Retention volume (mL)', fontsize=14)
    
    if legend==True:
        plt.ylabel('Refractive index (mV), Ultraviolet (mV)', fontsize=14)
        legendary([cm(50), cm(125), cm(200)], ['Molecular weight', 'Refractive index', 'Ultraviolet'],
                 loc='upper left', fontsize=12)
        
    y = ax1.get_ylim()
    j = 0
    for peak in sorted(peak_num[gene]):
        plt.text(6.5, y[1]/2 - j, 
                 f'Peak {peak} ({peak_num[gene][peak][2]}%):\nMw = {peak_num[gene][peak][1]}', color=cm(50),
                size=12)
        j += y[1]/8
        
    
    ax2 = ax1.twinx()
    df = sec_peaks[(sec_peaks.gene==gene) & (sec_peaks.replicate==sec_bestrep[gene])]
    plt.plot(df['retention_volume_mL'], df['molecular_weight_g/mol'], color=cm(50), linewidth=2,
            label = 'Molecular weight')

    plt.yscale('log')
    plt.ylim(1, 10000000)
    
    if i==10:
        plt.ylabel('Molecular weight (g/mol)', fontsize=14)    

    for peak in peak_num[gene]:
        plt.text(peak_num[gene][peak][0], peak_num[gene][peak][1], peak, color=cm(50), weight='semibold',
                horizontalalignment='center', verticalalignment='bottom', fontsize=16)

# ddG predicitons vs data

In [53]:
def ddg_scatter(gene):
    df = data_aa[singles_mask & eval(gene+'_mask') & ~chromomut_mask]
    df = df[~(df.aa_genotype_pseudo.str.contains('G|P'))]
    df = df[['brightness', 'ddG_prediction']].dropna()
    
    a = np.array(df['brightness'])
    b = np.array(df['ddG_prediction'])
    
    corr = scipy.stats.spearmanr(b,a,nan_policy='omit')
    
    label_plot_axis(x='Fluorescence (log)', y='ddG prediction', t=names[gene])
    plt.scatter(a,b, color='k', s=4,)

    plt.scatter(a,b, color=colors[gene], s=3, alpha=0.5)
    plt.text(df['brightness'].min(), 45, f'rs = {round(corr[0],2)}', fontsize=12,)
    plt.ylim(-15,50)

In [None]:
def ddg_violins(gene):
    df = data_aa[singles_mask & eval(gene+'_mask') & ~chromomut_mask]
    df = df[~(df.aa_genotype_pseudo.str.contains('G|P'))]
    df = df[['brightness', 'ddG_prediction']].dropna()
    
    g = 'amac' if gene=='amacV14L' else gene
    threshold = 3 if gene=='av' else gate_borders_log[g][0]
    df['fitness_level'] = df['brightness'].apply(lambda x: 'bright' if x > ref_wt_log[gene+'GFP'] - 0.06
                                                else 'dark' if x < threshold else np.nan)
    
    df['gene'] = names[gene]
    ax = sns.violinplot(data = df[df.fitness_level.notnull()], x='gene', y='ddG_prediction', hue='fitness_level',
                  split=True, cut=0, linewidth = 2, inner=None, hue_order=['bright','dark'])
    
    patch_violinplot(['w', colors[gene], ], [colors[gene], colors[gene]])
    legendary([ 'w',colors[gene]], ['Bright', 'Dark'], edges=[colors[gene], colors[gene]],
             loc='upper left', title='Genotypes')
    label_plot_axis(y='ddG')
    plt.xticks([])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_position('center')

# Mutagenesis nucleotide bias

In [1]:
aatype_to_aa = {'aliphatic':'AILMV', 'aromatic':'FWY', 'positive_charge':'RHK', 'negative_charge':'DE',
                   'polar_uncharged':'STNQ', 'other':'CGP'}
aa_to_aatype = {aa : [aatype] for aatype in aatype_to_aa for aa in aatype_to_aa[aatype]}
aa_to_aatype['H'] = ['aromatic','positive_charge']

In [2]:
def count_mut_by_type(nt_genotype, aa_genotype, mode):
    assert mode in ['indel', 'syn', 'type_change', 'type_keep']
    
    if mode=='indel':
        return nt_genotype.count('.')
    
    elif mode=='syn':
        if aa_genotype=='wt':
            return len(nt_genotype.split(':'))
        else:
            aamuts = aa_genotype.split(':')
            return len(nt_genotype.split(':')) - len(aamuts)
        
    elif mode=='type_keep':
        if aa_genotype=='wt':
            return 0
        else:
            aamuts = aa_genotype.split(':')
            aamuts = [x for x in aamuts if '*' not in x and '.' not in x]
            return len([m for m in aamuts if len(set(aa_to_aatype[m[0]])&set(aa_to_aatype[m[-1]]))>0])
        
    elif mode=='type_change':
        if aa_genotype=='wt':
            return 0
        else:
            aamuts = aa_genotype.split(':')
            aamuts = [x for x in aamuts if '*' not in x and '.' not in x]
            return len([m for m in aamuts if len(set(aa_to_aatype[m[0]])&set(aa_to_aatype[m[-1]]))==0])

In [3]:
def get_ntmut_stats(dataset, gene):
    muts = list(dataset[(dataset.gene == gene)]['nt_genotype'])
    muts = [x for x in muts if x!='wt']
    muts = ':'.join(muts).split(':')
    
    if gene == 'amacGFP':
        muts = [x for x in muts if x!='G33T']
        
#     print(len(muts))    
        
    stats = {}
    for wt in 'ATCG':
        for m in 'ATCG':
            if wt!=m:
                stats[wt+'>'+m] = len([x for x in muts if x[0]==wt and x[-1]==m]) / len(muts)
                
    return stats

In [5]:
def weight(genotype):
    muts = genotype.split(':')
    muts = [f'{x[0]}>{x[-1]}' for x in muts]
    w = 1
    for x in muts:
        w *= comps[x]
    return w

In [6]:
def classify_nonsyn_broad(dataset, gene):
    muts = dataset[(dataset.gene==gene) & (dataset.n_indel==0)]['aa_genotype_pseudo']
#     muts = dataset[(dataset.gene==gene) & (dataset.n_mut<9)]['aa_genotype_native']
    muts = ':'.join(muts).split(':')
    muts = [x for x in muts if x!='wt' and '*' not in x and '.' not in x]
#     print(gene, len(muts))
    
    changers = [x for x in muts if len(set(aa_to_aatype[x[0]]) & set(aa_to_aatype[x[-1]]))==0]
    keepers = [x for x in muts if len(set(aa_to_aatype[x[0]]) & set(aa_to_aatype[x[-1]]))>0]
    
    return {'AA type change' : len(changers) / len(muts), 'No AA type change' : len(keepers) / len(muts)}

In [7]:
def do_sims(data_nt, n_sims, sample_size=15000):
    medians = {i:[] for i in range(1,9)}
    sample_stats = {mut : [] for mut in comps}
    sample_stats_aa = {aachange :[] for aachange in ['AA type change', 'No AA type change']}
    
    for sim in range(n_sims):
        sample = data_nt[data_nt.gene=='avGFP'].sample(n=sample_size, 
                                                       weights = data_nt[data_nt.gene=='avGFP']['mut_weight'])
        for i in range(1,9):
            medians[i].append(sample[sample.n_nt_mut==i]['log_brightness'].median())
        
        stats = get_ntmut_stats(sample, 'avGFP')
        for mut in sample_stats:
            sample_stats[mut].append(stats[mut])
            
        stats_aa = classify_nonsyn_broad(sample,'avGFP')
        for mut in sample_stats_aa:
            sample_stats_aa[mut].append(stats_aa[mut])
            
    return medians, sample_stats, sample_stats_aa

In [8]:
def plot_ntmut_types(data):
    data['mut_weight'] = data['nt_genotype'].apply(lambda x: 0 if x=='wt' or '.' in x or '*' in x
                                               else weight(x)**2)
    for mo in ['indel', 'syn', 'type_change', 'type_keep']:
        data['n_'+mo] = data[['nt_genotype','aa_genotype_pseudo']].apply(lambda x:
                                                    count_mut_by_type(x[0], x[1], mo), axis=1)
    
    sims_medians, sims_stats, sims_aa_stats = do_sims(data, 10, 15000)
    
    x = np.arange(12)
    plt.bar(x - 0.3, height = stats['amacGFP'].values(), width = 0.1, label = 'amacGFP', color=colors['amac'])
    plt.bar(x - 0.15, height = stats['cgreGFP'].values(), width = 0.1, label = 'cgreGFP', color=colors['cgre'])
    plt.bar(x, height = stats['ppluGFP'].values(), width = 0.1, label = 'ppluGFP2',color=colors['pplu'])
    plt.bar(x + 0.15, height = stats['avGFP'].values(), width = 0.1, label = 'avGFP',color=colors['av'])
    plt.bar(x + 0.3, height = [np.mean(sims_stats[x]) for x in sims_stats], width = 0.1, label = 'Subsampled avGFP',
                       yerr = [np.std(sims_stats[x]) for x in sims_stats], color=colors['av'], hatch='///')
    
    plt.legend(frameon=False, loc='upper right', fontsize=13)
    plt.xticks(x, comps.keys(), fontsize=13)
    plt.ylabel('Fraction of total \nnucleotide mutations', fontsize=13)

In [None]:
def plot_avGFP_mutbias(data):
    data['mut_weight'] = data['nt_genotype'].apply(lambda x: 0 if x=='wt' or '.' in x or '*' in x
                                               else weight(x)**2)
    for mo in ['indel', 'syn', 'type_change', 'type_keep']:
        data['n_'+mo] = data[['nt_genotype','aa_genotype_pseudo']].apply(lambda x:
                                                    count_mut_by_type(x[0], x[1], mo), axis=1)
    
    sims_medians, sims_stats, sims_aa_stats = do_sims(data, 10, 15000)
    plt.axhline(3, linestyle='--', color='crimson')
    plt.errorbar(range(1,9), [np.mean(sims_medians[x]) for x in sims_medians], color=colors['av'], linewidth=4,
                 yerr=[np.std(sims_medians[x]) for x in sims_medians], label='Subsampled\navGFP', zorder=0)
    plt.plot(range(1,9), [data_nt[(data_nt.gene=='avGFP') & (data_nt.n_nt_mut==i)]['log_brightness'].median() 
                            for i in range(1,9)], marker='x', linewidth=0, label='avGFP', color='k', zorder=1)
    plt.ylabel('Median brightness', fontsize=13)
    plt.xlabel('Number of mutations', fontsize=13)
    plt.xticks(range(1,9),range(1,9))
    plt.legend(frameon=False, loc='lower left')

In [None]:
def plot_aatypechange(data):
    data['mut_weight'] = data['nt_genotype'].apply(lambda x: 0 if x=='wt' or '.' in x or '*' in x
                                               else weight(x)**2)
    for mo in ['indel', 'syn', 'type_change', 'type_keep']:
        data['n_'+mo] = data[['nt_genotype','aa_genotype_pseudo']].apply(lambda x:
                                                    count_mut_by_type(x[0], x[1], mo), axis=1)
    sims_medians, sims_stats, sims_aa_stats = do_sims(data, 10, 15000)
    
    x = np.arange(2)
    for xi,gene in zip([-0.3,-0.15,0,0.15],['amacGFP', 'cgreGFP', 'ppluGFP','avGFP']):
        nchange = data[(data.gene==gene) & (data.n_indel==0)]['n_type_change'].sum()
        nkeep = data[(data.gene==gene) & (data.n_indel==0)]['n_type_keep'].sum()
        ntot = nchange + nkeep
#         print(gene, [100*nchange/ntot, 100*nkeep/ntot])
        plt.bar([0+xi, 1+xi], height = [100*nchange/ntot, 100*nkeep/ntot], color=colors[gene], width=0.12)
    plt.bar([0.3, 1.3], height=[100*np.mean(sims_aa_stats[t]) for t in sims_aa_stats.keys()],
           color=colors['av'], hatch='//',width=0.12)
    plt.xticks([0,1], ['AA type\nchange', 'AA type\nmaintenance'], fontsize=13)
    plt.ylabel('Non-synonymous\nmutations (%)', fontsize=13)