This file is to convert Ben's sgm4fMRI matlab code to python

also including testing part.

In [1]:
import sys
sys.path.append("../../mypkg")

In [2]:
from constants import RES_ROOT, FIG_ROOT, DATA_ROOT
from utils.misc import load_pkl, save_pkl, get_ball_cor
from utils.colors import qual_cmap
from utils.measures import reg_R_fn

In [3]:
%load_ext autoreload
%autoreload 2
# 0,1, 2, 3, be careful about the space

In [4]:
import numpy as np
import scipy
from scipy.io import loadmat
from scipy import signal
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm, trange
import bct # the pkg to get graph features
from joblib import Parallel, delayed
from easydict import EasyDict as edict
import pandas as pd

plt.style.use(FIG_ROOT/'base.mplstyle')

In [5]:
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

if not logger.hasHandlers():
    ch = logging.StreamHandler() # for console. 
    ch.setLevel(logging.DEBUG)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    
    # add formatter to ch
    ch.setFormatter(formatter)
    
    logger.addHandler(ch)

# Fns and params

## Some fns

In [6]:
def _preprocess_ts(ts):
    """preprocessing, 
        1. detrend
        2. resample (len of seq has 235 or 555, I make it consistent to 235)
    """
    ts = signal.detrend(ts);
    if ts.shape[-1] > 235:
        ts = signal.resample(ts, num=235, axis=-1)
    return ts
    
def _get_fc(ts, is_fisher=True):
    """Get FC from ts, including 
        1. Pearsons'r 
        2. fisher transform
        3. abs value
    """
    fc = np.corrcoef(ts)
    fc = fc - np.diag(np.diag(fc))
    return fc

    
def _load_data(n):
    """Sub idx is from 1, n: the sub idx
    """
    return mat_data['ts_321_273'][np.where(mat_data['subj_321_ts'][:, 0] == n)[0], :].transpose()

## Load data and params

In [7]:
# load data
mat_data = loadmat(DATA_ROOT/"ad_ftd_hc_fmri_data.mat");

# some parameters

num_rois = 246 # the BNA has 246 regions
num_sps = 321
SC_mat = loadmat(DATA_ROOT/"SC_HC_BN_template_nature_order.mat")["SC_template"];

In [106]:
sub_ix = 4
ts = _load_data(sub_ix)[:num_rois]
data = _preprocess_ts(ts).T; # num of time pts x num of ROIs
eps = 1e-10

1e-10

In [107]:
# Get Laplacian and eigmode
SC = SC_mat.copy()
SC = SC/np.sum(SC)
cd, rd = SC.sum(axis=0), SC.sum(axis=1);
L = np.eye(num_rois) - np.diag(1/(np.sqrt(rd)+eps))@SC@np.diag(1/(np.sqrt(cd)+eps));
ev, U = np.linalg.eig(L);
sorted_idx = np.argsort(np.abs(ev)) # ascending
ev = ev[sorted_idx]
U = U[:, sorted_idx];

In [108]:
params = edict()
params.TR = 2 # 
params.fband = [0.008, 0.08]
params.pwelch_windows = []
params.costtype = "corr"
params.perc_thresh = False
params.eig_weights = True
params.deconvHRF = False
params.is_ann = False
params.model_focus = "FX"
params.fitmean = False
params.theta = []

In [109]:
num_pts, num_rois = data.shape
if num_pts < 64:
    logger.warning(f"Not enough timepoints ({num_pts}) for a good FFT; "
                   f"therefore SGM is only fitting to FC.")
    params.model_focus = "FC"
elif num_pts < 128:
    nfft = 64
else: 
    nfft = 128
    
fvec = np.linspace(params.fband[0], params.fband[1], nfft);
omegavec = 2 * np.pi * fvec
fs = 1/params.TR;

In [110]:
import scipy.signal as signal
# Preprocessing time series
# demean
input_data = data - data.mean(axis=0, keepdims=True);
# detrend along the time axis
input_data = signal.detrend(input_data, axis=0);
# lowpass filter
sos = signal.butter(N=5, Wn=params.fband[1], btype="low", fs=fs, output="sos")
input_data = signal.sosfilt(sos, input_data, axis=0);

if params.deconvHRF:
    # current not defined this function. 
    # difficult to know the definition. 
    input_data = deconv_HRF(input_data)

In [111]:
# Get empirical FC
# get empirical FC, diagonal term is 0    
emp_fc = np.corrcoef(input_data.T)
np.fill_diagonal(emp_fc, 0);

