In [None]:
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

from pathlib import Path
import re
import utils
import functools

%load_ext autoreload
%autoreload 2

In [None]:
# defined for each pipeline to simplify data parsing
def read_data(directory, regex=None, dfc=True, bins=4, debug_regex=False):
    directory = Path(directory)
    # some variables reused below
    extra_columns = ['ImageNumber', 'Parent_DilatedGC']
    
    extras = {
        'extra_columns': extra_columns,
        'reduce': True,
        'merge_fcn': functools.partial(utils.merge_reduced, merge_on=extra_columns)
    }
    
    extras_left = {
        'extra_columns': extra_columns,
        'reduce': True,
        'merge_fcn': functools.partial(utils.merge_reduced, merge_on=extra_columns, how='left')
    }

    # Parse information from filename using the provided regex
    result, _ = utils.analyze(directory / 'Image.csv', 
                              parsers=[
                                  utils.ImageParser(regex, debug_regex=debug_regex),                              
                              ])
    # Combine with DilatedGC for using to merge with other measures
    result, _ = utils.analyze(directory / 'DilatedGC.csv',
                              previous_result=result,
                              parsers=[utils.BlankParser(['ObjectNumber'])],
                              extra_columns=['ImageNumber', ],
                              merge_fcn=functools.partial(utils.merge_result, merge_on=['ImageNumber'], how='left'),
                             )
    result = result.rename(columns={'ObjectNumber': 'Parent_DilatedGC'})

    # Measure features from GC objects
    result, extra = utils.analyze(directory / 'InitialGC.csv',
                              previous_result=result,
                              parsers=[
                                  utils.CountingParser(),
                                  utils.ShapeParser(),
                                  utils.RDFParser(id_vars=['ImageNumber', 'ObjectNumber', 'Parent_DilatedGC']),
                              ],
                              region='GC',
                              **extras
                             )
    
    # Measure features from FC objects
    result, _ = utils.analyze(directory / 'InitialFC.csv',
                              previous_result=result,
                              parsers=[
                                  utils.CountingParser(),
                                  utils.ShapeParser(),
                              ],
                              region='FC',
                              **extras_left
                             )
    
    return result, extra

# multiple dataframes can be combined with unique experiments and subsets of overlapping variables (e.g. time or treatment)
full_data, extra_data = read_data('/scratch/gpfs/tcomi/cp_paper_redo/rdf/testing/outputs', 
                      r'/[A-G]\d+_(?P<treatment>SCR|RPL5KD)_15p(?P<time>\d+)c.*nd2', bins=4)
rdf, = extra_data
full_data

In [None]:
# need to average GCs from each parent
rdf_avg = []
groups = ['ImageNumber', 'Parent_DilatedGC', 'channel', 'radius']
for name, dat in rdf.groupby(groups):
    rdf_avg.append(dict(
        zip(groups, name),
        intensity=((dat['intensity'] * dat['counts']).fillna(0).sum()) / dat['counts'].sum(),
        counts=dat['counts'].sum(),
    ))
rdf_avg = pd.DataFrame(rdf_avg)
rdf_avg

In [None]:
# get cell information
merged = rdf_avg.merge(full_data[['ImageNumber', 'Parent_DilatedGC', 'time', 'treatment']], 
                   on=['ImageNumber', 'Parent_DilatedGC'])
# average raw values based on target and ssu, estimate sem from total variance
groups = ['time', 'treatment', 'channel']
channels = ['', 'EU', 'DFC', 'FC', 'GC']
rdf_data = []

for name, dat in merged.groupby(groups):
    pivoted = dat.pivot_table(columns='radius', values=['intensity', 'counts'], index=['ImageNumber', 'Parent_DilatedGC'])
    average_intens = ((pivoted['intensity'] * pivoted['counts']).fillna(0).sum()) / pivoted['counts'].sum()
    mn, mx = pivoted['intensity'].min(), pivoted['intensity'].max()
    normed = (pivoted['intensity'] - mn) / (mx - mn)
    
    sem = np.sqrt((((normed - normed.mean())**2) * pivoted['counts']).sum() / pivoted['counts'].sum()) / np.sqrt(len(normed))
    mn, mx = average_intens.min(), average_intens.max()
    norm_intens = (average_intens - mn) / (mx - mn)
    for radius, vals in pd.concat([norm_intens, average_intens, sem], axis=1).iterrows():
        rdf_data.append(dict(
            zip(groups, name),
            norm_intensity=vals[0],
            intensity=vals[1],
            sem=vals[2],
            channel=channels[name[2]],
            radius=radius,
        ))
