Skip to content

Commit

Permalink
[CHERI] Use separate Pseudo instructions for cmpxchg nodes
Browse files Browse the repository at this point in the history
Using separate pseudos for exact and inexact comparsions ensures that
our lowering does not depend on the MachineMemOperand (after SDAG) since
passes could drop it (which means use the most conservative approach).
This adds a bit of boilerplate but it's not as bad as I expected and is
less fragile than the previous approach.
  • Loading branch information
arichardson committed Feb 18, 2024
1 parent f60ed4d commit 056a8a4
Show file tree
Hide file tree
Showing 11 changed files with 390 additions and 48 deletions.
6 changes: 6 additions & 0 deletions llvm/include/llvm/CodeGen/SelectionDAGNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1453,6 +1453,12 @@ class AtomicSDNode : public MemSDNode {
return MMO->getFailureOrdering();
}

/// Return true if the memory operation ordering is Unordered or higher.
bool isExactCmpXchg() const {
assert(getMemoryVT().isFatPointer());
return MMO->isExactCompare();
}

// Methods to support isa and dyn_cast
static bool classof(const SDNode *N) {
return N->getOpcode() == ISD::ATOMIC_CMP_SWAP ||
Expand Down
18 changes: 15 additions & 3 deletions llvm/include/llvm/Target/TargetSelectionDAG.td
Original file line number Diff line number Diff line change
Expand Up @@ -1888,10 +1888,21 @@ multiclass binary_atomic_op_cap<SDNode atomic_op> {
defm NAME : binary_atomic_op_ord;
}

multiclass ternary_atomic_op_cap<SDNode atomic_op> {
multiclass ternary_atomic_op_cap_inexact<SDNode atomic_op> {
def "" : PatFrag<(ops node:$ptr, node:$cmp, node:$val),
(atomic_op node:$ptr, node:$cmp, node:$val), [{
return cast<AtomicSDNode>(N)->getMemoryVT().isFatPointer();
auto AN = cast<AtomicSDNode>(N);
return AN->getMemoryVT().isFatPointer() && !AN->isExactCmpXchg();
}]>;

defm NAME : ternary_atomic_op_ord;
}

multiclass ternary_atomic_op_cap_exact<SDNode atomic_op> {
def "" : PatFrag<(ops node:$ptr, node:$cmp, node:$val),
(atomic_op node:$ptr, node:$cmp, node:$val), [{
auto AN = cast<AtomicSDNode>(N);
return AN->getMemoryVT().isFatPointer() && AN->isExactCmpXchg();
}]>;

defm NAME : ternary_atomic_op_ord;
Expand All @@ -1910,7 +1921,8 @@ defm atomic_load_max_cap : binary_atomic_op_cap<atomic_load_max_cap_node>;
defm atomic_load_umin_cap : binary_atomic_op_cap<atomic_load_umin_cap_node>;
defm atomic_load_umax_cap : binary_atomic_op_cap<atomic_load_umax_cap_node>;
defm atomic_store_cap : binary_atomic_op_cap<atomic_store_cap_node>;
defm atomic_cmp_swap_cap : ternary_atomic_op_cap<atomic_cmp_swap_cap_node>;
defm atomic_cmp_swap_cap_addr : ternary_atomic_op_cap_inexact<atomic_cmp_swap_cap_node>;
defm atomic_cmp_swap_cap_exact : ternary_atomic_op_cap_exact<atomic_cmp_swap_cap_node>;

def atomic_load_cap :
PatFrag<(ops node:$ptr),
Expand Down
12 changes: 7 additions & 5 deletions llvm/lib/Target/Mips/MipsExpandPseudo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,14 +212,18 @@ bool MipsExpandPseudo::expandAtomicCmpSwap(MachineBasicBlock &BB,

unsigned Size = -1;
bool IsCapCmpXchg = false;
bool UseExactEquals = false;
switch(I->getOpcode()) {
case Mips::ATOMIC_CMP_SWAP_I32_POSTRA: Size = 4; break;
case Mips::ATOMIC_CMP_SWAP_I64_POSTRA: Size = 8; break;
case Mips::CAP_ATOMIC_CMP_SWAP_I8_POSTRA: Size = 1; break;
case Mips::CAP_ATOMIC_CMP_SWAP_I16_POSTRA: Size = 2; break;
case Mips::CAP_ATOMIC_CMP_SWAP_I32_POSTRA: Size = 4; break;
case Mips::CAP_ATOMIC_CMP_SWAP_I64_POSTRA: Size = 8; break;
case Mips::CAP_ATOMIC_CMP_SWAP_CAP_POSTRA:
case Mips::CAP_ATOMIC_CMP_SWAP_CAP_EXACT_POSTRA:
UseExactEquals = true;
LLVM_FALLTHROUGH;
case Mips::CAP_ATOMIC_CMP_SWAP_CAP_ADDR_POSTRA:
Size = CAP_ATOMIC_SIZE;
IsCapCmpXchg = true;
break;
Expand Down Expand Up @@ -327,9 +331,6 @@ bool MipsExpandPseudo::expandAtomicCmpSwap(MachineBasicBlock &BB,
if (!IsCapOp)
LLOp.addImm(0);
if (IsCapCmpXchg) {
assert(I->hasOneMemOperand());
bool UseExactEquals =
STI->useCheriExactEquals() || I->memoperands()[0]->isExactCompare();
unsigned CapCmp = UseExactEquals ? Mips::CEXEQ : Mips::CEQ;
// load, compare, and exit if not equal
// cllc dest, ptr
Expand Down Expand Up @@ -1098,7 +1099,8 @@ bool MipsExpandPseudo::expandMI(MachineBasicBlock &MBB,
case Mips::CAP_ATOMIC_CMP_SWAP_I16_POSTRA:
case Mips::CAP_ATOMIC_CMP_SWAP_I32_POSTRA:
case Mips::CAP_ATOMIC_CMP_SWAP_I64_POSTRA:
case Mips::CAP_ATOMIC_CMP_SWAP_CAP_POSTRA:
case Mips::CAP_ATOMIC_CMP_SWAP_CAP_ADDR_POSTRA:
case Mips::CAP_ATOMIC_CMP_SWAP_CAP_EXACT_POSTRA:
return expandAtomicCmpSwap(MBB, MBBI, NMBB, /*IsCapOp=*/true);
case Mips::PseudoPccRelativeAddressPostRA:
return expandPccRelativeAddr(MBB, MBBI, NMBB);
Expand Down
11 changes: 8 additions & 3 deletions llvm/lib/Target/Mips/MipsISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1837,7 +1837,8 @@ MipsTargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
case Mips::CAP_ATOMIC_CMP_SWAP_I16:
case Mips::CAP_ATOMIC_CMP_SWAP_I32:
case Mips::CAP_ATOMIC_CMP_SWAP_I64:
case Mips::CAP_ATOMIC_CMP_SWAP_CAP:
case Mips::CAP_ATOMIC_CMP_SWAP_CAP_ADDR:
case Mips::CAP_ATOMIC_CMP_SWAP_CAP_EXACT:
return emitAtomicCmpSwap(MI, BB);


Expand Down Expand Up @@ -2445,8 +2446,12 @@ MipsTargetLowering::emitAtomicCmpSwap(MachineInstr &MI,
AtomicOp = Mips::CAP_ATOMIC_CMP_SWAP_I64_POSTRA;
ScratchTy = MVT::i64;
break;
case Mips::CAP_ATOMIC_CMP_SWAP_CAP:
AtomicOp = Mips::CAP_ATOMIC_CMP_SWAP_CAP_POSTRA;
case Mips::CAP_ATOMIC_CMP_SWAP_CAP_ADDR:
AtomicOp = Mips::CAP_ATOMIC_CMP_SWAP_CAP_ADDR_POSTRA;
ScratchTy = MVT::i64;
break;
case Mips::CAP_ATOMIC_CMP_SWAP_CAP_EXACT:
AtomicOp = Mips::CAP_ATOMIC_CMP_SWAP_CAP_EXACT_POSTRA;
ScratchTy = MVT::i64;
break;
default:
Expand Down
14 changes: 8 additions & 6 deletions llvm/lib/Target/Mips/MipsInstrCheri.td
Original file line number Diff line number Diff line change
Expand Up @@ -763,8 +763,9 @@ let usesCustomInserter = 1 in {

// Capability atomics:
// FIXME: this seems wrong it should be CheriGPROrCNULL
def CAP_ATOMIC_SWAP_CAP : CapAtomic2Ops<atomic_swap_cap, CheriOpnd>;
def CAP_ATOMIC_CMP_SWAP_CAP : CapAtomicCmpSwap<atomic_cmp_swap_cap, CheriOpnd>;
def CAP_ATOMIC_SWAP_CAP : CapAtomic2Ops<atomic_swap_cap, CheriOpnd>;
def CAP_ATOMIC_CMP_SWAP_CAP_ADDR : CapAtomicCmpSwap<atomic_cmp_swap_cap_addr, CheriOpnd>;
def CAP_ATOMIC_CMP_SWAP_CAP_EXACT : CapAtomicCmpSwap<atomic_cmp_swap_cap_exact, CheriOpnd>;

// TODO: implement these:
// def ATOMIC_LOAD_ADD_CAP : Atomic2Ops<atomic_load_add_cap, CheriOpnd>;
Expand Down Expand Up @@ -816,8 +817,9 @@ def CAP_ATOMIC_CMP_SWAP_I64_POSTRA : CapAtomicCmpSwapPostRA<GPR64Opnd>;
// Capability postra atomics:
// TODO: do we want add/sub/or/xor/nand/and for capabilities?
// I guess add/sub makes sense but the others don't
def CAP_ATOMIC_SWAP_CAP_POSTRA : CapAtomic2OpsPostRA<CheriOpnd>;
def CAP_ATOMIC_CMP_SWAP_CAP_POSTRA : CapAtomicCmpSwapPostRA<CheriOpnd>;
def CAP_ATOMIC_SWAP_CAP_POSTRA : CapAtomic2OpsPostRA<CheriOpnd>;
def CAP_ATOMIC_CMP_SWAP_CAP_ADDR_POSTRA : CapAtomicCmpSwapPostRA<CheriOpnd>;
def CAP_ATOMIC_CMP_SWAP_CAP_EXACT_POSTRA : CapAtomicCmpSwapPostRA<CheriOpnd>;
// TODO:
// def CAP_ATOMIC_LOAD_ADD_CAP_POSTRA : CapAtomic2OpsPostRA<CheriOpnd>;
// def CAP_ATOMIC_LOAD_SUB_CAP_POSTRA : CapAtomic2OpsPostRA<CheriOpnd>;
Expand Down Expand Up @@ -853,8 +855,8 @@ def : MipsPat<(atomic_store_cap GPR64Opnd:$a, CheriOpnd:$v),
(STORECAP $v, GPR64Opnd:$a, (i64 0), DDC)>;
def : MipsPat<(atomic_swap_cap GPR64Opnd:$a, CheriOpnd:$swap),
(CAP_ATOMIC_SWAP_CAP (CFromPtr DDC, GPR64Opnd:$a), CheriOpnd:$swap)>;
def : MipsPat<(atomic_cmp_swap_cap GPR64Opnd:$a, CheriOpnd:$cmp, CheriOpnd:$swap),
(CAP_ATOMIC_CMP_SWAP_CAP (CFromPtr DDC, GPR64Opnd:$a), CheriOpnd:$cmp, CheriOpnd:$swap)>;
def : MipsPat<(atomic_cmp_swap_cap_addr GPR64Opnd:$a, CheriOpnd:$cmp, CheriOpnd:$swap),
(CAP_ATOMIC_CMP_SWAP_CAP_ADDR (CFromPtr DDC, GPR64Opnd:$a), CheriOpnd:$cmp, CheriOpnd:$swap)>;
}
////////////////////////////////////////////////////////////////////////////////
// Helpers for capability-using calls and returns
Expand Down
10 changes: 6 additions & 4 deletions llvm/lib/Target/RISCV/RISCVExpandAtomicPseudoInsts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ bool RISCVExpandAtomicPseudo::expandMI(MachineBasicBlock &MBB,
case RISCV::PseudoAtomicLoadUMinCap:
return expandAtomicMinMaxOp(MBB, MBBI, AtomicRMWInst::UMin, false, CLenVT,
false, NextMBBI);
case RISCV::PseudoCmpXchgCap:
case RISCV::PseudoCmpXchgCapAddr:
case RISCV::PseudoCmpXchgCapExact:
return expandAtomicCmpXchg(MBB, MBBI, false, CLenVT, false, NextMBBI);
case RISCV::PseudoCheriAtomicSwap8:
return expandAtomicBinOp(MBB, MBBI, AtomicRMWInst::Xchg, false, MVT::i8,
Expand Down Expand Up @@ -272,7 +273,8 @@ bool RISCVExpandAtomicPseudo::expandMI(MachineBasicBlock &MBB,
case RISCV::PseudoCheriAtomicLoadUMinCap:
return expandAtomicMinMaxOp(MBB, MBBI, AtomicRMWInst::UMin, false, CLenVT,
true, NextMBBI);
case RISCV::PseudoCheriCmpXchgCap:
case RISCV::PseudoCheriCmpXchgCapAddr:
case RISCV::PseudoCheriCmpXchgCapExact:
return expandAtomicCmpXchg(MBB, MBBI, false, CLenVT, true, NextMBBI);
}

Expand Down Expand Up @@ -1020,8 +1022,8 @@ bool RISCVExpandAtomicPseudo::expandAtomicCmpXchg(
BuildMI(LoopHeadMBB, DL, TII->get(getLRForRMW(PtrIsCap, Ordering, VT)),
DestReg)
.addReg(AddrReg);
assert(MI.hasOneMemOperand());
if (VT.isFatPointer() && MI.memoperands()[0]->isExactCompare()) {
bool ExactCapCompare = MI.getOpcode() == RISCV::PseudoCheriCmpXchgCapExact;
if (VT.isFatPointer() && ExactCapCompare) {
BuildMI(LoopHeadMBB, DL, TII->get(RISCV::CSEQX), ScratchReg)
.addReg(DestReg, 0)
.addReg(CmpValReg, 0);
Expand Down
12 changes: 8 additions & 4 deletions llvm/lib/Target/RISCV/RISCVInstrInfoXCheri.td
Original file line number Diff line number Diff line change
Expand Up @@ -1621,7 +1621,8 @@ def PseudoAtomicLoadMinCap : PseudoAMO<GPCR> { let Size = 24; }
def PseudoAtomicLoadUMaxCap : PseudoAMO<GPCR> { let Size = 24; }
def PseudoAtomicLoadUMinCap : PseudoAMO<GPCR> { let Size = 24; }
def PseudoAtomicLoadNandCap : PseudoAMO<GPCR> { let Size = 24; }
def PseudoCmpXchgCap : PseudoCmpXchg<GPCR> { let Size = 16; }
def PseudoCmpXchgCapAddr : PseudoCmpXchg<GPCR> { let Size = 16; }
def PseudoCmpXchgCapExact : PseudoCmpXchg<GPCR> { let Size = 16; }
} // Predicates = [HasCheri, HasStdExtA]f

let Predicates = [HasCheri, HasStdExtA, NotCapMode] in {
Expand All @@ -1635,7 +1636,8 @@ defm : PseudoAMOPat<"atomic_load_min_cap", PseudoAtomicLoadMinCap, GPCR>;
defm : PseudoAMOPat<"atomic_load_umax_cap", PseudoAtomicLoadUMaxCap, GPCR>;
defm : PseudoAMOPat<"atomic_load_umin_cap", PseudoAtomicLoadUMinCap, GPCR>;
defm : PseudoAMOPat<"atomic_load_nand_cap", PseudoAtomicLoadNandCap, GPCR>;
defm : PseudoCmpXchgPat<"atomic_cmp_swap_cap", PseudoCmpXchgCap, GPCR>;
defm : PseudoCmpXchgPat<"atomic_cmp_swap_cap_addr", PseudoCmpXchgCapAddr, GPCR>;
defm : PseudoCmpXchgPat<"atomic_cmp_swap_cap_exact", PseudoCmpXchgCapExact, GPCR>;
} // Predicates = [HasCheri, HasStdExtA, NotCapMode]

/// Capability Mode Instructions
Expand Down Expand Up @@ -1782,7 +1784,8 @@ def PseudoCheriAtomicLoadMinCap : PseudoCheriAMO<GPCR> { let Size = 24; }
def PseudoCheriAtomicLoadUMaxCap : PseudoCheriAMO<GPCR> { let Size = 24; }
def PseudoCheriAtomicLoadUMinCap : PseudoCheriAMO<GPCR> { let Size = 24; }
def PseudoCheriAtomicLoadNandCap : PseudoCheriAMO<GPCR> { let Size = 24; }
def PseudoCheriCmpXchgCap : PseudoCheriCmpXchg<GPCR> { let Size = 16; }
def PseudoCheriCmpXchgCapAddr : PseudoCheriCmpXchg<GPCR> { let Size = 16; }
def PseudoCheriCmpXchgCapExact : PseudoCheriCmpXchg<GPCR> { let Size = 16; }
} // Predicates = [HasCheri, HasStdExtA]

let Predicates = [HasCheri, HasStdExtA, IsRV64] in {
Expand Down Expand Up @@ -1981,7 +1984,8 @@ defm : PseudoCheriCmpXchgPat<"atomic_cmp_swap_8", PseudoCheriCmpXchg8>;
defm : PseudoCheriCmpXchgPat<"atomic_cmp_swap_16", PseudoCheriCmpXchg16>;
defm : PseudoCheriCmpXchgPat<"atomic_cmp_swap_32", PseudoCheriCmpXchg32>;

defm : PseudoCheriCmpXchgPat<"atomic_cmp_swap_cap", PseudoCheriCmpXchgCap, GPCR>;
defm : PseudoCheriCmpXchgPat<"atomic_cmp_swap_cap_addr", PseudoCheriCmpXchgCapAddr, GPCR>;
defm : PseudoCheriCmpXchgPat<"atomic_cmp_swap_cap_exact", PseudoCheriCmpXchgCapExact, GPCR>;

} // Predicates = [HasCheri, HasStdExtA, IsCapMode]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@
; CHERI-GENERIC-UTC: mir
@IF-RISCV@; RUN: llc @PURECAP_HARDFLOAT_ARGS@ -mattr=+a < %s --stop-after=branch-folder | FileCheck %s --check-prefixes=MIR
@IFNOT-RISCV@; RUN: llc @PURECAP_HARDFLOAT_ARGS@ < %s --stop-after=branch-folder --enable-tail-merge | FileCheck %s --check-prefixes=MIR
@IF-RISCV@; RUN: not --crash llc @PURECAP_HARDFLOAT_ARGS@ -mattr=+a < %s
@IFNOT-RISCV@; RUN: not --crash llc @PURECAP_HARDFLOAT_ARGS@ --enable-tail-merge < %s
; Note: cat %s is needed so that update_mir_test_checks.py does not process these RUN lines.
@IF-RISCV@; RUN: cat %s | llc @PURECAP_HARDFLOAT_ARGS@ -mattr=+a | FileCheck %s
@IFNOT-RISCV@; RUN: cat %s | llc @PURECAP_HARDFLOAT_ARGS@ --enable-tail-merge | FileCheck %s
; REQUIRES: asserts

; The branch-folder MIR pass will merge the two blocks inside these functions but
; since the base pointer is distinct it will have two MachineMemOperands.
; The cmpxchg exact logic stored the exact flag in the MachineMemOperand and
; previously assumed there would only ever be one operand, so this test ensures
; we can handle the merged logic.
; we can handle the merged logic by adding separate pseudo instructions (which
; ensures that the branches with different comparisons can no longer be merged).

define dso_local signext i32 @merge_i32(i1 %cond1, ptr addrspace(200) %ptr, i32 %newval, i32 %cmpval) {
entry:
Expand Down Expand Up @@ -66,7 +68,6 @@ end:
ret i32 0
}

; FIXME: these two branches should not be merged!
define dso_local signext i32 @merge_ptr_mismatch_exact_flag(i1 %cond1, ptr addrspace(200) %ptr, ptr addrspace(200) %newval, ptr addrspace(200) %cmpval) {
entry:
br i1 %cond1, label %if.then, label %if.else
Expand Down
Loading

0 comments on commit 056a8a4

Please sign in to comment.