In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import zscore
from sklearn.decomposition import PCA

In [None]:
# @title Figure settings
from matplotlib import rcParams

rcParams['figure.figsize'] = [20, 4]
rcParams['font.size'] = 15
rcParams['axes.spines.top'] = False
rcParams['axes.spines.right'] = False
rcParams['figure.autolayout'] = True

In [None]:
# @title Data retrieval
import os, requests

fname = []
for j in range(3):
  fname.append('steinmetz_part%d.npz'%j)
url = ["https://osf.io/agvxh/download"]
url.append("https://osf.io/uv3mw/download")
url.append("https://osf.io/ehmw2/download")

for j in range(len(url)):
  if not os.path.isfile(fname[j]):
    try:
      r = requests.get(url[j])
    except requests.ConnectionError:
      print("!!! Failed to download data !!!")
    else:
      if r.status_code != requests.codes.ok:
        print("!!! Failed to download data !!!")
      else:
        with open(fname[j], "wb") as fid:
          fid.write(r.content)

In [None]:
# @title Data loading
alldat = np.array([])
for j in range(len(fname)):
  alldat = np.hstack((alldat,
                      np.load('steinmetz_part%d.npz'%j,
                              allow_pickle=True)['dat']))

In [None]:
# Make a plot of which brain areas are present in each dataset
# note that region 4 ("other ctx" are neurons that were not able to be classified)
# region 4 does not correspond to brain_group 4, which are all cortical neurons outside of visual cortex
regions = ["vis ctx", "thal", "hipp", "other ctx", "midbrain", "basal ganglia", "cortical subplate", "other"]
region_colors = ['blue', 'red', 'green', 'darkblue', 'violet', 'lightblue', 'orange', 'gray']
brain_groups = [["VISa", "VISam", "VISl", "VISp", "VISpm", "VISrl"],  # visual cortex
                ["CL", "LD", "LGd", "LH", "LP", "MD", "MG", "PO", "POL", "PT", "RT", "SPF", "TH", "VAL", "VPL", "VPM"], # thalamus
                ["CA", "CA1", "CA2", "CA3", "DG", "SUB", "POST"],  # hippocampal
                ["ACA", "AUD", "COA", "DP", "ILA", "MOp", "MOs", "OLF", "ORB", "ORBm", "PIR", "PL", "SSp", "SSs", "RSP","TT"],  # non-visual cortex
                ["APN", "IC", "MB", "MRN", "NB", "PAG", "RN", "SCs", "SCm", "SCig", "SCsg", "ZI"],  # midbrain
                ["ACB", "CP", "GPe", "LS", "LSc", "LSr", "MS", "OT", "SNr", "SI"],  # basal ganglia
                ["BLA", "BMA", "EP", "EPd", "MEA"]  # cortical subplate
                ]

# Assign each area an index
area_to_index = dict(root=0)
counter = 1
for group in brain_groups:
    for area in group:
        area_to_index[area] = counter
        counter += 1

# Figure out which areas are in each dataset
areas_by_dataset = np.zeros((counter, len(alldat)), dtype=bool)
for j, d in enumerate(alldat):
    for area in np.unique(d['brain_area']):
        i = area_to_index[area]
        areas_by_dataset[i, j] = True

# Show the binary matrix
plt.figure(figsize=(8, 10))
plt.imshow(areas_by_dataset, cmap="Greys", aspect="auto", interpolation="none")

# Label the axes
plt.xlabel("dataset")
plt.ylabel("area")

# Add tick labels
yticklabels = ["root"]
for group in brain_groups:
  yticklabels.extend(group)
plt.yticks(np.arange(counter), yticklabels, fontsize=8)
plt.xticks(np.arange(len(alldat)), fontsize=9)

# Color the tick labels by region
ytickobjs = plt.gca().get_yticklabels()
ytickobjs[0].set_color("black")
counter = 1
for group, color in zip(brain_groups, region_colors):
  for area in group:
    ytickobjs[counter].set_color(color)
    counter += 1

plt.title("Brain areas present in each dataset")
plt.grid(True)
plt.show()

In [None]:
# @title Basic plots of population average

# select just one of the recordings here. 11 is nice because it has some neurons in vis ctx.
dat = alldat[11]
print(dat.keys())

dt = dat['bin_size']  # binning at 10 ms
NT = dat['spks'].shape[-1]

ax = plt.subplot(1, 5, 1)
response = dat['response']  # right - nogo - left (-1, 0, 1)
vis_right = dat['contrast_right']  # 0 - low - high
vis_left = dat['contrast_left']  # 0 - low - high
plt.plot(dt * np.arange(NT), 1/dt * dat['spks'][:, response >= 0].mean(axis=(0, 1)))  # left responses
plt.plot(dt * np.arange(NT), 1/dt * dat['spks'][:, response < 0].mean(axis=(0, 1)))  # right responses
plt.plot(dt * np.arange(NT), 1/dt * dat['spks'][:, vis_right > 0].mean(axis=(0, 1)))  # stimulus on the right
plt.plot(dt * np.arange(NT), 1/dt * dat['spks'][:, vis_right == 0].mean(axis=(0, 1)))  # no stimulus on the right

