In [None]:
sm = snakemake

In [None]:
import spherpro.bro as sb
import spherpro.db as db

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import plotnine as gg

import re
import pathlib

%matplotlib inline

In [None]:
from src.variables import Vars
from src.config import Conf

## Aim

Plot the fraction of overexpressing cells per conditions

In [None]:
class CurVariableHelper(Vars):
    SUFFIX_FILNB = 'Nb'
    COL_COEFNAME = 'coefname'
    COL_D2RIM = 'DistRim'
    COL_DELTA = 'delta'
    COL_DF = 'DF'
    COL_DILUTION = 'dilution'
    COL_DOXO = 'doxocycline'
    COL_FC = 'fc'
    COL_FC_CENS = 'fc_cens'
    COL_FITCONDITIONNAME = 'FitConditionName'
    COL_FITTED = 'fitted'
    COL_FLAGPOS = 'IsFlagpos'
    COL_FLAGPOSNB = COL_FLAGPOS+SUFFIX_FILNB
    COL_GENE = 'gene'
    COL_GENE_UNTAGGED = 'gene_untagged'
    COL_GFPPOS = 'IsGfppos'
    COL_GFPPOSNB = COL_GFPPOS+SUFFIX_FILNB
    COL_GOODNAME = 'goodname'
    COL_ISNB = 'isnb'
    COL_ISSIG = 'is_sig_sel'
    COL_N = 'n'
    COL_NB = 'nb'
    COL_N_OVEREXPR = 'n_overexpr'
    COL_OBJ_NR = db.objects.object_number.key
    COL_P = 'p'
    COL_POSSTAT = 'PosStat'
    COL_P_CORR = 'p_corrected'
    COL_RESID = 'residual'
    COL_TAG = 'tag'
    COL_TAGFLAG = 'TagFLAG'
    COL_TAGGFP = 'TagGFP'
    COL_TAGSTAT = 'TagStat'
    COL_TSTAT = 't'
    COL_WORKING = 'working'
    
    COL_FIL = 'filters'
    COL_FILSTAT = 'filter_stat'
    COL_FILTAG = 'filter_tag'
       
    
V = CurVariableHelper

In [None]:


class Config(Conf):
    fn_config = sm.input.fn_config
    fol_paper = pathlib.Path(sm.output.fol_plots)
    REF_COND = 'ctrl'
    SUFFIX_NB = '-NB'
    FIL_FLAGPOS = 'is-flagpos'
    FIL_FLAGPOSNB = FIL_FLAGPOS+V.SUFFIX_FILNB
    FIL_GFPPOS = 'is-gfppos'
    FIL_GFPPOSNB = FIL_GFPPOS+V.SUFFIX_FILNB
    FILS = [FIL_GFPPOS, FIL_GFPPOSNB,
    FIL_FLAGPOS,FIL_FLAGPOSNB]
    FILS_POS = [V.COL_GFPPOS, V.COL_GFPPOSNB]
    DIC_TAG = {
              V.COL_GFPPOS: V.COL_TAGGFP,
               V.COL_GFPPOSNB: V.COL_TAGGFP,
              }
    
    name_gfp = 'GFP'
    
C = Config        

In [None]:
C.fol_paper.mkdir(exist_ok=True)

## 0) Setup configuration and bro

In [None]:


import spherpro.bromodules.helpers_vz as helpers_vz



In [None]:
bro = sb.get_bro(C.fn_config)

hpr = helpers_vz.HelperVZ(bro)

## 1) Analysis

In [None]:



def update_condmeta(dat_condmeta):
    dat_condmeta = dat_condmeta.copy()
    
    def split_names(x):
        c = re.compile('(?P<{}>.*)_(?P<{}>.*)'.format(V.COL_GENE_UNTAGGED, V.COL_TAG))

        m = c.match(x)
        g = m.groups() 
        return pd.Series({l: g[i-1] for l, i in c.groupindex.items()},name=x.index)
    
    dat_condmeta = dat_condmeta.join(dat_condmeta[V.COL_CONDNAME].apply(split_names))
    # Fix the one GFP without GFP-FLAG
    #dat_condmeta[V.COL_CONDNAME] = dat_condmeta[V.COL_CONDNAME].replace({'GFP_nan': 'GFP_GFP-FLAG'})
    # Fix the one GFP without GFP-FLAG
    #dat_condmeta[V.COL_CONDNAME] = dat_condmeta[V.COL_CONDNAME].replace({'GFP_nan': 'GFP_GFP-FLAG'})
    dat_condmeta[V.COL_TAGGFP] = dat_condmeta[V.COL_CONDNAME].map(lambda x: x.find('GFP') >= 0)
    dat_condmeta[V.COL_TAGFLAG] = dat_condmeta[V.COL_CONDNAME].map(lambda x: (x.find('FLAG') >= 0) & (x.find('GFP') < 0))
    # Make the gene also be tag specific, as these are different constructs 
    dat_condmeta[V.COL_GENE] = dat_condmeta.apply(lambda x: '_'.join([x[V.COL_GENE_UNTAGGED], x[V.COL_TAG]]), axis=1)
    return dat_condmeta

def get_data(channel_name, measurement_name, transf=True, object_type='cell'):
    fil = (dat_measmeta.channel_name == channel_name) & (dat_measmeta.measurement_name == measurement_name)
    meas_ids = list(dat_measmeta.loc[fil, V.COL_MEASID])
    dat=hpr.get_data(meas_ids=meas_ids,object_type=object_type)
    if transf:
        dat[V.COL_VALUE] = dat.groupby(V.COL_MEASID)[V.COL_VALUE].transform(cur_transf)
    return dat

