In [None]:
sm = snakemake

In [None]:
import pandas as pd

import spherpro.bro as spb
import spherpro.datastore as spd
import spherpro.library as spl
import spherpro.configuration as conf
import spherpro.db as db

from imctools.scripts import exportacquisitioncsv
import imp
import pycytools as pct
import pycytools.library
import re
import os
import pathlib
import dateutil

import numpy as np
import spherpro.library as lib
import matplotlib.pyplot as plt
import plotnine as gg
import seaborn as sns
from matplotlib import colors
#import mpld3


In [None]:
from src.variables import Vars 

In [None]:
def get_valid_filename(s):
    s = str(s).strip().replace(' ', '_')
    return re.sub(r'(?u)[^-\w.]', '', s)

In [None]:
class C:
    fol_plots = pathlib.Path(sm.output.fol_plots)
    fn_config =  pathlib.Path(sm.input.fn_config)

In [None]:
C.fol_plots.mkdir(exist_ok=True)

In [None]:
class V(Vars):
    COL_GFP1 = 'Tm169'
    COL_GFP2 = 'Er167'
    VAR_STACK = 'FullStackFiltered'


In [None]:
bro = spb.get_bro(C.fn_config)

In [None]:
import spherpro.bromodules.helpers_vz as helpers_vz
imp.reload(helpers_vz)
hpr = helpers_vz.HelperVZ(bro)

Query the measuremetn metadata

In [None]:
meas = [('MeanIntensityComp', 'NbMeanMeanIntensityComp', 'NbMaxMeanIntensityComp','NbMeanNormMeanIntensityComp')]
fil = bro.filters.measurements.get_measmeta_filter_statements(
    channel_names=[None],
    stack_names=['FullStackFiltered'],
    measurement_names=meas,
    measurement_types=[None])
    

In [None]:
%%time
dat_measmeta = bro.doquery(bro.data.get_measmeta_query().filter(fil)
                              .add_columns(db.stacks.stack_name,
                                           db.ref_planes.channel_name, db.ref_stacks.scale))

In [None]:
dat_measmeta = (dat_measmeta
    .merge(bro.data.pannel, left_on=V.COL_CHANNELNAME, right_on=V.COL_METAL)
    #.query(f'{V.COL_WORKING} == True')
               )

In [None]:
dat_measmeta

In [None]:
import anndata

In [None]:
q_obj = bro.data.get_objectmeta_query().filter(db.objects.object_type == 'cell')

In [None]:
dat = bro.io.objmeasurements.get_measurements(q_obj=q_obj, dat_meas=dat_measmeta)
bro.io.objmeasurements.scale_anndata(dat)

In [None]:
# I added consoring as there were hugh outliers in the data
def censor_dat(x, q=99.9):
    x = np.copy(x)
    pmax = np.percentile(x,q=q)
    x[ x > pmax ] = pmax
    pmin = np.percentile(x,q=100-q)
    x[x < pmin] = pmin
    return x

def cur_logtransf(x):
    return np.log10(x+0.1)

def cur_transf(x):
    x= censor_dat(x, 99.9)
    x= cur_logtransf(x)
    return x

In [None]:
dat.X = np.apply_along_axis(cur_transf, 0, dat.X)

In [None]:
import scanpy as sc


In [None]:
q= (bro.session.query(db.images, db.conditions, db.acquisitions, db.sites, db.slideacs, db.slides)
    .join(db.conditions)
    .join(db.acquisitions)
    .join(db.sites)
    .join(db.slideacs)
    .join(db.slides)
    .join(db.valid_images)
)
dat_condition = bro.doquery(q)

In [None]:
bro.data._read_experiment_layout()

In [None]:
dat_condition = dat_condition.merge(bro.data.experiment_layout)

In [None]:
dat_condition = dat_condition.loc[:,~dat_condition.columns.duplicated()]

In [None]:
dat_d2rim = hpr.get_d2rim()

In [None]:
bro.helpers.anndata.add_anndata_obsmeta(dat, dat_d2rim.loc[:, [V.COL_OBJID, V.COL_D2RIM]])

In [None]:
bro.helpers.anndata.add_anndata_obsmeta(dat, dat_condition)

In [None]:
C.VAR_INT = 'MeanIntensityComp'
C.VAR_NB = 'NbMeanMeanIntensityComp'
good_channels = sorted(dat.var.query(f'{V.COL_WORKING}==1')[V.COL_CHANNELNAME].unique())

