Skip to content

Commit

Permalink
fix pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki committed Apr 22, 2024
1 parent 7839367 commit bd3baae
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions src/torchmetrics/utilities/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def plot_curve(
label_names: Tuple containing the names of the x and y axis
legend_name: Name of the curve to be used in the legend
name: Custom name to describe the metric
labels: Optional labels for the different curves that will be added to the plot
Returns:
A tuple consisting of the figure and respective ax objects (or array of ax objects) of the generated figure
Expand Down Expand Up @@ -321,10 +322,7 @@ def plot_curve(
)

for i, (x_, y_) in enumerate(zip(x, y)):
if labels is None:
label = f"{legend_name}_{i}" if legend_name is not None else str(i)
else:
label = str(labels[i])
label = f"{legend_name}_{i}" if legend_name is not None else str(i) if labels is None else str(labels[i])
label += f" AUC={score[i].item():0.3f}" if score is not None else ""
ax.plot(x_.detach().cpu(), y_.detach().cpu(), linestyle="-", linewidth=2, label=label)
ax.legend()
Expand Down

0 comments on commit bd3baae

Please sign in to comment.