In [1]:
pip install allensdk



In [2]:
import os
import shutil
import numpy as np
import pandas as pd
from pathlib import Path

from allensdk.brain_observatory.ecephys.ecephys_project_cache import EcephysProjectCache

In [3]:
import matplotlib.pyplot as plt
from scipy.linalg import svd
import time
from scipy.ndimage import maximum_filter, minimum_filter
from scipy.ndimage import maximum_filter1d, minimum_filter1d

In [4]:
pip install --upgrade allensdk



## Load data

In [5]:
output_dir = '/content/drive/MyDrive/ecephys_cache_dir'
os.makedirs(output_dir, exist_ok=True)
DOWNLOAD_COMPLETE_DATASET = True
manifest_path = os.path.join(output_dir, "manifest.json")

In [6]:
cache = EcephysProjectCache.from_warehouse(manifest=manifest_path)
print(cache)

<allensdk.brain_observatory.ecephys.ecephys_project_cache.EcephysProjectCache object at 0x7bfd8c587b10>


In [7]:
sessions = cache.get_session_table()
print('Total number of sessions: ' + str(len(sessions)))
sessions.head()

Total number of sessions: 58


Unnamed: 0_level_0,published_at,specimen_id,session_type,age_in_days,sex,full_genotype,unit_count,channel_count,probe_count,ecephys_structure_acronyms
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
715093703,2019-10-03T00:00:00Z,699733581,brain_observatory_1.1,118.0,M,Sst-IRES-Cre/wt;Ai32(RCL-ChR2(H134R)_EYFP)/wt,884,2219,6,"[CA1, VISrl, nan, PO, LP, LGd, CA3, DG, VISl, ..."
719161530,2019-10-03T00:00:00Z,703279284,brain_observatory_1.1,122.0,M,Sst-IRES-Cre/wt;Ai32(RCL-ChR2(H134R)_EYFP)/wt,755,2214,6,"[TH, Eth, APN, POL, LP, DG, CA1, VISpm, nan, N..."
721123822,2019-10-03T00:00:00Z,707296982,brain_observatory_1.1,125.0,M,Pvalb-IRES-Cre/wt;Ai32(RCL-ChR2(H134R)_EYFP)/wt,444,2229,6,"[MB, SCig, PPT, NOT, DG, CA1, VISam, nan, LP, ..."
732592105,2019-10-03T00:00:00Z,717038288,brain_observatory_1.1,100.0,M,wt/wt,824,1847,5,"[grey, VISpm, nan, VISp, VISl, VISal, VISrl]"
737581020,2019-10-03T00:00:00Z,718643567,brain_observatory_1.1,108.0,M,wt/wt,568,2218,6,"[grey, VISmma, nan, VISpm, VISp, VISl, VISrl]"


## Choose specific session

In [8]:
session = cache.get_session_data(715093703)

  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."


## Extract LFPs for the first probe of the session

In [9]:
probe_id = session.probes.index.values[0]
lfp = session.get_lfp(probe_id)

  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."


## Load spike timings for all detected neurons

In [10]:
spike_times = session.spike_times

  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."


In [11]:
lfp = lfp.T
lfp_array = lfp.values
lfp_time = lfp.coords["time"].values
lfp_channel_ids = lfp.coords["channel"].values

## Keep only neurons that have been detected using the channels of the probe of interest

In [12]:
units = cache.get_units()
valid_units = units[units['ecephys_channel_id'].isin(lfp_channel_ids)]

## Solving the PCP problem using Proximal Gradient

In [13]:
#Proximal operator for the nuclear norm
def prox_nuclear(X, gamma):
    U, S, Vt = svd(X, full_matrices=False)
    S = np.diag(np.maximum(S - gamma, 0))
    return U @ S @ Vt

#Proximal operator for the l1 norm
def prox_l1(X, lambda_):
    return np.sign(X) * np.maximum(np.abs(X) - lambda_, 0)

