Skip to content

Fix _get_descendant_accumulate_grads#217

Merged
ValerianRey merged 6 commits intomainfrom
fix-get-descendant-accumulate-grads
Dec 21, 2024
Merged

Fix _get_descendant_accumulate_grads#217
ValerianRey merged 6 commits intomainfrom
fix-get-descendant-accumulate-grads

Conversation

@ValerianRey
Copy link
Copy Markdown
Contributor

@ValerianRey ValerianRey commented Dec 16, 2024

@ValerianRey
Copy link
Copy Markdown
Contributor Author

@austen260 could you tell me if this other implementation also works on your particular case? It's very similar to what you suggested, so I think it should.

@codecov
Copy link
Copy Markdown

codecov Bot commented Dec 16, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Files with missing lines Coverage Δ
src/torchjd/autojac/_utils.py 100.00% <100.00%> (ø)

Copy link
Copy Markdown
Contributor

@PierreQuinton PierreQuinton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would need a test to reproduce the bug of course but otherwise LGTM.

@ValerianRey
Copy link
Copy Markdown
Contributor Author

We would need a test to reproduce the bug of course but otherwise LGTM.

Yes, I would also like to make sure that we can reproduce or understand the infinite loop bug before writing this changelog message

@PierreQuinton
Copy link
Copy Markdown
Contributor

The thing is that we are pretty sure that this is better, so maybe we can merge without test and changelog and work on an additional PR with less priority for the test and the changelog, this is a bit irregular and is a bit dirty in terms of commits log.

@PierreQuinton
Copy link
Copy Markdown
Contributor

We could do something like this instead:

def _get_descendant_accumulate_grads(roots: set[Node], excluded_nodes: set[Node]) -> set[Node]:

    excluded_nodes = set(excluded_nodes)  # Re-instantiate set to avoid modifying input
    nodes_to_traverse = deque([node for node in roots if node is not None])
    result = set()

    while nodes_to_traverse:
        current_node = nodes_to_traverse.popleft()  # Breadth-first
        excluded_nodes.add(current_node)

        if current_node.__class__.__name__ == "AccumulateGrad":
            result.add(current_node)

        for node, _ in current_node.next_functions:
            if node is not None and node not in excluded_nodes:
                nodes_to_traverse.append(node)  # Append to the right

    return result

This avoids duplicating the check node not in excluded_nodes, so this makes it more natural (and I think equivalent).

@AustenMan
Copy link
Copy Markdown

We could do something like this instead:

def _get_descendant_accumulate_grads(roots: set[Node], excluded_nodes: set[Node]) -> set[Node]:

    excluded_nodes = set(excluded_nodes)  # Re-instantiate set to avoid modifying input
    nodes_to_traverse = deque([node for node in roots if node is not None])
    result = set()

    while nodes_to_traverse:
        current_node = nodes_to_traverse.popleft()  # Breadth-first
        excluded_nodes.add(current_node)

        if current_node.__class__.__name__ == "AccumulateGrad":
            result.add(current_node)

        for node, _ in current_node.next_functions:
            if node is not None and node not in excluded_nodes:
                nodes_to_traverse.append(node)  # Append to the right

    return result

This avoids duplicating the check node not in excluded_nodes, so this makes it more natural (and I think equivalent).

@ValerianRey

I can confirm that this resolves the issue on my end!

Thank you both very much!

@PierreQuinton
Copy link
Copy Markdown
Contributor

@ValerianRey I think we should do the following:

  • Merge this PR without changelog entry, this code just makes the graph search more efficient and is non-breaking.
  • Then we investigate separately if it is indeed possible to have a cyclic graph (in which case this PR should solve the problem), if it is possible, then we will add a test later on (which is also non-breaking if we didn't make a mistake here).

@ValerianRey ValerianRey merged commit a352cc9 into main Dec 21, 2024
@ValerianRey ValerianRey deleted the fix-get-descendant-accumulate-grads branch December 21, 2024 14:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

_get_descendant_accumulate_grads stuck in infinite loop

3 participants