In [1]:
%run lgs00_general_functions.ipynb

#### Some necessary variables

In [None]:
epistatic_pairs_pos = {gene : 
                     doublemuts[doublemuts.gene==gene+'GFP'].groupby('position_pseudo').apply(
                                                                             lambda x: list(x['epistasis']))
                     for gene in ['amac', 'amacV14L', 'cgre', 'pplu', 'av']}
epistatic_pairs_pos = {gene : dict(epistatic_pairs_pos[gene]) for gene in epistatic_pairs_pos}

In [None]:
identities = {frozenset(['amac','av']):82, frozenset(['cgre','av']):41, frozenset(['pplu','av']):18,
              frozenset(['amac','cgre']):43, frozenset(['amac','pplu']):17, frozenset(['cgre','pplu']):19,
        frozenset(['amacV14L','av']):82, frozenset(['amacV14L','cgre']):43, frozenset(['amacV14L','pplu']):17,
             frozenset(['amac','amacV14L']):100}

In [None]:
def make_figure_letters(figure, coordinates, **kwargs):
    ax = figure.add_subplot(1,1,1)
    ax = fig.add_subplot(1,1,1)
    for side in ['top','bottom','right','left']:
        ax.spines[side].set_color(None)
    ax.tick_params(labelcolor='w', top=False, bottom=False, left=False, right=False)
    
    for letter in coordinates:
        plt.text(coordinates[letter][0], coordinates[letter][1], letter, weight='bold', **kwargs)
    return ax

## Figure 2: effects of mutations

#### Side by side slices of landscape peaks

In [None]:
def cheesescapes(n=9, col='brightness', cmap='BuGn', pad=0.15, center=(0,0)):
    cm = plt.get_cmap(cmap)
    
    plt.scatter([center[0]]*2,[center[1]]*2, c=[0,1], cmap=cmap)
    plt.colorbar(orientation='vertical', label='Median brightness', pad=pad)
    
    medians = {}
    for i in range(n):
        n_mask = data_aa_scaled['n_mut']==i
        medians[i] = [data_aa_scaled[amac_mask & n_mask][col].median() * 255, 
                      data_aa_scaled[cgre_mask & n_mask][col].median() * 255,
                      data_aa_scaled[pplu_mask & n_mask][col].median() * 255,
                      data_aa_scaled[av_mask & n_mask][col].median() * 255]
    x = 2
    for i in range(n)[::-1]:
        if i==n-1:
            labels = ['\n\n\namacGFP','cgreGFP','ppluGFP2','avGFP']
        else:
            labels = ['','','','']
        c = cm([int(j) for j in medians[i]])
        plt.pie([1,1,1,1], radius=(x+i)/10, colors=c, labels=labels, wedgeprops={'linewidth':0.5,'edgecolor':'w'},
               startangle=-45, center=center)
        
        c2 = cm(int(255 - medians[i][0]))
        if i == 8:
            plt.text((x+i)/10 - 0.1 + center[0], 0 + center[1], '8 mutations',
                horizontalalignment='left', verticalalignment='center', color=c2, weight='semibold')
        else:
            plt.text((x+i)/10 - 0.1 + center[0], 0 + center[1], i,
                horizontalalignment='left', verticalalignment='center', color=c2, weight='semibold')

#### Observed fluorescent vs expected fluorescent

In [None]:
def get_ratio_predicted2observed_fluo(dataset, n_mut, threshold=0.4, **kwargs):
    '''Returns the fraction of genotypes observed to be bright, from among those predicted to be bright
    (measure of negative epistasis), as well as the fraction of genotypes observed to be dark,
    from among those predicted to be dark (measure of positive epistasis).'''
    
    n_mut_mask = dataset['n_mut']==n_mut
    bright_prediction_mask = dataset['expected_effect']>=threshold
    dark_prediction_mask = dataset['expected_effect']<threshold
    is_bright_mask = dataset['measured_effect']>=threshold
    
    bright_predicted = dataset[n_mut_mask & bright_prediction_mask]
    bright_real = dataset[n_mut_mask & is_bright_mask & bright_prediction_mask]
    dark_predicted = dataset[n_mut_mask & dark_prediction_mask]
    dark_real = dataset[n_mut_mask & ~is_bright_mask & dark_prediction_mask]
    
    return len(bright_real)/len(bright_predicted)


def plot_predicted_bright_vs_real(dataset, threshold=False):

    brightness_thresholds = {'amac':np.log10(3000), 'cgre':np.log10(6000), 'pplu':np.log10(5000), 'av':2, 
                             'amacV14L':np.log10(3000), 'amacV14V':np.log10(3000)}
    n = 8
    name = {'amac':'amacGFP', 'av':'avGFP', 'cgre':'cgreGFP', 'pplu':'ppluGFP2'}
    
    for gene in ['amac', 'cgre', 'pplu','av' ,]:

        plt.plot([get_ratio_predicted2observed_fluo(dataset[eval(gene+'_mask')], n_mut=i, 
                            threshold=threshold) for i in range(1,n)],
                 color=colors[gene],label=name[gene], linewidth=3)
            
        plt.legend(loc='lower left', frameon=False)
        plt.ylim(0,1.05)
        plt.xticks(range(n-1), range(1,n))
        label_plot_axis(x='Number of amino acid substitutions',
                        y='Fraction of observed v.\nexpected bright genotypes',)

#### Buried vs exposed single mutants

In [None]:
def plot_buried_vs_exposed_violins(cm='mako' ):
    c2 = ListedColormap(sns.color_palette(cm, 256))
    colors = [c2(200), c2(100)]
    
    sns.violinplot(data=data_aa_scaled[~wt_mask & singles_mask], x='gene', y='brightness', 
                   saturation=1,
                   palette=colors, width=1,
                   order=['amacGFP','cgreGFP','ppluGFP','avGFP',], 
                   hue='has_buried_mutation', col='gene', split=True, linewidth=0, scale='area', cut=0)
    
    legendary(colors + ['w'], ['Exposed sites', 'Buried sites', 'Synonymous'], 
              colors + ['k'],
              ncol=3, loc='upper center')
    plt.ylim(-0.05, 1.4)
    plt.yticks([y/10 for y in range(0,13,2)], [y/10 for y in range(0,13,2)])
    
    plot_half_violin([data_nt[wt_mask_nt & genemask]['scaled_brightness'] for genemask in [amac_mask_nt, cgre_mask_nt,
                                                                                 pplu_mask_nt, av_mask_nt]],
                    side='left', color='w', show_medians=False, alpha=1, widths=0.3)
    plt.xticks([0,1,2,3], ['amacGFP', 'cgreGFP', 'ppluGFP2', 'avGFP'])
