Skip to content

Commit

Permalink
Better label refs (#310)
Browse files Browse the repository at this point in the history
Previously we were using a `Box<dyn FnOnce>` to support patching the
code when jumping to labels. We needed to do this because some of the
closures that were being used to patch needed to capture local variables
(on both X86 and ARM it was the type of condition for the conditional
jumps).

To get around that, we can instead use const generics since the
condition codes are always known at compile-time. This means that the
closures go from polymorphic to monomorphic, which means they can be
represented as an `fn` instead of a `Box<dyn FnOnce>`, which means they
can fall back to a plain function pointer. This simplifies the storage
of the `LabelRef` structs and should hopefully be a better default
going forward.

PR: #310
  • Loading branch information
kddnewton authored and noahgibbs committed Aug 24, 2022
1 parent 9934499 commit 5a1375f
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 65 deletions.
34 changes: 18 additions & 16 deletions yjit/src/asm/arm64/arg/condition.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
/// Various instructions in A64 can have condition codes attached. This enum
/// includes all of the various kinds of conditions along with their respective
/// encodings.
pub enum Condition {
EQ = 0b0000, // equal to
NE = 0b0001, // not equal to
CS = 0b0010, // carry set (alias for HS)
CC = 0b0011, // carry clear (alias for LO)
MI = 0b0100, // minus, negative
PL = 0b0101, // positive or zero
VS = 0b0110, // signed overflow
VC = 0b0111, // no signed overflow
HI = 0b1000, // greater than (unsigned)
LS = 0b1001, // less than or equal to (unsigned)
GE = 0b1010, // greater than or equal to (signed)
LT = 0b1011, // less than (signed)
GT = 0b1100, // greater than (signed)
LE = 0b1101, // less than or equal to (signed)
AL = 0b1110, // always
pub struct Condition;

impl Condition {
pub const EQ: u8 = 0b0000; // equal to
pub const NE: u8 = 0b0001; // not equal to
pub const CS: u8 = 0b0010; // carry set (alias for HS)
pub const CC: u8 = 0b0011; // carry clear (alias for LO)
pub const MI: u8 = 0b0100; // minus, negative
pub const PL: u8 = 0b0101; // positive or zero
pub const VS: u8 = 0b0110; // signed overflow
pub const VC: u8 = 0b0111; // no signed overflow
pub const HI: u8 = 0b1000; // greater than (unsigned)
pub const LS: u8 = 0b1001; // less than or equal to (unsigned)
pub const GE: u8 = 0b1010; // greater than or equal to (signed)
pub const LT: u8 = 0b1011; // less than (signed)
pub const GT: u8 = 0b1100; // greater than (signed)
pub const LE: u8 = 0b1101; // less than or equal to (signed)
pub const AL: u8 = 0b1110; // always
}
4 changes: 2 additions & 2 deletions yjit/src/asm/arm64/inst/branch_cond.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use super::super::arg::Condition;
///
pub struct BranchCond {
/// The kind of condition to check before branching.
cond: Condition,
cond: u8,

/// The instruction offset from this instruction to branch to.
imm19: i32
Expand All @@ -20,7 +20,7 @@ pub struct BranchCond {
impl BranchCond {
/// B.cond
/// https://developer.arm.com/documentation/ddi0596/2020-12/Base-Instructions/B-cond--Branch-conditionally-
pub fn bcond(cond: Condition, byte_offset: i32) -> Self {
pub fn bcond(cond: u8, byte_offset: i32) -> Self {
Self { cond, imm19: byte_offset >> 2 }
}
}
Expand Down
2 changes: 1 addition & 1 deletion yjit/src/asm/arm64/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ pub const fn bcond_offset_fits_bits(offset: i64) -> bool {
}

/// B.cond - branch to target if condition is true
pub fn bcond(cb: &mut CodeBlock, cond: Condition, byte_offset: A64Opnd) {
pub fn bcond(cb: &mut CodeBlock, cond: u8, byte_offset: A64Opnd) {
let bytes: [u8; 4] = match byte_offset {
A64Opnd::Imm(imm) => {
assert!(bcond_offset_fits_bits(imm), "The immediate operand must be 21 bits or less and be aligned to a 2-bit boundary.");
Expand Down
6 changes: 3 additions & 3 deletions yjit/src/asm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ struct LabelRef {
num_bytes: usize,

/// The object that knows how to encode the branch instruction.
encode: Box<dyn FnOnce(&mut CodeBlock, i64, i64)>
encode: fn(&mut CodeBlock, i64, i64)
}

/// Block of memory into which instructions can be assembled
Expand Down Expand Up @@ -227,11 +227,11 @@ impl CodeBlock {
}

// Add a label reference at the current write position
pub fn label_ref<E: 'static>(&mut self, label_idx: usize, num_bytes: usize, encode: E) where E: FnOnce(&mut CodeBlock, i64, i64) {
pub fn label_ref(&mut self, label_idx: usize, num_bytes: usize, encode: fn(&mut CodeBlock, i64, i64)) {
assert!(label_idx < self.label_addrs.len());

// Keep track of the reference
self.label_refs.push(LabelRef { pos: self.write_pos, label_idx, num_bytes, encode: Box::new(encode) });
self.label_refs.push(LabelRef { pos: self.write_pos, label_idx, num_bytes, encode });

// Move past however many bytes the instruction takes up
self.write_pos += num_bytes;
Expand Down
66 changes: 33 additions & 33 deletions yjit/src/asm/x86_64/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -799,45 +799,45 @@ pub fn int3(cb: &mut CodeBlock) {

// Encode a conditional relative jump to a label
// Note: this always encodes a 32-bit offset
fn write_jcc(cb: &mut CodeBlock, op: u8, label_idx: usize) {
cb.label_ref(label_idx, 6, move |cb, src_addr, dst_addr| {
fn write_jcc<const OP: u8>(cb: &mut CodeBlock, label_idx: usize) {
cb.label_ref(label_idx, 6, |cb, src_addr, dst_addr| {
cb.write_byte(0x0F);
cb.write_byte(op);
cb.write_byte(OP);
cb.write_int((dst_addr - src_addr) as u64, 32);
});
}

/// jcc - relative jumps to a label
pub fn ja_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc(cb, 0x87, label_idx); }
pub fn jae_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc(cb, 0x83, label_idx); }
pub fn jb_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc(cb, 0x82, label_idx); }
pub fn jbe_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc(cb, 0x86, label_idx); }
pub fn jc_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc(cb, 0x82, label_idx); }
pub fn je_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc(cb, 0x84, label_idx); }
pub fn jg_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc(cb, 0x8F, label_idx); }
pub fn jge_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc(cb, 0x8D, label_idx); }
pub fn jl_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc(cb, 0x8C, label_idx); }
pub fn jle_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc(cb, 0x8E, label_idx); }
pub fn jna_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc(cb, 0x86, label_idx); }
pub fn jnae_label(cb: &mut CodeBlock, label_idx: usize) { write_jcc(cb, 0x82, label_idx); }
pub fn jnb_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc(cb, 0x83, label_idx); }
pub fn jnbe_label(cb: &mut CodeBlock, label_idx: usize) { write_jcc(cb, 0x87, label_idx); }
pub fn jnc_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc(cb, 0x83, label_idx); }
pub fn jne_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc(cb, 0x85, label_idx); }
pub fn jng_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc(cb, 0x8E, label_idx); }
pub fn jnge_label(cb: &mut CodeBlock, label_idx: usize) { write_jcc(cb, 0x8C, label_idx); }
pub fn jnl_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc(cb, 0x8D, label_idx); }
pub fn jnle_label(cb: &mut CodeBlock, label_idx: usize) { write_jcc(cb, 0x8F, label_idx); }
pub fn jno_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc(cb, 0x81, label_idx); }
pub fn jnp_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc(cb, 0x8b, label_idx); }
pub fn jns_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc(cb, 0x89, label_idx); }
pub fn jnz_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc(cb, 0x85, label_idx); }
pub fn jo_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc(cb, 0x80, label_idx); }
pub fn jp_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc(cb, 0x8A, label_idx); }
pub fn jpe_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc(cb, 0x8A, label_idx); }
pub fn jpo_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc(cb, 0x8B, label_idx); }
pub fn js_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc(cb, 0x88, label_idx); }
pub fn jz_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc(cb, 0x84, label_idx); }
pub fn ja_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc::<0x87>(cb, label_idx); }
pub fn jae_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc::<0x83>(cb, label_idx); }
pub fn jb_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc::<0x82>(cb, label_idx); }
pub fn jbe_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc::<0x86>(cb, label_idx); }
pub fn jc_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc::<0x82>(cb, label_idx); }
pub fn je_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc::<0x84>(cb, label_idx); }
pub fn jg_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc::<0x8F>(cb, label_idx); }
pub fn jge_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc::<0x8D>(cb, label_idx); }
pub fn jl_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc::<0x8C>(cb, label_idx); }
pub fn jle_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc::<0x8E>(cb, label_idx); }
pub fn jna_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc::<0x86>(cb, label_idx); }
pub fn jnae_label(cb: &mut CodeBlock, label_idx: usize) { write_jcc::<0x82>(cb, label_idx); }
pub fn jnb_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc::<0x83>(cb, label_idx); }
pub fn jnbe_label(cb: &mut CodeBlock, label_idx: usize) { write_jcc::<0x87>(cb, label_idx); }
pub fn jnc_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc::<0x83>(cb, label_idx); }
pub fn jne_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc::<0x85>(cb, label_idx); }
pub fn jng_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc::<0x8E>(cb, label_idx); }
pub fn jnge_label(cb: &mut CodeBlock, label_idx: usize) { write_jcc::<0x8C>(cb, label_idx); }
pub fn jnl_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc::<0x8D>(cb, label_idx); }
pub fn jnle_label(cb: &mut CodeBlock, label_idx: usize) { write_jcc::<0x8F>(cb, label_idx); }
pub fn jno_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc::<0x81>(cb, label_idx); }
pub fn jnp_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc::<0x8b>(cb, label_idx); }
pub fn jns_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc::<0x89>(cb, label_idx); }
pub fn jnz_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc::<0x85>(cb, label_idx); }
pub fn jo_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc::<0x80>(cb, label_idx); }
pub fn jp_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc::<0x8A>(cb, label_idx); }
pub fn jpe_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc::<0x8A>(cb, label_idx); }
pub fn jpo_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc::<0x8B>(cb, label_idx); }
pub fn js_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc::<0x88>(cb, label_idx); }
pub fn jz_label (cb: &mut CodeBlock, label_idx: usize) { write_jcc::<0x84>(cb, label_idx); }

