In [25]:
import numpy as np
from scipy.ndimage import gaussian_filter1d
from scipy.linalg import qr

import torch
import torch.nn.functional as F
import random
import scipy.ndimage
import matplotlib.pyplot as plt

device = "cuda:4" if torch.cuda.is_available() else "cpu"

# visual function

In [None]:
def visual_F_list(F_list):

    K = len(F_list)
    plt.figure(figsize=(5 * K, 3))
    for idx, f in enumerate(F_list):
        plt.subplot(1, K, idx + 1)
        plt.imshow(f.cpu(), cmap='coolwarm', interpolation='nearest')
        plt.title(f"F_{idx + 1}")
        plt.colorbar()
        plt.xlabel("Latent dim")
        plt.ylabel("Latent dim")

        p = f.shape[0]
        ticks = list(range(p))
        labels = list(range(1, p + 1))
        plt.xticks(ticks, labels)
        plt.yticks(ticks, labels)
        
    plt.suptitle("Sub-circuit dynamic matrices F_list")
    plt.show()


def visual_C(C):

    plt.figure(figsize=(10,3))
    for k in range(C.shape[0]):
        plt.plot(C[k].cpu().numpy(), label=f'c_{k+1}')
    plt.title("Sub-circuit coefficients C")
    plt.xlabel("Time")
    plt.legend()
    plt.show()

def visual_A(data):

    plt.figure(figsize=(6, 5))
    plt.imshow(data.detach().cpu().numpy(), aspect='auto', origin='lower', cmap='viridis')
    plt.colorbar(label='Projection weight')
    plt.title("Projection matrix A")
    plt.xlabel("Latent dimension (p)")
    plt.ylabel("Neuron #")
    plt.show()


def visual_X(data):

    X = data.detach().cpu().numpy()
    plt.figure(figsize=(10,3))
    for dim in range(X.shape[0]):
        plt.plot(X[dim], label=f'x_{dim+1}')
    plt.title("Latent dynamics X")
    plt.xlabel("Time")
    plt.legend()
    plt.show()

def visual_Y(data):
    Y = data.detach().cpu().numpy()

    plt.figure(figsize=(10,5))
    plt.imshow(Y, aspect='auto', origin='lower', cmap='viridis')
    plt.colorbar(label='Firing rate')
    plt.title("Synthetic observations Y")
    plt.xlabel("Time")
    plt.ylabel("Neuron #")
    plt.show()

# optimal

## a and x

In [None]:
def compute_h(Y):

    Y_norm = Y / (Y.norm(dim=1, keepdim=True) + 1e-8)
    h = Y_norm @ Y_norm.T
    h = (h + 1.0) / 2.0
    return h


def similarity_loss(a, h):
    D = torch.diag(h.sum(dim=1))
    L = D - h
    loss = torch.trace(a.T @ L @ a)
    return loss


def update_a_and_x(Y, X, a, C, F_list, num_iter, lr_a, lr_x, lambda_sparse_A, lambda_sim_A, lambda_dyn_X, epoch):

    N, T = Y.shape
    p = X.shape[0]

    h = compute_h(Y)

    A_var = a.clone().detach().requires_grad_(True)
    X_var = X.clone().detach().requires_grad_(False)

    optimizer_a = torch.optim.Adam([A_var], lr=lr_a)

    for i in range(num_iter):

        Y_hat = A_var @ X_var
        loss_mse = F.mse_loss(Y, Y_hat)

        loss_sparse = lambda_sparse_A * torch.norm(A_var, p=1)

        loss_sim = lambda_sim_A * similarity_loss(A_var, h)

        loss_a = loss_mse + loss_sparse + loss_sim

        if epoch % 10 == 0 and i == 0:
            print(f"[Epoch {epoch}] "
                  f"loss_rec_y={loss_mse.item():.6f}, "
                  f"loss_sparse_a={loss_sparse.item():.6f}, "
                  f"loss_sim_a={loss_sim.item():.6f}")

        optimizer_a.zero_grad()
        loss_a.backward()
        optimizer_a.step()

        with torch.no_grad():
            A_var.clamp_(min=0.0, max=1.0)
            col_norms = torch.norm(A_var, p=2, dim=0, keepdim=True)
            col_norms = torch.clamp(col_norms, min=1e-8)
            A_var.div_(col_norms)
    
    next_a = A_var.detach()


    A_var = next_a.clone().detach().requires_grad_(False)
    X_var = X.clone().detach().requires_grad_(True)
    C_var = C.clone().detach().requires_grad_(False)
    
    optimizer_x = torch.optim.Adam([X_var], lr=lr_x)

    F_all = torch.stack(F_list, dim=0).clone().detach().requires_grad_(False)

    for i in range(num_iter):

        Y_hat = A_var @ X_var
        loss_rec_y = F.mse_loss(Y, Y_hat)

        F_t_all = torch.einsum('kt,kij->tij', C_var, F_all)
        X_t1_pred = torch.einsum('tij,jt->it', F_t_all[:-1, :, :], X[:, :-1])                   
        loss_dyn_x = lambda_dyn_X * F.mse_loss(X_var[:, 1:], X_t1_pred)

        loss_x = loss_rec_y + loss_dyn_x

        if epoch % 10 == 0 and i == 0:
            print(f"[Epoch {epoch}] loss_rec_y={loss_rec_y.item():.6f}")

        optimizer_x.zero_grad()
        loss_x.backward()
        optimizer_x.step()

    next_x = X_var.detach()

    return next_a, next_x