if params.perc_thresh:
    # not definied, refer to Ben's github
    emp_fc = perc_thresh(emp_fc)
    
# make it symmetric
emp_fc =  np.triu(emp_fc, 1) + np.triu(emp_fc).T;

In [112]:
# Get PSD
from scipy.interpolate import interp1d
def obt_psd_at_freqs(psd_raw, f, fvec):
    """
    Calculate the power spectral density (PSD) at given frequency points.

    Parameters:
    psd_raw (array-like): The estimated PSD from Welch's method.
    f (array-like): The frequency vector corresponding to the PSD.
    fvec (array-like): The frequency points at which to calculate the PSD.

    Returns:
    array-like: The PSD values at the given frequency points (not in dB)

    Notes:
    - The input PSD is expected to be in linear scale, i.e., not in dB
    - The PSD values are converted to dB scale using a small epsilon value to avoid taking the logarithm of zero.
    - The PSD is smoothed using a 5-point symmetric linear-phase FIR filter.
    - The PSD values at the given frequency points are obtained using linear interpolation.

    """
    eps = 1e-10
    psd_dB = 10*np.log10(psd_raw+eps)
    
    # Smooth the PSD
    lpf = np.array([1, 2, 5, 2, 1]) 
    lpf = lpf/np.sum(lpf)
    psd_dB = np.convolve(psd_dB, lpf, 'same')
    
    fit_psd = interp1d(f, psd_dB)
    return 10**(fit_psd(fvec)/10)
f, Pxx = signal.welch(input_data, fs=fs, nperseg=64, axis=0);
# not in dB
PSD = np.sqrt(np.array([obt_psd_at_freqs(Pxx[:, roi_ix], f, fvec) for roi_ix in range(num_rois)]).T);

f_at_max = fvec[np.argmax(PSD, axis=0)];
omega = 2*np.pi*f_at_max.mean();

In [113]:
omega

0.24553051921714494

In [114]:
if params.eig_weights:
    ev_weight = np.abs(np.diag(U.T @ emp_fc @ U))
else:
    ev_weight = np.ones(num_rois)
ev_weight[0] = 0;

In [115]:
def _forward_FC(theta_star):
    """Checked with matlab code.
    """
    alpha = np.tanh(theta_star[0])
    tau = theta_star[1]
    He = 1/tau**2/(1/tau+omega*1j)**2
    newev = 1/(1j*omega + 1/tau*He*(1-alpha*(1-ev)));
    newev = (np.abs(newev))**2 * ev_weight;
    out_fc = U @ (newev.reshape(-1, 1) * np.conjugate(U).T);
    dg = 1/(1e-4+np.sqrt(np.diag(out_fc)));
    out_fc = out_fc * dg.reshape(-1, 1) * dg.reshape(1, -1)
    return out_fc
def _myfun_FC(theta_star):
    out_fc = _forward_FC(theta_star);
    kp_idxs = np.where(np.triu(out_fc, 1) != 0);
    r = np.corrcoef(out_fc[kp_idxs], emp_fc[kp_idxs])[0, 1]
    err = np.abs(1-r)
    return err, r

In [116]:
def _forward_FX(theta_star):
    alpha = np.tanh(theta_star[0])
    tau = theta_star[1]
    He = 1/tau**2/(1/tau+1j*omegavec)**2;
    tmp_vec = 1j * omegavec;
    tmp_mat = (1/tau*(1-alpha*(1-ev))).reshape(-1, 1) * He.reshape(1, -1)
    frequency_response = ev_weight.reshape(-1, 1)/(tmp_mat + tmp_vec.reshape(1, -1));
    
    UtP = U.conj().T @ np.ones(ev.shape[0]);
    out_fx = (U@(frequency_response * UtP[:, np.newaxis])).T;
    return out_fx

# theta_star = [0.2, 1]
# omegavec1 = np.array([1, 2, 3, 4]);
# ev1 = np.array([1, 2, 3])
# ev_weight1 = np.array([3, 4, 5])
# U1 = np.arange(1, 10).reshape(3, 3)
# U2 = U1+4
# UU = U1 + U2*1j
# _forward_FX(theta_star, omegavec1, ev_weight1, ev1, UU)
def minmax_fn(x, byrow=False):
    if x.ndim == 1:
        minmax_x = (x-x.min())/(x.max()-x.min())
    elif x.ndim == 2:
        if not byrow:
            x = x.T
        minmax_x = ((x - x.min(axis=1, keepdims=1))/(x.max(axis=1, keepdims=1) - x.min(axis=1, keepdims=1)))
        
        if not byrow:
            minmax_x = minmax_x.T
    return minmax_x

