Skip to content

Commit

Permalink
csr.bus: redesign Multiplexer shadow registers.
Browse files Browse the repository at this point in the history
Before this commit, csr.Multiplexer had separate shadows for every
element in its memory map. The same shadow was shared for read and
write accesses to an element; a combined read/write transaction was
impossible despite being allowed by the CSR interface.

After this commit, csr.Multiplexer has separate shadows for read and
write accesses, but both shadows are shared by every element using
them. For multiplexers with many elements, this approach also results
in significant resource savings.
  • Loading branch information
jfng committed Aug 4, 2023
1 parent d2ca157 commit 39194ac
Show file tree
Hide file tree
Showing 2 changed files with 378 additions and 144 deletions.
279 changes: 243 additions & 36 deletions amaranth_soc/csr/bus.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import defaultdict
from math import ceil, log2
import enum
from amaranth import *
from amaranth.utils import log2_int

from ..memory import MemoryMap

Expand Down Expand Up @@ -171,10 +172,183 @@ def memory_map(self, memory_map):


class Multiplexer(Elaboratable):
class _Shadow:
class Chunk:
"""The interface between of a CSR multiplexer and a shadow register chunk."""
def __init__(self, shadow, offset, elements):
self.name = f"{shadow.name}__{offset}"
self.data = Signal(shadow.granularity, name=f"{self.name}__data")
self.r_en = Signal(name=f"{self.name}__r_en")
self.w_en = Signal(name=f"{self.name}__w_en")
self._elements = tuple(elements)

def elements(self):
"""Iterate the address ranges of CSR elements using this chunk."""
yield from self._elements

"""CSR multiplexer shadow register.
Attributes
----------
name : :class:`str`
Name of the shadow register.
granularity : :class:`int`
Amount of bits stored in a chunk of the shadow register.
overlaps : :class:`int`
Maximum amount of CSR elements that can share a chunk of the shadow register. Optional.
If ``None``, it is implicitly set by :meth:`Multiplexer._Shadow.prepare`.
"""
def __init__(self, granularity, overlaps, *, name):
assert isinstance(name, str)
assert isinstance(granularity, int) and granularity >= 0
assert overlaps is None or isinstance(overlaps, int) and overlaps >= 0
self.name = name
self.granularity = granularity
self.overlaps = overlaps
self._ranges = set()
self._size = 1
self._chunks = None

@property
def size(self):
"""Size of the shadow register.
Returns
-------
:class:`int`
The amount of :class:`Multiplexer._Shadow.Chunk`s of the shadow. It can increase
by calling :meth:`Multiplexer._Shadow.add` or :meth:`Multiplexer._Shadow.prepare`.
"""
return self._size

def add(self, elem_range):
"""Add a CSR element to the shadow.
Arguments
---------
elem_range : :class:`range`
Address range of a CSR :class:`Element`. It uses ``2 ** ceil(log2(elem_range.stop -
elem_range.start))`` chunks of the shadow register. If this amount is greater than
:attr:`~Multiplexer._Shadow.size`, it replaces the latter.
"""
assert isinstance(elem_range, range)
self._ranges.add(elem_range)
elem_size = 2 ** ceil(log2(elem_range.stop - elem_range.start))
self._size = max(self._size, elem_size)

def decode_address(self, addr, elem_range):
"""Decode a bus address into a shadow register offset.
Returns
-------
:class:`int`
The shadow register offset corresponding to the :class:`Multiplexer._Shadow.Chunk`
used by ``addr``.
The address decoding scheme is illustrated by the following example:
* ``addr`` is ``0x1c``;
* ``elem_range`` is ``range(0x1b, 0x1f)``;
* the :attr:`~Multiplexer._Shadow.size` of the shadow is ``16``.
The lower bits of the offset would be ``0b00``, extracted from ``addr``:
.. code-block::
+----+--+--+
|0001|11|00|
+----+--+--+
│ └─ 0
└──── ceil(log2(elem_range.stop - elem_range.start))
The upper bits of the offset would be ``0b10``, extracted from ``elem_range.start``:
.. code-block::
+----+--+--+
|0001|10|11|
+----+--+--+
│ │
│ └──── ceil(log2(elem_range.stop - elem_range.start))
└─────── log2(self.size)
The decoded offset would therefore be ``0xc`` (i.e. ``0b1100``).
"""
assert elem_range in self._ranges and addr in elem_range
elem_size = 2 ** ceil(log2(elem_range.stop - elem_range.start))
self_mask = self.size - 1
elem_mask = elem_size - 1
return elem_range.start & self_mask & ~elem_mask | addr & elem_mask

