In [1]:
import pathlib as pl
from configparser import ConfigParser

import numpy as np

import plotly.graph_objs as go
import plotly.colors as pc
from plotly.subplots import make_subplots
from webcolors import name_to_hex, hex_to_name, rgb_to_name, rgb_to_hex

import joblib as jl

from src.root_path import config_path, root_path
from src.visualization.palette import *

# Sparseness of contextual effects
example neuron coverage of context-pair*probe (modulation) space

histogram showing the number of neurons for given percentages of modulation space coverage

surface plot showing the recruitment of neurons as the modulation space increases

In [2]:
config = ConfigParser()
config.read_file(open(config_path / 'settings.ini'))
meta = {'reliability': 0.1,  # r value
        'smoothing_window': 0,  # ms
        'raster_fs': 30,
        'montecarlo': 1000,
        'zscore': True,
        'dprime_absolute': None,
        'stim_type': 'permutations',
        'alpha':0.05}
# todo, if batch analysis rerun, use the anotated line instead
# summary_DF_file = pl.Path(config['paths']['analysis_cache']) / f'211221_cxt_metrics_summary_DF_alpha_{meta}'
summary_DF_file = pl.Path(config['paths']['analysis_cache']) / '211221_cxt_metrics_summary_DF_alpha_0.05'

### same example cell as in figure 1 ###
prb_idx = 3 - 1# selected probe. the -1 is to acount for 0 not being used
ctx_pair = [0,1] # pair of contexts to compare and exemplify d'
cellid = 'ARM021b-36-8'



DF = jl.load(summary_DF_file)
def format_dataframe(DF):

    ff_analylis = DF.analysis.isin(['SC', 'fdPCA'])
    ff_corr = DF.mult_comp_corr == 'consecutive_3'

    good_cols =['analysis', 'mult_comp_corr', 'region', 'siteid',  'cellid', 'context_pair',
                'probe', 'metric', 'value']
    filtered = DF.loc[ff_analylis & ff_corr, good_cols]

    filtered['probe'] = [int(p) for p in filtered['probe']]
    filtered['context_pair'] = [f"{int(cp.split('_')[0]):02d}_{int(cp.split('_')[1]):02d}"
                                for cp in filtered['context_pair']]

    # rename metrics and analysis for ease of ploting
    filtered['metric'] = filtered['metric'].replace({'significant_abs_mass_center': 'center of mass (ms)',
                                                     'significant_abs_mean': "mean d'",
                                                     'significant_abs_sum': "integral (d'*ms)"})
    filtered['analysis'] = filtered['analysis'].replace({'SC': 'single cell',
                                                         'fdPCA': 'population',
                                                         'pdPCA': 'probewise pop',
                                                         'LDA': 'pop ceiling'})

    filtered['id'] = filtered['cellid'].fillna(value=filtered['siteid'])
    filtered = filtered.drop(columns=['cellid', 'siteid'])

    filtered['value'] = filtered['value'].fillna(value=0)

    # permutation related preprocesing.
    # creates a new column relating probe with  context pairs
    ctx = np.asarray([row.split('_') for row in filtered.context_pair], dtype=int)
    prb = np.asarray(filtered.probe, dtype=int)

    silence = ctx == 0
    same = ctx == prb[:,None]
    different = np.logical_and(~silence, ~same)

    name_arr = np.full_like(ctx, np.nan, dtype=object)
    name_arr[silence] = 'silence'
    name_arr[same] = 'same'
    name_arr[different] = 'diff'
    comp_name_arr = np.apply_along_axis('_'.join, 1, name_arr)

    # swaps clasification names to not have repetitions i.e. diff_same == same_diff
    comp_name_arr[np.where(comp_name_arr == 'same_silence')] = 'silence_same'
    comp_name_arr[np.where(comp_name_arr == 'diff_silence')] = 'silence_diff'
    comp_name_arr[np.where(comp_name_arr == 'diff_same')] = 'same_diff'
    comp_name_arr[np.where(comp_name_arr == 'same_silence')] = 'silence_same'

    filtered['trans_pair'] = comp_name_arr

    ord_cols = ['analysis', 'region', 'id', 'context_pair', 'trans_pair', 'probe', 'metric', 'value']
    pivot_idx = [col for col in ord_cols if col not in ['value', 'metric']]
    pivoted = filtered.pivot_table(index=pivot_idx, columns='metric', values='value', aggfunc='first').reset_index()

    full_long = filtered # saves long format for subsamplig analysis

    return pivoted, full_long
pivoted, filtered = format_dataframe(DF)

### Figure parameters and subplot layout

In [3]:
# fig = make_subplots(rows=2,cols=3, vertical_spacing=0.05, horizontal_spacing=0.05,
#                     column_widths=[0.5, 0.4, 0.1],
#                     specs=[[{'rowspan':2}, {'l':0.05}, {}],
#                            [None         , {'l':0.05}, {}]])

fig = make_subplots(1,1)

ppi = 300 # high quality print standard (do i need this if I do postprocecing in vectors)
ppi = 96 # www standard
ppi = 92.5 # house monitor

width = 6 # in inches
heigh = width * 4/9 # aspect ratio of heatmap
_ = fig.update_layout(template='simple_white',
                      margin=dict(l=10, r=10, t=10, b=10),
                      width=round(ppi*width), height=round(ppi*heigh),
                      showlegend=False)

# double heatplot (contextual metrics) for example cell

In [82]:
fig = make_subplots(1,1)
width = 6 # in inches
heigh = width * 4/9 # aspect ratio of heatmap
_ = fig.update_layout(template='simple_white',
                      margin=dict(l=10, r=10, t=10, b=10),
                      # width=round(ppi*width), height=round(ppi*heigh),
                      showlegend=False)

