In [None]:
import os
import math
import numpy as np
from numpy import genfromtxt
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.gridspec as gridspec
import seaborn as sns
from scipy import stats
from scipy.stats import gaussian_kde
from sklearn.decomposition import PCA
from mpl_toolkits import mplot3d
from numpy import random as rd
import time
import pylab as pl
import holoviews as hv
from holoviews import dim, opts
hv.notebook_extension('bokeh')
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))
from IPython import display

## Load units
> ### Specify brain region and session number (starting from 0)

In [None]:
brainRegion = 'VISp'
sessions = []
for file in os.listdir('./buzsaki_data/'):
    if brainRegion in file:
        if 'firingrates' in file:
            sessions.append(file.split('session_')[1][:9])

sesh = 0
for file in os.listdir('./buzsaki_data/'):
    if brainRegion in file:
        if sessions[sesh] in file:
            if 'firingrates' in file:
                units = pd.read_csv('./buzsaki_data/' + file)
                unit_id = units.iloc[:,0]
                cell_type = units.iloc[:,1]
                firingRates = units.iloc[:,2:].to_numpy()
            elif 'runningSpeed' in file:
                running = np.genfromtxt('./buzsaki_data/' + file)
            elif 'labels' in file:
                stimnums = np.genfromtxt('./buzsaki_data/' + file)
            elif 'labelNames' in file:
                stimnames = np.genfromtxt('./buzsaki_data/' + file, dtype=str)

# Chop off beginning and end of session where stimulus presentation, running, and recordings don't line up
firstNonNaNindex = np.where(np.isnan(running))[0][-1] + 1
if len(stimnums) < len(running):
    stimnums = stimnums[firstNonNaNindex:]
    stimnames = stimnames[firstNonNaNindex:]
    firingRates = firingRates[:,firstNonNaNindex:len(stimnums)+firstNonNaNindex]
    running = running[firstNonNaNindex:len(stimnums)+firstNonNaNindex]
else:
    useToChop = running
    firingRates = firingRates[:,firstNonNaNindex:len(useToChop)]
    stimnums = stimnums[firstNonNaNindex:len(useToChop)]
    stimnames = stimnames[firstNonNaNindex:len(useToChop)]
    running = running[firstNonNaNindex:]
    
# indices = np.arange(0,len(running),20)
# firingRates = firingRates[:,indices]
# stimnums = stimnums[indices]
# stimnames = stimnames[indices]
# running = running[indices]

In [None]:
%matplotlib inline
FSUs = np.zeros(firingRates.shape[0])
RSUs = np.zeros(firingRates.shape[0])
for j in range(firingRates.shape[0]):
    if cell_type[j] == 'FSU':
        FSUs[j] = 1
    else:
        RSUs[j] = 1

# Plot all currents as heatmap
plt.figure(figsize=(25,5))
plt.imshow(firingRates, aspect='auto', cmap='viridis')
plt.colorbar()
plt.show()

# Plot individual currents
plt.figure(figsize=(20,8))
for i in range(5):
    plt.plot(firingRates[i,:] + i*firingRates[i,:].max(), linewidth=1)
plt.ylabel("Current")
plt.xlabel("Time")
plt.show()

## Calculate covariance matrix and eigenvalues/vectors
> ### choose number of dimensions to create projection of

In [None]:
#rates = firingRates[RSUs.astype(bool),:].T #Transpose firingRates if you would like to collapse across cells
rates = firingRates.copy().T
rates -= np.mean(rates, axis=0)
covMat = (1.0/(rates.shape[0]-1))*(rates.T @ rates) # covariance matrix
evalues, evectors = np.linalg.eig(covMat) # eigenvalues and eigenvectors
#np.linalg.svd()
numProjections = 10 # number of dimensions
projections = np.zeros((numProjections,rates.shape[0]))
for x in range(numProjections):
    projections[x,:] = np.dot(rates, evectors.T[x]) # projections of principal components onto firing rates
#%matplotlib inline
#plt.figure(figsize=(10,5))
#plt.scatter(x=np.arange(len(evalues)), y=evalues, marker='o', s=10)

