<a href="https://colab.research.google.com/github/msaligane/US_Japan_Semiconductor_Workshop/blob/main/Day%202%20-%201355%20-%20XLS%20-%20High-Level%20Synthesis/learn-xls-in-y-minutes.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Learn XLS in Y minutes

In the spirit of https://learnxinyminutes.com/ this notebook show the basic of [XLS toolchain](https://google.github.io/xls/) and the [DSLX language](https://google.github.io/xls/dslx_reference/) thru heavily annotated interactive examples.

In [None]:
#@title Setup {run:"auto"}

xls_version = 'v0.0.0-4699-gfb023174' #@param {type:"string"}

!echo '📦 downloading xls-{xls_version}'
!curl --show-error -L https://github.com/proppy/xls/releases/download/{xls_version}/xls-{xls_version}-linux-x64.tar.gz | tar xzf - --strip-components=1
!echo '🧪 setting up colab integration'
!python -m pip install --quiet --no-cache-dir --ignore-installed https://github.com/proppy/xls/releases/download/{xls_version}/xls_colab-0.0.0-py3-none-any.whl
import xls.contrib.colab
_ = xls.contrib.colab.register_dslx_magic()

In [None]:
#@title {run:"auto"}

xls.contrib.colab.pdk = 'asap7' #@param ["asap7", "sky130"] {allow-input: false}

In [None]:
%%dslx --top=adder8 --pipeline_stages=1 --flop_inputs=false --flop_outputs=false
// set `adder8` ↑ as the design entrypoint and specify combinational circuit code generation.

// define a function named `adder8` with two parameters `a` and `b` and one return value (after `->`).
fn adder8(a: u8, b: u8) -> u8 { // u8 is the type for 8-bit unsigned integers
                                // similar type like `u1`, `u2`, `u3`, `s4` (for signed 4-bit integer) are pre-defined
                                // more explicit form like `uN[5]` (equivalent to `u5`) or `bits[6]` (equivalent to `u6`) are also permitted.
  // last expression of a function denotes the return value.
  a + b
}

// define a unittest function named `adders_test` with the `#[test]` annotation.
#[test]
fn adders_test() {
  // assert function return value against the expected value.
  assert_eq(adder8(u8:41, u8:1), u8:42);
}

// the interpreter tab ↓ show the test function passing (all assert succeeded) and the verilog tab shows generated SystemVerilog.

In [None]:
%%dslx --top=adder8_with_carry --pipeline_stages=1 --flop_inputs=false --flop_outputs=false

fn adder8_with_carry(a: u8, b: u8) -> (u8, u1) { // returns two values of type `u8` and `u1` (as a tuple).
  // use `let` to bind the intermediate result to identifier `sum_with_carry`.
  let sum_with_carry: u9 = a as u9 + b as u9; // widen `a` and `b` to 9 bits before adding them.
  let sum = sum_with_carry[0:8]; // access a range of bits from bit `0` (inclusive) to bit `8` (exclusive).
  let carry_bit = sum_with_carry[8+:u1]; // access a range of `u1` bits (1 bit) from bit `8`.

  // return `sum` and `carry_bit` as a tuple of 2 values.
  (sum, carry_bit)
}

#[test]
fn adders_test() {
  // assert function return values when carry = 0
  assert_eq(adder8_with_carry(u8:41, u8:1), (u8:42, u1:0));
  // assert function return values when carry = 1
  assert_eq(adder8_with_carry(u8:255, u8:1), (u8:0 , u1:1));
}

In [None]:
%%dslx --top=adder8_with_carry --pipeline_stages=1 --flop_inputs=false --flop_outputs=false

import std; // import the package `std` to reuse existing functionality.

fn adder8_with_carry(a: u8, b: u8) -> (u8, u1) {
  let sum_with_carry: u9 = std::uadd(a, b); // reuse `uadd` function rather than implementing our own.
  let sum = sum_with_carry[0:8];
  let carry_bit = sum_with_carry[8+:u1];
  (sum, carry_bit)
}

#[test]
fn adders_test() {
  assert_eq(adder8_with_carry(u8:41, u8:1), (u8:42, u1:0));
  assert_eq(adder8_with_carry(u8:255, u8:1), (u8:0 , u1:1));
}

In [None]:
%%dslx --top=adder8_with_carry --pipeline_stages=1 --flop_inputs=false --flop_outputs=false

import std;

