<a href="https://colab.research.google.com/github/alanvgreen/CFU-Playground/blob/fccm2/Amaranth_for_CFUs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Amaranth for CFUs

```
Copyright 2022 Google LLC.
SPDX-License-Identifier: Apache-2.0
```
This page shows Amaranth's language features.

* https://github.com/amaranth-lang/amaranth
* Docs: https://amaranth-lang.org/docs/amaranth/latest/

avg@google.com / 2022-04-19


In [None]:
# Install Amaranth 
!pip install --upgrade 'amaranth[builtin-yosys]'

# CFU-Playground library
!git clone https://github.com/google/CFU-Playground.git
import sys
sys.path.append('CFU-Playground/python')

# Imports
from amaranth import *
from amaranth.back import verilog
from amaranth.sim import Delay, Simulator, Tick
from amaranth_cfu import TestBase, SimpleElaboratable, pack_vals, simple_cfu, InstructionBase, CfuTestBase
import re, unittest

# Utility to convert Amaranth to verilog 
def convert_elaboratable(elaboratable):
  v = verilog.convert(elaboratable, name='Top', ports=elaboratable.ports)
  v = re.sub(r'\(\*.*\*\)', '', v)
  return re.sub(r'^ *\n', '\n', v, flags=re.MULTILINE)

def runTests(klazz):
  loader = unittest.TestLoader()
  suite = unittest.TestSuite()
  suite.addTests(loader.loadTestsFromTestCase(klazz))
  runner = unittest.TextTestRunner()
  runner.run(suite)

Collecting amaranth[builtin-yosys]
  Downloading amaranth-0.3-py3-none-any.whl (167 kB)