rdf_data = pd.DataFrame(rdf_data)  
rdf_data['time'] = rdf_data.time.astype(int)

In [None]:
sns.relplot(data=rdf_data, x='radius', y='intensity', col='channel', 
            kind='line', style='treatment', hue='time', facet_kws=dict(sharex=True, sharey=False))

In [None]:
sns.relplot(data=rdf_data, x='radius', y='norm_intensity', col='channel', 
            kind='line', style='treatment', hue='time', facet_kws=dict(sharex=True, sharey=False))

In [None]:
channels = ['EU', 'DFC', 'FC', 'GC']
g = sns.relplot(data=rdf_data, x='radius', y='norm_intensity', col='time', col_wrap=3,
            kind='line', style='treatment', hue='channel', facet_kws=dict(sharex=True, sharey=False), hue_order=channels)

for time, ax in g.axes_dict.items():
    for channel in channels:
        sub_dat = rdf_data[(rdf_data.channel==channel) & (rdf_data.time == time)]
        ax.fill_between(sub_dat.radius, sub_dat.norm_intensity - sub_dat['sem'], sub_dat.norm_intensity + sub_dat['sem'], alpha=0.3)

In [None]:
rdf_data.channel.unique()

In [None]:
# get peak position over time
overall_peak_vals = []
for channel, (ax, title) in enumerate(zip(axs.flatten(), rdf_data.channel.unique())):
    result = []
    for name, dat in rdf_data[rdf_data.channel == title].groupby(["time",])[
        ["radius", "intensity"]
    ]:
        result.append(
            {
                "com_radius": np.average(
                    dat["radius"],
                    weights=(dat["intensity"] - dat["intensity"].min())
                    / (dat["intensity"].max() - dat["intensity"].min()),
                ),
                "max_radius": dat.loc[dat['intensity'].idxmax(), 'radius'],
                "time": name,
            }
        )
    peak_vals = pd.DataFrame.from_records(result)
    overall_peak_vals.append(peak_vals.assign(channel=title))
overall_peak_vals = pd.concat(overall_peak_vals, ignore_index=True)
overall_peak_vals

In [None]:
sns.relplot(data=overall_peak_vals, x='time', y='com_radius', col='channel', col_wrap=2, kind='line', hue='channel')
plt.subplots()
sns.lineplot(data=overall_peak_vals, x='time', y='com_radius', hue='channel')

In [None]:
# generate color bar
def generate_color_bar(data, rgb, height=10, save_name=None, dist_to_peak=None):
    result = np.zeros((height, len(data.radius.unique()), 3))

    for output_ind, input_ind in enumerate(rgb):
        to_plot = data[data.channel == input_ind].copy()
        intens = to_plot.groupby("radius")["intensity"].mean()
        result[:, ..., output_ind] = (intens - intens.min()) / (
            intens.max() - intens.min()
        )

    # enhance color bar to show regions of difference between phases
    # effectively set intersection of phases to 0 and clip negative
    # subtract FC signal (green)
    result[:, :, 0] -= result[:, :, 1]
    result[:, :, 2] -= result[:, :, 1]

    # subtract DFC signal (red)
    result[:, :, 1] -= result[:, :, 0]
    result[:, :, 2] -= result[:, :, 0]

    # clip and rescale
    result = np.clip(result, 0, 1)
    result /= result.max(axis=(0, 1), keepdims=True)

    fig, ax = plt.subplots()
    ax.imshow(result, aspect="equal")
    plt.tick_params(left=False, labelleft=False)

    ax.set_xlabel('Radius (px)')
    
    if dist_to_peak is not None:
        t_labels = dist_to_peak.time
        t_pixels = dist_to_peak.max_radius
        ax_t = ax.secondary_xaxis('top')
        ax_t.set_xticks(t_pixels)
        ax_t.set_xticklabels(t_labels)
        ax_t.set_xlabel('Time (min)')

    # find intersections
    r,g,b = result[0].T
    
    valid_idx = np.where(r >= 0.2)[0]
    fc_dfc = valid_idx[abs(r[valid_idx] - g[valid_idx]).argmin()]
    ax.axvline(fc_dfc, c='k', linestyle='--')
    dfc_gc = valid_idx[abs(r[valid_idx] - b[valid_idx]).argmin()]
    ax.axvline(dfc_gc, c='k', linestyle='--')
    
    print(intens.index[fc_dfc], intens.index[dfc_gc])
    
    if save_name:
        plt.savefig(save_name)