plt.legend(['left resp', 'right resp', 'right stim', 'no right stim'], fontsize=12)
ax.set(xlabel='time (sec)', ylabel='firing rate (Hz)')
plt.show()

In [None]:
nareas = 4  # only the top 4 regions are in this particular mouse
NN = len(dat['brain_area'])  # number of neurons
barea = nareas * np.ones(NN, )  # last one is "other"
for j in range(nareas):
  barea[np.isin(dat['brain_area'], brain_groups[j])] = j  # assign a number to each region

In [None]:
# @title plots by brain region and visual conditions
for j in range(nareas):
  ax = plt.subplot(1, nareas, j + 1)

  plt.plot(1/dt * dat['spks'][barea==j][:, np.logical_and(vis_left == 0, vis_right > 0)].mean(axis=(0, 1)))
  plt.plot(1/dt * dat['spks'][barea==j][:, np.logical_and(vis_left > 0, vis_right == 0)].mean(axis=(0, 1)))
  plt.plot(1/dt * dat['spks'][barea==j][:, np.logical_and(vis_left == 0, vis_right == 0)].mean(axis=(0, 1)))
  plt.plot(1/dt * dat['spks'][barea==j][:, np.logical_and(vis_left > 0, vis_right > 0)].mean(axis=(0, 1)))
  plt.text(.25, .92, 'n=%d'%np.sum(barea == j), transform=ax.transAxes)

  if j==0:
    plt.legend(['right only', 'left only', 'neither', 'both'], fontsize=12)
  ax.set(xlabel='binned time', ylabel='mean firing rate (Hz)', title=regions[j])
plt.show()

In [None]:
brain_groups_all = list(np.hstack(brain_groups))
print(len(brain_groups_all))

In [None]:
for key in alldat[0].keys():
    print(key)

In [None]:
#Generate some metadata for the neuron regions.

def generate_metadata(dat):
    n_neurons = (len(dat['brain_area']))
    n_regions = len(regions)
    region_index = np.zeros(n_neurons)
    group_index = np.zeros(n_neurons)
    for region in range(len(regions)-1):
        region_index[np.isin(dat['brain_area'], brain_groups[region])] = region
    for group in range(len(brain_groups_all)):
        group_index[np.where(dat['brain_area']==brain_groups_all[group])[0]] = group
    return n_neurons, region_index, group_index

In [None]:
np.shape(alldat[0]['spks'])

In [None]:
#Choose dataset
session =3
dat = alldat[session]
dt = dat['bin_size']
no_of_bins = np.shape(dat['spks'])[2]
n_neurons, region_index, group_index = generate_metadata(dat)
spikes_all = dat['spks']
spikes_passive = dat['spks_passive']
groups_present = np.unique(dat['brain_area'])
n_group_present = len(np.unique(dat['brain_area']))

CR_idx = np.where(dat['feedback_type']==1)[0]
WR_idx = np.where(dat['feedback_type']==-1)[0]
spikes_CR = dat['spks'][:, CR_idx, :]
spikes_WR = dat['spks'][:, WR_idx, :]

In [None]:
print(np.unique(group_index))

In [None]:
group_activation_CR = np.zeros((n_group_present, no_of_bins))
group_activation_WR = np.zeros((n_group_present, no_of_bins))
for idx, group in enumerate(groups_present):
    CR_spikes = np.mean(spikes_CR[np.where(dat['brain_area'] == group)[0],:,:],axis=(0,1))#/np.mean(spikes_passive[np.where(dat['brain_area'] == group)[0],:,:],axis=(0,1,2))
    WR_spikes = np.mean(spikes_WR[np.where(dat['brain_area'] == group)[0],:,:],axis=(0,1))#/np.mean(spikes_passive[np.where(dat['brain_area'] == group)[0],:,:],axis=(0,1,2))

    #Normalize
    CR_spikes = (CR_spikes - np.min(CR_spikes))/np.ptp(CR_spikes)
    WR_spikes = (WR_spikes - np.min(WR_spikes))/np.ptp(WR_spikes)

    group_activation_CR[idx,:] = CR_spikes
    group_activation_WR[idx,:] = WR_spikes

plt.pcolormesh(group_activation_CR)
plt.yticks(np.arange(0.5,len(groups_present),1),groups_present)
plt.ylabel('Brain Group')
plt.xlabel('Time Bins (10ms)')
plt.title('Spiking Activity during trials with correct response in session '+str(session))
plt.colorbar()
plt.show()
plt.pcolormesh(group_activation_WR)
plt.yticks(np.arange(0.5,len(groups_present),1),groups_present)
plt.ylabel('Brain Group')
plt.xlabel('Time Bins (10ms)')
plt.title('Spiking Activity during trials with incorrect response in session '+str(session))
plt.colorbar()
plt.show()