# Quality filtering of the data

Removal of the genes whose localization was mismatched with the known localizations [@Chong2015-yn], images containing abnormalities, artifacts, and excessively high heterogeneity, and the replicates containing very low cell numbers. 

In [15]:
## logging functions
from icecream import ic as info
import logging
## data functions
import numpy as np
import pandas as pd
## system functions
from os.path import dirname
import sys
## system functions from roux
from roux.lib.io import read_table
from roux.lib.io import to_table
## visualization functions from roux
from roux.viz.io import to_plot
## data functions from roux
import roux.lib.dfs as rd # attributes
sys.path.append('../')

In [16]:
input_path=None
output_path=None #f'{output_dir_path}/04_filteredby_cell_counts.tsv',
input_file_visual_inspection_path=None

controls=False ## if input data is controls or not
column_cells=None ## onto which the cutoff is applied.
cutoff_min_cells=None ## cutoff for minimum cells per image
cutoffs_min_cells_q=[0.01,0.05] # to show on the distribution

In [18]:
## inferred parameters
output_dir_path=dirname(output_path)    

In [None]:
## common functions
def check_filter_metainfo_images_removed( #noqa
    df0,
    column_filter,
    ):
    ## check number of images per constructs that would be removed #noqa
    df3_=(df0
         ## select removed rows
        .loc[df0[column_filter],:]
         ## count images per construct 
        .groupby('label')['image id'].nunique()
        # arrange
        .sort_index(ascending=True).sort_values(ascending=False)
        .reset_index()
        .rename(columns={
            'label':'construct name',
            'image id':'images removed',
                        })     
        )
    logging.info(f"-> total images {column_filter}: {df3_['images removed'].sum()}")
    logging.info(df3_)
    return df3_

def filter_metainfo(
    df2,
    controls=False,
    ):
    ## stats: how many rows will be removed by each of the filters
    info(df2.filter(regex='^remove because .*').sum()/len(df2))#.any(axis=1)

    df3=(
        df2
        .log('pairs')
        .loc[~(df2.filter(regex='^remove because .*').any(axis=1)),:].log('pairs')
    )
    df3=(
        df3
        ## keep the ones that have all 4 constructs 
        .groupby('pairs').filter(lambda df: df['label'].nunique()==4).log('pairs')
    )
    df3=(
        df3
        ## keep the ones that have both the WT and DELTA background      
        .groupby('pairs')
        .filter(lambda df: df['status partner'].nunique()==2).log('pairs')
        )
    info(set(df2['pairs']) - set(df3['pairs']))
    return df3

### Input data

In [19]:
from modules.tools.io import read_pre_processed
df0=read_pre_processed(
    input_path,
    rename=False,
    )
df0.head(1)

In [20]:
## cells per image
df0['cells per image']=df0.groupby(['label','replicate','URL'])['cell_id'].transform('nunique')
## cells per replicate
df0['cells per replicate']=df0.groupby(['label','replicate'])['cell id per pair'].transform('nunique')
## validate that number of cells calculated with independent methods match
assert all((df0.loc[:,['label','replicate','URL','cells per image']].drop_duplicates().groupby(['label','replicate'])['cells per image'].sum()).sort_index() == (df0.loc[:,['label','replicate','cells per replicate']].drop_duplicates().set_index(['label','replicate'])['cells per replicate']).sort_index())

### Filtering based on possible artifacts identified in the visual inspection

In [21]:
## remove genes with abnormalities as detected from the visual inspections
remove_genes=['BDF1', # localization mismatch
              'KIN1', # no cells in the wt background
              'YNR048W', # no cells in the delta background
             ]

In [22]:
df0['remove because of localization mismatch']=(df0['gene symbol query'].isin(remove_genes) | df0['gene symbol partner'].isin(remove_genes))

In [23]:
## remove replicates with high heterogeneity in few replicates but not all
from roux.lib.io import read_excel
df02=read_excel(input_file_visual_inspection_path,
          sheet_name='Sheet1').log()
if not controls:
    df02=df02.log().loc[(df02['Control']!=True),:].log()
else:
    df1_=(df02
    .log()
    .loc[(df02['Control']==True),:]
    .rd.dropby_patterns('WT')
    .log()
      .loc[:,['gene symbol',
              "replicates with abnormalities DELTA background (visual inspection)",
              "images with abnormalities DELTA background (visual inspection)",
              "image ids with abnormalities DELTA background (visual inspection)",
             "replicates with high heterogeneity (visual inspection)",
             ],
          ]
     )  

    df2_=(df02
        .loc[((df02['gene symbol'].isin(df1_['gene symbol'].tolist())) & (df02['Control']!=True)),:]
        .rd.dropby_patterns('DELTA')
        .loc[:,['gene symbol',
                      "replicates with abnormalities WT background (visual inspection)",
                      "images with abnormalities WT background (visual inspection)",
                      "image ids with abnormalities WT background (visual inspection)",]]
        .log()  
         )

    df02=df2_.log.merge(right=df1_,
                   on='gene symbol',
                   how='inner',
                  validate="1:1",
                   validate_equal_length=True,
                  )  
df02.head(1)