// generalize `adder_with_carry` with a bit-length parameter `N` of type u32.
fn adder_with_carry<N: u32>(a: uN[N], b: uN[N]) -> (uN[N], u1) { // result is N bit long, carry is 1 bit long.
  let sum_with_carry = std::uadd(a, b); // std::uadd is also parametric: result type is inferred as uN[N+1].
  let sum = sum_with_carry[0:N as s32]; // slice indexes are signed values.
  let carry_bit = sum_with_carry[N+:u1];
  (sum, carry_bit)
}

fn adder8_with_carry(a: u8, b: u8) -> (u8, u1) {
  // set bit-length parameter explicitly from a literal.
  adder_with_carry<u32:8>(a, b)
}

// define a module level constant for the bit-length.
const BIT_LENGTH = u32:4;

#[test]
fn adders_test() {
  assert_eq(adder_with_carry<BIT_LENGTH>(u4:14, u4:1), (u4:15, u1:0)); // set bit-length parameter explicitly from a constant.
  assert_eq(adder_with_carry(u8:255, u8:1), (u8:0 , u1:1)); // infer bit length parameter from arguments.
}

In [None]:
%%dslx --top=muladd8 --pipeline_stages=2
// set number of pipeline stages ↑ explicitly to 2.

import std;

fn muladd<N: u32>(a: uN[N], b: uN[N], c: uN[N]) -> (uN[N+N], u1) {
  let product = std::umul(a, b); // std::umul is parametric: result type is inferred as uN[N*2].
  let sum_with_carry = std::uadd(product, c);
  let sum = sum_with_carry[0:N as s32 * s32:2]; // slice indexes are signed values.
  let carry_bit = sum_with_carry[N+:u1];
  (sum, carry_bit)
}

fn muladd8(a: u8, b: u8, c: u8) -> (u16, u1) {
  muladd(a, b, c)
}

#[test]
fn adders_test() {
  assert_eq(muladd8(u8:1, u8:2, u8:3), (u16:5 , u1:0));
}

// schedule tab ↓ shows two pipeline stages: with `product` in stage 0 and `sum_with_carry` in stage 1.

In [None]:
%%dslx --top=muladd8 --clock_period_ps=1000
// set target clock frequency to `1000ps` and let XLS schedule the pipeline stages.

import std;

fn muladd<N: u32>(a: uN[N], b: uN[N], c: uN[N]) -> (uN[N+N], u1) {
  let product = std::umul(a, b);
  let sum_with_carry = std::uadd(product, c);
  let sum = sum_with_carry[0:N as s32 * s32:2];
  let carry_bit = sum_with_carry[N+:u1];
  (sum, carry_bit)
}

fn muladd8(a: u8, b: u8, c: u8) -> (u16, u1) {
  muladd(a, b, c)
}

#[test]
fn adders_test() {
  assert_eq(muladd8(u8:1, u8:2, u8:3), (u16:5 , u1:0));
}

// schedule tab ↓ shows the same 2 pipeline stages: with `product` in stage 0 and `sum_with_carry` in stage 1.

In [None]:
%%dslx --top=muladd8 --clock_period_ps=850
// set a more aggressive target clock frequency of `850`

import std;

fn muladd<N: u32>(a: uN[N], b: uN[N], c: uN[N]) -> (uN[N+N], u1) {
  let product = std::umul(a, b);
  let sum_with_carry = std::uadd(product, c);
  let sum = sum_with_carry[0:N as s32 * s32:2];
  let carry_bit = sum_with_carry[N+:u1];
  (sum, carry_bit)
}

fn muladd8(a: u8, b: u8, c: u8) -> (u16, u1) {
  muladd(a, b, c)
}

#[test]
fn adders_test() {
  assert_eq(muladd8(u8:1, u8:2, u8:3), (u16:5 , u1:0));
}

// XLS can't schedule the pipeline stages because of of the cost the multiplier: `883ps` for this technology exceeding the target clock frequency: `850ps`.

In [None]:
%%dslx --top=muladd8 --clock_period_ps=850

import std;

fn muladd<N: u32>(a: uN[N], b: uN[N], c: uN[N]) -> (uN[N], u1) {
  let (product_a, product_b) = umulp(a, b); // uses `umulp` partial multiplication to allow additional pipelining.
  let product = std::uadd(product_a, product_b);
  let product_carry = product[N+:u1];
  let sum_with_carry = std::uadd(product[0+:uN[N]], c);
  let sum = sum_with_carry[0:N as s32]; // slice indexes are signed values.
  let carry_bit = sum_with_carry[N+:u1];
  (sum, carry_bit)
}