#     sns.violinplot(data=all_data_nt[wt_mask_nt], x='gene', y=y_axis, order=['amacGFP','ppluGFP','cgreGFP','avGFP'],
#                    linewidth=0, color='w', scale='width', width=0.3, cut=0)
    label_plot_axis(y='Fluorescence of\nsingle mutants', )


## Figure 3: Neutral-to-deleterious change vs sequence distance

#### Mutations which are neutral in one background and become deleterious in another

In [None]:
def get_fraction_changed_effect(data, gene1, gene2, neutral_threshold, bad_threshold, mode='n2b'):
    # n2b = neutral in gene1, becomes deleterious in gene2
    # n2n =  neutral in gene1, remains neutral in gene2
    # b2n = deleterious in gene1, becomes neutral in gene1
    # b2b = deleterious in gene1, remains deleterious in gene2
    assert mode in ['n2b', 'n2n', 'b2n', 'b2b']
    
    df = data[['effect_in_'+gene1, 'effect_in_'+gene2]].dropna() # muts in df are observed in both genes
    df = df.reset_index()
    
    neutral_in_gene1_mask = df['effect_in_'+gene1] > neutral_threshold
    neutral_in_gene2_mask = df['effect_in_'+gene2] > neutral_threshold
    bad_in_gene1_mask = df['effect_in_'+gene1] < bad_threshold
    bad_in_gene2_mask = df['effect_in_'+gene2] < bad_threshold
    
    if mode =='n2b':
        return len(df[neutral_in_gene1_mask & bad_in_gene2_mask]) / len(df[neutral_in_gene1_mask])
    elif mode == 'n2n':
        return len(df[neutral_in_gene1_mask & neutral_in_gene2_mask]) / len(df[neutral_in_gene1_mask])
    elif mode == 'b2n':
        return len(df[bad_in_gene1_mask & neutral_in_gene2_mask]) / len(df[bad_in_gene1_mask])
    elif mode == 'b2b':
        return len(df[bad_in_gene1_mask & bad_in_gene2_mask]) / len(df[bad_in_gene1_mask])

In [None]:
def plot_fraction_changed_effect(data, neutral_threshold, bad_threshold, mode, style, se=75, sf=10, **kwargs):
    genes = ['amacV14L', 'amac', 'cgre', 'pplu', 'av', ]
    shapes = {'amac':'P', 'amacV14L':'X', 'cgre':'o', 'pplu':'s', 'av':'d'}
    gene_names = {'amac':'amacGFP', 'cgre':'cgreGFP', 'av':'avGFP', 'pplu':'ppluGFP2', 'amacV14L':'amacGFP:V14L'}
    for gene1 in genes:
        for gene2 in genes:
            if gene1 != gene2:
                x = 100 - identities[frozenset([gene1,gene2])]
                y = get_fraction_changed_effect(data, gene1, gene2, 
                                                neutral_threshold, bad_threshold, mode)
                if style == 'simple':
                    plt.scatter(x, y, s=sf, color=colors[gene2], edgecolor='w', **kwargs)
                elif style=='orbits':
                    plt.scatter(x, y, s=se, facecolor=(0,0,0,0), marker='o',
                                edgecolor=colors[gene1], linewidth=1.5, linestyle='--')
                    plt.scatter(x, y, s=sf, color=colors[gene2], linewidth=0, marker='o')
                    
                elif style=='shapes':
                    plt.scatter(x, y, s=sf, color=colors[gene2], edgecolor='w', linewidth=1, marker=shapes[gene1])
                else:
                    plt.scatter(x, y, s=sf, color=style, edgecolor='w', **kwargs)
                   
                
    plt.xticks([0,18,58,82])
#     label_plot_axis(x = 'Sequence distance', y = 'Fraction of mutations')
    plt.xlim(-2, 85)

#### Pairs which are epistatic in one background and remain epistatic in another background

In [None]:
def get_epistatic_pairs_overlap(gene1, gene2, e):
    pairs_gene1 = {x for x in epistatic_pairs_pos[gene1] if abs(max(epistatic_pairs_pos[gene1][x]))>e
                                      and x in epistatic_pairs_pos[gene2]}
    pairs_gene2 = {x for x in epistatic_pairs_pos[gene2] if abs(max(epistatic_pairs_pos[gene2][x]))>e
                                      and x in epistatic_pairs_pos[gene1]}
    return len([x for x in pairs_gene1 if x in pairs_gene2]) / len(pairs_gene1)

In [None]:
def plot_shared_epistatic_pairs(e=0.3, style='shapes', se=75, sf=10, **kwargs):
    genes = ['amacV14L', 'amac', 'cgre', 'pplu', 'av', ]
    shapes = {'amac':'P', 'amacV14L':'X', 'cgre':'o', 'pplu':'s', 'av':'d'}
    gene_names = {'amac':'amacGFP', 'cgre':'cgreGFP', 'av':'avGFP', 'pplu':'ppluGFP2', 'amacV14L':'amacGFP:V14L'}
    for gene1 in genes:
        for gene2 in genes:
            if gene1 != gene2:
                x = 100 - identities[frozenset([gene1,gene2])]
                y = get_epistatic_pairs_overlap( gene1, gene2, e)
                if style == 'simple':
                    plt.scatter(x, y, s=sf, color=colors[gene2], edgecolor='w', **kwargs)
                elif style == 'orbits':
                    plt.scatter(x, y, s=se, facecolor=(0,0,0,0), marker='o',
                                edgecolor=colors[gene1], linewidth=1.5, linestyle='--')
                    plt.scatter(x, y, s=sf, color=colors[gene2], linewidth=0, marker='o')
                elif style == 'shapes':
                    plt.scatter(x, y, s=sf, color=colors[gene2], edgecolor='w', linewidth=1, marker=shapes[gene1])
                else:
                    plt.scatter(x, y, s=sf, color=style, edgecolor='w',**kwargs)
                
    plt.xticks([0,18,58,82])
#     label_plot_axis(x = 'Sequence distance between backgrounds', y = 'Fraction of pairs')
    plt.xlim(-2, 85)

#### Shared legend