In [None]:
%matplotlib inline
rates = firingRates.copy().T
pca = PCA()
pca.fit(rates)
numProjections = 10 # number of dimensions
projections = np.zeros((numProjections,rates.shape[0]))
for x in range(numProjections):
    projections[x,:] = np.dot(rates, pca.components_[x])
plt.figure(figsize=(20,5))
plt.scatter(np.arange(len(pca.explained_variance_ratio_)), pca.explained_variance_ratio_/np.sum(pca.explained_variance_ratio_))
plt.show()

In [None]:
%matplotlib inline
numvectors = 2
fig = plt.figure(figsize=(20,8))
for n in range(numvectors):
    plt.plot(pca.components_[n] - n, c='black')
    plt.scatter(x = np.arange(len(FSUs)), y = np.zeros((len(FSUs),1))-n, c=FSUs, cmap='coolwarm', alpha=1, s=20)
    plt.text(x=0, y=-n+0.25, s='PC ' + str(n+1), fontsize=15, c='black')
FSUpatch = mpatches.Patch(color='red',label='Fast Spiking Unit')
RSUpatch = mpatches.Patch(color='blue',label='Regular Spiking Unit')
plt.legend(handles = [FSUpatch,RSUpatch],loc='center',fontsize=15)
plt.title(brainRegion + '_session_' + sessions[sesh] + '_PC1and2_byCellType')
plt.show()
#fig.savefig('./buzsaki_plots/' + brainRegion + '_session_' + sessions[sesh] + '_PC1and2_byCellType.png')

In [None]:
plt.figure(figsize=(18,8))
numvectors = 2
fig = plt.figure(figsize=(20,8))
for n in range(numvectors):
    plt.plot(pca.components_[n][np.abs(pca.components_[n]).argsort()]-n, c='black')
    plt.scatter(x = np.arange(len(FSUs)), y = np.zeros((len(FSUs),1))-n, c=FSUs[np.abs(pca.components_[n]).argsort()], cmap='coolwarm', alpha=1, s=50)
    plt.text(x=0, y=-n+0.25, s='PC ' + str(n+1), fontsize=15, c='black')
FSUpatch = mpatches.Patch(color='red',label='Fast Spiking Unit')
RSUpatch = mpatches.Patch(color='blue',label='Regular Spiking Unit')
plt.legend(handles = [FSUpatch,RSUpatch],loc='center',fontsize=15)
plt.title(brainRegion + '_session_' + sessions[sesh] + '_PC1and2_byCellType')
plt.show()
#fig.savefig('./buzsaki_plots/' + brainRegion + '_session_' + sessions[sesh] + '_PC1and2_byCellType_sorted.png')

###  Run if PCA is collapsed across cells (so each data point is timepoints)

In [None]:
display_names = []
for left in np.concatenate([[0], np.where(np.diff(stimnums))[0]]):
    name = stimnames[left]
    if name not in display_names:
        display_names.append(name)

fig, ax = plt.subplots(figsize=(22,10))
# Scatter by stimulus presentations
stimScatter = ax.scatter(projections[0,:], projections[1,:], marker='.', c=stimnums, cmap='Dark2', s=80, alpha=0.5)
ax.legend(stimScatter.legend_elements()[0], display_names, fontsize=20)

# Scatter by running speed
runScatter = ax.scatter(projections[0,:], projections[1,:], marker='.', c=running, s=80, alpha=1)
fig.colorbar(runScatter)
plt.title(brainRegion + ' Session ' + sessions[sesh], fontsize=30)
plt.xlabel('First Principle Component', fontsize=20)
plt.ylabel('Second Principle Component', fontsize=20)
fig.show()
fig.savefig('./buzsaki_plots/' + brainRegion + '_session_' + sessions[sesh] + 'PC1and2_projectedthroughtime_byStim_bigLegend.png')

In [None]:
#%matplotlib notebook

display_names = []
for left in np.concatenate([[0], np.where(np.diff(stimnums))[0]]):
    name = stimnames[left]
    if name not in display_names:
        display_names.append(name)

fig = plt.figure(figsize=(22,10))
ax = fig.add_subplot(111, projection='3d')
# Line through time, and plot start and end timepoints as green and red stars, respectively
#ax.plot(projections[0,:], projections[1,:], projections[2,:], c='grey', alpha=0.5, linewidth=1)
ax.scatter(projections[0,0], projections[1,0], projections[2,0], marker='*', c='green', s=100) # start timepoint
ax.scatter(projections[0,-1], projections[1,-1], projections[2,-1], marker='*', c='red', s=100) # end timepoint

