Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion p-isa_tools/kerngen/kernel_optimization/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,54 @@
from high_parser.pisa_operations import PIsaOp, Comment


def remove_comments(pisa_list: list[PIsaOp]) -> list[PIsaOp]:
"""Remove comments from a list of PIsaOp instructions.

Args:
pisa_list: List of PIsaOp instructions

Returns:
List of PIsaOp instructions without comments
"""
return [pisa for pisa in pisa_list if not isinstance(pisa, Comment)]


def split_by_reorderable(pisa_list: list[PIsaOp]) -> tuple[list[PIsaOp], list[PIsaOp]]:
"""Split a list of PIsaOp instructions into reorderable and non-reorderable groups.

Args:
pisa_list: List of PIsaOp instructions

Returns:
Tuple containing two lists:
- reorderable: Instructions that can be reordered
- non_reorderable: Instructions that cannot be reordered
"""

reorderable = []
non_reorderable = []
is_reorderable = False

for pisa in pisa_list:
# if the pisa is a comment and it contains <reorderable> tag, treat the following pisa as reorderable until a </reorderable> tag is found.
if isinstance(pisa, Comment):
if "<reorderable>" in pisa.line:
is_reorderable = True
elif "</reorderable>" in pisa.line:
is_reorderable = False

if is_reorderable:
reorderable.append(pisa)
else:
non_reorderable.append(pisa)

# if reoderable is empty, return non_reorderable as reorderable
if not reorderable:
reorderable = non_reorderable
non_reorderable = []
return remove_comments(reorderable), remove_comments(non_reorderable)


def loop_interchange(
pisa_list: list[PIsaOp],
primary_key: LoopKey | None = LoopKey.PART,
Expand Down Expand Up @@ -52,7 +100,7 @@ def get_sort_key(pisa: PIsaOp) -> tuple:
return (primary_value,)

# Filter out comments
pisa_list_wo_comments = [p for p in pisa_list if not isinstance(p, Comment)]
pisa_list_wo_comments = remove_comments(pisa_list)
# Sort based on primary and optional secondary keys
pisa_list_wo_comments.sort(key=get_sort_key)
return pisa_list_wo_comments
29 changes: 16 additions & 13 deletions p-isa_tools/kerngen/kerngraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,29 +33,26 @@
import argparse
import sys
from kernel_parser.parser import KernelParser
from kernel_optimization.loops import loop_interchange
from kernel_optimization.loops import loop_interchange, split_by_reorderable
from const.options import LoopKey
from pisa_generators.basic import mixed_to_pisa_ops
from high_parser.config import Config


def parse_args():
"""Parse arguments from the commandline"""
parser = argparse.ArgumentParser(description="Kernel Graph Parser")
parser.add_argument("-d", "--debug", action="store_true", help="Enable Debug Print")
parser.add_argument(
"-l", "--legacy", action="store_true", help="Enable Legacy Mode"
)
parser.add_argument(
"-t",
"--target",
nargs="*",
default=[],
# Composition high ops such are ntt, mod, and relin are not currently supported
choices=[
"add",
"sub",
"mul",
"muli",
"copy",
"ntt",
], # currently supports single ops
choices=["add", "sub", "mul", "muli", "copy", "ntt", "intt", "mod"],
help="List of high_op names",
)
parser.add_argument(
Expand Down Expand Up @@ -85,6 +82,7 @@ def main(args):
"""Main function to read input and parse each line with KernelParser."""
input_lines = sys.stdin.read().strip().splitlines()
valid_kernels = []
Config.legacy_mode = args.legacy

for line in input_lines:
try:
Expand All @@ -106,11 +104,16 @@ def main(args):
if args.target and any(
target.lower() in str(kernel).lower() for target in args.target
):
kernel = loop_interchange(
kernel.to_pisa(),
primary_key=args.primary,
secondary_key=args.secondary,
reorderable, non_reorderable = split_by_reorderable(kernel.to_pisa())
kernel = non_reorderable
kernel.append(
loop_interchange(
reorderable,
primary_key=args.primary,
secondary_key=args.secondary,
)
)

for pisa in mixed_to_pisa_ops(kernel):
print(pisa)
else:
Expand Down
6 changes: 3 additions & 3 deletions p-isa_tools/kerngen/pisa_generators/mod.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2024 Intel Corporation
# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""Module containing conversions or operations from isa to p-isa."""
Expand Down Expand Up @@ -89,7 +89,7 @@ def generate_mod_stages() -> list[Stage]:
stages.append(
Stage(
[
Comment("Mod Stage 1"),
Comment("Mod Stage 1 <reorderable>"),
muli_last_half(
self.context,
temp_input_remaining_rns,
Expand Down Expand Up @@ -194,7 +194,7 @@ def generate_mod_stages() -> list[Stage]:
+ stages[2].pisa_ops
+ [
Muli(self.context, self.output, temp_input_remaining_rns, iq),
Comment("End of mod kernel"),
Comment("End of mod kernel </reorderable>"),
]
)

Expand Down