In [24]:
def rename_columns_(df):
    """
    Rename columns to be compatible with `pd.concat`.
    """
    return df.rename(
                   columns={c: 'replicate' if c.startswith('replicates') else \
                            'URL' if c.startswith('images') else \
                            'image id' if c.startswith('image id') else \
                            c for c in df},
                   errors='raise',
               )
## remove abnormalities/artifacts
df1=pd.concat({k:(df02
              .loc[:,['gene symbol',
                 f"replicates with abnormalities {k} background (visual inspection)",
                 f"images with abnormalities {k} background (visual inspection)",
                 f"image ids with abnormalities {k} background (visual inspection)",
                     ]]
              .pipe(rename_columns_)
               ) for k in ['WT','DELTA']},
             axis=0,
             names=['status partner'],
             ).reset_index(0)

In [25]:
## high henterogeneity replicates
df2=(df02
.loc[:,['gene symbol','replicates with high heterogeneity (visual inspection)']]
.pipe(rename_columns_)
    )
df2=pd.concat({k:df2 for k in ['WT','DELTA']},
          axis=0,
          names=['status partner'],
          ).reset_index(0)

In [26]:
## combine all
df3=(
    pd.concat(
    {
        'artifacts':df1,
        'high heterogeneity':df2,
    },
    axis=0,
    names=['reason'],
    )
    .reset_index(0)
    .log.drop_duplicates()
)
df3.head(1)

In [27]:
assert len(df3.loc[((df3['replicate'].isnull()) & ~(df3['URL'].isnull())),:])==0, "URLs removed irrespective of the replicate?"

In [None]:
to_table(df3,f'{output_dir_path}/02_mapped_filters/00_filters_raw.tsv')

In [None]:
def to_sample_info_(
    df1,
    column_replicate='replicate',
    column_image='URL',
    ):
    """
    Format the table to be compatible with the table with sample info.  
    
    Note: 
        Replicate with heterogeneity=='all' means all replicates show similar heterogeneity. So NOT to be filtered out.
    """
    ## ignore the rows with all the replicates showing heterogeneity
    df1=(df1
        .log()
        .loc[(df1[column_replicate]!='all'),:] ## do not filter rows with 'all' value 
        .log()
        )
    
    for c in [column_replicate,column_image]:
        ## split the replicate numbers to lists in case >1
        df1[c]=df1[c].apply(lambda x: [int(i) for i in x.split(';')] if isinstance(x,str) else np.nan if pd.isnull(x) else int(x))
        ## split lists to separate rows
        df1=df1.log.explode(column=c)
        
    ## split the 'image id's to replicate and URLs
    df1['image id']=df1['image id'].str.strip().str.strip(';').str.split(';')
    df1=df1.log.explode(column='image id')
    df1['image id']=df1['image id'].replace('None', np.nan)
    assert df1['image id'].dropna().apply(lambda x: isinstance(x,str) and ':' in x).all()
    df1['replicate']=df1.apply(lambda x: x['image id'].split(':')[0] if not pd.isnull(x['image id']) else x['replicate'],axis=1)
    df1['URL']=df1.apply(lambda x: x['image id'].split(':')[1] if not pd.isnull(x['image id']) else x['URL'],axis=1)
    # assert df1['URL'].isin(['replicate1','replicate2','replicate3']).all()
    df1=df1.drop(['image id'],axis=1)
        
    ## ignore the rows with missing data
    df1=(df1
        .log()
        .loc[~(df1[column_replicate].isnull()),:] ## do not filter rows with 'all' value 
        .log()
        )
    
    # format the replicates and the URLs
    ## rename to be consistent with the metainfo table
    df1[column_image]=df1[column_image].apply(lambda x: str(x).zfill(9) if not pd.isnull(x) else x)
    df1[column_replicate]=df1[column_replicate].apply(lambda x: f"replicate{int(x)}" if isinstance(x,(int,float)) else x)
    assert df1['replicate'].isin(['replicate1','replicate2','replicate3']).all(), df1['replicate'].unique()
    
    ## rename columns
    df1=df1.rename(columns={'gene symbol':'gene symbol query',
                            # column_replicate:'replicate',
                       },errors='raise')
    # info(df1['replicate'].value_counts())
    return df1.log.drop_duplicates()
df4=to_sample_info_(
    df1=df3,
    column_replicate='replicate',
    column_image='URL',
    )
df4.head(1)

In [None]:
to_table(df4,f'{output_dir_path}/02_mapped_filters/01_filters_renamed.tsv')

In [29]:
def map_remove_rows_(
    df0,
    remove_images,
    column_name,
    ):
    """
    Map the rows to be removed.
    """
    assert not column_name in df0
    df0=df0.reset_index(drop=True)
    rows=[]
    for d_ in remove_images:
        rows_=df0.rd.filter_rows(d_,verbose=False).index.tolist()
        if len(rows_)==0:
            if df0.rd.filter_rows({k:d_[k] for k in list(d_.keys())[:-1]},verbose=False)[list(d_.keys())[-1]].nunique()!=4:
                logging.warning(f"rows not available in the data for {d_}, maybe the image was prefiltered.")
            else:
                logging.error(f"rows not available in the data for {d_}")
        rows+=rows_
    df0[column_name]=df0.index.isin(list(set(rows)))
    return df0