In [None]:
def plot_acrossbg_legend(style, **kwargs):
    genes = ['amacV14L', 'amac', 'cgre', 'pplu', 'av', ]
    shapes = {'amac':'P', 'amacV14L':'X', 'cgre':'o', 'pplu':'s', 'av':'d'}
    gene_names = {'amac':'amacGFP', 'cgre':'cgreGFP', 'av':'avGFP', 'pplu':'ppluGFP2', 'amacV14L':'amacGFP:V14L'}
    if style=='colors':
        plt.legend([mpatches.Patch(facecolor=colors[gfp]) for gfp in genes] + [mpatches.Patch(facecolor='w')]
                   + [Line2D([0],[0], c='w', marker='$\u25CC$', markeredgecolor='k', ms=10, markeredgewidth=0.5,
                             markerfacecolor='w'),
                     Line2D([0],[0], c='w', marker='o', markerfacecolor='k', ms=6)], 
                  [gene_names[gfp] for gfp in genes] + ['','Donor background', 'Acceptor background'], 
                   **kwargs)
        
    elif style=='shapes':
        filler = mpatches.Patch(facecolor='w')
        plt.legend([filler] + [mpatches.Patch(facecolor=colors[gene], edgecolor=colors[gene]) for gene in genes] 
                   + [filler, filler]
                  + [Line2D([0],[0],c='w',marker=shapes[gene], ms=10, markerfacecolor='k',
                           markeredgecolor='k') for gene in genes],
                  
                  ['Acceptor background'] + [gene_names[gene] for gene in genes] + ['', 'Donor background']
                  + [gene_names[gene] for gene in genes],
                  
                  frameon=False, **kwargs)

## Figure 4: Results of neural net predictions

In [None]:
def load_predictions_data():
    predictions = pd.read_csv('unified_nn_predictions.txt', sep='\t')
    fiji_ctrls = {'cgre_wt': 246, 'cgre_neg': 5, 'amac_wt': 65, 'amac_neg': 4, 'pplu_wt': 191, 'pplu_neg': 4}
    predictions['fiji_log_value'] = np.log10(predictions['fiji_value'])
    predictions['fiji_log_scaled'] = predictions[['fiji_log_value', 'gene']].apply(lambda x:
                                        (x[0] - np.log10(fiji_ctrls[x[1]+'_neg'])) / 
                            (np.log10(fiji_ctrls[x[1]+'_wt']) -  np.log10(fiji_ctrls[x[1]+'_neg'])), axis=1)
    predictions['fiji_log_scaled'] = predictions['fiji_log_scaled'].apply(lambda x: x if x>0 else 0)
    return predictions

#### Violin plots of measured data

In [None]:
def plot_all_vs_neutral_muts(gene, y_axis, threshold, df_effects, color):
    df = data_aa_scaled[(data_aa['gene']==gene+'GFP')]
    plot_half_violin([df[df['n_mut']==i][y_axis] for i in range(1,9)], side='left', 
                     color=color[1], alpha=1, widths=0.9, chonkylines=True)
    bads = df_effects[df_effects['effect_in_'+gene] < -threshold].copy()
    bads['full_mutation'] = bads[['wt_state_'+gene, 'position', 'mutation']].apply(lambda x:
                                            x[0] + str(x[1]) + x[2], axis=1)
    bads = set(bads['full_mutation'])
    bads = {x for x in bads if '*' not in x}

    df = df[(df['gene']==gene+'GFP') & ~(df['aa_genotype_pseudo'].str.contains('|'.join(bads)))]

    plot_half_violin([df[df['n_mut']==i][y_axis] for i in range(1,9)], side='right', alpha=1, 
                     color=color[0], widths=0.9, chonkylines=True)
#     plt.ylim(df[y_axis].min() - 0.25, df[y_axis].max() + 0.1)

    label_plot_axis(x = 'Number of mutations', t= gene+'GFP fitness')
#     plt.xticks(range(7), range(2,9))
#     plt.yticks([x/10 for x in range(0,12,2)], [x/10 for x in range(0,12,2)])

    if gene!='av':
        plt.axhline(gate_borders_scaled[gene][0], color='crimson', linewidth=1, linestyle='--',
                   path_effects=[pe.Stroke(linewidth=3, foreground='w'), pe.Normal()])
        
#     plt.xticks(range(8), range(1,9))
#     plt.axhline(gate_borders_scaled[gene][3], color='g', linewidth=1)


In [None]:
import matplotlib.patheffects as pe
def plot_all_vs_neutral_muts_ONLYMEDIANS(gene, y_axis, threshold, df_effects, color):
    df = data_aa_scaled[(data_aa['gene']==gene+'GFP')]
    meds = [df[df.n_mut==i]['brightness'].median() for i in range(1,9)]
    plt.plot(meds, color=color[0])
    bads = df_effects[df_effects['effect_in_'+gene] < -threshold].copy()
    bads['full_mutation'] = bads[['wt_state_'+gene, 'position', 'mutation']].apply(lambda x:
                                            x[0] + str(x[1]) + x[2], axis=1)
    bads = set(bads['full_mutation'])
    bads = {x for x in bads if '*' not in x}

    df = df[(df['gene']==gene+'GFP') & ~(df['aa_genotype_pseudo'].str.contains('|'.join(bads)))]
    meds = [df[df.n_mut==i]['brightness'].median() for i in range(1,9)]
    plt.plot(meds, color=color[1], linestyle='-')
#     plt.ylim(df[y_axis].min() - 0.25, df[y_axis].max() + 0.1)

    label_plot_axis(x = 'Number of mutations', t= gene+'GFP fitness')
#     plt.xticks(range(7), range(2,9))
#     plt.yticks([x/10 for x in range(0,12,2)], [x/10 for x in range(0,12,2)])

    if gene!='av':
        plt.axhline(gate_borders_scaled[gene][0], color='crimson', linewidth=1, linestyle='--')
        
#     plt.xticks(range(8), range(1,9))
#     plt.axhline(gate_borders_scaled[gene][3], color='g', linewidth=1)


In [None]:
def plot_predictions(gene, color,mode='full',**kwargs):
    c = ListedColormap(sns.color_palette('mako', 256))
    pos = [1,2,3,4,5,6,7,8,12,18,24,30,36,42,48]
    if mode=='full':
        plot_all_vs_neutral_muts(gene, 'brightness', 0.05, 
                                             df_effects_singles_scaled, color=color)
    elif mode=='medians':
        plot_all_vs_neutral_muts_ONLYMEDIANS(gene, 'brightness', 0.05, 
                                             df_effects_singles_scaled, color=color)

    dummy = predictions[predictions['gene']==gene].copy()
    dists = list(range(1,49))
    for d in [6,12,18,24,30,36,42,48]:
        dists.remove(d)
    for d in dists:
        dummy.loc[10000*d,'distance'] = d
        dummy.loc[10000*d,'gene'] = gene
    sns.swarmplot(data=dummy[dummy['gene']==gene], x='distance', y='fiji_log_scaled', color='k',**kwargs)

    dists = [dummy[dummy['distance']==x]['fiji_log_scaled'].median() for x in [6,12,18,24,30,36,42,48]]
    plt.plot([x-1 for x in [6,12,18,24,30,36,42,48]], dists, color='k', linestyle='--')
    
    
