From f0b4e964e5fc25e1e516c8ff9f7ffd3a43c404dc Mon Sep 17 00:00:00 2001 From: Wanda Date: Mon, 13 Oct 2025 18:23:51 +0200 Subject: [PATCH] hdl._nir: speed up combinational cycle detection. Fixes #1628. --- amaranth/hdl/_nir.py | 60 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 58 insertions(+), 2 deletions(-) diff --git a/amaranth/hdl/_nir.py b/amaranth/hdl/_nir.py index 23b53136c..5c9e78da4 100644 --- a/amaranth/hdl/_nir.py +++ b/amaranth/hdl/_nir.py @@ -444,6 +444,7 @@ def traverse(net): busy.add(net) cycle = None + extra_nets = [] if net.is_const: pass elif net.is_late: @@ -452,10 +453,16 @@ def traverse(net): sig, bit = self.late_to_signal[net] cycle.path.append((sig, bit, sig.src_loc)) else: - for src, src_loc in self.cells[net.cell].comb_edges_to(net.bit): + cell = self.cells[net.cell] + if not cell.comb_edges_is_per_bit(): + extra_nets = [extra_net for extra_net in cell.output_nets(net.cell) if extra_net != net] + for extra_net in extra_nets: + assert extra_net not in checked + busy.add(extra_net) + for src, src_loc in cell.comb_edges_to(net.bit): cycle = traverse(src) if cycle is not None: - cycle.path.append((self.cells[net.cell], net.bit, src_loc)) + cycle.path.append((cell, net.bit, src_loc)) break if cycle is not None and cycle.start == net: @@ -473,6 +480,9 @@ def traverse(net): busy.remove(net) checked.add(net) + for extra_net in extra_nets: + busy.remove(extra_net) + checked.add(extra_net) return cycle for cell_idx, cell in enumerate(self.cells): @@ -579,6 +589,10 @@ def resolve_nets(self, netlist: Netlist): def comb_edges_to(self, bit: int) -> "Iterable[(Net, Any)]": raise NotImplementedError + def comb_edges_is_per_bit(self) -> bool: + """Returns True iff ``comb_edges_to`` looks at its argument.""" + raise NotImplementedError + class Top(Cell): """A special cell type representing top-level non-IO ports. Must be present in the netlist exactly @@ -631,6 +645,9 @@ def __repr__(self): def comb_edges_to(self, bit): return [] + def comb_edges_is_per_bit(self) -> bool: + return False + class Operator(Cell): """Roughly corresponds to ``hdl.ast.Operator``. @@ -722,6 +739,15 @@ def comb_edges_to(self, bit): yield (self.inputs[1][bit], self.src_loc) yield (self.inputs[2][bit], self.src_loc) + def comb_edges_is_per_bit(self) -> bool: + if len(self.inputs) == 1 and self.operator == "~": + return True + elif len(self.inputs) == 2 and self.operator in ("&", "|", "^"): + return True + elif len(self.inputs) == 3: + return True + return False + class Part(Cell): """Corresponds to ``hdl.ast.Part``. @@ -767,6 +793,9 @@ def comb_edges_to(self, bit): for net in self.offset: yield (net, self.src_loc) + def comb_edges_is_per_bit(self) -> bool: + return False + class Match(Cell): """Used to represent a single switch on the control plane of processes. @@ -812,6 +841,9 @@ def comb_edges_to(self, bit): for net in self.value: yield (net, self.src_loc) + def comb_edges_is_per_bit(self) -> bool: + return False + class Assignment: """A single assignment in an ``AssignmentList``. @@ -895,6 +927,9 @@ def comb_edges_to(self, bit): yield (assign.cond, assign.src_loc) yield (assign.value[bit - assign.start], assign.src_loc) + def comb_edges_is_per_bit(self) -> bool: + return True + class FlipFlop(Cell): """A flip-flop. ``data`` is the data input. ``init`` is the initial and async reset value. @@ -943,6 +978,9 @@ def comb_edges_to(self, bit): yield (self.clk, self.src_loc) yield (self.arst, self.src_loc) + def comb_edges_is_per_bit(self) -> bool: + return False + class Memory(Cell): """Corresponds to ``Memory``. ``init`` must have length equal to ``depth``. @@ -1054,6 +1092,9 @@ def comb_edges_to(self, bit): for net in self.addr: yield (net, self.src_loc) + def comb_edges_is_per_bit(self) -> bool: + return False + class SyncReadPort(Cell): """A single synchronous read port of a memory. The cell output is the data port. @@ -1101,6 +1142,9 @@ def __repr__(self): def comb_edges_to(self, bit): return [] + def comb_edges_is_per_bit(self) -> bool: + return False + class AsyncPrint(Cell): """Corresponds to ``Print`` in the "comb" domain. @@ -1187,6 +1231,9 @@ def __repr__(self): def comb_edges_to(self, bit): return [] + def comb_edges_is_per_bit(self) -> bool: + return False + class AnyValue(Cell): """Corresponds to ``AnyConst`` or ``AnySeq``. ``kind`` must be either ``'anyconst'`` @@ -1220,6 +1267,9 @@ def __repr__(self): def comb_edges_to(self, bit): return [] + def comb_edges_is_per_bit(self) -> bool: + return False + class AsyncProperty(Cell): """Corresponds to ``Assert``, ``Assume``, or ``Cover`` in the "comb" domain. @@ -1381,6 +1431,9 @@ def comb_edges_to(self, bit): # don't ask me, I'm a housecat return [] + def comb_edges_is_per_bit(self) -> bool: + return False + class IOBuffer(Cell): """An IO buffer cell. This cell does two things: @@ -1440,3 +1493,6 @@ def comb_edges_to(self, bit): if self.dir is not IODirection.Input: yield (self.o[bit], self.src_loc) yield (self.oe, self.src_loc) + + def comb_edges_is_per_bit(self) -> bool: + return True