In [None]:
import h5py
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats as stats

from sklearn.decomposition import PCA
from sklearn.cluster import KMeans

In [None]:
with h5py.File('spike_waveform.hdf5', 'r') as _f:
    waveform = np.array(_f['waveform'])
    unit = np.array(_f['unit'])
    frequency = np.array(_f['frequency'])
    spike_train = np.array(_f['spike_train'])
    
tspec = np.linspace(-30/frequency, 30/frequency, 60) * 1000

In [None]:
plt.figure(figsize=(8,6))
plt.plot(tspec, waveform[::5,:].T, 'k', alpha=0.1)
plt.show()

In [None]:
_proj = PCA(n_components=5).fit_transform(waveform)

In [None]:
plt.figure(figsize=(6,6))
plt.scatter(_proj[:, 0], _proj[:, 1], s=3, alpha=0.5)

plt.xlim((-0.3, 0.3))
plt.ylim((-0.3, 0.3))
plt.show()

In [None]:
y_pred = KMeans(n_clusters=3).fit_predict(_proj[:, :3])

plt.figure(figsize=(8,8))

for group in np.unique(y_pred):
    _subgroup = y_pred == group
    plt.scatter(_proj[_subgroup, 0], _proj[_subgroup, 1], s=3, alpha=0.5)

plt.xlim((-0.3, 0.3))
plt.ylim((-0.3, 0.3))
plt.show()

In [None]:
group_n = np.size(np.unique(y_pred))
plt.figure(figsize=(6*group_n,4))

for idx, group in enumerate(np.unique(y_pred)):
    plt.subplot(1,group_n, idx+1)
    _subgroup = y_pred == group
    plt.plot(tspec, waveform[_subgroup].T, 'k', alpha=0.1)
    plt.plot(tspec, np.mean(waveform[_subgroup], 0), 'r', linewidth=2)
    plt.ylim((-0.2, 0.1))
    
# plt.plot(tspec, waveform[::5,:].T, 'k', alpha=0.1)
plt.show()