Skip to content

Commit

Permalink
Merge pull request #2424 from ROCm/r2.14-rocm-enhanced-fp16-atomic-add
Browse files Browse the repository at this point in the history
[r2.14-rocm-enhanced] Enabled real HW fp16 atomic add instead of CAS loop for MI200 and lat…
  • Loading branch information
jayfurmanek committed May 9, 2024
2 parents 10cc386 + 76a8ca7 commit b0d9c59
Showing 1 changed file with 37 additions and 0 deletions.
37 changes: 37 additions & 0 deletions tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,43 @@ void EmitAMDGPUAtomicAdd(llvm::IRBuilder<>* builder,
output_address,
llvm::PointerType::getWithSamePointeeType(output_address_type,
/*AddressSpace=*/1));
if (source->getType()->getPrimitiveSizeInBits() == 16) {
llvm::VectorType* half2type = llvm::VectorType::get(builder->getHalfTy(),
llvm::ElementCount::getFixed(2));
auto i16 = builder->getInt16Ty();
auto i32 = builder->getInt32Ty();
auto i64 = builder->getInt64Ty();
auto half2ptr = llvm::PointerType::get(half2type, 1);
auto intptr = builder->CreatePtrToInt(output_address, i64);
auto alignment = builder->CreateAnd(intptr, llvm::ConstantInt::get(i64, 2ull));
intptr = builder->CreateAnd(intptr, llvm::ConstantInt::get(i64, ~3ull));
output_ptr = builder->CreateIntToPtr(intptr, half2ptr);

auto shift = builder->CreateShl(builder->CreateTrunc(alignment, i32), 3);
auto i16src = builder->CreateBitCast(source, i16);
auto intsrc = builder->CreateZExt(i16src, i32);
source = builder->CreateShl(intsrc, shift);
source = builder->CreateBitCast(source, half2type);

llvm::Module* module = builder->GetInsertBlock()->getModule();
std::vector<llvm::Type*> ir_input_types{half2ptr, half2type};

llvm::FunctionType* callee_type = llvm::FunctionType::get(
half2type, ir_input_types, false);

// Declares the callee if it is not declared already.
llvm::Function* callee = llvm::dyn_cast<llvm::Function>(
builder->GetInsertBlock()
->getModule()
->getOrInsertFunction("llvm.amdgcn.global.atomic.fadd.v2f16.p1.v2f16", callee_type)
.getCallee());

callee->addFnAttr(llvm::Attribute::NoUnwind);
callee->setMemoryEffects(llvm::MemoryEffects::argMemOnly());

builder->CreateCall(callee, {output_ptr, source});//llvm_ir::AsArrayRef(operands));
return;
}

builder->CreateAtomicRMW(
llvm::AtomicRMWInst::FAdd, output_ptr, source, llvm::MaybeAlign(),
Expand Down

0 comments on commit b0d9c59

Please sign in to comment.