In [None]:
import base
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pathlib import Path
import rushd as rd
import scipy as sp
import seaborn as sns

from statannotations.Annotator import Annotator

from importlib import reload
reload(base)

sns.set_style('ticks')
sns.set_context('talk',rc={'font.family': 'sans-serif', 'font.sans-serif':['Helvetica Neue']})

In [None]:
# Setup data loading
base_path = rd.datadir/'instruments'/'data'/'attune'
plates = pd.DataFrame({
    'data_path': [base_path/'chris'/'2024.06.15-rat-neurons'/'export', base_path/'kasey'/'2024.11.12_exp098.2'/'export', base_path/'kasey'/'2024.11.23_exp098.4'/'export'],
    'yaml_path': [base_path/'chris'/'2024.06.15-rat-neurons'/'metadata.yaml', base_path/'kasey'/'2024.11.12_exp098.2'/'export'/'wells.yaml', base_path/'kasey'/'2024.11.23_exp098.4'/'export'/'wells.yaml'],
    'exp': ['exp098', 'exp098.2', 'exp098.4'],
    'cell': ['neuron']*3,
    'dox': [1000]*3
})
output_path = rd.rootdir/'output'/'lenti_neuron'
cache_path = output_path/'lenti_neuron.gzip'

for p in plates['yaml_path'].unique():
    rd.plot.plot_well_metadata(p)

In [None]:
# Load data
data = pd.DataFrame()
channel_list = ['mRuby2-A','mGL-A']

if cache_path.exists(): data = pd.read_parquet(cache_path)
else: 
    data = rd.flow.load_groups_with_metadata(plates, columns=channel_list)
    for c in channel_list: data = data[data[c]>0]
    data.to_parquet(rd.outfile(cache_path))
display(data)

In [None]:
# Add metadata for constructs
metadata = base.get_metadata(rd.datadir/'projects'/'miR-iFFL'/'plasmids'/'construct-metadata.xlsx')
data = data.merge(metadata, how='left', on='construct')
display(data)

In [None]:
# Create dicts to specify colors/markers
metadata_dict = metadata.set_index('construct').to_dict('dict')
main_palette = metadata_dict['color']
main_markers = metadata_dict['markers']

In [None]:
gates = pd.DataFrame()
channel_list = ['mGL-A', 'mRuby2-A',]
for channel in channel_list:
    gates[channel] = data[data['construct']=='UT'].groupby(['exp'])[channel].apply(lambda x: x.quantile(0.9999))
gates.reset_index(inplace=True)

# Indicate which channels are relevant for each experiment
gates['marker'] = 'mGL-A'
gates['output'] = 'mRuby2-A'
gates['marker'] = gates['mGL-A']
gates['output'] = gates['mRuby2-A']

display(gates)

In [None]:
data['marker'] = data['mGL-A']
data['output'] = data['mRuby2-A']

In [None]:
plot_df = data[(data['dox']==1000) & (data['construct']!='UT')]
g = sns.displot(data=plot_df, x='marker', y='output', hue='construct', palette=main_palette, kind='kde',
                row='exp', col='construct', facet_kws=dict(margin_titles=True),
                log_scale=True, common_norm=False, levels=8)

for (exp, construct), ax in g.axes_dict.items():
    ax.axvline(gates.loc[gates['exp']==exp, 'marker'].values[0], c='black', ls=':', zorder=0)
    ax.axvline(1e3, c='black', ls=':', zorder=0)
    ax.axhline(gates.loc[gates['exp']==exp, 'output'].values[0], c='black', ls=':', zorder=0)

In [None]:
gates['marker'] = [1e3]*3

In [None]:
# Gate data by marker expression
def gate_data(df, gates):
    df = df.copy()
    exp = df['exp'].values[0] # the same for entire df, assuming df = data.groupby('exp')
    gates_dict = gates.set_index('exp').to_dict('dict') # format: column -> {index: value}
    df['expressing'] = df['marker'] > gates_dict['marker'][exp]
    return df

