In [None]:
import base
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import rushd as rd
import scipy as sp
import seaborn as sns

from importlib import reload
reload(base)

sns.set_style('ticks')
sns.set_context('talk',rc={'font.family': 'sans-serif', 'font.sans-serif':['Helvetica Neue']})

In [None]:
base_path = rd.datadir/'instruments'/'data'/'attune'/'Emma'
exp11_path = base_path/'2022.10.11_EXP11'/'Data'
exp11_controls_path = base_path/'2022.10.11_EXP10'/'data_controls'
exp49_path = base_path/'2024.04.06_EXP11_replicates'/'Plate_1_EXP49'/'data_singlets'
exp50_path = base_path/'2024.04.06_EXP11_replicates'/'Plate_2_EXP50'/'data_singlets'
exp49_50_controls_path = base_path/'2024.04.06_EXP11_replicates'/'Plate_3_Controls'/'data_singlets'

plates = pd.DataFrame({
    'data_path': [base_path/'2022.10.04_EXP9'/'Data',
                  base_path/'2023.01.16_EXP12'/'Data',
                  base_path/'2023.02.09_EXP13'/'Data',
                  exp11_path, exp11_controls_path, 
                  exp49_path, exp50_path, 
                  exp49_50_controls_path, exp49_50_controls_path],
    
    'yaml_path': ([base_path/'2022.10.04_EXP9'/'Data'/'wells_KL.yaml']*3 + 
                  [exp11_path/'wells_KL.yaml', exp11_controls_path/'wells_KL.yaml', 
                   exp11_path/'wells_KL.yaml', exp11_path/'wells_KL.yaml', 
                   exp49_50_controls_path/'wells_KL.yaml', exp49_50_controls_path/'wells2_KL.yaml', ]),
    
    'biorep': [1,2,3,
               1,1,
               2,3,
               2,3],

    'exp': ['ELP_exp09', 'ELP_exp12', 'ELP_exp13',
            'ELP_exp11', 'ELP_exp11',
            'ELP_exp49', 'ELP_exp50',
            'ELP_exp49', 'ELP_exp50',],
})

output_path = rd.rootdir/'output'/'miR-characterization'
cache_path = output_path/'data.gzip'

# Load data
data = pd.DataFrame()
if cache_path.is_file(): data = pd.read_parquet(cache_path)
else: 
    channel_list = ['mRuby2-A','FSC-A','SSC-A','mGL-A']
    data = rd.flow.load_groups_with_metadata(plates, columns=channel_list)

    # Remove negative channel values
    for c in channel_list: data = data[data[c]>0]
    
    data.dropna(inplace=True)
    data.to_parquet(rd.outfile(cache_path))

# Add metadata for constructs
metadata_miR = pd.read_excel(rd.datadir/'projects'/'miR-iFFL'/'plasmids'/'miR-metadata.xlsx')
data = data.merge(metadata_miR, how='left', on='miR_construct')
metadata_ts = pd.read_excel(rd.datadir/'projects'/'miR-iFFL'/'plasmids'/'ts-metadata.xlsx')
data = data.merge(metadata_ts, how='left', on='ts_construct')
data['constructs'] = data['miR_construct'] + '_' + data['ts_construct']
display(data)

In [None]:
# Gate cells
gates = pd.DataFrame()
channel_list = ['mGL-A', 'mRuby2-A']
for channel in channel_list:
    gates[channel] = data[data['ts_construct']=='UT'].groupby(['exp'])[channel].apply(lambda x: x.quantile(0.999))
gates.reset_index(inplace=True)

# Indicate which channels are relevant for each experiment
gates.sort_values(['exp'], inplace=True)
gates['marker'] = 'mRuby2-A'
gates['output'] = 'mGL-A'

# Gate data by transfection marker expression
data = data.groupby('exp')[data.columns].apply(lambda x: base.gate_data(x,gates))
data.reset_index(inplace=True, drop=True)
df = data[(data['expressing']) & (data['ts_construct']!='UT')]

In [None]:
plot_df = data[(data['miR']!='na') & (data['ts']!='na')]
order = ['FF3','FF4','FF5','FF6','none']
for exp, group in plot_df.groupby('exp'):
    g = sns.displot(data=group, row='miR', col='ts', x='marker', hue='ts_num', kind='kde', log_scale=True,
                    facet_kws=dict(margin_titles=True), col_order=order, row_order=order, common_norm=False,
                    height=3)
    for name, ax in g.axes_dict.items():
        marker = gates.set_index('exp').to_dict('dict')['marker'][exp]
        gate = gates.loc[(gates['exp']==exp), marker].values[0]
        ax.axvline(gate, color='black')
    g.figure.savefig(rd.outfile(output_path/(f'hist_gate-marker_{exp}.svg')), bbox_inches='tight')