# Scatter by stimulus presentations
#stimScatter = ax.scatter(projections[0,:], projections[1,:], projections[2,:], '.', c=stimnums, cmap='Dark2', s=10, alpha=0.5)
#ax.legend(stimScatter.legend_elements()[0], display_names)

# Scatter by running speed
runScatter = ax.scatter(projections[0,:], projections[1,:], projections[2,:], '.', c=running, cmap='viridis', s=10, alpha=1)
fig.colorbar(runScatter)

# Scatter by progression of time
#timeScatter = ax.scatter(projections[0,:], projections[1,:], projections[2,:], '.', c=np.arange(len(projections[0,:])), cmap='inferno', s=10, alpha=0.5)
#fig.colorbar(timeScatter)

ax.set_xlabel('First Principal Component', fontsize=10)
ax.set_ylabel('Second Principal Component', fontsize=10)
ax.set_zlabel('Third Principal Component', fontsize=10)
plt.show()

## Runs through time in 3D PC space

In [None]:
%matplotlib inline
# Watch progression of time in 3D PC space (lol)
display_names = []
for left in np.concatenate([[0], np.where(np.diff(stimnums))[0]]):
    name = stimnames[left]
    if name not in display_names:
        display_names.append(name)

throughTime = projections[:,0].reshape((10,1))
labelsthroughTime = [stimnums[0].copy()]
for i in range(projections.shape[1]-1):
    fig = plt.figure(figsize=(22,10))
    ax = fig.add_subplot(111, projection='3d')
    # Get rid of colored axes planes
    # First remove fill
#     ax.xaxis.pane.fill = False
#     ax.yaxis.pane.fill = False
#     ax.zaxis.pane.fill = False

#     # Now set color to white (or whatever is "invisible")
#     ax.xaxis.pane.set_edgecolor('k')
#     ax.yaxis.pane.set_edgecolor('k')
#     ax.zaxis.pane.set_edgecolor('k')
#     # Bonus: To get rid of the grid as well:
#     ax.grid(False)

    ax.plot(projections[0,:i+1], projections[1,:i+1], projections[2,:i+1], c='grey', alpha=0.4, linewidth=1)
    ax.scatter(projections[0,0], projections[1,0], projections[2,0], marker='*', c='green', s=100) # start timepoint
    stimScatterthroughTime = ax.scatter(projections[0,:i+1], projections[1,:i+1], projections[2,:i+1], '.', c=stimnums[:i+1], cmap='Dark2',vmax=stimnums.max(), s=30, alpha=0.4)
    ax.legend(stimScatter.legend_elements()[0], display_names)
    ax.scatter(projections[0,i], projections[1,i], projections[2,i], '.', c='red', s=30)
    ax.set_xlabel('First Principal Component', fontsize=10)
    ax.set_ylabel('Second Principal Component', fontsize=10)
    ax.set_zlabel('Third Principal Component', fontsize=10)
    #throughTime = np.hstack((throughTime, projections[:,i+1].reshape((10,1))))
    #labelsthroughTime.append(stimnums[i+1])
    #if throughTime.shape[1] >= 100:
    #    throughTime = throughTime[:,-1].reshape((throughTime.shape[0],1))
    #    labelsthroughTime = [stimnums[i+1].copy()]
    display.clear_output(wait=True)
    display.display(pl.gcf())
    time.sleep(0.00001)
    plt.close(fig)

## Run through time with running speed

In [None]:
%matplotlib inline

display_names = []
for left in np.concatenate([[0], np.where(np.diff(stimnums))[0]]):
    name = stimnames[left]
    if name not in display_names:
        display_names.append(name)

throughTime = projections[:,0].reshape((10,1))
labelsthroughTime = [stimnums[0].copy()]
for i in range(projections.shape[1]-1):
    fig, ax = plt.subplots(nrows=2, ncols = 1, figsize=(22,10))
    gs = gridspec.GridSpec(2, 1,height_ratios=[1,1])
    ax1 = plt.subplot(gs[0])
    ax2 = plt.subplot(gs[1])
    ax1 = fig.add_subplot(211, projection='3d')