#     plt.plot([x-1 for x in [6,12,18,24,30,36,42,48]], dists, linestyle='--',
#              color='w', lw=2, path_effects=[pe.Stroke(linewidth=5, foreground=colors[gene]), pe.Normal()])
    
    
#     legendary(colors=[c2(150), c2(50), 'k'], loc='upper center',
#                   labels=['all mutations', '~neutral mutations', 'neural net predicted'], ncol=3, frameon=False)
#     label_plot_axis(x='Amino acid substitutions', y='Brightness')
    plt.xticks([i-1 for i in pos], pos);

## Suppl. Fig. 2: Distribution of libraries during sorting

In [4]:
from FlowCytometryTools import FCMeasurement
from FlowCytometryTools import ThresholdGate, PolyGate, IntervalGate, QuadGate

def transform_by_log10(dataset):
    dataset = np.where(dataset<=1, 0, np.log10(dataset))
    return dataset

def determine_filename(gene, machine, ctrl):
    dates = {'amac':'20190117', 'cgre':'20190124', 'pplu':'20190128'}
    path = os.path.join(data_folder, gene+'GFP', 'cell_sorting_data')
    machine = '_' + machine + '_'
    if ctrl in ['ctrl', 'negctrl']:
        name = 'negative_control'
    else:
        name = gene + 'GFP_library'
    filename = os.path.join(path, dates[gene] + machine + name + '.fcs')
    return filename

def load_fcs_data(filename):
    fcs = FCMeasurement(ID=filename, datafile=filename)
    cols = [x for x in fcs.data.columns if x!='Time']
    fcs = fcs.transform(transform_by_log10, auto_range=False, use_spln=False, channels=cols)
    for channel_name in ['GFP-A', 'FITC-A']:
        if channel_name in fcs.data.columns:
            fcs.data.rename(columns={channel_name:'GFP'}, inplace=True)
    for channel_name in ['mCherry-A', 'PE-Texas Red-A']:
        if channel_name in fcs.data.columns:
            fcs.data.rename(columns={channel_name:'mKate2'}, inplace=True)
    return fcs

The dedent function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use inspect.cleandoc instead.
  doc = dedent(doc)


In [5]:
is_bacteria_gate = ThresholdGate(threshold=1, channel=['FSC-A'], region='above', name='is_bacteria_gate')

# approximately, based on sorting data pdfs from VBCF
mKate2_gate_A = IntervalGate(vert=(3.2, 3.4), channel=['mKate2'], region='in', name='mKate2_gate_A')
mKate2_gate_B = IntervalGate(vert=(3.4, 3.6), channel=['mKate2'], region='in', name='mKate2_gate_B')

In [271]:
def sortingplot_library_scatter(gene, machine, fontscale=1, savepath=False):
    redgate = {'A' : {'amac':(3.18, 3.35), 'cgre':(3.14, 3.32), 'pplu':(3.18, 3.44)},
              'B' : {'amac':(3.4, 3.6), 'cgre':(3.29, 3.57), 'pplu':(3.28, 3.53)}}
    
    gate_borders = pd.read_csv(os.path.join(data_folder, gene+'GFP', gene+'GFP_03_gate_border_values.txt'),
                              sep='\t', index_col=0)
    gate_borders = {'A': np.log10(np.array(gate_borders.head(7)['upper_limit'])),
                  'B': np.log10(np.array(gate_borders.iloc[8:15]['upper_limit']))}
    
    library = load_fcs_data(determine_filename(gene,machine,''))
    library = library.gate(is_bacteria_gate).data.copy().replace(0,np.nan).sample(10000)
    library[''] = 'Whole library'
    ctrl = load_fcs_data(determine_filename(gene,machine,'ctrl'))
    ctrl = ctrl.gate(is_bacteria_gate).data.copy().replace(0,np.nan).sample(1500)
    ctrl[''] = 'Non-fluorescent controls'
    data=pd.concat([library,ctrl])
    data['sorted'] = data['mKate2'].apply(lambda x: True if 
                                          redgate[machine][gene][0]<x<redgate[machine][gene][1] else False)
    
    ax = sns.jointplot(data=data, x='GFP', y='mKate2', hue='', 
                       alpha=0.4, s=5, edgecolor=None, palette=['darkgrey', 'k',], )
    ax.set_axis_labels('GFP fluorescence (log)','mKate2 fluorescence (log)', fontsize=10+fontscale)
    ax.ax_joint.legend(frameon=False)

    plt.sca(ax.ax_joint)
    plt.scatter(data[data.sorted==True]['GFP'], data[data.sorted==True]['mKate2'], color='crimson', s=4, 
                label='Sorted cells', alpha=0.4)
    legendary(['lightgrey', 'grey', 'crimson'], 
              ['Whole library', 'Non-fluorescent controls', 'Sorted cells'], fontsize=10+fontscale,
             title='%sGFP library: sorter %s' % (gene,machine), title_fontsize=12+fontscale, loc='upper left')
    plt.ylim(0,5.5)
    plt.xlim(0,5.5)
    plt.yticks(fontsize=10+fontscale)
    plt.xticks(fontsize=10+fontscale)
    
    plt.sca(ax.ax_marg_x)
    sns.kdeplot(x=data[data.sorted==True]['GFP'], color='crimson')
    for border in gate_borders[machine]:
        plt.axvline(border, color='crimson', linestyle='--', linewidth=1)
    
#     plt.fill_between([0,6], redgate[0], redgate[1], facecolor='r', edgecolor='w', alpha=0.4, linewidth=0,
#                      hatch='////', )
    return ax
    if savepath:
        savefig(savepath, dpi=300)

## Suppl. Fig. 4: Mutation effects by position