## remove images/replicates marked from the visual inspection
df1=map_remove_rows_(
    df0=df0,
        remove_images=[
            {'gene symbol query':"DNF1",'replicate':["replicate1","replicate2"],},
            {'label':"CPR5-GFP CPR2-WT",'replicate':["replicate3"],'URL': ["001013001","001013003"],},
        ],
        column_name='remove because of abnormalities',
        )
## remove images/replicates marked from the visual inspection
df2=df1.copy()
for column_name,df_ in df4.groupby('reason'):
    df2=map_remove_rows_(
        df0=df2,
        remove_images=df_.drop(['reason'],axis=1).apply(lambda x: {key:value for key,value in x.to_dict().items() if not pd.isnull(value)},axis=1).tolist(),
        column_name=f'remove because of {column_name}',
        )

In [31]:
to_table(df2,f'{output_dir_path}/02_mapped_filters.tsv')

In [32]:
df2.head(1)

In [61]:
for col in df2.filter(regex='^remove because .*'):
    ## save table
    to_table(check_filter_metainfo_images_removed( #noqa
    df2,
    column_filter=col,
    ),
    f'{output_dir_path}/03_filtered_images_removed {col}.tsv')

In [33]:
df3=filter_metainfo(df2,controls=controls)
to_table(df3,f'{output_dir_path}/03_filtered.tsv')

In [34]:
df3.head(1)

### Filtering based on the number of cells

In [35]:
## cells per replicate
df3['cells per replicate (filtered)']=df3.groupby(['label','replicate'])['cell id per pair'].transform('nunique')

In [36]:
def filterby_cells(
    df0,
    column_cells,
    cutoff_min_cells,
    cutoffs_min_cells_q,
    output_dir_path,
    ):
    ## for counting and the plot
    ### column
    column_unit='image id' if column_cells.startswith('cells per image') else 'replicate' if column_cells.startswith('cells per replicate') else None
    ### dataframe
    df0_=df0.loc[:,['pairs','label']+[column_unit]+[column_cells]].drop_duplicates()

    ## dataframe with cutoffs
    df1_=(
        pd.DataFrame(
            {
                "cutoff q": cutoffs_min_cells_q,
                "cutoff": [df0_[column_cells].quantile(q) for q in cutoffs_min_cells_q]}
            )
        .append(
            pd.DataFrame({"cutoff":[cutoff_min_cells]})
        )
        .assign(
        **{
            'cutoff %':lambda df: df['cutoff q']*100,
            'cutoff label':lambda df: df.apply(lambda x: f"{x['cutoff']:.0f}"+(f"\n({x['cutoff %']}%)" if not pd.isnull(x['cutoff %']) else ''),axis=1)
        },
        )
    )

    # plot distribution with cutoffs
    from modules.tools.plot import plot_dist_cutoffs
    plot_dist_cutoffs(
        data=df0_[column_cells],
        df0=df1_, 
        ax=None,
        xlim_inset=[1,40],
        bins=100,
        )
    to_plot(f'{output_dir_path}/02_mapped_filters_plots/{column_cells}',fmts=['pdf','png'])

    ## images/replicates removed
    df2_=(df0_
    .loc[(df0_[column_cells]<cutoff_min_cells),:].log()
    .groupby(['label',])[column_unit].nunique()
    .sort_index(ascending=True).sort_values(ascending=False)
    .to_frame(f"{column_unit}s removed")
    .reset_index()
    .rename(columns={
        'label':'construct name',
        'image ids removed':'images removed',
                    })
    )
    
    ## apply filtering selected settings
    logging.info(f"-> filtering by cells applied to keep the rows with >={cutoff_min_cells} {column_cells}.")
    column_filter=f'remove because {column_cells} < {cutoff_min_cells}'
    df0[column_filter]=(df0[column_cells]<cutoff_min_cells)
    return df0,column_filter
df3,column_filter=filterby_cells(
    df3,
    column_cells,
    cutoff_min_cells,
    cutoffs_min_cells_q,
    output_dir_path,
    )

In [59]:
## save table
to_table(check_filter_metainfo_images_removed( #noqa
df3,
column_filter,
),
f'{output_dir_path}/04_filteredby_cell_counts_images_removed {column_filter}.tsv')

In [38]:
df4=df3.pipe(filter_metainfo)
to_table(df4,
        output_path,
        )

### Outputs: table by constructs

In [52]:
df01=read_table(f'{output_dir_path}/04_filteredby_cell_counts.tsv')

In [53]:
df01.head(1)

In [55]:
df1=(df01
.drop(['URL','replicate','cell_id','file','Row','Column','Field','R-C','image id',
      column_cells,
      ]+df01.filter(regex="^remove because.*").columns.tolist()+df01.filter(regex="^abundance.*").columns.tolist(),
      axis=1)
.log.drop_duplicates()
)

In [56]:
to_table(df1,f'{output_dir_path}/04_filteredby_cell_counts_byconstruct.tsv')

In [57]:
df1.head(1)