In [None]:
from dPCA import dPCA

import matplotlib.pyplot as plt
import numpy as np
import scipy.io as sio

import os

from src.tensor import centered_trial_average

## File Setup

In [None]:
# Relevant files for F147
F147 = {
    'tensor': 'F147_tensor.npy',
    'trial': '2p_raw/F147/20210526_LT_18_0.mat'
}

In [None]:
# Relevant files for F201
F201 = {
    'tensor': 'F201_Z-Score.npy',
    'trial': '2p_raw/F201/20210812_RT_13_59.mat'
}

In [None]:
# Select which files to load
path = F147

In [None]:
# Move to the data directory and load trial metadata
os.chdir('../data/')
trial_info = sio.loadmat(path['trial'])['trial']

In [None]:
# Move to the results directory and load data tensor
os.chdir('../results/')
tensor = np.load(path['tensor'])

## Stimuli Separation

In [None]:
# Add the stimulus axis to the tensor
tensor_stim = np.expand_dims(tensor, axis=2)

In [None]:
# Calculate centered trial averages
tensor_cta = centered_trial_average(tensor_stim, trial_axis=0, neuron_axis=1)

In [None]:
# Initialize a dPCA object
dpca = dPCA.dPCA(labels='st', n_components=3, regularizer='auto')
dpca.protect = ['t']

In [None]:
# Perform dPCA using a debugged version of the package
# In the source code, there is a bug on line 660 of dPCA/python/dPCA/dPCA.py
# This causes the function train_test_split to fail when there is only one stimulus
Z = dpca.fit_transform(tensor_cta, tensor_stim)

In [None]:
%matplotlib inline

time = np.arange(Z['t'].shape[2])
order = {0: '1st', 1: '2nd', 2: '3rd'}

for i in range(Z['t'].shape[0]):
    plt.figure(figsize=(16,7))
    plt.subplot(131)
    for s in range(Z['t'].shape[1]):
        plt.plot(time,Z['t'][i,s])
    plt.title(order[i] + ' time component')
    plt.subplot(132)
    for s in range(Z['t'].shape[1]):
        plt.plot(time,Z['s'][i,s])
    plt.title(order[i] + ' stimulus component')
    plt.subplot(133)
    for s in range(Z['t'].shape[1]):
        plt.plot(time,Z['st'][i,s])
    plt.title(order[i] + ' mixing component')
    plt.show()

In [None]:
%matplotlib qt

name = {'s': 'Stimulus', 't': 'Time', 'st': 'Mixed'}

for label in Z.keys():
    fig, ax = plt.subplots(subplot_kw={'projection': '3d'})
    ax.scatter(Z[label][0, 0], Z[label][1, 0], Z[label][2, 0],
               c=np.arange(Z[label].shape[2]), cmap='gist_rainbow')
    ax.set_xlabel('dPC1')
    ax.set_ylabel('dPC2')
    ax.set_zlabel('dPC3')
    ax.set_title(name[label])
    plt.show()