In [2]:
# widen jupyter notebook window
from IPython.display import display, HTML
display(HTML("<style>.container {width:100% !important; }</style>"))

# check environment
import os
print(f'Conda Environment: ' + os.environ['CONDA_DEFAULT_ENV'])

Conda Environment: roicat


In [3]:
import copy
from pathlib import Path

In [4]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import tensorly as tl

In [5]:
%load_ext autoreload
%autoreload 2
import bnpm

In [6]:
tl.set_backend('pytorch')

Import face_rhythm TCA factors and spectrogram tensor

In [7]:
directory_FR_template = r'/media/rich/bigSSD/analysis_data/face_rhythm/mouse_0322N/20230430/run_from_o2'
directory_FR_current = r'/media/rich/bigSSD/analysis_data/face_rhythm/mouse_0322N/20230502//jobNum_0'

In [8]:
tca_template = bnpm.h5_handling.simple_load(str(Path(directory_FR_template) / 'analysis_files' / 'TCA.h5'))

In [9]:
tca_current = bnpm.h5_handling.simple_load(str(Path(directory_FR_current) / 'analysis_files' / 'TCA.h5'))

In [10]:
params_template = bnpm.file_helpers.json_load(str(Path(directory_FR_template) / 'params.json'))

In [11]:
DEVICE_data = bnpm.torch_helpers.set_device(use_GPU=False)

device: 'cpu'


In [12]:
def cp_dict_to_cp_tensor(cp_dict, device='cpu'):
    """A function for converting a raw list of factor matrices into tensorly's CPTensor format"""
    return tl.cp_tensor.CPTensor((None, [torch.as_tensor(v, dtype=torch.float32, device=device) for v in cp_dict.values()]))

In [13]:
def make_cp_init(k_tensor, shape_dense_tensor, modes_fixed=[0,1,], device='cpu'):
    """Makes a CPTensor for initializing a TCA run. The k_tensor matrices will be used for each of the fixed modes and will be shuffle permuted for each of the non-fixed modes."""
    import copy
    n_modes = len(k_tensor)
    kt = [None]*n_modes
    for i_mode in range(len(kt)):
        if i_mode in modes_fixed:
            kt[i_mode] = torch.as_tensor(k_tensor[i_mode], dtype=torch.float32, device=device)
        else:
            perm = torch.randperm(shape_dense_tensor[i_mode])
            kt[i_mode] = torch.as_tensor(k_tensor[i_mode], dtype=torch.float32, device=device)[perm]
        
    return tl.cp_tensor.CPTensor((None, kt))

In [14]:
def reconstruction_EV(tensor, cp):
#     tensor_rec = tl.cp_to_tensor(cp).cpu()
    tensor_rec = bnpm.indexing.kruskal_to_dense(cp.factors).cpu()
    ev = 1 - (torch.var(tensor - tensor_rec) / torch.var(tensor))
    return ev

make model

In [15]:
spec_current = bnpm.h5_handling.simple_load(str(Path(directory_FR_current) / 'analysis_files' / 'VQT_Analyzer.h5'))

In [16]:
## Prepare the current session spectrogram for refitting
### flatten the (xy points) dimension
s = spec_current['spectrograms']['0'].copy()
s = s.transpose(2,3,0,1)
s = s.reshape(s.shape[0], s.shape[1], -1)
s = s.transpose(2,0,1)
s = torch.as_tensor(s, dtype=torch.float32, device=DEVICE_data)

In [17]:
## prepare tca factors into a tensorly CPTensor
cp_template = cp_dict_to_cp_tensor(tca_template['factors_rearranged']['0'], device=DEVICE_data)
cp_current = cp_dict_to_cp_tensor(tca_current['factors_rearranged']['0'], device=DEVICE_data)

In [18]:
DEVICE_tca = bnpm.torch_helpers.set_device(use_GPU=False)

device: 'cpu'


In [19]:
modes_fixed = [0,1,]

In [20]:
cp_init = make_cp_init(cp_template.factors, s.shape, modes_fixed=modes_fixed, device=DEVICE_tca)

In [21]:
params_tca = copy.deepcopy(params_template['TCA']['fit']['params_method'])

In [22]:
# params_tca['n_iter_max'] = 40
params_tca['init'] = cp_init

In [23]:
model_tca = tl.decomposition.CP_NN_HALS(
    **params_tca,
    fixed_modes=modes_fixed,
)

In [24]:
model_tca.fit(s.to(DEVICE_tca))

  return torch.tensor(


reconstruction error=0.3259415030479431
iteration 1, reconstruction error: 0.3259415030479431, decrease = 0.0
PARAFAC converged after 1 iterations


Rank-12 Non-Negative CP decomposition.

In [25]:
cp_refit = model_tca.decomposition_
cp_refit = tl.cp_tensor.CPTensor((cp_refit.weights.cpu(), [f.cpu() for f in cp_refit.factors]))

In [26]:
EV_rec_refit = bnpm.similarity.cp_reconstruction_EV(
    tensor_dense=s,
    tensor_CP=cp_refit.factors,
)

In [27]:
EV_rec_original = bnpm.similarity.cp_reconstruction_EV(
    tensor_dense=s,
    tensor_CP=cp_current.factors,
)

In [28]:
EV_rec_original

tensor(0.7478)

In [29]:
EV_rec_refit

tensor(0.7047)

In [38]:
tca_refit = {
    'factors_refit': {key: val for key, val in zip(tca_template['factors_rearranged']['0'].keys(), cp_refit.factors)},
    'modes_fixed': modes_fixed,
    'EV_rec_original': EV_rec_original,
    'EV_rec_refit': EV_rec_refit,
    'directory_template': directory_FR_template,
    'directory_current':directory_FR_current,
}

In [None]:
bnpm.h5_handling.simple_save(
    dict_to_save=tca_refit,
    path=str(Path(dir_save) / 'tca_refit.h5')
)