In [None]:
# !pip install mindscope_utilities --upgrade
# !pip install allensdk==2.13.4

Collecting allensdk==2.13.4
  Using cached allensdk-2.13.4-py3-none-any.whl.metadata (2.0 kB)
Collecting matplotlib<3.4.3,>=1.4.3 (from allensdk==2.13.4)
  Using cached matplotlib-3.4.2.tar.gz (37.3 MB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting jinja2<2.12.0,>=2.7.3 (from allensdk==2.13.4)
  Using cached Jinja2-2.11.3-py2.py3-none-any.whl.metadata (3.5 kB)
Collecting pynrrd<1.0.0,>=0.2.1 (from allensdk==2.13.4)
  Using cached pynrrd-0.4.3-py2.py3-none-any.whl.metadata (3.8 kB)
Collecting future<1.0.0,>=0.14.3 (from allensdk==2.13.4)
  Using cached future-0.18.3.tar.gz (840 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting requests-toolbelt<1.0.0 (from allensdk==2.13.4)
  Using cached requests_toolbelt-0.10.1-py2.py3-none-any.whl.metadata (14 kB)
Collecting scikit-image<0.17.0,>=0.14.0 (from allensdk==2.13.4)
  Using cached scikit-image-0.16.2.tar.gz (28.9 MB)
  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mpython set

In [None]:
from allensdk.brain_observatory.ecephys.ecephys_project_cache import EcephysProjectCache

import os
import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:

data_dir = './data'
manifest_path = os.path.join(data_dir, 'manifest.json')

cache = EcephysProjectCache.from_warehouse(manifest=manifest_path)

In [None]:
#downloads data - takes 5min
sessions = cache.get_session_table()

In [None]:
len(sessions)

58

In [None]:
sessions.columns

Index(['published_at', 'specimen_id', 'session_type', 'age_in_days', 'sex',
       'full_genotype', 'unit_count', 'channel_count', 'probe_count',
       'ecephys_structure_acronyms'],
      dtype='object')

In [None]:


i = 4

session = cache.get_session_data(sessions.index.values[i],
                                 isi_violations_maximum = np.inf,
                                 amplitude_cutoff_maximum = np.inf,
                                 presence_ratio_minimum = -np.inf,
                                 timeout = None)



Downloading:   0%|          | 0.00/2.18G [00:00<?, ?B/s]

core - cached version: 2.2.2, loaded version: 2.7.0
  self.warn_for_ignored_namespaces(ignored_namespaces)


In [None]:
session?

In [None]:
# session.metadata #takes too long to run everytime

In [None]:
units = session.units

In [None]:
print('Total number of units:' + str(len(units)))
units.head()

In [None]:
#filtering units
V1_units = units[(units.ecephys_structure_acronym == 'VISp') &
                 (units.isi_violations < 0.1)]

print('Total number of low contamination units:' + str(len(V1_units)))

In [None]:
V1_units.index.values

In [None]:
session.stimulus_names

In [None]:
stim_table = session.get_stimulus_table(['flashes'])
stim_table

In [None]:
def get_bins(stim_time, bin_w = 10, start_adj = 500, end_adj = 2000):

    stim_time_ms = int(stim_time * 1000)
    start_ms = stim_time_ms - 500
    end_ms = stim_time_ms + 2000

    return np.arange(start_ms, end_ms, bin_w)

In [None]:
firing_rates_all_stim = []
#likely inefficient
for stim in tqdm.tqdm(stim_table.start_time.to_numpy()):
    bins = get_bins(stim) #timepoints of data

    #array for storing spike data around one stimulus (units x bins)
    firing_rates = np.zeros((len(V1_units.index.values), bins.shape[0]+2))

# # #loading spike times
    for i, unit in enumerate(V1_units.index.values):
        all_spikes = session.spike_times[unit]
        stim_spikes = all_spikes[(all_spikes >= stim - 0.5)&(all_spikes < stim + 2)]*1000
        rate = np.digitize(stim_spikes, bins)
        spikebin, count = np.unique(rate, return_counts = True)
        firing_rates[i][spikebin] = count
    firing_rates_all_stim.append(firing_rates)

firing_rates_all_stim = np.array(firing_rates_all_stim)

In [None]:
plt.imshow(np.mean(firing_rates_all_stim, axis = 0))

In [None]:
plt.imshow(firing_rates_all_stim[0,:,:])
mean = np.mean(firing_rates_all_stim[0,:,:], axis=1)
np.argmax(mean)

In [None]:
firing_rates_all_stim.shape

In [None]:
plt.hist(np.mean(np.sum(firing_rates_all_stim, axis = 2), axis = 0)/2.5)

In [None]:
plt.figure(figsize = (10, 6))
plt.imshow(firing_rates_all_stim[0, :, :], cmap='gray_r', vmax = 3, vmin=0, aspect='auto')
plt.xlabel('Time (ms)')
plt.ylabel('Cell #')
plt.colorbar(orientation='vertical', label='# Spikes in 0.01 s time bin')
plt.title('Example trial')
plt.show()

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# we will pretend like every "bin" is 10ms, so the trial length is 2500ms
NT = 100

# let's use 10 latent components
ncomp = 36

x = torch.from_numpy(firing_rates_all_stim)
x = x.permute(2, 0, 1)

In [None]:
plt.imshow(x[:,0,:].T)

In [None]:
x.shape

In [None]:
class Net(nn.Module):
  def __init__(self, ncomp, NN1, NN2, bidi=True):
    super(Net, self).__init__()

    # play with some of the options in the RNN!
    self.rnn = nn.RNN(NN1, ncomp, num_layers = 10, dropout = 0,
                      bidirectional = bidi, nonlinearity = 'tanh')
    self.fc = nn.Linear(ncomp, NN2)

  def forward(self, x):

    y = self.rnn(x)[0]

    if self.rnn.bidirectional:
      # if the rnn is bidirectional, it concatenates the activations from the forward and backward pass
      # we want to add them instead, so as to enforce the latents to match between the forward and backward pass
      q = (y[:, :, :ncomp] + y[:, :, ncomp:])/2
    else:
      q = y

    # the softplus function is just like a relu but it's smoothed out so we can't predict 0
    # if we predict 0 and there was a spike, that's an instant Inf in the Poisson log-likelihood which leads to failure
    z = F.softplus(self.fc(q), 10)

    return z, q

In [None]:
NN = x.shape[-1]

# we separate the neuron data into two populations: the input and output
x0 = x[:, :, :NN//2].to(device).float()
x1 = x[:, :, NN//2:].to(device).float()

NN1 = x1.shape[-1]
NN2 = x0.shape[-1]

# we initialize the neural network
net = Net(ncomp, NN1, NN2, bidi = True).to(device)

# special thing:  we initialize the biases of the last layer in the neural network
# we set them as the mean firing rates of the neurons.
# this should make the initial predictions close to the mean, because the latents don't contribute much
net.fc.bias.data[:] = x0.mean((0,1))

# we set up the optimizer. Adjust the learning rate if the training is slow or if it explodes.
optimizer = torch.optim.Adam(net.parameters(), lr=.001)

In [None]:
import scipy.stats

def rate_correlation(z, x0):
    z_np = z.detach().cpu().numpy()
    x_np = x0.detach().cpu().numpy()

    # Compute correlation per neuron, averaged across trials
    n_neurons = x_np.shape[-1]
    corr_list = []
    for n in range(n_neurons):
        corr_trialwise = []
        for t in range(x_np.shape[1]):  # loop over trials
            corr, _ = scipy.stats.pearsonr(z_np[:, t, n], x_np[:, t, n])
            if not np.isnan(corr):
                corr_trialwise.append(corr)
        if corr_trialwise:
            corr_list.append(np.mean(corr_trialwise))

    return np.mean(corr_list)  # average across neurons

In [None]:
def binary_accuracy(z, x_spikes, threshold=0.5):
    pred_spikes = (z > threshold).float()
    true_spikes = (x_spikes > 0).float()
    correct = (pred_spikes == true_spikes).float()
    return correct.mean().item()

In [None]:
# you can keep re-running this cell if you think the cost might decrease further

# we define the Poisson log-likelihood loss
def Poisson_loss(lam, spk):
  lam = torch.clamp(lam, min=1e-3)
  return lam - spk * torch.log(lam + 1e-8)
loss_history = []
acc_history = []
corr_history = []
niter = 6000
for k in range(niter):
  # the network outputs the single-neuron prediction and the latents
  z, q = net(x1)

  # our log-likelihood cost
  cost = Poisson_loss(z, x0).mean()
  acc = binary_accuracy(z, x0)
  corr = rate_correlation(z, x0)

  # train the network as usual
  cost.backward()
  optimizer.step()
  optimizer.zero_grad()

  loss_history.append(cost.item())
  acc_history.append(acc)
  corr_history.append(corr)

  if k % 500 == 0:
    print(f'iteration {k}, cost {cost.item():.4f}')
    print(f"Accuracy: {acc:.3f} | Rate Corr: {corr:.3f}")


In [None]:
plt.figure(figsize=(8, 4))
plt.plot(loss_history)
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.title("Training Loss Over Time")
plt.grid(True)
plt.show()

In [None]:
rpred = z.detach().cpu().numpy()
rates = x0.detach().cpu().numpy()
nn_idx = 18

plt.figure(figsize=(10, 6))
plt.plot(rates[:, 0, nn_idx], label='rates (true)')
plt.plot(rpred[:, 0, nn_idx], label='rates (predicted)')
plt.plot(-0.5 + x[:, 0, nn_idx].cpu().numpy() / 4, label='spikes')
plt.legend()
plt.title(f'Neuron {nn_idx}')
plt.show()

In [None]:
plt.figure(figsize = (12, 8))
plt.subplot(121)
plt.imshow(rates[:, 12, :25].T, cmap='gray_r')
plt.xlabel('Time (ms)')
plt.ylabel('Cell #')
plt.title('True rates (trial 12)')

plt.subplot(122)
plt.imshow(rpred[:, 12, :].T, cmap='gray_r')
plt.xlabel('Time (ms)')
plt.ylabel('Cell #')
plt.title('Inferred rates (trial 12)')
plt.show()

In [None]:
qcpu = q.detach().cpu().numpy()

plt.figure(figsize=(20, 4))
plt.subplot(121)
plt.plot(qcpu[:, 0, :]);
plt.title('All latents on trial 0')

plt.subplot(122)
plt.plot(qcpu[:, :, 0]);
plt.title('All trials for latent 0')
plt.show()

In [None]:
import pickle

In [None]:
with open('fname.pkl', 'wb') as f: pickle.dump(net, f)