From 46010b4c4dfa7c9df9d6e9421e7b6917adbae7ea Mon Sep 17 00:00:00 2001 From: Chetan Gohil Date: Mon, 13 May 2024 21:10:50 +0100 Subject: [PATCH] Feat: helper function to plot HMM summary stats. --- osl_dynamics/utils/plotting.py | 86 +++++++++++++++++++++++++++++++++- 1 file changed, 85 insertions(+), 1 deletion(-) diff --git a/osl_dynamics/utils/plotting.py b/osl_dynamics/utils/plotting.py index 1fc23b8a..125445ff 100644 --- a/osl_dynamics/utils/plotting.py +++ b/osl_dynamics/utils/plotting.py @@ -916,7 +916,7 @@ def plot_violin( # Plot violins x = np.concatenate([[x_] * len(y) for x_, y in zip(x, data)]) y = data.flatten() - ax = sns.violinplot(x=x, y=y, ax=ax, **sns_kwargs) + ax = sns.violinplot(x=x, y=y, hue=x, ax=ax, legend=False, **sns_kwargs) # Set title and axis labels ax.set_title(title) @@ -2012,6 +2012,90 @@ def plot_psd_topo( return fig, ax +def plot_hmm_summary_stats( + fo, + lt, + intv, + sr, + filename, + cmap="tab10", + fig_kwargs=None, + sns_kwargs=None, +): + """Plot summary statistics (FO, LT, INTV, SR). + + Parameters + ---------- + fo : np.ndarray + Fractional occupancies. Shape much be (n_subjects, n_states). + lt : np.ndarray + Mean lifetimes. Shape much be (n_subjects, n_states). + intv : np.ndarray + Mean intervals. Shape much be (n_subjects, n_states). + sr : np.ndarray + Switching rates. Shape much be (n_subjects, n_states). + filename : str + Output filename. + cmap : str, optional + Matplotlib colormap. + fig_kwargs : dict, optional + Arguments to pass to :code:`plt.subplots()`. + sns_kwargs : dict, optional + Arguments to pass to :code:`sns.violinplot()`. + """ + if fig_kwargs is None: + fig_kwargs = {} + + if sns_kwargs is None: + sns_kwargs = {} + + n_states = fo.shape[1] + x = range(1, n_states + 1) + + sns_kwargs.update( + { + "inner": "quart", + "cut": 0, + "palette": cmap, + } + ) + + fig, ax = create_figure(nrows=1, ncols=4, figsize=(15, 3)) + plot_violin( + fo.T, + x=x, + x_label="State", + y_label="Fractional Occupancy", + ax=ax[0], + sns_kwargs=sns_kwargs, + ) + plot_violin( + lt.T, + x=x, + x_label="State", + y_label="Mean Lifetime (s)", + ax=ax[1], + sns_kwargs=sns_kwargs, + ) + plot_violin( + intv.T, + x=x, + x_label="State", + y_label="Mean Interval (s)", + ax=ax[2], + sns_kwargs=sns_kwargs, + ) + plot_violin( + sr.T, + x=x, + x_label="State", + y_label="Switching rate (Hz)", + ax=ax[3], + sns_kwargs=sns_kwargs, + ) + save(fig, filename=filename, tight_layout=True) + + def plot_summary_stats_group_diff( name, summary_stats,