In [46]:
from math import pi
import itertools as itt

import numpy as np

import plotly.graph_objs as go
from plotly.subplots import make_subplots
from webcolors import name_to_rgb
from src.root_path import root_path

from src.data.rasters import load_site_formated_raster
from src.metrics.consolidated_dprimes import single_cell_dprimes
from src.metrics.significance import _significance

from src.visualization.fancy_plots import squarefy, quantified_dprime

### Plotting parameters

In [47]:
# list of sequences, manually prepended zeros
sequences = np.asarray([[0,1,3,2,4,4],
                        [0,3,4,1,1,2],
                        [0,4,2,3,3,1],
                        [0,2,2,1,4,3]])

n_samps = 100
colors = ['blue', 'orange', 'green', 'purple', 'brown']
dummy_wave = np.sin(np.linspace(0,pi*4,n_samps)) * 0.5 # todo, pull some real example waves??
waves = [np.zeros(n_samps)] + [dummy_wave, ] * 5 + [np.zeros(n_samps)]
vertical_offset = 1
prb_idx = 3 - 1# selected probe. the -1 is to acount for 0 not being used
ctx_pair = [0,1] # pair of contexts to compare and exemplify d'
cellid = 'ARM021b-36-8'


## loads all data

In [48]:
# rasters
site_raster, goodcellse = load_site_formated_raster(cellid[:7], part='all', smoothing_window=50)
eg_raster = site_raster[:, goodcellse.index(cellid),:, prb_idx, :]

# dprimes

meta = {'reliability': 0.1,  # r value
        'smoothing_window': 0,  # ms
        'raster_fs': 30,
        'montecarlo': 1000,
        'zscore': True,
        'dprime_absolute': None,
        'stim_type': 'permutations',
        'alpha':0.05}

dprime, shuff_dprime_quantiles, goodcells, var_capt = single_cell_dprimes(cellid[:7], contexts='all', probes='all', meta=meta)
significance, confidence_interval = _significance(dprime, shuff_dprime_quantiles,
                                                  multiple_comparisons_axis=[3], consecutive=3, alpha=meta['alpha'])
cell_idx = goodcells.index(cellid) if len(cellid) > 7 else 0
pair_idx = [f'{t0}_{t1}' for t0, t1 in itt.combinations(range(dprime.shape[2]), 2)].index(f'{ctx_pair[0]}_{ctx_pair[1]}')




## full figure configuration

In [49]:
# simple plot layout and other figurewide configurations
fig = make_subplots(rows=4,cols=1, vertical_spacing=0.05, horizontal_spacing=0.05)

# figure size in inches at different PPIs

ppi = 300 # high quality print standard (do i need this if I do postprocecing in vectors)
ppi = 96 # www standard
ppi = 92.5 # house monitor

width = 3 # in inches
heigh =  width * 4
_ = fig.update_layout(template='simple_white',
                      margin=dict(l=10, r=10, t=10, b=10),
                      width=round(ppi*width), height=round(ppi*heigh),
                      showlegend=False)

### sound sequences plus selected examples

In [50]:
row, col = 1, 1 # 1 indexing fuck my life!
for ss, seq in enumerate(sequences):
    for ww, wave_idx in enumerate(seq):
        # wave form plots
        x = np.linspace(0,1,n_samps) + ww
        y = waves[wave_idx] + ss * vertical_offset
        color = colors[wave_idx]

        _ = fig.add_trace(go.Scatter(x=x, y=y, line_color=color, mode='lines'),
                          row=row, col=col)

        # vertical lines for clear separation of sounds
        if ww > 0:
            _ = fig.add_vline(x=ww, line_width=2, line_color='black', line_dash='dot',
                              row=row, col=col)

        # add rectangle to point at exaample
        if wave_idx == prb_idx:
            x0 = ww - 1
            y0 = ss * vertical_offset - 0.5
            xd, yd = 2, 1 # 2 seconds widht, 2*norm wave
            x = [x0, x0, x0+xd, x0+xd, x0]
            y = [y0, y0+yd, y0+yd, y0, y0]
            _ = fig.add_trace(go.Scatter(x=x, y=y, line_color='black', mode='lines'),
                            row=row, col=col)