def encode_offset(self, offset, elem_range):
"""Encode a shadow register offset into a bus address.
Returns
-------
:class:`int`
The bus address in ``elem_range`` using the :class:`Multiplexer._Shadow.Chunk`
located at ``offset``. See :meth:`~Multiplexer._Shadow.decode_address` for details.
"""
assert elem_range in self._ranges and isinstance(offset, int)
elem_size = 2 ** ceil(log2(elem_range.stop - elem_range.start))
return elem_range.start + ((offset - elem_range.start) % elem_size)

def prepare(self):
"""Balance out and instantiate the shadow register chunks.
The scheme used by :meth:`~Multiplexer._Shadow.decode_address` allows multiple bus
addresses to be decoded to the same shadow register offset. Depending on the platform
and its toolchain, this may create nets with high fan-in (if the chunk is read from
the bus) or fan-out (if written), which may impact timing closure or resource usage.
If any shadow register offset is aliased to more bus addresses than permitted by the
:attr:`~Multiplexer._Shadow.overlaps` constraint, the :attr:`~Multiplexer._Shadow.size`
of the shadow is doubled. This increases the number of address bits used for decoding,
which effectively balances chunk usage across the shadow register.
This method is recursive until the overlap constraint is satisfied.
"""
if isinstance(self._ranges, frozenset):
return
if self.overlaps is None:
self.overlaps = len(self._ranges)

elements = defaultdict(list)
balanced = True

for elem_range in self._ranges:
for chunk_addr in elem_range:
chunk_offset = self.decode_address(chunk_addr, elem_range)
if len(elements[chunk_offset]) > self.overlaps:
balanced = False
break
elements[chunk_offset].append(elem_range)

if balanced:
self._ranges = frozenset(self._ranges)
self._chunks = dict()
for chunk_offset, chunk_elements in elements.items():
chunk = Multiplexer._Shadow.Chunk(self, chunk_offset, chunk_elements)
self._chunks[chunk_offset] = chunk
else:
self._size *= 2
self.prepare()

def chunks(self):
"""Iterate shadow register chunks used by at least one CSR element."""
for chunk_offset, chunk in self._chunks.items():
yield chunk_offset, chunk

"""CSR register multiplexer.
An address-based multiplexer for CSR registers implementing atomic updates.
This implementation assumes the following from the CSR bus:
* an initiator must have exclusive ownership over the multiplexer for the full duration of
a register transaction;
* an initiator must access a register in ascending order of addresses, but it may abort a
transaction after any bus cycle.
Latency
-------
Expand Down Expand Up @@ -214,16 +388,22 @@ class Multiplexer(Elaboratable):
Register alignment. See :class:`..memory.MemoryMap`.
name : str
Window name. Optional.
shadow_overlaps : int
Maximum number of CSR registers that can share a chunk of a shadow register.
Optional. If ``None``, any number of CSR registers can share a shadow chunk.
See :class:`Multiplexer._Shadow` for details.
Attributes
----------
bus : :class:`Interface`
CSR bus providing access to registers.
"""
def __init__(self, *, addr_width, data_width, alignment=0, name=None):
def __init__(self, *, addr_width, data_width, alignment=0, name=None, shadow_overlaps=None):
self._map = MemoryMap(addr_width=addr_width, data_width=data_width, alignment=alignment,
name=name)
self._bus = None
self._r_shadow = Multiplexer._Shadow(data_width, shadow_overlaps, name="r_shadow")
self._w_shadow = Multiplexer._Shadow(data_width, shadow_overlaps, name="w_shadow")

@property
def bus(self):
Expand Down Expand Up @@ -258,50 +438,77 @@ def add(self, element, *, addr=None, alignment=None, extend=False):
def elaborate(self, platform):
m = Module()

# Instead of a straightforward multiplexer for reads, use a per-element address comparator,
# AND the shadow register chunk with the comparator output, and OR all of those together.
# If the toolchain doesn't already synthesize multiplexer trees this way, this trick can
# save a significant amount of logic, since e.g. one 4-LUT can pack one 2-MUX, but two
# 2-AND or 2-OR gates.
r_data_fanin = 0

for elem, _, (elem_start, elem_end) in self._map.resources():
shadow = Signal(elem.width, name="{}__shadow".format(elem.name))
elem_range = range(elem_start, elem_end)
if elem.access.readable():
shadow_en = Signal(elem_end - elem_start, name="{}__shadow_en".format(elem.name))
m.d.sync += shadow_en.eq(0)
self._r_shadow.add(elem_range)
if elem.access.writable():
m.d.comb += elem.w_data.eq(shadow)
m.d.sync += elem.w_stb.eq(0)
self._w_shadow.add(elem_range)