## c and f

In [None]:
def decorrelation_loss(F_var: torch.Tensor, eps: float = 1e-8,):

    K = F_var.shape[0]
    flat = F_var.reshape(K, -1)                
    flat = flat / (flat.norm(dim=1, keepdim=True).clamp_min(eps))

    G = flat @ flat.t()

    idx = torch.triu_indices(K, K, offset=1, device=F_var.device)
    vals = (G[idx[0], idx[1]] ** 2)

    return vals.sum()


def update_c_and_f(X, C, F_list, num_iter, lr_c, lr_f, lambda_sparse_c, lambda_smooth_c, lambda_sparse_f, lambda_decor_f, epoch):
    p, _ = X.shape
    K, T = C.shape
    
    F_all = torch.stack(F_list, dim=0)

    C_var = C.clone().detach().requires_grad_(True)
    F_var = F_all.clone().detach().requires_grad_(False)

    optimizer_c = torch.optim.Adam([C_var], lr=lr_c)

    for i in range(num_iter):

        # compute dynamics loss
        F_t_all = torch.einsum('kt,kij->tij', C_var, F_var)
        X_t1_pred = torch.einsum('tij,jt->it', F_t_all[:-1, :, :], X[:, :-1])

        loss_dyn_c = F.mse_loss(X[:, 1:], X_t1_pred)

        loss_sparse_c = lambda_sparse_c * torch.norm(C_var, p=1, dim=0).mean()

        diff = C_var[:, 1:] - C_var[:, :-1]  # [K, T-1]
        loss_smooth_c = lambda_smooth_c * diff.abs().mean()

        loss_c = loss_dyn_c + loss_sparse_c + loss_smooth_c

        optimizer_c.zero_grad()
        loss_c.backward()
        optimizer_c.step()

        if epoch % 10 == 0 and i == 0:
            print(f"[Epoch {epoch}] loss_dyn_c={loss_dyn_c.item():.6f}, "
                  f"loss_sparse_c={loss_sparse_c.item():.6f}, "
                  f"loss_smooth_c={loss_smooth_c.item():.6f}"
                  )

        with torch.no_grad():
            C_var.clamp_(min=0.0, max=1.0)
            col_sums = C_var.sum(dim=0, keepdim=True)  # [1, T]
            col_sums = torch.clamp(col_sums, min=1e-8)
            C_var.div_(col_sums)

    C_opt = C_var.detach()


    C_var = C_opt.clone().detach().requires_grad_(False)
    F_var = F_all.clone().detach().requires_grad_(True)

    optimizer_f = torch.optim.Adam([F_var], lr=lr_f)

    for i in range(num_iter):

        F_t_all = torch.einsum('kt,kij->tij', C_var, F_var)
        X_t1_pred = torch.einsum('tij,jt->it', F_t_all[:-1, :, :], X[:, :-1])

        loss_dyn_f = F.mse_loss(X[:, 1:], X_t1_pred)

        loss_sparse_f = lambda_sparse_f * torch.norm(F_var, p=1)

        loss_decor_f = lambda_decor_f * decorrelation_loss(F_var)

        if epoch % 10 == 0 and i == 0:
            print(f"[Epoch {epoch}] loss_dyn_f={loss_dyn_f.item():.6f}, "
                  f"loss_sparse_f={loss_sparse_f.item():.6f}, "
                  f"loss_decor_f={loss_decor_f.item():.6f}")

        loss_f = loss_dyn_f + loss_sparse_f + loss_decor_f

        optimizer_f.zero_grad()
        loss_f.backward()
        optimizer_f.step()

        with torch.no_grad():
            for k in range(K):
                eigvals = torch.linalg.eigvals(F_var[k])
                max_abs_eig = eigvals.abs().max()
                F_var[k] /= max_abs_eig

    F_opt = [F_var[i].detach() for i in range(K)]

    return C_opt, F_opt