# Update xaxis properties, includding x padding to avoid trace clipping
_ = fig.update_xaxes(title_text='time (s)', title_standoff=0, range=[-0.1,6.1] ,
                     row=row, col=col)
_ = fig.update_yaxes(tickmode='array', tickvals=list(range(4)), ticktext=[f'Seq.{i+1}' for i in range(4)],
                     row=row, col=col)

### selected stimuli examples and transition type clasification

In [51]:
row, col = 2, 1

hor_off = 3
for ww, (wave, color) in enumerate(zip(waves, colors)):
    # context
    x = np.linspace(-1, 0, n_samps) + hor_off # sum to offset to center, insline with sequences
    y = wave + ww * vertical_offset
    _ = fig.add_trace(go.Scatter(x=x, y=y, mode='lines', line_color=color),
                      row=row, col=col)
    # ax.plot(x, y, color)

    # probe
    x = np.linspace(0, 1, n_samps) + hor_off
    y = waves[prb_idx] + ww * vertical_offset
    _ = fig.add_trace(go.Scatter(x=x, y=y, mode='lines', line_color=colors[prb_idx]),
                      row=row, col=col)
    # ax.plot(x, y, colors[prb_idx])

    # context type text
    if ww == 0:
        type_text = 'silence'
    elif ww == prb_idx:
        type_text = 'same'
    else:
        type_text = 'different'

    _ = fig.add_trace(go.Scatter(x=[-1.1 + hor_off], y=[ww * vertical_offset],
                                 mode='text', text=[type_text],
                                 textposition='middle left', textfont_size=15),
                      row=row, col=col)

# ax.axvline(0+hor_off, color='black', linestyle=':')
_ = fig.add_vline(x=hor_off, line_width=2, line_color='black', line_dash='dot', opacity=1,
                  row=row, col=col)
# context and probe text
_ = fig.add_trace(go.Scatter(x=[hor_off-0.2, hor_off+0.2],
                             y=[-1, -1],
                             mode='text', text=['<b>Context</b>', '<b>Probe</b>'],
                             textposition=['middle left', 'middle right'], textfont_size=18),
                  row=row, col=col)

_ = fig.update_layout(xaxis2=dict(visible=False, showline=False, matches='x'),
                      yaxis2=dict(visible=False, showline=False))

### example pair of PSTHs

In [52]:
row, col = 3, 1
# row, col = 'all', 'all'
# fig = make_subplots(rows=1,cols=1)

for cxt_idx in ctx_pair:
    nsamps = eg_raster.shape[-1]
    time = np.linspace(-1, 1, nsamps)
    mean_resp = np.mean(eg_raster[:,cxt_idx, :], axis=0)
    std_resp = np.std(eg_raster[:,cxt_idx, :],axis=0)

    halfs = [np.s_[:int(nsamps/2)], np.s_[int(nsamps/2):]]
    part_color = [colors[cxt_idx], colors[prb_idx]]


    for nn, (half, color) in enumerate(zip(halfs, part_color)):

        x, y = squarefy(time[half], mean_resp[half])
        _, ystd = squarefy(time[half], std_resp[half])

        # off set half a bin to the left
        halfbin = np.mean(np.diff(time))/2
        x -= halfbin
        y -= halfbin
        ystd -= halfbin

        if nn == 0:
            # ax.fill_between(x, y-ystd, y+ystd, color=color, alpha=0.5)
            _ = fig.add_trace(go.Scatter(x=x, y=y+ystd, mode='lines', line_color=color, line_width=1),
                              row=row, col=col)
            _ = fig.add_trace(go.Scatter(x=x, y=y-ystd, mode='lines', line_color=color, line_width=1,
                                         fill='tonexty'),
                              row=row, col=col)

        else:
            # to set a transparent fillcolor changes the 'rgb(x,y,z)' into 'rgba(x,y,z,a)'
            rgb = name_to_rgb(part_color[0]) # tupple
            fill_opacity = 0.5
            rgba = f'rgba({rgb[0]}, {rgb[1]}, {rgb[2]}, {fill_opacity})'

            _ = fig.add_trace(go.Scatter(x=x, y=y+ystd, mode='lines', line_color=color, line_width=1),
                  row=row, col=col)
            _ = fig.add_trace(go.Scatter(x=x, y=y-ystd, mode='lines', line_color=color, line_width=1,
                                         fill='tonexty', fillcolor=rgba),
                              row=row, col=col)

        # set the mean lines second so they lie on top of the colored areas
        _ = fig.add_trace(go.Scatter(x=x, y=y, mode='lines', line_color=color, line_width=3),
                          row=row, col=col)