def PCP_by_PG(X, lambda_, gamma, lr=1.0, max_iter=100, tol=1e-6):
    L = np.zeros_like(X)
    S = np.zeros_like(X)
    objective_values = []
    residual_values = []

    for iteration in range(max_iter):
        G = L + S - X

        # Update L
        L_new = prox_nuclear(L - lr * G, lr * gamma)
        # Update S
        S_new = prox_l1(S - lr * G, lr * lambda_)

        # Compute residual and objective value
        residual = X - L_new - S_new
        objective_value = 0.5 * np.linalg.norm(residual, 'fro')**2 + gamma * np.linalg.norm(L_new, 'nuc') + lambda_ * np.sum(np.abs(S_new))
        residual_values.append(np.linalg.norm(residual, 'fro'))
        objective_values.append(objective_value)

        # Check convergence
        if np.linalg.norm(L_new - L, 'fro') < tol and np.linalg.norm(S_new - S, 'fro') < tol:
            print(f'Converged in {iteration + 1} iterations')
            break

        L, S = L_new, S_new

        # if (iteration + 1) % 10 == 0:
        #     print(f'Iter {iteration+1}: Obj={objective_value:.4e}, Residual={np.linalg.norm(residual, "fro"):.4e}')

    # Plotting
    # plt.figure(figsize=(12, 5))
    # plt.subplot(1, 2, 1)
    # plt.plot(objective_values, label='Objective Value')
    # plt.xlabel('Iteration')
    # plt.ylabel('Objective Value')
    # plt.title('Objective Function vs Iteration')
    # plt.grid(True)
    # plt.legend()

    # plt.subplot(1, 2, 2)
    # plt.plot(residual_values, label='Residual Norm')
    # plt.xlabel('Iteration')
    # plt.ylabel('Frobenius Norm of Residual')
    # plt.title('Reconstruction Error vs Iteration')
    # plt.grid(True)
    # plt.legend()

    # plt.tight_layout()
    # plt.show()

    # print(f'Iter {iteration+1}: Obj={objective_value:.4e}, Residual={np.linalg.norm(residual, "fro"):.4e}')

    return L, S, objective_values, residual_values


## Solving the PCP problem using ADMM

In [14]:
def PCP_by_ADMM(X, lambda_, mu, max_iter=100, tol=1e-6):
    L = np.zeros_like(X)
    S = np.zeros_like(X)
    Y = np.zeros_like(X)
    objective_values = []
    reconstruction_errors = []

    for iteration in range(max_iter):
        # Update L
        L = prox_nuclear(X - S + (1/mu)*Y, 1/mu)

        # Update S
        S = prox_l1(X - L + (1/mu)*Y, lambda_/mu)

        # Update dual variable
        residual = X - L - S
        Y = Y + mu * residual

        # Objective function
        objective_value = np.linalg.norm(L, 'nuc') + lambda_ * np.linalg.norm(S, 1)
        objective_values.append(objective_value)

        reconstruction_error = np.linalg.norm(residual, 'fro')
        reconstruction_errors.append(reconstruction_error)

        # Check convergence
        if np.linalg.norm(residual, 'fro') < tol:
            #print(f'ADMM converged in {iteration + 1} iterations')
            break

        # if (iteration+1) % 10 == 0:
        #   print(f'Iteration {iteration + 1}, Objective Value: {objective_value}')

    # plt.figure(figsize=(12, 5))
    # plt.subplot(1, 2, 1)
    # plt.plot(objective_values, label='Objective Value')
    # plt.xlabel('Iteration')
    # plt.ylabel('Objective Value')
    # plt.title('Objective Function vs Iteration')
    # plt.grid(True)
    # plt.legend()

    # plt.subplot(1, 2, 2)
    # plt.plot(reconstruction_errors, label='Residual Norm')
    # plt.xlabel('Iteration')
    # plt.ylabel('Frobenius Norm of Residual')
    # plt.title('Reconstruction Error vs Iteration')
    # plt.grid(True)
    # plt.legend()

    # print(f'Iter {iteration+1}: Obj={objective_value:.4e}, Residual={np.linalg.norm(residual, "fro"):.4e}')

    return L, S, objective_values, reconstruction_errors