def _myfun_FX(theta_star):
    out_fx = _forward_FX(theta_star);
    
    if params.fitmean:
        qdata = np.abs(PSD.mean(axis=1))[np.newaxis];
        qmodel = np.abs(out_fx.mean(axis=1))[np.newaxis];
        rvec = reg_R_fn(qdata, qmodel)
    else:
        qdata = np.abs(PSD).T
        qmodel = np.abs(out_fx).T
        rvec = reg_R_fn(qdata, qmodel)
        
        
    if params.costtype.lower().startswith("corr"):
        errvec = np.abs(1-rvec)
            
    elif params.costtype.lower().startswith("mse"):
        qdata = minmax_fn(qdata, byrow=True)
        qmodel = minmax_fn(qmodel, byrow=True)
        errvec = np.mean((qdata-qmodel)**2, axis=1)
        
    rvec[np.isnan(rvec)] = 0
    errvec[np.isnan(errvec)] = 0
    return np.nanmean(errvec), rvec

In [117]:
def _myfun_both(theta_star):
    err_FC, r_FC = _myfun_FC(theta_star)
    err_FX, rvec = _myfun_FX(theta_star)
    r_FX = np.nanmean(rvec)
    
    err_b = err_FC + err_FX
    return err_b, r_FC, r_FX

In [118]:
maxiter = 1000;
theta0 = [0.5, 1]
ll = [0.1, 0.1]
ul = [10, 5]
bds = [[0.1, 10], [0.1, 5]]

if params.model_focus.lower().startswith("both"):
    obj_fn1 = _myfun_both
elif params.model_focus.lower().startswith("fc"):
    obj_fn1 = _myfun_FC
elif params.model_focus.lower().startswith("fx"):
    obj_fn1 = _myfun_FX
obj_fn = lambda x: obj_fn1(x)[0];

In [119]:
from scipy.optimize import dual_annealing, minimize

In [120]:
if not params.is_ann:
    fit_res = minimize(obj_fn, x0=theta0, bounds=bds, options={"maxiter":maxiter})
else:
    fit_res = dual_annealing(obj_fn, x0=theta0, bounds=bds, maxiter=maxiter)
    
model_fc = _forward_FC(fit_res.x)
model_psd = np.abs(_forward_FX(fit_res.x));

In [121]:
params

{'TR': 2,
 'fband': [0.008, 0.08],
 'pwelch_windows': [],
 'costtype': 'corr',
 'perc_thresh': False,
 'eig_weights': True,
 'deconvHRF': False,
 'is_ann': False,
 'model_focus': 'FX',
 'fitmean': False,
 'theta': []}

In [122]:
fit_res

      fun: 0.28685200426413276
 hess_inv: <2x2 LbfgsInvHessProduct with dtype=float64>
      jac: array([-6.56141808e-06, -1.14352972e-06])
  message: 'CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL'
     nfev: 51
      nit: 10
     njev: 17
   status: 0
  success: True
        x: array([0.48911888, 2.39446794])

In [123]:
_forward_FX([0.5, 1])

array([[0.43013009+0.00960486j, 0.40650881+0.00723866j,
        0.49101263+0.0103583j , ..., 0.25268883+0.00936384j,
        0.37688604+0.01061218j, 0.3628753 +0.01465389j],
       [0.43057321+0.01027282j, 0.40694624+0.00773636j,
        0.49152557+0.01107689j, ..., 0.25291379+0.01002685j,
        0.37725892+0.01135781j, 0.36318667+0.01569379j],
       [0.43104739+0.01093809j, 0.40741433+0.00823085j,
        0.49207447+0.01179222j, ..., 0.25315453+0.01068972j,
        0.37765796+0.01210208j, 0.36351987+0.01673397j],
       ...,
       [0.79825266-0.21075211j, 0.75182024-0.26697848j,
        0.91511109-0.26245697j, ..., 0.48652774+0.01386596j,
        0.73679084-0.10096268j, 0.7025108 +0.06868837j],
       [0.8009927 -0.22091763j, 0.75336961-0.27818884j,
        0.91811097-0.27466266j, ..., 0.49079407+0.01088866j,
        0.7419641 -0.10854173j, 0.70951479+0.06567877j],
       [0.80346801-0.23131289j, 0.75459525-0.28963032j,
        0.92079222-0.28714568j, ..., 0.49507116+0.00777289j,
 