fn muladd8(a: u8, b: u8, c: u8) -> (u8, u1) {
  muladd(a, b, c)
}

#[test]
fn adders_test() {
  assert_eq(muladd8(u8:1, u8:2, u8:3), (u8:5 , u1:0));
}

// schedule tab ↓ shows 2 pipeline stages with a shorter path_delay: `801ps`.

In [None]:
%%dslx --top=muladd_accumulate --clock_period_ps=850 --reset=reset
import std;

fn muladd<N: u32>(a: uN[N], b: uN[N], c: uN[N]) -> (u1, uN[N]) {
  let (product_a, product_b) = umulp(a, b);
  let product = std::uadd(product_a, product_b);
  let product_carry = product[N+:u1];
  let sum_with_carry = std::uadd(product[0+:uN[N]], c);
  let sum = sum_with_carry[0:N as s32];
  let carry_bit = sum_with_carry[N+:u1];
  (carry_bit, sum)
}

// define a proc named `muladd_accumulate` implementing a multiplier-accumulator
// procs express stateful sequencial logic similar always_ff block w/ register in SystemVerilog
proc muladd_accumulate {
    // define `input_a and `input_b` as 8-bits input channels
    // channels are uni-directional data fifo that allow communication between procs
    // they materialize as data bus with ready/valid signals in SystemVerilog.
    input_a: chan<u8> in;
    input_b: chan<u8> in;
    // define `output` as a 8-bits outputs channel
    output: chan<(u1, u8)> out;

    // `init` returns `0` as the initial N-bits state (the accumulator value) for the sequencial logic.
    init {
        u8:0
    }

    // `config` takes external channels as parameters and returns channel used by the sequencial logic.
    // order of the return channels has to match the proc channel definitions.
    // the proc can `recv` value from the `in` channels and `send` value to the `out` channel
    // similar to the module input and output definition in SystemVerilog.
    config(input_a: chan<u8> in, input_b: chan<u8> in, output: chan<(u1, u8)> out) {
        (input_a, input_b, output)
    }

    // `next` takes the current state (the current accumulated value) as a parameter
    // and returns the next state (the newly accumulated value).
    // the `tok` parameter is used to synchronize and order I/O operations on the channels of the proc
    // if two I/O operations use the same token: they will happen concurrently.
    // each operation return a new `token` that can be used to sequence further operations.
    next(tok: token, acc: u8) {
        // receive one N-bits value from the input channel.
        let (tok_a, a) = recv(tok, input_a);
        // receive one N-bits value from the second input channel.
        // `tok` is used in both `recv`s so those happen concurrently.
        let (tok_b, b) = recv(tok, input_b);
        // wait for both receive operation to complete.
        let tok_c = join(tok_a, tok_b);
        // multiply `a` and `b` and add (accumulate) them to the previous `acc` value.
        let (c, n) = muladd(a, b, acc);
        // send the result and carry to the output channel.
        send(tok_c, output, (c, n));
        // return the accumulated result as the new state.
        n
    }
}

In [None]:
%%dslx --top=muladd_accumulate --clock_period_ps=850 --reset=reset
import std;

fn muladd<N: u32>(a: uN[N], b: uN[N], c: uN[N]) -> (u1, uN[N]) {
  let (product_a, product_b) = umulp(a, b);
  let product = std::uadd(product_a, product_b);
  let product_carry = product[N+:u1];
  let sum_with_carry = std::uadd(product[0+:uN[N]], c);
  let sum = sum_with_carry[0:N as s32];
  let carry_bit = sum_with_carry[N+:u1];
  (carry_bit, sum)
}

proc muladd_accumulate {
    input_a: chan<u8> in;
    input_b: chan<u8> in;
    output: chan<(u1, u8)> out;

    init {
        u8:0
    }

    config(input_a: chan<u8> in, input_b: chan<u8> in, output: chan<(u1, u8)> out) {
        (input_a, input_b, output)
    }

    next(tok: token, acc: u8) {
        let (tok_a, a) = recv(tok, input_a);
        let (tok_b, b) = recv(tok, input_b);
        let tok_c = join(tok_a, tok_b);
        let (c, n) = muladd(a, b, acc);
        send(tok_c, output, (c, n));
        n
    }
}

