In [1]:
import pathlib as pl
from configparser import ConfigParser

import numpy as np
import pandas as pd

import plotly.graph_objs as go
import plotly.colors as pc
from plotly.subplots import make_subplots
from webcolors import name_to_hex, hex_to_name, rgb_to_name, rgb_to_hex

import joblib as jl

from src.root_path import config_path, root_path
from src.visualization.palette import *

# Sparseness of contextual effects
example neuron coverage of context-pair*probe (modulation) space

histogram showing the number of neurons for given percentages of modulation space coverage

surface plot showing the recruitment of neurons as the modulation space increases

In [198]:
config = ConfigParser()
config.read_file(open(config_path / 'settings.ini'))
meta = {'reliability': 0.1,  # r value
        'smoothing_window': 0,  # ms
        'raster_fs': 30,
        'montecarlo': 1000,
        'zscore': True,
        'dprime_absolute': None,
        'stim_type': 'permutations',
        'alpha':0.05}
# todo, if batch analysis rerun, use the anotated line instead
# summary_DF_file = pl.Path(config['paths']['analysis_cache']) / f'211221_cxt_metrics_summary_DF_alpha_{meta}'
summary_DF_file = pl.Path(config['paths']['analysis_cache']) / '211221_cxt_metrics_summary_DF_alpha_0.05'

### same example cell as in figure 1 ###
prb_idx = 3 - 1# selected probe. the -1 is to acount for 0 not being used
ctx_pair = [0,1] # pair of contexts to compare and exemplify d'
cellid = 'ARM021b-36-8'

n_eg_neu = 3 # total number of related examples



DF = jl.load(summary_DF_file)
def format_dataframe(DF):

    ff_analylis = DF.analysis.isin(['SC', 'fdPCA'])
    ff_corr = DF.mult_comp_corr == 'consecutive_3'

    good_cols =['analysis', 'mult_comp_corr', 'region', 'siteid',  'cellid', 'context_pair',
                'probe', 'metric', 'value']
    filtered = DF.loc[ff_analylis & ff_corr, good_cols]

    filtered['probe'] = [int(p) for p in filtered['probe']]
    filtered['context_pair'] = [f"{int(cp.split('_')[0]):02d}_{int(cp.split('_')[1]):02d}"
                                for cp in filtered['context_pair']]

    # rename metrics and analysis for ease of ploting
    filtered['metric'] = filtered['metric'].replace({'significant_abs_mass_center': 'center of mass (ms)',
                                                     'significant_abs_mean': "mean d'",
                                                     'significant_abs_sum': "integral (d'*ms)"})
    filtered['analysis'] = filtered['analysis'].replace({'SC': 'single cell',
                                                         'fdPCA': 'population',
                                                         'pdPCA': 'probewise pop',
                                                         'LDA': 'pop ceiling'})

    filtered['id'] = filtered['cellid'].fillna(value=filtered['siteid'])
    filtered = filtered.drop(columns=['cellid', 'siteid'])

    filtered['value'] = filtered['value'].fillna(value=0)

    # permutation related preprocesing.
    # creates a new column relating probe with  context pairs
    ctx = np.asarray([row.split('_') for row in filtered.context_pair], dtype=int)
    prb = np.asarray(filtered.probe, dtype=int)

    silence = ctx == 0
    same = ctx == prb[:,None]
    different = np.logical_and(~silence, ~same)

    name_arr = np.full_like(ctx, np.nan, dtype=object)
    name_arr[silence] = 'silence'
    name_arr[same] = 'same'
    name_arr[different] = 'diff'
    comp_name_arr = np.apply_along_axis('_'.join, 1, name_arr)

    # swaps clasification names to not have repetitions i.e. diff_same == same_diff
    comp_name_arr[np.where(comp_name_arr == 'same_silence')] = 'silence_same'
    comp_name_arr[np.where(comp_name_arr == 'diff_silence')] = 'silence_diff'
    comp_name_arr[np.where(comp_name_arr == 'diff_same')] = 'same_diff'
    comp_name_arr[np.where(comp_name_arr == 'same_silence')] = 'silence_same'

    filtered['trans_pair'] = comp_name_arr
    return filtered

