Skip to content

Commit

Permalink
[mlir] Expose Value hierarchy to C API
Browse files Browse the repository at this point in the history
The Value hierarchy consists of BlockArgument and OpResult, both of which
derive Value. Introduce IsA functions and functions specific to each class,
similarly to other class hierarchies. Also, introduce functions for
pointer-comparison of Block and Operation that are necessary for testing and
are generally useful.

Reviewed By: stellaraccident, mehdi_amini

Differential Revision: https://reviews.llvm.org/D89714
  • Loading branch information
ftynse committed Oct 20, 2020
1 parent 595c615 commit 39613c2
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 6 deletions.
32 changes: 32 additions & 0 deletions mlir/include/mlir-c/IR.h
Expand Up @@ -241,6 +241,10 @@ void mlirOperationDestroy(MlirOperation op);
/** Checks whether the underlying operation is null. */
static inline int mlirOperationIsNull(MlirOperation op) { return !op.ptr; }

/** Checks whether two operation handles point to the same operation. This does
* not perform deep comparison. */
int mlirOperationEqual(MlirOperation op, MlirOperation other);

/** Returns the number of regions attached to the given operation. */
intptr_t mlirOperationGetNumRegions(MlirOperation op);

Expand Down Expand Up @@ -348,6 +352,10 @@ void mlirBlockDestroy(MlirBlock block);
/** Checks whether a block is null. */
static inline int mlirBlockIsNull(MlirBlock block) { return !block.ptr; }

/** Checks whether two blocks handles point to the same block. This does not
* perform deep comparison. */
int mlirBlockEqual(MlirBlock block, MlirBlock other);

/** Returns the block immediately following the given block in its parent
* region. */
MlirBlock mlirBlockGetNextInRegion(MlirBlock block);
Expand Down Expand Up @@ -397,6 +405,30 @@ void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
/** Returns whether the value is null. */
static inline int mlirValueIsNull(MlirValue value) { return !value.ptr; }

/** Returns 1 if the value is a block argument, 0 otherwise. */
int mlirValueIsABlockArgument(MlirValue value);

/** Returns 1 if the value is an operation result, 0 otherwise. */
int mlirValueIsAOpResult(MlirValue value);

/** Returns the block in which this value is defined as an argument. Asserts if
* the value is not a block argument. */
MlirBlock mlirBlockArgumentGetOwner(MlirValue value);

/** Returns the position of the value in the argument list of its block. */
intptr_t mlirBlockArgumentGetArgNumber(MlirValue value);

/** Sets the type of the block argument to the given type. */
void mlirBlockArgumentSetType(MlirValue value, MlirType type);

/** Returns an operation that produced this value as its result. Asserts if the
* value is not an op result. */
MlirOperation mlirOpResultGetOwner(MlirValue value);

/** Returns the position of the value in the list of results of the operation
* that produced it. */
intptr_t mlirOpResultGetResultNumber(MlirValue value);

/** Returns the type of the value. */
MlirType mlirValueGetType(MlirValue value);

Expand Down
38 changes: 38 additions & 0 deletions mlir/lib/CAPI/IR/IR.cpp
Expand Up @@ -211,6 +211,10 @@ MlirOperation mlirOperationCreate(const MlirOperationState *state) {

void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); }

int mlirOperationEqual(MlirOperation op, MlirOperation other) {
return unwrap(op) == unwrap(other);
}

intptr_t mlirOperationGetNumRegions(MlirOperation op) {
return static_cast<intptr_t>(unwrap(op)->getNumRegions());
}
Expand Down Expand Up @@ -343,6 +347,10 @@ MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType *args) {
return wrap(b);
}

int mlirBlockEqual(MlirBlock block, MlirBlock other) {
return unwrap(block) == unwrap(other);
}

MlirBlock mlirBlockGetNextInRegion(MlirBlock block) {
return wrap(unwrap(block)->getNextNode());
}
Expand Down Expand Up @@ -412,6 +420,36 @@ void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
/* Value API. */
/* ========================================================================== */

int mlirValueIsABlockArgument(MlirValue value) {
return unwrap(value).isa<BlockArgument>();
}

int mlirValueIsAOpResult(MlirValue value) {
return unwrap(value).isa<OpResult>();
}

MlirBlock mlirBlockArgumentGetOwner(MlirValue value) {
return wrap(unwrap(value).cast<BlockArgument>().getOwner());
}

intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) {
return static_cast<intptr_t>(
unwrap(value).cast<BlockArgument>().getArgNumber());
}

void mlirBlockArgumentSetType(MlirValue value, MlirType type) {
unwrap(value).cast<BlockArgument>().setType(unwrap(type));
}

MlirOperation mlirOpResultGetOwner(MlirValue value) {
return wrap(unwrap(value).cast<OpResult>().getOwner());
}

intptr_t mlirOpResultGetResultNumber(MlirValue value) {
return static_cast<intptr_t>(
unwrap(value).cast<OpResult>().getResultNumber());
}