// define a test proc to test `muladd_accumulate` proc's sequential logic.
// note: XLS does not currently generate the equivalent SystemVerilog testbed.
#[test_proc]
proc muladd_accumulate_test {
    // define channels to communicate with the proc under test
    input_a_s: chan<u8> out;
    input_b_s: chan<u8> out;
    output_r: chan<(u1, u8)> in;
    // define one output channel to terminate the test proc.
    terminator: chan<bool> out;

    init {
        () // no state
    }

    // `config` takes the `terminator` output channel as an argument.
    config(terminator: chan<bool> out) {
      // define a channel pair (sender, receiver) to communicate `input_a` values.
      let (input_a_s, input_a_r) = chan<u8>;
      // define a channel pair (sender, receiver) to communicate `input_b` values.
      let (input_b_s, input_b_r) = chan<u8>;
      // define a channel pair (sender, receiver) to communicate `output` values.
      let (output_s, output_r) = chan<(u1, u8)>;
      // spawn the `muladd_accumulate` with the receiving end of `input_a` and `input_b` channels (to receive inputs)
      // and the sending end of the `output` channels (to send the result).
      // this is equivalent to "instanciating" the underlying  SystemVerilog module of the corresponding proc.
      spawn muladd_accumulate(input_a_r, input_b_r, output_s);
      // return the sending end of `input_a` and `input_b` channels (to send inputs),
      // the receiving end of the `output` channels (to get the result)
      // and the terminator channel.
      (input_a_s, input_b_s, output_r, terminator)
    }

    // `next` performs the actual the test.
    next(tok: token, state: ()) {
      // send `1` as the first input.
      let tok_a = send(tok, input_a_s, u8:1);
      // send `2` as the second input (concurrently).
      let tok_b = send(tok, input_b_s, u8:2);
      // wait for both send to complete (and overwrite the existing `tok` binding).
      let tok = join(tok_a, tok_b);
      // receive and assert the result value.
      let (tok, (c, n)) = recv(tok, output_r);
      assert_eq(n, u8:2);
      assert_eq(c, u1:0);

      // send `3` as the first input.
      let tok_a = send(tok, input_a_s, u8:3);
      // send `4` as the second input (concurrently).
      let tok_b = send(tok, input_b_s, u8:4);
      // wait for both send to complete (and overwrite the existing `tok` binding).
      let tok = join(tok_a, tok_b);
      // receive and assert the accumulated value.
      let (tok, (c, n)) = recv(tok, output_r);
      assert_eq(n, u8:14);
      assert_eq(c, u1:0);

      // terminate the test proc.
      send(tok, terminator, true);
    }
}

In [None]:
%%dslx --top=muladd_accumulate --clock_period_ps=850 --reset=reset
import std;

fn muladd<N: u32>(a: uN[N], b: uN[N], c: uN[N]) -> (u1, uN[N]) {
  let (product_a, product_b) = umulp(a, b);
  let product = std::uadd(product_a, product_b);
  let product_carry = product[N+:u1];
  let sum_with_carry = std::uadd(product[0+:uN[N]], c);
  let sum = sum_with_carry[0:N as s32];
  let carry_bit = sum_with_carry[N+:u1];
  (carry_bit, sum)
}

proc muladd_accumulate {
    input_a: chan<u8> in;
    input_b: chan<u8> in;
    // define `reset` as a 1-bit input channel
    reset: chan<bool> in;
    output: chan<(u1, u8)> out;

    init {
        u8:0
    }

    config(input_a: chan<u8> in, input_b: chan<u8> in, reset: chan<bool> in, output: chan<(u1, u8)> out) {
        (input_a, input_b, reset, output)
    }

    next(tok: token, acc: u8) {
        let (tok_a, a) = recv(tok, input_a);
        let (tok_b, b) = recv(tok, input_b);
        let tok_c = join(tok_a, tok_b);
        let (c, n) = muladd(a, b, acc);
        let tok_send = send(tok_c, output, (c, n));

        // use `recv_non_blocking` to conditionally recv from a channel w/ a default value as the last parameter.
        // returned tuple includes a `vld` flag if a new value was received from the channel.
        let (tok_reset, do_reset, do_reset_vld) = recv_non_blocking(tok, reset, false);
        if (do_reset_vld && do_reset) {
          u8:0 // return 0 as the new state (effectively reseting the accumulator).
        } else {
          n  // return the accumulated result as the new state.
        }
    }
}

