Skip to content

Commit

Permalink
simplified, added loop printing, tests pass if GrayDecoder fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
widlarizer committed Oct 31, 2022
1 parent 22a10fc commit 0526474
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 65 deletions.
42 changes: 29 additions & 13 deletions amaranth/hdl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,11 @@ def prepare(self, ports=None, missing_domain=lambda name: ClockDomain(name)):
fragment = SampleLowerer()(self)
new_domains = fragment._propagate_domains(missing_domain)
fragment = DomainLowerer()(fragment)
ctx = CombGraphCtx()
(edges, node_cnt) = CombGraphCompiler(ctx)(fragment)
loop = find_loop(range(node_cnt), list(edges))
if loop:
raise DomainError(format_loop(loop, ctx))
if ports is None:
fragment._propagate_ports(ports=(), all_undef_as_ports=True)
else:
Expand All @@ -555,9 +560,6 @@ def prepare(self, ports=None, missing_domain=lambda name: ClockDomain(name)):
if cd.rst is not None:
mapped_ports.append(cd.rst)
fragment._propagate_ports(ports=mapped_ports, all_undef_as_ports=False)
(edges, node_cnt) = CombGraphCompiler(CombGraphCtx())(fragment)
if find_cycle(range(node_cnt), list(edges)):
raise DomainError("Combinational loop detected")

return fragment

Expand All @@ -584,7 +586,10 @@ def __init__(self, state):
def _edges(self, fragment):
from .xfrm import LHSGroupFilter, AssignmentGraphBuilder
comb_edges = set()
comb_signals = fragment.drivers[None]
try:
comb_signals = fragment.drivers[None]
except KeyError:
return set()
stmts = LHSGroupFilter(comb_signals)(fragment.statements)
for signal in comb_signals:
_ = self.state.get_signal(signal)
Expand All @@ -599,27 +604,38 @@ def __call__(self, fragment):
return (self._edges(fragment), self.state.next_idx)


def find_cycle(nodes, edges):
edge_list = [[] for _ in edges]
def find_loop(nodes, edges):
edge_list = [[] for _ in nodes]
for [neighbour, node] in edges:
edge_list[node].append(neighbour)

def DFS(source, path):
if source in path:
return True
loop = [source]
return loop
elif not visited[source]:
visited[source] = True
neighbours = edge_list[source]
for n in neighbours:
node_path = path.union([source])
if DFS(n, node_path):
return True
loop = DFS(n, path.union([source]))
if loop:
return loop + [source]
return []

visited = [False] * len(nodes)
for node in nodes:
if DFS(node, set()):
return True
return False
loop = DFS(node, set())
if loop:
return loop
return []


def format_loop(loop, ctx):
s = "Combinational loop detected: "
for sig, id in ctx.signals.items():
if id in loop:
s += f"{sig.name},"
return s[:-1]


class Instance(Fragment):
Expand Down
64 changes: 12 additions & 52 deletions amaranth/hdl/xfrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,47 +743,6 @@ def on_fragment(self, fragment):
return new_fragment


class ValueSignalExtractor(ValueVisitor):
def __init__(self):
self.sigs = []

def on_ignore(self, value):
pass

on_Const = on_ignore
on_AnyConst = on_Const
on_AnySeq = on_Const
on_ClockSignal = on_Const
on_ResetSignal = on_Const
on_Sample = on_Const

def on_recurse(self, value):
self.on_value(value.value)

on_Part = on_recurse
on_Slice = on_recurse
on_Repl = on_recurse

def on_Signal(self, value):
self.sigs.append(value)

def on_Operator(self, value):
for o in value.operands:
self.on_value(o)

def on_Cat(self, value):
for o in value.parts:
self.on_value(o)

def on_ArrayProxy(self, value):
self.on_value(value.index)
for elem in value.elems:
self.on_value(elem)

def on_Initial(self, value):
self.on_Signal(value)


class AssignmentGraphBuilder(StatementVisitor):
def __init__(self, state):
self.edges = []
Expand All @@ -798,28 +757,29 @@ def on_ignore(self, stmt):
on_Cover = on_ignore

def on_value(self, value):
sigs = ValueSignalExtractor()
sigs(value)
return sigs.sigs
return value._rhs_signals()

def on_Assign(self, stmt):
def get_ids(sigs):
return map(self.state.get_signal, sigs)
for driver in set(get_ids(self.on_value(stmt.rhs))):
self.edges.append((driver, self.state.get_signal(stmt.lhs)))
for control_sig in set(get_ids(self.control_stack)):
self.edges.append((control_sig, self.state.get_signal(stmt.lhs)))
for lhs_sig in stmt.lhs._lhs_signals():
for driver in set(get_ids(self.on_value(stmt.rhs))):
self.edges.append((driver, self.state.get_signal(lhs_sig)))
for control_sig in set(get_ids(self.control_stack)):
self.edges.append((control_sig, self.state.get_signal(lhs_sig)))

def on_Switch(self, stmt):
control_sigs = ValueSignalExtractor()
control_sigs(stmt.test)
for sig in control_sigs.sigs:
if isinstance(stmt.test, ClockSignal):
return # sync in a trenchcoat!
for sig in stmt.test._rhs_signals():
if sig.name == 'clk':
return # sync in a trenchcoat!
self.control_stack.append(sig)

for case_stmts in stmt.cases.values():
self.on_statements(case_stmts)

for _ in control_sigs.sigs:
for _ in stmt.test._rhs_signals():
self.control_stack.pop()

def on_statements(self, stmts):
Expand Down

0 comments on commit 0526474

Please sign in to comment.