In [None]:
from dPCA import dPCA

import matplotlib.pyplot as plt
import numpy as np
import scipy.io as sio

import os

from src.decomposition_hyperparams import Hyperparams
from src.tensor import centered_trial_average

## Parameter Setup

In [None]:
# Hyperparameters for F147
F147 = Hyperparams(name='F147')
F147.set_path(path='F147_tensor_zscore.npy')
F147.set_events(
    events_name=['Laser On', 'Initial Turn', 'Laser Off'],
    events_time=[22, 166, 184]
)

In [None]:
# Hyperparameters for F201
F201 = Hyperparams(name='F201')
F201.set_path(path='F201_tensor_zscore.npy')
F201.set_events(
    events_name=['Laser On', 'Initial Turn', 'Laser Off'],
    events_time=[22, 83, 101]
)

In [None]:
# Select which hyperparameters to use
hyp = F147

## Data Preparation

In [None]:
# Move to the results directory and load data tensor
os.chdir('../results/')
tensor = np.load(hyp.path)

In [None]:
# Add the stimulus axis to the tensor
tensor_stim = np.expand_dims(tensor, axis=2)

In [None]:
# Calculate centered trial averages
tensor_cta = centered_trial_average(tensor_stim, trial_axis=0, neuron_axis=1)

## dPCA

In [None]:
# Initialize a dPCA object
dpca = dPCA.dPCA(labels='st', join={'s': ['s', 'st']}, n_components=3, regularizer='auto')
dpca.protect = ['t']

In [None]:
# Perform dPCA using a debugged version of the package
# In the source code, there is a bug on line 660 of dPCA/python/dPCA/dPCA.py
# This causes the function train_test_split to fail when there is only one stimulus
Z = dpca.fit_transform(tensor_cta, tensor_stim)

In [None]:
# TODO: Regularization parameter plot - x: lambda, y: residual variance over total test variance
# TODO: Use cross validation

## dPC Visualization

In [None]:
# Full names of each label
names = {'t': 'Condition-Independent', 's': 'Stimulus-Independent'}

# Color list
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

In [None]:
# Adjust the row order of features to plot
order = {0: 't', 1: 's'}

In [None]:
%matplotlib inline

# Create subplots for 2D plots
fig, axs = plt.subplots(nrows=3, ncols=2, sharex=True, sharey='col', figsize=(5.5, 5))

# Plot data
for row in range(axs.shape[0]):
    for col in range(axs.shape[1]):
        axs[row, col].plot(Z[order[col]][row, 0])
        
        # Add feature labels
        if row == 0:
            axs[row, col].set_title(names[order[col]])
        
        # Add component labels
        if col == 0:
            axs[row, col].set_ylabel("dPC " + str(row + 1))
            
        # Add units
        if row == axs.shape[0] - 1:
            axs[row, col].set_xlabel("Time")
        
        # Add event lines
        for i in range(len(hyp.events_time)):
            axs[row, col].axvline(hyp.events_time[i], c=colors[i + 1], label=hyp.events_name[i])

# Adjust subplot padding
fig.tight_layout()
            
# Add title and legend
fig.suptitle(hyp.name + " Demixed Principal Components (dPCs)", y=1.06)
fig.legend(*axs[0, 0].get_legend_handles_labels(), loc=1, fontsize='x-small')

# Display the plot
plt.show()

In [None]:
%matplotlib qt

# Generate 3D plots
for label in Z.keys():
    
    # Plot the points using the three dPCs as axes
    fig, ax = plt.subplots(subplot_kw={'projection': '3d'})
    ax.scatter(Z[label][0, 0], Z[label][1, 0], Z[label][2, 0],
               c=np.arange(Z[label].shape[2]), cmap='gist_rainbow', alpha=0.4)
    
    # Plot event points
    for i in range(len(hyp.events_time)):
        ax.scatter(Z[label][0, 0][hyp.events_time[i]], Z[label][1, 0][hyp.events_time[i]], Z[label][2, 0][hyp.events_time[i]],
                   s=144, c=colors[i + 1], marker='x', label=hyp.events_name[i])
    
    # Add title and labels
    ax.set_xlabel("dPC 1")
    ax.set_ylabel("dPC 2")
    ax.set_zlabel("dPC 3")
    ax.set_title(hyp.name + " " + names[label] + " Components")
    
    # Add legend
    plt.legend(*ax.get_legend_handles_labels(), loc=[1, 0.88], fontsize='small')
    
    # Display the plot
    plt.show()