self._r_shadow.prepare()
self._w_shadow.prepare()

# Instead of a straightforward multiplexer for reads, use an address comparator for each
# shadow register chunk, AND the comparator output with the chunk contents, and OR all of
# those together. If the toolchain doesn't already synthesize multiplexer trees this way,
# this trick can save a significant amount of logic, since e.g. one 4-LUT can pack one
# 2-MUX, but two 2-AND or 2-OR gates.
r_data_fanin = 0

for chunk_offset, r_chunk in self._r_shadow.chunks():
# Use the same trick to select which element is read into a shadow register chunk.
r_chunk_w_en_fanin = 0
r_chunk_data_fanin = 0

m.d.sync += r_chunk.r_en.eq(0)

# Enumerate every address used by the register explicitly, rather than using
# arithmetic comparisons, since some toolchains (e.g. Yosys) are too eager to infer
# carry chains for comparisons, even with a constant. (Register sizes don't have
# to be powers of 2.)
with m.Switch(self.bus.addr):
for chunk_offset, chunk_addr in enumerate(range(elem_start, elem_end)):
shadow_slice = shadow.word_select(chunk_offset, self.bus.data_width)
for elem_range in r_chunk.elements():
chunk_addr = self._r_shadow.encode_offset(chunk_offset, elem_range)
elem = self._map.decode_address(elem_range.start)
elem_offset = chunk_addr - elem_range.start
elem_slice = elem.r_data.word_select(elem_offset, self.bus.data_width)

with m.Case(chunk_addr):
if elem.access.readable():
r_data_fanin |= Mux(shadow_en[chunk_offset], shadow_slice, 0)
if chunk_addr == elem_start:
m.d.comb += elem.r_stb.eq(self.bus.r_stb)
with m.If(self.bus.r_stb):
m.d.sync += shadow.eq(elem.r_data)
# Delay by 1 cycle, allowing reads to be pipelined.
m.d.sync += shadow_en.eq(self.bus.r_stb << chunk_offset)

if elem.access.writable():
if chunk_addr == elem_end - 1:
# Delay by 1 cycle, avoiding combinatorial paths through
# the CSR bus and into CSR registers.
m.d.sync += elem.w_stb.eq(self.bus.w_stb)
with m.If(self.bus.w_stb):
m.d.sync += shadow_slice.eq(self.bus.w_data)
if chunk_addr == elem_range.start:
m.d.comb += elem.r_stb.eq(self.bus.r_stb)
# Delay by 1 cycle, allowing reads to be pipelined.
m.d.sync += r_chunk.r_en.eq(self.bus.r_stb)

r_chunk_w_en_fanin |= elem.r_stb
r_chunk_data_fanin |= Mux(elem.r_stb, elem_slice, 0)

m.d.comb += r_chunk.w_en.eq(r_chunk_w_en_fanin)
with m.If(r_chunk.w_en):
m.d.sync += r_chunk.data.eq(r_chunk_data_fanin)

r_data_fanin |= Mux(r_chunk.r_en, r_chunk.data, 0)

m.d.comb += self.bus.r_data.eq(r_data_fanin)

for chunk_offset, w_chunk in self._w_shadow.chunks():
with m.Switch(self.bus.addr):
for elem_range in w_chunk.elements():
chunk_addr = self._w_shadow.encode_offset(chunk_offset, elem_range)
elem = self._map.decode_address(elem_range.start)
elem_offset = chunk_addr - elem_range.start
elem_slice = elem.w_data.word_select(elem_offset, self.bus.data_width)

if chunk_addr == elem_range.stop - 1:
m.d.sync += elem.w_stb.eq(0)

with m.Case(chunk_addr):
if chunk_addr == elem_range.stop - 1:
# Delay by 1 cycle, avoiding combinatorial paths through
# the CSR bus and into CSR registers.
m.d.sync += elem.w_stb.eq(self.bus.w_stb)
m.d.comb += w_chunk.w_en.eq(self.bus.w_stb)

m.d.comb += elem_slice.eq(w_chunk.data)

with m.If(w_chunk.w_en):
m.d.sync += w_chunk.data.eq(self.bus.w_data)

return m


Expand Down
Loading

0 comments on commit 39194ac

Please sign in to comment.