In [None]:
def plot_positional_effects_heatmap(df=df_effects_singles, func=np.nanmedian):
    effects_amac = get_effects_by_position('amac', df, positions='pseudo', func=func)
    effects_cgre = get_effects_by_position('cgre', df, positions='pseudo', func=func)
    effects_pplu = get_effects_by_position('pplu', df, positions='pseudo', func=func)
    effects_av = get_effects_by_position('av', df, positions='pseudo', func=func)

    font = 'Arial'
    
    to_plot = np.array([effects_av, effects_amac, effects_cgre, effects_pplu, ])
        
    hm = sns.heatmap(to_plot, 
               cmap='Blues_r', yticklabels=['avGFP','amacGFP', 'cgreGFP', 'ppluGFP2', ], xticklabels=10,
               cbar_kws={'label': 'Median mutation effect', 'pad':0.01, 'fraction':0.05})

    y = 0.5
    for gene in ['av','amac','cgre','pplu']:
        ss = pseudify(eval(gene+'_ss_pymol'), gene)
        plot_secondary_structure(ss, y=y, hel=0.05, arrow_width=0.1, linewidth=1, c='k')
        y += 1
    plt.scatter([x+0.5 for x in buried_pos], [0 for x in buried_pos], color='k', s=50, marker=2)

    plt.pcolor(np.arange(len(effects_amac)+1), np.arange(len(to_plot)+1), 
               np.ma.masked_less(np.array([[np.nan]*67+[10,10,10]+[np.nan]*178 for i in range(4)]), 5), 
               hatch='///', alpha=0)

    label_plot_axis(x='Aligned amino acid position', t='Median effects of single mutations, by position')
    plt.xticks(fontname=font)
    plt.yticks(rotation=0, fontname=font)
    plt.ylim(len(to_plot)+0.1, -0.1)
    return to_plot, hm

In [None]:
import matplotlib.image as mpimg
def show_image(image_path):
    plt.imshow(mpimg.imread(image_path))
    plt.axis('off')

## Suppl. Fig. 5: General epistasis

In [None]:
def doublemut_scatterplot(gene):
    df = doublemuts[doublemuts['gene']==gene+'GFP'].copy()
    df['abs_epistasis'] = abs(df['epistasis'])
    df.sort_values('abs_epistasis', inplace=True)
    x,y,z = 'mut1_effect','mut2_effect','epistasis'
    plt.scatter(df[x], df[y], alpha=0.8,
               c=df[z], cmap='Spectral', s=5,vmin=-2, vmax=2)
    label_plot_axis(x='Effect of first mutation', y='Effect of second mutation', t=gene+'GFP')
#     cbar = plt.colorbar()
#     cbar.ax.set_ylabel('Epistasis')
plt.tight_layout()

In [None]:
def positional_heatmap(gene, include_singlemut_e=True):
    # assumes positions are pseudopisitions, not native
    
    length = len(eval(gene+'_wt'))
    
    if include_singlemut_e==False:
        positional_max_e = np.zeros([length, length])
        for i in range(247):
            for j in range(247):
                pos1 = pseudopos_to_nativepos[i][genekey[gene]]
                pos2 = pseudopos_to_nativepos[j][genekey[gene]]
                if str(i)+':'+str(j) in epistatic_pairs_pos[gene]:
                    positional_max_e[pos1,pos2] = max(map(abs, epistatic_pairs_pos[gene][str(i)+':'+str(j)]))
                    positional_max_e[pos2,pos1] = 20
                    
        sns.heatmap(positional_max_e, vmin=0, vmax=1, cmap='GnBu', square=True, mask= positional_max_e==0,
                    cbar_kws={'label':'epistasis', 'shrink':1})
        plt.ylim(0,length)
        plt.xlim(0,length)
    #     ax.invert_yaxis()
        plt.xticks(np.array(range(11,length,25))+0.5, range(0,length,25) )
        plt.yticks(np.array(range(11,length,25))+0.5, range(0,length,25), rotation=0)
        
    else:
        positional_max_e = np.zeros([length+10, length+10])
        position_effects = get_effects_by_position(gene, df_effects_singles, positions='native')
        for i in range(247):
            for j in range(247):
                pos1 = pseudopos_to_nativepos[i][genekey[gene]]
                pos2 = pseudopos_to_nativepos[j][genekey[gene]]
                if str(i)+':'+str(j) in epistatic_pairs_pos[gene]:
                    positional_max_e[pos1+10, 0:10] = position_effects[pos1]-10
                    positional_max_e[0:10, pos1+10] = position_effects[pos1]-10
                    positional_max_e[pos1+10,pos2+10] = max(map(abs, epistatic_pairs_pos[gene][str(i)+':'+str(j)]))
                    positional_max_e[pos2+10,pos1+10] = 20
                    
        sns.heatmap(positional_max_e, mask= positional_max_e<=0, vmin=0, vmax=2, cmap='Blues',
               square=True, cbar_kws={'label': 'epistasis', 'shrink':1},)
        sns.heatmap(positional_max_e, mask= positional_max_e>=0, vmin=np.nanmin(position_effects)-10, 
                vmax=np.nanmax(position_effects)-10, cmap='Blues_r', cbar=False, square=True, )
            #cbar_kws={'label': 'singlemut effect', 'shrink':0.75, 'orientation':'horizontal'})
            
        plt.ylim(0,length+10)
        plt.xlim(0,length+10)
    #     ax.invert_yaxis()
        plt.xticks(np.array(range(11,length+10,25))+0.5, range(0,length,25) )
        plt.yticks(np.array(range(11,length+10,25))+0.5, range(0,length,25), rotation=0)

        
        

In [None]:
def barplot_epistasis(gene, e):
    neg = [len(data_aa[(data_aa['gene']==gene+'GFP') & (data_aa['epistasis'] < -e) & (data_aa['n_mut']==i)]) / 
         len(data_aa[(data_aa['gene']==gene+'GFP') & (data_aa['n_mut']==i)])
        for i in range(2,9)]
    pos = [len(data_aa[(data_aa['gene']==gene+'GFP') & (data_aa['epistasis'] > e) & (data_aa['n_mut']==i)]) / 
         len(data_aa[(data_aa['gene']==gene+'GFP') & (data_aa['n_mut']==i)])
        for i in range(2,9)]

    plt.bar(range(len(neg)), neg, color=colors[gene])
    plt.bar(range(len(neg)), pos, bottom=neg, color=colors[gene], alpha=1, hatch='///', 
            edgecolor='w', linewidth=0)
    plt.xticks(range(len(neg)), range(2,11))
    plt.ylim(0,0.25)
    label_plot_axis(t=gene+'GFP', x='Number of mutations', y='Fraction of genotypes\nwith epistasis over %s' % e);

## Suppl. Fig. 6: Distances between amino acid pairs

