In [None]:
from sklearn.decomposition import PCA
import scipy
import numpy as np
from functions import get_audio_onset_offset
import os
from pathlib import Path
from session_metadata import incorrect_trials
import math
import plotly.graph_objects as go
import matplotlib.pyplot as plt

from datetime import datetime
current_datetime = datetime.now()
formatted_datetime = current_datetime.strftime("%Y%m%d_%H%M%S")

required_keys = ['spikepow', 'threshcross', 'predaudio16k', 'cue']
participant = 't15'
session = 'XX'
data_path = f'../data/{participant}/{session}/'
savepath_fig = f'../figures/{participant}_pca/'
if not os.path.exists(savepath_fig):
    os.mkdir(savepath_fig)
files = os.listdir(data_path)

nbins_before_onset = 10 # 100ms
nbins_after_onset = 40 # 500ms
n_channels = 256
amplitudes = ['MIME', 'WHISPER', 'NORMAL', 'LOUD']
words = ['be', 'my', 'know', 'do', 'have', 'going']
amplitude_color = [
    (167/255, 185/255, 207/255),
    (114/255, 159/255, 207/255),
    (53/255, 126/255, 221/255),
    (0, 79/255, 158/255),
]
word_color = [
    (0.9254902, 0.12156863, 0.14117647),
    (0.98431373, 0.72941176, 0.07058824),
    # [0.57254902, 0.78431373, 0.24313725],  # Commented out in MATLAB
    (0.384, 0.682, 0.2),
    (0.43137255, 0.79607843, 0.85490196),
    # (0.26529412, 0.40686275, 0.72490196),
    [0.45568627, 0.31764706, 0.63529412],  # Commented out in MATLAB
    (0.84705882, 0.2627451, 0.59215686)
]


In [None]:
# load data
data = {}
for file in files:
    name, extension = os.path.splitext(file)
    if extension == '.mat':
        fullPath = str(Path(data_path, file).resolve())
        print(f'Loading {fullPath} ...')
        data_temp = scipy.io.loadmat(fullPath)

        # remove incorrect trials
        curr_block = int(np.squeeze(data_temp['block_number']))
        remove_trials = incorrect_trials[participant][session][curr_block] # trial ids, 1-indexed
        print(f'Removing trials {remove_trials}')
        remove_trial_inds = [i-1 for i in remove_trials] # trial indices, 0-indexed

        data_temp_required = {}
        for key in required_keys:
            data_temp_required[key] = np.delete(data_temp[key], remove_trial_inds, axis = -1)

        # append data to master dict
        if data == {}:
            data = data_temp_required
        else:
            for key in required_keys:
                data[key] = np.append(data[key], data_temp_required[key], axis = -1)


print('Data loaded ...')
for key in data:
    print(key, data[key].shape)


In [None]:
# load spikepow around speech onset, get average spike for word-amp combination
n_trials = len(data['cue'])
word_amp_trial_avg = []
word_label = []
amp_label = []
for word in words:
    for amp in amplitudes:
        print(word, amp)
        inds = [k for k in range(n_trials) if amp in data['cue'][k] and 
                word in data['cue'][k] and max(np.squeeze(data['predaudio16k'])[k]) != 0]
        
        x_spikepow = np.empty((0, nbins_before_onset + nbins_after_onset, 256))
        # x_threshcross = np.empty((0, nbins_before_onset + nbins_after_onset, 256))
        for k in inds:
                spikepow = np.squeeze(data['spikepow'])[k]
                # threshcross = np.squeeze(data['threshcross'])[k]
                start_ind, end_ind = get_audio_onset_offset(np.squeeze(data['predaudio16k'])[k].squeeze())
                # binned start and end ind
                start_ind = math.floor((start_ind/30000) * (1000/10)) # divide by sampling rate (30kHZ), scale it to ms by multiplying with 1000, divide by 10 to get bin index
                end_ind = math.ceil((end_ind/30000) * (1000/10)) # divide by sampling rate (30kHZ), scale it to ms by multiplying with 1000, divide by 10 to get bin index

                if np.expand_dims(spikepow[start_ind - nbins_before_onset: start_ind + nbins_after_onset, :], 0).shape[1] == nbins_before_onset + nbins_after_onset:
                        # add spikepow and threshcross around speech onset
                        temp_spikepow = spikepow[start_ind - nbins_before_onset: start_ind + nbins_after_onset, :] # shape (time_bins x 256)
                        x_spikepow = np.append(x_spikepow, np.expand_dims(temp_spikepow, 0), axis = 0)

        print(x_spikepow.shape)
        x_spikepow_mean = np.mean(x_spikepow, axis = 0) # avg across trials
        word_amp_trial_avg.append(x_spikepow_mean)
        word_label.append(words.index(word))
        amp_label.append(amplitudes.index(amp))

amp_word_trial_avg = np.array(word_amp_trial_avg)
print(amp_word_trial_avg.shape)

