Skip to content

Commit

Permalink
fix: facet titles be strings
Browse files Browse the repository at this point in the history
  • Loading branch information
johentsch committed Dec 8, 2023
1 parent 55768b0 commit ed185f7
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/dimcat/data/resources/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -1694,6 +1694,7 @@ def make_heatmaps_from_transitions(
traces_update.update(traces_settings)

if not groupby:
# no subplots needed, return single heatmap
proportions, counts, proportions_str = prepare_transitions(
transitions_df, max_x=max_x, max_y=max_y
)
Expand All @@ -1715,6 +1716,7 @@ def make_heatmaps_from_transitions(
write_image(fig=fig, filename=output, width=width, height=height)
return fig

# prepare subplots according to groupby
facet_row_names, facet_col_names = [], []
group2row_col = {}
group2data, group2customdata, group2text = {}, {}, {}
Expand Down Expand Up @@ -1752,6 +1754,11 @@ def update_facet_names(group):

# prepare the transition data
for group, group_df in transitions_df.groupby(groupby, group_keys=False):
if not isinstance(group, str):
if isinstance(group, tuple):
group = ", ".join(str(g) for g in group)
else:
group = str(group)
proportions, counts, proportions_str = prepare_transitions(
group_df, max_x=max_x, max_y=max_y
)
Expand All @@ -1760,6 +1767,7 @@ def update_facet_names(group):
group2text[group] = proportions_str
update_facet_names(group)

# prepare the colorscales
colorscale_list = []
if column_colorscales is not None:
if isinstance(column_colorscales, list):
Expand Down Expand Up @@ -1812,6 +1820,8 @@ def update_facet_names(group):
name=group,
)
fig.add_trace(heatmap, row, col)

# layout and return
update_figure_layout(
fig=fig,
layout=layout,
Expand Down
1 change: 1 addition & 0 deletions src/dimcat/data/resources/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,7 @@ def merge_notes(staff_midi):
df.loc[new_dur.index, "duration"] = new_dur.duration
except Exception:
print(new_dur)
raise
if return_dropped:
df.loc[new_dur.index, "dropped"] = new_dur.dropped
df = df.drop(new_dur.dropped.sum())
Expand Down

0 comments on commit ed185f7

Please sign in to comment.