MlirType mlirValueGetType(MlirValue value) {
return wrap(unwrap(value).getType());
}
Expand Down
57 changes: 51 additions & 6 deletions mlir/test/CAPI/ir.c
Expand Up @@ -153,10 +153,12 @@ struct ModuleStats {
unsigned numBlocks;
unsigned numRegions;
unsigned numValues;
unsigned numBlockArguments;
unsigned numOpResults;
};
typedef struct ModuleStats ModuleStats;

void collectStatsSingle(OpListNode *head, ModuleStats *stats) {
int collectStatsSingle(OpListNode *head, ModuleStats *stats) {
MlirOperation operation = head->op;
stats->numOperations += 1;
stats->numValues += mlirOperationGetNumResults(operation);
Expand All @@ -166,12 +168,39 @@ void collectStatsSingle(OpListNode *head, ModuleStats *stats) {

stats->numRegions += numRegions;

intptr_t numResults = mlirOperationGetNumResults(operation);
for (intptr_t i = 0; i < numResults; ++i) {
MlirValue result = mlirOperationGetResult(operation, i);
if (!mlirValueIsAOpResult(result))
return 1;
if (mlirValueIsABlockArgument(result))
return 2;
if (!mlirOperationEqual(operation, mlirOpResultGetOwner(result)))
return 3;
if (i != mlirOpResultGetResultNumber(result))
return 4;
++stats->numOpResults;
}

for (unsigned i = 0; i < numRegions; ++i) {
MlirRegion region = mlirOperationGetRegion(operation, i);
for (MlirBlock block = mlirRegionGetFirstBlock(region);
!mlirBlockIsNull(block); block = mlirBlockGetNextInRegion(block)) {
++stats->numBlocks;
stats->numValues += mlirBlockGetNumArguments(block);
intptr_t numArgs = mlirBlockGetNumArguments(block);
stats->numValues += numArgs;
for (intptr_t j = 0; j < numArgs; ++j) {
MlirValue arg = mlirBlockGetArgument(block, j);
if (!mlirValueIsABlockArgument(arg))
return 5;
if (mlirValueIsAOpResult(arg))
return 6;
if (!mlirBlockEqual(block, mlirBlockArgumentGetOwner(arg)))
return 7;
if (j != mlirBlockArgumentGetArgNumber(arg))
return 8;
++stats->numBlockArguments;
}

for (MlirOperation child = mlirBlockGetFirstOperation(block);
!mlirOperationIsNull(child);
Expand All @@ -183,9 +212,10 @@ void collectStatsSingle(OpListNode *head, ModuleStats *stats) {
}
}
}
return 0;
}

void collectStats(MlirOperation operation) {
int collectStats(MlirOperation operation) {
OpListNode *head = malloc(sizeof(OpListNode));
head->op = operation;
head->next = NULL;
Expand All @@ -196,9 +226,13 @@ void collectStats(MlirOperation operation) {
stats.numBlocks = 0;
stats.numRegions = 0;
stats.numValues = 0;
stats.numBlockArguments = 0;
stats.numOpResults = 0;

do {
collectStatsSingle(head, &stats);
int retval = collectStatsSingle(head, &stats);
if (retval)
return retval;
OpListNode *next = head->next;
free(head);
head = next;
Expand All @@ -209,6 +243,11 @@ void collectStats(MlirOperation operation) {
fprintf(stderr, "Number of blocks: %u\n", stats.numBlocks);
fprintf(stderr, "Number of regions: %u\n", stats.numRegions);
fprintf(stderr, "Number of values: %u\n", stats.numValues);
fprintf(stderr, "Number of block arguments: %u\n", stats.numBlockArguments);
fprintf(stderr, "Number of op results: %u\n", stats.numOpResults);
if (stats.numValues != stats.numBlockArguments + stats.numOpResults)
return 100;
return 0;
}

static void printToStderr(const char *str, intptr_t len, void *userData) {
Expand Down Expand Up @@ -914,13 +953,19 @@ int main() {
// CHECK: }
// clang-format on

collectStats(module);
fprintf(stderr, "@stats\n");
int errcode = collectStats(module);
fprintf(stderr, "%d\n", errcode);
// clang-format off
// CHECK-LABEL: @stats
// CHECK: Number of operations: 13
// CHECK: Number of attributes: 4
// CHECK: Number of blocks: 3
// CHECK: Number of regions: 3
// CHECK: Number of values: 9
// CHECK: Number of block arguments: 3
// CHECK: Number of op results: 6
// CHECK: 0
// clang-format on

printFirstOfEach(ctx, module);
Expand Down Expand Up @@ -988,7 +1033,7 @@ int main() {
// CHECK: 0
// clang-format on
fprintf(stderr, "@types\n");
int errcode = printStandardTypes(ctx);
errcode = printStandardTypes(ctx);
fprintf(stderr, "%d\n", errcode);

// clang-format off
Expand Down

0 comments on commit 39613c2

Please sign in to comment.