In [None]:
def distances_heatmap(gene, epistasis=0.3):
    length = len(eval(gene+'_wt'))
    df = pd.read_csv(os.path.join(data_folder, gene+'GFP', gene+'GFP_11a_minimum_distance_matrix_aa.txt'),
                    sep='\t', index_col=0)
    positional_max_e = np.zeros([length, length])
    sites_to_ignore = [x for x in range(length) if x not in df.index]
    
    for i in range(length):
        for j in range(length):
            if str(i)+':'+str(j) in epistatic_pairs_pos[gene]:
                positional_max_e[i,j] = max(map(abs, epistatic_pairs_pos[gene][str(i)+':'+str(j)]))

    positional_max_e = np.delete(positional_max_e, sites_to_ignore, axis=1)
    positional_max_e = np.delete(positional_max_e, sites_to_ignore, axis=0)

    d = np.array(df)
    sns.heatmap(data=d, square=True, cmap='mako', mask=np.tril(d), cbar=False,
                cbar_kws={'label': 'Minimum distance between residues (A)', 'shrink':1, 'pad':0.15})
    ax = sns.heatmap(data=positional_max_e, mask=positional_max_e<0.3,square=True, alpha=1, 
                cmap='hot', cbar=False, vmin=0, vmax=0.3)
    plt.ylim(0,len(d))
    plt.xlim(0,len(d))
    plt.xticks(range(0, len(d), 20), df.index[0::20], fontsize=8)
    plt.yticks(range(0, len(d), 20), df.index[0::20], rotation=0, fontsize=8)
    ax.yaxis.set_label_position('right')
    ax.yaxis.tick_right()
        
    label_plot_axis(x='Amino acid position',y='Amino acid position', fontsize_x=10, fontsize_y=10)

In [None]:
def plot_all_distance_maps(df = doublemuts, e = 0.3):
    df['e_above_threshold'] = df['epistasis'].apply(lambda x: True if abs(x)>e else False)
    fig = plt.figure(figsize=[10,10], dpi=200)

    c = ListedColormap(sns.color_palette('mako', 256))

    plt.subplot(2,2,1)
    distances_heatmap('amac')
    left, bottom, width, height = [0.02, 0.75, 0.14, 0.18]
    ax1 = fig.add_axes([left, bottom, width, height])
    ax = sns.violinplot(data=df[(df.gene=='amacGFP')], x='gene', y='distance', linewidth=0, cut=0,
                  hue='e_above_threshold', split=True, palette=[c(175), c(75)], )
    ax.get_legend().remove()
    ax1.spines['right'].set_visible(False)
    ax1.spines['top'].set_visible(False)
    ax1.spines['bottom'].set_visible(False)
    ax1.spines['left'].set_position('center')
    plt.xticks(fontsize=12)
    label_plot_axis(y='Distance between residues (A)', fontsize_y=8)
    ax1.text(0.1, 50, 'No epistasis', color=c(175))
    ax1.text(0.1, 45, '|Epistasis| > 0.3', color=c(75))
    plt.ylim(0,50)

    plt.subplot(2,2,2)
    distances_heatmap('cgre')
    left, bottom, width, height = [0.52, 0.75, 0.14, 0.18]
    ax2 = fig.add_axes([left, bottom, width, height])
    ax = sns.violinplot(data=df[(df.gene=='cgreGFP')], x='gene', y='distance', linewidth=0, cut=0,
                  hue='e_above_threshold', split=True, palette=[c(175), c(75)], )
    ax.get_legend().remove()
    ax2.spines['right'].set_visible(False)
    ax2.spines['top'].set_visible(False)
    ax2.spines['bottom'].set_visible(False)
    ax2.spines['left'].set_position('center')
    plt.xticks(fontsize=12)
    label_plot_axis(y='Distance between residues (A)', fontsize_y=8)
    ax2.text(0.1, 50, 'No epistasis', color=c(175))
    ax2.text(0.1, 45, '|Epistasis| > 0.3', color=c(75))
    plt.ylim(0,50)

    plt.subplot(2,2,3)
    distances_heatmap('pplu')
    left, bottom, width, height = [0.02, 0.25, 0.14, 0.18]
    ax3 = fig.add_axes([left, bottom, width, height])
    ax = sns.violinplot(data=df[(df.gene=='ppluGFP')], x='gene', y='distance', linewidth=0, cut=0,
                  hue='e_above_threshold', split=True, palette=[c(175), c(75)], )
    ax.get_legend().remove()
    ax3.spines['right'].set_visible(False)
    ax3.spines['top'].set_visible(False)
    ax3.spines['bottom'].set_visible(False)
    ax3.spines['left'].set_position('center')
    plt.xticks(fontsize=12)
    label_plot_axis(y='Distance between residues (A)', fontsize_y=8)
    ax3.text(0.1, 50, 'No epistasis', color=c(175))
    ax3.text(0.1, 45, '|Epistasis| > 0.3', color=c(75))
    plt.ylim(0,50)

    plt.subplot(2,2,4)
    distances_heatmap('av')
    left, bottom, width, height = [0.52, 0.25, 0.14, 0.18]
    ax4 = fig.add_axes([left, bottom, width, height])
    ax = sns.violinplot(data=df[(df.gene=='avGFP')], x='gene', y='distance', linewidth=0, cut=0,
                  hue='e_above_threshold', split=True, palette=[c(175), c(75)], )
    ax.get_legend().remove()
    ax4.spines['right'].set_visible(False)
    ax4.spines['top'].set_visible(False)
    ax4.spines['bottom'].set_visible(False)
    ax4.spines['left'].set_position('center')
    plt.xticks(fontsize=12)
    label_plot_axis(y='Distance between residues (A)', fontsize_y=8)
    ax4.text(0.1, 50, 'No epistasis', color=c(175))
    ax4.text(0.1, 45, '|Epistasis| > 0.3', color=c(75))
    plt.ylim(0,50)

    plt.tight_layout(pad=2)
    
    fig,ax = plt.subplots(figsize=[10,0.5], dpi=200)
    fig.subplots_adjust(bottom=0.5)

    norm = matplotlib.colors.Normalize(vmin=0, vmax=50)
    cb1 = matplotlib.colorbar.ColorbarBase(ax, cmap=c,norm=norm,orientation='horizontal')
    cb1.ax.tick_params(labelsize=7)
    cb1.set_label('Minimal distance between residues (A)',labelpad=-25, color='w',)

In [None]:
def distances_violins(df, epistasis, palette='mako'):
    c = ListedColormap(sns.color_palette(palette, 256))
    df['e_above_threshold'] = df['epistasis'].apply(lambda x: True if abs(x)>epistasis else False)

    sns.violinplot(data=df[(df.gene!='amacV14LGFP') & 
                                   (abs(df['epistasis'])>0)], 
                   x='gene', y='distance', hue='e_above_threshold', split=True, linewidth=0,
                  palette=[c(75), c(175)], saturation=100, cut=0)
    legendary([c(75), c(175)], ['No epistasis', 'Epistasis > %s' % epistasis], ncol=2, 
              loc='upper center')
    label_plot_axis(y='Distance (A) between\namino acid pairs', )

## Suppl. Fig. 7: Protein stability

In [None]:
def import_plate_reader_data():
#     spectra_fluo = pd.read_csv('urea_spectra_fluo_7_longform.txt', sep='\t', index_col=0)
    spectra_fluo = pd.read_csv('urea_spectra_fluo_longform_raw_and_normed.txt', sep='\t', index_col=0)