def get_fildat(filname):
    d = bro.doquery(bro.session.query(db.object_filters, db.object_filter_names.object_filter_name)
               .join(db.object_filter_names)
               .filter(db.object_filter_names.object_filter_name == filname))
    return d

def rename_measurement(dat, name):
    dat = dat.rename({V.COL_VALUE: name}, axis=1)
    dat = dat.drop(V.COL_MEASID, axis=1)
    return dat

def get_fildats(filnames):
    q_obj = (bro.data.get_objectmeta_query()
             .filter(db.objects.object_type == 'cell')
            # .filter(db.objects.image_id == 101)
            .join(db.conditions, db.images.condition_id == db.conditions.condition_id)
            .add_columns(db.conditions.condition_id, db.conditions.plate_id)
        )

    for fil_name in filnames:
        fil = (bro.session.query(db.object_filters)
                   .join(db.object_filter_names)
                   .filter(db.object_filter_names.object_filter_name == fil_name)).subquery()
        q_obj = (q_obj
             .join(fil, fil.c.object_id == db.objects.object_id)
             .add_columns(fil.c.filter_value.label(fil_name))
            )
    return bro.doquery(q_obj)
    



def sort_levels(col, vals):
    return pd.Categorical(col, categories=vals)



def reverse_logical(col):
    return pd.Categorical(col, categories=[True, False])

In [None]:


dat_imgmeta = hpr.get_imgmeta()


dat_condmeta = bro.doquery(bro.session.query(db.conditions))
dat_condmeta = update_condmeta(dat_condmeta)

In [None]:
dat_obj = bro.doquery(bro.session.query(db.objects.object_id, db.objects.image_id)
                     .filter(db.objects.object_type == 'cell'))

In [None]:
C.FIL_LM_CLASSES = ['doubt', 'ctrl', 'oexp-NB', 'oexp']

def get_fitcond(dat):
    dat[V.COL_FITCONDITIONNAME] = C.FIL_LM_CLASSES[1] # ref
    
    fil = dat[C.FIL_GFPPOSNB] > 0
    
    dat.loc[fil, V.COL_FITCONDITIONNAME] = C.FIL_LM_CLASSES[2]
    
    fil = dat[C.FIL_GFPPOS] == 2
    
    dat.loc[fil, V.COL_FITCONDITIONNAME] = C.FIL_LM_CLASSES[3]
    
    fil = dat[C.FIL_GFPPOS] == 1
    
    dat.loc[fil, V.COL_FITCONDITIONNAME] = C.FIL_LM_CLASSES[0]
    #dat[V.COL_FITCONDITIONNAME] = pd.Categorical(dat[V.COL_FITCONDITIONNAME], categories=C.FIL_LM_CLASSES)
    
    return dat[[V.COL_FITCONDITIONNAME, V.COL_OBJID]]

In [None]:
dat_fil = get_fildats(C.FILS)
dat_fil = get_fitcond(dat_fil)

In [None]:
dat_fil

In [None]:


dat_filstat = (dat_fil
    .assign(val=1)
    .pivot_table(values='val', index=V.COL_OBJID,
                 columns=V.COL_FITCONDITIONNAME,
                fill_value=0)
    .reset_index()
    .merge(dat_obj)
    .merge(dat_imgmeta)
    .groupby(V.COL_CONDID)[C.FIL_LM_CLASSES].mean()
     .rename_axis(columns=V.COL_FIL)
     .stack()
     .rename(V.COL_FILSTAT)
     .to_frame()
     .reset_index()
     .assign(**{V.COL_NB: lambda x: ['Nb' in n for n in x[V.COL_FIL]]})
     #.eval(f'{V.COL_FILTAG} = @gettag({V.COL_FIL})', local_dict={'gettag': lambda x: [C.name_gfp if 'Gfp' in n else C.name_flag 
     #                                                                                 for n in x]})
)


# - sort conditions names according to average fraction of overexpressing cell
# - Plot as stacked bar graphs
# 
# Facet by: COL_FILTAG ~ COL_CONDITIONNAME
# x=conditionid, y=fraction, color=is_nb

In [None]:
dat_filstat

In [None]:
condid_sort = (dat_filstat
                .merge(dat_condmeta)
               .query(f'{V.COL_FIL} == "{C.FIL_LM_CLASSES[3]}"')
               
               .groupby(V.COL_CONDNAME)[V.COL_FILSTAT].mean()
               .reset_index()
               .sort_values(V.COL_FILSTAT, ascending=False)[V.COL_CONDNAME]
              )

In [None]:


tagorder = ['GFP','FLAG']

p = (dat_condmeta
     .merge(dat_filstat)
     .assign(sb = lambda x: [f'Set {i}' for i in x[db.conditions.sampleblock_id.key]])
     >>
     gg.ggplot(gg.aes(x=f'{V.COL_CONDID}.astype(str)', y=V.COL_FILSTAT,
                fill=f'pd.Categorical({V.COL_FIL}, categories={C.FIL_LM_CLASSES})')) +
     gg.facet_grid(f'sb~sort_levels({V.COL_CONDNAME}, condid_sort)',
                   scales='free_x')+
     gg.geom_bar(stat='identity')+
     gg.ylab('Fraction [a.u.]') +
     gg.xlab('Sphere ID')+
     gg.guides(fill = gg.guide_legend(title = "Category"))+
     gg.theme(figure_size=(13,3),
              strip_text_x = gg.element_text(angle = 45, va='bottom',ha='left'),
              axis_text_x =gg.element_blank())
)
p

In [None]:
gg.ggsave(p, C.fol_paper / 'oexp_overview.pdf', limitsize=False)