generate_color_bar(
    data=rdf_data,
    rgb=('DFC', 'FC', 'GC'),
    # sample number of times to prevent overlap
    dist_to_peak=overall_peak_vals[overall_peak_vals.time.isin((0, 30, 45, 60, 120)) & (overall_peak_vals.channel == 'EU')],
    # save_name="linear_distance_remap.pdf",
)

In [None]:
# generate color bar
def generate_color_bar_linearT(data, dist_to_peak, rgb, height=10, save_name=None):
    result = np.zeros((height, dist_to_peak.time.max()+1, 3))
    
    interp_dists = np.interp(x=np.arange(0, dist_to_peak.time.max()+1),
                             xp=dist_to_peak.time, fp=dist_to_peak.max_radius)
    
    dist_to_time = (np.abs(data.radius.unique() - interp_dists[:, None])).argmin(axis=1)
    
    for output_ind, input_ind in enumerate(rgb):
        to_plot = data[data.channel == input_ind].copy()
        intens = to_plot.groupby("radius")["intensity"].mean()
        intens = (intens - intens.min()) / (
            intens.max() - intens.min()
        )
        result[:, ..., output_ind] = intens.iloc[dist_to_time]
            
    result[:, :, 0] -= result[:, :, 1]
    result[:, :, 2] -= result[:, :, 1]

    result[:, :, 1] -= result[:, :, 0]
    result[:, :, 2] -= result[:, :, 0]

    result = np.clip(result, 0, 1)
    result /= result.max(axis=(0, 1), keepdims=True)

    fig, ax = plt.subplots()
    ax.imshow(result, aspect="equal")
    plt.tick_params(left=False, labelleft=False)
    
    time_ticks = np.linspace(0, dist_to_peak.time.max(), 7)
    ax.set_xticks(time_ticks)
    ax.set_xlabel('Time (min)')
    
    ax_t = ax.secondary_xaxis('top')
    ax_t.set_xticks(time_ticks)
    
    dist_labels = np.interp(x=time_ticks, xp=dist_to_peak.time, fp=dist_to_peak.max_radius).round(3)
    ax_t.set_xticklabels(dist_labels)
    ax_t.set_xlabel('Distance (px)')

    # find intersections
    r,g,b = result[0].T
    
    valid_idx = np.where(r >= 0.2)[0]
    fc_dfc = valid_idx[abs(r[valid_idx] - g[valid_idx]).argmin()]
    ax.axvline(fc_dfc, c='k', linestyle='--')
    dfc_gc = valid_idx[abs(r[valid_idx] - b[valid_idx]).argmin()]
    ax.axvline(dfc_gc, c='k', linestyle='--')
    print(interp_dists[fc_dfc], interp_dists[dfc_gc])
    time_pos = np.arange(dist_to_peak.Chase.max())
    print(time_pos[fc_dfc], time_pos[dfc_gc])
    
    if save_name:
        # imageio.imwrite(save_name, np.uint8(result * 255))
        plt.savefig(save_name)


generate_color_bar_linearT(
    data=rdf_data,  # show a subset
    dist_to_peak=overall_peak_vals[(overall_peak_vals.channel == 'EU')],
    rgb=('DFC', 'FC', 'GC'),
    # save_name="linear_time_remap.pdf",  # save as a png to recolor
)