In [None]:
import tensortools as tt

import matplotlib.pyplot as plt
import numpy as np

import os

from src.decomposition_hyperparams import Hyperparams

## Parameter Setup

In [None]:
# Hyperparameters for F147
F147 = Hyperparams(name='F147')
F147.set_path(path='F147_tensor_minmax.npy')
F147.set_decomp_params(n_components=range(6, 7), rep=5)
F147.set_decomp_methods(methods=['ncp_bcd', 'ncp_hals'])
F147.set_events(
    events_name=['Laser On', 'Initial Turn', 'Laser Off'],
    events_time=[22, 166, 184]
)

In [None]:
# Hyperparameters for F201
F201 = Hyperparams(name='F201')
F201.set_path(path='F201_tensor_minmax.npy')
F201.set_decomp_params(n_components=range(6, 7), rep=5)
F201.set_decomp_methods(methods=['ncp_bcd', 'ncp_hals'])
F201.set_events(
    events_name=['Laser On', 'Initial Turn', 'Laser Off'],
    events_time=[22, 83, 101]
)

In [None]:
# Choose which hyperparameters to use
hyp = F147

## Data Loading

In [None]:
# Load the data
os.chdir('../results/')
tensor = np.load(hyp.path)

## TCA

In [None]:
# Fit ensembles of tensor decompositions
ensembles = {}
for m in hyp.methods:
    ensembles[m] = tt.Ensemble(fit_method=m)
    ensembles[m].fit(tensor, ranks=hyp.n_components, replicates=hyp.rep)

## Choosing Rank

In [None]:
# Set the plot size
plt.figure(figsize=(3, 2))

# Create an error plot
for m in hyp.methods:
    tt.plot_objective(ensembles[m])

# Add a title
plt.title(hyp.name + " TCA Error Plot")

# Display the plot
plt.show()

In [None]:
# Set the plot size
plt.figure(figsize=(3, 2))

# Create a similarity plot
for m in hyp.methods:
    tt.plot_similarity(ensembles[m])

# Add a title
plt.title(hyp.name + " TCA Similarity Plot")

# Display the plot
plt.show()

In [None]:
# Set the desired rank
rank = 6

## Optimal Model

In [None]:
# Create lists to hold the best fits
best_obj = []
best_factors = []

In [None]:
# Get the best fits for each method
for m in hyp.methods:
    best_obj.append(ensembles[m].objectives(rank)[0])
    best_factors.append(ensembles[m].factors(rank)[0])

In [None]:
# Find the tensor with the lowest error
tensor_red = best_factors[np.argmin(best_obj)]

In [None]:
# Organize the tensor by decreasing values of lambda
tensor_red.permute(np.flip(np.argsort(tensor_red.component_lams())))
tensor_red.component_lams()

## TCA Component Visualization

In [None]:
# Names and units of factors
factors = ['Trial Factors', 'Neuron Factors', 'Time Factors']
units = ['Trial', 'Neuron', 'Time']

# Color list
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

In [None]:
%matplotlib inline

# Create subplots for 2D plots
fig, axs = plt.subplots(nrows=rank, ncols=3, sharex='col', sharey='col', figsize=(5.5, 1.2 * rank))

# Plot data
for row in range(axs.shape[0]):
    for col in range(axs.shape[1]):
        axs[row, col].plot(tensor_red.factors[col][:, row])
        
        # Add feature labels
        if row == 0:
            axs[row, col].set_title(factors[col])
        
        # Add component labels and dividing lines for trial factors
        if col == 0:
            axs[row, col].set_ylabel("Comp. " + str(row + 1))
            axs[row, col].axvline(tensor_red.factors[col][:, row].size // 2, c='k')
        
        # Add units
        if row == axs.shape[0] - 1:
            axs[row, col].set_xlabel(units[col])
        
        # Add event lines
        if col == 2:
            for i in range(len(hyp.events_time)):
                axs[row, col].axvline(hyp.events_time[i], c=colors[i + 1], label=hyp.events_name[i])

# Add title and legend
fig.suptitle(hyp.name + " TCA Components", y=1.01)
fig.legend(*axs[0, 2].get_legend_handles_labels(), loc=1, fontsize='x-small')
        
# Adjust subplot padding
plt.tight_layout()

# Display the plot
plt.show()

In [None]:
%matplotlib qt

# Create 3D plots for the first three TCA components
for factor in range(len(factors)):
    fig, ax = plt.subplots(subplot_kw={'projection': '3d'})
    
    # Plot points
    ax.scatter(tensor_red.factors[factor][:, 0],
               tensor_red.factors[factor][:, 1],
               tensor_red.factors[factor][:, 2],
               c=np.arange(tensor_red.factors[factor].shape[0]), cmap='gist_rainbow')
    
    # Add axis labels and a title
    ax.set_xlabel("Component 1")
    ax.set_ylabel("Component 2")
    ax.set_zlabel("Component 3")
    ax.set_title(factors[factor])
    
    # Display the plot
    plt.show()

## "Time Factors" Within Trials (Experimental)

In [None]:
# Create an array of time factors within each trial
time_factors_within_trials = np.empty((rank, tensor.shape[0], tensor.shape[2]))    
for i in range(rank):
    for trial in range(tensor.shape[0]):
        time_factors_within_trials[i, trial] = tensor_red.factors[1].T[i] @ tensor[trial]

In [None]:
# Get plotting colors
import matplotlib as mpl
cmap = mpl.colormaps['gist_rainbow']

In [None]:
%matplotlib inline

# Create plots of the factors within trials
fig, axs = plt.subplots(nrows=3, ncols=2, sharex=True, figsize=(5.5, 1.2 * rank))
for i in range(rank):
    
    # Location of the plot in the figure
    row, col = i % 3, i // 3
    
    # Randomly permute the order of trials and plot lines (makes the color distribution more even)
    for trial in np.random.permutation(np.arange(0, tensor.shape[0], 2)):
        axs[row, col].plot(time_factors_within_trials[i, trial], c=cmap(trial / tensor.shape[0], alpha=0.5))
        
        # Add component labels
        axs[row, col].set_ylabel("Comp. " + str(i + 1))
        
        # Add axis labels
        if row == 2:
            axs[row, col].set_xlabel("Time")

# Add plot title
fig.suptitle(hyp.name)

# Adjust subplot padding
plt.tight_layout()

# Display the plot
plt.show()