## Solving the PCP problem using GoDec

In [15]:
# Keep only the top-k largest (in absolute value) elements of S
def hard_threshold(S, k):
    flat = np.abs(S).flatten()
    if k >= flat.size:
        return S
    threshold = np.partition(flat, -k)[-k]
    return np.where(np.abs(S) >= threshold, S, 0)

def PCP_by_GoDec(X, rank_r, k, max_iter=100, tol=1e-6):
    L = np.zeros_like(X)
    S = np.zeros_like(X)
    objective_values = []

    for i in range(max_iter):
        # Low-rank approximation via truncated SVD
        Y = X - S
        U, s, Vt = svd(Y, full_matrices=False)
        s[rank_r:] = 0
        L_new = (U * s) @ Vt

        # Sparse approximation via hard thresholding
        R = X - L_new
        S_new = hard_threshold(R, k)

        # Objective (Frobenius norm of the residual)
        residual = X - L_new - S_new
        obj = np.linalg.norm(residual, 'fro')
        objective_values.append(obj)

        if np.linalg.norm(L_new - L, 'fro') < tol and np.linalg.norm(S_new - S, 'fro') < tol:
            #print(f"Converged in {i+1} iterations.")
            break

        L, S = L_new, S_new

    #     if (i + 1) % 10 == 0:
    #         print(f"Iteration {i+1}, Residual Fro Norm: {obj:.4e}")

    # plt.plot(objective_values)
    # plt.xlabel("Iteration")
    # plt.ylabel("Residual Norm")
    # plt.title("GoDec Objective vs Iteration")
    # plt.grid(True)
    # plt.show()

    # print(f"Iteration {i+1}, Residual Fro Norm: {obj:.4e}")

    return L, S, objective_values


In [16]:
# Identify local maxima (positive or negative) in a 2D array and return a binary mask.
def local_maxima_binary(S):

    binary_S = np.zeros_like(S, dtype=int)

    # Loop through each row to detect local maxima
    for i in range(S.shape[0]):
        row = S[i, :]
        size = 10

        local_max_pos = (row == maximum_filter1d(row, size=size, mode='constant', cval=0)) & (row > 0)
        local_max_neg = (row == minimum_filter1d(row, size=size, mode='constant', cval=0)) & (row < 0)
        binary_S[i, :] = (local_max_pos | local_max_neg).astype(int)

    return binary_S


In [17]:
# Get the timing of the predicted spikes
def get_predicted_spikes(S_thresholded, lim_inf, lim_sup):
  S = np.array(S_thresholded)
  sub_lfp_time = lfp_time[lim_inf:lim_sup]

  neuron_idx, time_idx = np.where(S == 1)
  spike_times = sub_lfp_time[time_idx]
  spike_times_detected = sorted(set(spike_times))

  return spike_times_detected

# Get the timing of the ground truth spikes from Kilosort
def get_gt_spikes(valid_units, spike_times, lim_inf, lim_sup):
  gt_spikes = []
  sub_lfp_time = lfp_time[lim_inf:lim_sup]

  for neuron_id in valid_units.index:
      gt_spikes.append(spike_times[neuron_id])

  all_spikes = [spike_time for neuron_spikes in gt_spikes for spike_time in neuron_spikes if spike_time < sub_lfp_time[-1]]
  spike_times_groundtruth = sorted(set(all_spikes))

  return spike_times_groundtruth