data = data.groupby('exp')[data.columns].apply(lambda x: gate_data(x,gates))
data.reset_index(inplace=True, drop=True)
df = data[(data['expressing']) & (data['construct']!='UT')]

In [None]:
# Bin data and calculate statistics
df_quantiles, stats, _, fits = base.calculate_bins_stats(df, by=['construct','moi','dox','exp','biorep'], num_bins=10)
stats = stats.merge(metadata, how='left', on='construct')
fits = fits.merge(metadata, how='left', on='construct')

In [None]:
# Since there is no marker-only condition, save the output expression stats for untransduced cells
baseline_df = data[(data['construct']=='UT')].groupby(['exp','biorep'])['output'].apply(sp.stats.gmean).rename('output_gmean').reset_index()

In [None]:
biorep = 1
plot_df = df_quantiles[(df_quantiles['moi']==7) & (df_quantiles['dox']==1000) & (df_quantiles['biorep']==biorep)]
fig, axes = plt.subplots(1,2, gridspec_kw=dict(width_ratios=(1,0.3)))

# line plot
ax = axes[0]
sns.lineplot(data=plot_df, x='bin_marker_quantiles_median', y='output', hue='construct', palette=main_palette, 
             legend=False, dashes=False, style='construct', markers=main_markers, ax=ax, markersize=9, markeredgewidth=1,
             estimator=sp.stats.gmean, errorbar=lambda x: (sp.stats.gmean(x) / sp.stats.gstd(x), sp.stats.gmean(x) * sp.stats.gstd(x)))
ax.set(xscale='log', yscale='log', xlabel='marker', )#ylim=(2e2,1e6), )#xlim=(2e3,3e4))
sns.despine(ax=ax)
baseline = baseline_df['output_gmean'].mean()
ax.axhline(baseline, color='black', ls=':')
ax.annotate('untransduced', (ax.get_xlim()[1], baseline), ha='right', va='bottom')

# histogram
ax = axes[1]
sns.kdeplot(data=plot_df, y='output', hue='construct', palette=main_palette, 
             legend=False, log_scale=True, common_norm=False, ax=ax)
sns.despine(ax=ax, bottom=True)
ax.set(xlabel='', ylim=axes[0].get_ylim(), ylabel='', yticklabels=[])
ax.get_xaxis().set_visible(False)

fig.savefig(rd.outfile(rd.rootdir/'output'/'for-review'/'joint_neuron-with-baseline.png'))

In [None]:
biorep = 2
plot_df = df_quantiles[(df_quantiles['dox']==1000) & (df_quantiles['biorep']==biorep) & (df_quantiles['moi']==1) &
                      ~(df_quantiles['name'].str.contains('FXN')) & ~(df_quantiles['name'].str.contains('FMRP'))]
fig, axes = plt.subplots(1,2, gridspec_kw=dict(width_ratios=(1,0.3)))

# line plot
ax = axes[0]
sns.lineplot(data=plot_df, x='bin_marker_quantiles_median', y='output', hue='construct', palette=main_palette, 
             legend=False, dashes=False, style='construct', markers=main_markers, ax=ax, markersize=9, markeredgewidth=1,
             estimator=sp.stats.gmean, errorbar=lambda x: (sp.stats.gmean(x) / sp.stats.gstd(x), sp.stats.gmean(x) * sp.stats.gstd(x)))
ax.set(xscale='log', yscale='log', xlabel='marker', )#ylim=(2e2,1e6), )#xlim=(2e3,3e4))
sns.despine(ax=ax)

# histogram
ax = axes[1]
sns.kdeplot(data=plot_df, y='output', hue='construct', palette=main_palette, 
             legend=False, log_scale=True, common_norm=False, ax=ax)
sns.despine(ax=ax, bottom=True)
ax.set(xlabel='', ylim=axes[0].get_ylim(), ylabel='', yticklabels=[])
ax.get_xaxis().set_visible(False)

