Skip to content

Commit

Permalink
Increase default logging for MLIR embedding rewrite passes
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 623625663
  • Loading branch information
patnotz authored and pull[bot] committed Apr 30, 2024
1 parent 4dbba26 commit 2347423
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -193,25 +193,26 @@ struct EmbeddingPipeliningPass

bool UseEmbeddingPipelining(ModuleOp& module) {
// Enable automated pipelining pass unless:
// 1. The user disables it via flog, or
// 1. The user disables it via flag, or
// 2. The graph contains TF.Summary ops. Graphs like this typically only run
// for a single step which doesn't work in pipelining.

if (tensorflow::GetBuildXlaOpsPassFlags()
->tf_xla_disable_full_embedding_pipelining)
->tf_xla_disable_full_embedding_pipelining) {
LOG(INFO) << "Embedding pipelining disabled via flag.";
return false;

}
// Detect summaries by looking for key Ops in the graph. It would be better to
// do this via operator attributes rather than looking for a specific op.
WalkResult walk_result = module.walk([&](Operation* op) -> WalkResult {
if (llvm::isa<TF::WriteSummaryOp>(op)) return WalkResult::interrupt();
return WalkResult::advance();
});
if (walk_result.wasInterrupted()) {
VLOG(1) << "TF summaries detected - disabling embedding pipelining.";
LOG(INFO) << "TF summaries detected - disabling embedding pipelining.";
return false;
}
VLOG(1) << "Embedding pipelining rewrite enabled.";
LOG(INFO) << "Embedding pipelining rewrite enabled.";
return true;
}

Expand Down Expand Up @@ -1685,12 +1686,11 @@ Operation* LiftNonTpuFuncCaller(mlir::OpBuilder& builder,
}