# Supervised analysis of spikes detection
def get_metrics(spike_times_detected, spike_times_groundtruth):

  gt_spikes = np.array(spike_times_groundtruth)
  det_spikes = np.array(spike_times_detected)

  gt_spikes.sort()
  det_spikes.sort()

  tolerance = 0.001
  gt_matched = np.zeros(len(gt_spikes), dtype=bool)
  tp = 0

  for d in det_spikes:
      idx = np.searchsorted(gt_spikes, d - tolerance, side='left')
      while idx < len(gt_spikes) and gt_spikes[idx] <= d + tolerance:
          if not gt_matched[idx]:
              gt_matched[idx] = True
              tp += 1
              break
          idx += 1

  fp = len(det_spikes) - tp
  fn = len(gt_spikes) - tp

  precision = tp / (tp + fp) if (tp + fp) > 0 else 0
  recall = tp / (tp + fn) if (tp + fn) > 0 else 0
  f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

  return tp, fp, fn, precision, recall, f1

# Estimate rank to preserve a specified amount of energy (variance)
def estimate_rank(X, energy_threshold=0.9):
    _, S, _ = svd(X, full_matrices=False)
    cumulative_energy = np.cumsum(S**2) / np.sum(S**2)
    rank_r = np.searchsorted(cumulative_energy, energy_threshold) + 1
    return rank_r

# Unsupervised analysis
def reco(L, S, X):
  reconstruction = L + S
  reconstruction_error = np.linalg.norm(X - reconstruction, 'fro') / np.linalg.norm(X, 'fro')

  energy_ratio = np.linalg.norm(S, 'fro')**2 / np.linalg.norm(X, 'fro')**2

  return reconstruction_error, energy_ratio

## parameter combination testing

In [27]:
def tune_param(method, lim_inf, lim_sup, lfp_array, max_iter):

  results = []
  spike_times_groundtruth = get_gt_spikes(valid_units, spike_times, lim_inf, lim_sup)

  X = lfp_array[:, lim_inf:lim_sup]
  X = (X - np.mean(X)) / np.std(X)

  if method == 'PCP_by_PG':
    lambda_ = 1.0 / np.sqrt(max(X.shape))
    param1_list = [lambda_*0.8, lambda_, lambda_*1.2]
    gamma = X.size / (4 * np.sum(np.abs(X)))
    param2_list = [gamma*0.8, gamma, gamma*1.2]

  elif method == 'PCP_by_ADMM':
    lambda_ = 1.0 / np.sqrt(max(X.shape))
    param1_list = [lambda_, lambda_*2, lambda_*3]
    mu = np.linalg.norm(X, 2) / 100
    param2_list = [mu*0.1, mu, mu*10]

  elif method == 'PCP_by_GoDec':
    param1_list = [estimate_rank(X, energy_threshold=0.85), estimate_rank(X, energy_threshold=0.95), estimate_rank(X, energy_threshold=0.99)]
    param2_list = [int(0.005 * X.size), int(0.01 * X.size), int(0.05 * X.size)]



  for param1 in param1_list:
    for param2 in param2_list:

      if method == 'PCP_by_PG':
        print(f'lambda: {param1}, gamma: {param2}')
        start_time = time.time()
        L, S, obj, residual_values = PCP_by_PG(X, param1, param2, max_iter=max_iter)
        elapsed_time = time.time() - start_time

      elif method == 'PCP_by_ADMM':
        print(f'lambda: {param1}, mu: {param2}')
        start_time = time.time()
        L, S, obj, residual_values = PCP_by_ADMM(X, param1, param2, max_iter=max_iter)
        elapsed_time = time.time() - start_time

      elif method == 'PCP_by_GoDec':
        print(f'rank_r: {param1}, k: {param2}')
        start_time = time.time()
        L, S, obj = PCP_by_GoDec(X, param1, param2, max_iter=max_iter)
        elapsed_time = time.time() - start_time


      binary_S = local_maxima_binary(S)

      spike_times_detected = get_predicted_spikes(binary_S, lim_inf, lim_sup)
      tp, fp, fn, precision, recall, f1 = get_metrics(spike_times_detected, spike_times_groundtruth)
      reconstruction_error, energy_ratio = reco(L, S, X)

      rank = np.sum(np.linalg.svd(L, compute_uv=False) > 1e-3)
      sparsity = np.count_nonzero(S) / S.size

      if method == 'PCP_by_PG':
        results.append([lim_inf, lim_sup, method, param1, param2, tp, fp, fn, precision, recall, f1, elapsed_time, len(obj), rank, sparsity, reconstruction_error, energy_ratio, obj, residual_values])
      elif method == 'PCP_by_ADMM':
        results.append([lim_inf, lim_sup, method, param1, param2, tp, fp, fn, precision, recall, f1, elapsed_time, len(obj), rank, sparsity, reconstruction_error, energy_ratio, obj, residual_values])
      elif method == 'PCP_by_GoDec':
        results.append([lim_inf, lim_sup, method, param1, param2, tp, fp, fn, precision, recall, f1, elapsed_time, len(obj), rank, sparsity, reconstruction_error, energy_ratio, obj, obj])


  return results