In [None]:
gfp_int = dat.var.query(f'({db.measurement_names.measurement_name.key}=="{C.VAR_INT}")&({db.ref_planes.channel_name.key} == "{V.COL_GFP1}")').index[0]

gfp_nb = dat.var.query(f'({db.measurement_names.measurement_name.key}=="{C.VAR_NB}")&({db.ref_planes.channel_name.key} == "{V.COL_GFP1}")').index[0]

In [None]:
def plt_relation(dat, varx, vary, varval, gridsize=80, clim=None,extent=None, ax=None,
                 colorbar=True, contour=False,contour_levels=5,contour_color='k',contour_linewidths=1, **kwargs):
    if ax is None:
        plt.figure()
        ax=plt.gca()
    p =ax.hexbin(dat.obs_vector(varx), dat.obs_vector(vary), C=dat.obs_vector(varval), gridsize=gridsize, clim=clim,
                 extent=extent,**kwargs)
    ax.set_aspect('equal', adjustable='box')
    if contour:
        points=p.get_offsets()
        vals=p.get_array()
        ax.tricontour(points[:,0],points[:,1],vals,colors=contour_color,
                      levels=contour_levels,linewidths=contour_linewidths,alpha=1,linestyles='solid'
                 )
    if colorbar:
        plt.colorbar(p)
    return p

In [None]:
val = dat.var.query(f'(\
                       {db.measurement_names.measurement_name.key}=="{C.VAR_INT}")\
                        &({db.ref_planes.channel_name.key} == "Yb171")').index[0]

In [None]:
p = plt_relation(dat[dat.obs.query(f'{V.COL_CONDNAME}=="Empty_nan"').index], gfp_int, gfp_nb, val,
             gridsize=20, extent=(-1,2,-1,2), marginals=True)

In [None]:
p = plt_relation(dat[dat.obs.query(f'{V.COL_CONDNAME}=="FGFR1_GFP-FLAG"').index], gfp_int, gfp_nb, val, gridsize=20)

In [None]:
print(p.get_offsets().shape)
print(p.get_array().shape)

In [None]:
import scipy as sp

In [None]:
np.corrcoef(np.hstack([p.get_offsets()[:,0], p.get_array()]))

In [None]:
p.get_offsets()[:,0].max()

In [None]:
plt.scatter(p.get_offsets()[:,0], p.get_offsets()[:,1],c=p.get_array())

In [None]:
p = plt_relation(dat[dat.obs.query(f'{V.COL_CONDNAME}=="MEK1 (S218D/S222D)_GFP-FLAG"').index],
             gfp_int, gfp_nb, val, gridsize=20)
(sp.stats.spearmanr(p.get_offsets()[:,0], p.get_array()),
                   sp.stats.spearmanr(p.get_offsets()[:,1], p.get_array()))

In [None]:
val = dat.var.query(f'(\
                       {db.measurement_names.measurement_name.key}=="{C.VAR_INT}")\
                        &({db.ref_planes.channel_name.key} == "Sm154")').index[0]
p = plt_relation(dat[dat.obs.query(f'{V.COL_CONDNAME}=="MEK1 (S218D/S222D)_GFP-FLAG"').index],
             gfp_int, gfp_nb, val, gridsize=20)
(sp.stats.spearmanr(p.get_offsets()[:,0], p.get_array()),
                   sp.stats.spearmanr(p.get_offsets()[:,1], p.get_array()))

In [None]:
val = dat.var.query(f'(\
                       {db.measurement_names.measurement_name.key}=="{C.VAR_INT}")\
                        &({db.ref_planes.channel_name.key} == "Yb171")').index[0]
p = plt_relation(dat[dat.obs.query(f'{V.COL_CONDNAME}=="MEK1 (S218D/S222D)_GFP-FLAG"').index],
             gfp_int, gfp_nb, val, gridsize=20)
(sp.stats.spearmanr(p.get_offsets()[:,0], p.get_array()),
                   sp.stats.spearmanr(p.get_offsets()[:,1], p.get_array()))



In [None]:
ax =p.get_figure().axes[0]

In [None]:
p.get_figure()

In [None]:
points = p.get_offsets()
vals = p.get_array()

np.arange(points.min())
ax_range = list(np.linspace(a, b, 10) for a, b in zip(points.min(axis=0), points.max(axis=0)))


In [None]:
ax_range[0]