In [None]:
plot_df = df[(df['miR']!='na') & (df['ts']!='na')]
order = ['FF3','FF4','FF5','FF6','none']
for exp, group in plot_df.groupby('exp'):
    g = sns.displot(data=group, row='miR', col='ts', x='output', hue='ts_num', kind='kde', log_scale=True,
                    facet_kws=dict(margin_titles=True), col_order=order, row_order=order, common_norm=False,
                    height=3)
    for name, ax in g.axes_dict.items():
        output = gates.set_index('exp').to_dict('dict')['output'][exp]
        gate = gates.loc[(gates['exp']==exp), output].values[0]
        ax.axvline(gate, color='black')
    g.figure.savefig(rd.outfile(output_path/(f'hist_gate-output.svg')), bbox_inches='tight')

In [None]:
# Gate data on output expression
def gate_output(df, gates):
    df = df.copy()
    exp = df['exp'].values[0] # the same for entire df, assuming df = data.groupby('exp')
    gates_dict = gates.set_index('exp').to_dict('dict') # format: column -> {index: value}
    output = gates_dict['output'][exp]
    df['expressing'] = df[output] > gates_dict[output][exp]
    return df

df = df.groupby('exp')[df.columns].apply(lambda x: gate_output(x,gates))
df.reset_index(inplace=True, drop=True)
df2 = df[(df['expressing'])]

In [None]:
# Calculate statistics
stat_list = [np.std, sp.stats.gmean]
grouped = df2.groupby(by=['miR_construct','ts_construct','biorep','exp'])
stats2 = grouped[['marker','output']].agg(stat_list).reset_index().dropna()

# Rename columns as 'col_stat'
stats2.columns = stats2.columns.map(lambda i: base.rename_multilevel_cols(i))
stats2['count'] = grouped['output'].count().reset_index()['output']
stats2 = stats2.merge(metadata_miR, how='left', on='miR_construct')
stats2 = stats2.merge(metadata_ts, how='left', on='ts_construct')
stats2['constructs'] = stats2['miR_construct'] + '_' + stats2['ts_construct']

# Compute fold changes relative to no-TS conditions
def get_fc(df):
    d = df.copy()
    baseline = d.loc[d['ts']=='none', 'output_gmean'].mean()
    d['output_gmean_fc'] = d['output_gmean'] / baseline
    return d

stats2 = stats2.groupby(by=['miR_construct','biorep','exp'])[stats2.columns].apply(get_fc).reset_index(drop=True)

# orthogonal exp: {miR} x {TS}x2
orthogonal_exp = ['ELP_exp09', 'ELP_exp12', 'ELP_exp13']
fcs_orthogonal = stats2[(stats2['exp'].isin(orthogonal_exp))].groupby(by=['miR_construct','ts_construct'])[['output_gmean_fc']].apply('mean').reset_index()
fcs_orthogonal = fcs_orthogonal.merge(metadata_miR, how='left', on='miR_construct')
fcs_orthogonal = fcs_orthogonal.merge(metadata_ts, how='left', on='ts_construct')

# matched exp: {miR} x TSx{n}
matched_exp = ['ELP_exp11','ELP_exp49', 'ELP_exp50',]
fcs_matched = stats2[(stats2['exp'].isin(matched_exp))].groupby(by=['miR_construct','ts_construct'])[['output_gmean_fc']].apply('mean').reset_index()
fcs_matched = fcs_matched.merge(metadata_miR, how='left', on='miR_construct')
fcs_matched = fcs_matched.merge(metadata_ts, how='left', on='ts_construct')

In [None]:
# Create color palettes for miR/ts characterization
metadata_comb = data.drop_duplicates(['constructs'])
metadata_comb['color'] = base.colors['gray']
metadata_comb['matched'] = metadata_comb['miR'] == metadata_comb['ts']
metadata_comb.loc[metadata_comb['matched'], 'color'] = base.colors['green']
metadata_comb.loc[metadata_comb['ts']=='none', 'color'] = 'black'

metadata_comb_dict = metadata_comb.set_index('constructs').to_dict('dict')
matched_palette = metadata_comb_dict['color']

