# Supporting a custom L0 RISC-V accelerator


Let's assume we have been given the following from God:

 - A program compiled half-way down to RISC-V Assembly
 - The format of this program is what we call "RISC-V SSA"

Let's not worry too much about what all this means, and instead just look at it:


In [11]:
from riscv.riscv_ssa import *
from xdsl.dialects.builtin import ModuleOp
from riscv.emulator_iop import run_riscv, print_riscv_ssa

p = Printer(target=Printer.Target.MLIR)

### Our SSA RISC-V:

In [12]:
module = ModuleOp.from_region_or_ops([
    LabelOp.get('main'),
    a0  := LIOp.get(82),
    a1  := LIOp.get(5),
    mul := MULOp.get(a0, a1),
    a2  := LIOp.get(10),
    add := AddOp.get(mul, a2),
    ECALLOp.get(93, add)
])

**let's print it out:**

In [3]:
p.print(module)

"builtin.module"() ({
  "riscv_ssa.label"() {"label" = #riscv.label<main>} : () -> ()
  %0 = "riscv_ssa.li"() {"immediate" = 82 : i32} : () -> #riscv_ssa.reg
  %1 = "riscv_ssa.li"() {"immediate" = 5 : i32} : () -> #riscv_ssa.reg
  %2 = "riscv_ssa.mul"(%0, %1) : (#riscv_ssa.reg, #riscv_ssa.reg) -> #riscv_ssa.reg
  %3 = "riscv_ssa.li"() {"immediate" = 10 : i32} : () -> #riscv_ssa.reg
  %4 = "riscv_ssa.add"(%2, %3) : (#riscv_ssa.reg, #riscv_ssa.reg) -> #riscv_ssa.reg
  "riscv_ssa.ecall"(%4) {"syscall_num" = 93 : i32} : (#riscv_ssa.reg) -> ()
}) : () -> ()


This *really* doesn't look like RISC-V. But if I look at it for a while, I can see some familiar stuff:


```
    %0 = "riscv_ssa.li"() {"immediate" = 82 : i32}
    ^^              ^^                   ^^
    Result?         Op-name              Argument
```

We seem to have the `li` pseudo-op here, with an immediate value of `82`. It's result is stored in `%0`.

### A short (and incomplete) introduction to SSA

 - We have infinitely many variables (or registers)
 - We can only assign to each register once

In [4]:
@irdl_op_definition
class FmaddOp(Operation):
    name = "riscv_ssa.fmadd"
    
    rd: Annotated[OpResult, RegisterType]
    rs1: Annotated[Operand, RegisterType]
    rs2: Annotated[Operand, RegisterType]
    rs3: Annotated[Operand, RegisterType]
    
    @classmethod
    def get(cls, *rs):
        return cls.build(operands=rs, result_types=[RegisterType()])

In [5]:
from xdsl.pattern_rewriter import (GreedyRewritePatternApplier,
                                   PatternRewriter, PatternRewriteWalker,
                                   RewritePattern, op_type_rewrite_pattern)

class FmaddOpOptimizer(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, add: AddOp, rewriter: PatternRewriter):        
        for operand in add.operands:
            if isinstance(operand.op, MULOp):
                other_operand = add.rs1 if operand == add.rs2 else add.rs2                
                a, b = operand.op.operands

                rewriter.replace_matched_op(
                    FmaddOp.get(a, b, other_operand)
                )

                rewriter.erase_op(operand.op)
                break

In [6]:
PatternRewriteWalker(GreedyRewritePatternApplier([FmaddOpOptimizer()])).rewrite_module(module)

In [7]:
p.print(module)

"builtin.module"() ({
  "riscv_ssa.label"() {"label" = #riscv.label<main>} : () -> ()
  %0 = "riscv_ssa.li"() {"immediate" = 82 : i32} : () -> #riscv_ssa.reg
  %1 = "riscv_ssa.li"() {"immediate" = 5 : i32} : () -> #riscv_ssa.reg
  %3 = "riscv_ssa.li"() {"immediate" = 10 : i32} : () -> #riscv_ssa.reg
  %5 = "riscv_ssa.fmadd"(%0, %1, %3) : (#riscv_ssa.reg, #riscv_ssa.reg, #riscv_ssa.reg) -> #riscv_ssa.reg
  "riscv_ssa.ecall"(%5) {"syscall_num" = 93 : i32} : (#riscv_ssa.reg) -> ()
}) : () -> ()


In [8]:
print(print_riscv_ssa(module))

.text
main:
	li	%0, 82
	li	%1, 5
	li	%2, 10
	fmadd	%3, %0, %1, %2
	mv	a0, %3
	li	a7, 93
	scall



In [9]:
from riscemu.instructions import InstructionSet
from riscemu.types import Instruction

class RV_fmadd(InstructionSet):
    def instruction_fmadd(self, ins: Instruction):
        rd, rs1, rs2, rs3 = (ins.get_reg(i) for i in (0,1,2,3))
        
        self.regs.set(
            rd, (self.regs.get(rs1) * self.regs.get(rs2)) + self.regs.get(rs3)
        )

In [10]:
run_riscv(print_riscv_ssa(module), extensions=(RV_fmadd,), unlimited_regs=True)

[34m[1m[CPU] Started running from example.asm:.text at main (0x100) + 0x0[0m
Program(name=example.asm,sections=set(),base=['.text'])
[34m[1m   Running 0x00000100:[0m li %0, 82
[34m[1m   Running 0x00000104:[0m li %1, 5
[34m[1m   Running 0x00000108:[0m li %2, 10
[34m[1m   Running 0x0000010C:[0m fmadd %3, %0, %1, %2
[34m[1m   Running 0x00000110:[0m mv a0, %3
[34m[1m   Running 0x00000114:[0m li a7, 93
[34m[1m   Running 0x00000118:[0m scall 
[34m[1m[CPU] Program exited with code 420[0m