#     spectra_abs = pd.read_csv('urea_spectra_abs_longform.txt', sep='\t', index_col=0)
    spectra_abs = pd.read_csv('urea_spectra_abs_longform_raw_and_normed.txt', sep='\t', index_col=0)

    
    spectra_fluo['\nTreatment'] = spectra_fluo['treatment'].apply(lambda x: '9M urea' if x=='urea' else 'control')
    spectra_fluo['hours'] = spectra_fluo['t'] * 25 / 60 #each fluorescence spectrum measurement takes ~25min
    spectra_fluo['Gene'] = spectra_fluo['gene'].apply(lambda x: 'amacGFP' if x=='amac' else 'cgreGFP' if x=='cgre'
                                                 else 'ppluGFP' if x=='pplu' else 'avGFP' if x=='av'
                                                 else 'amacGFP_V14L' if x=='amacV14L' else 'blank')
    
    spectra_abs['\nTreatment'] = spectra_abs['treatment'].apply(lambda x: '9M urea' if x=='urea' else 'control')
    spectra_abs['hours'] = spectra_abs['t'] * 40 / 60 #each absorbance spectrum measurement takes ~40min
    spectra_abs['Gene'] = spectra_abs['gene'].apply(lambda x: 'amacGFP' if x=='amac' else 'cgreGFP' if x=='cgre'
                                                 else 'ppluGFP' if x=='pplu' else 'avGFP' if x=='av'
                                                 else 'amacGFP_V14L' if x=='amacV14L' else 'blank')
    return spectra_abs, spectra_fluo

In [None]:
def plot_spectra(gene, treatment, normed=True):
    if normed==True:
        y = 'signal_normalized'
    elif normed==False:
        y = 'signal'
    for df,color in zip([spectra_abs, spectra_fluo], ['bone_r', 'mako_r']):
            sns.lineplot(data=df[(df['treatment']==treatment) & (df['Gene']==gene)
                                 & (df['hours']<=57) & (df['Wavelength']<=650)],
                        x='Wavelength', y=y, hue='t', ci=None, palette=color, legend=False)
    plt.ylim(-0.1,1.1)
    label_plot_axis()

In [22]:
def plot_peak_timelapse(y, normed=True):
    if normed==True:
        col = 'signal_normalized'
    elif normed==False:
        col = 'signal'
    
    if y=='Absorbance':
        df = spectra_abs
        plt.legend(handles = [Line2D([0],[0],c='k', lw=2, label='9M urea'),
                         Line2D([0],[0],c='k', ls='--', lw=2, label='1X PBS')],
               loc='lower left', frameon=False,)

    elif y=='Fluorescence':
        df = spectra_fluo

        
    sns.lineplot(data=df[(df['peak']==True) & (df['gene']!='blank') & (df['hours']<=57)], 
             x='hours', y=col, hue='Gene', ci='sd', style='\nTreatment', legend=False,
             palette=[colors['amac'],colors['amacV14L'],colors['cgre'],colors['pplu'],colors['av']])
    label_plot_axis(y=y, x='Time (hours)', )
    plt.xlim(-2.5, 60)
    plt.xticks(range(0,60,10), range(0,60,10))
    
        
#     plt.legend(frameon=False, loc='center right')

In [6]:
from matplotlib.lines import Line2D

In [5]:
def plot_thermosensitivity():
    t = pd.read_csv('qPCR_rawdata_200717.txt', sep='\t')
    t = t[t['SampleName'].str.contains('av|pplu|cgre|amac|amacV14L|blank')].copy()
    t['Time'] = t['Time']/1000
    t['temperature'] = round(t['Temp'])
    melts = pd.read_csv('qPCR_melt_200717.txt', sep='\t')
    melts_tm = {'av':melts[(melts['Name']=='av')]['Tm1'].mean(),
             'amac':melts[(melts['Name']=='amac')]['Tm1'].mean(),
             'amacV14L':melts[(melts['Name']=='amacV14L')]['Tm1'].mean(),
             'cgre':melts[(melts['Name']=='cgre')]['Tm1'].mean(),
             'pplu':melts[(melts['Name']=='pplu')]['Tm1'].mean(),}
    
    norm = {'av':t[(t['SampleName']=='av')&(t['temperature']==50)]['465-510'].mean(),
            'amac':t[(t['SampleName']=='amac')&(t['temperature']==50)]['465-510'].mean(),
        'amacV14L':t[(t['SampleName']=='amacV14L')&(t['temperature']==50)]['465-510'].mean(),
                'cgre':t[(t['SampleName']=='cgre')&(t['temperature']==50)]['465-510'].mean(),
                'pplu':t[(t['SampleName']=='pplu')&(t['temperature']==50)]['465-510'].mean(),
               'blank':t[(t['SampleName']=='blank')&(t['temperature']==50)]['465-510'].mean()}
    
    t['normed_fluo'] = t[['465-510', 'SampleName']].apply(lambda x: x[0] / norm[x[1]], axis=1)
    
    sns.lineplot(data=t[(t['SampleName']!='blank') & (t['Temp']>50)], x='Temp', y='normed_fluo',#y='465-510', 
                 hue='SampleName', ci='sd', 
                 palette=[colors['amac'],colors['amacV14L'],colors['cgre'],colors['pplu'],colors['av']])
    for gene in ['amac', 'av', 'cgre', 'pplu', 'amacV14L']:
        plt.axvline(melts_tm[gene], color=colors[gene], ls=':')
    label_plot_axis(y='Fluorescence', x='Temperature (Celsius)', )
    
    plt.legend(handles = [Line2D([0],[0],c=colors['amacV14L'], lw=2, label='amacGFP:V14L'),
                         Line2D([0],[0],c=colors['amac'], lw=2, label='amacGFP'),
                         Line2D([0],[0],c=colors['cgre'], lw=2, label='cgreGFP'),
                          Line2D([0],[0],c=colors['pplu'], lw=2, label='ppluGFP'),
                         Line2D([0],[0],c=colors['av'], lw=2, label='avGFP')],
               loc='lower left', frameon=False,)


# amacGFP v amacGFP:V14L

In [None]:
def get_mutposeffect(data, y_column, x_val, x_column='position', s_column='mutation'):
    df = data[data[x_column]==x_val]
    return list(df[y_column]), list(df[s_column])

def set_x_coordinates(list_of_yvals, list_of_labels, y_distance_threshold, center_x, x_distance):
    '''Tries to find non-overlapping x coordinates for text-swarmplot'''
    yvals_to_labels = dict(zip(list_of_yvals, list_of_labels))
    list_of_yvals = sorted([x for x in list_of_yvals if str(x)!='nan'])
    list_of_xvals = [center_x for x in list_of_yvals]
    
    ref_y = 0
    multiplier = 1