# main

In [None]:
def main(Y, N, K, p, T, epoch_num, warmup_num):

    torch.manual_seed(0)
    F_list = [torch.eye(p).to(device) + 0.1 * torch.randn(p, p).to(device) for _ in range(K)]
    C = torch.rand(K, T).to(device)
    a = torch.rand(N, p, requires_grad=False, device=device)
    X = torch.randn(p, T, requires_grad=False, device=device)
    X[:, 0] = torch.tensor([1, -1, 1])

    for epoch in range(warmup_num):
        a, X = update_a_and_x(Y, X, a, C, F_list, num_iter=20, lr_a=1e-3, lr_x=1e-2, lambda_sparse_A=0.005, lambda_sim_A=0.001, lambda_dyn_X=0.0, epoch=epoch)

    for epoch in range(warmup_num):
        C, F_list = update_c_and_f(X, C, F_list, num_iter=20, lr_c=1e-2, lr_f=1e-3, lambda_sparse_c=0.05, lambda_smooth_c=0.08, lambda_sparse_f=0.001, lambda_decor_f=0.001, epoch=epoch)
            

    # --- joint optimization ---
    for epoch in range(epoch_num):

        a, X = update_a_and_x(Y, X, a, C, F_list, num_iter=2, lr_a=1e-3, lr_x=1e-2, lambda_sparse_A=0.0, lambda_sim_A=0.0, lambda_dyn_X=0.0001, epoch=epoch)

        C, F_list = update_c_and_f(X, C, F_list, num_iter=2, lr_c=1e-2, lr_f=1e-3, lambda_sparse_c=0.05, lambda_smooth_c=0.08, lambda_sparse_f=0.001, lambda_decor_f=0.001, epoch=epoch)


    data = {
        'C': C,
        'X': X,
        'A': a,
        'Y': Y,
        'F_list': F_list
    }
    return data


N = 21         
p = 3           
K = 3           
T = 500     

data_true = torch.load("./data/Three_Task_Synthetic_Data_train.pt", weights_only=True)

Y = data_true['Y']

data_est = main(Y, N, p, K, T, epoch_num=50, warmup_num=200)

# visual_F_list(data_est["F_list"])
# visual_C(data_est["C"])
# visual_A(data_est["A"])
# visual_X(data_est["X"])
# visual_Y(data_est["A"] @ data_est["X"])

[Epoch 0] loss_rec_y=1.184915, loss_sparse_a=0.160320, loss_sim_a=0.071039
[Epoch 0] loss_rec_y=0.243432
[Epoch 10] loss_rec_y=0.045665, loss_sparse_a=0.057226, loss_sim_a=0.012231
[Epoch 10] loss_rec_y=0.042825
[Epoch 20] loss_rec_y=0.016001, loss_sparse_a=0.055982, loss_sim_a=0.012611
[Epoch 20] loss_rec_y=0.015095
[Epoch 30] loss_rec_y=0.014100, loss_sparse_a=0.051011, loss_sim_a=0.016797
[Epoch 30] loss_rec_y=0.014145
[Epoch 40] loss_rec_y=0.014407, loss_sparse_a=0.043628, loss_sim_a=0.022128
[Epoch 40] loss_rec_y=0.014570
[Epoch 50] loss_rec_y=0.015362, loss_sparse_a=0.038035, loss_sim_a=0.025108
[Epoch 50] loss_rec_y=0.015499
[Epoch 60] loss_rec_y=0.015899, loss_sparse_a=0.033725, loss_sim_a=0.026966
[Epoch 60] loss_rec_y=0.015801
[Epoch 70] loss_rec_y=0.015524, loss_sparse_a=0.032215, loss_sim_a=0.027406
[Epoch 70] loss_rec_y=0.015576
[Epoch 80] loss_rec_y=0.015867, loss_sparse_a=0.031622, loss_sim_a=0.027772
[Epoch 80] loss_rec_y=0.015878
[Epoch 90] loss_rec_y=0.016094, loss_sp

