Skip to content

Commit

Permalink
pre-commit fix
Browse files Browse the repository at this point in the history
  • Loading branch information
lee1043 committed Nov 7, 2023
1 parent f99d5e3 commit 668f352
Showing 1 changed file with 26 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def parallel_coordinate_plot(
models_to_highlight_by_line=True,
models_to_highlight_colors=None,
models_to_highlight_labels=None,
models_to_highlight_markers=['s', 'o', '^', '*'],
models_to_highlight_markers=["s", "o", "^", "*"],
models_to_highlight_markers_size=10,
fig=None,
ax=None,
Expand Down Expand Up @@ -56,7 +56,7 @@ def parallel_coordinate_plot(
- `data`: 2-d numpy array for metrics
- `metric_names`: list, names of metrics for individual vertical axes (axis=1)
- `model_names`: list, name of models for markers/lines (axis=0)
- `models_to_highlight`: list, default=None, List of models to highlight as lines or marker
- `models_to_highlight`: list, default=None, List of models to highlight as lines or marker
- `models_to_highlight_by_line`: bool, default=True, highlight as lines. If False, as marker
- `models_to_highlight_colors`: list, default=None, List of colors for models to highlight as lines
- `models_to_highlight_labels`: list, default=None, List of string labels for models to highlight as lines
Expand Down Expand Up @@ -242,14 +242,19 @@ def parallel_coordinate_plot(
label = models_to_highlight_labels[mh_index]
else:
label = model

if models_to_highlight_by_line:
ax.plot(range(N), zs[j, :], "-", c=color, label=label, lw=3)
else:
ax.plot(range(N), zs[j, :], models_to_highlight_markers[mh_index],
c=color, label=label,
markersize=models_to_highlight_markers_size)

ax.plot(
range(N),
zs[j, :],
models_to_highlight_markers[mh_index],
c=color,
label=label,
markersize=models_to_highlight_markers_size,
)

mh_index += 1
else:
if identify_all_models:
Expand Down Expand Up @@ -300,20 +305,28 @@ def parallel_coordinate_plot(
interpolate=False,
alpha=0.5,
)

if arrow_between_lines:
# Add vertical arrows
for xi, yi1, yi2 in zip(x, y1, y2):
if (yi2 > yi1):
if yi2 > yi1:
arrow_color = arrow_between_lines_colors[0]
elif (yi2 < yi1):
elif yi2 < yi1:
arrow_color = arrow_between_lines_colors[1]
else:
arrow_color = None
arrow_length = yi2 - yi1
ax.arrow(xi, yi1, 0, arrow_length, color=arrow_color,
length_includes_head=True,
alpha=arrow_alpha, width=0.05, head_width=0.15)
ax.arrow(
xi,
yi1,
0,
arrow_length,
color=arrow_color,
length_includes_head=True,
alpha=arrow_alpha,
width=0.05,
head_width=0.15,
)

ax.set_xlim(-0.5, N - 0.5)
ax.set_xticks(range(N))
Expand Down

0 comments on commit 668f352

Please sign in to comment.