#     original_x_distance = x_distance
    current_x_spacer = x_distance
    
#     print(list_of_yvals)
    
    for i in range(1, len(list_of_yvals)):

        if abs(list_of_yvals[i] - list_of_yvals[ref_y]) < y_distance_threshold:
#             print(i, 'too close')
            list_of_xvals[i] += multiplier*current_x_spacer
            multiplier *= (-1)
            if multiplier== 1:
#                 print(i, 'multiply -1')
                current_x_spacer+=x_distance
        else:
            ref_y = i
            current_x_spacer = x_distance
#             print(i, 'new ref')

    return list_of_xvals, list_of_yvals, [yvals_to_labels[y] for y in list_of_yvals]

# Extant vs non-extant

In [None]:
def plot_extant_vs_nonextant_muts(gene, dataset=data_aa, y_axis='brightness', palette='mako', n=5):
    c = ListedColormap(sns.color_palette(palette, 256))
    df = dataset[dataset.gene==gene+'GFP']
    df['f_extant'] = df[['n_mut', 'n_mut_extant']].apply(lambda x:
                                            'all' if x[0]==x[1] else 'none' if x[1]==0 else 'some', axis=1)
    
    plot_half_violin([df[(df['n_mut']==i) & (df['f_extant']=='none')][y_axis] for i in range(1,n)], 
                     side='left', 
                     color=[c(80),c(80)], alpha=1, widths=0.8, chonkylines=True)

    plot_half_violin([df[(df['n_mut']==i) & (df['f_extant']=='all')][y_axis] for i in range(1,n)], 
                     side='right', alpha=1, 
                     color=[c(180), c(180)], widths=0.9, chonkylines=True)
#     plt.ylim(df[y_axis].min() - 0.25, df[y_axis].max() + 0.1)

    label_plot_axis(x = 'Number of mutations', t= gene+'GFP')
    plt.xticks(range(n-1), range(1,1+n))

In [None]:
def compare_against_extant(gene, dataset=data_aa, nok=-0.2, overlap_n=10, **kwargs):
#     muts_all = set(dataset[singles_mask & eval(gene+'_mask')]['quasipos_genotype'])
#     muts_nok = set(dataset[singles_mask & eval(gene+'_mask') 
#                        & (dataset.measured_effect < nok)]['quasipos_genotype'])
#     namekey = {'amac':'GFPxm'}
    gene_name = namekey[gene]
    gene_wt = get_wt_states(gene_name)

    muts_all = {x for x in eval('siffects_'+gene+'_log')}
    muts_all = {x[0] + str(pseudopos_to_nativepos[int(x[1:-1])][genekey[gene]]) + x[-1] for x in muts_all}
    muts_all = muts_all | {x[0:-1] + x[0] for x in muts_all}
    muts_all = {str(nativepos_to_quasipos[int(x[1:-1])][gene_name]) + x[-1] for x in muts_all}
    
    muts_nok = {x for x in eval('siffects_'+gene+'_log') if eval('siffects_'+gene+'_log')[x] < nok}
    muts_nok = {str(pseudopos_to_nativepos[int(x[1:-1])][genekey[gene]]) + x[-1] for x in muts_nok}
    muts_nok = {str(nativepos_to_quasipos[int(x[:-1])][gene_name]) + x[-1] for x in muts_nok}
    
    
#     print(len(muts_all), len(muts_nok))

    for extant in wts:
        if extant=='RrGFP' or extant==namekey[gene]:
#             print(extant)
            pass
        
        else:
            extant_ok = get_wt_states(extant)
            extant_ok = {x for x in extant_ok if x not in gene_wt}
                
#             print(extant, m)
            overlap = {x for x in extant_ok if x in muts_all}
            if len(overlap) > overlap_n:
                nok_in_gene = {x for x in overlap if x in muts_nok}
                aa_id = 100 - pairwise[frozenset([extant, namekey[gene]])]

                plt.scatter(aa_id, len(nok_in_gene) / len(overlap), color=colors[gene],**kwargs)
#             print(gene,extant, len(overlap))        
#     return x, y

# Other

From here: https://stackoverflow.com/questions/35042255/how-to-plot-multiple-seaborn-jointplot-in-subplot

In [None]:
import matplotlib.gridspec as gridspec

class SeabornFig2Grid():

    def __init__(self, seaborngrid, fig,  subplot_spec):
        self.fig = fig
        self.sg = seaborngrid
        self.subplot = subplot_spec
        if isinstance(self.sg, sns.axisgrid.FacetGrid) or \
            isinstance(self.sg, sns.axisgrid.PairGrid):
            self._movegrid()
        elif isinstance(self.sg, sns.axisgrid.JointGrid):
            self._movejointgrid()
        self._finalize()

    def _movegrid(self):
        """ Move PairGrid or Facetgrid """
        self._resize()
        n = self.sg.axes.shape[0]
        m = self.sg.axes.shape[1]
        self.subgrid = gridspec.GridSpecFromSubplotSpec(n,m, subplot_spec=self.subplot)
        for i in range(n):
            for j in range(m):
                self._moveaxes(self.sg.axes[i,j], self.subgrid[i,j])

    def _movejointgrid(self):
        """ Move Jointgrid """
        h= self.sg.ax_joint.get_position().height
        h2= self.sg.ax_marg_x.get_position().height
        r = int(np.round(h/h2))
        self._resize()
        self.subgrid = gridspec.GridSpecFromSubplotSpec(r+1,r+1, subplot_spec=self.subplot)

        self._moveaxes(self.sg.ax_joint, self.subgrid[1:, :-1])
        self._moveaxes(self.sg.ax_marg_x, self.subgrid[0, :-1])
        self._moveaxes(self.sg.ax_marg_y, self.subgrid[1:, -1])

    def _moveaxes(self, ax, gs):
        #https://stackoverflow.com/a/46906599/4124317
        ax.remove()
        ax.figure=self.fig
        self.fig.axes.append(ax)
        self.fig.add_axes(ax)
        ax._subplotspec = gs
        ax.set_position(gs.get_position(self.fig))
        ax.set_subplotspec(gs)

    def _finalize(self):
        plt.close(self.sg.fig)
#         self.fig.canvas.mpl_connect("resize_event", self._resize)
#         self.fig.canvas.draw()

    def _resize(self, evt=None):
        self.sg.fig.set_size_inches(self.fig.get_size_inches())
#         self.sg.fig.set_size_inches([3,2])