# compute MSE

In [None]:
def compute_losses(X, C, F_list, a, data_true):

    true_X = data_true['X']
    true_C = data_true['C']
    true_F_list = data_true['F_list']
    true_a = data_true['A']
    true_Y = data_true['Y']

    Y_hat = a @ X
    loss_Y = torch.mean((Y_hat - true_Y) ** 2)

    loss_X = torch.mean((X - true_X) ** 2)

    loss_C = torch.mean((C - true_C) ** 2)

    loss_F = 0.0
    for f_est, f_true in zip(F_list, true_F_list):
        loss_F += torch.mean((f_est - f_true) ** 2)

    loss_a = torch.mean((a - true_a) ** 2)

    return {
        'loss_Y': loss_Y,
        'loss_X': loss_X,
        'loss_C': loss_C,
        'loss_F': loss_F,
        'loss_a': loss_a
    }

def save_data(re_F_list, re_A):
    data = {
        'F_list': re_F_list,
        'A': re_A,
    }
    torch.save(data, "./data/Three_Task_Synthetic_Data_A_F.pt")


re_F_list = data_est["F_list"]
re_C = data_est["C"]
re_A = data_est["A"]
re_X = data_est["X"]
re_Y = data_est["A"] @ data_est["X"]

order = torch.tensor([1, 2, 0], device=re_A.device)
re_A = re_A[:, order]
re_X = re_X[order, :]

save_data(re_F_list, re_A)

results_mse = compute_losses(re_X, re_C, re_F_list, re_A, data_true)
print(results_mse)

{'loss_Y': tensor(0.0052, device='cuda:4'), 'loss_X': tensor(0.0095, device='cuda:4'), 'loss_C': tensor(0.6071, device='cuda:4'), 'loss_F': tensor(0.0075, device='cuda:4'), 'loss_a': tensor(0.0041, device='cuda:4')}


# compute_p

In [32]:
def p(X, C, F_list, a, data_true):

    def corrcoef(a, b):
        a_flat = a.flatten()
        b_flat = b.flatten()
        return np.corrcoef(a_flat, b_flat)[0, 1]

    results = {}

    Y_hat = a @ X

    A_true = data_true['A'].cpu().numpy()
    A_est = a.cpu().numpy()
    results['corr_A'] = corrcoef(A_true, A_est)

    corr_F = []
    F_true_list = data_true['F_list']
    F_est_list = F_list
    for f_true, f_est in zip(F_true_list, F_est_list):
        f_true = f_true.cpu().numpy()
        f_est = f_est.cpu().numpy()
        corr = corrcoef(f_true, f_est)
        corr_F.append(corr)
    results['corr_F_list'] = corr_F

    C_true = data_true['C'].cpu().numpy()
    C_est = C.cpu().numpy()
    results['corr_C'] = corrcoef(C_true, C_est)

    X_true = data_true['X'].cpu().numpy()
    X_est = X.cpu().numpy()
    results['corr_X'] = corrcoef(X_true, X_est)

    Y_true = data_true['Y'].cpu().numpy()
    Y_est = Y_hat.cpu().numpy()
    results['corr_Y'] = corrcoef(Y_true, Y_est)

    return results

results_p = p(re_X, re_C, re_F_list, re_A, data_true)
print(results_p)

{'corr_A': 0.9402853807650996, 'corr_F_list': [0.9992513474937725, 0.9930902077186291, 0.9909606149523752], 'corr_C': -0.4778150230736332, 'corr_X': 0.9925176803867974, 'corr_Y': 0.9710753126681247}
