## Agile Hardware Design
***
# Functional Programming (continued)

## Prof. Scott Beamer
### sbeamer@ucsc.edu

## [CSE 293](https://classes.soe.ucsc.edu/cse293/Winter22/)

## Plan for Today

* reduce + fold
* Scala type signatures
* zipWithIndex
* FP considerations

## Loading The Chisel Library Into a Notebook

In [None]:
val path = System.getProperty("user.dir") + "/../resource/chisel_deps.sc"
interp.load.module(ammonite.ops.Path(java.nio.file.FileSystems.getDefault().getPath(path)))

In [None]:
import chisel3._
import chisel3.util._
import chisel3.tester._
import chisel3.tester.RawTester.test

## Motivation for `reduce` and `fold`

* In the last lecture, we applied a function to each element (e.g. `map`, `foreach`, `zip`)
  * Resulting collection (if there is one), has same size as input collection

* What about if we want to combine things (collapse)?

* How do we gracefully handle collapsing with an empty collection?

## Scala `reduce`

* Given a binary operator, it applies it on collection until down to one element

* Can use the placeholder syntax to have concise expressions

In [None]:
val l = Seq(0,1,2,3,4,5)
l reduce { _ + _ }
l.reduce{(a,b) => a+b}

val squares = l map { i => i*i }
val sumOfSquares = squares reduce { _ + _ }
l map { i => i*i } reduce { _ + _ }

## Tweaking Our Arbiter with FP (1/2) - original

In [None]:
class MyArb(numPorts: Int, n: Int) extends Module {
    val io = IO(new Bundle {
        val req = Flipped(Vec(numPorts, Decoupled(UInt(n.W))))
        val out = Decoupled(UInt(n.W))
    })
    require (numPorts > 0)
    val inValids = Wire(Vec(numPorts, Bool()))
    val inBits   = Wire(Vec(numPorts, UInt(n.W)))
    val chosenOH = PriorityEncoderOH(inValids)
    for (p <- 0 until numPorts) {
        io.req(p).ready := chosenOH(p) && io.out.fire
        inValids(p) := io.req(p).valid
        inBits(p) := io.req(p).bits
    }
    io.out.valid := inValids.asUInt.orR
    io.out.bits := Mux1H(chosenOH, inBits)
}

## Tweaking Our Arbiter with FP (2/2) - with FP + reduce

In [None]:
class MyArb(numPorts: Int, n: Int) extends Module {
    val io = IO(new Bundle {
        val req = Flipped(Vec(numPorts, Decoupled(UInt(n.W))))
        val out = Decoupled(UInt(n.W))
    })
    require (numPorts > 0)
    val inValids = io.req map { _.valid }
//     io.out.valid := VecInit(inValids).asUInt.orR
    io.out.valid := inValids reduce { _ || _ }
    val chosenOH = PriorityEncoderOH(inValids)
    io.out.bits := Mux1H(chosenOH, io.req map { _.bits })
    io.req.zip(chosenOH) foreach { case (i, c) => i.ready := c && io.out.fire}
}

## How Do You Reduce on 0 Elements?

* What should `reduce` return when the collection has 0 elements?

* Alternatively, what if we want to collapse a collection into a different type?

## Scala `foldLeft`

* Given initial value and operator, applies _left to right_
  * "Left" is element 0, i.e. in iterable's order
* Can be used to implement `reduce`
* Can return a type different than initial collection

In [None]:
val l = Seq(1,2,3,4,5)
l.foldLeft(0)((totalSoFar, elem) => totalSoFar + elem)
l.foldLeft(0)(_ + _)
l reduce { _ + _ }
l.sum

def myMax(maxSoFar: Int, x: Int) = if (maxSoFar > x) maxSoFar else x
val maxTheHardWay = l.foldLeft(0)(myMax)
l.max

## Visualizing `foldLeft` & `foldRight`

<img src="images/folds.svg" alt="foldLeft & foldRight" style="width:70%;margin-left:auto;margin-right:auto"/>

## Brief Detour: Currying (functions) in Scala

* Multiple argument lists to a function
* We have seen it and used it without talking about it yet
  * e.g. `Seq.fill(4)(0)`
* Can create partially applied functions to pass to FP operation

In [None]:
def sum(a: Int, b: Int) = a + b

def plusX(x: Int)(b: Int) = x + b

val f = plusX(1)_

f(2)