## Testing over 5 submatrices of 1000 time points

In [31]:
max_iter = 500

ADMM_results = []
for i in range(5):
  lim_inf = i*1000
  lim_sup = (i+1)*1000
  print(f'lim_inf: {lim_inf}, lim_sup: {lim_sup}')
  result = tune_param('PCP_by_ADMM', lim_inf, lim_sup, lfp_array, max_iter)
  ADMM_results.extend(result)

lim_inf: 0, lim_sup: 1000
lambda: 0.03162277660168379, mu: 0.26552456665039065
lambda: 0.03162277660168379, mu: 2.655245666503906
lambda: 0.03162277660168379, mu: 26.55245666503906
lambda: 0.06324555320336758, mu: 0.26552456665039065
lambda: 0.06324555320336758, mu: 2.655245666503906
lambda: 0.06324555320336758, mu: 26.55245666503906
lambda: 0.09486832980505137, mu: 0.26552456665039065
lambda: 0.09486832980505137, mu: 2.655245666503906
lambda: 0.09486832980505137, mu: 26.55245666503906
lim_inf: 1000, lim_sup: 2000
lambda: 0.03162277660168379, mu: 0.27282229614257814
lambda: 0.03162277660168379, mu: 2.7282229614257814
lambda: 0.03162277660168379, mu: 27.282229614257815
lambda: 0.06324555320336758, mu: 0.27282229614257814
lambda: 0.06324555320336758, mu: 2.7282229614257814
lambda: 0.06324555320336758, mu: 27.282229614257815
lambda: 0.09486832980505137, mu: 0.27282229614257814
lambda: 0.09486832980505137, mu: 2.7282229614257814
lambda: 0.09486832980505137, mu: 27.282229614257815
lim_inf: 

In [33]:
ADMM_df = pd.DataFrame(ADMM_results)
ADMM_df.columns = ['lim_inf', 'lim_sup', 'method', 'param1', 'param2', 'tp', 'fp', 'fn', 'precision', 'recall', 'f1', 'elapsed_time', 'obj_len', 'rank', 'sparsity', 'reconstruction_error', 'energy_ratio', 'obj', 'residual']
print(ADMM_df)

ADMM_df.to_csv('/content/drive/MyDrive/low_d/ADMM_df.csv', index=False)

    lim_inf  lim_sup       method    param1     param2   tp   fp   fn  \
0         0     1000  PCP_by_ADMM  0.031623   0.265525  148  852    0   
1         0     1000  PCP_by_ADMM  0.031623   2.655246  148  852    0   
2         0     1000  PCP_by_ADMM  0.031623  26.552457  148  852    0   
3         0     1000  PCP_by_ADMM  0.063246   0.265525  148  831    0   
4         0     1000  PCP_by_ADMM  0.063246   2.655246  148  831    0   
5         0     1000  PCP_by_ADMM  0.063246  26.552457  148  829    0   
6         0     1000  PCP_by_ADMM  0.094868   0.265525   68  191   80   
7         0     1000  PCP_by_ADMM  0.094868   2.655246   68  191   80   
8         0     1000  PCP_by_ADMM  0.094868  26.552457   68  191   80   
9      1000     2000  PCP_by_ADMM  0.031623   0.272822  147  853  148   
10     1000     2000  PCP_by_ADMM  0.031623   2.728223  147  853  148   
11     1000     2000  PCP_by_ADMM  0.031623  27.282230  147  853  148   
12     1000     2000  PCP_by_ADMM  0.063246   0.272