#[test_proc]
proc muladd_accumulate_test {
    input_a_s: chan<u8> out;
    input_b_s: chan<u8> out;
    // define an additional channel output for sending reset.
    reset_s: chan<bool> out;
    output_r: chan<(u1, u8)> in;
    terminator: chan<bool> out;

    init {
        ()
    }

    config(terminator: chan<bool> out) {
      let (input_a_s, input_a_r) = chan<u8>;
      let (input_b_s, input_b_r) = chan<u8>;
      // define a channel pair (sender, receiver) to communicate `reset` values.
      let (reset_s, reset_r) = chan<bool>;
      let (output_s, output_r) = chan<(u1, u8)>;

      spawn muladd_accumulate(input_a_r, input_b_r, reset_r, output_s);
      (input_a_s, input_b_s, reset_s, output_r, terminator)
    }

    next(tok: token, state: ()) {
      let tok_a = send(tok, input_a_s, u8:1);
      let tok_b = send(tok, input_b_s, u8:2);
      let tok = join(tok_a, tok_b);
      let (tok, (c, n)) = recv(tok, output_r);
      assert_eq(n, u8:2);
      assert_eq(c, u1:0);

      let tok_a = send(tok, input_a_s, u8:3);
      let tok_b = send(tok, input_b_s, u8:4);
      let tok_reset = send(tok, reset_s, true);
      // wait for all sends to complete (and overwrite the existing `tok` binding).
      let tok = join(tok_a, tok_b, tok_reset);
      let (tok, (c, n)) = recv(tok, output_r);
      assert_eq(n, u8:14);
      assert_eq(c, u1:0);

      // send `5` as the first input.
      let tok_a = send(tok, input_a_s, u8:1);
      // send `6` as the second input (concurrently).
      let tok_b = send(tok, input_b_s, u8:2);
      // wait for both send to complete (and overwrite the existing `tok` binding).
      let tok = join(tok_a, tok_b);
      // receive and assert the accumulated value.
      let (tok, (c, n)) = recv(tok, output_r);
      // assert that accumulator was reset.
      assert_eq(n, u8:2);
      assert_eq(c, u1:0);

      // terminate the test proc.
      send(tok, terminator, true);
    }
}

In [None]:
%%dslx --top=muladd_accumulate --clock_period_ps=850 --reset=reset
import std;

fn muladd<N: u32>(a: uN[N], b: uN[N], c: uN[N]) -> (u1, uN[N]) {
  let (product_a, product_b) = umulp(a, b);
  let product = std::uadd(product_a, product_b);
  let product_carry = product[N+:u1];
  let sum_with_carry = std::uadd(product[0+:uN[N]], c);
  let sum = sum_with_carry[0:N as s32];
  let carry_bit = sum_with_carry[N+:u1];
  (carry_bit, sum)
}

// define `Op` enum for supporting multiple type of operations.
enum Op: u2 {
  MUL = 0,
  MUL_ACC = 1,
  RESET_ACC = 2,
}

// define `Command` struct to encapsulate op and their operands.
struct Command {
  op: Op,
  a: u8,
  b: u8,
}

proc muladd_accumulate {
    // define a single input channel.
    input: chan<Command> in;
    output: chan<(u1, u8)> out;

    init {
        u8:0
    }

    config(input: chan<Command> in, output: chan<(u1, u8)> out) {
        (input, output)
    }

    next(tok: token, acc: u8) {
        // receive the current `command`.
        let (tok, command) = recv(tok, input);
        // match the command `op`s with the appropriate expression
        // and compute the result and the new accumulator value accordingly.
        // use `'` naming suffix convention for denoting a new "version" of an existing binding.
        let (c, n, acc') = match (command.op) {
          Op::MUL => {
            let (c, n) = std::umul_with_overflow<u32:8>(command.a, command.b);
            // keep current `acc` value.
            (c, n, acc)
          },
          Op::MUL_ACC => {
            let (c, n) = muladd(command.a, command.b, acc);
            // use `muladd` result  as the new `acc` value.
            (c, n, n)
          },
          Op::RESET_ACC => {
            // reset `acc` value.
            (false, u8:0, u8:0)
          },
          // catch all pattern to covering unsupported ops.
          _ => fail!("unsupported_op", (false, u8:0, u8:0))
        };
        send(tok, output, (c, n));
        // return new accumulator value as the new state.
        acc'
    }
}