pub fn jmp_label(cb: &mut CodeBlock, label_idx: usize) {
cb.label_ref(label_idx, 5, |cb, src_addr, dst_addr| {
Expand Down
20 changes: 10 additions & 10 deletions yjit/src/backend/arm64/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ impl Assembler

/// Emit a conditional jump instruction to a specific target. This is
/// called when lowering any of the conditional jump instructions.
fn emit_conditional_jump(cb: &mut CodeBlock, condition: Condition, target: Target) {
fn emit_conditional_jump<const CONDITION: u8>(cb: &mut CodeBlock, target: Target) {
match target {
Target::CodePtr(dst_ptr) => {
let src_addr = cb.get_write_ptr().into_i64() + 4;
Expand All @@ -297,12 +297,12 @@ impl Assembler
// to load the address into a register and use the branch
// register instruction.
if bcond_offset_fits_bits(offset) {
bcond(cb, condition, A64Opnd::new_imm(dst_addr - src_addr));
bcond(cb, CONDITION, A64Opnd::new_imm(dst_addr - src_addr));
} else {
// If the condition is met, then we'll skip past the
// next instruction, put the address in a register, and
// jump to it.
bcond(cb, condition, A64Opnd::new_imm(4));
bcond(cb, CONDITION, A64Opnd::new_imm(4));

// If the offset fits into a direct jump, then we'll use
// that and the number of instructions will be shorter.
Expand Down Expand Up @@ -333,7 +333,7 @@ impl Assembler
// offset. We're going to assume we can fit into a single
// b.cond instruction. It will panic otherwise.
cb.label_ref(label_idx, 4, |cb, src_addr, dst_addr| {
bcond(cb, condition, A64Opnd::new_imm(dst_addr - src_addr));
bcond(cb, CONDITION, A64Opnd::new_imm(dst_addr - src_addr));
});
},
Target::FunPtr(_) => unreachable!()
Expand Down Expand Up @@ -395,7 +395,7 @@ impl Assembler
// being loaded is a heap object, we'll report that
// back out to the gc_offsets list.
ldr(cb, insn.out.into(), 1);
b(cb, A64Opnd::new_uimm((SIZEOF_VALUE as u64) / 4));
b(cb, A64Opnd::new_imm((SIZEOF_VALUE as i64) / 4));
cb.write_bytes(&value.as_u64().to_le_bytes());

if !value.special_const_p() {
Expand Down Expand Up @@ -507,19 +507,19 @@ impl Assembler
};
},
Op::Je => {
emit_conditional_jump(cb, Condition::EQ, insn.target.unwrap());
emit_conditional_jump::<{Condition::EQ}>(cb, insn.target.unwrap());
},
Op::Jbe => {
emit_conditional_jump(cb, Condition::LS, insn.target.unwrap());
emit_conditional_jump::<{Condition::LS}>(cb, insn.target.unwrap());
},
Op::Jz => {
emit_conditional_jump(cb, Condition::EQ, insn.target.unwrap());
emit_conditional_jump::<{Condition::EQ}>(cb, insn.target.unwrap());
},
Op::Jnz => {
emit_conditional_jump(cb, Condition::NE, insn.target.unwrap());
emit_conditional_jump::<{Condition::NE}>(cb, insn.target.unwrap());
},
Op::Jo => {
emit_conditional_jump(cb, Condition::VS, insn.target.unwrap());
emit_conditional_jump::<{Condition::VS}>(cb, insn.target.unwrap());
},
Op::IncrCounter => {
ldaddal(cb, insn.opnds[0].into(), insn.opnds[0].into(), insn.opnds[1].into());
Expand Down

0 comments on commit 5a1375f

Please sign in to comment.