From ff6c0327a78f97217d83691e606092e454803b41 Mon Sep 17 00:00:00 2001 From: anuejn Date: Thu, 16 Apr 2020 18:46:55 +0200 Subject: [PATCH] hdl.rec: make Record inherit from UserValue. Closes #354. --- nmigen/back/pysim.py | 3 --- nmigen/back/rtlil.py | 11 ++++++----- nmigen/hdl/ast.py | 5 ++++- nmigen/hdl/rec.py | 8 +++++--- nmigen/hdl/xfrm.py | 12 ------------ nmigen/test/test_hdl_ast.py | 8 ++++++++ nmigen/test/test_hdl_xfrm.py | 9 +++++++++ 7 files changed, 32 insertions(+), 24 deletions(-) diff --git a/nmigen/back/pysim.py b/nmigen/back/pysim.py index 1ea631c56..3077991e2 100644 --- a/nmigen/back/pysim.py +++ b/nmigen/back/pysim.py @@ -374,9 +374,6 @@ def on_ClockSignal(self, value): def on_ResetSignal(self, value): raise NotImplementedError # :nocov: - def on_Record(self, value): - return self(Cat(value.fields.values())) - def on_AnyConst(self, value): raise NotImplementedError # :nocov: diff --git a/nmigen/back/rtlil.py b/nmigen/back/rtlil.py index 120a354ee..3f9936e78 100644 --- a/nmigen/back/rtlil.py +++ b/nmigen/back/rtlil.py @@ -365,9 +365,6 @@ def on_Sample(self, value): def on_Initial(self, value): raise NotImplementedError # :nocov: - def on_Record(self, value): - return self(ast.Cat(value.fields.values())) - def on_Cat(self, value): return "{{ {} }}".format(" ".join(reversed([self(o) for o in value.parts]))) @@ -378,7 +375,11 @@ def on_Slice(self, value): if value.start == 0 and value.stop == len(value.value): return self(value.value) - sigspec = self._prepare_value_for_Slice(value.value) + if isinstance(value.value, ast.UserValue): + sigspec = self._prepare_value_for_Slice(value.value._lazy_lower()) + else: + sigspec = self._prepare_value_for_Slice(value.value) + if value.start == value.stop: return "{}" elif value.start + 1 == value.stop: @@ -644,7 +645,7 @@ def on_Signal(self, value): return wire_next or wire_curr def _prepare_value_for_Slice(self, value): - assert isinstance(value, (ast.Signal, ast.Slice, ast.Cat, rec.Record)) + assert isinstance(value, (ast.Signal, ast.Slice, ast.Cat)) return self(value) def on_Part(self, value): diff --git a/nmigen/hdl/ast.py b/nmigen/hdl/ast.py index 45dab3d06..d8d9da041 100644 --- a/nmigen/hdl/ast.py +++ b/nmigen/hdl/ast.py @@ -1188,7 +1188,10 @@ def lower(self): def _lazy_lower(self): if self.__lowered is None: - self.__lowered = Value.cast(self.lower()) + lowered = self.lower() + if isinstance(lowered, UserValue): + lowered = lowered._lazy_lower() + self.__lowered = Value.cast(lowered) return self.__lowered def shape(self): diff --git a/nmigen/hdl/rec.py b/nmigen/hdl/rec.py index b60fb00fd..be8934226 100644 --- a/nmigen/hdl/rec.py +++ b/nmigen/hdl/rec.py @@ -85,7 +85,7 @@ def __repr__(self): # Unlike most Values, Record *can* be subclassed. -class Record(Value): +class Record(UserValue): @staticmethod def like(other, *, name=None, name_suffix=None, src_loc_at=0): if name is not None: @@ -113,6 +113,8 @@ def concat(a, b): return Record(other.layout, name=new_name, fields=fields, src_loc_at=1) def __init__(self, layout, *, name=None, fields=None, src_loc_at=0): + super().__init__(src_loc_at=src_loc_at) + if name is None: name = tracer.get_var_name(depth=2 + src_loc_at, default=None) @@ -165,8 +167,8 @@ def __getitem__(self, item): else: return super().__getitem__(item) - def shape(self): - return Shape(sum(len(f) for f in self.fields.values())) + def lower(self): + return Cat(self.fields.values()) def _lhs_signals(self): return union((f._lhs_signals() for f in self.fields.values()), start=SignalSet()) diff --git a/nmigen/hdl/xfrm.py b/nmigen/hdl/xfrm.py index fb4dad954..4d83eb726 100644 --- a/nmigen/hdl/xfrm.py +++ b/nmigen/hdl/xfrm.py @@ -38,10 +38,6 @@ def on_AnySeq(self, value): def on_Signal(self, value): pass # :nocov: - @abstractmethod - def on_Record(self, value): - pass # :nocov: - @abstractmethod def on_ClockSignal(self, value): pass # :nocov: @@ -98,9 +94,6 @@ def on_value(self, value): elif isinstance(value, Signal): # Uses `isinstance()` and not `type() is` because nmigen.compat requires it. new_value = self.on_Signal(value) - elif isinstance(value, Record): - # Uses `isinstance()` and not `type() is` to allow inheriting from Record. - new_value = self.on_Record(value) elif type(value) is ClockSignal: new_value = self.on_ClockSignal(value) elif type(value) is ResetSignal: @@ -147,9 +140,6 @@ def on_AnySeq(self, value): def on_Signal(self, value): return value - def on_Record(self, value): - return value - def on_ClockSignal(self, value): return value @@ -372,8 +362,6 @@ def on_ClockSignal(self, value): def on_ResetSignal(self, value): self._add_used_domain(value.domain) - on_Record = on_ignore - def on_Operator(self, value): for o in value.operands: self.on_value(o) diff --git a/nmigen/test/test_hdl_ast.py b/nmigen/test/test_hdl_ast.py index f642f20f6..1b7dd58e2 100644 --- a/nmigen/test/test_hdl_ast.py +++ b/nmigen/test/test_hdl_ast.py @@ -916,6 +916,14 @@ def test_shape(self): self.assertEqual(uv.shape(), unsigned(1)) self.assertEqual(uv.lower_count, 1) + def test_lower_to_user_value(self): + uv = MockUserValue(MockUserValue(1)) + self.assertEqual(uv.shape(), unsigned(1)) + self.assertIsInstance(uv.shape(), Shape) + uv.lowered = MockUserValue(2) + self.assertEqual(uv.shape(), unsigned(1)) + self.assertEqual(uv.lower_count, 1) + class SampleTestCase(FHDLTestCase): def test_const(self): diff --git a/nmigen/test/test_hdl_xfrm.py b/nmigen/test/test_hdl_xfrm.py index e5f1745f1..121c87f10 100644 --- a/nmigen/test/test_hdl_xfrm.py +++ b/nmigen/test/test_hdl_xfrm.py @@ -620,3 +620,12 @@ def test_lower(self): ) ) """) + + +class UserValueRecursiveTestCase(UserValueTestCase): + def setUp(self): + self.s = Signal() + self.c = Signal() + self.uv = MockUserValue(MockUserValue(self.s)) + + # inherit the test_lower method from UserValueTestCase because the checks are the same