Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
5eebcf7
feat(aggregation): add GradVac aggregator
rkhosrowshahi Apr 9, 2026
a588c93
chore: Remove outdated doctesting stuff (#639)
ValerianRey Apr 11, 2026
9d65f63
chore: Add governance documentation (#637)
PierreQuinton Apr 11, 2026
3ab336c
refactor(gradvac): literal group types, eps/beta rules, and plotter UX
rkhosrowshahi Apr 12, 2026
e53849e
refactor(gradvac): base on GramianWeightedAggregator with GradVacWeig…
rkhosrowshahi Apr 12, 2026
4909964
fix: update type hint for update_gradient_coordinate function
rkhosrowshahi Apr 12, 2026
a39f343
test(gradvac): cover beta setter success path for codecov
rkhosrowshahi Apr 12, 2026
0359e60
Rename some variables in test_gradvac.py
ValerianRey Apr 12, 2026
1da5f6e
Add comment about why we move to cpu
ValerianRey Apr 12, 2026
21d55f9
Add GradVac to the aggregator table in README
ValerianRey Apr 12, 2026
17b1dd5
Add changelog entry
ValerianRey Apr 12, 2026
02a826b
Merge branch 'main' into feature/gradvac
ValerianRey Apr 12, 2026
f4e8e60
Remove seed setting in test_aggregator_output
ValerianRey Apr 12, 2026
75c89c1
fix(aggregation): Add fallback in NashMTL (#640)
ValerianRey Apr 13, 2026
b100c8b
Merge branch 'main' into feature/gradvac
ValerianRey Apr 13, 2026
193ffa6
Merge branch 'main' of https://github.com/TorchJD/torchjd into featur…
rkhosrowshahi Apr 13, 2026
9ffdd13
Revert plot test refactors; keep GradVac in interactive plotter
rkhosrowshahi Apr 13, 2026
50525a1
Merge branch 'main' into feature/gradvac (21f6b74)
rkhosrowshahi Apr 13, 2026
e626475
docs(aggregation): add grouping usage example and fix GradVac note
rkhosrowshahi Apr 13, 2026
a244d2b
docs(changelog): split Unreleased into Added and Fixed for GradVac an…
rkhosrowshahi Apr 13, 2026
6a78932
feat(plots): restore enhanced interactive plotter UI
rkhosrowshahi Apr 13, 2026
787f486
Merge branch 'main' into feature/interactive-plotting-ui
ValerianRey Apr 14, 2026
cd04362
Remove grouping example
ValerianRey Apr 14, 2026
ea8b5d5
Improve display of degrees
ValerianRey Apr 14, 2026
0641812
Improve display of length
ValerianRey Apr 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions tests/plots/_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections.abc import Callable

import numpy as np
import torch
from plotly import graph_objects as go
Expand All @@ -7,14 +9,22 @@


class Plotter:
def __init__(self, aggregators: list[Aggregator], matrix: torch.Tensor, seed: int = 0) -> None:
self.aggregators = aggregators
def __init__(
self,
aggregator_factories: dict[str, Callable[[], Aggregator]],
selected_keys: list[str],
matrix: torch.Tensor,
seed: int = 0,
) -> None:
self._aggregator_factories = aggregator_factories
self.selected_keys = selected_keys
self.matrix = matrix
self.seed = seed

def make_fig(self) -> Figure:
torch.random.manual_seed(self.seed)
results = [agg(self.matrix) for agg in self.aggregators]
aggregators = [self._aggregator_factories[key]() for key in self.selected_keys]
results = [agg(self.matrix) for agg in aggregators]

fig = go.Figure()

Expand All @@ -23,14 +33,19 @@ def make_fig(self) -> Figure:
fig.add_trace(cone)

for i in range(len(self.matrix)):
scatter = make_vector_scatter(self.matrix[i], "blue", f"g{i + 1}")
scatter = make_vector_scatter(
self.matrix[i],
"blue",
f"g{i + 1}",
textposition="top right",
)
fig.add_trace(scatter)

for i in range(len(results)):
scatter = make_vector_scatter(
results[i],
"black",
str(self.aggregators[i]),
self.selected_keys[i],
showlegend=True,
dash=True,
)
Expand Down
126 changes: 96 additions & 30 deletions tests/plots/interactive_plotter.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import logging
import os
import webbrowser
from collections.abc import Callable
from threading import Timer

import numpy as np
import torch
from dash import Dash, Input, Output, callback, dcc, html
from plotly.graph_objs import Figure
from typing_extensions import Unpack

from plots._utils import Plotter, angle_to_coord, coord_to_angle
from torchjd.aggregation import (
IMTLG,
MGDA,
Aggregator,
AlignedMTL,
CAGrad,
ConFIG,
Expand All @@ -31,6 +34,14 @@
MAX_LENGTH = 25.0


def _format_angle_display(angle: float) -> str:
return f"{np.degrees(angle):.1f}°"


def _format_length_display(r: float) -> str:
return f"{r:.2f}"


def main() -> None:
log = logging.getLogger("werkzeug")
log.setLevel(logging.CRITICAL)
Expand All @@ -43,27 +54,30 @@ def main() -> None:
],
)

aggregators = [
AlignedMTL(),
CAGrad(c=0.5),
ConFIG(),
DualProj(),
GradDrop(),
GradVac(),
IMTLG(),
Mean(),
MGDA(),
NashMTL(n_tasks=matrix.shape[0]),
PCGrad(),
Random(),
Sum(),
TrimmedMean(trim_number=1),
UPGrad(),
]

aggregators_dict = {str(aggregator): aggregator for aggregator in aggregators}

plotter = Plotter([], matrix)
n_tasks = matrix.shape[0]
aggregator_factories: dict[str, Callable[[], Aggregator]] = {
"AlignedMTL-min": lambda: AlignedMTL(scale_mode="min"),
"AlignedMTL-median": lambda: AlignedMTL(scale_mode="median"),
"AlignedMTL-RMSE": lambda: AlignedMTL(scale_mode="rmse"),
Comment on lines +59 to +61
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like it would be nice to change the __str__ method of AlignedMTL to also include the scale_mode. For example, str(AlignedMTL(scale_mode="min")) would become AlignedMTL-min.

But this is quite independent from this PR, and should come in a different PR if you care about that! @rkhosrowshahi

str(CAGrad(c=0.5)): lambda: CAGrad(c=0.5),
str(ConFIG()): lambda: ConFIG(),
str(DualProj()): lambda: DualProj(),
str(GradDrop()): lambda: GradDrop(),
str(GradVac()): lambda: GradVac(),
str(IMTLG()): lambda: IMTLG(),
str(Mean()): lambda: Mean(),
str(MGDA()): lambda: MGDA(),
str(NashMTL(n_tasks=n_tasks)): lambda: NashMTL(n_tasks=n_tasks),
str(PCGrad()): lambda: PCGrad(),
str(Random()): lambda: Random(),
str(Sum()): lambda: Sum(),
str(TrimmedMean(trim_number=1)): lambda: TrimmedMean(trim_number=1),
str(UPGrad()): lambda: UPGrad(),
}

aggregator_strings = list(aggregator_factories.keys())

plotter = Plotter(aggregator_factories, [], matrix)

app = Dash(__name__)

Expand Down Expand Up @@ -98,7 +112,6 @@ def main() -> None:
gradient_slider_inputs.append(Input(angle_input, "value"))
gradient_slider_inputs.append(Input(r_input, "value"))

aggregator_strings = [str(aggregator) for aggregator in aggregators]
checklist = dcc.Checklist(aggregator_strings, [], id="aggregator-checklist")

control_div = html.Div(
Expand All @@ -117,32 +130,40 @@ def update_seed(value: int) -> Figure:
plotter.seed = value
return plotter.make_fig()

n_gradients = len(matrix)
gradient_value_outputs: list[Output] = []
for i in range(n_gradients):
gradient_value_outputs.append(Output(f"g{i + 1}-angle-value", "children"))
gradient_value_outputs.append(Output(f"g{i + 1}-length-value", "children"))

@callback(
Output("aggregations-fig", "figure", allow_duplicate=True),
*gradient_value_outputs,
*gradient_slider_inputs,
prevent_initial_call=True,
)
def update_gradient_coordinate(*values: str) -> Figure:
def update_gradient_coordinate(*values: str) -> tuple[Figure, Unpack[tuple[str, ...]]]:
values_ = [float(value) for value in values]

display_parts: list[str] = []
for j in range(len(values_) // 2):
angle = values_[2 * j]
r = values_[2 * j + 1]
x, y = angle_to_coord(angle, r)
plotter.matrix[j, 0] = x
plotter.matrix[j, 1] = y
display_parts.append(_format_angle_display(angle))
display_parts.append(_format_length_display(r))

return plotter.make_fig()
return (plotter.make_fig(), *display_parts)

@callback(
Output("aggregations-fig", "figure", allow_duplicate=True),
Input("aggregator-checklist", "value"),
prevent_initial_call=True,
)
def update_aggregators(value: list[str]) -> Figure:
aggregator_keys = value
new_aggregators = [aggregators_dict[key] for key in aggregator_keys]
plotter.aggregators = new_aggregators
plotter.selected_keys = list(value)
return plotter.make_fig()

Timer(1, open_browser).start()
Expand Down Expand Up @@ -175,11 +196,56 @@ def make_gradient_div(
style={"width": "250px"},
)

label_style: dict[str, str | int] = {
"display": "inline-block",
"width": "52px",
"margin-right": "8px",
"vertical-align": "middle",
}
value_style: dict[str, str] = {
"display": "inline-block",
"margin-left": "10px",
"min-width": "140px",
"font-family": "monospace",
"font-size": "13px",
"vertical-align": "middle",
}
row_style: dict[str, str] = {"display": "block", "margin-bottom": "6px"}
div = html.Div(
[
html.P(f"g{i + 1}", style={"display": "inline-block", "margin-right": 20}),
angle_input,
r_input,
dcc.Markdown(
f"$g_{{{i + 1}}}$",
mathjax=True,
style={
"margin": "0 0 6px 0",
"font-weight": "bold",
"display": "block",
},
),
html.Div(
[
html.Span("Angle", style=label_style),
angle_input,
html.Span(
id=f"g{i + 1}-angle-value",
children=_format_angle_display(angle),
style=value_style,
),
],
style=row_style,
),
html.Div(
[
html.Span("Length", style=label_style),
r_input,
html.Span(
id=f"g{i + 1}-length-value",
children=_format_length_display(r),
style=value_style,
),
],
style={**row_style, "margin-bottom": "12px"},
),
],
)
return div, angle_input, r_input
Expand Down
Loading