Skip to content

Commit

Permalink
Feat: helper function to plot HMM summary stats.
Browse files Browse the repository at this point in the history
  • Loading branch information
cgohil8 committed May 13, 2024
1 parent 59d5b81 commit 46010b4
Showing 1 changed file with 85 additions and 1 deletion.
86 changes: 85 additions & 1 deletion osl_dynamics/utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,7 @@ def plot_violin(
# Plot violins
x = np.concatenate([[x_] * len(y) for x_, y in zip(x, data)])
y = data.flatten()
ax = sns.violinplot(x=x, y=y, ax=ax, **sns_kwargs)
ax = sns.violinplot(x=x, y=y, hue=x, ax=ax, legend=False, **sns_kwargs)

# Set title and axis labels
ax.set_title(title)
Expand Down Expand Up @@ -2012,6 +2012,90 @@ def plot_psd_topo(
return fig, ax


def plot_hmm_summary_stats(
fo,
lt,
intv,
sr,
filename,
cmap="tab10",
fig_kwargs=None,
sns_kwargs=None,
):
"""Plot summary statistics (FO, LT, INTV, SR).
Parameters
----------
fo : np.ndarray
Fractional occupancies. Shape much be (n_subjects, n_states).
lt : np.ndarray
Mean lifetimes. Shape much be (n_subjects, n_states).
intv : np.ndarray
Mean intervals. Shape much be (n_subjects, n_states).
sr : np.ndarray
Switching rates. Shape much be (n_subjects, n_states).
filename : str
Output filename.
cmap : str, optional
Matplotlib colormap.
fig_kwargs : dict, optional
Arguments to pass to :code:`plt.subplots()`.
sns_kwargs : dict, optional
Arguments to pass to :code:`sns.violinplot()`.
"""
if fig_kwargs is None:
fig_kwargs = {}

if sns_kwargs is None:
sns_kwargs = {}

n_states = fo.shape[1]
x = range(1, n_states + 1)

sns_kwargs.update(
{
"inner": "quart",
"cut": 0,
"palette": cmap,
}
)

fig, ax = create_figure(nrows=1, ncols=4, figsize=(15, 3))
plot_violin(
fo.T,
x=x,
x_label="State",
y_label="Fractional Occupancy",
ax=ax[0],
sns_kwargs=sns_kwargs,
)
plot_violin(
lt.T,
x=x,
x_label="State",
y_label="Mean Lifetime (s)",
ax=ax[1],
sns_kwargs=sns_kwargs,
)
plot_violin(
intv.T,
x=x,
x_label="State",
y_label="Mean Interval (s)",
ax=ax[2],
sns_kwargs=sns_kwargs,
)
plot_violin(
sr.T,
x=x,
x_label="State",
y_label="Switching rate (Hz)",
ax=ax[3],
sns_kwargs=sns_kwargs,
)
save(fig, filename=filename, tight_layout=True)


def plot_summary_stats_group_diff(
name,
summary_stats,
Expand Down

0 comments on commit 46010b4

Please sign in to comment.