DF_long = format_dataframe(DF)

### Figure parameters and subplot layout

In [199]:
n_eg_neu = 3 #todo delete this line
# find top neurons in example site
site = cellid[:7]
ff_site = DF_long.id.str.contains(cellid[:7])
ff_sc = DF_long.analysis == 'single cell'
ff_metric = DF_long.metric == "integral (d'*ms)"

# probably neurons with large tilings are going to have high sum amplitude
# across the modulation space
top_neurons = DF_long.loc[ff_site & ff_sc & ff_metric, :
           ].groupby(['id']).agg(sum_amplitude=('value', np.sum))
top_neurons = top_neurons.sort_values(by='sum_amplitude', ascending=False).head(n_eg_neu).index.values
top_neurons

array(['ARM021b-36-8', 'ARM021b-43-8', 'ARM021b-05-1'], dtype=object)

In [3]:
# fig = make_subplots(rows=2,cols=3, vertical_spacing=0.05, horizontal_spacing=0.05,
#                     column_widths=[0.5, 0.4, 0.1],
#                     specs=[[{'rowspan':2}, {'l':0.05}, {}],
#                            [None         , {'l':0.05}, {}]])

fig = make_subplots(1,1)

ppi = 300 # high quality print standard (do i need this if I do postprocecing in vectors)
ppi = 96 # www standard
ppi = 92.5 # house monitor

width = 6 # in inches
heigh = width * 4/9 # aspect ratio of heatmap
_ = fig.update_layout(template='simple_white',
                      margin=dict(l=10, r=10, t=10, b=10),
                      width=round(ppi*width), height=round(ppi*heigh),
                      showlegend=False)

# double heatplot (contextual metrics) for example cell

In [327]:
fig = make_subplots(rows=n_eg_neu, cols=1,
                    vertical_spacing=0.01,
                    # horizontal_spacing=0.05,
                    # column_widths=[0.5, 0.4, 0.1],
                    # specs=[[{'rowspan':2}, {'l':0.05}, {}],
                    #        [None         , {'l':0.05}, {}]]
                                               )

ppi = 300 # high quality print standard (do i need this if I do postprocecing in vectors)
ppi = 96 # www standard
ppi = 92.5 # house monitor

width = 6 # in inches
heigh = width * 4/9 # aspect ratio of heatmap
_ = fig.update_layout(template='simple_white',
                      margin=dict(l=10, r=10, t=10, b=10),
                      # width=round(ppi*width), height=round(ppi*heigh),
                      showlegend=False)

# prefilters all the data to be ploted (multiple neurons) to ensure shared colormaps
# for the fill values

# turns long format data into an array with dimension Probe * context_pair
to_pivot = DF_long.loc[(DF_long['id'].isin(top_neurons)),
                         ['id','context_pair', 'trans_pair', 'probe', 'metric', 'value']]
val_df = to_pivot.pivot_table(index=['metric', 'id','probe'], columns=['context_pair'], values='value')



cscales = {"integral (d'*ms)" : pc.make_colorscale(['#000000', Green]),
           "center of mass (ms)" : pc.make_colorscale(['#000000', Purple])}
max_vals = dict()
# normalizes,saves max values and get colors for each metric
color_df = val_df.copy()
for metric in color_df.index.levels[0]:
    max_vals[metric] = val_df.loc[metric].values.max()
    colors = pc.sample_colorscale(cscales[metric],
                                  (val_df.loc[metric] / max_vals[metric]).values.flatten())
    color_df.loc[metric] = np.asarray(colors).reshape(color_df.loc[metric].shape)

# general shapes of the upper and lower triangles to be passed to Scatter x and y
xu, yu = np.array([0,0,1,0]), np.array([0,1,1,0])
xl, yl = np.array([0,1,1,0]), np.array([0,0,1,0])