In [None]:
biorep = 2
plot_df = df_quantiles[(df_quantiles['dox']==1000) & (df_quantiles['biorep']==biorep) & (df_quantiles['moi']==7) &
                      ~(df_quantiles['name'].str.contains('FXN')) & ~(df_quantiles['name'].str.contains('FMRP'))]
fig, axes = plt.subplots(1,2, gridspec_kw=dict(width_ratios=(1,0.3)))

# line plot
ax = axes[0]
sns.lineplot(data=plot_df, x='bin_marker_quantiles_median', y='output', hue='construct', palette=main_palette, 
             legend=False, dashes=False, style='construct', markers=main_markers, ax=ax, markersize=9, markeredgewidth=1,
             estimator=sp.stats.gmean, errorbar=lambda x: (sp.stats.gmean(x) / sp.stats.gstd(x), sp.stats.gmean(x) * sp.stats.gstd(x)))
ax.set(xscale='log', yscale='log', xlabel='marker', )#ylim=(2e2,1e6), )#xlim=(2e3,3e4))
sns.despine(ax=ax)

# histogram
ax = axes[1]
sns.kdeplot(data=plot_df, y='output', hue='construct', palette=main_palette, 
             legend=False, log_scale=True, common_norm=False, ax=ax)
sns.despine(ax=ax, bottom=True)
ax.set(xlabel='', ylim=axes[0].get_ylim(), ylabel='', yticklabels=[])
ax.get_xaxis().set_visible(False)

In [None]:
biorep = 3
plot_df = df_quantiles[(df_quantiles['dox']==1000) & (df_quantiles['biorep']==biorep) & (df_quantiles['moi']==1) &
                      ~(df_quantiles['name'].str.contains('FXN')) & ~(df_quantiles['name'].str.contains('FMRP'))]
fig, axes = plt.subplots(1,2, gridspec_kw=dict(width_ratios=(1,0.3)))

# line plot
ax = axes[0]
sns.lineplot(data=plot_df, x='bin_marker_quantiles_median', y='output', hue='construct', palette=main_palette, 
             legend=False, dashes=False, style='construct', markers=main_markers, ax=ax, markersize=9, markeredgewidth=1,
             estimator=sp.stats.gmean, errorbar=lambda x: (sp.stats.gmean(x) / sp.stats.gstd(x), sp.stats.gmean(x) * sp.stats.gstd(x)))
ax.set(xscale='log', yscale='log', xlabel='marker', )#ylim=(2e2,1e6), )#xlim=(2e3,3e4))
sns.despine(ax=ax)

# histogram
ax = axes[1]
sns.kdeplot(data=plot_df, y='output', hue='construct', palette=main_palette, 
             legend=False, log_scale=True, common_norm=False, ax=ax)
sns.despine(ax=ax, bottom=True)
ax.set(xlabel='', ylim=axes[0].get_ylim(), ylabel='', yticklabels=[])
ax.get_xaxis().set_visible(False)

In [None]:
biorep = 4
plot_df = df_quantiles[(df_quantiles['dox']==1000) & (df_quantiles['biorep']==biorep) & (df_quantiles['moi']==1) &
                      ~(df_quantiles['name'].str.contains('FXN')) & ~(df_quantiles['name'].str.contains('FMRP'))]
fig, axes = plt.subplots(1,2, gridspec_kw=dict(width_ratios=(1,0.3)))

# line plot
ax = axes[0]
sns.lineplot(data=plot_df, x='bin_marker_quantiles_median', y='output', hue='construct', palette=main_palette, 
             legend=False, dashes=False, style='construct', markers=main_markers, ax=ax, markersize=9, markeredgewidth=1,
             estimator=sp.stats.gmean, errorbar=lambda x: (sp.stats.gmean(x) / sp.stats.gstd(x), sp.stats.gmean(x) * sp.stats.gstd(x)))
ax.set(xscale='log', yscale='log', xlabel='marker', )#ylim=(2e2,1e6), )#xlim=(2e3,3e4))
sns.despine(ax=ax)

