# Comp Screening - cell painting select variable features
## FOR COMPLTETE (datalock) cell painting dataset
#### _BEM 09-24-2021_


### Metadata
* U2OS cells screened with single / full plate(s) of KI FDA library
* 57 plates
* 9x fields captured per well @ 20X magnification

### Preprocessing
Cell profiler used to image correct, capture QC metrics, segment, and feature extract (AWS), cytominer used to aggregate

Adapting per-image QC appraoch per Caldera et al Nature Comm.

### What this does
Aim is to select features with minimal noise (control variance) and maximal information (treat variance) regardless of imaging batch 

* divide the dataset into the 2 analyses:
    * 1) Dose-time GT data of 316 compounds (DT)
    * 2) 1uM & 24hr GT & CS of 316 compounds (OG316)


* On a per-dataset basis:
    * 1) Calculate per plate ratio of non-DMSO to DMSO well dispersion (MAD)
    * 2) Record features with dispersion ratio > 66th pctl
    * 3) Retain features which pass above condition in > 50% of plates
    * 4) Run mRMR feature selection (retain top 150)
    

In [6]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import multiprocess as mp

from pymrmr import mRMR
from collections import Counter
from itertools import repeat
from scipy import special
from sklearn.cross_decomposition import CCA

np.seterr(invalid='ignore')
pd.set_option('mode.chained_assignment', None)

batch_data = '../1_SCALED_median_aggregated/'

## Define Functions

#### MAD calculation

In [2]:
def calc_disp(table, table_name, ratio_cutoff):
    
    #Calculate disp, should be < cutoff > 0        
    # load all features
    features = [col for col in \
                    table.columns if 'Metadata' not in col and \
                    'Number_of_Cells' not in col]

    CVs = pd.DataFrame()
    
    # initailize null disp table
    CVs = pd.DataFrame(np.zeros(len(features)),
                                       columns=['CV'], 
                                       index=features)
    
    # Calculate disp for each feature
    for f in features:

        # calculate the disp (MAD)
        DMSO_results = table.loc[(table.Metadata_perturbation == 'DMSO'),f].values
        med_DMSO = np.median(DMSO_results)
        DMSO_MAD = np.median(np.absolute(DMSO_results - med_DMSO))
        
        treat_results = table.loc[~(table.Metadata_perturbation == 'DMSO'),f].values
        treat_MAD = np.median(np.absolute(treat_results - med_DMSO))
        
        # store the ratio in table
        if DMSO_MAD != 0:
            CVs.at[f,'CV'] = treat_MAD / DMSO_MAD

    cutoff = np.quantile([x for x in CVs.CV.values if x != 0], ratio_cutoff)
    passing_feats = CVs.loc[(CVs.CV > cutoff)].index.to_list()
    
    hist, bins = np.histogram((CVs.CV.values + 0.5), bins=100)
    logbins = np.logspace(np.log10(bins[0]),np.log10(bins[-1]),len(bins))
    plt.hist((CVs.CV.values + 0.5), bins=logbins, color='grey')
    plt.xscale('log')
    plt.axvline((cutoff + 0.5), ls='--', c='red')
    plt.legend(['Rejected Features: %d' % len([x for x in CVs.CV.values if x < cutoff])])
    plt.savefig('plots/MADratio_'+ table_name + '_CV.pdf')
    plt.close()
            
    return passing_feats


In [3]:
def pass_disp(table, batch, ratio_cutoff, freq_cutoff, condition = 'fraction'):
    # Create list of features which PASS intraplate cv < cutoff
    ## Condition MUST BE either 'fraction' or 'quantile'
    
    # list all plates and features in batch
    plates = list(set(table.Metadata_Plate.values))
    features = [col for col in \
                    table.columns if 'Metadata' not in col and \
                    'Number_of_Cells' not in col]
    
    # Calc disp ratio on per-plate basis & build table of freq passing features
    
    # initailize null passing features table
    passing_count_table = pd.DataFrame(np.zeros(len(features)),
                                       columns=['counts'], 
                                       index=features)

    # run for each plate and append to passing features table    
    for p in plates:

        #filter table to plate and return features passing disp condition
        plate = table.loc[(table.Metadata_Plate==p),:]
        plate_passing_feat = calc_disp(plate, (batch+'_'+p), ratio_cutoff)   
        
        #if passing feature present add 1
        for f in plate_passing_feat:
            passing_count_table.at[f,'counts'] += 1
        
    # Create final passing feature list
    print('Calculating final passing feature list')
    
    # for DMSO cutoff is % of passing plates
    if condition == 'fraction':
        freq_cutoff = freq_cutoff*len(plates)
    # for treat cutoff is quantile of passing plates
    elif condition == 'quantile':
        freq_cutoff = np.quantile(passing_count_table.counts.values, freq_cutoff)
    
    CV_passing = passing_count_table.loc[(passing_count_table.counts > freq_cutoff)].index.to_list()

    plt.hist(passing_count_table.counts.values, bins='auto', color='grey')
    plt.axvline(freq_cutoff, ls='--', c='red')
    plt.legend(['Rejected Features: %d' % len([x for x in passing_count_table.counts.values if x < freq_cutoff])])
    plt.savefig('plots/' + batch + '_dratio_merge.pdf')
    plt.close()
    
    # print out some metrics
    print('# features passing for:')
    print(batch+' ratio: %d' %len(CV_passing))
    
    return CV_passing