row, col = 1, 1

# turns long format data into an array with dimension Probe * context_pair
to_pivot = filtered.loc[(filtered['id'] == cellid),
                         ['context_pair', 'trans_pair', 'probe', 'metric', 'value']]
sqr_df = to_pivot.pivot_table(index=['metric','probe'], columns=['context_pair'], values='value')
amplitudes = sqr_df.loc["integral (d'*ms)", :]
durations = sqr_df.loc["center of mass (ms)", :]

# general shapes of the upper and lower triangles to be passed to Scatter x and y
xu, yu = np.array([0,0,1,0]), np.array([0,1,1,0])
xl, yl = np.array([0,1,1,0]), np.array([0,0,1,0])

# defines colors corresponding to normalized metric values
norm_amp = (amplitudes.values / amplitudes.values.max()).flatten()
norm_dur = (durations.values / durations.values.max()).flatten()

amp_cscale = pc.make_colorscale(['#000000', Green])
dur_cscale = pc.make_colorscale(['#000000', Purple])

amp_color = np.asarray(pc.sample_colorscale(amp_cscale, norm_amp)).reshape(amplitudes.shape)
dur_color = np.asarray(pc.sample_colorscale(dur_cscale, norm_dur)).reshape(durations.shape)

for nn, (p, c) in enumerate(np.ndindex(amplitudes.shape)):

    # amplitud uppe half
    _ = fig.add_trace(go.Scatter(x=xu+c, y=yu+p, mode='lines',
                                 line_width=1, line_color='#222222',
                                 fill='toself', fillcolor=amp_color[p,c],
                                 marker=dict(color=amplitudes.values[p,c],
                                             coloraxis='coloraxis',
                                             cmin=0, cmax=amplitudes.values.max(),
                                             showscale=True,
                                             )
                                 ),
                      row=row, col=col)

    # duration lower half
    _ = fig.add_trace(go.Scatter(x=xl+c, y=yl+p, mode='lines',
                                 line_width=1, line_color='#222222',
                                 fill='toself', fillcolor=dur_color[p,c],
                                 marker=dict(
                                             color=durations.values[p,c],
                                             coloraxis='coloraxis2',
                                             cmin=0, cmax=durations.values.max(),
                                             showscale=True,
                                             )),
                      row=row, col=col)

# add dummy traces for their color bars
# _ = fig.add_trace(go.Scatter(x=[None], y=[None], mode='markers',
#                             marker=dict(colorscale=amp_cscale, showscale=True,
#                                         cmin=0, cmax=amplitudes.values.max(),
#                                         colorbar=dict(
#                                             thickness=10,
#                                             title_text="amplitude (d'*ms)",
#                                             title_side='right',
#                                             tickangle=-90,
#                                             xanchor='left', x=1.15))),
#                   row=row, col=col)
#
# _ = fig.add_trace(go.Scatter(x=[None], y=[None], mode='markers',
#                             marker=dict(colorscale=dur_cscale, showscale=True,
#                                         cmin=0, cmax=durations.values.max(),
#                                         colorbar=dict(
#                                             thickness=10,
#                                             title_text="duration (ms)",
#                                             title_side='right',
#                                             tickangle=-90,
#                                             xanchor='left', x=1))),
#                   row=row, col=col)

# highlights the example probe context_pair combination
# ctx_idx = amplitudes.columns.to_list().index(f'{ctx_pair[0]:02}_{ctx_pair[1]:02}')
# _ = fig.add_shape(type='rect', x0=ctx_idx-0.5, x1=ctx_idx+0.5, y0=prb_idx-0.5, y1=prb_idx+0.5,
#                   line=dict(color='red', width=5), fillcolor='rgba(0,0,0,0)',
#                   row=row, col=col)

# strip left zero padding for better display
ticktexts = [f"{int(pp.split('_')[0])}_{int(pp.split('_')[1])}" for pp in amplitudes.columns.to_list()]
_ = fig.update_xaxes(title=dict(text='context pairs'),
                     anchor='y', scaleanchor='y',
                     constrain='domain',
                     tickmode = 'array',
                     tickvals=np.arange(amplitudes.columns.size)+0.5,
                     ticktext = ticktexts)

_ = fig.update_yaxes(title=dict(text='probes'), anchor = "x",
                     constrain='domain',
                     range=[0, amplitudes.index.size],
                     tickmode = 'array',
                     tickvals = np.arange(amplitudes.index.size)+0.5,
                     ticktext = amplitudes.index.to_list()
                     )

fig.update_layout(coloraxis=dict(colorscale='inferno',
                                 colorbar=dict(
                                     title=dict(
                                         text='neurons recruited (%)',
                                         side='right'
                                     ))))

fig.show()

In [93]:
# todo find another exsample neuron from the same site with a different coverage
fig = go.Figure()
_ = fig.add_trace(go.Scatter(x=xu+c-0.5, y=yu+p-0.5, mode='markers',
                             # line_width=1, line_color='#222222',
                             # fill='toself', fillcolor=amp_color[p,c],
                             marker=dict(color=amplitudes.values[p,c],
                                         coloraxis='coloraxis',
                                         cmin=0, cmax=amplitudes.values.max(),
                                         showscale=False,
                                         )
                             ))

fig.show()

In [5]:
# ensures transparent backgrounds
fig.update_layout(paper_bgcolor='rgba(0,0,0,0)',
                  plot_bgcolor='rgba(0,0,0,0)')
folder = root_path / 'reports' / 'figures' / 'paper'
folder.mkdir(parents=True, exist_ok=True)

filename = folder / 'figure_03'
# fig.write_image(filename.with_suffix('.png'))
# fig.write_image(filename.with_suffix('.svg'))

In [6]:
Purple

'#B07AA1'