In [None]:
word_label_3d = np.zeros(amp_word_trial_avg.shape)
amp_label_3d = np.zeros(amp_word_trial_avg.shape)
for i in range(amp_word_trial_avg.shape[0]):
    word_label_3d[i] = word_label[i] * np.ones((amp_word_trial_avg.shape[1], amp_word_trial_avg.shape[2]))
    amp_label_3d[i] = amp_label[i] * np.ones((amp_word_trial_avg.shape[1], amp_word_trial_avg.shape[2]))
print(word_label_3d.shape, amp_label_3d.shape, amp_word_trial_avg.shape)

n_samples = n_trial x t, n_feat  = n_channels; plot after averaging across time

In [None]:
# compute pca all components, but take only 3 for plotting
pca = PCA(n_components=3)

# no averaging across time; n_samples == n_conditions x time (25 x 50), n_features == n_channels (256)
X = amp_word_trial_avg.reshape(-1, 256)
print('PCA input shape:', X.shape)

from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

x_pca = pca.fit_transform(X_scaled) 
print('PCA transformed shape:', x_pca.shape)

x_pca_flatten_time_pca_on_ch = x_pca

# reshape it back to n_trials x time x n_pca_components
x_pca = x_pca.reshape(amp_word_trial_avg.shape[0], -1, 3)
print('PCA reshaped back to n_trials x time x n_pca_feat:', x_pca.shape)
# averge across time
x_pca = np.mean(x_pca, axis=1)
print('Average across time:', x_pca.shape)

x_ax = x_pca[:, 0]
y_ax = x_pca[:, 1]
z_ax = x_pca[:, 2]

print('PCA x_ax:', x_ax)
print('PCA y_ax:', y_ax)
print('PCA z_ax:', z_ax)


# Create 3D scatter plot, legend based on amplitude
fig = go.Figure(data=[go.Scatter3d(
    x=x_ax, y=y_ax, z=z_ax,
    mode='markers',
    marker=dict(size=5, color=np.array(amplitude_color)[amp_label], opacity=1),
)])

# Set layout options
fig.update_layout(
    title="Interactive 3D Scatter Plot",
    scene=dict(
        xaxis_title="PC1",
        yaxis_title="PC2",
        zaxis_title="PC3"
    )
)

# Show plot
fig.show()

### Plot PCA projections

In [None]:
# amplitude based color coding

# uncomment to use ipympl backend and allow interactive 3D plotting
# %matplotlib ipympl
# fig = plt.figure(figsize=(5, 5))
# fontsize = 10
# scatter_size = 100

fontsize = 26
scatter_size = 1800


fig = plt.figure(figsize=(15, 15))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(x_ax, y_ax, z_ax, color=np.array(amplitude_color)[amp_label], s=scatter_size, alpha = 0.8)
ax.set_xlabel('PC 1', fontsize = fontsize)
ax.set_ylabel('PC 2', fontsize = fontsize)
ax.set_zlabel('PC 3', fontsize = fontsize)
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])
ax.view_init(elev = -75, azim = 118, roll = 177)
ax.grid(False)

mime_inds = np.where(np.array(amp_label) == 0)[0]
mime_centroid = [sum(x_ax[mime_inds])/len(mime_inds), sum(y_ax[mime_inds])/len(mime_inds), sum(z_ax[mime_inds])/len(mime_inds)]
loud_inds = np.where(np.array(amp_label) == 3)[0]
loud_centroid = [sum(x_ax[loud_inds])/len(loud_inds), sum(y_ax[loud_inds])/len(loud_inds), sum(z_ax[loud_inds])/len(loud_inds)]
# ax.plot3D([loud_centroid[0], mime_centroid[0]], [loud_centroid[1], mime_centroid[1]], [loud_centroid[2], mime_centroid[2]], 'black', linewidth = 7)

fig.tight_layout()
# save figure
plt.savefig(f'{savepath_fig}{participant}_{session}_{formatted_datetime}_pca_amplitude.svg', format='svg', dpi=1200)
plt.savefig(f'{savepath_fig}{participant}_{session}_{formatted_datetime}_pca_amplitude.png', format='png')


In [None]:
# word based color coding
fig = plt.figure(figsize=(15, 15))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(x_ax, y_ax, z_ax, color=np.array(word_color)[word_label], s=scatter_size, alpha = 0.8)
ax.set_xlabel('PC 1', fontsize = fontsize)
ax.set_ylabel('PC 2', fontsize = fontsize)
ax.set_zlabel('PC 3', fontsize = fontsize)
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])
ax.view_init(elev = -75, azim = 118, roll = 177)
ax.grid(False)

fig.tight_layout()
# save figure
plt.savefig(f'{savepath_fig}{participant}_{session}_{formatted_datetime}_pca_word.svg', format='svg', dpi=1200)
plt.savefig(f'{savepath_fig}{participant}_{session}_{formatted_datetime}_pca_word.png', format='png')