#### Heatmap fn

In [4]:
def corr_heatmap(data, features, plotname):

    feature_all_values = {} # Dict with keys as features and values as ALL well values
    for f in features:
        feature_all_values[f] = data[f].values

    heatmap_data = []
    for f in feature_all_values.keys():
        feature1 = feature_all_values[f]

        tmp = []
        for f2 in feature_all_values.keys():
            feature2 = feature_all_values[f2]
            cor = np.corrcoef(feature1, feature2)[0, 1]
            tmp.append(cor)
        heatmap_data.append(tmp)

    sns.clustermap(data=heatmap_data, cmap="RdBu")
    plt.savefig('plots/'+ plotname +'_feature_CorrelationHeatMap.jpg')
    plt.close()

## Filter Datasets

In [11]:
# Dose-time plates
DT_plates = ['GT_run1_batch1_KI-CAS1_200211160001',
             'GT_run1_batch1_KI-CAS1_200211200001',
             'GT_run1_batch1_KI-CAS1_200211230001',
             'GT_run1_batch1_KI-CAS1_200212030001',
             'GT_run1_batch1_KI-CAS1_200212060001',
             'GT_run1_batch1_KI-CAS1_200212100001',
             'GT_run1_batch1_KI-CAS1_200212140001',
             'GT_run1_batch1_KI-CAS1_200212170001',
             'GT_run1_batch1_KI-CAS1_200212210001',
             'GT_run1_batch1_KI-CAS1_200213000001',
             'GT_run1_batch1_KI-CAS1_200213030001',
             'GT_run1_batch1_KI-CAS1_200213070001',
             'GT_run1_batch1_KI-CAS1_200213120001',
             'GT_run1_batch1_KI-CAS1_200213150001',
             'GT_run1_batch1_KI-CAS1_200213180001',
             'GT_run1_batch1_KI-CAS1_200213210001',
             'GT_run1_batch1_KI-CAS1_200214000001',
             'GT_run1_batch1_KI-CAS1_200214040001',
             'GT_run1_batch2_KI-CAS1_200721140001',
             'GT_run1_batch2_KI-CAS1_200721170001',
             'GT_run1_batch2_KI-CAS1_200722120001',
             'GT_run1_batch2_KI-CAS1_200724100001',
             'GT_run2_FDA_plate1_1',
             'GT_run2_FDA_plate1_2',
             'GT_run2_FDA_plate1_3',
             'GT_run2_FDA_plate1_4']


#### Robust (non DMSO) scale

In [12]:
data = pd.read_csv(batch_data+'0920201_robscale_QC_median_all_feature_table.gz', low_memory = False)

In [13]:
# DT
r_DT_data = data.loc[data.Metadata_Plate.isin(DT_plates)]

# OG316
r_OG316_data = data.loc[(data.Metadata_time_hr == 24)&(data.Metadata_conc_uM == 1)]

del data

#### Robust DMSO scale

In [16]:
data = pd.read_csv(batch_data+'0920201_robscaleDMSO_QC_median_all_feature_table.gz', low_memory = False)

In [17]:
# DT
rd_DT_data = data.loc[data.Metadata_Plate.isin(DT_plates)]

# OG316
rd_OG316_data = data.loc[(data.Metadata_time_hr == 24)&(data.Metadata_conc_uM == 1)]


del data

## Run Feature Selection

#### Robust (non DMSO) scale

In [21]:
data = r_DT_data
metadata_all = ['Number_of_Cells']+[col for col in data.columns if 'Metadata' in col]