plusX(1)(2)

Seq(0,1,2,3,4) map plusX(10)

## Brief Detour: Scala Function Signatures

<img src="images/map.png" alt="map signature" style="width:70%;margin-left:auto;margin-right:auto"/>
<img src="images/foldLeft.png" alt="foldLeft signature" style="width:70%;margin-left:auto;margin-right:auto"/>

* Screenshots from language API docs, will want to peruse for available FP operations
* Square brackets `[]` indicate parameterized types, and often type inference determines them (e.g. `A`)
* Recognize these operations take in functions (as `op`): (_input arg types_) `=>` _return type_

## `reduce`X vs `fold`Y

* All 6 variants exist (`reduce`, `reduceLeft`, `reduceRight`, `fold`, `foldLeft`, `foldRight`)

* Directions give explicit evaluation order, otherwise unspecified

* In practice, `foldLeft` is often most versatile/appropriate, but brevity of `reduce` makes it tempting
  * Typically use `reduce` to collapse, but `foldLeft` to do it in deliberate order

* Can use `foldRight` and `reduceRight` to effectively do things in reverse (can also use `.reverse`)

## Redoing Reducer with `reduce`

In [None]:
class Reducer(n: Int, m: Int) extends Module {
    val io = IO(new Bundle {
        val in  = Input(Vec(n, UInt(m.W)))
        val out = Output(UInt(m.W))
    })
    require(n > 0)
//     var totalSoFar = io.in(0)
//     for (i <- 1 until n)
//         totalSoFar = io.in(i) + totalSoFar
//     io.out := totalSoFar
    io.out := io.in.reduce{ _ + _ }
}
println(getVerilog(new Reducer(4,2)))

## Redoing DelayN (Pipe) with `foldLeft`

In [None]:
class DelayNCycles(n: Int) extends Module {
    val io = IO(new Bundle {
        val in  = Input(Bool())
        val out = Output(Bool())
    })
    require(n >= 0)
//     def helper(n: Int, lastConn: Bool): Bool = {
//         if (n == 0) lastConn
//         else helper(n-1, RegNext(lastConn))
//     }
//     io.out := helper(n, io.in)
    io.out := (0 until n).foldLeft(io.in){(lastConn,i) => RegNext(lastConn)}
}
println(getVerilog(new DelayNCycles(3)))

## Scala `zipWithIndex`

* Sometimes want to have access to index when performing FP op
  * Analogous to `enumerate` in Python

<img src="images/zipWithIndex.svg" alt="zipWithIndex" style="width:40%;margin-left:auto;margin-right:auto"/>

In [None]:
val l = Seq(5,6,7,8)
l.zip(0 until l.size)
l.zipWithIndex
l.zipWithIndex.map{ t => t._1 * t._2 }
l.zipWithIndex.map{ case (x, i) => x * i }

## One-Hot Priority Encoder (with muxes) Redone with FP

In [None]:
class MyPriEncodeOH(n: Int) extends Module {
    val io = IO(new Bundle {
        val in  = Input(UInt(n.W))
        val out = Output(UInt())
    })
    require (n > 0)
//     def withMuxes(index: Int): UInt = {
//         if (index < n) Mux(io.in(index), (1 << index).U, withMuxes(index+1))
//         else 0.U
//     }
//     io.out := withMuxes(0)
    io.out := io.in.asBools.zipWithIndex.reverse.foldLeft(0.U) {
        case (soFar, (b, index)) => Mux(b, (1 << index).U, soFar)
//         case ((b, index), soFar) => Mux(b, (1 << index).U, soFar)
    }
//    io.out := PriorityEncoderOH(io.in)    // Standard Library
//     printf("%b -> %b\n", io.in, io.out)
}

println(getVerilog(new MyPriEncodeOH(3)))
// test(new MyPriEncodeOH(3)) { c =>
//     for (i <- 0 until 8) {
//         c.io.in.poke(i.U)
//         c.clock.step()
//     }
// }

## Redoing Crossbar with FP (1/4) - IO decs

In [None]:
class Message(numOuts: Int, length: Int) extends Bundle {
    val addr = UInt(log2Ceil(numOuts+1).W)
    val data = UInt(length.W)
}

class XBarIO(numIns: Int, numOuts: Int, length: Int) extends Bundle {
    val in  = Vec(numIns, Flipped(Decoupled(new Message(numOuts, length))))
    val out = Vec(numOuts, Decoupled(new Message(numOuts, length)))
}