In [None]:
x,y = (v.flatten() for v in np.meshgrid(ax_range[0], ax_range[1]))

In [None]:
c=sp.interpolate.griddata(points, values=vals, xi=np.vstack((x.flatten(), y.flatten())).T, method='cubic' )

In [None]:
plt.scatter(points[:,0],points[:,1],c=vals )

In [None]:
fil=np.isfinite(c)
x, y, c=list(v[fil] for v in (x,y,c))

In [None]:
plt.tricontour(x,y,c,colors='k')

In [None]:
val = dat.var.query(f'(\
                       {db.measurement_names.measurement_name.key}=="{C.VAR_INT}")\
                        &({db.ref_planes.channel_name.key} == "Nd144")').index[0]
p = plt_relation(dat[dat.obs.query(f'{V.COL_CONDNAME}=="MEK1 (S218D/S222D)_GFP-FLAG"').index],
             gfp_int, gfp_nb, val, gridsize=20)
np.corrcoef(p.get_offsets()[:,0], p.get_array())[0,1], np.corrcoef(p.get_offsets()[:,1], p.get_array())[0,1]

In [None]:
def plt_hexbin_cond(dat, val, cond, ax=None, colorbar=True,titlesize=8, measure_name_int=C.VAR_INT, measure_name_nb=C.VAR_NB, oexp_channel_name=V.COL_GFP1, **kwargs):
    if ax is None:
        plt.figure()
        ax=plt.gca()
    gfp_int = dat.var.query(f'({db.measurement_names.measurement_name.key}=="{measure_name_int}")&({db.ref_planes.channel_name.key} == "{oexp_channel_name}")').index[0]
    gfp_nb = dat.var.query(f'({db.measurement_names.measurement_name.key}=="{measure_name_nb}")&({db.ref_planes.channel_name.key} == "{oexp_channel_name}")').index[0]
    x=dat.obs_vector(gfp_int)
    xlims = [x.min(), x.max()]
    y=dat.obs_vector(gfp_int)
    ylims = [y.min(), y.max()]
    extent= xlims +ylims
    c = dat.obs_vector(val)
    clim = (c.min(), c.max())
    if cond is not None:
        dat = dat[dat.obs.query(f'{V.COL_CONDNAME}=="{cond}"').index]
        p= plt_relation(dat, gfp_int, gfp_nb, val, extent=extent, ax=ax, colorbar=colorbar,clim=clim,
                       **kwargs)
    #ax.set_xlabel('GFP ' + C.VAR_INT)
    #ax.set_ylabel('GFP '+ C.VAR_NB)
    if cond is not None:
        ax.set_title(cond, fontsize=titlesize)
    return ax
        

In [None]:
def wrapsubplots(total, wrap=None, **kwargs):
    if wrap is not None:
        cols = min(total, wrap)
        rows = 1 + (total - 1)//wrap
    else:
        cols = total
        rows = 1
    fig, ax = plt.subplots(rows, cols, **kwargs)
    return fig, ax

In [None]:
def hexbin_all_conds(dat, channel_name, gridsize=20, figsize=(30,30), measure_name_nb=C.VAR_NB,**kwargs):
    conds = sorted(dat.obs.condition_name.unique())
    fig, axs = wrapsubplots(len(conds), 7, figsize=figsize)
    val = dat.var.query(f'({db.measurement_names.measurement_name.key}=="{C.VAR_INT}")\
                            &({db.ref_planes.channel_name.key} == "{channel_name}")').index[0]
    for c, a in zip(conds, axs.flatten()):
        plt_hexbin_cond(dat, val,c, ax=a, colorbar=False, gridsize=gridsize,measure_name_nb=measure_name_nb,**kwargs)
    # add a big axes, hide frame
    fig.add_subplot(111, frameon=False)
    fig.colorbar(a.collections[0],ax=fig.axes)
    # hide tick and tick label of the big axes
    plt.tick_params(labelcolor='none', top='off', bottom='off', left='off', right='off')
    plt.grid(False)
    plt.xlabel(C.VAR_INT)
    plt.ylabel(measure_name_nb)
    
    for i in range(len(conds)-len(axs.flatten()),0):
        fig.delaxes(axs.flatten()[i])
    
    plt.suptitle(f'{bro.helpers.dbhelp.get_target_by_channel(channel_name)} - {channel_name}')
    return fig