# histogram
ax = axes[1]
sns.kdeplot(data=plot_df, y='output', hue='construct', palette=main_palette, 
             legend=False, log_scale=True, common_norm=False, ax=ax)
sns.despine(ax=ax, bottom=True)
ax.set(xlabel='', ylim=axes[0].get_ylim(), ylabel='', yticklabels=[])
ax.get_xaxis().set_visible(False)

In [None]:
ts_label = {'na': 'base', 'NT': 'OL', 'T': 'CL', 'none': '–'}
marker_list = ['o', 'v', 'D', 'X']

In [None]:
plot_df = stats[(stats['dox']==1000) & ~(stats['name'].str.contains('FXN')) & ~(stats['name'].str.contains('FMRP'))]
display(plot_df)

In [None]:
pairs = [('na','NT'),('na','T'),('NT','T')]

In [None]:
fig, axes = plt.subplots(1,3, figsize=(10,4), gridspec_kw=dict(wspace=0.5,))

plot_df = stats[(stats['dox']==1000) & ~(stats['name'].str.contains('FXN')) & ~(stats['name'].str.contains('FMRP')) &
                (stats['count']>100) & (stats['moi']==7)]
plot_df2 = fits[(fits['dox']==1000) & ~(fits['name'].str.contains('FXN')) & ~(fits['name'].str.contains('FMRP')) &
                (fits['moi']==7)]
moi_list = [1,5,7]

ax = axes[0]
for num, group in plot_df.groupby('moi'):
    sns.stripplot(data=group, x='ts_kind', y='output_gmean', hue='construct', palette=main_palette,
                  legend=False, ax=ax, marker=marker_list[moi_list.index(num)], s=8, edgecolor='white', linewidth=1)
ax.set(title='Mean', xlabel='', ylabel='', yscale='log', )#ylim=(1e3,2e4),)

annotator = Annotator(ax, pairs, data=plot_df, x='ts_kind', y='output_gmean',)
annotator.configure(test='t-test_ind', text_format='star', loc='inside', line_height=0,
                    line_width=0.5, text_offset=-2, line_offset_to_group=0.2) #line_offset=100)
annotator.apply_and_annotate()

ax = axes[1]
for num, group in plot_df.groupby('moi'):
    sns.stripplot(data=group, x='ts_kind', y='output_std', hue='construct', palette=main_palette,
                  legend=False, ax=ax, marker=marker_list[moi_list.index(num)], s=8, edgecolor='white', linewidth=1)
ax.set(title='Std.', xlabel='', ylabel='', yscale='log', )#ylim=(1e3,2e4),)

ax = axes[2]
for num, group in plot_df2.groupby('moi'):
    sns.stripplot(data=group, x='ts_kind', y='slope', hue='construct', palette=main_palette,
                  legend=False, ax=ax, marker=marker_list[moi_list.index(num)], s=8, edgecolor='white', linewidth=1)
ax.set(title='Slope', xlabel='', ylabel='',)

# ax = axes[3]
# for num, group in plot_df.groupby('sort'):
#     sns.stripplot(data=group, x='ts_kind', y='output_variation', hue='construct', palette=main_palette,
#                   legend=False, ax=ax, marker=marker_list[num-1], s=8, edgecolor='white', linewidth=1)
# ax.set(title='CV', xlabel='', ylabel='',)

for ax in axes:
    ax.set_xticklabels([ts_label[x.get_text()] for x in ax.get_xticklabels()], rotation=45, ha='right',)
    sns.despine(ax=ax)

fig.savefig(rd.outfile(output_path/'stats_moi7.png'))

In [None]:
fig, axes = plt.subplots(1,3, figsize=(10,4), gridspec_kw=dict(wspace=0.5,))

plot_df = stats[(stats['dox']==1000) & ~(stats['name'].str.contains('FXN')) & ~(stats['name'].str.contains('FMRP')) &
                (stats['count']>10) & (stats['moi']==1)]
plot_df2 = fits[(fits['dox']==1000) & ~(fits['name'].str.contains('FXN')) & ~(fits['name'].str.contains('FMRP')) &
                (fits['moi']==1)]