[?25l[K     |██                              | 10 kB 23.1 MB/s eta 0:00:01[K     |████                            | 20 kB 11.4 MB/s eta 0:00:01[K     |█████▉                          | 30 kB 9.1 MB/s eta 0:00:01[K     |███████▉                        | 40 kB 8.2 MB/s eta 0:00:01[K     |█████████▊                      | 51 kB 6.3 MB/s eta 0:00:01[K     |███████████▊                    | 61 kB 7.4 MB/s eta 0:00:01[K     |█████████████▊                  | 71 kB 6.7 MB/s eta 0:00:01[K     |███████████████▋                | 81 kB 6.7 MB/s eta 0:00:01[K     |█████████████████▋              | 92 kB 7.4 MB/s eta 0:00:01[K     |███████████████████▌            | 102 kB 7.2 MB/s eta 0:00:01[K     |█████████████████████▌          | 112 kB 7.2 MB/s eta 0:00:01[K     |███████████████████████▍        | 122 kB 7.2 MB/s eta 0:00:01[K     |█████████████████████████▍      | 133 kB 7.2 MB/s e

In [None]:
# Single Multiply-Add
class SingleMultiply(SimpleElaboratable):
  def __init__(self):
    self.a = Signal(signed(8))
    self.b = Signal(signed(8))
    self.result = Signal(signed(32))
  def elab(self, m):
    m.d.comb += self.result.eq((self.a + 128) * self.b)

class SingleMultiplyTest(TestBase):
  def create_dut(self):
    return SingleMultiply()
  def test(self):
    TEST_CASE = [
      (1-128, 1, 1),
      (33-128, -25, 33*-25),
    ]
    def process():
      for (a, b, expected) in TEST_CASE:
        yield self.dut.a.eq(a)
        yield self.dut.b.eq(b)
        yield Delay(0.1)
        self.assertEqual(expected, (yield self.dut.result))
        yield
    self.run_sim(process)

runTests(SingleMultiplyTest)

CFU-Playground/python/amaranth_cfu/util.py:121: UnusedElaboratable: <amaranth_cfu.util._DummySyncModule object at 0x7f16a6055510> created but never used
  self.m.submodules['dummy'] = _DummySyncModule()
UnusedElaboratable: Enable tracemalloc to get the object allocation traceback
CFU-Playground/python/amaranth_cfu/util.py:90: UnusedElaboratable: <amaranth.hdl.dsl.Module object at 0x7f16a6055610> created but never used
  self.m = Module()
UnusedElaboratable: Enable tracemalloc to get the object allocation traceback
.
----------------------------------------------------------------------
Ran 1 test in 0.005s

OK


In [None]:
# four multiply-adds

class WordMultiplyAdd(SimpleElaboratable):
  def __init__(self):
    self.a_word = Signal(32)
    self.b_word = Signal(32)
    self.result = Signal(signed(32))
  def elab(self, m):
    a_bytes = [self.a_word[i:i+8].as_signed() for i in range(0, 32, 8)]
    b_bytes = [self.b_word[i:i+8].as_signed() for i in range(0, 32, 8)]
    m.d.comb += self.result.eq(
        sum((a + 128) * b for a, b in zip(a_bytes, b_bytes)))


class WordMultiplyAddTest(TestBase):
  def create_dut(self):
    return WordMultiplyAdd()
  
  def test(self):
    def a(a, b, c, d): return pack_vals(a, b, c, d, offset=-128)
    def b(a, b, c, d): return pack_vals(a, b, c, d, offset=0)
    TEST_CASE = [
        (a(99, 22, 2, 1), b(-2, 6, 7, 111), 59),
        (a(63, 161, 15, 0), b(29, 13, 62, -38), 4850),
    ]
    def process():
      for (a, b, expected) in TEST_CASE:
        yield self.dut.a_word.eq(a)
        yield self.dut.b_word.eq(b)
        yield Delay(0.1)
        self.assertEqual(expected, (yield self.dut.result))
        yield
    self.run_sim(process)

runTests(WordMultiplyAddTest)

.
----------------------------------------------------------------------
Ran 1 test in 0.009s

OK


In [None]:
class WordMultiplyAccumulate(SimpleElaboratable):
  def __init__(self):
    self.a_word = Signal(32)
    self.b_word = Signal(32)
    self.accumulator = Signal(signed(32))
    self.enable = Signal()
    self.clear = Signal()
  def elab(self, m):
    a_bytes = [self.a_word[i:i+8].as_signed() for i in range(0, 32, 8)]
    b_bytes = [self.b_word[i:i+8].as_signed() for i in range(0, 32, 8)]
    calculations = ((a + 128) * b for a, b in zip(a_bytes, b_bytes))
    summed = sum(calculations)
    with m.If(self.enable):
      m.d.sync += self.accumulator.eq(self.accumulator + summed)
    with m.If(self.clear):
      m.d.sync += self.accumulator.eq(0)


class WordMultiplyAccumulateTest(TestBase):
  def create_dut(self):
    return WordMultiplyAccumulate()
  
  def test(self):
    def a(a, b, c, d): return pack_vals(a, b, c, d, offset=-128)
    def b(a, b, c, d): return pack_vals(a, b, c, d, offset=0)
    DATA = [
        # (a_word, b_word, enable, clear), expected accumulator
        ((a(0, 0, 0, 0),  b(0, 0, 0, 0), 0, 0), 0),

        # Simple tests: with just first byte
        ((a(10, 0, 0, 0), b(3, 0, 0, 0),  1, 0),   0),
        ((a(11, 0, 0, 0), b(-4, 0, 0, 0), 1, 0),  30),
        ((a(11, 0, 0, 0), b(-4, 0, 0, 0), 0, 0), -14),
        # Since was not enabled last cycle, accumulator will not change
        ((a(11, 0, 0, 0), b(-4, 0, 0, 0), 1, 0), -14),
        # Since was enabled last cycle, will change accumlator
        ((a(11, 0, 0, 0), b(-4, 0, 0, 0), 0, 1), -58),
        # Accumulator cleared
        ((a(11, 0, 0, 0), b(-4, 0, 0, 0), 0, 0),  0),

        # Uses all bytes (calculated on a spreadsheet)
        ((a(99, 22, 2, 1),      b(-2, 6, 7, 111), 1, 0),             0),
        ((a(2, 45, 79, 22),     b(-33, 6, -97, -22), 1, 0),         59),
        ((a(23, 34, 45, 56),    b(-128, -121, 119, 117), 1, 0),  -7884),
        ((a(188, 34, 236, 246), b(-87, 56, 52, -117), 1, 0),     -3035),
        ((a(131, 92, 21, 83),   b(-114, -72, -31, -44), 1, 0),  -33997),
        ((a(74, 68, 170, 39),   b(102, 12, 53, -128), 1, 0),    -59858),
        ((a(16, 63, 1, 198),    b(29, 36, 106, 62), 1, 0),      -47476),
        ((a(0, 0, 0, 0),        b(0, 0, 0, 0), 0, 1),           -32362),

        # Interesting bug
        ((a(128, 0, 0, 0), b(-104, 0, 0, 0), 1, 0), 0),
        ((a(0, 51, 0, 0), b(0, 43, 0, 0), 1, 0), -13312),
        ((a(0, 0, 97, 0), b(0, 0, -82, 0), 1, 0), -11119),
        ((a(0, 0, 0, 156), b(0, 0, 0, -83), 1, 0), -19073),
        ((a(0, 0, 0, 0), b(0, 0, 0, 0), 1, 0), -32021),
    ]

    dut = self.dut

    def process():
        for (a_word, b_word, enable, clear), expected in DATA:
            yield dut.a_word.eq(a_word)
            yield dut.b_word.eq(b_word)
            yield dut.enable.eq(enable)
            yield dut.clear.eq(clear)
            yield Delay(0.1)  # Wait for input values to settle

            # Check on accumulator, as calcuated last cycle
            self.assertEqual(expected, (yield dut.accumulator))
            yield Tick()
    self.run_sim(process)

runTests(WordMultiplyAccumulateTest)  

.
----------------------------------------------------------------------
Ran 1 test in 0.024s

OK


In [None]:
class WordMultiplyAccumulateInstruction(InstructionBase):

In [None]:
class SyncAndComb(Elaboratable):
  def __init__(self):
    self.out = Signal(1)
    self.ports = [self.out]
  def elaborate(self, platform):
    m = Module()
    counter = Signal(2)
    m.d.sync += counter.eq(counter + 1)
    m.d.comb += self.out.eq(counter[-1])
    return m
print(convert_elaboratable(SyncAndComb()))

In [None]:
class ConditionalEnable(Elaboratable):
  def __init__(self):
    self.up = Signal()
    self.down = Signal()
    self.value = Signal(5)
    self.ports = [self.value, self.up, self.down]

  def elaborate(self, platform):
    m = Module()
    with m.If(self.up):
      m.d.sync += self.value.eq(self.value + 1)
    with m.Elif(self.down):
      m.d.sync += self.value.eq(self.value - 1)
    return m

print(convert_elaboratable(ConditionalEnable()))
    

In [None]:
# Edge detector and test cases
class EdgeDetector(SimpleElaboratable):
  """Detects low-high transitions in a signal"""
  def __init__(self):
    self.input = Signal()
    self.detected = Signal()
  def elab(self, m):
    last = Signal()
    m.d.sync += last.eq(self.input)
    m.d.comb += self.detected.eq(self.input & ~last)
    
class EdgeDetectorTestCase(AmaranthTestBase):
  def create_dut(self):
    return EdgeDetector()

  def test_with_table(self):
    TEST_CASE = [
      (0, 0),
      (1, 1),
      (0, 0),
      (0, 0),
      (1, 1),
      (1, 0),
      (0, 0),
    ]
    def process():
      for (input, expected) in TEST_CASE:
        # Set input
        yield self.dut.input.eq(input)
        # Allow some time for signals to propagate
        yield Delay(0.1)
        self.assertEqual(expected, (yield self.dut.detected))
        yield
    self.run_sim(process)

EdgeDetectorTestCase().runTests()