Skip to content

Commit

Permalink
[SME] Add support for inserting processor state annotations (#16761)
Browse files Browse the repository at this point in the history
Execution of SME instructions requires the processor be in a certain
state. This functionality can be can be controlled using LLVM function
level annotations such as "aarch64_pstate_sm_enabled" or
"aarch64_pstate_za_new" (see arm_utils.py for more information).

This commit exposes this functionality for AArch64 schedules where SME
intrinsics will be called. The attributes are intended to be added
at the block-level around the compute definition. They are prepended
with "pragma" to ensure they remain in the lowering.

In order to detect these attributes and convert them to the relevant
LLVM function attributes, a new AArch64 LLVM codegen backend is added.
This backend extends the functionality of `codegen_llvm` for AArch64
specific compilation.

Tests to check these attributes propagate correctly have been added.
  • Loading branch information
lhutton1 committed Mar 26, 2024
1 parent 4f3a863 commit ac2f478
Show file tree
Hide file tree
Showing 3 changed files with 300 additions and 2 deletions.
84 changes: 84 additions & 0 deletions python/tvm/topi/arm_cpu/pstate_attributes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
Specialized attributes that can be added to schedules to alter
the behaviour of AArch64 codegen.
"""


class SMEAttributes:
"""
This class serves as a convenience wrapper for processor state annotations
relating to the Scalable Matrix Extension (SME). Processor state annotations
are inserted at compile time and alter some global state of the processor
during execution. For example, the streaming mode attribute can be used to
transfer some vector operations to a separate processing element. These
attributes can be added to block-level annotations in AArch64 schedules to
define a desired state.
Please refer to the following pages for more information regarding the SME
attributes and their behaviours:
- https://arm-software.github.io/acle/main/acle.html#markdown-toc-sme-attributes
- https://llvm.org/docs/AArch64SME.html
Attributes
----------
STREAMING_MODE : str
Whether execution should occur in regular mode or streaming mode. When
enabled, some vector operations may be transferred to a separate processing
element.
ZA_STORAGE : str
Defines how the ZA area of storage provided by the SME extension should be
utilized.
"""

STREAMING_MODE = "pragma_aarch64_pstate_sm"

class StreamingModeValues:
"""
Streaming mode attribute values. By default, a function is considered
'non-streaming' (often referred to as 'regular').
Attributes
----------
ENABLED : str
The processor state must be in streaming mode before executing the marked function.
COMPATIBLE : str
The marked function can be run in either streaming or non-streaming mode.
"""

ENABLED = "enabled"
COMPATIBLE = "compatible"

ZA_STORAGE = "pragma_aarch64_pstate_za"

class ZAStorageValues:
"""
ZA Storage attribure values. By default, a function has no ZA state. In other words, it
does not use the ZA storage.
Attributes
----------
NEW : str
A new ZA state is created "from scratch".
SHARED : str
The ZA state is shared with the calling function.
"""

NEW = "new"
SHARED = "shared"
102 changes: 102 additions & 0 deletions src/target/llvm/codegen_aarch64.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/target/llvm/codegen_aarch64.cc
* \brief AArch64 specific LLVM code generator.
*/
#ifdef TVM_LLVM_VERSION

#include <llvm/IR/Intrinsics.h>
#include <llvm/Target/TargetMachine.h>
#include <tvm/runtime/registry.h>

#include "codegen_cpu.h"
#include "llvm_instance.h"

namespace tvm {
namespace codegen {

class CodeGenAArch64 final : public CodeGenCPU {
public:
CodeGenAArch64() = default;
virtual ~CodeGenAArch64() = default;

void VisitStmt_(const AttrStmtNode* op);
void AddFunction(const GlobalVar& gvar, const PrimFunc& f);

bool func_has_pstate_sm = false;
bool func_has_pstate_za = false;
};

void CodeGenAArch64::AddFunction(const GlobalVar& gvar, const PrimFunc& f) {
func_has_pstate_sm = false;
func_has_pstate_za = false;
CodeGenCPU::AddFunction(gvar, f);
}

/*!
* \brief Visit and handle AArch64 specific pragmas. To be AArch64 specific,
* the expectation is that they are prepended with "pragma_aarch64".
*/
void CodeGenAArch64::VisitStmt_(const AttrStmtNode* op) {
std::string attr_key = op->attr_key;

if (!tir::attr::IsPragmaKey(attr_key)) {
CodeGenCPU::VisitStmt_(op);
return;
}
bool is_aarch64_specific_pragma = attr_key.substr(7, 7) == "aarch64";
if (!is_aarch64_specific_pragma) {
CodeGenCPU::VisitStmt_(op);
return;
}

const auto* attr_value = op->value.as<StringImmNode>();
ICHECK(attr_value) << "Expect " << attr_key << " to have a String value but was "
<< op->value->GetTypeKey();

std::string aarch64_attr_key = attr_key.substr(7);
if (aarch64_attr_key == "aarch64_pstate_sm") {
ICHECK(!func_has_pstate_sm) << "Multiple definitions of " << op->attr_key
<< " attribute found in the function "
<< function_->getName().data();
function_->addFnAttr({aarch64_attr_key + "_" + attr_value->value});
func_has_pstate_sm = true;
} else if (aarch64_attr_key == "aarch64_pstate_za") {
ICHECK(!func_has_pstate_za) << "Multiple definitions of " << op->attr_key
<< " attribute found in the function "
<< function_->getName().data();
function_->addFnAttr({aarch64_attr_key + "_" + attr_value->value});
func_has_pstate_za = true;
} else {
LOG(WARNING) << "Unknown pragma " << op->attr_key;
}
this->VisitStmt(op->body);
}

TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_aarch64")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
*rv = static_cast<void*>(new CodeGenAArch64());
});

} // namespace codegen
} // namespace tvm

#endif // TVM_LLVM_VERSION
116 changes: 114 additions & 2 deletions tests/python/codegen/test_target_codegen_aarch64.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import re

import pytest

import tvm
from tvm import te
from tvm.script import tir as T
import re
import pytest
from tvm.topi.arm_cpu.pstate_attributes import SMEAttributes

from tvm.target.codegen import llvm_version_major

Expand Down Expand Up @@ -533,5 +537,113 @@ def my_func(a: T.handle):
assert re.findall(r" store <vscale x 4 x float>", llvm), "No scalable store in generated LLVM."


@pytest.mark.skipif(
llvm_version_major() < 16, reason="Test requires an LLVM version of at least 16 to target SME"
)
@pytest.mark.parametrize(
"attr_key,attr_value,expected",
[
(
SMEAttributes.STREAMING_MODE,
SMEAttributes.StreamingModeValues.ENABLED,
"aarch64_pstate_sm_enabled",
),
(
SMEAttributes.STREAMING_MODE,
SMEAttributes.StreamingModeValues.COMPATIBLE,
"aarch64_pstate_sm_compatible",
),
(SMEAttributes.ZA_STORAGE, SMEAttributes.ZAStorageValues.NEW, "aarch64_pstate_za_new"),
(
SMEAttributes.ZA_STORAGE,
SMEAttributes.ZAStorageValues.SHARED,
"aarch64_pstate_za_shared",
),
],
)
def test_function_attributes(attr_key, attr_value, expected):
target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sme"

@T.prim_func
def prim_func(a: T.handle, c: T.handle):
T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
A = T.match_buffer(a, (16,), "float32")
C = T.match_buffer(c, (1,), "float32")

with T.block("extern"):
T.block_attr({attr_key: attr_value})
for i in range(16):
C[0] += A[i]

func = tvm.build(prim_func, target=target)
ll = func.get_source("ll")

# Check that the attribute exists
attr = re.findall(rf".*{expected}*.", ll)
assert attr, f"Function attribute {expected} was not found in generated LLVM IR"

# Check this attribute is used on the "compute" function
func_attr_label = attr[0].split(" ")[1]
found_compute_func = False
for match in re.findall(rf".*{func_attr_label}*.", ll):
if "_compute_" in match:
found_compute_func = True

assert found_compute_func, (
f"The attribute {expected} was found to be under the label {func_attr_label}, "
"but it was not used by the 'compute' scope function."
)


def test_unsupported_function_attribute_type():
target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sme"

@T.prim_func
def prim_func(a: T.handle, c: T.handle):
T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
A = T.match_buffer(a, (16,), "float32")
C = T.match_buffer(c, (1,), "float32")

with T.block("extern"):
T.block_attr({SMEAttributes.STREAMING_MODE: True})
with T.block("root"):
for i in range(16):
C[0] += A[i]

err_msg = f"Expect {SMEAttributes.STREAMING_MODE} to have a String value but was IntImm"
with pytest.raises(tvm.error.TVMError, match=err_msg):
tvm.build(prim_func, target=target)


@pytest.mark.parametrize(
"attr_key,attr_value",
[
(SMEAttributes.STREAMING_MODE, SMEAttributes.StreamingModeValues.ENABLED),
(SMEAttributes.ZA_STORAGE, SMEAttributes.ZAStorageValues.NEW),
],
)
def test_unsupported_multiple_function_attributes(attr_key, attr_value):
target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sme"

@T.prim_func
def prim_func(a: T.handle, c: T.handle):
A = T.match_buffer(a, (16,), "float32")
C = T.match_buffer(c, (1,), "float32")

with T.block("root"):
with T.block("extern"):
T.block_attr({attr_key: attr_value})
for i in range(16):
C[0] += A[i] * 2
with T.block("extern2"):
T.block_attr({attr_key: attr_value})
for i in range(16):
C[0] += A[i] * 3

err_msg = f"Multiple definitions of {attr_key} attribute found in the function default_function_compute_"
with pytest.raises(tvm.error.TVMError, match=err_msg):
tvm.build(prim_func, target=target)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit ac2f478

Please sign in to comment.