Signals from Recordings (Spike Sorting)

This implements basic spike sorting techniques on real data from (https://www2.le.ac.uk/centres/csn/software), and simulated data. Problem 1 (50 points) features real data from a single channel. Problem 1 parts will involve implementing (a) bandpass filtering, (b) thresholding, (c) PCA, and (d) K-means. Problem 2 (50 points) will involve applying the (1b-d) to simulated data from multiple channels. Problem 3 (optional for a bonus of 50 points) will also feature the same data as in Problem 2, and will involve the inference of a generative model. Here, rather than assuming putative spikes have been identified and windows around them have been extracted, we'll model the multi-channel voltage time series directly using a convolutional matrix factorization model. We'll use PyTorch to implement the key operations (convolutions and cross-correlations) on a GPU.

## Instructions
Make a copy of this notebook (File $\rightarrow$ Save a Copy in Drive). Fill in your name above, and any team members you are working with.

Complete and run the code cells below to preprocess the data, implement the spike sorting models described in class, and produce some plots.

The Problems 1-3 ask you to fill in a few lines of code, denoted as the following
```
###
# This block should do a thing.
# YOUR CODE BELOW
#

result = ...
#
###
```
It's ok if you split your answer into more than one line.
Don't change the output variable names or the subsequent code won't run!

Some problems ask you to discuss the results. Respond to these questions in text cells.



In [None]:
import scipy.signal as signal
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from tqdm.auto import trange
from scipy.ndimage import gaussian_filter1d
from scipy.optimize import linear_sum_assignment

# plotting stuff
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap
from matplotlib.gridspec import GridSpec
import seaborn as sns

# Enable plots inside the Jupyter Notebook
%matplotlib inline

# Problem 1: Single Channel Electrophysiology Data

In [None]:
# Define data path (this is the same data as in the link above)
%%capture
!wget -nc https://www.dropbox.com/s/vkjq2nqdyaq8ezf/CSC4.Ncs

In [None]:
#@title Function for reading data (run this cell)

data_file="CSC4.Ncs"
# Header has 16 kilobytes length
header_size   = 16 * 1024

# Open file
fid = open(data_file, 'rb')

# Skip header by shifting position by header size
fid.seek(header_size)

# Read data according to Neuralynx information
data_format = np.dtype([('TimeStamp', np.uint64),
                        ('ChannelNumber', np.uint32),
                        ('SampleFreq', np.uint32),
                        ('NumValidSamples', np.uint32),
                        ('Samples', np.int16, 512)])

raw = np.fromfile(fid, dtype=data_format)

# Close file
fid.close()



In [None]:
#@title Plotting functions (run this cell)

def plot_fig(time_, data_, title_='', xlabel_='', ylabel_=''):
  fig, ax = plt.subplots(figsize=(12, 4))
  ax.plot(time_, data_, 'k')
  ax.set_title(title_, fontsize=22)
  ax.set_xlim(time_[0], time_[-1])
  ax.set_xlabel(xlabel_, fontsize=20)
  ax.set_ylabel(ylabel_, fontsize=20)
  plt.show()

In [None]:
## Visualizing the raw data
# Get sampling frequency
sample_freq = raw['SampleFreq'][0]

# Create data vector
data = raw['Samples'].ravel()

# Determine duration of recording in seconds
dur_sec = data.shape[0]/sample_freq
print("Data is ", dur_sec, " seconds long")

# Create time vector
time = np.linspace(0, dur_sec, data.shape[0])

# Plot first second of data
plot_fig(time[0:sample_freq], data[0:sample_freq], 'Raw Data', 'time [s]', 'amplitude [uV]')

## Problem 1(a): Bandpass filtering
First we will bandpass filter each channel from 500Hz to 9kHz to isolate spiking content. We'll use a 10th order Butterworth filter.

Use `signal.butter` and `signal.sosfilt` to do this (note that `signal` is an alias for `scipy.signal`; see above). Try calling `help(signal.butter)` or Googling it for more information on the function signature and outputs. The `sample_freq` is specified above.

In [None]:
# Construct a Butterworth bandpass filter.
order = 10
###
# Create a Butterworth filter with order 10 from 500Hz to 9kHz using the signal.butter() function
#
# YOUR CODE BELOW
#
# sos = signal.butter(...)
#
###


Plot the frequency response of the filter

In [None]:
# This function evaluates the filter response
# at a grid of input frequencies from 0 to the Nyquist frequency
# (1/2 the sampling frequency). The response is given
# as a complex number for each input frequency, where the square
# of the magnitude is the power at that frequency.
freqs, response = signal.sosfreqz(sos, fs=sample_freq)

# convert the response to decibels and truncate lower end.
# (see, e.g., https://en.wikipedia.org/wiki/Decibel)
response_db = 20 * np.log10(np.maximum(np.abs(response), 1e-5))

# Plot the response.
plt.figure(figsize=(8, 6))
plt.plot(freqs, response_db, lw=2)
plt.vlines([300, 2000], *plt.ylim(), colors='r', ls=':')
plt.ylim(-40, 5)
plt.grid(True)
plt.yticks([0, -20, -40, -60])
plt.ylabel('Gain [dB]')
plt.title('Butterworth filter frequency response')

Apply the filter to the raw data using `signal.sosfilt`.

In [None]:
###
# Filter the data with the Butterworth filter above, using the signal.sosfilt() function
#
# YOUR CODE BELOW
#
# filtered_data = signal.sosfilt(...)
#
###

Plot the raw and bandpass filtered data.

In [None]:
plot_fig(time[0:10*sample_freq], data[0:10*sample_freq], 'Raw Data', 'time [s]', 'amplitude [uV]')
plot_fig(time[0:10*sample_freq], filtered_data[0:10*sample_freq], 'Bandpass Filtered Data', 'time [s]', 'amplitude [uV]')

## Problem 1(b): Extract spikes from the filtered signal
Now that we have a clean spike channel we can identify and extract spikes. Let's first normalize the data, then instead of simply thresholding to find the action potentials, we will look for _peaks_ in the data. Check out the [scipy.signal.find_peaks()](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.find_peaks.html) function to find out more more.

In [None]:
## Normalize the filtered data by computing the z-score before utilizing the find_peaks() function.

###
# Compute the Z-score of the 'filtered_data' (subtract the mean and divide by the standard deviation)
#
# YOUR CODE BELOW
#
# filtered_data_zscore = ...
#
###

In [None]:
# Use scipy.signal.find_peaks to find candidate spike times
distance_s = 0.005      # time between spikes (in seconds)

###
# # Pick a number for the minimum threshold that makes sense given the data, i.e., try out several values and see which you think is right.
#
# YOUR CODE BELOW
# min_height = ...  # minimum standard deviation to define a peak, i.e., the peak has to be minimum this height
#
###
max_height = 20     # maximum standard deviation to define a peak, i.e., the peak can be maximum this height, to remove artifacts in the data

###
# # Find peaks in the z-scored, filtered data such that the minimum height of the peaks is min_height, and maximum is max_height.
# # Keep in mind the distance between two consecutive spikes is minimum 'distance_s' (converted to samples).
#
# YOUR CODE BELOW
# spike_inds, spike_heights = signal.find_peaks(...)

spike_inds, spike_heights = signal.find_peaks(filtered_data_zscore,  height=[min_height,max_height],  distance=distance_s * sample_freq)
#
###

spike_inds = np.delete(spike_inds, 0) # remove the first spike if too close to time 0

# Define waveforms by taking a set number of samples around the peak, and offsetting that slightly
num_spikes = len(spike_inds)
spike_width = 91 # waveform length, or action potential length
offset = 10 # offset the peak of the action potential by this number
waveforms = np.zeros((spike_width, num_spikes))
for i, ind in enumerate(spike_inds):
    window = slice(ind - spike_width // 2 + offset, ind + spike_width // 2 + offset + 1)
    waveforms[:,i] = filtered_data_zscore[window]

# Delete waveforms with values above max height (artifact)
inds_to_del=[]
for i, ind in enumerate(spike_inds):
  if (waveforms[:,i]).max()>max_height:
    inds_to_del.append(i)
spike_inds = np.delete(spike_inds, inds_to_del)
waveforms = np.delete(waveforms, inds_to_del, axis=1)

# plot
fig, ax = plt.subplots(figsize=(12, 4))
ax.plot(time[:10*sample_freq], filtered_data_zscore[:10*sample_freq], 'k')
ax.plot(spike_inds/sample_freq,10*np.ones(spike_inds.shape),'xr')
ax.set_title('Detected Spike Times', fontsize=22)
ax.set_xlim(time[0], time[10*sample_freq])
ax.set_xlabel('time [s]', fontsize=20)
ax.set_ylabel('amplitude', fontsize=20)
plt.show()


In [None]:
# Visualize the waveforms
plot_fig(1000*np.arange(spike_width)/sample_freq, waveforms[:,np.random.randint(0,waveforms.shape[1],30)], 'Waveforms', 'time [ms]', 'amplitude')

## Problem 1(c): Dimensionality Reduction

We will now implement Principal Component Analysis (PCA) for all the spike waveforms. This will help with the next step of efficient clustering.

In [None]:
###
# # Implement PCA on the waveforms, with a large number of components, in order to choose how many components we will need to explain the variance in our data
#
# YOUR CODE BELOW
# num_comps = ...
# pca = PCA(n_components=num_comps)
# pca.fit(...)
#
###

## plot the variance explained ratio, and pick the number of PCs that reach ~90% variance explained
plot_fig(np.arange(num_comps)+1,np.cumsum(pca.explained_variance_ratio_),'Variance Explained','Number of Components', 'Fraction of Explained Variance')

In [None]:
# Choose the number of components to keep, such that ~90% of the variance in the data is explained
###
# # Perform PCA with the chosen number of components
#
# YOUR CODE BELOW
# num_comps = ...
# pca = PCA(n_components=num_comps)
# pc_scores = pca.fit_transform(...)
##

# Plot the first 3 PCs, with the color denoting the 3rd PC.
# Remember, each data point here should be a waveform (check that this is true).
fig, ax = plt.subplots(figsize=(8, 6))
im = ax.scatter(pc_scores[:, 0], pc_scores[:, 1], c=pc_scores[:, 2])
cbar = fig.colorbar(im,ax=ax)
ax.set_xlabel('1st PC', fontsize=20)
ax.set_ylabel('2nd PC', fontsize=20)
cbar.set_label('3rd PC', rotation=270)
ax.set_title('PCs 1-3', fontsize=23)

plt.show()

## Problem 1(d) Implement K-means using Lloyd's algorithm

Minimize the loss by iterating between finding the cluster means and assigning data points to clusters.

In [None]:
def k_means(data, num_clusters=3, num_iters=200):
    """
    Input: data of shape [number of samples x number of dimensions], desired number of clusters, and maximum number of iterations
    Returns:
    1) loss [1x1]
    2) cluster assignments [number of samples x 1]
    3) cluster_centroids [num_clusters x number of dimensions]
    """
    ###
    # # Implement the K-means function such that it takes in the required inputs and produces the variables detailed above.
    #
    # YOUR CODE BELOW
    #
    ###
    ###
    return loss, cluster_assignments, cluster_centroids

Run the K-means model with different initializations and different number of clusters to determine the 'elbow' in number of clusters.

In [None]:
# This may take a couple of minutes!
max_num_clusters = 15
num_runs = 10

min_loss = []
for num_clus in range(1, max_num_clusters +1):
    loss_over_runs = []
    for run in range(num_runs):
        loss, cluster_assignments, cluster_centroids = k_means(pc_scores, num_clus)
        loss_over_runs.append(loss)
    min_loss.append(np.min(loss_over_runs))


In [None]:
# Plot in order to decide number of clusters by using the elbow method
plot_fig(range(1, max_num_clusters +1), min_loss,'Deciding number of clusters', 'Number of Clusters','Loss')

In [None]:
###
# # Implement K-means again with the chosen number of clusters, such that there is an approximate elbow in the curve above
#
# YOUR CODE BELOW
# num_clus = ... # input a number based on the above plot
#
###

loss, cluster_assignments, cluster_centroids = k_means(pc_scores, num_clus)

# Plot the result
fig, ax = plt.subplots(1, 2, figsize=(15, 5))
ax[0].scatter(pc_scores[:, 0], pc_scores[:, 1], c=cluster_assignments)
ax[0].plot(cluster_centroids[:,0], cluster_centroids[:,1],'xr')
ax[0].set_xlabel('1st principal component', fontsize=20)
ax[0].set_ylabel('2nd principal component', fontsize=20)
ax[0].set_title('Clustered Data', fontsize=23)

time = np.linspace(0, waveforms.shape[0]/sample_freq, waveforms.shape[0])*1000
for i in range(num_clus):
    cluster_mean = waveforms[:, cluster_assignments==i].mean(axis=1)
    cluster_std = waveforms[:, cluster_assignments==i].std(axis=1)

    ax[1].plot(time, cluster_mean, label='Cluster {}'.format(i))
    ax[1].fill_between(time, cluster_mean-cluster_std, cluster_mean+cluster_std, alpha=0.15)

ax[1].set_title('average waveforms', fontsize=23)
ax[1].set_xlim([0, time[-1]])
ax[1].set_xlabel('time [ms]', fontsize=20)
ax[1].set_ylabel('amplitude [uV]', fontsize=20)

plt.legend()
plt.show()

## Problem 1(e): Discussion
This is an open-ended question so there's not necessarily a right answer.  Try to think critically about the algorithm and the results. Please respond to the following prompts:

- In practice, you would post-process the extracted spikes to reject unrealistic neurons and merge overlapping ones. How would you approach this problem.
- Over the course of a long recording session, the probe could drift up and down so that the channels activated by a neuron shift. How could you compensate for this slow drift in this model, or possibly try to correct for it during preprocessing?

*Answer below this line*

---

# Problem 2: Multichannel spike sorting

In [None]:
#@title Helper function to generate synthetic data (run this cell)

def generate_templates(rng, C, D, N):
    # Make (semi) random templates
    templates = []
    for n in range(N):
        # center = n * C / (N - 1) if N > 1 else C / 2
        center = C * rng.random()
        width = 1 + (C / 10) * rng.random()
        spatial_factor = np.exp(-0.5 * (np.arange(C) - center)**2 / width**2)

        dt = np.arange(D)
        period = D / (1 + rng.random())
        z = (dt - 0.75 * period) / (.25 * period)
        warp = lambda x: -np.exp(-x) + 1
        window = np.exp(-0.5 * z**2)
        shape = np.sin(2 * np.pi * dt / period)
        temporal_factor = warp(window * shape)

        template = np.outer(spatial_factor, temporal_factor)
        template /= np.linalg.norm(template)
        templates.append(template)

    return np.array(templates)


def generate(rng, T, C, D, N,
             mean_amplitude=15,
             shape_amplitude=3.0,
             noise_std=1,
             sample_freq=1000):
    """Create a random set of model parameters and sample data.

    Parameters:
    T: integer number of time samples in the data
    C: integer number of channels
    D: integer duration (number of samples) of each template
    N: integer number of neurons
    """
    # Make semi-random templates
    templates = generate_templates(rng, C, D, N)

    # Make random amplitudes
    amplitudes = np.zeros((N, T))
    for n in range(N):
        num_spikes = rng.poisson(T / sample_freq * 10)
        times = rng.integers(0, T, size=num_spikes)
        amps = rng.gamma(shape_amplitude,
                         scale=mean_amplitude / shape_amplitude,
                         size=num_spikes)
        amplitudes[n, times] = amps

        # Only keep spikes separated by at least D
        times, props = signal.find_peaks(amplitudes[n], distance=D, height=1e-3)
        amplitudes[n] = 0
        amplitudes[n, times] = props['peak_heights']

    # Convolve the signal with each row of the multi-channel template
    data = 0
    for temp, amp in zip(templates, amplitudes):
        data += np.row_stack([
            np.convolve(amp, row, mode='full')[:-(D-1)]
            for row in temp])

    data += rng.normal(scale=noise_std, size=data.shape)

    return templates, amplitudes, data

In [None]:
#@title Helper functions for plotting (run this cell)

# initialize a color palette for plotting
palette = sns.xkcd_palette(["windows blue",
                            "red",
                            "medium green",
                            "dusty purple",
                            "orange",
                            "amber",
                            "clay",
                            "pink",
                            "greyish"])
sns.set_context("notebook")


def plot_data(timestamps,
              data,
              plot_slice=slice(0, 6000),
              labels=None,
              spike_times=None,
              neuron_channels=None,
              spike_width=81,
              scale=10,
              figsize=(12, 9),
              cmap="jet"):
    n_channels, n_samples = data.shape
    cmap = get_cmap(cmap) if isinstance(cmap, str) else cmap

    plt.figure(figsize=figsize)
    plt.plot(timestamps[plot_slice],
             data.T[plot_slice] - scale * np.arange(n_channels),
             '-k', lw=1)

    if not any(x is None for x in [labels, spike_times, neuron_channels]):
        # Plot the ground truth spikes and assignments
        n_units = labels.max()
        in_slice = (spike_times >= plot_slice.start) & (spike_times < plot_slice.stop)
        labels = labels[in_slice]
        times = spike_times[in_slice]
        for i in range(n_units):
            i_channels = neuron_channels[i]
            for t in times[labels == i]:
                window = slice(t, t + spike_width)
                plt.plot(timestamps[window],
                         data.T[window, i_channels] - scale * np.arange(n_channels)[i_channels],
                         color=cmap(i / (n_units-1)),
                         alpha=0.5,
                         lw=2)

    plt.yticks(-scale * np.arange(1, n_channels+1, step=2),
            np.arange(1, n_channels+1, step=2) + 1)
    plt.xlabel("time [s]")
    plt.ylabel("channel")
    plt.xlim(timestamps[plot_slice.start], timestamps[plot_slice.stop])
    plt.ylim(-scale * n_channels, scale)


def plot_templates(templates,
                   indices,
                   scale=0.1,
                   n_cols=8,
                   panel_height=6,
                   panel_width=1.25,
                   colors=('k',),
                   label="neuron",
                   sample_freq=30000,
                   fig=None,
                   axs=None):
    n_subplots = len(indices)
    n_cols = min(n_cols, n_subplots)
    n_rows = int(np.ceil(n_subplots / n_cols))

    if fig is None and axs is None:
        fig, axs = plt.subplots(n_rows, n_cols,
                                figsize=(panel_width * n_cols, panel_height * n_rows),
                                sharex=True, sharey=True)

    n_units, n_channels, spike_width = templates.shape
    timestamps = np.arange(-spike_width // 2, spike_width//2) / sample_freq
    for i, (ind, ax) in enumerate(zip(indices, np.ravel(axs))):
        color = colors[i % len(colors)]
        ax.plot(timestamps * 1000,
                templates[ind].T - scale * np.arange(n_channels),
                '-', color=color, lw=1)

        ax.set_title("{} {:d}".format(label, ind + 1))
        ax.set_xlim(timestamps[0] * 1000, timestamps[-1] * 1000)
        ax.set_yticks(-scale * np.arange(n_channels+1, step=4))
        ax.set_yticklabels(np.arange(n_channels+1, step=4) + 1)
        ax.set_ylim(-scale * n_channels, scale)

        if i // n_cols == n_rows - 1:
            ax.set_xlabel("time [ms]")
        if i % n_cols == 0:
            ax.set_ylabel("channel")

        # plt.tight_layout(pad=0.1)

    # hide the remaining axes
    for ax in np.ravel(axs)[n_subplots:]:
        ax.set_visible(False)
    plt.show()
    return fig, axs


def plot_model(templates, amplitude, data, scores=None, lw=2, figsize=(12, 6)):
    """Plot the raw data as well as the underlying signal amplitudes and templates.

    amplitude: (N,T) array of underlying signal amplitude
    template: (N,C,D) array of template that is convolved with signal
    data: (C, T) array (channels x time)
    scores: optional (N,T) array of correlations between data and template
    """
    # prepend dimension if data and template are 1d
    data = np.atleast_2d(data)
    C, T = data.shape
    amplitude = np.atleast_2d(amplitude)
    N, _ = amplitude.shape
    templates = templates.reshape(N, C, -1)
    D = templates.shape[-1]
    dt = np.arange(D)
    if scores is not None:
        scores = np.atleast_2d(scores)

    # Set up figure with 2x2 grid of panels
    fig = plt.figure(figsize=figsize)
    gs = GridSpec(2, N + 1, height_ratios=[1, 2], width_ratios=[1] * N + [2 * N])

    # plot the templates
    t_spc = 1.05 * abs(templates).max()
    for n in range(N):
        ax = fig.add_subplot(gs[1, n])
        ax.plot(dt, templates[n].T - t_spc * np.arange(C),
                '-', color=palette[n % len(palette)], lw=lw)
        ax.set_xlabel("$d$")
        ax.set_xlim([0, D])
        ax.set_yticks(-t_spc * np.arange(C))
        ax.set_yticklabels([])
        ax.set_ylim(-C * t_spc, t_spc)
        if n == 0:
            ax.set_ylabel("channels $c$")
        ax.set_title("$W_{{ {} }}$".format(n+1))

    # plot the amplitudes for each neuron
    ax = fig.add_subplot(gs[0, -1])
    a_spc = 1.05 * abs(amplitude).max()
    if scores is not None:
        a_spc = max(a_spc, 1.05 * abs(scores).max())

    for n in range(N):
        ax.plot(amplitude[n] - a_spc * n, '-', color=palette[n % len(palette)], lw=lw)

        if scores is not None:
            ax.plot(scores[n] - a_spc * n, ':', color=palette[n % len(palette)], lw=lw,
                label="$y \star W$")

    ax.set_xlim([0, T])
    ax.set_xticklabels([])
    ax.set_yticks(-a_spc * np.arange(N))
    ax.set_yticklabels([])
    ax.set_ylabel("neurons $n$")
    ax.set_title("amplitude $a$")
    if scores is not None:
        ax.legend()

    # plot the data
    ax = fig.add_subplot(gs[1, -1])
    d_spc = 1.05 * abs(data).max()
    ax.plot(data.T - d_spc * np.arange(C), '-', color='gray', lw=lw)
    ax.set_xlabel("time $t$")
    ax.set_xlim([0, T])
    ax.set_yticks(-d_spc * np.arange(C))
    ax.set_yticklabels([])
    ax.set_ylim(-C * d_spc, d_spc)
    # ax.set_ylabel("channels $c$")
    ax.set_title("data $y$")

    # plt.tight_layout()

In [None]:
# Create a larger dataset with a multiple channels and neurons.
T = 1000000  # number of time samples
C = 10      # number of channels
D = 81      # duration of a spike (in samples)
N = 5       # multiple neurons

# Generate random templates, amplitudes, and noisy data.
# `templates` are NxCxD and `amplitudes` are NxT
rng = np.random.default_rng(seed=1)
print("Simulating data. This could take a minute!")
true_templates, true_amplitudes, data = generate(rng, T, C, D, N)
plot_model(true_templates, true_amplitudes[:, :2000], data[:,:2000], lw=1, figsize=(12, 8))

Implement the above steps (Problems 1(b-c)) on the multi-channel data

## Problem 2(a): Extract putative spikes from Multi-Channel Data

In [None]:
## We first find peaks again to find candidate spike times. However, this time we will look in each channel.
# Use scipy.signal.find_peaks to find candidate spike times
timestamps = np.arange(T)
distance_ms = 0.005  # time between spikes (in seconds)
height = 4           # standard deviations to define a peak
per_ch_spike_inds = []
for ch in trange(C):
    ###
    # # Implement the 'find_peaks' signal again, this time to find negative peaks (or positive peaks on -data)
    #
    # YOUR CODE BELOW
    #   ch_spike_inds, _ = signal.find_peaks(...)
    #
    ###
    per_ch_spike_inds.append(ch_spike_inds)

plot_data(timestamps, data)
for ch, ch_spike_inds in enumerate(per_ch_spike_inds):
    plt.plot(timestamps[ch_spike_inds],
             data[ch, ch_spike_inds] - 10 * ch,
             'r.',)

Combine nearly coincident spikes across channels


In [None]:
# allow a delay of a fraction of the refractory period
bins = np.arange(T + 1)
###
# # Compute the histogram of all the spike times across all channels
#
# YOUR CODE BELOW
#   total_spike_counts, _ = np.histogram(...)
#
###

# Do a little Gaussian smoothing to allow for jitter in spike time across channels
jitter_width = 0.0001  # in seconds
total_spike_counts = gaussian_filter1d(total_spike_counts.astype(float),
                                       jitter_width * sample_freq)


###
# # # Eyeball a threshold and find peaks
# # run with np.inf first and then update with chosen threshold
#
# YOUR CODE BELOW
# min_height = ...

min_height = 0.1
#
###
spike_inds, _ = signal.find_peaks(total_spike_counts,
                                  height=min_height,
                                  distance=0.001 * sample_freq)
spike_inds=np.delete(spike_inds,-1) # may need to delete last spike if close to end of time series
num_spikes = len(spike_inds)
print("Found", num_spikes, "putative spikes")


In [None]:
putative_spikes = np.zeros((num_spikes, C, D))
for i, ind in enumerate(spike_inds):
    window = slice(ind - D // 2, ind + D // 2 + 1)
    putative_spikes[i] = data[:, window]

# we can use our template plotting code to visualize the spikes too
plot_templates(putative_spikes, indices=np.arange(16), scale=10, label="spike")

## Problem 2(b): Dimensionality Reduction for Multi-Channel Data

In [None]:
## Flatten the waveform data so that each data point is an entire waveform across all channels.
spikes_flat=putative_spikes.reshape((putative_spikes.shape[0],putative_spikes.shape[1]*putative_spikes.shape[2]))

In [None]:
###
# # Implement PCA on the waveforms across all channels, with a large number of components,
# # Using this plot, we will choose how many components we will need to explain the variance in our data
#
# YOUR CODE BELOW
# num_comps = ...
# pca = PCA(n_components=num_comps)
# pca.fit(...)
#
###

## plot the variance explained ratio, and pick the number of PCs that reach some saturation in variance explained
plt.plot(np.cumsum(pca.explained_variance_ratio_)); plt.show()

In [None]:
# Choose the number of components to keep, such that variance explained in the data saturates. Question: why does it start saturating so early?
###
# # Perform PCA with the chosen number of components
#
# YOUR CODE BELOW
# num_comps = ...
# pca = PCA(n_components=num_comps)
# pc_scores = pca.fit_transform(...)
#
###

# Plot the first 3 PCs, with the color denoting the 3rd PC.
# Remember, each data point here is a waveform.
fig, ax = plt.subplots(figsize=(8, 6))
im = ax.scatter(pc_scores[:, 0], pc_scores[:, 1], c=pc_scores[:, 2])
cbar = fig.colorbar(im,ax=ax)
ax.set_xlabel('1st PC', fontsize=20)
ax.set_ylabel('2nd PC', fontsize=20)
cbar.set_label('3rd PC', rotation=270)
ax.set_title('PCs 1-3', fontsize=23)

plt.show()

## Problem 2(c): Implement K-means for Multi-Channel Data

This time, we run the model with different number of clusters to determine the 'elbow' in number of clusters, but only with one initialization due to the time-consuming nature of the code. This may still take a couple of minutes to run, but will help us decide how many clusters to choose.

In [None]:
max_num_clusters = 15
num_runs = 1

min_loss = []
for num_clus in range(1, max_num_clusters +1):
    loss_over_runs = []
    for run in range(num_runs):
        loss, cluster_assignments, cluster_centroids = k_means(pc_scores, num_clus)
        loss_over_runs.append(loss)
    min_loss.append(np.min(loss_over_runs))


In [None]:
# Plot in order to decide number of clusters by using the elbow method
plot_fig(range(1, max_num_clusters +1), min_loss,'Deciding number of clusters', 'Number of Clusters','Loss')

In [None]:
###
# # Implement K-means again with the chosen number of clusters, such that there is an approximate elbow in the curve above
# # # Choose the number of clusters based on the above plot.
# # # Try different numbers of clusters and continue till the visualization to gain understanding.
# YOUR CODE BELOW
# num_clus = ...
#
###
loss, cluster_assignments, cluster_centroids = k_means(pc_scores, num_clus)

# Plot the result
fig, ax = plt.subplots( figsize=(5, 5))
ax.scatter(pc_scores[:, 0], pc_scores[:, 1], c=cluster_assignments)
ax.plot(cluster_centroids[:,0], cluster_centroids[:,1],'xr')
ax.set_xlabel('1st principal component', fontsize=20)
ax.set_ylabel('2nd principal component', fontsize=20)
ax.set_title('Clustered Data', fontsize=23)


In [None]:
## Compute the similarity (correlation coefficient) between the true and inferred templates

# Take the cluster centroids and map them back into the original space, and treat each as an individual neuron's template
templates = np.reshape(pca.inverse_transform(cluster_centroids),(num_clus,C,D))

similarity = np.zeros((N, num_clus))
for i in range(N):
    for j in range(num_clus):
        similarity[i, j] = np.corrcoef(np.reshape(true_templates[i],(-1)), np.reshape(templates[j],-1))[0,1]

# Show the similarity matrix
_, perm = linear_sum_assignment(similarity, maximize=True)
plt.imshow(similarity[:, perm], vmin=0, vmax=1)
plt.xlabel("true neuron")
plt.ylabel("inferred neuron")
plt.title("correlation coefficients of amplitudes")
plt.colorbar()

In [None]:
## Plot the true and inferred templates
plot_templates(true_templates, np.arange(N), n_cols=N, label='true')
plot_templates(templates[perm], np.arange(min(N,num_clus)), n_cols=min(N,num_clus), scale=3, label='matched',colors=('r',))

if num_clus>N:
    plot_templates(templates[np.setdiff1d(np.arange(num_clus), perm)],
                   np.arange(num_clus-N), n_cols=num_clus-N, scale=3, label='others',colors=('r',))


## Problem 2(d): Discussion

- How do the recovered templates seem in comparison to the true templates, especially for num_clus $\geq$ N?
- What steps can you take to improve the results?

*Answer below this line*

---

# Problem 3: Modeling the Voltage Data with Generative Modeling

## Problem 3 intro: Convolution and cross-correlation with PyTorch

PyTorch is a machine learning framework for implementing fast numerical computations on CPU and GPU hardware. It has a domain specific language that is very similar to NumPy's. Instead of numpy `ndarray`'s, we operate on PyTorch `Tensor`'s. Tensors are stored on the specified device, and in the setup above you'll see that we specified our device to be `'cuda'`, i.e. a GPU. That's also why you need to run this Colab notebook with a GPU Runtime. (If you click the RAM/disk icon in the upper right, you should see that this session is a GPU session. If it's not, go to "Runtime -> Change Runtime Type" to select a GPU.)

PyTorch `Tensor`s offer a similar interface to Numpy arrays. You can read all about them in the [docs](https://pytorch.org/docs/stable/tensors.html). In this notebook, we'll name our variables with `_t` postscripts to indicate that they are `Tensor`'s. We can convert back and forth between an `array`s and `Tensor`s using the convenience functions `to_t` and `from_t`, which we defined above. (There are a few minor concerns in making the translation; for example, Numpy defaults to 64-bit floats whereas PyTorch defaults to 32-bit. Similarly, we have to make sure that we copy the tensor back to the CPU and "detach" it from the computation graph before converting it to a Numpy array.)

`Tensor` objects have a few nice functions that will make your life easier in this homework. Suppose `data_t` is a `Tensor`. Then,
- `data_t.flip(dims=(-1,))` flips the tensor along its last axis.
- `data_t.unsqueeze(0)` creates a new leading axis.
- `data_t.permute(1, 0, 2)` permutes the order of the axes.
- `data_t.reshape(1, 1, -1)` vectorizes the data and reshapes it to have two leading dimensions of length 1.
- `data_t.sum()` sums the entries in the data.

If this is your first time using PyTorch, you might want to check out some of their [tutorials](https://pytorch.org/tutorials/beginner/pytorch_with_examples.html) first.


In [None]:
# We'll use PyTorch for this problem
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions

In [None]:
# specify that we want our tensors on the GPU and in float32
device = torch.device('cuda')
dtype = torch.float32

# helper function to convert between numpy arrays and tensors
to_t = lambda array: torch.tensor(array, device=device, dtype=dtype)
from_t = lambda tensor: tensor.to("cpu").detach().numpy()

### Problem 3 intro (a): Perform a 1d convolution

As a first step, we generate synthetic data simulating the presence of a spike in a single channel. Based on a template and a 1d array representing amplitudes, we simulate data as the convolution of these two arrays.

You'll use the `conv1d` function in the `torch.nn.functional` package. We've already imported as `F` so that you can call it with `F.conv1d(...)`. Take a look at its documentation [here](https://pytorch.org/docs/stable/nn.functional.html?highlight=conv1d#torch.nn.functional.conv1d), as well as the corresponding documentation for the `torch.nn.Conv1d` object, which implements a convolutional layer for a neural network.

Here's an example,
```
# make the input (i.e. the signal)
B = 1   # batch size
N = 2   # number of input channels
T = 100 # length of input signal
input_t = torch.rand(B, N, T)

# make the weights (i.e. the filter)
C = 3   # number of output channels
D = 10  # length of the filter
weight_t = torch.rand(C, N, D)

# perform the convolution
output_t = F.conv1d(input_t, weight_t)

# output.shape is (B, C, T - D + 1)
```
**Remember that `conv1d` actually performs a cross-correlation!**

Let $X \in \mathbb{R}^{B \times N \times T}$ denote the signal/input and $W \in \mathbb{R}^{C \times N \times D}$ denote the filter/weights, and let $Y \in \mathbb{R}^{B \times C \times (T - D + 1)}$ denote the output. Then the `conv1d` function implements the cross-correlation,
\begin{align}
y_{b,c,t} = \sum_{n = 1}^{N} \sum_{d=1}^D x_{b,n,t+d-1} w_{c,n,d}.
\end{align}
for $b=1,\ldots,B$, $c=1,\ldots,C$, and $t=1,\ldots,T-D+1$.

By default the output only contains the "valid" portion of the convolution; i.e. the $T-D+1$ samples where the inputs and weights completely overlap. If you want the "full" output, you have to call `F.conv1d(input_t, weights_t, padding=D-1)`. This pads the input with $D-1$ zeros at the beginning and end so that the resulting output is length $T + D - 1$. Depending your application, you may want the first $T$ or the last $T$ entries in this array. When in doubt, try both and see!

Use `conv1d` to implement a 1d **convolution**. Remember that you can do it by cross-correlation as long as you flip your weights along the last axis.

In [None]:
# Create a dataset with a single channel and one neuron.
T = 1000    # number of time samples
N = 1       # one neuron
C = 1       # number of channels
D = 51      # duration of a spike (in samples)

# Generate random templates, amplitudes, and noisy data.
# `templates` are NxCxD and `amplitudes` are NxT
rng = np.random.default_rng(seed=2)
templates, amplitudes, _ = generate(rng, T, C, D, N)

# First we'll perform the convolution using numpy.
data = np.convolve(amplitudes[0], templates[0, 0], mode='full')[:T]
assert data.shape[0] == T

# Plot the templates, amplitude, and data
plot_model(templates, amplitudes, data)

###
# Now perform the same convolution using PyTorch's `conv1d` function. F
#
# YOUR CODE BELOW
amplitudes_t = to_t(amplitudes)
templates_t = to_t(templates)

# data_t = F.conv1d(...)
#
###

assert np.allclose(data, from_t(data_t))

### Problem 3 intro (b): Perform a 1d cross-correlation in PyTorch

Recall from class that the cross-correlation measures the similarity between the template and the actual data at every time window. In those time points where the data and the template coincide, we should obtain a high correlation indicating the presence of a spike. The peaks of the amplitude array and the cross-correlation array should match, as you see in the plot below.  The dotted line shows the cross-correlation of the data and the template, and we see that it peaks where there are spikes in the true underlying amplitude that generated the data. Use `conv1d` to implement this **cross-correlation** and get the dotted line.

In [None]:
# Now correlate the data with the template to estimate spike times
score = np.correlate(data, templates[0, 0], mode='full')[D-1:]
assert score.shape[0] == T

plot_model(templates, amplitudes, data, scores=score)

###
# Now perform the same cross-correlation using PyTorch's `conv1d` function.
#
# YOUR CODE BELOW
#
# score_t = F.conv1d(...)
#
###

# Move the data tensor to the CPU, detach it, and convert to a numpy array
assert np.allclose(score, from_t(score_t))

### Problem 3 intro (c): Perform a 1d convolution across multiple channels at once

Similar to problem 3a, except that the template (and therefore the final synthetic data) has multiple _output_ channels.

In [None]:
# Create a dataset with a single channel and one neuron.
T = 1000    # number of time samples
C = 10      # number of channels
D = 100     # duration of a spike (in samples)
N = 1       # one neuron

# Generate random templates, amplitudes, and noisy data.
# `templates` are NxCxD and `amplitudes` are NxT
rng = np.random.default_rng(seed=0)
templates, amplitudes, _ = generate(rng, T, C, D, N)

# Convolve the signal with each row of the multi-channel tempalte
data = np.row_stack([
    np.convolve(amplitudes[0], row, mode='full')[:T]
    for row in templates[0]])

plot_model(templates, amplitudes, data)

###
# Now perform the same convolution using PyTorch's `conv1d` function.
#
# YOUR CODE BELOW
#
amplitudes_t = to_t(amplitudes)
templates_t = to_t(templates)
# data_t = F.conv1d(...)
#
###

assert np.allclose(data, from_t(data_t))

### Problem intro 3 (d): Perform a 1d cross-correlation across multiple channels at once

Same as Problem 3b except that the data and templates have multiple _input_ channels.

In [None]:
# We'll first perform the cross-correlation in numpy by correlating
# each row of the data with the corresponding row of the template and summing.
# Then you'll do the same thing in PyTorch using a single call to `F.conv1d`.
score = np.sum([np.correlate(data[c], templates[0, c], mode='full')[D-1:]
                for c in range(C)], axis=0)

plot_model(templates, amplitudes, data, scores=score)

###
# Now perform the same cross-correlation using PyTorch's
# ``nn.functional.conv1d` function. You should only need
# one call to this function!
#
# YOUR CODE BELOW
# score_t = F.conv1d(...)
#
###

assert np.allclose(score, from_t(score_t))

### Problem 3 intro (e): Convolving multiple neurons' spikes and templates

Similar to problem 3c, but here we have multiple neurons (input channels), each associated to a template with multiple (output) channels. The final simulated data is aggregated across neurons to simulate actual measurements where signals from multiple neurons are superimposed.

In [None]:
# Create a dataset with a multiple channels and neurons.
T = 1000    # number of time samples
C = 10      # number of channels
D = 100     # duration of a spike (in samples)
N = 3       # multiple neuron

# Generate random templates, amplitudes, and noisy data.
# `templates` are NxCxD and `amplitudes` are NxT
rng = np.random.default_rng(seed=0)
templates, amplitudes, _ = generate(rng, T, C, D, N)

# Convolve the signal with each row of the multi-channel template
data = 0
for temp, amp in zip(templates, amplitudes):
    data += np.row_stack([
        np.convolve(amp, row, mode='full')[:T]
        for row in temp])

plot_model(templates, amplitudes, data)

###
# Now perform the convolution using PyTorch's `conv1d` function.
# One call to `F.conv1d` should perform the sum over neurons for you.
#
# YOUR CODE BELOW
#
templates_t = to_t(templates)
amplitudes_t = to_t(amplitudes)
# data_t = F.conv1d(...)

# permute the channels (out) and neurons (in) dimensions of the templates,
# and flip along the duration dimension.
#
###
assert np.allclose(data, from_t(data_t))

### Problem 3 intro (f): Perform a 1d cross-correlation across multiple channels and neurons at once

Same as Problem 3c but now we're performing the cross-correlation with multiple neurons' templates at once.

In [None]:
# We'll first perform the cross-correlation in numpy by correlating
# each row of the data with the corresponding row of each template and summing.
# Then you'll do the same thing in PyTorch using a single call to `F.conv1d`.
score = np.array([
    np.sum([np.correlate(data[c], templates[n, c], mode='full')[D-1:]
            for c in range(C)], axis=0)
    for n in range(N)])

plot_model(templates, amplitudes, data, scores=score)

###
# Now perform the convolution using PyTorch's `conv1d` function.
# One call to `F.conv1d` should perform all cross-correlations for you.
#
# YOUR CODE BELOW
# score_t = F.conv1d(...)
#
#
###

assert np.allclose(score, from_t(score_t), atol=1e-4)

# Problem 3 Model: Modeling the Voltage Data with Generative Modeling

We will use the same simulated data as in Problem 2. Make sure you use the same parameters as in Problem 2, with the same seed.

In [None]:
# Create a larger dataset with a multiple channels and neurons.
T = 1000000  # number of time samples
C = 10      # number of channels
D = 81      # duration of a spike (in samples)
N = 5       # multiple neurons

# Generate random templates, amplitudes, and noisy data.
# `templates` are NxCxD and `amplitudes` are NxT
rng = np.random.default_rng(seed=1)
print("Simulating data. This could take a minute!")
true_templates, true_amplitudes, data = generate(rng, T, C, D, N)
plot_model(true_templates, true_amplitudes[:, :2000], data[:,:2000], lw=1, figsize=(12, 8))

In [None]:
# Generate a set of random templates and amplitudes to seed the model
rng = np.random.default_rng(seed=1)
templates = generate_templates(rng, C, D, N)
amplitudes = np.zeros((N, T))
noise_std = 1.0

# copy to the device
templates_t = to_t(templates)
amplitudes_t = to_t(amplitudes)
data_t = to_t(data)

## Problem 3(a): Compute the log likelihood

One of the most awesome features of PyTorch is its `torch.distributions` package. See the docs [here](https://pytorch.org/docs/stable/distributions.html). It contains objects for many of our favorite distributions, and has convenient functions for computing log probabilities (with `d.log_prob()` where `d` is a `Distribution` object), sampling (`d.sample()`), computing the entropy (`d.entropy()`), etc. These functions broadcast as you'd expect (unlike `scipy.stats`!), and they're designed to work with automatic differentiation.  More on that another day...

For now, you'll use `torch.distributions.Normal` to compute the log likelihood of the data given the template and amplitudes, $\log p(Y \mid A, W)$.  To do that, you'll convolve the amplitudes and templates (recall Problem 3 intro (e)) to get the mean value of $Y$, then you'll use the `log_prob` function to evaluate the likelihood of the data.

In [None]:
def log_likelihood(templates_t, amplitudes_t, data_t, noise_std):
    """Evaluate the log likelihood"""
    N, C, D = templates_t.shape
    _, T = data_t.shape

    ###
    # Compute the log probability
    #
    # YOUR CODE BELOW
    #
    # compute the model predictions by convolving the amplitude and templates
    # pred_t = F.conv1d(...)
    #
    # evaluate the log probability using torch.distributions.Normal
    # lp = ...
    #
    ###

    # return the log probability normalized by the data size
    return lp / (C * T)

ll = log_likelihood(templates_t, amplitudes_t, data_t, noise_std)


## Problem 3(b): Compute the residual

Next, compute the residual for a specified neuron by subtracting the convolved amplitudes and templates for all the other neurons. Again, recall Problem 3 intro (e).

In [None]:
def compute_residual(neuron, templates_t, amplitudes_t, data_t):
    N, C, D = templates_t.shape

    ###
    # Compute the predicted value of the data by
    # convolving the amplitudes and the templates for all
    # neurons except the specified one.
    #
    # YOUR CODE BELOW
    not_n = np.concatenate([np.arange(neuron), np.arange(neuron+1, N)])
    # pred_t = F.conv1d(...)
    #
    ###

    # return the data minus the predicted value given other neurons
    return data_t - pred_t

residual_t = compute_residual(0, templates_t, amplitudes_t, data_t)

## Problem 3(c): Compute the score

We defined the "score" for neuron $n$ to be the cross-correlation of the residual and its template. Compute it using `conv1d`. Recall Problem 3 intro (d).

In [None]:
def compute_score(neuron, templates_t, amplitudes_t, data_t):
    N, C, D = templates_t.shape
    T = data_t.shape[1]

    # first get the residual
    residual_t = compute_residual(neuron, templates_t, amplitudes_t, data_t)

    ###
    # Compute the 'score' by cross-correlating the residual
    # and the template for this neuron.
    #
    # YOUR CODE BELOW

    # score_t = F.conv1d(...)
    #
    ###
    return score_t

## Problem 3(d): Update the amplitudes using `find_peaks`

Our next step is to update the amplitudes given the scores. We'll use the simple heuristic described in lecture to find peaks in the score that are separated by a distance of at least $D$ samples and at least a height of $\sigma^2 \lambda$, where $\sigma$ is the standard deviation of the noise and $\lambda$ is the amplitude rate hyperparameter. Use the `find_peaks` function from 1(b) to do this, and then update the amplitude tensor with your results

In [None]:
def _update_amplitude(neuron, templates_t, amplitudes_t, data_t, noise_std=1.0, amp_rate=5.0):
    N, C, D = templates_t.shape
    T = data_t.shape[1]

    # compute the score and convert it to a numpy array.
    score_t = compute_score(neuron, templates_t, amplitudes_t, data_t)
    score = from_t(score_t)

    ###
    # Find the peaks in the cross-correlation and update the amplitude tensor.
    #
    # YOUR CODE BELOW
    # peaks, props = find_peaks(...)
    ###

    heights = props['peak_heights']

    # Update the amplitude tensor for this neuron.
    amplitudes_t[neuron] = 0
    amplitudes_t[neuron, peaks] = to_t(props['peak_heights'])


## Problem 3(e): Update the templates
Our last step is to update the template for a given neuron by projecting $\overline{R}_n \in \mathbb{R}^{C \times D}$, the sum of scaled residuals at the times of spikes in the amplitudes:
\begin{align}
    \overline{R}_n = \sum_{t:a_{nt} > 0} a_{nt} R_{n,:,t:t+D}.
\end{align}
where $R_n \in \mathbb{R}^{C \times T}$ denotes the residual for neuron $n$.

In lecture we suggested a simple trick to implement this summation. First compute a matrix of regressors $X_n \in \mathbb{R}^{D \times T}$ where the $d$-th row is equal to the lagged amplitudes. That is,
\begin{align}
x_{ndt} = a_{n,t-d+1}.
\end{align}
We can equivalently compute the regressor matrix by convolving the amplitudes with a "delay" matrix, which is just the $D \times D$ identity matrix.

Once we have the regressors, we can compute sum of scaled residuals as $\overline{R}_n = R_n X_n^\top$.

Finally, to get the template, project $\overline{R}_n$ onto $\mathcal{S}_K$, the set of rank-$K$, unit-norm matrices, using the SVD.

In [None]:
def _update_template(rng, neuron, templates_t, amplitudes_t, data_t, template_rank=1):
    N, C, D = templates.shape
    T = data_t.shape[1]


    # check if the factor is used. if not, generate a random new one.
    if amplitudes_t[neuron].sum() < 1:
        target_t = to_t(generate_templates(rng, C, D, 1)[0])

    else:
        ###
        # Make a TxD array of regressors for this neuron
        # by convolving its amplitude with a "delay" matrix;
        # i.e. a DxD identity matrix.
        #
        # YOUR CODE BELOW
        delay_t = to_t(np.eye(D))

        # regressors_t = F.conv1d(...)
        #
        ###

        # get the residual using the function you wrote above
        residual_t = compute_residual(neuron, templates_t, amplitudes_t, data_t)

        # compute the target (inner product of residual and regressors)
        target_t = residual_t @ regressors_t.T

    ###
    # Project the target onto the set of normalized rank-K templates using
    # `torch.svd` and `torch.norm`. Note that `torch.svd` returns V rather
    # than V^T, as `np.linalg.svd` does.
    #
    # YOUR CODE BELOW
    # ...
    # templates_t[neuron] = ...
    #
    ###



## Put it all together

That's it! We've written a little function to perform coordinate ascent using your `_update_*` functions. It tracks the log likelihood at each iteration. (We're ignoring the priors for now). It also uses some nice progress bars so you can see how fast (or slow?) your code runs.

In [None]:
def map_estimate(rng,
                 templates_t,
                 amplitudes_t,
                 data_t,
                 num_iters=20,
                 template_rank=1,
                 noise_std=1.0,
                 amp_rate=5.0,
                 tol=1e-4):
    """Fit the templates and amplitudes by maximum a posteriori (MAP) estimation.
    """
    N, C, D = templates_t.shape

    # make a fancy reusable progress bar for the inner loops over neurons.
    outer_pbar = trange(num_iters)
    inner_pbar = trange(N)
    inner_pbar.set_description("updating neurons")

    # track log likelihoods over iterations
    lls = [from_t(log_likelihood(templates_t, amplitudes_t, data_t, noise_std=noise_std))]
    for itr in outer_pbar:
        inner_pbar.reset()
        for n in range(N):
            # update the amplitude
            _update_amplitude(n, templates_t, amplitudes_t, data_t, noise_std=noise_std, amp_rate=amp_rate)
            # update the template
            _update_template(rng, n, templates_t, amplitudes_t, data_t, template_rank=template_rank)
            inner_pbar.update()

        # compute the log likelihood
        lls.append(from_t(log_likelihood(templates_t, amplitudes_t, data_t, noise_std=noise_std)))

        # check for convergence
        if abs(lls[-1] - lls[-2]) < tol:
            print("Convergence detected!")
            break

    return np.array(lls)

## Fit the synthetic data and plot the log likelihoods

In [None]:
# Make random templates and set amplitude to zero
rng = np.random.default_rng(seed=1)
templates = generate_templates(rng, C, D, N)
amplitudes = np.zeros((N, T))
noise_std = 1.0     # \sigma
amp_rate = 5.0      # \lambda

# copy to the device
templates_t = to_t(templates)
amplitudes_t = to_t(amplitudes)
data_t = to_t(data)

# Fit the model.
lls = map_estimate(rng, templates_t, amplitudes_t, data_t, noise_std=noise_std, amp_rate=amp_rate)

# For comparison, compute the log likelihood with the true templates and amplitudes.
true_ll = from_t(log_likelihood(to_t(true_templates),
                                to_t(true_amplitudes),
                                data_t,
                                noise_std))

# Plot the log likelihoods
plt.plot(lls, '-o')
plt.hlines(true_ll, 0, len(lls) - 1, colors='r', linestyles=':', label="true LL")
plt.xlabel("Iteration")
plt.xlim(-.1, len(lls) - .9)
plt.ylabel("Log Likelihood")
plt.grid(True)
plt.legend(loc="lower right")

## Find a permutation of the inferred neurons that best matches the true neurons

In [None]:
# compute the similarity (inner product) of the true and inferred templates
templates = from_t(templates_t)
similarity = np.zeros((N, N))
for i in range(N):
    for j in range(N):
        similarity[i, j] = np.sum(true_templates[i] * templates[j])

# Show the similarity matrix
_, perm = linear_sum_assignment(similarity, maximize=True)
plt.imshow(similarity[:, perm], vmin=0, vmax=1)
plt.xlabel("true neuron")
plt.ylabel("inferred neuron")
plt.title("cosine similarity of amplitudes")
plt.colorbar()

## Plot the true and inferred templates

They should line up pretty well.

In [None]:
# Plot the true and inferred templates, permuted to best match
plot_templates(true_templates, np.arange(N), n_cols=N)
plot_templates(templates[perm], np.arange(N), n_cols=N, colors=('r',))

## Problem 3(f): Discussion

- How does this compare to the clustering (K-means) approach?

- What are the advantages of utilizing this kind of model?

*Answer below this line*

---

# Submission Instructions
- Print to PDF and download an .ipynb file.
- Submit both on Canvas (one per team).