## Testing over 5 submatrices of varying time points size

In [40]:
max_iter = 500
time_list = [100,1000,5000,10000]

ADMM_results_times = []
for time_val in time_list:
  lim_inf = 0
  lim_sup = time_val
  print(f'lim_inf: {lim_inf}, lim_sup: {lim_sup}')
  result = tune_param('PCP_by_ADMM', lim_inf, lim_sup, lfp_array, max_iter)
  ADMM_results_times.extend(result)

lim_inf: 0, lim_sup: 100
lambda: 0.1, mu: 0.08875743865966797
lambda: 0.1, mu: 0.8875743865966796
lambda: 0.1, mu: 8.875743865966797
lambda: 0.2, mu: 0.08875743865966797
lambda: 0.2, mu: 0.8875743865966796
lambda: 0.2, mu: 8.875743865966797
lambda: 0.30000000000000004, mu: 0.08875743865966797
lambda: 0.30000000000000004, mu: 0.8875743865966796
lambda: 0.30000000000000004, mu: 8.875743865966797
lim_inf: 0, lim_sup: 1000
lambda: 0.03162277660168379, mu: 0.26552456665039065
lambda: 0.03162277660168379, mu: 2.655245666503906
lambda: 0.03162277660168379, mu: 26.55245666503906
lambda: 0.06324555320336758, mu: 0.26552456665039065
lambda: 0.06324555320336758, mu: 2.655245666503906
lambda: 0.06324555320336758, mu: 26.55245666503906
lambda: 0.09486832980505137, mu: 0.26552456665039065
lambda: 0.09486832980505137, mu: 2.655245666503906
lambda: 0.09486832980505137, mu: 26.55245666503906
lim_inf: 0, lim_sup: 5000
lambda: 0.01414213562373095, mu: 0.6095609130859376
lambda: 0.01414213562373095, mu: 6

In [42]:
ADMM_time_df = pd.DataFrame(ADMM_results_times)
ADMM_time_df.columns = ['lim_inf', 'lim_sup', 'method', 'param1', 'param2', 'tp', 'fp', 'fn', 'precision', 'recall', 'f1', 'elapsed_time', 'obj_len', 'rank', 'sparsity', 'reconstruction_error', 'energy_ratio', 'obj', 'residual']
print(ADMM_time_df)

ADMM_time_df.to_csv('/content/drive/MyDrive/low_d/ADMM_time_df.csv', index=False)

    lim_inf  lim_sup       method    param1     param2    tp    fp   fn  \
0         0      100  PCP_by_ADMM  0.100000   0.088757    16    84    0   
1         0      100  PCP_by_ADMM  0.100000   0.887574    16    84    0   
2         0      100  PCP_by_ADMM  0.100000   8.875744    16    84    0   
3         0      100  PCP_by_ADMM  0.200000   0.088757    16    82    0   
4         0      100  PCP_by_ADMM  0.200000   0.887574    16    82    0   
5         0      100  PCP_by_ADMM  0.200000   8.875744    16    82    0   
6         0      100  PCP_by_ADMM  0.300000   0.088757     8    16    8   
7         0      100  PCP_by_ADMM  0.300000   0.887574    10    15    6   
8         0      100  PCP_by_ADMM  0.300000   8.875744    10    16    6   
9         0     1000  PCP_by_ADMM  0.031623   0.265525   148   852    0   
10        0     1000  PCP_by_ADMM  0.031623   2.655246   148   852    0   
11        0     1000  PCP_by_ADMM  0.031623  26.552457   148   852    0   
12        0     1000  PCP