In [None]:
def hexbin_all_chans(dat, condition, gridsize=20, figsize=(30,30),titlesize=8, channels=None,measure_name_nb=C.VAR_NB, **kwargs):
    if channels is None:
        chans = sorted(dat.var.channel_name.unique())
    else:
        chans = channels
    fig, axs = wrapsubplots(len(chans), 7, figsize=figsize)

    for c, a in zip(chans, axs.flatten()):
        val = dat.var.query(f'({db.measurement_names.measurement_name.key}=="{C.VAR_INT}")\
                            &({db.ref_planes.channel_name.key} == "{c}")').index[0]
        ax = plt_hexbin_cond(dat, val, condition, ax=a, colorbar=False, gridsize=gridsize,measure_name_nb=measure_name_nb,**kwargs)
        plt.colorbar(ax.collections[0], ax=ax)
        ax.set_title(f'{bro.helpers.dbhelp.get_target_by_channel(c)} - {c}', fontsize=titlesize)
    # add a big axes, hide frame
    fig.add_subplot(111, frameon=False)
    # hide tick and tick label of the big axes
    plt.tick_params(labelcolor='none', top='off', bottom='off', left='off', right='off')
    plt.grid(False)
    plt.xlabel(C.VAR_INT)
    plt.ylabel(measure_name_nb)
    
    for i in range(len(chans)-len(axs.flatten()),0):
        fig.delaxes(axs.flatten()[i])
    
    plt.suptitle(condition)
    return fig

In [None]:
def hexbin_per_plate(dat, condition, channel_name, plates=None, gridsize=20, figsize=(30,30),titlesize=8, channels=None,measure_name_nb=C.VAR_NB, **kwargs):
    if plates is None:
        plates = sorted(dat.obs.plate_id.unique())

    fig, axs = wrapsubplots(len(plates), len(plates), figsize=figsize)

    for p, a in zip(plates, axs.flatten()):
        val = dat.var.query(f'({db.measurement_names.measurement_name.key}=="{C.VAR_INT}")\
                            &({db.ref_planes.channel_name.key} == "{channel_name}")').index[0]
        ax = plt_hexbin_cond(dat[dat.obs.plate_id == p,:], val, condition, ax=a, colorbar=False, gridsize=gridsize,measure_name_nb=measure_name_nb,**kwargs)
        #plt.colorbar(ax.collections[0], ax=ax)
        ax.set_title(f'Plate {p}', fontsize=titlesize)
    # add a big axes, hide frame
    fig.add_subplot(111, frameon=False)
    fig.colorbar(a.collections[0],ax=fig.axes)
    # hide tick and tick label of the big axes
    plt.tick_params(labelcolor='none', top='off', bottom='off', left='off', right='off')
    plt.grid(False)
    plt.xlabel(C.VAR_INT)
    plt.ylabel(measure_name_nb)
    
    for i in range(len(plates)-len(axs.flatten()),0):
        fig.delaxes(axs.flatten()[i])
    
    plt.suptitle(f'{condition} - {bro.helpers.dbhelp.get_target_by_channel(channel_name)} - {channel_name}')
    return fig

Plot mek1 dd vs

In [None]:
cur_cond =  'MEK1 (S218D/S222D)_GFP-FLAG'
cur_marks = ['Tm169', 'Nd144', 'Sm154', 'Yb171', 'Nd143']

In [None]:
dat.obs.condition_name.unique()

In [None]:
fig = hexbin_all_chans(dat, cur_cond, channels=cur_marks,
                      figsize=(12, 2), gridsize=15,measure_name_nb='NbMaxMeanIntensityComp',
                      contour=True, contour_color='k',contour_linewidths=1)

fig.subplots_adjust(top=0.8)
fig.savefig(C.fol_plots / 'example_MEK1_DD.svg')
fig.savefig(C.fol_plots / 'example_MEK1_DD.png')

In [None]:
plt.ioff()
for chan in dat.var.channel_name.unique():
    fig = hexbin_all_conds(dat, chan,figsize=(20,20), gridsize=10,measure_name_nb='NbMaxMeanIntensityComp')
    target =bro.helpers.dbhelp.get_target_by_channel(chan)
    fig.savefig(C.fol_plots / f'oexp_vs_nb_{chan}_{get_valid_filename(target)}_{V.COL_GFP1}.png')
    plt.close()
plt.ion()
    

Plot 1 overview per plate

# Make it per marker, construct and plate

-> plot all the replicates next to each other:


Wo mitosis/apoptosis