for row, id in enumerate(top_neurons):
    col = 1
    row += 1 # plotly 1 indexing FTW!

    amp_color = color_df.loc[("integral (d'*ms)", id), :].values
    dur_color = color_df.loc[('center of mass (ms)', id), :].values

    amplitudes = val_df.loc[("integral (d'*ms)", id), :]
    durations = val_df.loc[('center of mass (ms)', id), :]


    for nn, (p, c) in enumerate(np.ndindex(amp_color.shape)):
        # note the use of transparent markers to define the colorbars internally
        # amplitud uppe half
        _ = fig.add_trace(go.Scatter(x=xu+c, y=yu+p, mode='lines+markers',
                                     line_width=1, line_color='#222222',
                                     fill='toself', fillcolor=amp_color[p,c],
                                     marker=dict(color=(amplitudes.values[p,c],)*len(xu),
                                                 coloraxis='coloraxis',
                                                 opacity=0,
                                                 cmin=0, cmax=max_vals["integral (d'*ms)"],
                                                 )
                                     ),
                          row=row, col=col)

        # duration lower half
        _ = fig.add_trace(go.Scatter(x=xl+c, y=yl+p, mode='lines+markers',
                                     line_width=1, line_color='#222222',
                                     fill='toself', fillcolor=dur_color[p,c],
                                     marker=dict(color=(durations.values[p,c],)*len(xl),
                                                 coloraxis='coloraxis2',
                                                 opacity=0,
                                                 cmin=0, cmax=max_vals["center of mass (ms)"],
                                                 )),
                          row=row, col=col)

    # defines commong parameters for all axes,
    rr = row if row > 1 else '' #fist x and y axis dont have a number on their name
    xax_params = dict(scaleanchor=f'y{rr}',
                      constrain='domain',
                      range=[0, amplitudes.columns.size], fixedrange=True,
                      showticklabels=False, ticks='',
                      row=row, col=col)

    yax_params = dict(title=dict(text=f'{id}<br>probes'),
                  constrain='domain',
                  range=[0, amplitudes.index.size], fixedrange=True,
                  tickmode = 'array',
                  tickvals = np.arange(amplitudes.index.size)+0.5,
                  ticktext = amplitudes.index.to_list(),
                  row=row, col=col)

    if row == n_eg_neu: # last row title and tics
        # strip left zero padding for better display
        ticktexts = [f"{int(pp.split('_')[0])}_{int(pp.split('_')[1])}"
             for pp in amplitudes.columns.to_list()]
        xax_params.update(dict(title_text='context pairs',
                               showticklabels=True, ticks='outside',
                               tickmode = 'array',
                               tickvals=np.arange(amplitudes.columns.size)+0.5,
                               ticktext = ticktexts))

    _ = fig.update_xaxes(**xax_params)
    _ = fig.update_yaxes(**yax_params)

# highlights the example probe context_pair combination on the first subplot
ctx_idx = amplitudes.columns.to_list().index(f'{ctx_pair[0]:02}_{ctx_pair[1]:02}')
_ = fig.add_shape(type='rect', x0=ctx_idx, x1=ctx_idx+1, y0=prb_idx, y1=prb_idx+1,
                  line=dict(color='red', width=5), fillcolor='rgba(0,0,0,0)',
                  row=1, col=col)


fig.update_layout(coloraxis=dict(colorscale=cscales["integral (d'*ms)"],
                                 colorbar=dict(
                                     thickness=10, len=1,
                                     title_text="amplitude (d'*ms)",
                                     title_side='right',
                                     tickangle=-90,
                                     xanchor='left', x=0.78)
                                 ),
                  coloraxis2=dict(colorscale=cscales["center of mass (ms)"],
                                  colorbar=dict(
                                      thickness=10, len=1,
                                      title_text="duration (ms)",
                                      title_side='right',
                                      tickangle=-90,
                                      xanchor='left', x=0.7)
                                  )
                  )
fig.show()

In [None]:
# todo move the figures from ctx_probe_space_subsampling to new subplots here

In [None]:
# ensures transparent backgrounds
fig.update_layout(paper_bgcolor='rgba(0,0,0,0)',
                  plot_bgcolor='rgba(0,0,0,0)')
folder = root_path / 'reports' / 'figures' / 'paper'
folder.mkdir(parents=True, exist_ok=True)

filename = folder / 'figure_03'
# fig.write_image(filename.with_suffix('.png'))
# fig.write_image(filename.with_suffix('.svg'))