-
Notifications
You must be signed in to change notification settings - Fork 2
/
kgen.py
executable file
·136 lines (115 loc) · 4.38 KB
/
kgen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#!/usr/bin/env python3
#
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# Author: John Linford <jlinford@nvidia.com>
#
import sys
import argparse
KERNEL_FILE_TEMPLATE = """
//
// SPDX-License-Identifier: BSD-3-Clause
// Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
// Author: John Linford <jlinford@nvidia.com>
//
// This file was autogenerated by kgen.py
//
%(headers)s
const char * description = "%(descr)s";
unsigned long block_inst = %(block_inst)s;
unsigned long block_ops = %(block_ops)s;
unsigned long unroll = %(unroll)s;
void kernel(unsigned long iters)
{
for (unsigned long i=0; i<iters; ++i) {
asm volatile (
%(body)s
: /* no input */
: /* no output */
: %(clobber)s
);
}
}
"""
def modrange(start, stop, mod):
return [(i % mod) for i in range(start, stop)]
def usage_error(*args, **kwargs):
print(*args, file=sys.stderr, **kwargs)
sys.exit(100)
def generate_block(lines, clobber, block_ops, blk):
def ignore_register(x):
return x[0] not in ("x", "w", "v", "z", "d", "s", "h")
indent = 8*" "
count = blk.count
opcode = blk.opcode
formats = blk.operand[0::2]
values = []
for val in blk.operand[1::2]:
try:
evaluated = eval(val)
except SyntaxError as err:
usage_error("Syntax error in operand range value: %s" % err.text)
if (evaluated is not None) and (len(evaluated) != count):
usage_error("Invalid length %d of operand range value '%s' (expected %d)" % (len(evaluated), val, count))
values.append(evaluated)
for i in range(count):
operands = [fmt % val[i] if val is not None else fmt for (fmt, val) in zip(formats, values)]
clobber |= set([x.split(".")[0] for x in operands if not ignore_register(x)])
lines.append('%s"%s %s \\n\\t"' % (indent, opcode, ", ".join(operands)))
if blk.isa == "SCALAR":
lanes = "1"
elif blk.isa == "NEON":
lanes = "(128/%s)" % blk.typebits
elif blk.isa == "SVE":
lanes = "(8*svcntb()/%s)" % blk.typebits
block_ops.append("(%s*(%s*%s))" % (blk.count, blk.laneops, lanes))
def describe(unroll, blocks):
parts = ["%d(" % unroll]
for blk in blocks:
parts.append("%d(%s_%s_%db)" % (blk.count, blk.isa, blk.opcode.upper(), blk.typebits))
parts.append(")")
return " ".join(parts)
def generate(unroll, blocks):
lines = []
clobber = set()
headers = set()
block_ops = []
for blk in blocks:
generate_block(lines, clobber, block_ops, blk)
if blk.isa == "SVE":
headers.add("arm_sve.h")
block_ops = "+".join(block_ops)
block_inst = str(len(lines))
lines *= unroll
headers = "\n".join(["#include <%s>" % x for x in headers])
body = "\n".join(lines)
descr = describe(unroll, blocks)
clobber = ", ".join(sorted(['"%s"' % x for x in clobber]))
print(KERNEL_FILE_TEMPLATE % {
"headers": headers,
"descr": descr,
"block_inst": block_inst,
"block_ops": block_ops,
"unroll": unroll,
"body": body,
"clobber": clobber})
def parse_args(args):
block_parser = argparse.ArgumentParser(prog="", add_help=False)
block_parser.add_argument("isa", choices=["SCALAR", "SVE", "NEON"], help="Instruction ISA")
block_parser.add_argument("typebits", type=int, help="Size of the operation datatype in bits")
block_parser.add_argument("laneops", type=int, help="Operations performed per lane")
block_parser.add_argument("count", type=int, help="Instructions in the block")
block_parser.add_argument("opcode", help="Instruction opcode")
block_parser.add_argument("operand", nargs="+", help="Instruction operands")
block_help = block_parser.format_usage().replace("usage: ", "")
parser = argparse.ArgumentParser()
parser.add_argument("-u", "--unroll", type=int, help="Number of times to unroll the loop", default=4)
parser.add_argument("-b", required=True, nargs="+", metavar="block_template", dest="blocks", action="append", help=block_help)
parsed = parser.parse_args(args)
parsed.blocks = [block_parser.parse_args(blk) for blk in parsed.blocks]
return parsed
def main(*args, **kwargs):
parsed = parse_args(*args)
generate(parsed.unroll, parsed.blocks)
if __name__ == "__main__":
main(sys.argv[1:])