ratio_pass_noise = pass_disp(data, batch = 'r_DT_low_noise', 
                             ratio_cutoff = 0.10, 
                             freq_cutoff = 0.90)

ratio_pass_signal = pass_disp(data, batch = 'r_DT_high_singal', 
                              ratio_cutoff = 0.90, 
                              freq_cutoff = 0.10)

ratio_pass_both = list(set(ratio_pass_signal).intersection(set(ratio_pass_noise)))

print('# of features passing noise & signal: %d' %len(ratio_pass_both))

data_both = data[list(ratio_pass_both)+metadata_all]

data_both.to_csv('09242021_QC_both_r_DT_feature_table.gz', index=False, compression='gzip')

corr_heatmap(data, ratio_pass_both, '09242021_r_DT_data_both')

Calculating final passing feature list
# features passing for:
r_DT_low_noise ratio: 1859
Calculating final passing feature list
# features passing for:
r_DT_high_singal ratio: 1254
# of features passing noise & signal: 861


In [22]:
data = r_OG316_data
metadata_all = ['Number_of_Cells']+[col for col in data.columns if 'Metadata' in col]

ratio_pass_noise = pass_disp(data, batch = 'r_OG316_low_noise', 
                             ratio_cutoff = 0.10, 
                             freq_cutoff = 0.90)

ratio_pass_signal = pass_disp(data, batch = 'r_OG316_high_singal', 
                              ratio_cutoff = 0.90, 
                              freq_cutoff = 0.10)

ratio_pass_both = list(set(ratio_pass_signal).intersection(set(ratio_pass_noise)))

print('# of features passing noise & signal: %d' %len(ratio_pass_both))

data_both = data[list(ratio_pass_both)+metadata_all]

data_both.to_csv('09242021_QC_both_r_OG316_feature_table.gz', index=False, compression='gzip')

corr_heatmap(data, ratio_pass_both, '09242021_r_OG316_data_both')

Calculating final passing feature list
# features passing for:
r_OG316_low_noise ratio: 2117
Calculating final passing feature list
# features passing for:
r_OG316_high_singal ratio: 1054
# of features passing noise & signal: 886


#### Robust DMSO scale

In [24]:
data = rd_DT_data
metadata_all = ['Number_of_Cells']+[col for col in data.columns if 'Metadata' in col]

ratio_pass_noise = pass_disp(data, batch = 'rd_DT_low_noise', 
                             ratio_cutoff = 0.10, 
                             freq_cutoff = 0.90)

ratio_pass_signal = pass_disp(data, batch = 'rd_DT_high_singal', 
                              ratio_cutoff = 0.90, 
                              freq_cutoff = 0.10)

ratio_pass_both = list(set(ratio_pass_signal).intersection(set(ratio_pass_noise)))

print('# of features passing noise & signal: %d' %len(ratio_pass_both))

data_both = data[list(ratio_pass_both)+metadata_all]

data_both.to_csv('09242021_QC_both_rd_DT_feature_table.gz', index=False, compression='gzip')

corr_heatmap(data, ratio_pass_both, '09242021_rd_DT_data_both')

Calculating final passing feature list
# features passing for:
rd_DT_low_noise ratio: 1859
Calculating final passing feature list
# features passing for:
rd_DT_high_singal ratio: 1254
# of features passing noise & signal: 861


In [25]:
data = rd_OG316_data
metadata_all = ['Number_of_Cells']+[col for col in data.columns if 'Metadata' in col]

ratio_pass_noise = pass_disp(data, batch = 'rd_OG316_low_noise', 
                             ratio_cutoff = 0.10, 
                             freq_cutoff = 0.90)

ratio_pass_signal = pass_disp(data, batch = 'rd_OG316_high_singal', 
                              ratio_cutoff = 0.90, 
                              freq_cutoff = 0.10)

ratio_pass_both = list(set(ratio_pass_signal).intersection(set(ratio_pass_noise)))

print('# of features passing noise & signal: %d' %len(ratio_pass_both))

data_both = data[list(ratio_pass_both)+metadata_all]

data_both.to_csv('09242021_QC_both_rd_OG316_feature_table.gz', index=False, compression='gzip')

corr_heatmap(data, ratio_pass_both, '09242021_rd_OG316_data_both')

Calculating final passing feature list
# features passing for:
rd_OG316_low_noise ratio: 2117
Calculating final passing feature list
# features passing for:
rd_OG316_high_singal ratio: 1054
# of features passing noise & signal: 886


## mRMR filtering

#### Robust DMSO scale

all run on AWS C5 instance - refer to mRMR_batch.py