Skip to content

Commit

Permalink
Feat: Function to re-normalise mixing coefs with correlations.
Browse files Browse the repository at this point in the history
  • Loading branch information
RukuangHuang committed May 16, 2024
1 parent 4620e6f commit d6f9cdb
Showing 1 changed file with 37 additions and 11 deletions.
48 changes: 37 additions & 11 deletions osl_dynamics/inference/modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,20 +565,46 @@ def reweight_alphas(alpha, covs):
reweighted_alpha : list of np.ndarray or np.ndarray
Re-weighted mixing coefficients. Shape is the same as :code:`alpha`.
"""
if isinstance(alpha, np.ndarray):
alpha = [alpha]
return reweight_mtc(alpha, covs, "covariance")

# Calculate normalised alphas
traces = np.trace(covs, axis1=1, axis2=2)
reweighted_alpha = [a * traces[np.newaxis, :] for a in alpha]
reweighted_alpha = [
na / np.sum(na, axis=1, keepdims=True) for na in reweighted_alpha
]

if len(reweighted_alpha) == 1:
reweighted_alpha = reweighted_alpha[0]
def reweight_mtc(mtc, params, params_type):
"""Re-weight mixing coefficients to account for the magnitude of
observation model parameters.
Parameters
----------
mtc : List[np.ndarray] or np.ndarray
Raw mixing coefficients. Shape must be (n_sessions, n_samples, n_modes)
or (n_samples, n_modes).
params : np.ndarray
Observation model parameters. Shape must be (n_modes, n_channels, n_channels).
params_type : str
Observation model parameters type. Either 'covariance' or 'correlation'.
Returns
-------
reweighted_mtc : List[np.ndarray]
Re-weighted mixing coefficients. Shape is the same as :code:`mtc`.
"""
if isinstance(mtc, np.ndarray):
mtc = [mtc]

if params_type == "covariance":
weights = np.trace(params, axis1=1, axis2=2)
elif params_type == "correlation":
m, n = np.tril_indices(params.shape[-1], -1)
weights = np.sum(np.abs(params[:, m, n]), axis=-1)
else:
raise ValueError("params_type must be 'covariance' or 'correlation'.")

reweighted_mtc = [x * weights[np.newaxis, :] for x in mtc]
reweighted_mtc = [x / np.sum(x, axis=1, keepdims=True) for x in reweighted_mtc]

if len(reweighted_mtc) == 1:
reweighted_mtc = reweighted_mtc[0]

return reweighted_alpha
return reweighted_mtc


def average_runs(alpha, n_clusters=None, return_cluster_info=False):
Expand Down

0 comments on commit d6f9cdb

Please sign in to comment.