## Redoing Crossbar with FP (2/4) - inner loops only

In [None]:
class XBar(numIns: Int, numOuts: Int, length: Int) extends Module {
    val io = IO(new XBarIO(numIns, numOuts, length))
    val arbs = Seq.fill(numOuts)(Module(new RRArbiter(new Message(numOuts, length), numIns)))
    for (ip <- 0 until numIns) {
        val inReadys = Wire(Vec(numOuts, Bool()))
        for (op <- 0 until numOuts) {
            inReadys(op) := arbs(op).io.in(ip).ready
        }
        io.in(ip).ready := inReadys.asUInt.orR
//         io.in(ip).ready := arbs.map{ _.io.in(ip).ready }.reduce{ _ || _ }
    }
    for (op <- 0 until numOuts) {
        for (ip <- 0 until numIns) {
            arbs(op).io.in(ip).bits <> io.in(ip).bits
            arbs(op).io.in(ip).valid := io.in(ip).valid && (io.in(ip).bits.addr === op.U)
        }
//         arbs(op).io.in.zip(io.in).foreach { case (arbIn, ioIn) =>
//             arbIn.bits <> ioIn.bits
//             arbIn.valid := ioIn.valid && (ioIn.bits.addr === op.U)
//         }
        io.out(op) <> arbs(op).io.out
    }
    for (op <- 0 until numOuts) {
        printf(" %d -> %d (%b)", io.out(op).bits.data, op.U, io.out(op).valid)
    }
    printf("\n")
}

## Redoing Crossbar with FP (3/4) - all loops

In [None]:
class XBar(numIns: Int, numOuts: Int, length: Int) extends Module {
    val io = IO(new XBarIO(numIns, numOuts, length))
    val arbs = Seq.fill(numOuts)(Module(new RRArbiter(new Message(numOuts, length), numIns)))
    for (ip <- 0 until numIns) {
        io.in(ip).ready := arbs.map{ _.io.in(ip).ready }.reduce{ _ || _ }
    }
//     io.in.zipWithIndex.foreach { case (in, ip) =>
//         in.ready := arbs.map{ _.io.in(ip).ready }.reduce{ _ || _ }
//     }
    for (op <- 0 until numOuts) {
        arbs(op).io.in.zip(io.in).foreach { case (arbIn, ioIn) =>
            arbIn.bits <> ioIn.bits
            arbIn.valid := ioIn.valid && (ioIn.bits.addr === op.U)
        }
        io.out(op) <> arbs(op).io.out
    }
//     io.out.zip(arbs).zipWithIndex.foreach { case ((ioOut, arbOut), op) =>
//         arbOut.io.in.zip(io.in).foreach { case (arbIn, ioIn) =>
//             arbIn.bits <> ioIn.bits
//             arbIn.valid := ioIn.valid && (ioIn.bits.addr === op.U)
//         }
//         ioOut <> arbOut.io.out
//     }
    for (op <- 0 until numOuts) {
        printf(" %d -> %d (%b)", io.out(op).bits.data, op.U, io.out(op).valid)
    }
//     io.out.zipWithIndex.foreach{
//         case (outP, op) => printf(" %d -> %d (%b)", outP.bits.data, op.U, outP.valid)
//     }
    printf("\n")
}

## Redoing Crossbar with FP (4/4) - Tests

In [None]:
val numIns = 4
val numOuts = 2
test(new XBar(numIns,numOuts,8)) { c =>
    for (ip <- 0 until numIns) {
        c.io.in(ip).valid.poke(true.B)
        c.io.in(ip).bits.data.poke(ip.U)
        c.io.in(ip).bits.addr.poke((ip % numOuts).U)
    }
    for (op <- 0 until numOuts) {
        c.io.out(op).ready.poke(true.B)
    }
    for (cycle <- 0 until 4) {
        c.clock.step()
    }
}

## Only Use FP When it is an Improvement!

* FP used well...
  * Leverages FP operation to execute commmon pattern
  * Improves readability and simplifies code

* FP used over-eagerly...
  * Harder to read/understand
  * Brittle

* Consider...
  * Would a simple for loop or even recursion be more clear?
  * Limit self to 2-3 FP operations per line at most
  * Multiple lines for the function literal?
    * Maybe pull into a named helper function or fall back to _for_