Skip to content

Commit

Permalink
[xla:cpu] Emit partitioned loops if operation marked for parallel exe…
Browse files Browse the repository at this point in the history
…cution

+ temporary rolled back thread pool support in HostKernel as I'll be reworking it to support thunks

PiperOrigin-RevId: 641724588
  • Loading branch information
ezhulenev authored and tkiela1 committed Jun 9, 2024
1 parent 63cde3b commit 32ba408
Show file tree
Hide file tree
Showing 9 changed files with 259 additions and 141 deletions.
4 changes: 4 additions & 0 deletions xla/service/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -617,8 +617,11 @@ cc_library(
srcs = ["ir_emitter2.cc"],
hdrs = ["ir_emitter2.h"],
deps = [
":backend_config_proto_cc",
":elemental_math_emitter",
":ir_emitter",
":parallel_loop_emitter",
":shape_partition",
"//xla:shape_util",
"//xla:util",
"//xla/hlo/ir:hlo",
Expand Down Expand Up @@ -1486,6 +1489,7 @@ cc_library(
hdrs = ["shape_partition.h"],
deps = [
"//xla:shape_util",
"@com_google_absl//absl/types:span",
],
)

Expand Down
8 changes: 1 addition & 7 deletions xla/service/cpu/benchmarks/elementwise_benchmark_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,10 @@ static void BM_AddF32(benchmark::State& state) {
std::string_view hlo = R"(
HloModule add_f32_$d0
add {
p0 = f32[1,2,1,$d0,256] parameter(0)
p1 = f32[1,2,1,$d0,256] parameter(1)
ROOT add = f32[1,2,1,$d0,256] add(p0, p1)
}
ENTRY e {
p0 = f32[1,2,1,$d0,256] parameter(0)
p1 = f32[1,2,1,$d0,256] parameter(1)
ROOT fusion = f32[1,2,1,$d0,256] fusion(p0, p1), kind=kLoop, calls=add
ROOT add = f32[1,2,1,$d0,256] add(p0, p1)
}
)";

Expand Down
197 changes: 156 additions & 41 deletions xla/service/cpu/ir_emitter2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@ limitations under the License.

#include "xla/service/cpu/ir_emitter2.h"

#include <cstddef>
#include <cstdint>
#include <optional>
#include <string>
#include <string_view>
#include <utility>
#include <vector>

Expand All @@ -30,6 +33,7 @@ limitations under the License.
#include "llvm/IR/CallingConv.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
Expand All @@ -41,8 +45,11 @@ limitations under the License.
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/ir/hlo_schedule.h"
#include "xla/service/cpu/backend_config.pb.h"
#include "xla/service/cpu/elemental_math_emitter.h"
#include "xla/service/cpu/ir_emitter.h"
#include "xla/service/cpu/parallel_loop_emitter.h"
#include "xla/service/cpu/shape_partition.h"
#include "xla/service/elemental_ir_emitter.h"
#include "xla/service/llvm_ir/fused_ir_emitter.h"
#include "xla/service/llvm_ir/ir_array.h"
Expand Down Expand Up @@ -222,35 +229,6 @@ bool IrEmitter2::fast_min_max() const {
return hlo_module_.config().debug_options().xla_cpu_enable_fast_min_max();
}

static absl::Status EmitElementalLoops(
llvm::IRBuilder<>& b, const HloInstruction* instr,
const llvm_ir::ElementGenerator& element_generator,
absl::Span<const llvm_ir::IrArray> results) {
// We can emit loops for instruction with multiple results only if it is a
// fusion, reduce or reduce window.
bool multiple_results = results.size() > 1;
bool support_multiple_results = instr->opcode() == HloOpcode::kFusion ||
instr->opcode() == HloOpcode::kReduce ||
instr->opcode() == HloOpcode::kReduceWindow;

if (multiple_results && !support_multiple_results) {
return Internal(
"Multi-output host kernels are not supported for %s instruction",
HloOpcodeString(instr->opcode()));
}

if (multiple_results) {
TF_RETURN_IF_ERROR(llvm_ir::LoopEmitter(element_generator, results, &b)
.EmitLoop(llvm_ir::IrName(instr)));
} else {
TF_RETURN_IF_ERROR(
llvm_ir::LoopEmitter(element_generator, results.front(), &b)
.EmitLoop(llvm_ir::IrName(instr)));
}

return absl::OkStatus();
}

absl::StatusOr<IrEmitter2::KernelInfo> IrEmitter2::EmitElementalHostKernel(
const HloInstruction* instr) {
VLOG(2) << "Emit elemental host kernel: " << instr->name();
Expand All @@ -273,11 +251,12 @@ absl::StatusOr<IrEmitter2::KernelInfo> IrEmitter2::EmitElementalHostKernel(
llvm_ir::ElementGenerator element_generator =
elemental_emitter.MakeElementGenerator(instr, operand_to_generator);

TF_RETURN_IF_ERROR(EmitElementalLoops(b, instr, element_generator,
kernel_prototype.results));
return kernels_.emplace_back(
KernelInfo{kernel_prototype.function->getName().str(), se::BlockDim(),
se::ThreadDim()});
TF_ASSIGN_OR_RETURN(
se::ThreadDim thread_dims,
EmitElementalLoops(b, instr, kernel_prototype, element_generator));

return kernels_.emplace_back(KernelInfo{
kernel_prototype.function->getName().str(), se::BlockDim(), thread_dims});
}

absl::StatusOr<IrEmitter2::KernelInfo> IrEmitter2::EmitFusionHostKernel(
Expand Down Expand Up @@ -309,11 +288,12 @@ absl::StatusOr<IrEmitter2::KernelInfo> IrEmitter2::EmitFusionHostKernel(
auto element_generator,
fused_emitter.GetGenerator(*fusion->fused_expression_root()));

TF_RETURN_IF_ERROR(EmitElementalLoops(b, fusion, element_generator,
kernel_prototype.results));
return kernels_.emplace_back(
KernelInfo{kernel_prototype.function->getName().str(), se::BlockDim(),
se::ThreadDim()});
TF_ASSIGN_OR_RETURN(
se::ThreadDim thread_dims,
EmitElementalLoops(b, fusion, kernel_prototype, element_generator));

return kernels_.emplace_back(KernelInfo{
kernel_prototype.function->getName().str(), se::BlockDim(), thread_dims});
}

absl::StatusOr<IrEmitter2::KernelInfo> IrEmitter2::EmitReductionHostKernel(
Expand All @@ -330,7 +310,8 @@ absl::StatusOr<IrEmitter2::KernelInfo> IrEmitter2::EmitReductionHostKernel(

IrEmitter2::KernelThreadDims IrEmitter2::EmitKernelThreadDims(
llvm::IRBuilder<>& b, llvm::Value* call_frame) {
auto* tdims = b.CreateStructGEP(call_frame_ty_, call_frame, 0, "tdims_gep");
auto* td_gep = b.CreateStructGEP(call_frame_ty_, call_frame, 0, "tdims_gep");
auto* tdims = b.CreateLoad(b.getPtrTy(), td_gep, "tdims");
auto* x_gep = b.CreateStructGEP(thread_dims_ty_, tdims, 0, "tdim_x_gep");
auto* y_gep = b.CreateStructGEP(thread_dims_ty_, tdims, 1, "tdim_y_gep");
auto* z_gep = b.CreateStructGEP(thread_dims_ty_, tdims, 2, "tdim_z_gep");
Expand All @@ -342,7 +323,8 @@ IrEmitter2::KernelThreadDims IrEmitter2::EmitKernelThreadDims(

IrEmitter2::KernelThread IrEmitter2::EmitKernelThread(llvm::IRBuilder<>& b,
llvm::Value* call_frame) {
auto* tids = b.CreateStructGEP(call_frame_ty_, call_frame, 1, "tid_gep");
auto* t_gep = b.CreateStructGEP(call_frame_ty_, call_frame, 1, "tid_gep");
auto* tids = b.CreateLoad(b.getPtrTy(), t_gep, "tids");
auto* x_gep = b.CreateStructGEP(thread_ty_, tids, 0, "tid_x_gep");
auto* y_gep = b.CreateStructGEP(thread_ty_, tids, 1, "tid_y_gep");
auto* z_gep = b.CreateStructGEP(thread_ty_, tids, 2, "tid_z_gep");
Expand Down Expand Up @@ -430,4 +412,137 @@ IrEmitter2::KernelPrototype IrEmitter2::EmitKernelPrototype(
FlattenedResults(instr));
}

std::optional<IrEmitter2::ParallelConfig> IrEmitter2::GetParallelConfig(
const HloInstruction* instr) {
// Check if the instruction is marked for parallel execution.
auto backend_config = instr->backend_config<BackendConfig>();
if (!backend_config.ok() ||
backend_config->outer_dimension_partitions().empty()) {
return std::nullopt;
}

ParallelConfig config;
config.outer_dimension_partitions.assign(
backend_config->outer_dimension_partitions().begin(),
backend_config->outer_dimension_partitions().end());

return config;
}

IrEmitter2::ParallelPartitionBounds IrEmitter2::EmitParallelPartitionBounds(
llvm::IRBuilder<>& b, const KernelPrototype& kernel_prototype,
const ParallelConfig& parallel_config, const Shape& shape,
std::string_view name) {
ShapePartitionIterator it(shape, parallel_config.outer_dimension_partitions);

size_t num_parallel_dimensions =
parallel_config.outer_dimension_partitions.size();

// Create a constant array of all partition bounds. We will be indexing into
// this array using block and thread dimension indices passed in a call frame.
//
// Type: [#partitions x [#outer_dimensions x [lower_bound, upper_bound]]]
//
llvm::ArrayType* dim_bounds_ty = llvm::ArrayType::get(b.getInt64Ty(), 2);
llvm::ArrayType* partition_bounds_ty =
llvm::ArrayType::get(dim_bounds_ty, num_parallel_dimensions);
llvm::ArrayType* parallel_bounds_ty =
llvm::ArrayType::get(partition_bounds_ty, it.GetTotalPartitionCount());

// Build a nested array of partition bounds from shape partition iterator.
std::vector<llvm::Constant*> partition_bounds;
for (int64_t i = 0; i < it.GetTotalPartitionCount(); ++i) {
std::vector<llvm::Constant*> dim_counts;
for (auto [lower, size] : it.GetPartition(i)) {
dim_counts.push_back(llvm::ConstantArray::get(
dim_bounds_ty, {b.getInt64(lower), b.getInt64(lower + size)}));
}
partition_bounds.push_back(
llvm::ConstantArray::get(partition_bounds_ty, dim_counts));
}

llvm::Constant* parallel_bounds =
llvm::ConstantArray::get(parallel_bounds_ty, partition_bounds);

llvm::Module* module = b.GetInsertBlock()->getParent()->getParent();
llvm::GlobalVariable* parallel_bounds_global = new llvm::GlobalVariable(
/*M=*/*module,
/*Ty=*/parallel_bounds_ty,
/*isConstant=*/true,
/*Linkage=*/llvm::GlobalValue::PrivateLinkage,
/*Initializer=*/parallel_bounds,
/*Name=*/absl::StrCat(name, "_parallel_bounds"));

// Construct IR to load bounds for all parallel dimensions.
ParallelPartitionBounds bounds;
for (size_t i = 0; i < num_parallel_dimensions; ++i) {
llvm::Value* partition = kernel_prototype.thread.x;
llvm::Value* parallel_dim = b.getInt32(i);

llvm::Value* lower_gep = b.CreateInBoundsGEP(
parallel_bounds_ty, parallel_bounds_global,
{b.getInt32(0), partition, parallel_dim, b.getInt32(0)},
absl::StrCat("lo_dim_", i, "_gep"));

llvm::Value* upper_gep = b.CreateInBoundsGEP(
parallel_bounds_ty, parallel_bounds_global,
{b.getInt32(0), partition, parallel_dim, b.getInt32(1)},
absl::StrCat("up_dim_", i, "_gep"));

bounds.emplace_back(
b.CreateLoad(b.getInt64Ty(), lower_gep, absl::StrCat("lo_dim_", i)),
b.CreateLoad(b.getInt64Ty(), upper_gep, absl::StrCat("up_dim_", i)));
}

return bounds;
}

absl::StatusOr<se::ThreadDim> IrEmitter2::EmitElementalLoops(
llvm::IRBuilder<>& b, const HloInstruction* instr,
const KernelPrototype& kernel_prototype,
const llvm_ir::ElementGenerator& element_generator) {
// We can emit loops for instruction with multiple results only if it is a
// fusion, reduce or reduce window.
bool multiple_results = kernel_prototype.results.size() > 1;
bool support_multiple_results = instr->opcode() == HloOpcode::kFusion ||
instr->opcode() == HloOpcode::kReduce ||
instr->opcode() == HloOpcode::kReduceWindow;

auto parallel_config = GetParallelConfig(instr);
bool has_parallel_config = parallel_config.has_value();

if (multiple_results && !support_multiple_results) {
return Internal(
"Multi-output host kernels are not supported for %s instruction",
HloOpcodeString(instr->opcode()));
}

// TODO(ezhulenev): Support multiple results for parallel loops.
if (multiple_results) {
TF_RETURN_IF_ERROR(
llvm_ir::LoopEmitter(element_generator, kernel_prototype.results, &b)
.EmitLoop(llvm_ir::IrName(instr)));
return se::ThreadDim();
}

const llvm_ir::IrArray& result = kernel_prototype.results.front();

// Emit a loop for a single parallel partition with dynamic bounds computed
// from thread index.
if (has_parallel_config) {
ParallelPartitionBounds parallel_bounds = EmitParallelPartitionBounds(
b, kernel_prototype, *parallel_config, instr->shape(), instr->name());
TF_RETURN_IF_ERROR(
ParallelLoopEmitter(element_generator, result, &parallel_bounds, &b)
.EmitLoop(llvm_ir::IrName(instr)));
return se::ThreadDim(ShapePartitionAssigner::GetTotalPartitionCount(
parallel_config->outer_dimension_partitions));
}

// Emit a whole loop for the instruction.
TF_RETURN_IF_ERROR(llvm_ir::LoopEmitter(element_generator, result, &b)
.EmitLoop(llvm_ir::IrName(instr)));
return se::ThreadDim();
}

} // namespace xla::cpu
34 changes: 34 additions & 0 deletions xla/service/cpu/ir_emitter2.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ limitations under the License.
#ifndef XLA_SERVICE_CPU_IR_EMITTER2_H_
#define XLA_SERVICE_CPU_IR_EMITTER2_H_

#include <cstddef>
#include <cstdint>
#include <optional>
#include <string>
#include <string_view>
#include <utility>
Expand All @@ -31,6 +33,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/service/cpu/ir_emitter.h"
#include "xla/service/llvm_ir/ir_array.h"
#include "xla/service/llvm_ir/loop_emitter.h"
#include "xla/shape.h"
#include "xla/stream_executor/launch_dim.h"

Expand Down Expand Up @@ -123,6 +126,18 @@ class IrEmitter2 {
private:
class ElementalIrEmitter;

// Parallel partition bounds for parallelized outer dimensions:
// vector<[i64 lower_bound, i64 upper_bound]>
using ParallelPartitionBounds =
std::vector<std::pair<llvm::Value*, llvm::Value*>>;

// A config for running kernel in parallel. We rely on partitioning iteration
// space along the outer dimension(s) and run each partition as a separate
// task inside a runtime-managed thread pool.
struct ParallelConfig {
std::vector<int64_t> outer_dimension_partitions;
};

KernelThreadDims EmitKernelThreadDims(llvm::IRBuilder<>& b,
llvm::Value* call_frame);

Expand All @@ -132,6 +147,25 @@ class IrEmitter2 {
llvm::Value* call_frame, int64_t index,
const Shape& shape);

// Returns parallel config for the given instruction or std::nullopt if
// the instruction has to be compiled to a single threaded loop.
std::optional<ParallelConfig> GetParallelConfig(const HloInstruction* instr);

// Emits LLVM IR that computes parallel partition bounds from the call frame's
// block and thread dimensions and parallel execution config.
ParallelPartitionBounds EmitParallelPartitionBounds(
llvm::IRBuilder<>& b, const KernelPrototype& kernel_prototype,
const ParallelConfig& parallel_config, const Shape& shape,
std::string_view name);

// Emits LLVM IR using elemental loop emitter and the given element generator.
// If the instruction is parallelized, it will emit a parallel loop partition
// and return the requested number of execution threads.
absl::StatusOr<se::ThreadDim> EmitElementalLoops(
llvm::IRBuilder<>& b, const HloInstruction* instr,
const KernelPrototype& kernel_prototype,
const llvm_ir::ElementGenerator& element_generator);

bool fast_min_max() const;

const HloModule& hlo_module_;
Expand Down
Loading

0 comments on commit 32ba408

Please sign in to comment.