metadata_comb['color'] = 'black'
metadata_comb.loc[metadata_comb['ts_num']>0, 'color'] = base.colors['green']
metadata_comb.loc[metadata_comb['ts_num']==2, 'color'] = metadata_comb.loc[metadata_comb['ts_num']==2, 'color'].apply(base.get_light_color)
metadata_comb.loc[metadata_comb['ts_num']==4, 'color'] = metadata_comb.loc[metadata_comb['ts_num']==4, 'color'].apply(base.get_light_color).apply(base.get_light_color)

metadata_comb_dict = metadata_comb.set_index('constructs').to_dict('dict')
ts_num_palette = metadata_comb_dict['color']

In [None]:
fcs = fcs_orthogonal
plot_df = fcs[(fcs['ts']!='na') & (fcs['miR']!='na') & (fcs['miR_promoter']=='hPGK.d') 
              & (fcs['ts_num'].isin([0,2]))].pivot(index='miR', columns='ts', values='output_gmean_fc')
g = sns.heatmap(plot_df, annot=True, fmt='.2f', cmap=sns.light_palette(base.colors['green'], as_cmap=True))
g.set(xlabel='target sites (x2)', ylabel='microRNA', title='Relative target expression')

# outline matched target sites
for i in range(len(plot_df)-1):
    g.add_patch(matplotlib.patches.Rectangle((i, i), 1, 1, fill=False, edgecolor='black', lw=2))

g.figure.savefig(rd.outfile(output_path/(f'heatmap_orthogonal-ts.svg')), bbox_inches='tight')

In [None]:
fcs = fcs_matched
plot_df = fcs[(fcs['miR_promoter']=='hPGK.d') & (fcs['miR']!='none')].pivot(index='miR', columns='ts_num', values='output_gmean_fc')
g = sns.heatmap(plot_df, annot=True, fmt='.2f', cmap=sns.light_palette(base.colors['green'], as_cmap=True))
g.set(xlabel='number of matched\ntarget sites', ylabel='microRNA', title='Relative target expression')

g.figure.savefig(rd.outfile(output_path/(f'heatmap_ts-num.svg')), bbox_inches='tight')

In [None]:
fig, axes = plt.subplots(1,5, sharey=True, figsize=(10,4))

plot_df = stats2[(stats2['ts']!='na') & (stats2['miR']!='na') & (stats2['miR_promoter']=='hPGK.d') 
                 & (stats2['ts_num'].isin([0,2])) & (stats2['exp'].isin(orthogonal_exp))]

for ax, (miR, d) in zip(axes, plot_df.groupby('miR')): 
    sns.stripplot(d, x='ts', y='output_gmean', hue='constructs', palette=matched_palette, ax=ax, legend=False, 
                  s=10, jitter=0.2, linewidth=1, edgecolor='white')
    ax.set(yscale='log', xlabel='', title=f'{miR}')
    sns.despine(ax=ax)
    ax.set_xticklabels(ax.get_xticklabels(), rotation=90)

axes[2].set(xlabel='target site (x2)')
ylabel_bbox = axes[0].get_yticklabels()[0]
title_bbox = axes[0].set_title(axes[0].get_title())
axes[0].annotate(text='miR:', xy=(1, 0.5), xycoords=(ylabel_bbox, title_bbox), 
                 ha="right", va="center",)
axes[0].set(ylabel='target')
        
fig.savefig(rd.outfile(output_path/(f'stats_orthogonal-ts.svg')))

In [None]:
fig, axes = plt.subplots(1,4, sharey=True, figsize=(10,4))

plot_df = stats2[(stats2['miR_promoter']=='hPGK.d') & (stats2['miR']!='none') & (stats2['exp'].isin(matched_exp))]

for ax, (miR, d) in zip(axes, plot_df.groupby('miR')): 
    sns.stripplot(d, x='ts_num', y='output_gmean', hue='constructs', palette=ts_num_palette, ax=ax, legend=False, 
                  s=10, jitter=0.2, linewidth=1, edgecolor='white')
    ax.set(yscale='log', xlabel='', title=f'{miR}')
    sns.despine(ax=ax)

axes[1].set_xlabel('number of matched target sites', x=1, ha='center')
ylabel_bbox = axes[0].get_yticklabels()[0]
title_bbox = axes[0].set_title(axes[0].get_title())
axes[0].annotate(text='miR:', xy=(1, 0.5), xycoords=(ylabel_bbox, title_bbox), ha="right", va="center",)
axes[0].set(ylabel='target')

fig.savefig(rd.outfile(output_path/(f'stats_ts-num.svg')))