# Supporting a custom L0 RISC-V accelerator


Let's assume we have been given the following from God:

 - A program compiled half-way down to RISC-V Assembly
 - The format of this program is what we call "RISC-V SSA"

Let's not worry too much about what all this means, and instead just look at it:


In [1]:
from riscv.riscv_ssa import *
from xdsl.dialects.builtin import ModuleOp
from riscv.emulator_iop import run_riscv, print_riscv_ssa

p = Printer(target=Printer.Target.MLIR)

## Our SSA RISC-V:

In [2]:
module = ModuleOp.from_region_or_ops([
    LabelOp.get('main'),
    a0  := LIOp.get(82),
    a1  := LIOp.get(5),
    mul := MULOp.get(a0, a1),
    a2  := LIOp.get(10),
    add := AddOp.get(mul, a2),
    ECALLOp.get(93, add)
])

**let's print it out:**

In [3]:
p.print(module)

"builtin.module"() ({
  "riscv_ssa.label"() {"label" = #riscv.label<main>} : () -> ()
  %0 = "riscv_ssa.li"() {"immediate" = 82 : i32} : () -> #riscv_ssa.reg
  %1 = "riscv_ssa.li"() {"immediate" = 5 : i32} : () -> #riscv_ssa.reg
  %2 = "riscv_ssa.mul"(%0, %1) : (#riscv_ssa.reg, #riscv_ssa.reg) -> #riscv_ssa.reg
  %3 = "riscv_ssa.li"() {"immediate" = 10 : i32} : () -> #riscv_ssa.reg
  %4 = "riscv_ssa.add"(%2, %3) : (#riscv_ssa.reg, #riscv_ssa.reg) -> #riscv_ssa.reg
  "riscv_ssa.ecall"(%4) {"syscall_num" = 93 : i32} : (#riscv_ssa.reg) -> ()
}) : () -> ()


This *really* doesn't look like RISC-V. But if I look at it for a while, I can see some familiar stuff:


```
    %0 = "riscv_ssa.li"() {"immediate" = 82 : i32}
    ^^              ^^                   ^^
    Result?         Op-name              Argument
```

We seem to have the `li` pseudo-op here, with an immediate value of `82`. It's result is stored in `%0`.

## A short (and incomplete) introduction to SSA

 - We have infinitely many variables (or registers)
 - We can only assign to each register once
 
This Notebook provides a function to print the RISC-V SSA in a more RISC-V-like format:

In [4]:
print(print_riscv_ssa(module))

.text
main:
	li	%0, 82
	li	%1, 5
	mul	%2, %0, %1
	li	%3, 10
	add	%4, %2, %3
	mv	a0, %4
	li	a7, 93
	scall



Okay, this isn't exactly valid RISC-V, but it's relatively close. We can try to run it in an RISC-V Emulator.

Luckily, this Notebook comes with a built-in emulator:

In [5]:
run_riscv(print_riscv_ssa(module))

[34m[1m[CPU] Started running from example.asm:.text at main (0x100) + 0x0[0m
Program(name=example.asm,sections=set(),base=['.text'])
[34m[1m   Running 0x00000100:[0m li %0, 82
Invalid register: %0


**Aaaargh, this is bad! We need to do register allocation!**


If *only* there was a way to make the emulator accept our "unlimited register" RISC-V...


In [6]:
run_riscv(print_riscv_ssa(module), unlimited_regs=True)

[34m[1m[CPU] Started running from example.asm:.text at main (0x100) + 0x0[0m
Program(name=example.asm,sections=set(),base=['.text'])
[34m[1m   Running 0x00000100:[0m li %0, 82
[34m[1m   Running 0x00000104:[0m li %1, 5
[34m[1m   Running 0x00000108:[0m mul %2, %0, %1
[34m[1m   Running 0x0000010C:[0m li %3, 10
[34m[1m   Running 0x00000110:[0m add %4, %2, %3
[34m[1m   Running 0x00000114:[0m mv a0, %4
[34m[1m   Running 0x00000118:[0m li a7, 93
[34m[1m   Running 0x0000011C:[0m scall 
[34m[1m[CPU] Program exited with code 420[0m


## Let's Accelerate This:

let's start simple. We want to add a fused multiply-add instruction to our RISC-V ISA.

Let's define it's structure as `fmadd  rd, rs1, rs2, rs3`

We first need to tell our compiler about the structure of our new instruction. For that we can use `xDSL` and it's interface for defining new Operations called `irdl`:

In [7]:
@irdl_op_definition
class FmaddOp(Operation):
    name = "riscv_ssa.fmadd"
    
    rd: Annotated[OpResult, RegisterType]
    """
    We return a single value in a register
    """
    
    rs1: Annotated[Operand, RegisterType]
    rs2: Annotated[Operand, RegisterType]
    rs3: Annotated[Operand, RegisterType]
    """
    We take three arguments (Operands), which are also registers.
    """
    
    @classmethod
    def get(cls, *rs):
        """
        This is a little helper function, to help us construct an fmadd operation
        """
        return cls.build(operands=rs, result_types=[RegisterType()])

### How to get the Fmadd into the RISC-V?

Now to the interesting part. We need to create a compiler optimization that replaces a `mul` and `add` with an `fmadd`.

For that, we can use the xdsl `pattern_rewriter` module, which provides us with a neat interface for defining optimizations:

In [8]:
# Import some things from the xdsl.pattern_rewriter module:
from xdsl.pattern_rewriter import (GreedyRewritePatternApplier,
                                   PatternRewriter, PatternRewriteWalker,
                                   RewritePattern, op_type_rewrite_pattern)

# Create our rewriter class:
class FmaddOpOptimizer(RewritePattern):
    
    @op_type_rewrite_pattern
    def match_and_rewrite(self, add: AddOp, rewriter: PatternRewriter):
        """
        This method will be called on each AddOp in out RISC-V SSA definition.
        """
        # we iterate over all operands (arguments) of the add instruction
        for operand in add.operands:
            # and try to find a value that was the result of a MULOp
            if isinstance(operand.op, MULOp):
                # if we find one, we grab its arguments
                a, b = operand.op.operands
                # and the other argument to our add instruction
                other_operand = add.rs1 if operand == add.rs2 else add.rs2

                # we then replace the add instruction with a fmadd instruction
                rewriter.replace_matched_op(
                    FmaddOp.get(a, b, other_operand)
                )

                # and erase the mul instruction
                rewriter.erase_op(operand.op)
                break

This is a pretty naive rewrite, but it will work for now.

We might want to check if the return value from the mul is used somewhere else before erasing the `mul`, but that is for future-us to worry about.

Let's apply this rewrite:

In [9]:
PatternRewriteWalker(GreedyRewritePatternApplier([FmaddOpOptimizer()])).rewrite_module(module)

In [10]:
# proposed solution to make the cell above nicer:
def apply_rewrites(module, *rewriters):
    PatternRewriteWalker(GreedyRewritePatternApplier(list(rewriters))).rewrite_module(module)

# cell above can be written as:
#apply_rewrites(module, FmaddOpOptimizer())

Okay, let's look at what happened to our assembly:

In [11]:
p.print(module)

"builtin.module"() ({
  "riscv_ssa.label"() {"label" = #riscv.label<main>} : () -> ()
  %0 = "riscv_ssa.li"() {"immediate" = 82 : i32} : () -> #riscv_ssa.reg
  %1 = "riscv_ssa.li"() {"immediate" = 5 : i32} : () -> #riscv_ssa.reg
  %3 = "riscv_ssa.li"() {"immediate" = 10 : i32} : () -> #riscv_ssa.reg
  %5 = "riscv_ssa.fmadd"(%0, %1, %3) : (#riscv_ssa.reg, #riscv_ssa.reg, #riscv_ssa.reg) -> #riscv_ssa.reg
  "riscv_ssa.ecall"(%5) {"syscall_num" = 93 : i32} : (#riscv_ssa.reg) -> ()
}) : () -> ()


We can see, that an `fmadd` operation was inserted, and our `mul` and `add` is done.

let's print it as RISC-V SSA Assembly:

In [12]:
print(print_riscv_ssa(module))

.text
main:
	li	%0, 82
	li	%1, 5
	li	%2, 10
	fmadd	%3, %0, %1, %2
	mv	a0, %3
	li	a7, 93
	scall



**Success!**

## Emulation Time:

We defined the syntax of the `fmadd` instruction, but we now need to define the semantics for the emulator:

In [13]:
from riscemu.instructions import InstructionSet, Instruction

# Define a RISC-V ISA extension by subclassing InstructionSet
class RV_fmadd(InstructionSet):
    # each method beginning with instruction_ will be available to the Emulator
    
    def instruction_fmadd(self, ins: Instruction):
        """
        We now have to define the fmadd semantics. Let's define it as
        
        rd = (rs1 * rs2) + rs3
        """
        # get all register names from the instruction:
        rd, rs1, rs2, rs3 = (ins.get_reg(i) for i in (0,1,2,3))
        
        # we can access the cpu registers through self.regs
        
        # we can set a register value using self.regs.set(name: str, value: Int32)
        self.regs.set(
            rd,
            # 
            (self.regs.get(rs1) * self.regs.get(rs2)) + self.regs.get(rs3)
        )

We need to tell the emulator about our new instruction set extension. Luckily, our `run_riscv` function accepts an `extension` argument, that takes a list of extensions for the emulator!

Let's give it a go!

In [14]:
run_riscv(print_riscv_ssa(module), extensions=(RV_fmadd,), unlimited_regs=True)

[34m[1m[CPU] Started running from example.asm:.text at main (0x100) + 0x0[0m
Program(name=example.asm,sections=set(),base=['.text'])
[34m[1m   Running 0x00000100:[0m li %0, 82
[34m[1m   Running 0x00000104:[0m li %1, 5
[34m[1m   Running 0x00000108:[0m li %2, 10
[34m[1m   Running 0x0000010C:[0m fmadd %3, %0, %1, %2
[34m[1m   Running 0x00000110:[0m mv a0, %3
[34m[1m   Running 0x00000114:[0m li a7, 93
[34m[1m   Running 0x00000118:[0m scall 
[34m[1m[CPU] Program exited with code 420[0m


## Success!