In [None]:
# default_exp plotting

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#export
import matplotlib.pyplot as plt
from matplotlib import gridspec
import numpy as np

from theonerig.core import *
from theonerig.processing import *
from theonerig.utils import *
from theonerig.modelling import *

In [None]:
#export
def plot_2d_sta(sta):
    sta = np.array(sta)
    if len(sta.shape) == 2:
        sta = [sta]
    fig = plt.figure(figsize=(20,4+len(sta)//8*2))
    gs = gridspec.GridSpec(len(sta)//8 + 1, 8)    
    for i, frame in enumerate(sta):
        ax1 = plt.subplot(gs[i//8, i%8])
        ax1.imshow(frame, cmap='gray',vmin=-1, vmax=1)

In [None]:
from os.path import join
import matplotlib.pyplot as plt
vivo_2p_dir = "./files/vivo_2p"
reM = import_record(join(vivo_2p_dir, "record_master.h5"))
pipe = Data_Pipe(reM, ["checkerboard", "eye_tracking", "S_matrix"], ["stim_inten", "eye_track", "spike_counts"])
pipe += "checkerboard"
result_sta = process_sta_batch(eyetrack_stim_inten(pipe[0]["stim_inten"], pipe[0]["eye_track"]), 
                               pipe[0]["spike_counts"], Hw=40)

In [None]:
plot_2d_sta(result_sta[0][::2])

In [None]:
#export
def plot_cross_correlation(correlation_array, threshold=.1  ,two_sided=True, figsize=None):
    if figsize is None:
        figsize = (len(correlation_array), len(correlation_array))
    n_cell = correlation_array.shape[0]
    _min,_max = np.min(correlation_array), np.max(correlation_array)
    thresh = (_max-_min) * threshold
    fig = plt.figure(figsize=figsize)
    for i in range(n_cell):
        for j in range(i, n_cell):
            c = "#1f77b4"
            if np.max(correlation_array[i,j])-np.min(correlation_array[i,j]) > thresh:
                c = "red"
            for k in range(2 if two_sided else 1):
                if k==0:
                    ax = fig.add_subplot(n_cell,n_cell,i*n_cell+j+1, ylim=(_min,_max), label=str(i*n_cell+j+1))
                else:
                    ax = fig.add_subplot(n_cell,n_cell,j*n_cell+i+1, ylim=(_min,_max), label="b"+str(i*n_cell+j+1))
                plt.plot(correlation_array[i,j], c=c)
                plt.axis('off')
                if i == 0 and k==0:
                    ax.set_title(str(j))  
                elif i == 0 and k==1:
                    ax.set_title(str(j), pad =-50, loc="left")  
                elif i == j:
                    ax.set_title(str(j), pad =-50, loc="center")

In [None]:
pipe = Data_Pipe(reM, ["S_matrix"])
pipe += "checkerboard"
checker_corr = cross_correlation(pipe[0]["S_matrix"])
plot_cross_correlation(checker_corr, threshold=.3, figsize=(4,4))

In [None]:
#export
def plot_2d_fit(sta, param_d, figsize=None):
    if figsize is None:
        figsize = (4,8)
    fig = plt.figure(figsize=figsize)
    plt.subplot(1,2,1)
    plt.imshow(sta, vmin=-1,vmax=1, cmap="gray")
    plt.subplot(1,2,2)
    
    plt.imshow(img_2d_fit(sta.shape, param_d, f=sum_of_2D_gaussian), vmin=-1,vmax=1, cmap="gray")

In [None]:
sta = result_sta[0,25]
plot_2d_fit(sta, fit_spatial_sta(sta))

In [None]:
from nbdev.export import *
notebook2script()