In [None]:
import os
import math
import numpy as np
from numpy import genfromtxt
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from mpl_toolkits import mplot3d
from numpy import random as rd
import time
import pylab as pl
from IPython import display
import holoviews as hv
from holoviews import dim, opts
hv.notebook_extension('bokeh')
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

## Load units

In [None]:
brainRegion = 'VISp'
sessions = []
for file in os.listdir('./buzsaki_data/'):
    if brainRegion in file:
        if 'firingrates' in file:
            sessions.append(file.split('session_')[1][:9])

sesh = 2
for file in os.listdir('./buzsaki_data/'):
    if brainRegion in file:
        if sessions[sesh] in file:
            if 'firingrates' in file:
                units = pd.read_csv('./buzsaki_data/' + file)
                unit_id = units.iloc[:,0]
                cell_type = units.iloc[:,1]
                firingRates = units.iloc[:,2:].to_numpy()
            elif 'runningSpeed' in file:
                running = np.genfromtxt('./buzsaki_data/' + file)
            elif 'labels' in file:
                stimnums = np.genfromtxt('./buzsaki_data/' + file)
            elif 'labelNames' in file:
                stimnames = np.genfromtxt('./buzsaki_data/' + file, dtype=str)

firstNonNaNindex = np.where(np.isnan(running))[0][-1] + 1
firingRates = firingRates[:,firstNonNaNindex:len(running)]
stimnums = stimnums[firstNonNaNindex:len(running)]
stimnames = stimnames[firstNonNaNindex:len(running)]
running = running[firstNonNaNindex:]

In [None]:
%matplotlib inline
FSUs = np.zeros(firingRates.shape[0])
RSUs = np.zeros(firingRates.shape[0])
for j in range(firingRates.shape[0]):
    if cell_type[j] == 'FSU':
        FSUs[j] = 1
    else:
        RSUs[j] = 1

# Plot all currents as heatmap
plt.figure(figsize=(25,5))
plt.imshow(firingRates, aspect='auto', cmap='viridis')
plt.colorbar()
plt.show()

# Plot individual currents
plt.figure(figsize=(20,8))
for i in range(5):
    plt.plot(firingRates[i,:] + i*firingRates.max()/10, linewidth=1)
plt.ylabel("Current")
plt.xlabel("Time")
plt.show()

## Calculate covariance matrix and eigenvalues/vectors
> ### choose number of dimensions to create projection of

In [None]:
rates = firingRates.T
rates -= np.mean(rates, axis=0)
covMat = (1.0/(rates.shape[0]-1))*(rates.T @ rates) # covariance matrix
evalues, evectors = np.linalg.eig(covMat) # eigenvalues and eigenvectors
numProjections = 10 # number of dimensions
projections = np.zeros((numProjections,rates.shape[0]))
for x in range(numProjections):
    projections[x,:] = np.dot(rates, evectors.T[x]) # projections of principal components onto firing rates
#%matplotlib inline
#plt.figure(figsize=(10,5))
#plt.scatter(x=np.arange(len(evalues)), y=evalues, marker='o', s=10)

In [None]:
plt.figure(figsize=(20,5))
plt.scatter(np.arange(len(evalues)), evalues/np.sum(evalues))
plt.show()

In [None]:
numvectors = 5
fig, ax = plt.subplots(nrows=numvectors, ncols=1, figsize=(20,8))
for n in range(numvectors):
    ax[n].plot(-evectors[:,n])
fig.show()

In [None]:
%matplotlib notebook

display_names = []
for left in np.concatenate([[0], np.where(np.diff(stimnums))[0]]):
    name = stimnames[left]
    if name not in display_names:
        display_names.append(name)

fig = plt.figure(figsize=(20,10))
ax = fig.add_subplot(111, projection='3d')
ax.plot(projections[0,:], projections[1,:], projections[2,:], c='grey', alpha=0.5, linewidth=1)
#scatter = ax.scatter(projections[0,:], projections[1,:], projections[2,:], '.', c=stimnums, cmap='Dark2', s=10, alpha=0.5)
#ax.legend(scatter.legend_elements()[0], display_names)
#ax.scatter(projections[0,:], projections[1,:], projections[2,:], '.', c=running, s=10, alpha=1)
ax.scatter(projections[0,0], projections[1,0], projections[2,0], marker='*', c='green', s=100) # start
ax.scatter(projections[0,-1], projections[1,-1], projections[2,-1], marker='*', c='red', s=100) # end
ax.set_xlabel('First Principal Component', fontsize=10)
ax.set_ylabel('Second Principal Component', fontsize=10)
ax.set_zlabel('Third Principal Component', fontsize=10)
#FSU_patch = mpatches.Patch(color='teal', label='FSU')
#RSU_patch = mpatches.Patch(color='saddlebrown', label='RSU')
#plt.legend(handles=[FSU_patch, RSU_patch], loc='upper right')
plt.show()