#[test_proc]
proc muladd_accumulate_test {
    input_s: chan<Command> out;
    output_r: chan<(u1, u8)> in;
    terminator: chan<bool> out;

    init {
        ()
    }

    config(terminator: chan<bool> out) {
      let (input_s, input_r) = chan<Command>;
      let (output_s, output_r) = chan<(u1, u8)>;

      spawn muladd_accumulate(input_r, output_s);
      (input_s, output_r, terminator)
    }

    next(tok: token, state: ()) {
      // send a `MUL` command.
      let tok = send(tok, input_s, Command{
        op: Op::MUL, a: u8:1, b: u8:2
      });
      let (tok, (c, n)) = recv(tok, output_r);
      assert_eq(n, u8:2);
      assert_eq(c, u1:0);

      // send a `MUL_ACC` command.
      let tok = send(tok, input_s, Command{
        op: Op::MUL_ACC, a: u8:1, b: u8:2
      });
      let (tok, (c, n)) = recv(tok, output_r);
      // assert that previous `Mul` operation was not accumulated.
      assert_eq(n, u8:2);
      assert_eq(c, u1:0);

      // send a `RESET_ACC` command.
      let tok = send(tok, input_s, Command{
        op: Op::RESET_ACC, a: u8:0, b: u8:0
      });
      let (tok, (c, n)) = recv(tok, output_r);
      assert_eq(n, u8:0);
      assert_eq(c, u1:0);

      // send a `MUL_ACC` command.
      let tok = send(tok, input_s, Command{
        op: Op::MUL_ACC, a: u8:1, b: u8:2
      });
      let (tok, (c, n)) = recv(tok, output_r);
      // assert that accumulator was reset.
      assert_eq(n, u8:2);
      assert_eq(c, u1:0);

      // terminate the test proc.
      send(tok, terminator, true);
    }
}

# Bonus examples

## simple counter

In [None]:
%%dslx --top=counter --clock_period_ps=850 --reset=reset

proc counter {
  output: chan<u8> out;

  init {
    u8:0
  }

  config(output: chan<u8> out) {
    (output,)
  }

  next(tok: token, state: u8) {
    let tok = send(tok, output, state);
    state + u8:1
  }
}

#[test_proc]
proc test {
  counter_r: chan<u8> in;
  terminator: chan<bool> out;

  init { () }

  config(terminator: chan<bool> out) {
    let (counter_s, counter_r) = chan<u8>;
    spawn counter(counter_s);
    (counter_r, terminator)
  }

  next(tok: token, state: ()) {
    let (tok, count) = recv(tok, counter_r);
    assert_eq(count, u8:0);
    let (tok, count) = recv(tok, counter_r);
    assert_eq(count, u8:1);
    let (tok, count) = recv(tok, counter_r);
    assert_eq(count, u8:2);
    let (tok, count) = recv(tok, counter_r);
    assert_eq(count, u8:3);
    let tok = send(tok, terminator, true);
  }
}

## streaming protobuf varint decoder

In [None]:
%%dslx --top=varint_streaming_u32_decode --top=varint_streaming_u32_decode --reset=rst --clock_period_ps=660 --worst_case_throughput=4
import std;

const INPUT_BYTES = u32:5;
const OUTPUT_WORDS = u32:2;
const BIG_SHIFT = u32:4;
const INPUT_BYTES_WIDTH = std::clog2(INPUT_BYTES + u32:1);
const OUTPUT_WORDS_WIDTH = std::clog2(OUTPUT_WORDS + u32:1);
const BIG_SHIFT_WIDTH = std::clog2(BIG_SHIFT);
const COMBINED_BYTES = INPUT_BYTES + u32:4;
const COMBINED_BYTES_WIDTH = std::clog2(COMBINED_BYTES + u32:1);

pub fn varint_decode_u32<NUM_BYTES:u32={std::clog2(u32:32)},
                         LEN_WIDTH:u32={std::clog2(NUM_BYTES)}>(
 bytes: u8[NUM_BYTES]) -> (u32, uN[LEN_WIDTH]) {
  type LenType = uN[LEN_WIDTH];

  let (chunks, last_chunk, saw_last_chunk) =
   for (i, (chunks, last_chunk, saw_last_chunk)):
   (u32, (u7[NUM_BYTES], LenType, bool)) in u32:0..NUM_BYTES {
    let current_byte = bytes[i];
    let lsbs = current_byte as u7;
    let msb = current_byte[-1:];
    if saw_last_chunk {
      (chunks, last_chunk, saw_last_chunk)
    } else {
      (update(chunks, i, lsbs), i as LenType, !msb)
    }
  }((u7[NUM_BYTES]:[u7:0, ...], LenType:0, bool:0));

  if !saw_last_chunk { fail!("did_not_see_last_chunk", ()) } else { () };

  const FLATTENED_CHUNK_BITS = NUM_BYTES * u32:7;
  type FlattenedChunkType = uN[FLATTENED_CHUNK_BITS];
  let flattened: FlattenedChunkType =
  for (i, flattened): (u32, FlattenedChunkType) in u32:0..NUM_BYTES {
    flattened | ((chunks[i]  as FlattenedChunkType) << (u32:7 * i))
  }(zero!<FlattenedChunkType>());

  const NUM_EXTRA_BITS = FLATTENED_CHUNK_BITS - u32:32;
  let msbs = (flattened >> u32:32) as uN[NUM_EXTRA_BITS];

  if msbs != uN[NUM_EXTRA_BITS]:0 {
    fail!("did_not_fit_in_u32", ())
  } else { () };

  (flattened as u32, last_chunk + u3:1)
}

