[Back to table of contents](0_Table_of_Contents.ipynb)

# Supporting a custom L0 RISC-V accelerator


Let's assume the following "fell from the sky" (a higher plane of abstraction):

 - A program compiled half-way down to RISC-V Assembly
 - The format of this program is what we call "RISC-V SSA"

Let's not worry too much about what all this means, all will be explained in due time. Have faith!

Let's have a look at what we got:

In [None]:
import xdsl, riscemu
from riscv.higher_plane_of_abstraction import module, printer

printer.print(module)

Okay. This *really* doesn't look like RISC-V. But if I look at it for a while, I can see some familiar stuff:


```
    %0 = "riscv_ssa.li"() {"immediate" = 82 : i32}
    ^^              ^^                   ^^
    Result?         Op-name              Argument
```

We seem to have the `li` pseudo-op here, with an immediate value of `82`. It's result is stored in `%0`.

But what register is `%0`?

## A short (and incomplete) introduction to SSA

SSA stands for "Static Single Assignment". Our compiler framework uses it.

 - Each variable is assigned exactly once!
 - We have infinitely many variables (or registers)

So `%0` is just an SSA variable. The type of the variable is `#riscv_ssa.reg`, which is a convoluted way of saying that it represents a register. The compiler just hasn't made up its mind yet which specific register it means.

This abstraction will enable us to do some powerful stuff later on.


Anyways, we try not to interact with SSA much **yet**, so let's print it in a more accessible way:

In [None]:
from riscv.emulator_iop import print_riscv_ssa, run_riscv

print(print_riscv_ssa(module))

Okay, this isn't exactly valid RISC-V assembly, but it's relatively close. Why don't we just try to run it in a RISC-V emulator?

Luckily, this Notebook also provides an emulator that is able to understand RISC-V code with an unlimited number of registers:

In [None]:
run_riscv(print_riscv_ssa(module), unlimited_regs=True)

## Let's Accelerate This:

let's start simple. We want to add a fused multiply-add instruction to our RISC-V ISA.

Let's define it's structure as `fmadd  rd, rs1, rs2, rs3`

We first need to tell our compiler about the structure of our new instruction. For that we can use `xDSL` and it's interface for defining new Operations called `irdl`:

In [None]:
from xdsl.irdl import irdl_op_definition, Operation, Operand, Annotated, OpResult
from riscv.riscv_ssa import *

@irdl_op_definition
class FmaddOp(Operation):
    name = "riscv_ssa.fmadd"
    
    rd: Annotated[OpResult, RegisterType]
    """
    We return a single value in a register
    """
    
    rs1: Annotated[Operand, RegisterType]
    rs2: Annotated[Operand, RegisterType]
    rs3: Annotated[Operand, RegisterType]
    """
    We take three arguments (Operands), which are also registers.
    """
    
    @classmethod
    def get(cls, *rs):
        """
        This is a little helper function, to help us construct an fmadd operation
        """
        return cls.build(operands=rs, result_types=[RegisterType()])

### How to get the Fmadd into the RISC-V?

Now to the interesting part. We need to create a compiler optimization that replaces a `mul` and `add` with an `fmadd`.

For that, we can use the xdsl `pattern_rewriter` module, which provides us with a neat interface for defining optimizations:

In [None]:
# Import some things from the xdsl.pattern_rewriter module:
from xdsl.pattern_rewriter import (GreedyRewritePatternApplier,
                                   PatternRewriter, PatternRewriteWalker,
                                   RewritePattern, op_type_rewrite_pattern)

# Create our rewriter class:
class FmaddOpOptimizer(RewritePattern):
    
    @op_type_rewrite_pattern
    def match_and_rewrite(self, add: AddOp, rewriter: PatternRewriter):
        """
        This method will be called on each AddOp in out RISC-V SSA definition.
        """
        # we iterate over all operands (arguments) of the add instruction
        for operand in add.operands:
            # and try to find a value that was the result of a MULOp
            # also check that it's only used once (by this AddOp) and no one else
            if isinstance(operand.op, MULOp) and len(operand.uses) == 1:
                # if we find one, we grab its arguments
                a, b = operand.op.operands
                # and the other argument to our add instruction
                other_operand = add.rs1 if operand == add.rs2 else add.rs2

                # we then replace the add instruction with a fmadd instruction
                rewriter.replace_matched_op(
                    FmaddOp.get(a, b, other_operand)
                )

                # and erase the mul instruction
                rewriter.erase_op(operand.op)
                break

This is a pretty naive rewrite, but it will work for now.

Let's apply this rewrite:

In [None]:
PatternRewriteWalker(GreedyRewritePatternApplier([FmaddOpOptimizer()])).rewrite_module(module)

In [None]:
# proposed solution to make the cell above nicer:
def apply_rewrites(module, *rewriters):
    PatternRewriteWalker(GreedyRewritePatternApplier(list(r() for r in rewriters))).rewrite_module(module)

# cell above can be written as:
#apply_rewrites(module, FmaddOpOptimizer)

Okay, let's look at what happened to our assembly:

In [None]:
printer.print(module)

We can see, that an `fmadd` operation was inserted, and our `mul` and `add` is done.

let's print it as RISC-V SSA Assembly:

In [None]:
print(print_riscv_ssa(module))

**Success!**

## Emulation Time:

We defined the syntax of the `fmadd` instruction, but we now need to define the semantics for the emulator:

In [None]:
from riscemu.instructions import InstructionSet, Instruction

# Define a RISC-V ISA extension by subclassing InstructionSet
class RV_fmadd(InstructionSet):
    # each method beginning with instruction_ will be available to the Emulator
    
    def instruction_fmadd(self, ins: Instruction):
        """
        This method defines the semantics of the fmadd instruction. Let's settle at:
        
        rd = (rs1 * rs2) + rs3
        """
        # get all register names from the instruction:
        rd, rs1, rs2, rs3 = (ins.get_reg(i) for i in (0,1,2,3))
        
        # we can access the cpu registers through self.regs
        
        # we can set a register value using self.regs.set(name: str, value: Int32)
        self.regs.set(
            rd,
            (self.regs.get(rs1) * self.regs.get(rs2)) + self.regs.get(rs3)
        )

We need to tell the emulator about our new instruction set extension. Luckily, our `run_riscv` function accepts an `extension` argument, that takes a list of extensions for the emulator!

Let's give it a go!

In [None]:
run_riscv(print_riscv_ssa(module), extensions=(RV_fmadd,), unlimited_regs=True)

## Success!

[Back to table of contents](0_Table_of_Contents.ipynb)