Fix _get_descendant_accumulate_grads#217
Conversation
|
@austen260 could you tell me if this other implementation also works on your particular case? It's very similar to what you suggested, so I think it should. |
Codecov ReportAll modified and coverable lines are covered by tests ✅
|
PierreQuinton
left a comment
There was a problem hiding this comment.
We would need a test to reproduce the bug of course but otherwise LGTM.
Yes, I would also like to make sure that we can reproduce or understand the infinite loop bug before writing this changelog message |
|
The thing is that we are pretty sure that this is better, so maybe we can merge without test and changelog and work on an additional PR with less priority for the test and the changelog, this is a bit irregular and is a bit dirty in terms of commits log. |
|
We could do something like this instead: def _get_descendant_accumulate_grads(roots: set[Node], excluded_nodes: set[Node]) -> set[Node]:
excluded_nodes = set(excluded_nodes) # Re-instantiate set to avoid modifying input
nodes_to_traverse = deque([node for node in roots if node is not None])
result = set()
while nodes_to_traverse:
current_node = nodes_to_traverse.popleft() # Breadth-first
excluded_nodes.add(current_node)
if current_node.__class__.__name__ == "AccumulateGrad":
result.add(current_node)
for node, _ in current_node.next_functions:
if node is not None and node not in excluded_nodes:
nodes_to_traverse.append(node) # Append to the right
return resultThis avoids duplicating the check |
I can confirm that this resolves the issue on my end! Thank you both very much! |
|
@ValerianRey I think we should do the following:
|
_get_descendant_accumulate_gradsstuck in infinite loop #216