#[test]
fn varint_decode_u32_test() {
  let (decoded, consumed) = varint_decode_u32(u8[5]:[172, 2, 0, 0, 0]);
  assert_eq(u32:300, decoded);
  assert_eq(u3:2, consumed);
  let (decoded, consumed) = varint_decode_u32(u8[5]:[172, 2, 172, 2, 0]);
  assert_eq(u32:300, decoded);
  assert_eq(u3:2, consumed);
}

// Convenience for use with map().
fn not(x: bool) -> bool { !x }

// Convenience to statically shift byte array left, filling with 'fill'.
fn byte_array_shl<SHIFT:u32, N:u32>(bytes: u8[N], fill: u8) -> u8[N] {
  for (i, arr): (u32, u8[N]) in u32:0..N {
    update(arr, i, if i + SHIFT < N { bytes[i + SHIFT] } else { fill })
  }(u8[N]:[u8:0, ...])
}

struct State<NUM_BYTES:u32, NUM_BYTES_WIDTH:u32> {
  work_chunk: u8[NUM_BYTES],       // Chunk of bytes currently being worked on.
  len: uN[NUM_BYTES_WIDTH],        // Number of valid bytes in work_chunk
  old_bytes: u8[4],                // Leftover bytes from a previous work chunk. Guaranteed not to
                                   // have any varint terminators, else they'd already be decoded.
  old_bytes_len: u3,               // Number of valid bytes in old_bytes_len. Invalid bytes start
                                   // with index 0.
  drop_count: uN[NUM_BYTES_WIDTH], // Number of bytes to drop from work_chunk/input.
}

