diff --git a/p-isa_tools/kerngen/kernel_optimization/loops.py b/p-isa_tools/kerngen/kernel_optimization/loops.py
index 7c36a835..16d598f1 100644
--- a/p-isa_tools/kerngen/kernel_optimization/loops.py
+++ b/p-isa_tools/kerngen/kernel_optimization/loops.py
@@ -10,6 +10,54 @@
from high_parser.pisa_operations import PIsaOp, Comment
+def remove_comments(pisa_list: list[PIsaOp]) -> list[PIsaOp]:
+ """Remove comments from a list of PIsaOp instructions.
+
+ Args:
+ pisa_list: List of PIsaOp instructions
+
+ Returns:
+ List of PIsaOp instructions without comments
+ """
+ return [pisa for pisa in pisa_list if not isinstance(pisa, Comment)]
+
+
+def split_by_reorderable(pisa_list: list[PIsaOp]) -> tuple[list[PIsaOp], list[PIsaOp]]:
+ """Split a list of PIsaOp instructions into reorderable and non-reorderable groups.
+
+ Args:
+ pisa_list: List of PIsaOp instructions
+
+ Returns:
+ Tuple containing two lists:
+ - reorderable: Instructions that can be reordered
+ - non_reorderable: Instructions that cannot be reordered
+ """
+
+ reorderable = []
+ non_reorderable = []
+ is_reorderable = False
+
+ for pisa in pisa_list:
+ # if the pisa is a comment and it contains tag, treat the following pisa as reorderable until a tag is found.
+ if isinstance(pisa, Comment):
+ if "" in pisa.line:
+ is_reorderable = True
+ elif "" in pisa.line:
+ is_reorderable = False
+
+ if is_reorderable:
+ reorderable.append(pisa)
+ else:
+ non_reorderable.append(pisa)
+
+ # if reoderable is empty, return non_reorderable as reorderable
+ if not reorderable:
+ reorderable = non_reorderable
+ non_reorderable = []
+ return remove_comments(reorderable), remove_comments(non_reorderable)
+
+
def loop_interchange(
pisa_list: list[PIsaOp],
primary_key: LoopKey | None = LoopKey.PART,
@@ -52,7 +100,7 @@ def get_sort_key(pisa: PIsaOp) -> tuple:
return (primary_value,)
# Filter out comments
- pisa_list_wo_comments = [p for p in pisa_list if not isinstance(p, Comment)]
+ pisa_list_wo_comments = remove_comments(pisa_list)
# Sort based on primary and optional secondary keys
pisa_list_wo_comments.sort(key=get_sort_key)
return pisa_list_wo_comments
diff --git a/p-isa_tools/kerngen/kerngraph.py b/p-isa_tools/kerngen/kerngraph.py
index e7343281..e942c8d9 100755
--- a/p-isa_tools/kerngen/kerngraph.py
+++ b/p-isa_tools/kerngen/kerngraph.py
@@ -33,29 +33,26 @@
import argparse
import sys
from kernel_parser.parser import KernelParser
-from kernel_optimization.loops import loop_interchange
+from kernel_optimization.loops import loop_interchange, split_by_reorderable
from const.options import LoopKey
from pisa_generators.basic import mixed_to_pisa_ops
+from high_parser.config import Config
def parse_args():
"""Parse arguments from the commandline"""
parser = argparse.ArgumentParser(description="Kernel Graph Parser")
parser.add_argument("-d", "--debug", action="store_true", help="Enable Debug Print")
+ parser.add_argument(
+ "-l", "--legacy", action="store_true", help="Enable Legacy Mode"
+ )
parser.add_argument(
"-t",
"--target",
nargs="*",
default=[],
# Composition high ops such are ntt, mod, and relin are not currently supported
- choices=[
- "add",
- "sub",
- "mul",
- "muli",
- "copy",
- "ntt",
- ], # currently supports single ops
+ choices=["add", "sub", "mul", "muli", "copy", "ntt", "intt", "mod"],
help="List of high_op names",
)
parser.add_argument(
@@ -85,6 +82,7 @@ def main(args):
"""Main function to read input and parse each line with KernelParser."""
input_lines = sys.stdin.read().strip().splitlines()
valid_kernels = []
+ Config.legacy_mode = args.legacy
for line in input_lines:
try:
@@ -106,11 +104,16 @@ def main(args):
if args.target and any(
target.lower() in str(kernel).lower() for target in args.target
):
- kernel = loop_interchange(
- kernel.to_pisa(),
- primary_key=args.primary,
- secondary_key=args.secondary,
+ reorderable, non_reorderable = split_by_reorderable(kernel.to_pisa())
+ kernel = non_reorderable
+ kernel.append(
+ loop_interchange(
+ reorderable,
+ primary_key=args.primary,
+ secondary_key=args.secondary,
+ )
)
+
for pisa in mixed_to_pisa_ops(kernel):
print(pisa)
else:
diff --git a/p-isa_tools/kerngen/pisa_generators/mod.py b/p-isa_tools/kerngen/pisa_generators/mod.py
index 4c01e069..9be5e2cd 100644
--- a/p-isa_tools/kerngen/pisa_generators/mod.py
+++ b/p-isa_tools/kerngen/pisa_generators/mod.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2024 Intel Corporation
+# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Module containing conversions or operations from isa to p-isa."""
@@ -89,7 +89,7 @@ def generate_mod_stages() -> list[Stage]:
stages.append(
Stage(
[
- Comment("Mod Stage 1"),
+ Comment("Mod Stage 1 "),
muli_last_half(
self.context,
temp_input_remaining_rns,
@@ -194,7 +194,7 @@ def generate_mod_stages() -> list[Stage]:
+ stages[2].pisa_ops
+ [
Muli(self.context, self.output, temp_input_remaining_rns, iq),
- Comment("End of mod kernel"),
+ Comment("End of mod kernel "),
]
)