moi_list = [1,5,7]

ax = axes[0]
for num, group in plot_df.groupby('moi'):
    sns.stripplot(data=group, x='ts_kind', y='output_gmean', hue='construct', palette=main_palette,
                  legend=False, ax=ax, marker=marker_list[moi_list.index(num)], s=8, edgecolor='white', linewidth=1)
ax.set(title='Mean', xlabel='', ylabel='', yscale='log', )#ylim=(1e3,2e4),)

annotator = Annotator(ax, pairs, data=plot_df, x='ts_kind', y='output_gmean',)
annotator.configure(test='t-test_ind', text_format='star', loc='inside', line_height=0,
                    line_width=0.5, text_offset=-2, line_offset_to_group=0.2) #line_offset=100)
annotator.apply_and_annotate()

ax = axes[1]
for num, group in plot_df.groupby('moi'):
    sns.stripplot(data=group, x='ts_kind', y='output_std', hue='construct', palette=main_palette,
                  legend=False, ax=ax, marker=marker_list[moi_list.index(num)], s=8, edgecolor='white', linewidth=1)
ax.set(title='Std.', xlabel='', ylabel='', yscale='log', )#ylim=(1e3,2e4),)

annotator = Annotator(ax, pairs, data=plot_df, x='ts_kind', y='output_std',)
annotator.configure(test='t-test_ind', text_format='star', loc='inside', line_height=0,
                    line_width=0.5, text_offset=-2, line_offset_to_group=0.2) #line_offset=100)
annotator.apply_and_annotate()

ax = axes[2]
for num, group in plot_df2.groupby('moi'):
    sns.stripplot(data=group, x='ts_kind', y='slope', hue='construct', palette=main_palette,
                  legend=False, ax=ax, marker=marker_list[moi_list.index(num)], s=8, edgecolor='white', linewidth=1)
ax.set(title='Slope', xlabel='', ylabel='',)

annotator = Annotator(ax, pairs, data=plot_df, x='ts_kind', y='slope',)
annotator.configure(test='t-test_ind', text_format='star', loc='inside', line_height=0,
                    line_width=0.5, text_offset=-2, line_offset_to_group=0.2) #line_offset=100)
annotator.apply_and_annotate()

# ax = axes[3]
# for num, group in plot_df.groupby('sort'):
#     sns.stripplot(data=group, x='ts_kind', y='output_variation', hue='construct', palette=main_palette,
#                   legend=False, ax=ax, marker=marker_list[num-1], s=8, edgecolor='white', linewidth=1)
# ax.set(title='CV', xlabel='', ylabel='',)

for ax in axes:
    ax.set_xticklabels([ts_label[x.get_text()] for x in ax.get_xticklabels()], rotation=45, ha='right',)
    sns.despine(ax=ax)

fig.savefig(rd.outfile(output_path/'stats_moi1.png'))

### Look at FXN/FMRP

In [None]:
plot_df = data[data['name'].str.contains('FXN')]
g = sns.scatterplot(data=plot_df, x='marker', y='output', hue='construct', palette=main_palette,
                    alpha=0.5)
g.set(xscale='log', yscale='log')
g.axvline(2e2, color='black', zorder=0)
g.axvline(1e3, color='black', zorder=0)

In [None]:
plot_df = data[data['name'].str.contains('FMRP')]
g = sns.scatterplot(data=plot_df, x='marker', y='output', hue='construct', palette=main_palette,
                    alpha=0.5)
g.set(xscale='log', yscale='log')
g.axvline(2e2, color='black', zorder=0)
g.axvline(1e3, color='black', zorder=0)

In [None]:
# Bin data and calculate statistics
df_quantiles2, stats2, _, fits2 = base.calculate_bins_stats(data[data['marker']>2e2], by=['construct','moi','dox','exp','biorep'], num_bins=10)
stats2 = stats2.merge(metadata, how='left', on='construct')
fits2 = fits2.merge(metadata, how='left', on='construct')
display(stats2)