pub proc varint_streaming_u32_decode {
  bytes_in: chan<(u8[INPUT_BYTES], uN[INPUT_BYTES_WIDTH])> in;
  words_out: chan<(u32[OUTPUT_WORDS], uN[OUTPUT_WORDS_WIDTH])> out;

  config(bytes_in: chan<(u8[INPUT_BYTES], uN[INPUT_BYTES_WIDTH])> in,
         words_out: chan<(u32[OUTPUT_WORDS], uN[OUTPUT_WORDS_WIDTH])> out) {
    (bytes_in, words_out)
  }

  init {
    type MyState = State<INPUT_BYTES, INPUT_BYTES_WIDTH>;
    zero!<MyState>()
  }

  next(tok: token, state: State<INPUT_BYTES, INPUT_BYTES_WIDTH>) {
    const_assert!(INPUT_BYTES >= OUTPUT_WORDS);
    const_assert!(BIG_SHIFT > u32:1 && BIG_SHIFT < INPUT_BYTES);

    type OutputWordArray = u32[OUTPUT_WORDS];
    type OutputIdx = uN[OUTPUT_WORDS_WIDTH];
    type InputIdx = uN[INPUT_BYTES_WIDTH];
    type BigShiftIdx = uN[BIG_SHIFT_WIDTH];
    type CombinedIdx = uN[COMBINED_BYTES_WIDTH];

    let terminators: bool[INPUT_BYTES] =  // terminators[i] -> word_chunk[i] terminates a varint
      map(map(state.work_chunk, std::is_unsigned_msb_set), not);
    // Remove terminators on invalid bytes.
    let terminators = for (i, terminators): (u32, bool[INPUT_BYTES]) in u32:0..INPUT_BYTES {
      if i < state.len as u32 { terminators } else { update(terminators, i, false) }
    }(terminators);
    let num_terminators = std::popcount(std::convert_to_bits_msb0(terminators)) as InputIdx;

    // Find the index of the last terminator that will be decoded in this proc iteration.
    // We can decode OUTPUT_WORDS varints per iteration, so count up to OUTPUT_WORDS and stop.
    let (_, last_terminator_idx): (OutputIdx, u32) =
    for (idx, (word_count, last_idx)): (u32, (OutputIdx, u32)) in u32:0..INPUT_BYTES {
      if terminators[idx] && word_count as u32 < OUTPUT_WORDS {
        (word_count + OutputIdx:1, idx)
      } else { (word_count, last_idx) }
    }((OutputIdx:0, u32:0));

    // Get a new input once we've processed the entire work chunk.
    let do_input = state.len == InputIdx:0;
    let (input_tok, (input_data, input_len)) = recv_if(
      tok, bytes_in, do_input, (u8[INPUT_BYTES]:[u8:0, ...], InputIdx:0));

    let do_drop = state.drop_count != InputIdx:0;
    if do_input && do_drop { fail!("input_and_drop", ()) } else { () };

    // Each iteration, we either shift by 1 or BIG_SHIFT. Do the shifts now and select later.
    let work_chunk_shl_1 = byte_array_shl<u32:1>(state.work_chunk, u8:0);
    let work_chunk_shl_big = byte_array_shl<BIG_SHIFT>(state.work_chunk, u8:0);

    // Do a big shift if we're dropping a big shift's worth of bytes or if the last terminator is
    // at or after the end of the big shift's window.
    let do_big_shift = (do_drop && state.drop_count as u32 >= BIG_SHIFT) ||
                       (!do_drop && last_terminator_idx as u32 >= BIG_SHIFT - u32:1);

    // compute next state
    let next_drop_count = if do_input {
      InputIdx:0
    } else if do_big_shift {
      InputIdx:1 + last_terminator_idx as InputIdx - BIG_SHIFT as InputIdx
    } else if do_drop {
      state.drop_count - InputIdx:1
    } else {
      last_terminator_idx as InputIdx
    };

    let next_len = if do_input {
      input_len
    } else if do_big_shift {
      state.len - BIG_SHIFT as InputIdx
    } else if do_drop {
      state.len - InputIdx:1
    } else {
      state.len - InputIdx:1
    };
    // word_chunk is either set to input, word_chunk << 1, or word_chunk << BIG_SHIFT
    let next_word_chunk = if do_input {
      input_data
     } else if do_big_shift {
      work_chunk_shl_big
     } else if do_drop {
      work_chunk_shl_1
     } else {
      work_chunk_shl_1
     };

    let (next_old_bytes, next_old_bytes_len) = if do_input {
      (state.old_bytes, state.old_bytes_len)
    } else if do_big_shift {
      (state.old_bytes, u3:0)
    } else if num_terminators == InputIdx:0 {
      let next_old_bytes_len = if state.old_bytes_len < u3:4 { state.old_bytes_len + u3:1 } else { state.old_bytes_len };
      (byte_array_shl<u32:1>(state.old_bytes, state.work_chunk[0]), next_old_bytes_len)
    } else {
      (state.old_bytes, u3:0)
    };
    let next_state = State {
      work_chunk: next_word_chunk,
      len: next_len,
      old_bytes: next_old_bytes,
      old_bytes_len: next_old_bytes_len,
      drop_count: next_drop_count,
    };

    // Compute and output decoded varints.
    let bytes = state.old_bytes ++ state.work_chunk;
    let bytes_len = state.len as CombinedIdx + CombinedIdx:4;

    let (output_words, num_output_words, _) =
    for (i, (output_words, num_output_words, bytes_taken)):
     (u32, (OutputWordArray, OutputIdx, CombinedIdx)) in u32:0..OUTPUT_WORDS {
      let idx = bytes_taken as u32;
      let encoded = for (j, encoded): (u32, u8[5]) in u32:0..u32:5 {
        let val = if j + idx < COMBINED_BYTES { bytes[j + idx] } else { u8: 0 };
        update(encoded, j, val)
      }(u8[5]:[u8:0, ...]);
      let (decoded, this_bytes_taken) = varint_decode_u32(encoded);
      let total_bytes_taken = bytes_taken + this_bytes_taken as CombinedIdx;
      if total_bytes_taken <= bytes_len {
        (
          update(output_words, i, decoded),
          num_output_words + OutputIdx:1,
          bytes_taken + this_bytes_taken as CombinedIdx,
        )
      } else {
        (output_words, num_output_words, bytes_taken)
      }
    }((zero!<OutputWordArray>(), OutputIdx:0, (u3:4 - state.old_bytes_len) as CombinedIdx));

    let output_tok = send_if(
      input_tok,
      words_out,
      !do_input && !do_drop && num_output_words > OutputIdx:0,
      (output_words, num_output_words));

    next_state
  }
}
