In [1]:
#!/usr/bin/env python
"""
run_sliceTCA.py
Loads a .mat file with X (N×T×K) and runs sliceTCA:
   X[n,t,k] ≈ ∑ₙᵣ uₙ^{(r)} A_{t,k}^{(r)}
            + ∑ₜᵣ vₜ^{(r)} B_{n,k}^{(r)}
            + ∑ₖʳ wₖ^{(r)} C_{n,t}^{(r)}
(see Pellegrino et al. eq. (4)) :contentReference[oaicite:2]{index=2}&#8203;:contentReference[oaicite:3]{index=3}.
"""

'\nrun_sliceTCA.py\nLoads a .mat file with X (N×T×K) and runs sliceTCA:\n   X[n,t,k] ≈ ∑ₙᵣ uₙ^{(r)} A_{t,k}^{(r)}\n            + ∑ₜᵣ vₜ^{(r)} B_{n,k}^{(r)}\n            + ∑ₖʳ wₖ^{(r)} C_{n,t}^{(r)}\n(see Pellegrino et al. eq. (4)) :contentReference[oaicite:2]{index=2}&#8203;:contentReference[oaicite:3]{index=3}.\n'

In [2]:
import sys
import numpy as np
import scipy.io as sio
import torch
from slicetca import SliceTCA        # pip install slicetca_paper :contentReference[oaicite:4]{index=4}&#8203;:contentReference[oaicite:5]{index=5}

ModuleNotFoundError: No module named 'torch'

In [2]:
import tensortools as tt

In [None]:
# ————————————————
# Load data
# ————————————————
mat = sio.loadmat(sys.argv[1])
X = mat['X'].astype(np.float32)       # shape (N, T, K)

In [None]:
# ————————————————
# Hyperparameters
# ————————————————
R_neuron = 2    # e.g. start with 2 neuron‐slicing comps
R_trial  = 4    # e.g. 4 trial‐slicing comps
R_time   = 1    # e.g. 1 time‐slicing comp

In [None]:
# ————————————————
# Initialize and fit model
# ————————————————
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = SliceTCA(
    R=(R_neuron, R_trial, R_time),
    nonneg=False,
    lr=0.02,
    device=device
)

In [None]:
# cross‐validation could be added here by masking blocks before fitting
model.fit(X, n_iter=300)

In [None]:
# ————————————————
# Extract components
# ————————————————
U, A = model.get_neuron_slicing()   # U: (N, R_neuron), A: (T, K, R_neuron)
V, B = model.get_time_slicing()     # V: (T, R_time),   B: (N, K, R_time)
W, C = model.get_trial_slicing()    # W: (K, R_trial),  C: (N, T, R_trial)

In [None]:
# ————————————————
# Save results
# ————————————————
sio.savemat(sys.argv[2], {
    'U': U, 'A': A,
    'V': V, 'B': B,
    'W': W, 'C': C,
    'reconstruction': model.reconstruct()
})
print("sliceTCA complete. Results saved to", sys.argv[2])