#     ax1.xaxis.pane.fill = False
#     ax1.yaxis.pane.fill = False
#     ax1.zaxis.pane.fill = False
#     ax1.xaxis.pane.set_edgecolor('w')
#     ax1.yaxis.pane.set_edgecolor('w')
#     ax1.zaxis.pane.set_edgecolor('w')
#     ax1.grid(False)
#     ax1.set_facecolor('black')
#     ax2.set_facecolor('black')
    
    ax1.plot(projections[0,:i+1], projections[1,:i+1], projections[2,:i+1], c='grey', alpha=0.4, linewidth=1)
    ax1.scatter(projections[0,0], projections[1,0], projections[2,0], marker='*', c='green', s=100) # start timepoint
    stimScatterthroughTime = ax1.scatter(projections[0,:i+1], projections[1,:i+1], projections[2,:i+1], '.', c=stimnums[:i+1], cmap='tab10', vmax=stimnums.max(), s=30, alpha=0.5)
    ax1.legend(stimScatter.legend_elements()[0], display_names)
    ax1.scatter(projections[0,i], projections[1,i], projections[2,i], '.', c='red', s=30)
    ax1.set_xlabel('First Principal Component', fontsize=10)
    ax1.set_ylabel('Second Principal Component', fontsize=10)
    ax1.set_zlabel('Third Principal Component', fontsize=10)

    ax2.plot(running[:i+1], c='k')
    ax2.set_xlim(0,len(running))
    ax2.set_ylabel('Running Speed')
    #fig.savefig('./buzsaki_video/PCAthroughTime_' + str(i).zfill(3) + '.png', dpi=80)
    
    display.clear_output(wait=True)
    display.display(pl.gcf())
    time.sleep(0.000000001)
    plt.close(fig)

## Correlation between PC1 and Running

In [None]:
%matplotlib inline
x, y = -projections[0,:], running
slope, intercept, r_value, p_value, std_err = stats.linregress(x,y)
line = slope*x+intercept
# Calculate the point density
xy = np.vstack([x,y])
z = gaussian_kde(xy)(xy)
fig, ax = plt.subplots(figsize=(15,10))
scat = ax.scatter(x, y, c=z, s=50, edgecolor='', cmap='YlGnBu_r')
plt.colorbar(scat)
plt.plot(x,y,'o', x, line, markersize=0, c='black', linewidth=3)
plt.title(brainRegion + ' Session ' + sessions[sesh] + " ; r = " + str(r_value.round(3)), fontsize=20)
plt.ylabel('Running', fontsize=20)
plt.xlabel('PC1', fontsize=20)
plt.show()
#fog.savefig('./buzsaki_plots/' + brainRegion + '_session_' + sessions[sesh] + 'PC1vsRunning.png')
#sns.regplot(x=-projections[0,:], y=running, marker='.').set_title('r = ' + str(np.corrcoef(-projections[0,:], running)[1,0]), c='white')

## Plot histogram of correlations between each cell's firing rate & running speed

In [None]:
plt.hist(np.linalg.lstsq(rates[:,FSUs.astype(bool)], running.reshape(-1,1))[0],20, alpha=0.3, color='red')
plt.hist(np.linalg.lstsq(rates[:,RSUs.astype(bool)], running.reshape(-1,1))[0],20, alpha=0.3, color='blue')
plt.show()
#print(np.mean(np.linalg.lstsq(rates, running.reshape(-1,1))[0]))

###  Run if PCA is collapsed across time (so each data point is cells)

In [None]:
fig = plt.figure(figsize=(22,10))
ax = fig.add_subplot(111, projection='3d')
# Scatter colored by cell type
ax.scatter(projections[0,:], projections[1,:], projections[2,:], c=FSUs, cmap='BrBG', alpha=0.5, linewidth=1)
ax.set_xlabel('First Principal Component', fontsize=10)
ax.set_ylabel('Second Principal Component', fontsize=10)
ax.set_zlabel('Third Principal Component', fontsize=10)
FSU_patch = mpatches.Patch(color='teal', label='FSU')
RSU_patch = mpatches.Patch(color='saddlebrown', label='RSU')
plt.legend(handles=[FSU_patch, RSU_patch], loc='upper right')
plt.show()