_ = fig.add_vline(x=0, line_width=2, line_color='black', line_dash='dot', opacity=1,
                  row=row, col=col)

_ = fig.update_xaxes(title_text='time from probe onset (s)', title_standoff=0,
                     row=row, col=col)
_ = fig.update_yaxes(title_text='firing rate (z-score)', title_standoff=0,
                     row=row, col=col)


### example quantified dprime

In [53]:
row, col = 4, 1

DP = dprime[cell_idx, pair_idx, prb_idx, :] * -1
CI = confidence_interval[:, cell_idx, pair_idx, prb_idx, :] * -1
SIG = significance[cell_idx, pair_idx, prb_idx, :]
raster_fs = meta['raster_fs']

signif_mask = SIG>0
t =  np.linspace(0, DP.shape[-1]/raster_fs, DP.shape[-1],endpoint=False)

# calculates center of mass and integral
significant_abs_mass_center = np.sum(np.abs(DP[signif_mask]) * t[signif_mask]) / np.sum(np.abs(DP[signif_mask]))
significant_abs_sum = np.sum(np.abs(DP[signif_mask])) * np.mean(np.diff(t))
print(f"integral: {significant_abs_sum*1000:.2f} d'*ms")
print(f'center of mass: {significant_abs_mass_center*1000:.2f} ms')


# plots dprime
tt, mmdd = squarefy(t, DP)
_ = fig.add_trace(go.Scatter(x=tt, y=mmdd, mode='lines', line_color='black', line_width=3),
                  row=row, col=col)

# significance confidence interval
_, CCII = squarefy(t, CI.T)
_ = fig.add_trace(go.Scatter(x=tt, y=CCII[:,0], mode='lines', line_color='gray', line_width=1),
                  row=row, col=col)
_ = fig.add_trace(go.Scatter(x=tt, y=CCII[:, 1], mode='lines', line_color='gray', line_width=1,
                             fill='tonexty'),
                  row=row, col=col)

# significant area under the curve
# little hack to add gaps into the area, set d' value to zero where no significance
_, smm = squarefy(t,signif_mask)
wmmdd = np.where(smm, mmdd, 0)
rgb = name_to_rgb('green')
rgba = f'rgba({rgb[0]}, {rgb[1]}, {rgb[2]}, 0.5)'

_ = fig.add_trace(go.Scatter(x=tt, y=wmmdd, mode='none',
                             fill='tozeroy', fillcolor=rgba),
                  row=row, col=col)


# center of mass indication
_ = fig.add_vline(significant_abs_mass_center, line=dict(color='purple', dash='dash', width=3),
                  row=row, col=col)

# general plot formating
_ = fig.add_hline(0, line=dict(dash='dot', width=2, color='black'),
                  row=row, col=col)
# formats axis, legend and so on.

_ = fig.update_xaxes(title=dict(text='time from probe onset (s)', standoff=0),
                     row=row, col=col)

_ = fig.update_yaxes(title=dict(text="contexts d'", standoff=0),
                     row=row, col=col)

integral: 1238.28 d'*ms
center of mass: 369.22 ms


In [54]:
fig.show()

In [55]:
# ensures transparent backgrounds
fig.update_layout(paper_bgcolor='rgba(0,0,0,0)',
                  plot_bgcolor='rgba(0,0,0,0)')
folder = root_path / 'reports' / 'figures' / 'paper'
folder.mkdir(parents=True, exist_ok=True)

filename = folder / 'figure_01'
fig.write_image(filename.with_suffix('.png'))
fig.write_image(filename.with_suffix('.svg'))