void EmbeddingPipeliningPass::runOnOperation() {
VLOG(3) << "EmbeddingPipeliningPass::runOnOperation()";
LOG(INFO) << "EmbeddingPipeliningPass::runOnOperation()";
ModuleOp module = getOperation();

// We only use one of the EmbeddingPipelining and EmbeddingSequencing passes.
if (!UseEmbeddingPipelining(module)) return;
VLOG(1) << "Embedding pipelining rewrite enabled.";

SymbolTable symbol_table(module);

Expand Down Expand Up @@ -1722,7 +1722,7 @@ void EmbeddingPipeliningPass::runOnOperation() {
// If there are no forward pass ops, there is no SC, so we end early.
if (forward_pass_ops.empty()) {
if (backward_pass_ops.empty()) {
VLOG(1) << "no pipelining ops found";
LOG(INFO) << "no pipelining ops found";
return;
} else {
(*backward_pass_ops.begin())->emitOpError()
Expand Down Expand Up @@ -1812,11 +1812,11 @@ void EmbeddingPipeliningPass::runOnOperation() {
if (failed(result)) return signalPassFailure();
merged_set.insert(non_tpu_ops.begin(), non_tpu_ops.end());

VLOG(3) << "Forwards pass " << forward_pass_ops.size()
<< " ops, backwards pass " << backward_pass_ops.size()
<< " ops, core " << core_tpu_ops.size()
<< " ops. Total = " << merged_set.size() << " of "
<< GetNumOps(loop_body_func);
LOG(INFO) << "Forwards pass " << forward_pass_ops.size()
<< " ops, backwards pass " << backward_pass_ops.size()
<< " ops, core " << core_tpu_ops.size()
<< " ops. Total = " << merged_set.size() << " of "
<< GetNumOps(loop_body_func);

builder.setInsertionPointAfter(*non_tpu_ops.begin());
TF::StatefulPartitionedCallOp non_tpu_caller = nullptr;
Expand Down Expand Up @@ -2185,7 +2185,8 @@ void EmbeddingPipeliningPass::runOnOperation() {
int parallel_iterations = parallel_iterations_flag > 0
? parallel_iterations_flag
: orig_while_op.getParallelIterations();
VLOG(1) << "Setting parallel_iterations_flag to " << parallel_iterations_flag;
LOG(INFO) << "Setting parallel_iterations_flag to "
<< parallel_iterations_flag;
auto new_while_op = builder.create<TF::WhileOp>(
orig_while_op->getLoc(), new_body_return_types,
new_while_operands.getArrayRef(), cond.getSymName(), body.getSymName(),
Expand Down Expand Up @@ -2252,7 +2253,7 @@ void EmbeddingPipeliningPass::runOnOperation() {
orig_while_op.body_function().erase();
orig_while_op.erase();

VLOG(3) << "EmbeddingPipeliningPass::runOnOperation done.";
LOG(INFO) << "EmbeddingPipeliningPass::runOnOperation done.";
}
} // namespace

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -413,8 +413,7 @@ LogicalResult FindForwardPassOps(OpBuilder& builder,
if (is_non_variable && is_variable) {
loop_body_func.emitOpError()
<< "resource input " << argument.getArgNumber()
<< " is used both as a varible and not "
<< " a variable";
<< " is used both as a varible and not a variable";
return LogicalResult::failure();
}
if (is_variable && use_in_forward)
Expand Down Expand Up @@ -772,7 +771,7 @@ LogicalResult ExtractOpsAsFunc(
}

void EmbeddingSequencingPass::runOnOperation() {
VLOG(3) << "EmbeddingSequencingPass::runOnOperation()";
LOG(INFO) << "EmbeddingSequencingPass::runOnOperation()";
ModuleOp module = getOperation();

llvm::SetVector<Operation*> forward_pass_ops;
Expand Down Expand Up @@ -803,14 +802,16 @@ void EmbeddingSequencingPass::runOnOperation() {
// If there are no forward pass ops, there is no SC, so we end early.
if (forward_pass_ops.empty()) {
if (backward_pass_ops.empty()) {
LOG(INFO) << "No unprocessed embedding ops found - skipping embedding "
<< "sequencing rewrite.";
return;
} else {
(*backward_pass_ops.begin())->emitOpError()
<< "embedding backwards pass op with no forwards pass ops.";
return signalPassFailure();
}
}
VLOG(1) << "Embedding sequencing rewrite enabled.";
LOG(INFO) << "Embedding sequencing rewrite enabled.";

// Ensure that all ops are in the same region, and have the same replication
// info.
Expand Down Expand Up @@ -860,18 +861,17 @@ void EmbeddingSequencingPass::runOnOperation() {
TF::WhileOp while_op = nullptr;
result = FindOwningWhileOp(loop_body_func, module, &while_op);
if (failed(result)) {
VLOG(1) << "WhileOp not found: assuming external loop.";
LOG(INFO) << "WhileOp not found: assuming external loop.";
} else {
// Override the WhileOp parallel_iterations if requested by flag.
int parallel_iterations_flag = tensorflow::GetBuildXlaOpsPassFlags()
->tf_xla_embedding_parallel_iterations;
if (parallel_iterations_flag > 0) {
VLOG(1) << "Setting WhileOp parallel_iterations_flag to "
<< parallel_iterations_flag;
LOG(INFO) << "Setting WhileOp parallel_iterations_flag to "
<< parallel_iterations_flag;
while_op.setParallelIterations(parallel_iterations_flag);
} else {
VLOG(1) << "Using original WhileOp parallel_iterations = "
<< while_op.getParallelIterations();
LOG(INFO) << "Using original WhileOp parallel_iteration";
}
}

Expand All @@ -898,11 +898,11 @@ void EmbeddingSequencingPass::runOnOperation() {
if (failed(result)) return signalPassFailure();
merged_set.insert(non_tpu_ops.begin(), non_tpu_ops.end());

VLOG(2) << "Forwards pass " << forward_pass_ops.size()
<< " ops, backwards pass " << backward_pass_ops.size()
<< " ops, core " << core_tpu_ops.size()
<< " ops. Total = " << merged_set.size() << " of "
<< GetNumOps(loop_body_func) << ".\n";
LOG(INFO) << "Forwards pass " << forward_pass_ops.size()
<< " ops, backwards pass " << backward_pass_ops.size()
<< " ops, core " << core_tpu_ops.size()
<< " ops. Total = " << merged_set.size() << " of "
<< GetNumOps(loop_body_func) << ".\n";

builder.setInsertionPointAfter(*non_tpu_ops.begin());
Operation* non_tpu_caller = nullptr;
Expand Down Expand Up @@ -936,7 +936,7 @@ void EmbeddingSequencingPass::runOnOperation() {
metadata_op->erase();
compilation_op->erase();

VLOG(3) << "EmbeddingSequencingPass::runOnOperation done.";
LOG(INFO) << "EmbeddingSequencingPass::runOnOperation done.";
}

} // namespace
Expand Down

0 comments on commit 2347423

Please sign in to comment.