diff --git a/assembler_tools/hec-assembler-tools/README.md b/assembler_tools/hec-assembler-tools/README.md index 3cc31a17..6857aca4 100644 --- a/assembler_tools/hec-assembler-tools/README.md +++ b/assembler_tools/hec-assembler-tools/README.md @@ -78,7 +78,7 @@ python3 he_as.py filename.tw.csv --input_mem_file filename.mem ```bash # link assembled output (input prefix: filename.tw) # outputs filename.minst, filename.cinst, filename.xinst -python3 he_link.py filename.tw --input_mem_file filename.mem --output_prefix filename +python3 he_link.py --input_prefixes filename.tw --input_mem_file filename.mem --output_prefix filename ``` This will generate the main three output files in the same directory as the input file: @@ -89,7 +89,7 @@ This will generate the main three output files in the same directory as the inpu Intermediate files, if any, are kept as well. -The linker program is able to link several assembled kernels into a single HERACLES program, given a correct memory mapping for the resulting program. +The linker program is able to link several assembled kernels into a single HERACLES program, given a correct memory mapping or trace file for the resulting program. This version of executing is intended for the assembler to be usable as part of a compilation pipeline. diff --git a/assembler_tools/hec-assembler-tools/he_link.py b/assembler_tools/hec-assembler-tools/he_link.py index b674494e..d2cc57b1 100644 --- a/assembler_tools/hec-assembler-tools/he_link.py +++ b/assembler_tools/hec-assembler-tools/he_link.py @@ -11,7 +11,7 @@ @par Classes: - LinkerRunConfig: Maintains the configuration data for the run. - - KernelFiles: Structure for kernel files. + - KernelInfo: Structure for kernel files. @par Functions: - main(run_config: LinkerRunConfig, verbose_stream=None): Executes the linking process using the provided configuration. @@ -22,321 +22,28 @@ to specify input and output files and configuration options for the linking process. """ import argparse -import io import os -import pathlib import sys import warnings -from typing import NamedTuple, Any, Optional -import linker -from assembler.common import constants -from assembler.common import makeUniquePath from assembler.common.counter import Counter -from assembler.common.run_config import RunConfig from assembler.common.config import GlobalConfig -from assembler.memory_model import mem_info from assembler.spec_config.mem_spec import MemSpecConfig from assembler.spec_config.isa_spec import ISASpecConfig -from linker import loader -from linker.steps import variable_discovery -from linker.steps import program_linker from linker.instructions import BaseInstruction - - -class NullIO: - """ - @class NullIO - @brief A class that provides a no-operation implementation of write and flush methods. - """ - - def write(self, *argts, **kwargs): - """ - @brief A no-operation write method. - """ - - def flush(self): - """ - @brief A no-operation flush method. - """ - - -class LinkerRunConfig(RunConfig): - """ - @class LinkerRunConfig - @brief Maintains the configuration data for the run. - - @fn as_dict - @brief Returns the configuration as a dictionary. - - @return dict The configuration as a dictionary. - """ - - # Type annotations for class attributes - input_prefixes: list[str] - input_mem_file: str - multi_mem_files: bool - output_dir: str - output_prefix: str - - __initialized = False # specifies whether static members have been initialized - # contains the dictionary of all configuration items supported and their - # default value (or None if no default) - __default_config: dict[str, Any] = {} - - def __init__(self, **kwargs): - """ - @brief Constructs a new LinkerRunConfig Object from input parameters. - - See base class constructor for more parameters. - - @param input_prefixes List of input prefixes, including full path. For an input prefix, linker will - assume there are three files named `input_prefixes[i] + '.minst'`, - `input_prefixes[i] + '.cinst'`, and `input_prefixes[i] + '.xinst'`. - This list must not be empty. - @param output_prefix Prefix for the output file names. - Three files will be generated: - `output_dir/output_prefix.minst`, `output_dir/output_prefix.cinst`, and - `output_dir/output_prefix.xinst`. - Output filenames cannot match input file names. - @param input_mem_file Input memory file associated with the result kernel. - @param output_dir OPTIONAL directory where to store all intermediate files and final output. - This will be created if it doesn't exists. - Defaults to current working directory. - - @exception TypeError A mandatory configuration value was missing. - @exception ValueError At least, one of the arguments passed is invalid. - """ - super().__init__(**kwargs) - - self.init_default_config() - - # class members based on configuration - for config_name, default_value in self.__default_config.items(): - value = kwargs.get(config_name, default_value) - if value is not None: - setattr(self, config_name, value) - else: - if not hasattr(self, config_name): - setattr(self, config_name, default_value) - if getattr(self, config_name) is None: - raise TypeError( - f"Expected value for configuration `{config_name}`, but `None` received." - ) - - # fix file names - self.output_dir = makeUniquePath(self.output_dir) - # E0203: Access to member 'input_mem_file' before its definition. - # But it was defined in previous loop. - if self.input_mem_file != "": # pylint: disable=E0203 - self.input_mem_file = makeUniquePath(self.input_mem_file) - - @classmethod - def init_default_config(cls): - """ - @brief Initializes static members of the class. - """ - if not cls.__initialized: - cls.__default_config["input_prefixes"] = None - cls.__default_config["input_mem_file"] = "" - cls.__default_config["multi_mem_files"] = False - cls.__default_config["output_dir"] = os.getcwd() - cls.__default_config["output_prefix"] = None - - cls.__initialized = True - - def __str__(self): - """ - @brief Provides a string representation of the configuration. - - @return str The string for the configuration. - """ - self_dict = self.as_dict() - with io.StringIO() as retval_f: - for key, value in self_dict.items(): - print(f"{key}: {value}", file=retval_f) - retval = retval_f.getvalue() - return retval - - def as_dict(self) -> dict: - """ - @brief Provides the configuration as a dictionary. - - @return dict The configuration. - """ - retval = super().as_dict() - tmp_self_dict = vars(self) - retval.update( - { - config_name: tmp_self_dict[config_name] - for config_name in self.__default_config - } - ) - return retval - - -class KernelFiles(NamedTuple): - """ - @class KernelFiles - @brief Structure for kernel files. - - @var prefix - Index = 0 - @var minst - Index = 1. Name for file containing MInstructions for represented kernel. - @var cinst - Index = 2. Name for file containing CInstructions for represented kernel. - @var xinst - Index = 3. Name for file containing XInstructions for represented kernel. - @var mem - Index = 4. Name for file containing memory information for represented kernel. - This is used only when multi_mem_files is set. - """ - - prefix: str - minst: str - cinst: str - xinst: str - mem: Optional[str] = None - - -def link_kernels(input_files, output_files, mem_model, verbose_stream): - """ - @brief Links input kernels and writes the output to the specified files. - - @param input_files List of KernelFiles for input kernels. - @param output_files KernelFiles for output. - @param mem_model Memory model to use. - @param run_config LinkerRunConfig object. - @param verbose_stream Stream for verbose output. - """ - with open(output_files.minst, "w", encoding="utf-8") as fnum_output_minst, open( - output_files.cinst, "w", encoding="utf-8" - ) as fnum_output_cinst, open( - output_files.xinst, "w", encoding="utf-8" - ) as fnum_output_xinst: - - result_program = program_linker.LinkedProgram( - fnum_output_minst, fnum_output_cinst, fnum_output_xinst, mem_model - ) - for idx, kernel in enumerate(input_files): - if verbose_stream: - print( - f"[ {idx * 100 // len(input_files): >3}% ]", - kernel.prefix, - file=verbose_stream, - ) - kernel_minstrs = loader.load_minst_kernel_from_file(kernel.minst) - kernel_cinstrs = loader.load_cinst_kernel_from_file(kernel.cinst) - kernel_xinstrs = loader.load_xinst_kernel_from_file(kernel.xinst) - result_program.link_kernel(kernel_minstrs, kernel_cinstrs, kernel_xinstrs) - if verbose_stream: - print( - "[ 100% ] Finalizing output", output_files.prefix, file=verbose_stream - ) - result_program.close() - - -def prepare_output_files(run_config) -> KernelFiles: - """ - @brief Prepares output file names and directories. - - @param run_config LinkerRunConfig object. - @return KernelFiles Output file paths. - """ - output_prefix = os.path.join(run_config.output_dir, run_config.output_prefix) - output_dir = os.path.dirname(output_prefix) - pathlib.Path(output_dir).mkdir(exist_ok=True, parents=True) - out_mem_file = ( - makeUniquePath(output_prefix + ".mem") if run_config.multi_mem_files else None - ) - return KernelFiles( - prefix=makeUniquePath(output_prefix), - minst=makeUniquePath(output_prefix + ".minst"), - cinst=makeUniquePath(output_prefix + ".cinst"), - xinst=makeUniquePath(output_prefix + ".xinst"), - mem=out_mem_file, - ) - - -def prepare_input_files(run_config, output_files) -> list: - """ - @brief Prepares input file names and checks for existence and conflicts. - - @param run_config LinkerRunConfig object. - @param output_files KernelFiles for output. - @return list List of KernelFiles for input. - @exception FileNotFoundError If an input file does not exist. - @exception RuntimeError If an input file matches an output file. - """ - input_files = [] - for file_prefix in run_config.input_prefixes: - mem_file = ( - makeUniquePath(file_prefix + ".mem") if run_config.multi_mem_files else None - ) - kernel_files = KernelFiles( - prefix=makeUniquePath(file_prefix), - minst=makeUniquePath(file_prefix + ".minst"), - cinst=makeUniquePath(file_prefix + ".cinst"), - xinst=makeUniquePath(file_prefix + ".xinst"), - mem=mem_file, - ) - input_files.append(kernel_files) - for input_filename in kernel_files[1:]: - if input_filename: - if not os.path.isfile(input_filename): - raise FileNotFoundError(input_filename) - if input_filename in output_files: - raise RuntimeError( - f'Input files cannot match output files: "{input_filename}"' - ) - return input_files - - -def scan_variables(input_files, mem_model, verbose_stream): - """ - @brief Scans input files for variables and adds them to the memory model. - - @param input_files List of KernelFiles for input. - @param mem_model Memory model to update. - @param verbose_stream Stream for verbose output. - """ - for idx, kernel in enumerate(input_files): - if not GlobalConfig.hasHBM: - if verbose_stream: - print( - f" {idx + 1}/{len(input_files)}", - kernel.cinst, - file=verbose_stream, - ) - kernel_cinstrs = loader.load_cinst_kernel_from_file(kernel.cinst) - for var_name in variable_discovery.discover_variables_spad(kernel_cinstrs): - mem_model.add_variable(var_name) - else: - if verbose_stream: - print( - f" {idx + 1}/{len(input_files)}", - kernel.minst, - file=verbose_stream, - ) - kernel_minstrs = loader.load_minst_kernel_from_file(kernel.minst) - for var_name in variable_discovery.discover_variables(kernel_minstrs): - mem_model.add_variable(var_name) - - -def check_unused_variables(mem_model): - """ - @brief Checks for unused variables in the memory model and raises an error if found. - - @param mem_model Memory model to check. - @exception RuntimeError If an unused variable is found. - """ - for var_name in mem_model.mem_info_vars: - if var_name not in mem_model.variables: - if GlobalConfig.hasHBM or var_name not in mem_model.mem_info_meta: - raise RuntimeError( - f'Unused variable from input mem file: "{var_name}" not in memory model.' - ) +from linker.linker_run_config import LinkerRunConfig +from linker.steps.variable_discovery import scan_variables, check_unused_variables +from linker.steps import program_linker +from linker.kern_trace.trace_info import TraceInfo +from linker.loader import Loader +from linker.he_link_utils import ( + NullIO, + prepare_output_files, + prepare_input_files, + update_input_prefixes, + remap_vars, + initialize_memory_model, +) def main(run_config: LinkerRunConfig, verbose_stream=NullIO()): @@ -358,69 +65,81 @@ def main(run_config: LinkerRunConfig, verbose_stream=NullIO()): GlobalConfig.hasHBM = run_config.has_hbm GlobalConfig.suppress_comments = run_config.suppress_comments - mem_filename: str = run_config.input_mem_file - hbm_capacity_words: int = constants.convertBytes2Words( - run_config.hbm_size * constants.Constants.KILOBYTE - ) + # Process trace file if enabled + kernel_ops = [] + if run_config.using_trace_file: + kernel_ops = TraceInfo.parse_kernel_ops_from_file(run_config.trace_file) + update_input_prefixes(kernel_ops, run_config) + + print( + f"Found {len(kernel_ops)} kernel ops in trace file:", + file=verbose_stream, + ) + + print("", file=verbose_stream) # Prepare input and output files - output_files: KernelFiles = prepare_output_files(run_config) - input_files: list[KernelFiles] = prepare_input_files(run_config, output_files) + program_info = prepare_output_files(run_config) + kernels_info = prepare_input_files(run_config, program_info) # Reset counters Counter.reset() - # parse mem file + # Parse memory information and setup memory model print("Linking...", file=verbose_stream) print("", file=verbose_stream) print("Interpreting variable meta information...", file=verbose_stream) - if run_config.multi_mem_files: - kernels_dinstrs = [] - for kernel in input_files: - if kernel.mem is None: - raise RuntimeError(f"Memory file not found for kernel {kernel.prefix}") - kernel_dinstrs = loader.load_dinst_kernel_from_file(kernel.mem) - kernels_dinstrs.append(kernel_dinstrs) + # Process kernel DInstructions when using trace file + program_dinstrs = [] + if run_config.using_trace_file: + dinstrs_per_kernel = [] + for kernel_info in kernels_info: + kernel_dinstrs = Loader.load_dinst_kernel_from_file(kernel_info.mem) + dinstrs_per_kernel.append(kernel_dinstrs) + + remap_vars(kernels_info, dinstrs_per_kernel, kernel_ops, verbose_stream) # Concatenate all mem info objects into one - kernel_dinstrs = program_linker.LinkedProgram.join_dinst_kernels( - kernels_dinstrs + program_dinstrs = program_linker.LinkedProgram.join_dinst_kernels( + dinstrs_per_kernel ) - mem_meta_info = mem_info.MemInfo.from_dinstrs(kernel_dinstrs) - else: - with open(mem_filename, "r", encoding="utf-8") as mem_ifnum: - mem_meta_info = mem_info.MemInfo.from_file_iter(mem_ifnum) - # Initialize memory model - print("Initializing linker memory model", file=verbose_stream) + # Write new program memory model to an output file + if program_info.mem is None: + raise RuntimeError("Output memory file path is None") + BaseInstruction.dump_instructions_to_file(program_dinstrs, program_info.mem) - mem_model = linker.MemoryModel(hbm_capacity_words, mem_meta_info) - print(f" HBM capacity: {mem_model.hbm.capacity} words", file=verbose_stream) + # Initialize memory model + mem_model = initialize_memory_model(run_config, program_dinstrs, verbose_stream) + # Discover variables print(" Finding all program variables...", file=verbose_stream) print(" Scanning", file=verbose_stream) - scan_variables(input_files, mem_model, verbose_stream) + scan_variables( + kernels_info=kernels_info, mem_model=mem_model, verbose_stream=verbose_stream + ) + check_unused_variables(mem_model) print(f" Variables found: {len(mem_model.variables)}", file=verbose_stream) print("Linking started", file=verbose_stream) - link_kernels(input_files, output_files, mem_model, verbose_stream) + # Link kernels and generate outputs + program_linker.LinkedProgram.link_kernels_to_files( + kernels_info, program_info, mem_model, verbose_stream=verbose_stream + ) - # Write the memory model to the output file - if run_config.multi_mem_files: - if output_files.mem is None: - raise RuntimeError("Output memory file path is None") - BaseInstruction.dump_instructions_to_file(kernel_dinstrs, output_files.mem) + # Flush cached kernels + Loader.flush_cache() print("Output written to files:", file=verbose_stream) - print(" ", output_files.minst, file=verbose_stream) - print(" ", output_files.cinst, file=verbose_stream) - print(" ", output_files.xinst, file=verbose_stream) - if run_config.multi_mem_files: - print(" ", output_files.mem, file=verbose_stream) + print(" ", program_info.minst, file=verbose_stream) + print(" ", program_info.cinst, file=verbose_stream) + print(" ", program_info.xinst, file=verbose_stream) + if run_config.using_trace_file: + print(" ", program_info.mem, file=verbose_stream) def parse_args(): @@ -445,12 +164,14 @@ def parse_args(): ) ) parser.add_argument( - "input_prefixes", + "-ip", + "--input_prefixes", + dest="input_prefixes", nargs="+", help=( - "List of input prefixes, including full path. For an input prefix, linker will " - "assume three files exist named `input_prefixes[i] + '.minst'`, " - "`input_prefixes[i] + '.cinst'`, and `input_prefixes[i] + '.xinst'`." + "List of input prefixes. For an input prefix, linker will " + "assume three files exist named `.minst`, " + "`.cinst`, and `.xinst`." ), ) parser.add_argument( @@ -466,13 +187,25 @@ def parse_args(): help=("Input ISA specification (.json) file."), ) parser.add_argument( - "--multi_mem_files", - action="store_true", - dest="multi_mem_files", + "--use_trace_file", + default="", + dest="trace_file", help=( - "Tells the linker to find a memory file (*.tw.mem) for each input prefix given." - "This can be used to link multiple kernels together. " - "If this flag is not set, the linker will use the input_mem_file argument instead" + "Instructs the linker to use a trace file to determine the required input files for each kernel line. " + "The linker will look for the following files: *.minst, *.cinst, *.xinst, and *.mem. " + "When this flag is set, the 'input_mem_file' and 'input_prefixes' flags are ignored." + ), + ) + parser.add_argument( + "-id", + "--input_dir", + dest="input_dir", + default="", + help=( + "Directory where input files are located. " + "If not provided and use_trace_file is set, the directory of the trace file will be used. " + "This is useful when input files are in a different location than the trace file. " + "If not provided and use_trace_file is not set, the current working directory will be used." ), ) parser.add_argument( @@ -538,11 +271,33 @@ def parse_args(): ) p_args = parser.parse_args() - # Enforce input_mem_file only if multi_mem_files is not set - if not p_args.multi_mem_files and p_args.input_mem_file == "": - parser.error( - "the following arguments are required: -im/--input_mem_file (unless --multi_mem_files is set)" - ) + # Determine if using trace file based on trace_file argument + p_args.using_trace_file = p_args.trace_file != "" + + # Set input_dir to trace_file directory if not provided and trace_file is set + if p_args.input_dir == "" and p_args.trace_file: + p_args.input_dir = os.path.dirname(p_args.trace_file) + + # Enforce only if use_trace_file is not set + if not p_args.using_trace_file: + if p_args.input_mem_file == "": + parser.error( + "the following arguments are required: -im/--input_mem_file (unless --use_trace_file is set)" + ) + if not p_args.input_prefixes: + parser.error( + "the following arguments are required: -ip/--input_prefixes (unless --use_trace_file is set)" + ) + else: + # If using trace file, input_mem_file and input_prefixes are ignored + if p_args.input_mem_file != "": + warnings.warn( + "Ignoring input_mem_file argument because --use_trace_file is set." + ) + if p_args.input_prefixes: + warnings.warn( + "Ignoring input_prefixes argument because --use_trace_file is set." + ) return p_args diff --git a/assembler_tools/hec-assembler-tools/linker/he_link_utils.py b/assembler_tools/hec-assembler-tools/linker/he_link_utils.py new file mode 100644 index 00000000..dc99672d --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/he_link_utils.py @@ -0,0 +1,178 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# These contents may have been developed with support from one +# or more Intel-operated generative artificial intelligence solutions + +""" +@file he_link_utils.py +@brief Utility functions for the he_link module +""" +import os +import pathlib + +import linker +from assembler.common import constants +from assembler.common import makeUniquePath +from assembler.memory_model import mem_info +from linker.kern_trace import KernelInfo, remap_dinstrs_vars + + +class NullIO: + """ + @class NullIO + @brief A class that provides a no-operation implementation of write and flush methods. + """ + + def write(self, *argts, **kwargs): + """ + @brief A no-operation write method. + """ + + def flush(self): + """ + @brief A no-operation flush method. + """ + + +def prepare_output_files(run_config) -> KernelInfo: + """ + @brief Prepares output file names and directories. + + @param run_config LinkerRunConfig object. + @return KernelInfo with output file paths. + """ + path_prefix = os.path.join(run_config.output_dir, run_config.output_prefix) + pathlib.Path(run_config.output_dir).mkdir(exist_ok=True, parents=True) + out_mem_file = ( + makeUniquePath(path_prefix + ".mem") if run_config.using_trace_file else None + ) + return KernelInfo( + { + "directory": run_config.output_dir, + "prefix": run_config.output_prefix, + "minst": makeUniquePath(path_prefix + ".minst"), + "cinst": makeUniquePath(path_prefix + ".cinst"), + "xinst": makeUniquePath(path_prefix + ".xinst"), + "mem": out_mem_file, + } + ) + + +def prepare_input_files(run_config, output_files) -> list: + """ + @brief Prepares input file names and checks for existence and conflicts. + + @param run_config LinkerRunConfig object. + @param output_files KernelInfo for output. + @return list List of KernelInfo for input. + @exception FileNotFoundError If an input file does not exist. + @exception RuntimeError If an input file matches an output file. + """ + input_files = [] + for file_prefix in run_config.input_prefixes: + path_prefix = os.path.join(run_config.input_dir, file_prefix) + mem_file = ( + makeUniquePath(path_prefix + ".mem") + if run_config.using_trace_file + else None + ) + kernel_info = KernelInfo( + { + "directory": run_config.input_dir, + "prefix": file_prefix, + "minst": makeUniquePath(path_prefix + ".minst"), + "cinst": makeUniquePath(path_prefix + ".cinst"), + "xinst": makeUniquePath(path_prefix + ".xinst"), + "mem": mem_file, + } + ) + input_files.append(kernel_info) + for input_filename in kernel_info.files: + if not os.path.isfile(input_filename): + raise FileNotFoundError(input_filename) + if input_filename in output_files.files: + raise RuntimeError( + f'Input files cannot match output files: "{input_filename}"' + ) + return input_files + + +def update_input_prefixes(kernel_ops, run_config): + """ + @brief Update input prefixes in run_config. + + @param kernel_ops List of kernel operations to extract prefixes from. + @param run_config LinkerRunConfig object to update with input prefixes. + """ + # Extract kernel prefixes and create list of (prefix, operation) tuples + prefixes = [] + for kernel_op in kernel_ops: + prefix = f"{kernel_op.expected_in_kern_file_name}_pisa.tw" + prefixes.append(prefix) + + # Update input_prefixes in run_config + run_config.input_prefixes = prefixes + + +def remap_vars( + kernels_info: list[KernelInfo], kernels_dinstrs, kernel_ops, verbose_stream +): + """ + @brief Process kernel DInstructions to remap variables based on kernel operations + and update KernelInfo with remap_dict. + + @param kernels_info List of input KernelInfo. + @param kernels_dinstrs List of kernel DInstructions. + @param kernel_ops List of kernel operations. + @param verbose_stream Stream for verbose output. + """ + assert len(kernels_info) == len( + kernel_ops + ), "Number of kernels_files must match number of kernel operations." + assert len(kernels_dinstrs) == len( + kernel_ops + ), "Number of kernel_dinstrs must match number of kernel operations." + + for kernel_info, kernel_op, kernel_dinstrs in zip( + kernels_info, kernel_ops, kernels_dinstrs + ): + print(f"\tProcessing kernel: {kernel_info.prefix}", file=verbose_stream) + + expected_prefix = f"{kernel_op.expected_in_kern_file_name}_pisa.tw" + assert expected_prefix in kernel_info.prefix, ( + f"Kernel operation prefix {expected_prefix} does not match " + f"kernel file prefix {kernel_info.prefix}" + ) + + # Remap dintrs' variables in kernel_dinstrs and return a mapping dict + var_map = remap_dinstrs_vars(kernel_dinstrs, kernel_op) + kernel_info.remap_dict = var_map + + +def initialize_memory_model(run_config, kernel_dinstrs=None, verbose_stream=None): + """ + @brief Initialize the memory model based on configuration. + + @param run_config The configuration object. + @param kernel_dinstrs Optional list of kernel DInstructions for trace file mode. + @param verbose_stream Stream for verbose output. + @return MemoryModel instance. + """ + hbm_capacity_words = constants.convertBytes2Words( + run_config.hbm_size * constants.Constants.KILOBYTE + ) + + # Parse memory information + if kernel_dinstrs: + mem_meta_info = mem_info.MemInfo.from_dinstrs(kernel_dinstrs) + else: + with open(run_config.input_mem_file, "r", encoding="utf-8") as mem_ifnum: + mem_meta_info = mem_info.MemInfo.from_file_iter(mem_ifnum) + + # Initialize memory model + print("Initializing linker memory model", file=verbose_stream) + mem_model = linker.MemoryModel(hbm_capacity_words, mem_meta_info) + print(f" HBM capacity: {mem_model.hbm.capacity} words", file=verbose_stream) + + return mem_model diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/__init__.py b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/__init__.py index 1bcc2452..9cbd9899 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/__init__.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/__init__.py @@ -9,7 +9,6 @@ from typing import Optional from assembler.instructions import tokenize_from_line -from assembler.memory_model.mem_info import MemInfo from . import dload, dstore, dkeygen from . import dinstruction @@ -49,13 +48,4 @@ def create_from_mem_line(line: str) -> dinstruction.DInstruction: if not retval: raise RuntimeError(f'No valid instruction found for line "{line}"') - try: - miv, _ = MemInfo.get_meminfo_var_from_tokens(tokens) - except RuntimeError as e: - raise RuntimeError(f'Error parsing memory map line "{line}"') from e - - miv_dict = miv.as_dict() - retval.var = miv_dict["var_name"] - retval.address = miv_dict["hbm_address"] - return retval diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dinstruction.py b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dinstruction.py index 9db0ad16..bdc3d127 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dinstruction.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dinstruction.py @@ -14,6 +14,7 @@ from linker.instructions.instruction import BaseInstruction from assembler.common.counter import Counter from assembler.common.decorators import classproperty +from assembler.memory_model.mem_info import MemInfo class DInstruction(BaseInstruction): @@ -91,13 +92,24 @@ def __init__(self, tokens: list, comment: str = ""): @param comment Optional comment for the instruction. """ # Do not increment the global instruction count; skip BaseInstruction's __init__ logic for __id - # Call BaseInstruction constructor but perform our own initialization - super().__init__(tokens, comment=comment) + # Perform our own initialization + super().__init__(tokens, comment=comment, count=False) self.comment = comment self._tokens = list(tokens) self._local_id = next(DInstruction._local_id_count) + try: + miv, _ = MemInfo.get_meminfo_var_from_tokens(tokens) + miv_dict = miv.as_dict() + self.var = miv_dict["var_name"] + if self.name in [MemInfo.Const.Keyword.LOAD, MemInfo.Const.Keyword.STORE]: + self.address = miv_dict["hbm_address"] + except RuntimeError as e: + raise ValueError( + f"Failed to parse memory info from tokens: {tokens}. Error: {str(e)}" + ) from e + @property def id(self): """ diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dload.py b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dload.py index e0902735..2bf91fd9 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dload.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dload.py @@ -45,4 +45,8 @@ def tokens(self) -> list: @return The list of tokens. """ - return [self.name, self._tokens[1], str(self.address)] + self._tokens[3:] + extra_tokens = [] + if len(self._tokens) > 4: + extra_tokens = self._tokens[4:] + + return [self.name, self._tokens[1], str(self.address), self.var] + extra_tokens diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py b/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py index 7a0ff43a..5b1d6770 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py @@ -111,19 +111,21 @@ def dump_instructions_to_file(cls, instructions: list, filename: str): # Constructor # ----------- - def __init__(self, tokens: list, comment: str = ""): + def __init__(self, tokens: list, comment: str = "", count: bool = True): """ @brief Creates a new BaseInstruction object. @param tokens List of tokens for the instruction. @param comment Optional comment for the instruction. + @param count If True, increments the global instruction count and sets a unique ID. @throws ValueError If the number of tokens is invalid or the instruction name is incorrect. """ assert self.name_token_index < self.num_tokens self._validate_tokens(tokens) - self._id = next(BaseInstruction.__id_count) + if count: + self._id = next(BaseInstruction.__id_count) self._tokens = list(tokens) self.comment = comment diff --git a/assembler_tools/hec-assembler-tools/linker/kern_trace/__init__.py b/assembler_tools/hec-assembler-tools/linker/kern_trace/__init__.py new file mode 100644 index 00000000..4d07e3e2 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/kern_trace/__init__.py @@ -0,0 +1,25 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +@brief Package for handling kernel operation tracing and analysis. + +This package provides utilities for parsing trace files and extracting kernel operation information. +""" + +from linker.kern_trace.kern_var import KernVar +from linker.kern_trace.context_config import ContextConfig +from linker.kern_trace.kernel_op import KernelOp +from linker.kern_trace.trace_info import TraceInfo +from linker.kern_trace.kern_remap import remap_dinstrs_vars, remap_m_c_instrs_vars +from linker.kern_trace.trace_info import KernelInfo + +__all__ = [ + "KernVar", + "ContextConfig", + "KernelOp", + "TraceInfo", + "remap_dinstrs_vars", + "remap_m_c_instrs_vars", + "KernelInfo", +] diff --git a/assembler_tools/hec-assembler-tools/linker/kern_trace/context_config.py b/assembler_tools/hec-assembler-tools/linker/kern_trace/context_config.py new file mode 100644 index 00000000..f6a6d05d --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/kern_trace/context_config.py @@ -0,0 +1,23 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + +"""@brief Module for encryption scheme context configuration.""" + +from dataclasses import dataclass + + +@dataclass +class ContextConfig: + """ + @brief Configuration class for encryption scheme parameters. + + @details This class encapsulates the parameters related to an encryption scheme, + including the scheme name, polynomial modulus degree, and key RNS terms. + """ + + scheme: str + poly_mod_degree: int + keyrns_terms: int diff --git a/assembler_tools/hec-assembler-tools/linker/kern_trace/kern_remap.py b/assembler_tools/hec-assembler-tools/linker/kern_trace/kern_remap.py new file mode 100644 index 00000000..890d22d2 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/kern_trace/kern_remap.py @@ -0,0 +1,110 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + +"""@brief Module for remapping kernel variables in DINST files.""" + +import re +from linker.instructions.dinst.dinstruction import DInstruction +from linker.instructions.minst.minstruction import MInstruction +from linker.instructions.cinst.cinstruction import CInstruction +from linker.instructions import minst, cinst +from linker.kern_trace.kernel_op import KernelOp + + +def remap_dinstrs_vars( + kernel_dinstrs: list[DInstruction], kernel_op: KernelOp +) -> dict[str, str]: + """ + @brief Remaps variable names in DInstructions based on KernelOp variables. + + For each variable name in the kernel_dinstrs: + 1. Extracts a prefix separated by '_' + 2. Ignores variables with prefixes 'ntt', 'intt', 'ones', 'ipsi', 'psi', 'rlk' or 'twid' + 3. Extracts a number from the prefix (digits after text) + 4. Uses this number as an index into the sorted list of KernelOp variables + 5. Replaces the dinstr var name prefix with the value obtained by the index + + @param kernel_dinstrs: List of DInstruction objects to process + @param kernel_op: KernelOp containing variables to use for remapping + @return: Dictionary mapping old variable names to new variable names + """ + + # Sort kernel_op variables by label + sorted_kern_vars = sorted(kernel_op.kern_vars, key=lambda x: x.label) + + # Dictionary to store mapping of old var names to new var names + var_mapping = {} + + # Process each DInstruction + for dinstr in kernel_dinstrs: + # Split the variable name by '_' to get the prefix + try: + prefix, rest = dinstr.var.split("_", 1) + except ValueError as e: + raise ValueError( + f"Unexpected format: variable name '{dinstr.var}' does not contain items to split by '_': {e}" + ) from e + + # Skip if prefix is not 'ct' or 'pt' + if not (prefix.lower().startswith("ct") or prefix.lower().startswith("pt")): + continue + + # Extract number from prefix (digits after text) + match = re.search(r"([a-zA-Z]+)(\d+)", prefix) + + if not match: + raise ValueError( + f"Unexpected format: variable prefix '{prefix}' does not contain a number after text." + ) + + number_part = int(match.group(2)) + + # Use number as index if it's valid + try: + # Replace prefix with kernel variable label + kern_var = sorted_kern_vars[number_part] + except IndexError as exc: + raise IndexError( + f"Number part {number_part} from prefix '{prefix}' is out of range [0, {len(sorted_kern_vars)-1}] for the KernelOp variables" + ) from exc + + old_var = dinstr.var + new_var = f"{kern_var.label}_{rest}" + dinstr.var = new_var + var_mapping[old_var] = new_var + + return var_mapping + + +def remap_m_c_instrs_vars(kernel_instrs: list, remap_dict: dict[str, str]) -> None: + """ + @brief Remaps variable names in M or C Instructions based on a provided remap dictionary. + + This function updates the variable names in each Instruction by replacing them + with their corresponding values from the remap dictionary. + + @param kernel_instrs: List of M or M Instruction objects to process + @param remap_dict: Dictionary mapping old variable names to new variable names + """ + if remap_dict: + for instr in kernel_instrs: + if not isinstance(instr, (MInstruction, CInstruction)): + raise TypeError(f"Item {instr} is not a valid M or C Instruction.") + + if isinstance( + instr, (minst.MLoad, cinst.BLoad, cinst.CLoad, cinst.BOnes, cinst.NLoad) + ): + if instr.source in remap_dict: + instr.comment = instr.comment.replace( + instr.source, remap_dict[instr.source] + ) + instr.source = remap_dict[instr.source] + elif isinstance(instr, (minst.MStore, cinst.CStore)): + if instr.dest in remap_dict: + instr.comment = instr.comment.replace( + instr.dest, remap_dict[instr.dest] + ) + instr.dest = remap_dict[instr.dest] diff --git a/assembler_tools/hec-assembler-tools/linker/kern_trace/kern_var.py b/assembler_tools/hec-assembler-tools/linker/kern_trace/kern_var.py new file mode 100644 index 00000000..0b886bdd --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/kern_trace/kern_var.py @@ -0,0 +1,80 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + +"""@brief Module for handling kernel variables in trace files.""" + + +class KernVar: + """ + @brief Class representing a kernel variable in trace files. + + @details This class encapsulates the properties of a kernel variable, + including its label, degree, and level. + """ + + def __init__(self, label: str, degree: int, level: int): + """ + @brief Initializes a KernVar instance. + + @param label: The label of the kernel variable. + @param degree: The polynomial degree of the variable. + @param level: The current RNS level of the variable. + """ + self._label = label + self._degree = degree + self._level = level + + @classmethod + def from_string(cls, var_str: str): + """ + @brief Creates a KernVar instance from a string representation. + + @param var_str: The string representation of the kernel variable in the format "label_degree_level". + + @return KernVar: An instance of KernVar initialized with the parsed values. + """ + parts = var_str.split("-") + if len(parts) != 3: + raise ValueError(f"Invalid kernel variable string format: {var_str}") + if not parts[1].isdigit() or not parts[2].isdigit(): + raise ValueError( + f"Invalid degree or level in kernel variable string: {var_str}" + ) + if not parts[0]: + raise ValueError(f"Invalid label in kernel variable string: {var_str}") + + label = parts[0] + degree = int(parts[1]) + level = int(parts[2]) + + return cls(label, degree, level) + + @property + def label(self) -> str: + """ + @brief Returns the label of the kernel variable. + + @return str: The label of the kernel variable. + """ + return self._label + + @property + def degree(self) -> int: + """ + @brief Returns the polynomial degree of the kernel variable. + + @return int: The polynomial degree of the kernel variable. + """ + return self._degree + + @property + def level(self) -> int: + """ + @brief Returns the current RNS level of the kernel variable. + + @return int: The current RNS level of the kernel variable. + """ + return self._level diff --git a/assembler_tools/hec-assembler-tools/linker/kern_trace/kernel_op.py b/assembler_tools/hec-assembler-tools/linker/kern_trace/kernel_op.py new file mode 100644 index 00000000..1f320598 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/kern_trace/kernel_op.py @@ -0,0 +1,196 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + +"""@brief Module for kernel operation representation and analysis.""" + +from linker.kern_trace.context_config import ContextConfig +from linker.kern_trace.kern_var import KernVar + + +class KernelOp: + """ + @brief Base class for kernel operations in trace files. + + @details This class serves as a base for all kernel operations, providing a common interface + and functionality for handling kernel operations in trace files. + """ + + # List of valid kernel operation names. + valid_kernel_ops = [ + "add", + "sub", + "mul", + "relin", + "mod_switch", + "add_plain", + "rotate", + "ntt", + "intt", + "square", + "mul_plain", + "rescale", + ] + + # List of valid encryption schemes. + valid_schemes = ["bgv", "ckks", "bfv"] + + def _get_expected_in_kern_file_name( + self, + name: str, + context_config: ContextConfig, + level: int, + ) -> str: + """ + @brief Returns the expected kernel file name based on internal params. + + @param name: The name of the kernel operation. + @param context_config: Configuration object containing encryption scheme parameters. + @param level: The current RNS level. + + @return str: The expected kernel file name formatted as: + "{scheme}_{name}_{poly_modulus_degree}_l{level}_m{keyrns_terms}" + """ + return ( + f"{context_config.scheme.lower()}_" + f"{name.lower()}_" + f"{context_config.poly_mod_degree}_" + f"l{level}_" + f"m{context_config.keyrns_terms}" + ) + + def get_kern_var_objs(self, kern_var_strs: list[str]) -> list[KernVar]: + """ + @brief Converts a list of kernel variable strings to KernVar objects. + + @param kern_var_strs: A list of strings representing kernel variables. + + @return list: A list of KernVar objects created from the input strings. + """ + return [KernVar.from_string(var_str) for var_str in kern_var_strs] + + def get_level(self, kern_vars: list[KernVar]) -> int: + """ + @brief Sets the level of the kernel operation based on input's current RNS level. + + @details The level is determined by current RNS level on input variables, + which is used to categorize the kernel operation. + """ + if not kern_vars: + raise ValueError( + "Kernel operation must have at least one variable to determine level." + ) + + # Assuming all input variables have the same level for the operation + return kern_vars[1].level if len(kern_vars) > 1 else kern_vars[0].level + + def __init__( + self, + name: str, + context_config: ContextConfig, + kern_args: list, + ): + """ + @brief Initializes a KernelOp instance. + + @param name: The name of the kernel operation. + @param context_config: Configuration object containing encryption scheme parameters. + @param kern_args: List of arguments for the kernel operation. + """ + + if name.lower() not in self.valid_kernel_ops: + raise ValueError( + f"Invalid kernel operation name: {name}. " + f"Valid names are: {', '.join(self.valid_kernel_ops)}" + ) + if context_config.scheme.lower() not in self.valid_schemes: + raise ValueError( + f"Invalid encryption scheme: {context_config.scheme}. " + f"Valid schemes are: {', '.join(self.valid_schemes)}" + ) + if len(kern_args) < 2: + raise ValueError("Kernel operation must have at least two arguments.") + + self._name = name + self._scheme = context_config.scheme + self._poly_modulus_degree = context_config.poly_mod_degree + self._keyrns_terms = context_config.keyrns_terms + self._vars = self.get_kern_var_objs(kern_args) + self._level = self.get_level(self._vars) + self._expected_in_kern_file_name = self._get_expected_in_kern_file_name( + name, + context_config, + self._level, + ) + + def __str__(self): + """ + @brief Returns a string representation of the KernelOp instance. + """ + return f"KernelOp(name={self.name})" + + @property + def kern_vars(self) -> list: + """ + @brief Returns the arguments of the kernel operation. + + @return list: A list of arguments for the kernel operation. + """ + return self._vars + + @property + def name(self) -> str: + """ + @brief Returns the name of the kernel operation. + + @return str: The name of the kernel operation. + """ + return self._name + + @property + def scheme(self) -> str: + """ + @brief Returns the encryption scheme used by the kernel operation. + + @return str: The encryption scheme (e.g., BGV, CKKS). + """ + return self._scheme + + @property + def poly_modulus_degree(self) -> int: + """ + @brief Returns the polynomial modulus degree of the kernel operation. + + @return int: The polynomial modulus degree. + """ + return self._poly_modulus_degree + + @property + def keyrns_terms(self) -> int: + """ + @brief Returns the number of key RNS terms for the kernel operation. + + @return int: The number of key RNS terms. + """ + return self._keyrns_terms + + @property + def level(self) -> int: + """ + @brief Returns the current RNS level of the kernel operation. + + @return int: The current RNS level. + """ + return self._level + + @property + def expected_in_kern_file_name(self) -> str: + """ + @brief Returns the expected file prefix for the kernel operation. + + @return str: The expected file prefix formatted as: + "{scheme}_{name}_{poly_modulus_degree}_l{level}_m{keyrns_terms}" + """ + return self._expected_in_kern_file_name diff --git a/assembler_tools/hec-assembler-tools/linker/kern_trace/trace_info.py b/assembler_tools/hec-assembler-tools/linker/kern_trace/trace_info.py new file mode 100644 index 00000000..010f7bf0 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/kern_trace/trace_info.py @@ -0,0 +1,192 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + +"""@brief Module for parsing and analyzing trace files.""" + +import os +from typing import Optional + +from assembler.instructions import tokenize_from_line +from linker.kern_trace.context_config import ContextConfig +from linker.kern_trace.kernel_op import KernelOp + + +class KernelInfo: + """ + @class KernelInfo + @brief Structure for kernel files. + + @details This class holds information about the kernel files used in the linker. + + @var directory + @var prefix + @var minst + @var cinst + @var xinst + @var mem + + @var remap_dict + @brief Dictionary for remapping variable names in DInstructions. + + """ + + directory: str + prefix: str + minst: str + cinst: str + xinst: str + mem: Optional[str] = None + remap_dict: dict[str, str] = {} + + def __init__(self, config: dict): + """ + @brief Initializes KernelInfo with a configuration dictionary. + + @param config: Dictionary with keys 'directory', 'prefix', 'minst', 'cinst', 'xinst', and optional 'mem'. + """ + self.directory = config["directory"] + self.prefix = config["prefix"] + self.minst = config["minst"] + self.cinst = config["cinst"] + self.xinst = config["xinst"] + self.mem = config.get("mem") + + @property + def files(self) -> list[str]: + """ + @brief Returns a list of file names associated with the kernel. + """ + return [self.minst, self.cinst, self.xinst] + ([self.mem] if self.mem else []) + + +class TraceInfo: + """ + @brief Class for handling trace files. + + @details This class provides an interface + and functionality for handling trace file info. + """ + + def __init__(self, filename: str): + """ + @brief Initializes a TraceFile instance. + + @param filename: The name of the trace file. + """ + self._trace_file = filename + + def __str__(self): + return f"TraceFile(trace_file={self._trace_file})" + + def get_trace_file(self) -> str: + """ + @brief Returns the trace file name. + + @return str: The name of the trace file. + """ + return self._trace_file + + def get_param_index_dict(self, tokens: list[str]) -> dict: + """ + @brief Returns a dictionary mapping property names to their indices in the trace file. + + @return dict: A dictionary mapping property names to their indices. + """ + param_idxs = {} + for i, token in enumerate(tokens): + param_idxs[token] = i + return param_idxs + + def extract_context_and_args(self, tokens, param_idxs, line_num): + """ + @brief Extract context configuration and arguments from tokens. + + @param tokens: List of tokens from a trace file line. + @param param_idxs: Dictionary mapping parameter names to their indices. + @param line_num: Current line number for error reporting. + + @return tuple: A tuple containing (context_config, kern_args). + """ + try: + # Extract required parameters + name = tokens[param_idxs["instruction"]] + scheme = tokens[param_idxs["scheme"]] + poly_mod_degree = int(tokens[param_idxs["poly_modulus_degree"]]) + keyrns_terms = int(tokens[param_idxs["keyrns_terms"]]) + + # Create scheme configuration + context_config = ContextConfig(scheme, poly_mod_degree, keyrns_terms) + + # Collect all parameters from the trace file line that start with "arg" + kern_args = [] + arg_keys = [key for key in param_idxs if key.startswith("arg")] + arg_keys.sort() + for arg_key in arg_keys: + if param_idxs[arg_key] < len(tokens) and tokens[param_idxs[arg_key]]: + kern_args.append(tokens[param_idxs[arg_key]]) + + return name, context_config, kern_args + + except KeyError as e: + raise KeyError( + f"Missing required parameter in line {line_num} with tokens: {tokens}: {e}" + ) from e + except IndexError as e: + raise ValueError( + f"Invalid number of parameters in line {line_num}: {e}" + ) from e + except ValueError as e: + raise ValueError(f"Invalid value in line {line_num}: {e}") from e + + def parse_kernel_ops(self) -> list[KernelOp]: + """ + @brief Parses the kernel operations from the trace file. + + @return list: A list of KernelOp instances parsed from the trace file. + """ + # Validate that trace file exists + if not os.path.isfile(self._trace_file): + raise FileNotFoundError(f"Trace file not found: {self._trace_file}") + + kernel_ops: list = [] + + with open(self._trace_file, "r", encoding="utf-8") as file: + lines = file.readlines() + + if not lines: + return kernel_ops + + # Process header line to get parameter indices + header_tokens, _ = tokenize_from_line(lines[0]) + param_idxs = self.get_param_index_dict(header_tokens) + + # Process the rest of the lines to get kernel operations + for line_num, line in enumerate(lines[1:], 2): # Start at line 2 (index+1) + tokens, _ = tokenize_from_line(line.strip()) + + if not tokens or not tokens[0]: # Skip empty lines + continue + + name, context_config, kern_args = self.extract_context_and_args( + tokens, param_idxs, line_num + ) + + # Create and add KernelOp with all arguments + kernel_op = KernelOp(name, context_config, kern_args) + kernel_ops.append(kernel_op) + + return kernel_ops + + @classmethod + def parse_kernel_ops_from_file(cls, filename: str) -> list[KernelOp]: + """ + @brief Parses kernel operations from a given trace file. + + @param filename: The name of the trace file. + @return list: A list of KernelOp instances parsed from the trace file. + """ + trace_info = cls(filename) + return trace_info.parse_kernel_ops() diff --git a/assembler_tools/hec-assembler-tools/linker/linker_run_config.py b/assembler_tools/hec-assembler-tools/linker/linker_run_config.py new file mode 100644 index 00000000..a98f7f17 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/linker_run_config.py @@ -0,0 +1,150 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# These contents may have been developed with support from one +# or more Intel-operated generative artificial intelligence solutions + +""" +@file linker_run_config.py +@brief This module provides configuration for the linker process. +""" +import io +import os +from typing import Any + +from assembler.common import makeUniquePath +from assembler.common.run_config import RunConfig + + +class LinkerRunConfig(RunConfig): + """ + @class LinkerRunConfig + @brief Maintains the configuration data for the run. + + @fn as_dict + @brief Returns the configuration as a dictionary. + + @return dict The configuration as a dictionary. + """ + + # Type annotations for class attributes + input_prefixes: list[str] + input_mem_file: str + using_trace_file: bool + trace_file: str + input_dir: str + output_dir: str + output_prefix: str + + __initialized = False # specifies whether static members have been initialized + # contains the dictionary of all configuration items supported and their + # default value (or None if no default) + __default_config: dict[str, Any] = {} + + def __init__(self, **kwargs): + """ + @brief Constructs a new LinkerRunConfig Object from input parameters. + + See base class constructor for more parameters. + + @param input_prefixes List of input prefixes, including full path. For an input prefix, linker will + assume there are three files named `input_prefixes[i] + '.minst'`, + `input_prefixes[i] + '.cinst'`, and `input_prefixes[i] + '.xinst'`. + This list must not be empty. + @param output_prefix Prefix for the output file names. + Three files will be generated: + `output_dir/output_prefix.minst`, `output_dir/output_prefix.cinst`, and + `output_dir/output_prefix.xinst`. + Output filenames cannot match input file names. + @param input_mem_file Input memory file associated with the result kernel. + @param output_dir OPTIONAL directory where to store all intermediate files and final output. + This will be created if it doesn't exists. + Defaults to current working directory. + + @exception TypeError A mandatory configuration value was missing. + @exception ValueError At least, one of the arguments passed is invalid. + """ + super().__init__(**kwargs) + + self.init_default_config() + + # Validate input parameters + if "hbm_size" in kwargs and kwargs["hbm_size"] is not None: + if not isinstance(kwargs["hbm_size"], int): + raise ValueError("Invalid param: hbm_size must be an integer") + if kwargs["hbm_size"] < 0: + raise ValueError( + "Invalid param: hbm_size must be a non-negative integer" + ) + + if "has_hbm" in kwargs and not isinstance(kwargs["has_hbm"], bool): + raise ValueError("Invalid param: has_hbm must be a boolean value") + + # class members based on configuration + for config_name, default_value in self.__default_config.items(): + value = kwargs.get(config_name, default_value) + if value is not None: + setattr(self, config_name, value) + else: + if not hasattr(self, config_name): + setattr(self, config_name, default_value) + if getattr(self, config_name) is None: + raise TypeError( + f"Expected value for configuration `{config_name}`, but `None` received." + ) + + # Fix file paths + # E0203: Access to member 'input_mem_file' before its definition. + # But it was defined in previous loop. + if self.input_mem_file != "": # pylint: disable=E0203 + self.input_mem_file = makeUniquePath(self.input_mem_file) + if self.trace_file != "": + self.trace_file = makeUniquePath(self.trace_file) + + self.output_dir = makeUniquePath(self.output_dir) + self.input_dir = makeUniquePath(self.input_dir) + + @classmethod + def init_default_config(cls): + """ + @brief Initializes static members of the class. + """ + if not cls.__initialized: + cls.__default_config["input_prefixes"] = "" + cls.__default_config["input_mem_file"] = "" + cls.__default_config["using_trace_file"] = False + cls.__default_config["trace_file"] = "" + cls.__default_config["output_dir"] = os.getcwd() + cls.__default_config["input_dir"] = os.getcwd() + cls.__default_config["output_prefix"] = None + + cls.__initialized = True + + def __str__(self): + """ + @brief Provides a string representation of the configuration. + + @return str The string for the configuration. + """ + self_dict = self.as_dict() + with io.StringIO() as retval_f: + for key, value in self_dict.items(): + print(f"{key}: {value}", file=retval_f) + retval = retval_f.getvalue() + return retval + + def as_dict(self) -> dict: + """ + @brief Provides the configuration as a dictionary. + + @return dict The configuration. + """ + retval = super().as_dict() + tmp_self_dict = vars(self) + retval.update( + { + config_name: tmp_self_dict[config_name] + for config_name in self.__default_config + } + ) + return retval diff --git a/assembler_tools/hec-assembler-tools/linker/loader.py b/assembler_tools/hec-assembler-tools/linker/loader.py index eeee007e..443914c4 100644 --- a/assembler_tools/hec-assembler-tools/linker/loader.py +++ b/assembler_tools/hec-assembler-tools/linker/loader.py @@ -8,6 +8,10 @@ @brief This module provides functionality to load different types of instruction kernels """ +import copy + +from typing import Any + from linker.instructions import minst from linker.instructions import cinst from linker.instructions import xinst @@ -15,130 +19,179 @@ from linker import instructions -def load_minst_kernel(line_iter) -> list: - """ - @brief Loads MInstruction kernel from an iterator of lines. - - @param line_iter An iterator over lines of MInstruction strings. - @return A list of MInstruction objects. - @throws RuntimeError If a line cannot be parsed into an MInstruction. - """ - retval = [] - for idx, s_line in enumerate(line_iter): - minstr = instructions.create_from_str_line(s_line, minst.factory()) - if not minstr: - raise RuntimeError(f"Error parsing line {idx + 1}: {s_line}") - retval.append(minstr) - return retval - - -def load_minst_kernel_from_file(filename: str) -> list: - """ - @brief Loads MInstruction kernel from a file. - - @param filename The file containing MInstruction strings. - @return A list of MInstruction objects. - @throws RuntimeError If an error occurs while loading the file. - """ - with open(filename, "r", encoding="utf-8") as kernel_minsts: - try: - return load_minst_kernel(kernel_minsts) - except Exception as e: - raise RuntimeError(f'Error occurred loading file "{filename}"') from e - - -def load_cinst_kernel(line_iter) -> list: - """ - @brief Loads CInstruction kernel from an iterator of lines. - - @param line_iter An iterator over lines of CInstruction strings. - @return A list of CInstruction objects. - @throws RuntimeError If a line cannot be parsed into a CInstruction. - """ - retval = [] - for idx, s_line in enumerate(line_iter): - cinstr = instructions.create_from_str_line(s_line, cinst.factory()) - if not cinstr: - raise RuntimeError(f"Error parsing line {idx + 1}: {s_line}") - retval.append(cinstr) - return retval - - -def load_cinst_kernel_from_file(filename: str) -> list: - """ - @brief Loads CInstruction kernel from a file. - - @param filename The file containing CInstruction strings. - @return A list of CInstruction objects. - @throws RuntimeError If an error occurs while loading the file. - """ - with open(filename, "r", encoding="utf-8") as kernel_cinsts: - try: - return load_cinst_kernel(kernel_cinsts) - except Exception as e: - raise RuntimeError(f'Error occurred loading file "{filename}"') from e - - -def load_xinst_kernel(line_iter) -> list: - """ - @brief Loads XInstruction kernel from an iterator of lines. - - @param line_iter An iterator over lines of XInstruction strings. - @return A list of XInstruction objects. - @throws RuntimeError If a line cannot be parsed into an XInstruction. - """ - retval = [] - for idx, s_line in enumerate(line_iter): - xinstr = instructions.create_from_str_line(s_line, xinst.factory()) - if not xinstr: - raise RuntimeError(f"Error parsing line {idx + 1}: {s_line}") - retval.append(xinstr) - return retval - - -def load_xinst_kernel_from_file(filename: str) -> list: - """ - @brief Loads XInstruction kernel from a file. - - @param filename The file containing XInstruction strings. - @return A list of XInstruction objects. - @throws RuntimeError If an error occurs while loading the file. - """ - with open(filename, "r", encoding="utf-8") as kernel_xinsts: - try: - return load_xinst_kernel(kernel_xinsts) - except Exception as e: - raise RuntimeError(f'Error occurred loading file "{filename}"') from e - - -def load_dinst_kernel(line_iter) -> list: - """ - @brief Loads DInstruction kernel from an iterator of lines. - - @param line_iter An iterator over lines of DInstruction strings. - @return A list of DInstruction objects. - @throws RuntimeError If a line cannot be parsed into an DInstruction. - """ - retval = [] - for idx, s_line in enumerate(line_iter): - dinstr = dinst.create_from_mem_line(s_line) - if not dinstr: - raise RuntimeError(f"Error parsing line {idx + 1}: {s_line}") - retval.append(dinstr) - - return retval - - -def load_dinst_kernel_from_file(filename: str) -> list: - """ - @brief Loads DInstruction kernel from a file. - - @param filename The file containing DInstruction strings. - @return A list of DInstruction objects. - @throws RuntimeError If an error occurs while loading the file. - """ - with open(filename, "r", encoding="utf-8") as kernel_dinsts: - try: - return load_dinst_kernel(kernel_dinsts) - except Exception as e: - raise RuntimeError(f'Error occurred loading file "{filename}"') from e +class Loader: + """ + @class Loader + @brief A class that provides methods to load different types of instruction kernels. + """ + + # Class-level file cache + _file_cache: dict[tuple, Any] = {} + + @classmethod + def flush_cache(cls): + """ + @brief Clears the file loading cache. + """ + cls._file_cache.clear() + + @classmethod + def load_minst_kernel(cls, line_iter) -> list: + """ + @brief Loads MInstruction kernel from an iterator of lines. + + @param line_iter An iterator over lines of MInstruction strings. + @return A list of MInstruction objects. + @throws RuntimeError If a line cannot be parsed into an MInstruction. + """ + retval = [] + for idx, s_line in enumerate(line_iter): + minstr = instructions.create_from_str_line(s_line, minst.factory()) + if not minstr: + raise RuntimeError(f"Error parsing line {idx + 1}: {s_line}") + retval.append(minstr) + return retval + + @classmethod + def load_minst_kernel_from_file(cls, filename: str, use_cache: bool = True) -> list: + """ + @brief Loads MInstruction kernel from a file. + + @param filename The file containing MInstruction strings. + @param use_cache Whether to use cached results if available. + @return A list of MInstruction objects. + @throws RuntimeError If an error occurs while loading the file. + """ + cache_key = (filename, "minst") + if use_cache and cache_key in cls._file_cache: + return copy.deepcopy(cls._file_cache[cache_key]) + + with open(filename, "r", encoding="utf-8") as kernel_minsts: + try: + result = cls.load_minst_kernel(kernel_minsts) + if use_cache: + cls._file_cache[cache_key] = result + return copy.deepcopy(result) + except Exception as e: + raise RuntimeError(f'Error occurred loading file "{filename}"') from e + + @classmethod + def load_cinst_kernel(cls, line_iter) -> list: + """ + @brief Loads CInstruction kernel from an iterator of lines. + + @param line_iter An iterator over lines of CInstruction strings. + @return A list of CInstruction objects. + @throws RuntimeError If a line cannot be parsed into a CInstruction. + """ + retval = [] + for idx, s_line in enumerate(line_iter): + cinstr = instructions.create_from_str_line(s_line, cinst.factory()) + if not cinstr: + raise RuntimeError(f"Error parsing line {idx + 1}: {s_line}") + retval.append(cinstr) + return retval + + @classmethod + def load_cinst_kernel_from_file(cls, filename: str, use_cache: bool = True) -> list: + """ + @brief Loads CInstruction kernel from a file. + + @param filename The file containing CInstruction strings. + @param use_cache Whether to use cached results if available. + @return A list of CInstruction objects. + @throws RuntimeError If an error occurs while loading the file. + """ + cache_key = (filename, "cinst") + if use_cache and cache_key in cls._file_cache: + return copy.deepcopy(cls._file_cache[cache_key]) + + with open(filename, "r", encoding="utf-8") as kernel_cinsts: + try: + result = cls.load_cinst_kernel(kernel_cinsts) + if use_cache: + cls._file_cache[cache_key] = result + return copy.deepcopy(result) + except Exception as e: + raise RuntimeError(f'Error occurred loading file "{filename}"') from e + + @classmethod + def load_xinst_kernel(cls, line_iter) -> list: + """ + @brief Loads XInstruction kernel from an iterator of lines. + + @param line_iter An iterator over lines of XInstruction strings. + @return A list of XInstruction objects. + @throws RuntimeError If a line cannot be parsed into an XInstruction. + """ + retval = [] + for idx, s_line in enumerate(line_iter): + xinstr = instructions.create_from_str_line(s_line, xinst.factory()) + if not xinstr: + raise RuntimeError(f"Error parsing line {idx + 1}: {s_line}") + retval.append(xinstr) + return retval + + @classmethod + def load_xinst_kernel_from_file(cls, filename: str, use_cache: bool = True) -> list: + """ + @brief Loads XInstruction kernel from a file. + + @param filename The file containing XInstruction strings. + @param use_cache Whether to use cached results if available. + @return A list of XInstruction objects. + @throws RuntimeError If an error occurs while loading the file. + """ + cache_key = (filename, "xinst") + if use_cache and cache_key in cls._file_cache: + return copy.deepcopy(cls._file_cache[cache_key]) + + with open(filename, "r", encoding="utf-8") as kernel_xinsts: + try: + result = cls.load_xinst_kernel(kernel_xinsts) + if use_cache: + cls._file_cache[cache_key] = result + return copy.deepcopy(result) + except Exception as e: + raise RuntimeError(f'Error occurred loading file "{filename}"') from e + + @classmethod + def load_dinst_kernel(cls, line_iter) -> list: + """ + @brief Loads DInstruction kernel from an iterator of lines. + + @param line_iter An iterator over lines of DInstruction strings. + @return A list of DInstruction objects. + @throws RuntimeError If a line cannot be parsed into an DInstruction. + """ + retval = [] + for idx, s_line in enumerate(line_iter): + dinstr = dinst.create_from_mem_line(s_line) + if not dinstr: + raise RuntimeError(f"Error parsing line {idx + 1}: {s_line}") + retval.append(dinstr) + + return retval + + @classmethod + def load_dinst_kernel_from_file(cls, filename: str, use_cache: bool = True) -> list: + """ + @brief Loads DInstruction kernel from a file. + + @param filename The file containing DInstruction strings. + @param use_cache Whether to use cached results if available. + @return A list of DInstruction objects. + @throws RuntimeError If an error occurs while loading the file. + """ + cache_key = (filename, "dinst") + if use_cache and cache_key in cls._file_cache: + return copy.deepcopy(cls._file_cache[cache_key]) + + with open(filename, "r", encoding="utf-8") as kernel_dinsts: + try: + result = cls.load_dinst_kernel(kernel_dinsts) + if use_cache: + cls._file_cache[cache_key] = result + return copy.deepcopy(result) + except Exception as e: + raise RuntimeError(f'Error occurred loading file "{filename}"') from e diff --git a/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py b/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py index 4096bd9e..43db1a98 100644 --- a/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py +++ b/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py @@ -8,8 +8,10 @@ from typing import Dict, Any, cast from linker import MemoryModel +from linker.loader import Loader from linker.instructions import minst, cinst, dinst from linker.instructions.dinst.dinstruction import DInstruction +from linker.kern_trace.kern_remap import remap_m_c_instrs_vars from assembler.common.config import GlobalConfig from assembler.instructions import cinst as ISACInst @@ -490,9 +492,56 @@ def join_dinst_kernels( continue # Add remaining carry-over variables to the new instructions - for _, var in carry_over_vars.items(): - var.address = mem_address - new_kernels_instrs.append(var) + for _, dintr in carry_over_vars.items(): + dintr.address = mem_address + new_kernels_instrs.append(dintr) mem_address = mem_address + 1 return new_kernels_instrs + + @staticmethod + def link_kernels_to_files( + input_files, output_files, mem_model, verbose_stream=None + ): + """ + @brief Links input kernels and writes the output to the specified files. + + @param input_files List of KernelInfo for input kernels. + @param output_files KernelInfo for output. + @param mem_model Memory model to use. + @param verbose_stream Stream for verbose output. + """ + with open(output_files.minst, "w", encoding="utf-8") as fnum_output_minst, open( + output_files.cinst, "w", encoding="utf-8" + ) as fnum_output_cinst, open( + output_files.xinst, "w", encoding="utf-8" + ) as fnum_output_xinst: + + result_program = LinkedProgram( + fnum_output_minst, fnum_output_cinst, fnum_output_xinst, mem_model + ) + + for idx, kernel in enumerate(input_files): + if verbose_stream: + print( + f"[ {idx * 100 // len(input_files): >3}% ]", + kernel.prefix, + file=verbose_stream, + ) + kernel_minstrs = Loader.load_minst_kernel_from_file(kernel.minst) + kernel_cinstrs = Loader.load_cinst_kernel_from_file(kernel.cinst) + kernel_xinstrs = Loader.load_xinst_kernel_from_file(kernel.xinst) + + remap_m_c_instrs_vars(kernel_minstrs, kernel.remap_dict) + remap_m_c_instrs_vars(kernel_cinstrs, kernel.remap_dict) + + result_program.link_kernel( + kernel_minstrs, kernel_cinstrs, kernel_xinstrs + ) + if verbose_stream: + print( + "[ 100% ] Finalizing output", + output_files.prefix, + file=verbose_stream, + ) + result_program.close() diff --git a/assembler_tools/hec-assembler-tools/linker/steps/variable_discovery.py b/assembler_tools/hec-assembler-tools/linker/steps/variable_discovery.py index c862de25..dfd1485c 100644 --- a/assembler_tools/hec-assembler-tools/linker/steps/variable_discovery.py +++ b/assembler_tools/hec-assembler-tools/linker/steps/variable_discovery.py @@ -4,10 +4,15 @@ """ @brief This module provides functionality to discover variable names in MInstructions and CInstructions. """ +from typing import Optional, TextIO, List from assembler.memory_model.variable import Variable +from assembler.memory_model import MemoryModel +from assembler.common.config import GlobalConfig from linker.instructions import minst, cinst from linker.instructions.minst.minstruction import MInstruction from linker.instructions.cinst.cinstruction import CInstruction +from linker.kern_trace import KernelInfo, remap_m_c_instrs_vars +from linker.loader import Loader def discover_variables_spad(cinstrs: list): @@ -23,7 +28,7 @@ def discover_variables_spad(cinstrs: list): for idx, cinstr in enumerate(cinstrs): if not isinstance(cinstr, CInstruction): raise TypeError( - f"Item {idx} in list of MInstructions is not a valid MInstruction." + f"Item {idx} in list of CInstructions is not a valid CInstruction." ) retval = None # TODO: Implement variable counting for CInst @@ -69,3 +74,56 @@ def discover_variables(minstrs: list): f'Invalid Variable name "{retval}" detected in instruction "{idx}, {minstr.to_line()}"' ) yield retval + + +def scan_variables( + kernels_info: List[KernelInfo], + mem_model: MemoryModel, + verbose_stream: Optional[TextIO] = None, +): + """ + @brief Scans input files for variables and adds them to the memory model. + + @param kernels_info List of KernelInfo for input. + @param mem_model Memory model to update. + @param verbose_stream Stream for verbose output. + """ + for idx, kernel_info in enumerate(kernels_info): + + if not GlobalConfig.hasHBM: + if verbose_stream: + print( + f" {idx + 1}/{len(kernels_info)}", + kernel_info.cinst, + file=verbose_stream, + ) + kernel_cinstrs = Loader.load_cinst_kernel_from_file(kernel_info.cinst) + remap_m_c_instrs_vars(kernel_cinstrs, kernel_info.remap_dict) + for var_name in discover_variables_spad(kernel_cinstrs): + mem_model.add_variable(var_name) + else: + if verbose_stream: + print( + f" {idx + 1}/{len(kernels_info)}", + kernel_info.minst, + file=verbose_stream, + ) + kernel_minstrs = Loader.load_minst_kernel_from_file(kernel_info.minst) + remap_m_c_instrs_vars(kernel_minstrs, kernel_info.remap_dict) + for var_name in discover_variables(kernel_minstrs): + mem_model.add_variable(var_name) + + +def check_unused_variables(mem_model): + """ + @brief Checks for unused variables in the memory model and raises an error if found. + + @param mem_model Memory model to check. + @exception RuntimeError If an unused variable is found. + """ + for var_name in mem_model.mem_info_vars: + if var_name not in mem_model.variables: + if GlobalConfig.hasHBM or var_name not in mem_model.mem_info_meta: + raise RuntimeError( + f'Unused variable from input mem file: "{var_name}" not in memory model.' + ) diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_link.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_link.py index eff9a306..07b7f73a 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_link.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_link.py @@ -8,437 +8,13 @@ @file test_he_link.py @brief Unit tests for the he_link module """ - -import os +import io import argparse -from unittest.mock import patch, mock_open, MagicMock, PropertyMock +from unittest.mock import patch, mock_open, MagicMock import pytest import he_link -from assembler.common.config import GlobalConfig - - -class TestLinkerRunConfig: - """ - @class TestLinkerRunConfig - @brief Test cases for the LinkerRunConfig class - """ - - def test_init_with_valid_params(self): - """ - @brief Test initialization with valid parameters - """ - # Arrange - kwargs = { - "input_prefixes": ["prefix1", "prefix2"], - "output_prefix": "output_prefix", - "input_mem_file": "input.mem", - "output_dir": "/tmp", - "has_hbm": True, - "hbm_size": 1024, - "suppress_comments": False, - "use_xinstfetch": False, - "multi_mem_files": False, - } - - # Act - with patch("he_link.makeUniquePath", side_effect=lambda x: x): - config = he_link.LinkerRunConfig(**kwargs) - - # Assert - assert config.input_prefixes == ["prefix1", "prefix2"] - assert config.output_prefix == "output_prefix" - assert config.input_mem_file == "input.mem" - assert config.output_dir == "/tmp" - assert config.has_hbm is True - assert config.hbm_size == 1024 - assert config.suppress_comments is False - assert config.use_xinstfetch is False - assert config.multi_mem_files is False - - def test_init_with_missing_required_param(self): - """ - @brief Test initialization with missing required parameters - """ - # Arrange - kwargs = { - "output_prefix": "output_prefix", - "input_mem_file": "input.mem", - # Missing input_prefixes - } - - # Act & Assert - with pytest.raises(TypeError): - he_link.LinkerRunConfig(**kwargs) - - def test_as_dict(self): - """ - @brief Test the as_dict method returns a proper dictionary - """ - # Arrange - kwargs = { - "input_prefixes": ["prefix1"], - "output_prefix": "output_prefix", - "input_mem_file": "input.mem", - "output_dir": "/tmp", - "has_hbm": True, - "hbm_size": 1024, - } - - # Act - with patch("he_link.makeUniquePath", side_effect=lambda x: x): - config = he_link.LinkerRunConfig(**kwargs) - result = config.as_dict() - - # Assert Keys - assert isinstance(result, dict) - assert "input_prefixes" in result - assert "output_prefix" in result - assert "input_mem_file" in result - assert "output_dir" in result - assert "has_hbm" in result - assert "hbm_size" in result - - # Assert values - assert result["input_prefixes"] == ["prefix1"] - assert result["output_prefix"] == "output_prefix" - assert result["input_mem_file"] == "input.mem" - assert result["output_dir"] == "/tmp" - assert result["has_hbm"] is True - assert result["hbm_size"] == 1024 - - def test_str_representation(self): - """ - @brief Test the string representation of the configuration - """ - # Arrange - kwargs = { - "input_prefixes": ["prefix1"], - "output_prefix": "output_prefix", - "input_mem_file": "input.mem", - } - - # Act - with patch("he_link.makeUniquePath", side_effect=lambda x: x): - config = he_link.LinkerRunConfig(**kwargs) - result = str(config) - - # Assert params - assert "input_prefixes" in result - assert "output_prefix" in result - assert "input_mem_file" in result - # Assert values - assert "prefix1" in result - assert "output_prefix" in result - assert "input.mem" in result - - def test_init_for_default_params(self): - """ - @brief Test initialization with default parameters - """ - - # Arrange - kwargs = {"input_prefixes": ["prefix1"], "output_prefix": ""} - - # Reset the class-level config so the patch will take effect - he_link.RunConfig.reset_class_state() - - # Act - with patch("he_link.makeUniquePath", side_effect=lambda x: x), patch.object( - he_link.RunConfig, "DEFAULT_HBM_SIZE_KB", new_callable=PropertyMock - ) as mock_hbm_size, patch.object( - GlobalConfig, "suppress_comments", new_callable=PropertyMock - ) as mock_suppress_comments, patch.object( - GlobalConfig, "useXInstFetch", new_callable=PropertyMock - ) as mock_use_xinstfetch: - - # Mock the default HBM size - mock_suppress_comments.return_value = False - mock_use_xinstfetch.return_value = False - mock_hbm_size.return_value = 1024 - config = he_link.LinkerRunConfig(**kwargs) - - # Assert - assert config.output_prefix == "" - assert config.input_mem_file == "" - assert config.output_dir == os.getcwd() - assert config.has_hbm is True - assert config.hbm_size == 1024 - assert config.suppress_comments is False - assert config.use_xinstfetch is False - assert config.multi_mem_files is False - - -class TestKernelFiles: - """ - @class TestKernelFiles - @brief Test cases for the KernelFiles class - """ - - def test_kernel_files_creation(self): - """ - @brief Test KernelFiles creation and attribute access - """ - # Act - kernel_files = he_link.KernelFiles( - prefix="prefix", - minst="prefix.minst", - cinst="prefix.cinst", - xinst="prefix.xinst", - mem="prefix.mem", - ) - - # Assert - assert kernel_files.prefix == "prefix" - assert kernel_files.minst == "prefix.minst" - assert kernel_files.cinst == "prefix.cinst" - assert kernel_files.xinst == "prefix.xinst" - assert kernel_files.mem == "prefix.mem" - - def test_kernel_files_without_mem(self): - """ - @brief Test KernelFiles creation without mem file - """ - # Act - kernel_files = he_link.KernelFiles( - prefix="prefix", - minst="prefix.minst", - cinst="prefix.cinst", - xinst="prefix.xinst", - ) - - # Assert - assert kernel_files.prefix == "prefix" - assert kernel_files.minst == "prefix.minst" - assert kernel_files.cinst == "prefix.cinst" - assert kernel_files.xinst == "prefix.xinst" - assert kernel_files.mem is None - - -class TestHelperFunctions: - """ - @class TestHelperFunctions - @brief Test cases for helper functions in he_link - """ - - def test_prepare_output_files(self): - """ - @brief Test prepare_output_files function creates correct output files - """ - # Arrange - mock_config = MagicMock() - mock_config.output_dir = "/tmp" - mock_config.output_prefix = "output" - mock_config.multi_mem_files = False - - # Act - with patch("os.path.dirname", return_value="/tmp"), patch( - "pathlib.Path.mkdir" - ), patch("he_link.makeUniquePath", side_effect=lambda x: x): - result = he_link.prepare_output_files(mock_config) - - # Assert - assert result.prefix == "/tmp/output" - assert result.minst == "/tmp/output.minst" - assert result.cinst == "/tmp/output.cinst" - assert result.xinst == "/tmp/output.xinst" - assert result.mem is None - - def test_prepare_output_files_with_mem(self): - """ - @brief Test prepare_output_files with multi_mem_files=True - """ - # Arrange - mock_config = MagicMock() - mock_config.output_dir = "/tmp" - mock_config.output_prefix = "output" - mock_config.multi_mem_files = True - - # Act - with patch("os.path.dirname", return_value="/tmp"), patch( - "pathlib.Path.mkdir" - ), patch("he_link.makeUniquePath", side_effect=lambda x: x): - result = he_link.prepare_output_files(mock_config) - - # Assert - assert result.prefix == "/tmp/output" - assert result.minst == "/tmp/output.minst" - assert result.cinst == "/tmp/output.cinst" - assert result.xinst == "/tmp/output.xinst" - assert result.mem == "/tmp/output.mem" - - def test_prepare_input_files(self): - """ - @brief Test prepare_input_files function - """ - # Arrange - mock_config = MagicMock() - mock_config.input_prefixes = ["/tmp/input1", "/tmp/input2"] - mock_config.multi_mem_files = False - - mock_output_files = he_link.KernelFiles( - prefix="/tmp/output", - minst="/tmp/output.minst", - cinst="/tmp/output.cinst", - xinst="/tmp/output.xinst", - ) - - # Act - with patch("os.path.isfile", return_value=True), patch( - "he_link.makeUniquePath", side_effect=lambda x: x - ): - result = he_link.prepare_input_files(mock_config, mock_output_files) - - # Assert - assert len(result) == 2 - assert result[0].prefix == "/tmp/input1" - assert result[0].minst == "/tmp/input1.minst" - assert result[0].cinst == "/tmp/input1.cinst" - assert result[0].xinst == "/tmp/input1.xinst" - assert result[0].mem is None - assert result[1].prefix == "/tmp/input2" - - def test_prepare_input_files_file_not_found(self): - """ - @brief Test prepare_input_files when a file doesn't exist - """ - # Arrange - mock_config = MagicMock() - mock_config.input_prefixes = ["/tmp/input1"] - mock_config.multi_mem_files = False - - mock_output_files = he_link.KernelFiles( - prefix="/tmp/output", - minst="/tmp/output.minst", - cinst="/tmp/output.cinst", - xinst="/tmp/output.xinst", - ) - - # Act & Assert - with patch("os.path.isfile", return_value=False), patch( - "he_link.makeUniquePath", side_effect=lambda x: x - ): - with pytest.raises(FileNotFoundError): - he_link.prepare_input_files(mock_config, mock_output_files) - - def test_prepare_input_files_output_conflict(self): - """ - @brief Test prepare_input_files when input and output files conflict - """ - # Arrange - mock_config = MagicMock() - mock_config.input_prefixes = ["/tmp/input1"] - mock_config.multi_mem_files = False - - # Output file matching an input file - mock_output_files = he_link.KernelFiles( - prefix="/tmp/output", - minst="/tmp/input1.minst", # Conflict - cinst="/tmp/output.cinst", - xinst="/tmp/output.xinst", - ) - - # Act & Assert - with patch("os.path.isfile", return_value=True), patch( - "he_link.makeUniquePath", side_effect=lambda x: x - ): - with pytest.raises(RuntimeError): - he_link.prepare_input_files(mock_config, mock_output_files) - - @pytest.mark.parametrize("has_hbm", [True, False]) - def test_scan_variables(self, has_hbm): - """ - @brief Test scan_variables function with and without HBM - @param has_hbm Boolean indicating whether HBM is enabled - """ - # Arrange - GlobalConfig.hasHBM = has_hbm - mock_mem_model = MagicMock() - mock_verbose = MagicMock() - - input_files = [ - he_link.KernelFiles( - prefix="/tmp/input1", - minst="/tmp/input1.minst", - cinst="/tmp/input1.cinst", - xinst="/tmp/input1.xinst", - ) - ] - - # Act - with patch("linker.loader.load_minst_kernel_from_file", return_value=[]), patch( - "linker.loader.load_cinst_kernel_from_file", return_value=[] - ), patch( - "linker.steps.variable_discovery.discover_variables", - return_value=["var1", "var2"], - ), patch( - "linker.steps.variable_discovery.discover_variables_spad", - return_value=["var1", "var2"], - ): - he_link.scan_variables(input_files, mock_mem_model, mock_verbose) - - # Assert - if has_hbm: - assert mock_mem_model.add_variable.call_count == 2 - else: - assert mock_mem_model.add_variable.call_count == 2 - - def test_check_unused_variables(self): - """ - @brief Test check_unused_variables function - """ - # Arrange - GlobalConfig.hasHBM = True - mock_mem_model = MagicMock() - mock_mem_model.mem_info_vars = {"var1": MagicMock(), "var2": MagicMock()} - mock_mem_model.variables = {"var1"} - mock_mem_model.mem_info_meta = {} - - # Act & Assert - with pytest.raises(RuntimeError): - he_link.check_unused_variables(mock_mem_model) - - def test_link_kernels(self): - """ - @brief Test link_kernels function - """ - # Arrange - input_files = [ - he_link.KernelFiles( - prefix="/tmp/input1", - minst="/tmp/input1.minst", - cinst="/tmp/input1.cinst", - xinst="/tmp/input1.xinst", - ) - ] - - output_files = he_link.KernelFiles( - prefix="/tmp/output", - minst="/tmp/output.minst", - cinst="/tmp/output.cinst", - xinst="/tmp/output.xinst", - ) - - mock_mem_model = MagicMock() - mock_verbose = MagicMock() - - # Act - with patch("builtins.open", mock_open()), patch( - "linker.loader.load_minst_kernel_from_file", return_value=[] - ), patch("linker.loader.load_cinst_kernel_from_file", return_value=[]), patch( - "linker.loader.load_xinst_kernel_from_file", return_value=[] - ), patch( - "linker.steps.program_linker.LinkedProgram" - ) as mock_linked_program: - he_link.link_kernels( - input_files, output_files, mock_mem_model, mock_verbose - ) - - # Assert - mock_linked_program.assert_called_once() - instance = mock_linked_program.return_value - assert instance.link_kernel.call_count == 1 - assert instance.close.call_count == 1 +from linker.kern_trace import KernelInfo class TestMainFunction: @@ -447,37 +23,42 @@ class TestMainFunction: @brief Test cases for the main function """ - @pytest.mark.parametrize("multi_mem_files", [True, False]) - def test_main(self, multi_mem_files): + @pytest.mark.parametrize("using_trace_file", [True, False]) + def test_main(self, using_trace_file): """ - @brief Test main function with and without multi_mem_files + @brief Test main function with and without using_trace_file """ # Arrange mock_config = MagicMock() - mock_config.multi_mem_files = multi_mem_files + mock_config.using_trace_file = using_trace_file mock_config.has_hbm = True mock_config.hbm_size = 1024 mock_config.suppress_comments = False mock_config.use_xinstfetch = False - # Setup input files with conditional mem files + # The expected kernel name pattern from parse_kernel_ops + expected_kernel_name = "kernel1_pisa.tw" + + # Setup input files with conditional mem files - ensure prefix matches expected pattern input_files = [ - he_link.KernelFiles( - prefix="prefix1", - minst="prefix1.minst", - cinst="prefix1.cinst", - xinst="prefix1.xinst", - mem="prefix1.mem" if multi_mem_files else None, - ), - he_link.KernelFiles( - prefix="prefix2", - minst="prefix2.minst", - cinst="prefix2.cinst", - xinst="prefix2.xinst", - mem="prefix2.mem" if multi_mem_files else None, + KernelInfo( + { + "directory": "/tmp", + "prefix": expected_kernel_name, # Match the expected name pattern + "minst": f"{expected_kernel_name}.minst", + "cinst": f"{expected_kernel_name}.cinst", + "xinst": f"{expected_kernel_name}.xinst", + "mem": f"{expected_kernel_name}.mem" if using_trace_file else None, + } ), ] + # Create mock DInstructions with proper .var attributes + mock_dinstr1 = MagicMock() + mock_dinstr1.var = "ct0_data" + mock_dinstr2 = MagicMock() + mock_dinstr2.var = "pt1_result" + # Create a dictionary of mocks to reduce the number of local variables mocks = { "prepare_output": MagicMock(), @@ -487,11 +68,36 @@ def test_main(self, multi_mem_files): "link_kernels": MagicMock(), "from_dinstrs": MagicMock(), "from_file_iter": MagicMock(), - "load_dinst": MagicMock(return_value=["1", "2"]), + "load_dinst": MagicMock( + return_value=[mock_dinstr1, mock_dinstr2] + ), # Return mock DInstructions "join_dinst": MagicMock(return_value=[]), "dump_instructions": MagicMock(), + "remap_dinstrs_vars": MagicMock(return_value={"old_var": "new_var"}), + "update_input_prefixes": MagicMock( + return_value={"kernel1_pisa.tw": MagicMock()} + ), + "remap_vars": MagicMock( + return_value=([mock_dinstr1, mock_dinstr2], {"key": "value"}) + ), + "initialize_memory_model": MagicMock(), + # Return a kernel_op with expected_in_kern_file_name that will match our input file prefix + "parse_kernel_ops": MagicMock( + return_value=[ + MagicMock( + expected_in_kern_file_name="kernel1", + kern_vars=[ + MagicMock(label="input"), + MagicMock(label="output"), + ], # Add mock kern_vars + ) + ] + ), } + # Add trace_file property to mock_config + mock_config.trace_file = "mock_trace.txt" if using_trace_file else "" + # Act with patch( "assembler.common.constants.convertBytes2Words", return_value=1024 @@ -500,7 +106,7 @@ def test_main(self, multi_mem_files): ), patch( "assembler.common.counter.Counter.reset" ), patch( - "linker.loader.load_dinst_kernel_from_file", mocks["load_dinst"] + "he_link.Loader.load_dinst_kernel_from_file", mocks["load_dinst"] ), patch( "linker.instructions.BaseInstruction.dump_instructions_to_file", mocks["dump_instructions"], @@ -520,12 +126,24 @@ def test_main(self, multi_mem_files): ), patch( "he_link.check_unused_variables", mocks["check_unused_variables"] ), patch( - "he_link.link_kernels", mocks["link_kernels"] + "linker.kern_trace.TraceInfo.parse_kernel_ops", mocks["parse_kernel_ops"] ), patch( - "he_link.BaseInstruction.dump_instructions_to_file", - mocks["dump_instructions"], + "os.path.isfile", + return_value=True, # Make all file existence checks return True + ), patch( + "linker.steps.program_linker.LinkedProgram.link_kernels_to_files", + mocks["link_kernels"], + ), patch( + "linker.kern_trace.remap_dinstrs_vars", mocks["remap_dinstrs_vars"] + ), patch( + "he_link.update_input_prefixes", mocks["update_input_prefixes"] + ), patch( + "he_link.remap_vars", mocks["remap_vars"] + ), patch( + "he_link.initialize_memory_model", mocks["initialize_memory_model"] ): + # Run the main function with all patches in place he_link.main(mock_config, MagicMock()) # Assert pipeline is run as expected @@ -535,18 +153,17 @@ def test_main(self, multi_mem_files): mocks["check_unused_variables"].assert_called_once() mocks["link_kernels"].assert_called_once() - if multi_mem_files: - # Should use from_dinstrs, not from_file_iter - assert mocks["from_dinstrs"].called - assert mocks["load_dinst"].called - assert mocks["join_dinst"].called - assert mocks["dump_instructions"].called - + if using_trace_file: + # Assert that the trace processing flow was used + mocks["update_input_prefixes"].assert_called_once() + mocks["remap_vars"].assert_called_once() + mocks["initialize_memory_model"].assert_called_once() assert not mocks["from_file_iter"].called else: - # Should use from_file_iter, not from_dinstrs - assert mocks["from_file_iter"].called - assert not mocks["from_dinstrs"].called + # Assert that the normal flow was used + assert not mocks["update_input_prefixes"].called + assert not mocks["remap_vars"].called + mocks["initialize_memory_model"].assert_called_once() def test_warning_on_use_xinstfetch(self): """ @@ -554,7 +171,7 @@ def test_warning_on_use_xinstfetch(self): """ # Arrange mock_config = MagicMock() - mock_config.multi_mem_files = False + mock_config.using_trace_file = False mock_config.has_hbm = True mock_config.hbm_size = 1024 mock_config.suppress_comments = False @@ -564,8 +181,8 @@ def test_warning_on_use_xinstfetch(self): # Act & Assert with patch("warnings.warn") as mock_warn, patch( "assembler.common.constants.convertBytes2Words", return_value=1024 - ), patch("he_link.prepare_output_files"), patch( - "he_link.prepare_input_files" + ), patch("linker.he_link_utils.prepare_output_files"), patch( + "linker.he_link_utils.prepare_input_files" ), patch( "assembler.common.counter.Counter.reset" ), patch( @@ -575,11 +192,11 @@ def test_warning_on_use_xinstfetch(self): ), patch( "linker.MemoryModel" ), patch( - "he_link.scan_variables" + "linker.steps.variable_discovery.scan_variables" ), patch( - "he_link.check_unused_variables" + "linker.steps.variable_discovery.check_unused_variables" ), patch( - "he_link.link_kernels" + "linker.steps.program_linker.LinkedProgram.link_kernels_to_files" ): he_link.main(mock_config, None) mock_warn.assert_called_once() @@ -595,25 +212,17 @@ def test_parse_args_minimal(self): """ @brief Test parse_args with minimal arguments """ - # Arrange - test_args = [ - "program", - "input_prefix", - "-o", - "output_prefix", - "-im", - "input.mem", - ] - - # Act - with patch("sys.argv", test_args), patch( + # Act - Mock the return value of parse_args directly + with patch( "argparse.ArgumentParser.parse_args", return_value=argparse.Namespace( input_prefixes=["input_prefix"], output_prefix="output_prefix", input_mem_file="input.mem", + trace_file="", + input_dir="", output_dir="", - multi_mem_files=False, + using_trace_file=False, mem_spec_file="", isa_spec_file="", has_hbm=True, @@ -628,30 +237,23 @@ def test_parse_args_minimal(self): assert args.input_prefixes == ["input_prefix"] assert args.output_prefix == "output_prefix" assert args.input_mem_file == "input.mem" - assert args.multi_mem_files is False + assert args.using_trace_file is False - def test_parse_args_multi_mem_files(self): + def test_parse_args_using_trace_file(self): """ - @brief Test parse_args with multi_mem_files flag + @brief Test parse_args with using_trace_file flag """ - # Arrange - test_args = [ - "program", - "input_prefix", - "-o", - "output_prefix", - "--multi_mem_files", - ] - - # Act - with patch("sys.argv", test_args), patch( + # Act - Mock the return value of parse_args directly + with patch( "argparse.ArgumentParser.parse_args", return_value=argparse.Namespace( - input_prefixes=["input_prefix"], + input_prefixes=None, output_prefix="output_prefix", input_mem_file="", + input_dir="", + trace_file="trace_file_path", output_dir="", - multi_mem_files=True, + using_trace_file=None, # This should be computed by parse_args function mem_spec_file="", isa_spec_file="", has_hbm=True, @@ -663,27 +265,76 @@ def test_parse_args_multi_mem_files(self): args = he_link.parse_args() # Assert - assert args.input_prefixes == ["input_prefix"] assert args.output_prefix == "output_prefix" - assert args.input_mem_file == "" - assert args.multi_mem_files is True + assert args.trace_file == "trace_file_path" + assert args.using_trace_file is True # Should be computed from trace_file - def test_missing_input_mem_file(self): + def test_trace_file_with_missing_output_prefix(self): """ - @brief Test parse_args with missing input_mem_file when multi_mem_files is False + @brief Test parse_args when trace_file is provided but output_prefix (always required) is missing """ - # Arrange - test_args = ["program", "input_prefix", "-o", "output_prefix"] + # Instead of manually creating a namespace with missing required arguments, + # we'll create an argv list that's missing the required argument + mock_argv = ["he_link.py", "--use_trace_file", "trace_file_path"] - # Act & Assert - with patch("sys.argv", test_args), patch( + # Create a StringIO to capture the error output + error_output = io.StringIO() + + # Patch sys.argv and sys.stderr + with patch("sys.argv", mock_argv), patch("sys.stderr", error_output), patch( + "sys.exit" + ) as mock_exit: + # When required args are missing, argparse will call sys.exit() + he_link.parse_args() + + # Verify that exit was called (indicating an error) + mock_exit.assert_called() + + # Verify the error output contains information about the missing required argument + error_message = error_output.getvalue() + assert "output_prefix" in error_message + assert "required" in error_message.lower() + + def test_required_args_when_trace_file_not_set(self): + """ + @brief Test that input_mem_file and input_prefixes are required when trace_file is not set + """ + # Case 1: Missing input_mem_file + with patch( "argparse.ArgumentParser.parse_args", return_value=argparse.Namespace( input_prefixes=["input_prefix"], output_prefix="output_prefix", - input_mem_file="", + input_mem_file="", # Empty input_mem_file + trace_file="", # No trace file + input_dir="", + output_dir="", + using_trace_file=False, + mem_spec_file="", + isa_spec_file="", + has_hbm=True, + hbm_size=None, + suppress_comments=False, + verbose=0, + ), + ), patch("argparse.ArgumentParser.error") as mock_error: + he_link.parse_args() + # Verify error was called for missing input_mem_file + mock_error.assert_called_once_with( + "the following arguments are required: -im/--input_mem_file (unless --use_trace_file is set)" + ) + + # Case 2: Missing input_prefixes + with patch( + "argparse.ArgumentParser.parse_args", + return_value=argparse.Namespace( + input_prefixes=None, # Missing input_prefixes + output_prefix="output_prefix", + input_mem_file="input.mem", + trace_file="", # No trace file + input_dir="", output_dir="", - multi_mem_files=False, + using_trace_file=False, mem_spec_file="", isa_spec_file="", has_hbm=True, @@ -693,4 +344,260 @@ def test_missing_input_mem_file(self): ), ), patch("argparse.ArgumentParser.error") as mock_error: he_link.parse_args() - mock_error.assert_called_once() + # Verify error was called for missing input_prefixes + mock_error.assert_called_once_with( + "the following arguments are required: -ip/--input_prefixes (unless --use_trace_file is set)" + ) + + def test_ignored_args_when_trace_file_set(self): + """ + @brief Test that input_mem_file and input_prefixes are ignored with warnings when trace_file is set + """ + # Both input_mem_file and input_prefixes are provided but should be ignored + with patch( + "argparse.ArgumentParser.parse_args", + return_value=argparse.Namespace( + input_prefixes=["input_prefix"], # Will be ignored + output_prefix="output_prefix", + input_mem_file="input.mem", # Will be ignored + trace_file="trace_file_path", # Trace file is provided + input_dir="", + output_dir="", + using_trace_file=None, # Will be computed + mem_spec_file="", + isa_spec_file="", + has_hbm=True, + hbm_size=None, + suppress_comments=False, + verbose=0, + ), + ), patch("warnings.warn") as mock_warn: + args = he_link.parse_args() + + # Verify using_trace_file is set based on trace_file + assert args.using_trace_file is True + + # Verify warnings were issued for ignored arguments + assert mock_warn.call_count == 2 + # Check warning messages + warning_messages = [call.args[0] for call in mock_warn.call_args_list] + assert any("Ignoring input_mem_file" in msg for msg in warning_messages) + assert any("Ignoring input_prefixes" in msg for msg in warning_messages) + + def test_hbm_flags_parsing(self): + """ + @brief Test the parsing of --hbm_size and --no_hbm flags + """ + # Test with hbm_size set to valid value + with patch( + "argparse.ArgumentParser.parse_args", + return_value=argparse.Namespace( + input_prefixes=["input_prefix"], + output_prefix="output_prefix", + input_mem_file="input.mem", + trace_file="", + input_dir="", + output_dir="", + using_trace_file=False, + mem_spec_file="", + isa_spec_file="", + has_hbm=True, + hbm_size=2048, # Valid hbm_size + suppress_comments=False, + verbose=0, + ), + ): + args = he_link.parse_args() + assert args.hbm_size == 2048 + assert args.has_hbm is True + + # Test with --no_hbm flag set + with patch( + "argparse.ArgumentParser.parse_args", + return_value=argparse.Namespace( + input_prefixes=["input_prefix"], + output_prefix="output_prefix", + input_mem_file="input.mem", + trace_file="", + input_dir="", + output_dir="", + using_trace_file=False, + mem_spec_file="", + isa_spec_file="", + has_hbm=False, # --no_hbm flag set + hbm_size=None, + suppress_comments=False, + verbose=0, + ), + ): + args = he_link.parse_args() + assert args.has_hbm is False + + # Test with hbm_size set to 0 + with patch( + "argparse.ArgumentParser.parse_args", + return_value=argparse.Namespace( + input_prefixes=["input_prefix"], + output_prefix="output_prefix", + input_mem_file="input.mem", + trace_file="", + input_dir="", + output_dir="", + using_trace_file=False, + mem_spec_file="", + isa_spec_file="", + has_hbm=True, + hbm_size=0, # Edge case: zero + suppress_comments=False, + verbose=0, + ), + ): + args = he_link.parse_args() + assert args.hbm_size == 0 + + def test_verbose_flag_parsing(self): + """ + @brief Test the parsing of -v/--verbose flag at different levels + """ + # Test with no verbose flag (default) + with patch( + "argparse.ArgumentParser.parse_args", + return_value=argparse.Namespace( + input_prefixes=["input_prefix"], + output_prefix="output_prefix", + input_mem_file="input.mem", + trace_file="", + input_dir="", + output_dir="", + using_trace_file=False, + mem_spec_file="", + isa_spec_file="", + has_hbm=True, + hbm_size=None, + suppress_comments=False, + verbose=0, # Default level + ), + ): + args = he_link.parse_args() + assert args.verbose == 0 + + # Test with single -v flag + with patch( + "argparse.ArgumentParser.parse_args", + return_value=argparse.Namespace( + input_prefixes=["input_prefix"], + output_prefix="output_prefix", + input_mem_file="input.mem", + trace_file="", + input_dir="", + output_dir="", + using_trace_file=False, + mem_spec_file="", + isa_spec_file="", + has_hbm=True, + hbm_size=None, + suppress_comments=False, + verbose=1, # Single -v + ), + ): + args = he_link.parse_args() + assert args.verbose == 1 + + # Test with double -vv flag + with patch( + "argparse.ArgumentParser.parse_args", + return_value=argparse.Namespace( + input_prefixes=["input_prefix"], + output_prefix="output_prefix", + input_mem_file="input.mem", + trace_file="", + input_dir="", + output_dir="", + using_trace_file=False, + mem_spec_file="", + isa_spec_file="", + has_hbm=True, + hbm_size=None, + suppress_comments=False, + verbose=2, # Double -vv + ), + ): + args = he_link.parse_args() + assert args.verbose == 2 + + # Test with high verbosity + with patch( + "argparse.ArgumentParser.parse_args", + return_value=argparse.Namespace( + input_prefixes=["input_prefix"], + output_prefix="output_prefix", + input_mem_file="input.mem", + trace_file="", + input_dir="", + output_dir="", + using_trace_file=False, + mem_spec_file="", + isa_spec_file="", + has_hbm=True, + hbm_size=None, + suppress_comments=False, + verbose=5, # High verbosity + ), + ): + args = he_link.parse_args() + assert args.verbose == 5 + + def test_input_dir_defaults_to_trace_file_directory(self): + """ + @brief Test that input_dir defaults to the directory of trace_file when not specified + """ + # Test with trace_file set but input_dir not set + with patch( + "argparse.ArgumentParser.parse_args", + return_value=argparse.Namespace( + input_prefixes=None, + output_prefix="output_prefix", + input_mem_file="", + input_dir="", # Not specified + trace_file="/path/to/trace_file.txt", # Trace file with a directory path + output_dir="", + using_trace_file=None, # Will be computed + mem_spec_file="", + isa_spec_file="", + has_hbm=True, + hbm_size=None, + suppress_comments=False, + verbose=0, + ), + ), patch("os.path.dirname", return_value="/path/to") as mock_dirname: + args = he_link.parse_args() + + # Verify input_dir is set to the directory of trace_file + mock_dirname.assert_called_once_with("/path/to/trace_file.txt") + assert args.input_dir == "/path/to" + + # Test with both trace_file and input_dir specified - input_dir should not be overwritten + with patch( + "argparse.ArgumentParser.parse_args", + return_value=argparse.Namespace( + input_prefixes=None, + output_prefix="output_prefix", + input_mem_file="", + input_dir="/custom/path", # Specified by user + trace_file="/path/to/trace_file.txt", + output_dir="", + using_trace_file=None, + mem_spec_file="", + isa_spec_file="", + has_hbm=True, + hbm_size=None, + suppress_comments=False, + verbose=0, + ), + ), patch("os.path.dirname") as mock_dirname: + args = he_link.parse_args() + + # Verify dirname was not called since input_dir was already specified + mock_dirname.assert_not_called() + # Input_dir should remain as specified + assert args.input_dir == "/custom/path" diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_he_link_utils.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_he_link_utils.py new file mode 100644 index 00000000..c46b889e --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_he_link_utils.py @@ -0,0 +1,509 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# These contents may have been developed with support from one +# or more Intel-operated generative artificial intelligence solutions + +""" +@file test_he_link_utils.py +@brief Unit tests for the he_link_utils module +""" +from unittest.mock import patch, mock_open, MagicMock +import pytest + +from linker.he_link_utils import ( + prepare_output_files, + prepare_input_files, + update_input_prefixes, + remap_vars, + initialize_memory_model, +) +from linker.kern_trace.trace_info import KernelInfo +from assembler.common import constants + + +class TestHelperFunctions: + """ + @class TestHelperFunctions + @brief Test cases for helper functions in he_link_utils + """ + + def test_prepare_output_files(self): + """ + @brief Test prepare_output_files function creates correct output files + """ + # Arrange + mock_config = MagicMock() + mock_config.output_dir = "/tmp" + mock_config.output_prefix = "output" + mock_config.using_trace_file = False + + # Act + with patch("os.path.dirname", return_value="/tmp"), patch( + "pathlib.Path.mkdir" + ), patch("assembler.common.makeUniquePath", side_effect=lambda x: x): + result = prepare_output_files(mock_config) + + # Assert + assert result.directory == "/tmp" + assert result.prefix == "output" + assert result.minst == "/tmp/output.minst" + assert result.cinst == "/tmp/output.cinst" + assert result.xinst == "/tmp/output.xinst" + assert result.mem is None + + def test_prepare_output_files_with_mem(self): + """ + @brief Test prepare_output_files with using_trace_file=True + """ + # Arrange + mock_config = MagicMock() + mock_config.output_dir = "/tmp" + mock_config.output_prefix = "output" + mock_config.using_trace_file = True + + # Act + with patch("os.path.dirname", return_value="/tmp"), patch( + "pathlib.Path.mkdir" + ), patch("assembler.common.makeUniquePath", side_effect=lambda x: x): + result = prepare_output_files(mock_config) + + # Assert + assert result.directory == "/tmp" + assert result.prefix == "output" + assert result.minst == "/tmp/output.minst" + assert result.cinst == "/tmp/output.cinst" + assert result.xinst == "/tmp/output.xinst" + assert result.mem == "/tmp/output.mem" + + def test_prepare_input_files(self): + """ + @brief Test prepare_input_files function + """ + # Arrange + mock_config = MagicMock() + mock_config.input_dir = "/tmp" + mock_config.input_prefixes = ["input1", "input2"] + mock_config.using_trace_file = False + + mock_output_files = KernelInfo( + { + "directory": "/tmp", + "prefix": "output", + "minst": "/tmp/output.minst", + "cinst": "/tmp/output.cinst", + "xinst": "/tmp/output.xinst", + } + ) + + # Act + with patch("os.path.isfile", return_value=True), patch( + "assembler.common.makeUniquePath", side_effect=lambda x: x + ): + result = prepare_input_files(mock_config, mock_output_files) + + # Assert + assert len(result) == 2 + assert result[0].directory == "/tmp" + assert result[0].prefix == "input1" + assert result[0].minst == "/tmp/input1.minst" + assert result[0].cinst == "/tmp/input1.cinst" + assert result[0].xinst == "/tmp/input1.xinst" + assert result[0].mem is None + assert result[1].prefix == "input2" + + def test_prepare_input_files_file_not_found(self): + """ + @brief Test prepare_input_files when a file doesn't exist + """ + # Arrange + mock_config = MagicMock() + mock_config.input_dir = "/tmp" + mock_config.input_prefixes = ["input1"] + mock_config.using_trace_file = False + + mock_output_files = KernelInfo( + { + "directory": "/tmp", + "prefix": "output", + "minst": "/tmp/output.minst", + "cinst": "/tmp/output.cinst", + "xinst": "/tmp/output.xinst", + "mem": None, + } + ) + + # Act & Assert + with patch("os.path.isfile", return_value=False), patch( + "assembler.common.makeUniquePath", side_effect=lambda x: x + ): + with pytest.raises(FileNotFoundError): + prepare_input_files(mock_config, mock_output_files) + + def test_prepare_input_files_output_conflict(self): + """ + @brief Test prepare_input_files when input and output files conflict + """ + # Arrange + mock_config = MagicMock() + mock_config.input_dir = "/tmp" + mock_config.input_prefixes = ["input1"] + mock_config.using_trace_file = False + + # Output file matching an input file + output_files = KernelInfo( + { + "directory": "/tmp", + "prefix": "output", + "minst": "/tmp/input1.minst", # Conflict + "cinst": "/tmp/output.cinst", + "xinst": "/tmp/output.xinst", + "mem": None, + } + ) + + # Act & Assert + with patch("os.path.isfile", return_value=True), patch( + "assembler.common.makeUniquePath", side_effect=lambda x: x + ): + with pytest.raises(RuntimeError): + prepare_input_files(mock_config, output_files) + + def test_update_input_prefixes(self): + """ + @brief Test update_input_prefixes correctly processes a trace file and returns kernel operations dictionary + """ + # Arrange + mock_config = MagicMock() + mock_config.input_prefixes = [] + + # Create mock kernel ops with expected_in_kern_file_name attribute + mock_kernel_op1 = MagicMock() + mock_kernel_op1.expected_in_kern_file_name = "kernel1" + + mock_kernel_op2 = MagicMock() + mock_kernel_op2.expected_in_kern_file_name = "kernel2" + + mock_kernel_ops = [mock_kernel_op1, mock_kernel_op2] + + # Act + # with patch("linker.he_link_utils.TraceInfo") as mock_trace_info_class: + # Configure the mock TraceInfo instance + # mock_trace_info = mock_trace_info_class.return_value + # mock_trace_info.parse_kernel_ops.return_value = mock_kernel_ops + + # Call the function under test + update_input_prefixes(mock_kernel_ops, mock_config) + + # Assert + # Verify the input_prefixes were updated in the run_config + assert mock_config.input_prefixes == ["kernel1_pisa.tw", "kernel2_pisa.tw"] + + def test_update_input_prefixes_with_empty_kernel_ops(self): + """ + @brief Test update_input_prefixes correctly handles empty kernel operations + """ + # Arrange + mock_config = MagicMock() + mock_config.input_prefixes = ["should_be_cleared"] + mock_kernel_ops = [] # Empty list of kernel ops + + # Act + # Call the function under test with empty kernel_ops list + update_input_prefixes(mock_kernel_ops, mock_config) + + # Assert + # Verify the input_prefixes were updated (cleared) in the run_config + assert not mock_config.input_prefixes + + def _create_kernel_test_data(self): + """ + @brief Helper method to create test data for remap_vars tests + @return Tuple containing test data: (kernels_files, kernels_dinstrs, kernel_ops, expected_dicts) + """ + # Create mock kernel files + mock_files = [] + for i in range(1, 3): + kernel_file = MagicMock(spec=KernelInfo) + kernel_file.prefix = f"kernel{i}_pisa.tw" + kernel_file.mem = f"/path/to/kernel{i}.mem" + mock_files.append(kernel_file) + + # Create mock kernel operations + kernel_ops = [] + for i in range(1, 3): + kernel_op = MagicMock() + kernel_op.expected_in_kern_file_name = f"kernel{i}" + kernel_ops.append(kernel_op) + + # Create test dinstructions data + dinstrs = [] + for i in range(1, 4): + dinstr = MagicMock() + dinstr.var = f"var{i}" + dinstrs.append(dinstr) + + # Setup test data structures + kernel_dinstrs = [ + [dinstrs[0]], # kernel1 dinstrs + [dinstrs[1], dinstrs[2]], # kernel2 dinstrs + ] + + # Expected remap dictionaries + expected_dicts = { + "var1": "mapped_var1", + "var2": "mapped_var2", + "var3": "mapped_var3", + } + + # Pack test data + test_data = { + "files": mock_files, + "kernel_ops": kernel_ops, + "dinstrs": dinstrs, + "kernel_dinstrs": kernel_dinstrs, + "expected_dicts": expected_dicts, + "joined_dinstrs": dinstrs, # All dinstrs joined + } + + return test_data + + def test_remap_vars_with_multiple_kernels(self): + """ + @brief Test remap_vars with multiple input kernel files + """ + # Arrange - Get test data from helper method + test_data = self._create_kernel_test_data() + + # Act + with patch("linker.he_link_utils.remap_dinstrs_vars") as mock_remap_vars: + + # Configure mocks + mock_remap_vars.side_effect = [ + test_data["expected_dicts"], + test_data["expected_dicts"], + ] + + # Call function under test + remap_vars( + test_data["files"], + test_data["kernel_dinstrs"], + test_data["kernel_ops"], + MagicMock(), + ) + + # Assert + # Verify remap_dinstrs_vars was called for each kernel with the correct arguments + assert mock_remap_vars.call_count == 2 + + # First call + mock_remap_vars.assert_any_call( + test_data["kernel_dinstrs"][0], test_data["kernel_ops"][0] + ) + + # Second call + mock_remap_vars.assert_any_call( + test_data["kernel_dinstrs"][1], test_data["kernel_ops"][1] + ) + + # Verify the remap_dict was set on each kernel file + assert test_data["files"][0].remap_dict == test_data["expected_dicts"] + assert test_data["files"][1].remap_dict == test_data["expected_dicts"] + + def test_remap_vars_with_mismatched_prefixes(self): + """ + @brief Test remap_vars correctly handles mismatched prefixes + """ + # Arrange + mock_files = [MagicMock(spec=KernelInfo)] + mock_files[0].prefix = "kernel1_pisa.tw" + + kernel_ops = [MagicMock()] + kernel_ops[0].expected_in_kern_file_name = "different_kernel" + + kernel_dinstrs = [[MagicMock()]] + + # Act & Assert + with pytest.raises(AssertionError, match="prefix .* does not match"): + remap_vars(mock_files, kernel_dinstrs, kernel_ops, MagicMock()) + + def test_remap_vars_with_empty_input(self): + """ + @brief Test remap_vars with an empty list of input files + """ + # Arrange + kernel_files = [] + kernel_dinstrs = [] + kernel_ops = [] + verbose_stream = MagicMock() + + # Act + # No exception should be raised for empty inputs + remap_vars(kernel_files, kernel_dinstrs, kernel_ops, verbose_stream) + + # Assert + # Just verifying the function completes without error + + def test_remap_vars_length_mismatch(self): + """ + @brief Test remap_vars correctly handles mismatched lengths + """ + # Arrange - mismatched lengths between files and ops + kernel_files = [MagicMock(), MagicMock()] + kernel_dinstrs = [[MagicMock()]] + kernel_ops = [MagicMock()] + + # Act & Assert + with pytest.raises(AssertionError, match="Number of kernels_files must match"): + remap_vars(kernel_files, kernel_dinstrs, kernel_ops, MagicMock()) + + # Arrange - mismatched lengths between dinstrs and ops + kernel_files = [MagicMock()] + kernel_dinstrs = [[MagicMock()], [MagicMock()]] + kernel_ops = [MagicMock()] + + # Act & Assert + with pytest.raises(AssertionError, match="Number of kernel_dinstrs must match"): + remap_vars(kernel_files, kernel_dinstrs, kernel_ops, MagicMock()) + + def test_initialize_memory_model_with_kernel_dinstrs(self): + """ + @brief Test initialize_memory_model when kernel_dinstrs is provided (trace file mode) + """ + # Arrange + mock_config = MagicMock() + mock_config.hbm_size = 1024 + + # Create mock kernel DInstructions + mock_dinstrs = [MagicMock(), MagicMock()] + + # Create mock mem_meta_info + mock_mem_info = MagicMock() + + # Create mock verbose stream + mock_stream = MagicMock() + + # Act + with patch( + "assembler.common.constants.convertBytes2Words", return_value=1024 * 1024 + ) as mock_convert, patch( + "assembler.memory_model.mem_info.MemInfo.from_dinstrs", + return_value=mock_mem_info, + ) as mock_from_dinstrs, patch( + "linker.MemoryModel" + ) as mock_memory_model_class: + + # Configure mock memory model + mock_memory_model = mock_memory_model_class.return_value + mock_memory_model.hbm.capacity = 1024 * 1024 + + # Call function under test + result = initialize_memory_model(mock_config, mock_dinstrs, mock_stream) + + # Assert + # Verify convertBytes2Words was called with correct parameters + mock_convert.assert_called_once_with( + mock_config.hbm_size * constants.Constants.KILOBYTE + ) + + # Verify from_dinstrs was called with kernel_dinstrs + mock_from_dinstrs.assert_called_once_with(mock_dinstrs) + + # Verify MemoryModel was initialized with correct parameters + mock_memory_model_class.assert_called_once_with(1024 * 1024, mock_mem_info) + + # Verify output was written to the verbose stream + assert mock_stream.write.call_count >= 1 + + # Verify the result is the mock memory model + assert result is mock_memory_model + + def test_initialize_memory_model_with_input_mem_file(self): + """ + @brief Test initialize_memory_model when reading from input_mem_file (standard mode) + """ + # Arrange + mock_config = MagicMock() + mock_config.hbm_size = 2048 + mock_config.input_mem_file = "/path/to/input.mem" + + # Create mock mem_meta_info + mock_mem_info = MagicMock() # Create mock verbose stream + mock_stream = MagicMock() + + # Act + with patch( + "assembler.common.constants.convertBytes2Words", return_value=2048 * 1024 + ) as mock_convert, patch("builtins.open", mock_open()) as mock_open_file, patch( + "assembler.memory_model.mem_info.MemInfo.from_file_iter", + return_value=mock_mem_info, + ) as mock_from_file_iter, patch( + "linker.MemoryModel" + ) as mock_memory_model_class: + + # Configure mock memory model + mock_memory_model = mock_memory_model_class.return_value + mock_memory_model.hbm.capacity = 2048 * 1024 + + # Call function under test + result = initialize_memory_model(mock_config, None, mock_stream) + + # Assert + # Verify convertBytes2Words was called with correct parameters + mock_convert.assert_called_once_with( + mock_config.hbm_size * constants.Constants.KILOBYTE + ) + + # Verify open was called with input_mem_file + mock_open_file.assert_called_once_with( + mock_config.input_mem_file, "r", encoding="utf-8" + ) + + # Verify from_file_iter was called + assert mock_from_file_iter.called + + # Verify MemoryModel was initialized with correct parameters + mock_memory_model_class.assert_called_once_with(2048 * 1024, mock_mem_info) + + # Verify output was written to the verbose stream + assert mock_stream.write.call_count >= 1 + + # Verify the result is the mock memory model + assert result is mock_memory_model + + def test_initialize_memory_model_with_zero_hbm_size(self): + """ + @brief Test initialize_memory_model with hbm_size=0 + """ + # Arrange + mock_config = MagicMock() + mock_config.hbm_size = 0 # Zero HBM size + + # Create mock kernel DInstructions + mock_dinstrs = [MagicMock()] + + # Create mock mem_meta_info + mock_mem_info = MagicMock() + + # Act + with patch( + "assembler.common.constants.convertBytes2Words", return_value=0 + ) as mock_convert, patch( + "assembler.memory_model.mem_info.MemInfo.from_dinstrs", + return_value=mock_mem_info, + ), patch( + "linker.MemoryModel" + ) as mock_memory_model_class: + + # Call function under test + result = initialize_memory_model(mock_config, mock_dinstrs) + + # Assert + # Verify convertBytes2Words was called with 0 + mock_convert.assert_called_once_with(0) + + # Verify MemoryModel was initialized with hbm_capacity_words=0 + mock_memory_model_class.assert_called_once_with(0, mock_mem_info) + + # Verify the result is the mock memory model + assert result is mock_memory_model_class.return_value diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dinstruction.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dinstruction.py index a6c2cf0a..933f8ee1 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dinstruction.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dinstruction.py @@ -9,6 +9,7 @@ """ import unittest +from unittest.mock import patch, MagicMock from linker.instructions.dinst.dinstruction import DInstruction @@ -22,6 +23,17 @@ class TestDInstruction(unittest.TestCase): """ def setUp(self): + # Create a mock MemInfoVar class for testing + self.mock_miv = MagicMock() + self.mock_miv.as_dict.return_value = {"var_name": "var1", "hbm_address": 123} + + # Patch the MemInfo.get_meminfo_var_from_tokens method + self.mem_info_patcher = patch( + "linker.instructions.dinst.dinstruction.MemInfo.get_meminfo_var_from_tokens" + ) + self.mock_get_meminfo = self.mem_info_patcher.start() + self.mock_get_meminfo.return_value = (self.mock_miv, 1) + # Create a concrete subclass for testing since DInstruction is abstract class ConcreteDInstruction(DInstruction): """ @@ -37,13 +49,17 @@ def _get_num_tokens(cls) -> int: @classmethod def _get_name(cls) -> str: - return "test_instruction" + return "dload" - self.d_instruction_class = ConcreteDInstruction # Changed to snake_case - self.tokens = ["test_instruction", "var1", "123"] + self.d_instruction_class = ConcreteDInstruction + self.tokens = ["dload", "var1", "123"] self.comment = "Test comment" self.dinst = self.d_instruction_class(self.tokens, self.comment) + def tearDown(self): + # Stop the patcher + self.mem_info_patcher.stop() + def test_get_name_token_index(self): """@brief Test _get_name_token_index returns 0 @@ -95,16 +111,6 @@ def test_id_property(self): inst2 = self.d_instruction_class(self.tokens) self.assertNotEqual(inst1.id, inst2.id) - def test_to_line_method(self): - """@brief Test to_line method returns expected string - - @test Verifies the to_line method correctly formats the instruction as a string - """ - tokens = ["test_instruction", "var1", "123"] - inst = self.d_instruction_class(tokens, "") - expected = "test_instruction, var1, 123" - self.assertEqual(inst.to_line(), expected) - def test_consecutive_ids(self): """@brief Test that consecutive instructions get incremental ids @@ -114,6 +120,39 @@ def test_consecutive_ids(self): inst2 = self.d_instruction_class(self.tokens) self.assertEqual(inst2.id, inst1.id + 1) + def test_var_and_address_properties(self): + """@brief Test var and address properties are correctly set from MemInfo + + @test Verifies the var and address properties are set from MemInfo during initialization + """ + # Check that var and address were set from the mock MemInfo data + self.assertEqual(self.dinst.var, "var1") + self.assertEqual(self.dinst.address, 123) + + # Test property setters + self.dinst.var = "new_var" + self.assertEqual(self.dinst.var, "new_var") + + self.dinst.address = 456 + self.assertEqual(self.dinst.address, 456) + + def test_memory_info_error_handling(self): + """@brief Test error handling when MemInfo parsing fails + + @test Verifies that when MemInfo parsing fails, a ValueError is raised + with information about the parsing failure + """ + # Make the mock raise an exception + error_message = "Test error" + self.mock_get_meminfo.side_effect = RuntimeError(error_message) + + # The DInstruction.__init__ should convert RuntimeError to ValueError + with self.assertRaises(ValueError) as context: + self.d_instruction_class(self.tokens, self.comment) + + # Verify the error message contains the original error + self.assertIn(error_message, str(context.exception)) + if __name__ == "__main__": unittest.main() diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dload.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dload.py index c6168790..601a4aae 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dload.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dload.py @@ -27,7 +27,7 @@ def setUp(self): # Create the instruction self.var_name = "test_var" self.address = 123 - self.type = "type1" + self.type = "poly" def test_get_num_tokens(self): """@brief Test that _get_num_tokens returns 3 @@ -59,7 +59,8 @@ def test_initialization_valid_meta(self): @test Verifies the instruction handles metadata loading correctly """ - inst = Instruction([MemInfo.Const.Keyword.LOAD, self.type, str(self.address)]) + metadata = "ones" + inst = Instruction([MemInfo.Const.Keyword.LOAD, metadata, str(self.address)]) self.assertEqual(inst.name, MemInfo.Const.Keyword.LOAD) @@ -86,8 +87,6 @@ def test_tokens_property(self): [Instruction.name, self.type, str(self.address), self.var_name] ) - # Manually set properties to match expected behavior - inst.address = self.address self.assertEqual(inst.tokens, expected_tokens) def test_tokens_with_additional_data(self): diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_kern_trace/test_kern_remap.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_kern_trace/test_kern_remap.py new file mode 100644 index 00000000..2bbc8f09 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_kern_trace/test_kern_remap.py @@ -0,0 +1,336 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# These contents may have been developed with support from one +# or more Intel-operated generative artificial intelligence solutions + +""" +@file test_kern_remap.py +@brief Unit tests for the kern_remap module +""" +from unittest.mock import MagicMock +import pytest + +from linker.kern_trace.kern_remap import remap_dinstrs_vars, remap_m_c_instrs_vars +from linker.kern_trace.kern_var import KernVar +from linker.kern_trace.kernel_op import KernelOp +from linker.instructions.minst import MLoad, MStore +from linker.instructions.cinst import BLoad, CLoad, BOnes, NLoad, CStore + + +class TestRemapDinstrsVars: + """ + @class TestRemapDinstrsVars + @brief Test cases for the remap_dinstrs_vars function + """ + + def _create_mock_kernel_op(self): + """ + @brief Helper method to create a mock KernelOp with sorted variables + """ + mock_kernel_op = MagicMock(spec=KernelOp) + + # Create mock KernVar objects for the kern_vars property + mock_vars = [ + KernVar("input", 8192, 2), + KernVar("output", 8192, 2), + KernVar("temp", 8192, 2), + ] + + # Configure the mock to return the sorted variables + mock_kernel_op.kern_vars = mock_vars + + return mock_kernel_op + + def test_remap_ct_variables(self): + """ + @brief Test remapping of CT (ciphertext) variables + """ + # Arrange + # Create mock DInstructions with CT variable names + dinstr1 = MagicMock() + dinstr1.var = "ct0_data" + + dinstr2 = MagicMock() + dinstr2.var = "ct1_result" + + kernel_dinstrs = [dinstr1, dinstr2] + + # Create mock KernelOp + mock_kernel_op = self._create_mock_kernel_op() + + # Act + result = remap_dinstrs_vars(kernel_dinstrs, mock_kernel_op) + + # Assert + assert dinstr1.var == "input_data" # ct0 -> input (index 0) + assert dinstr2.var == "output_result" # ct1 -> output (index 1) + assert result == {"ct0_data": "input_data", "ct1_result": "output_result"} + + def test_remap_pt_variables(self): + """ + @brief Test remapping of PT (plaintext) variables + """ + # Arrange + # Create mock DInstructions with PT variable names + dinstr1 = MagicMock() + dinstr1.var = "pt0_data" + + dinstr2 = MagicMock() + dinstr2.var = "pt2_result" + + kernel_dinstrs = [dinstr1, dinstr2] + + # Create mock KernelOp + mock_kernel_op = self._create_mock_kernel_op() + + # Act + result = remap_dinstrs_vars(kernel_dinstrs, mock_kernel_op) + + # Assert + assert dinstr1.var == "input_data" # pt0 -> input (index 0) + assert dinstr2.var == "temp_result" # pt2 -> temp (index 2) + assert result == {"pt0_data": "input_data", "pt2_result": "temp_result"} + + def test_skip_non_ct_pt_variables(self): + """ + @brief Test that variables with prefixes other than CT/PT are skipped + """ + # Arrange + # Create mock DInstructions with various variable names + dinstr1 = MagicMock() + dinstr1.var = "ct0_data" # Should be remapped + + dinstr2 = MagicMock() + dinstr2.var = "ntt_data" # Should be skipped + + dinstr3 = MagicMock() + dinstr3.var = "psi_data" # Should be skipped + + kernel_dinstrs = [dinstr1, dinstr2, dinstr3] + + # Create mock KernelOp + mock_kernel_op = self._create_mock_kernel_op() + + # Act + result = remap_dinstrs_vars(kernel_dinstrs, mock_kernel_op) + + # Assert + assert dinstr1.var == "input_data" # ct0 -> input (index 0) + assert dinstr2.var == "ntt_data" # Unchanged + assert dinstr3.var == "psi_data" # Unchanged + assert result == { + "ct0_data": "input_data", + } + + def test_case_insensitivity(self): + """ + @brief Test that CT/PT prefixes are case-insensitive + """ + # Arrange + dinstr1 = MagicMock() + dinstr1.var = "CT0_data" # Uppercase CT + + dinstr2 = MagicMock() + dinstr2.var = "Pt1_result" # Mixed case PT + + kernel_dinstrs = [dinstr1, dinstr2] + + # Create mock KernelOp + mock_kernel_op = self._create_mock_kernel_op() + + # Act + result = remap_dinstrs_vars(kernel_dinstrs, mock_kernel_op) + + # Assert + assert dinstr1.var == "input_data" # CT0 -> input (index 0) + assert dinstr2.var == "output_result" # Pt1 -> output (index 1) + assert result == {"CT0_data": "input_data", "Pt1_result": "output_result"} + + def test_error_when_no_underscore(self): + """ + @brief Test error when variable name doesn't contain underscore + """ + # Arrange + dinstr = MagicMock() + dinstr.var = "ct0data" # No underscore + + kernel_dinstrs = [dinstr] + mock_kernel_op = self._create_mock_kernel_op() + + # Act & Assert + with pytest.raises(ValueError, match="does not contain items to split by '_'"): + remap_dinstrs_vars(kernel_dinstrs, mock_kernel_op) + + def test_error_when_no_number_in_prefix(self): + """ + @brief Test error when prefix doesn't contain a number + """ + # Arrange + dinstr = MagicMock() + dinstr.var = "ct_data" # No number in prefix + + kernel_dinstrs = [dinstr] + mock_kernel_op = self._create_mock_kernel_op() + + # Act & Assert + with pytest.raises(ValueError, match="does not contain a number after text"): + remap_dinstrs_vars(kernel_dinstrs, mock_kernel_op) + + def test_error_when_index_out_of_range(self): + """ + @brief Test error when index is out of range of kernel variables + """ + # Arrange + # Create a simple MagicMock instead of using spec=DInstruction + dinstr = MagicMock() + dinstr.var = "ct5_data" # Index 5 is out of range (only 3 variables) + + kernel_dinstrs = [dinstr] + mock_kernel_op = self._create_mock_kernel_op() + + # Act & Assert + with pytest.raises(IndexError, match="out of range"): + remap_dinstrs_vars(kernel_dinstrs, mock_kernel_op) + + +class TestRemapMCInstrsVars: + """ + @class TestRemapMCInstrsVars + @brief Test cases for the remap_m_c_instrs_vars function + """ + + def _create_remap_dict(self): + """ + @brief Helper method to create a remap dictionary + """ + return {"old_source": "new_source", "old_dest": "new_dest"} + + def test_remap_m_load_instructions(self): + """ + @brief Test remapping variables in MLoad instructions + """ + # Arrange + mock_instr = MagicMock(spec=MLoad) + mock_instr.source = "old_source" + mock_instr.comment = "" + + kernel_instrs = [mock_instr] + remap_dict = self._create_remap_dict() + + # Act + remap_m_c_instrs_vars(kernel_instrs, remap_dict) + + # Assert + assert mock_instr.source == "new_source" + + def test_remap_m_store_instructions(self): + """ + @brief Test remapping variables in MStore instructions + """ + # Arrange + mock_instr = MagicMock(spec=MStore) + mock_instr.dest = "old_dest" + mock_instr.comment = "Store old_dest" + + kernel_instrs = [mock_instr] + remap_dict = self._create_remap_dict() + + # Act + remap_m_c_instrs_vars(kernel_instrs, remap_dict) + + # Assert + assert mock_instr.dest == "new_dest" + + def test_remap_c_load_instructions(self): + """ + @brief Test remapping variables in CLoad, BLoad, BOnes, and NLoad instructions + """ + # Arrange + c_instrs = [] + + # Create mock instructions of each type + for instr_class in [CLoad, BLoad, BOnes, NLoad]: + mock_instr = MagicMock(spec=instr_class) + mock_instr.source = "old_source" + mock_instr.comment = "" + c_instrs.append(mock_instr) + + remap_dict = self._create_remap_dict() + + # Act + remap_m_c_instrs_vars(c_instrs, remap_dict) + + # Assert + for instr in c_instrs: + assert instr.source == "new_source" + + def test_remap_c_store_instructions(self): + """ + @brief Test remapping variables in CStore instructions + """ + # Arrange + mock_instr = MagicMock(spec=CStore) + mock_instr.dest = "old_dest" + mock_instr.comment = "Store old_dest" + + kernel_instrs = [mock_instr] + remap_dict = self._create_remap_dict() + + # Act + remap_m_c_instrs_vars(kernel_instrs, remap_dict) + + # Assert + assert mock_instr.dest == "new_dest" + + def test_skip_unmapped_variables(self): + """ + @brief Test that variables not in the remap dictionary are not changed + """ + # Arrange + mock_load = MagicMock(spec=MLoad) + mock_load.source = "unmapped_source" + + mock_store = MagicMock(spec=MStore) + mock_store.dest = "unmapped_dest" + + kernel_instrs = [mock_load, mock_store] + remap_dict = self._create_remap_dict() + + # Act + remap_m_c_instrs_vars(kernel_instrs, remap_dict) + + # Assert + assert mock_load.source == "unmapped_source" # Unchanged + assert mock_store.dest == "unmapped_dest" # Unchanged + + def test_empty_remap_dict(self): + """ + @brief Test function with an empty remap dictionary + """ + # Arrange + mock_instr = MagicMock(spec=MLoad) + mock_instr.source = "source" + + kernel_instrs = [mock_instr] + remap_dict = {} # Empty dict + + # Act + remap_m_c_instrs_vars(kernel_instrs, remap_dict) + + # Assert + assert mock_instr.source == "source" # Unchanged + + def test_invalid_instruction_type(self): + """ + @brief Test error when instruction is not a valid M or C instruction + """ + # Arrange + mock_instr = MagicMock() # Not a proper instruction type + + kernel_instrs = [mock_instr] + remap_dict = self._create_remap_dict() + + # Act & Assert + with pytest.raises(TypeError, match="not a valid M or C Instruction"): + remap_m_c_instrs_vars(kernel_instrs, remap_dict) diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_kern_trace/test_kern_var.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_kern_trace/test_kern_var.py new file mode 100644 index 00000000..b079ffb6 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_kern_trace/test_kern_var.py @@ -0,0 +1,123 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# These contents may have been developed with support from one +# or more Intel-operated generative artificial intelligence solutions + +""" +@file test_kern_var.py +@brief Unit tests for the KernVar class +""" +import pytest + +from linker.kern_trace.kern_var import KernVar + + +class TestKernVar: + """ + @class TestKernVar + @brief Test cases for the KernVar class + """ + + def test_init_with_valid_params(self): + """ + @brief Test initialization of KernVar with valid parameters + """ + # Arrange & Act + kern_var = KernVar("input", 8192, 3) + + # Assert + assert kern_var.label == "input" + assert kern_var.degree == 8192 + assert kern_var.level == 3 + + def test_from_string_with_valid_input(self): + """ + @brief Test from_string class method with valid input + """ + # Arrange + var_str = "input-8192-3" + + # Act + kern_var = KernVar.from_string(var_str) + + # Assert + assert kern_var.label == "input" + assert kern_var.degree == 8192 + assert kern_var.level == 3 + + def test_from_string_with_invalid_format(self): + """ + @brief Test from_string with invalid format (missing parts) + """ + # Arrange + invalid_var_strs = [ + "input", # Missing degree and level + "input-8192", # Missing level + "-8192-3", # Missing label + "input-8192-a", # Non digit + "input-d-0", # Non digit + "input-8192-3-extra", # Too many parts + ] + + # Act & Assert + for invalid_str in invalid_var_strs: + with pytest.raises(ValueError, match="Invalid"): + KernVar.from_string(invalid_str) + + def test_from_string_with_non_numeric_degree(self): + """ + @brief Test from_string with non-numeric degree + """ + # Arrange + var_str = "input-degree-3" + + # Act & Assert + with pytest.raises(ValueError): + KernVar.from_string(var_str) + + def test_from_string_with_non_numeric_level(self): + """ + @brief Test from_string with non-numeric level + """ + # Arrange + var_str = "input-8192-level" + + # Act & Assert + with pytest.raises(ValueError): + KernVar.from_string(var_str) + + def test_label_property_immutability(self): + """ + @brief Test that label property is immutable (read-only) + """ + # Arrange + kern_var = KernVar("input", 8192, 3) + + # Act & Assert + with pytest.raises(AttributeError): + kern_var.label = ( + "new_label" # Should raise AttributeError for read-only property + ) + + def test_degree_property_immutability(self): + """ + @brief Test that degree property is immutable (read-only) + """ + # Arrange + kern_var = KernVar("input", 8192, 3) + + # Act & Assert + with pytest.raises(AttributeError): + kern_var.degree = 4096 # Should raise AttributeError for read-only property + + def test_level_property_immutability(self): + """ + @brief Test that level property is immutable (read-only) + """ + # Arrange + kern_var = KernVar("input", 8192, 3) + + # Act & Assert + with pytest.raises(AttributeError): + kern_var.level = 1 # Should raise AttributeError for read-only property diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_kern_trace/test_kernel_op.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_kern_trace/test_kernel_op.py new file mode 100644 index 00000000..9bebf623 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_kern_trace/test_kernel_op.py @@ -0,0 +1,328 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# These contents may have been developed with support from one +# or more Intel-operated generative artificial intelligence solutions + +""" +@file test_kernel_op.py +@brief Unit tests for the KernelOp class +""" +from unittest.mock import patch +import pytest + +from linker.kern_trace.kernel_op import KernelOp +from linker.kern_trace.context_config import ContextConfig +from linker.kern_trace.kern_var import KernVar + + +class TestKernelOp: + """ + @class TestKernelOp + @brief Test cases for the KernelOp class + """ + + def _create_test_context_config(self): + """ + @brief Helper method to create a test ContextConfig + """ + return ContextConfig(scheme="CKKS", poly_mod_degree=8192, keyrns_terms=2) + + def _create_test_kern_args(self): + """ + @brief Helper method to create test kernel arguments + """ + return ["input-8192-2", "output-8192-2"] + + def test_init_with_valid_params(self): + """ + @brief Test initialization of KernelOp with valid parameters + """ + # Arrange + context_config = self._create_test_context_config() + kern_args = self._create_test_kern_args() + + # Act + kernel_op = KernelOp("add", context_config, kern_args) + + # Assert + assert kernel_op.name == "add" + assert kernel_op.scheme == "CKKS" + assert kernel_op.poly_modulus_degree == 8192 + assert kernel_op.keyrns_terms == 2 + assert kernel_op.level == 2 # From the level in test args + assert len(kernel_op.kern_vars) == 2 + assert isinstance(kernel_op.kern_vars[0], KernVar) + assert isinstance(kernel_op.kern_vars[1], KernVar) + assert kernel_op.expected_in_kern_file_name == "ckks_add_8192_l2_m2" + + def test_init_with_invalid_kernel_operation_name(self): + """ + @brief Test initialization with invalid kernel operation name + """ + # Arrange + context_config = self._create_test_context_config() + kern_args = self._create_test_kern_args() + + # Act & Assert + with pytest.raises(ValueError, match="Invalid kernel operation name"): + KernelOp("invalid_op", context_config, kern_args) + + def test_init_with_invalid_encryption_scheme(self): + """ + @brief Test initialization with invalid encryption scheme + """ + # Arrange + invalid_context = ContextConfig( + scheme="INVALID", poly_mod_degree=8192, keyrns_terms=2 + ) + kern_args = self._create_test_kern_args() + + # Act & Assert + with pytest.raises(ValueError, match="Invalid encryption scheme"): + KernelOp("add", invalid_context, kern_args) + + def test_init_with_insufficient_arguments(self): + """ + @brief Test initialization with insufficient arguments + """ + # Arrange + context_config = self._create_test_context_config() + insufficient_args = ["input-8192-2"] # Only one argument + + # Act & Assert + with pytest.raises(ValueError, match="at least two arguments"): + KernelOp("add", context_config, insufficient_args) + + def test_get_kern_var_objs(self): + """ + @brief Test get_kern_var_objs method + """ + # Arrange + kernel_op = KernelOp( + "add", self._create_test_context_config(), self._create_test_kern_args() + ) + test_var_strs = ["var1-1024-1", "var2-2048-2"] + + # Act - Using the private method for testing + with patch( + "linker.kern_trace.kern_var.KernVar.from_string" + ) as mock_from_string: + mock_from_string.side_effect = [ + KernVar("var1", 1024, 1), + KernVar("var2", 2048, 2), + ] + result = kernel_op.get_kern_var_objs(test_var_strs) + + # Assert + assert len(result) == 2 + assert isinstance(result[0], KernVar) + assert isinstance(result[1], KernVar) + assert result[0].label == "var1" + assert result[1].label == "var2" + + def test_get_level(self): + """ + @brief Test get_level method + """ + # Arrange + kernel_op = KernelOp( + "add", self._create_test_context_config(), self._create_test_kern_args() + ) + + # Create test KernVar objects + test_vars = [KernVar("var1", 1024, 1), KernVar("var2", 2048, 3)] + + # Act - Using the private method for testing + result = kernel_op.get_level(test_vars) + + # Assert + assert result == 3 # Should use the level from the second variable + + def test_get_level_with_single_var(self): + """ + @brief Test get_level method with a single variable + """ + # Arrange + kernel_op = KernelOp( + "add", self._create_test_context_config(), self._create_test_kern_args() + ) + + # Create test KernVar objects + test_vars = [KernVar("var1", 1024, 2)] + + # Act - Using the private method for testing + result = kernel_op.get_level(test_vars) + + # Assert + assert result == 2 # Should use the level from the only variable + + def test_get_level_with_empty_vars(self): + """ + @brief Test get_level method with empty variables list + """ + # Arrange + kernel_op = KernelOp( + "add", self._create_test_context_config(), self._create_test_kern_args() + ) + + # Act & Assert + with pytest.raises(ValueError, match="at least one variable"): + kernel_op.get_level([]) + + def test_str_representation(self): + """ + @brief Test string representation of KernelOp + """ + # Arrange + kernel_op = KernelOp( + "add", self._create_test_context_config(), self._create_test_kern_args() + ) + + # Act + result = str(kernel_op) + + # Assert + assert "KernelOp" in result + assert "add" in result + + def test_property_kern_vars(self): + """ + @brief Test kern_vars property + """ + # Arrange + kernel_op = KernelOp( + "add", self._create_test_context_config(), self._create_test_kern_args() + ) + + # Act + result = kernel_op.kern_vars + + # Assert + assert len(result) == 2 + assert isinstance(result[0], KernVar) + assert isinstance(result[1], KernVar) + assert result[0].label == "input" + assert result[1].label == "output" + + def test_property_name(self): + """ + @brief Test name property + """ + # Arrange + kernel_op = KernelOp( + "mul", self._create_test_context_config(), self._create_test_kern_args() + ) + + # Act + result = kernel_op.name + + # Assert + assert result == "mul" + + def test_property_scheme(self): + """ + @brief Test scheme property + """ + # Arrange + context = ContextConfig(scheme="BFV", poly_mod_degree=4096, keyrns_terms=1) + kernel_op = KernelOp("add", context, self._create_test_kern_args()) + + # Act + result = kernel_op.scheme + + # Assert + assert result == "BFV" + + def test_property_poly_modulus_degree(self): + """ + @brief Test poly_modulus_degree property + """ + # Arrange + context = ContextConfig(scheme="CKKS", poly_mod_degree=16384, keyrns_terms=2) + kernel_op = KernelOp("add", context, self._create_test_kern_args()) + + # Act + result = kernel_op.poly_modulus_degree + + # Assert + assert result == 16384 + + def test_property_keyrns_terms(self): + """ + @brief Test keyrns_terms property + """ + # Arrange + context = ContextConfig(scheme="CKKS", poly_mod_degree=8192, keyrns_terms=3) + kernel_op = KernelOp("add", context, self._create_test_kern_args()) + + # Act + result = kernel_op.keyrns_terms + + # Assert + assert result == 3 + + def test_property_level(self): + """ + @brief Test level property + """ + # Arrange + kernel_op = KernelOp( + "add", self._create_test_context_config(), ["var1-8192-4", "var2-8192-4"] + ) + + # Act + result = kernel_op.level + + # Assert + assert result == 4 + + def test_property_expected_in_kern_file_name(self): + """ + @brief Test expected_in_kern_file_name property + """ + # Arrange + context = ContextConfig(scheme="BGV", poly_mod_degree=2048, keyrns_terms=1) + kernel_op = KernelOp("mul", context, ["var1-2048-5", "var2-2048-5"]) + + # Act + result = kernel_op.expected_in_kern_file_name + + # Assert + assert result == "bgv_mul_2048_l5_m1" + + def test_case_insensitivity_of_operation_name(self): + """ + @brief Test that operation names are case-insensitive + """ + # Arrange + context_config = self._create_test_context_config() + kern_args = self._create_test_kern_args() + + # Act + kernel_op = KernelOp( + "ADD", context_config, kern_args + ) # Uppercase operation name + + # Assert + assert kernel_op.name == "ADD" + assert ( + kernel_op.expected_in_kern_file_name == "ckks_add_8192_l2_m2" + ) # Note: lowercase in file name + + def test_case_insensitivity_of_scheme(self): + """ + @brief Test that scheme names are case-insensitive + """ + # Arrange + context = ContextConfig( + scheme="ckks", poly_mod_degree=8192, keyrns_terms=2 + ) # Lowercase scheme + kern_args = self._create_test_kern_args() + + # Act + kernel_op = KernelOp("add", context, kern_args) + + # Assert + assert kernel_op.scheme == "ckks" + assert kernel_op.expected_in_kern_file_name == "ckks_add_8192_l2_m2" diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_kern_trace/test_trace_info.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_kern_trace/test_trace_info.py new file mode 100644 index 00000000..de18a111 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_kern_trace/test_trace_info.py @@ -0,0 +1,334 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# These contents may have been developed with support from one +# or more Intel-operated generative artificial intelligence solutions + +""" +@file test_trace_info.py +@brief Unit tests for the TraceInfo module and related classes +""" +from unittest.mock import patch, mock_open +import pytest + +from linker.kern_trace.trace_info import KernelInfo, TraceInfo +from linker.kern_trace.context_config import ContextConfig +from linker.kern_trace.kernel_op import KernelOp + + +class TestKernelInfo: + """ + @class TestKernelInfo + @brief Test cases for the KernelInfo class + """ + + def test_kernel_files_creation(self): + """ + @brief Test KernelInfo creation and attribute access + """ + # Act + kernel_files = KernelInfo( + { + "directory": "/tmp/dir", + "prefix": "prefix", + "minst": "prefix.minst", + "cinst": "prefix.cinst", + "xinst": "prefix.xinst", + "mem": "prefix.mem", + } + ) + + # Assert + assert kernel_files.directory == "/tmp/dir" + assert kernel_files.prefix == "prefix" + assert kernel_files.minst == "prefix.minst" + assert kernel_files.cinst == "prefix.cinst" + assert kernel_files.xinst == "prefix.xinst" + assert kernel_files.mem == "prefix.mem" + + def test_kernel_files_without_mem(self): + """ + @brief Test KernelInfo creation without mem file + """ + # Act + kernel_files = KernelInfo( + { + "directory": "/tmp/dir", + "prefix": "prefix", + "minst": "prefix.minst", + "cinst": "prefix.cinst", + "xinst": "prefix.xinst", + } + ) + + # Assert + assert kernel_files.directory == "/tmp/dir" + assert kernel_files.prefix == "prefix" + assert kernel_files.minst == "prefix.minst" + assert kernel_files.cinst == "prefix.cinst" + assert kernel_files.xinst == "prefix.xinst" + assert kernel_files.mem is None + + +class TestTraceInfo: + """ + @class TestTraceInfo + @brief Test cases for the TraceInfo class + """ + + def test_init(self): + """ + @brief Test initialization of TraceInfo class + """ + # Arrange & Act + trace_info = TraceInfo("/path/to/trace.txt") + + # Assert + assert trace_info.get_trace_file() == "/path/to/trace.txt" + + def test_str_representation(self): + """ + @brief Test string representation of TraceInfo + """ + # Arrange + trace_info = TraceInfo("/path/to/trace.txt") + + # Act + result = str(trace_info) + + # Assert + assert "TraceFile" in result + assert "/path/to/trace.txt" in result + + def test_get_param_index_dict(self): + """ + @brief Test get_param_index_dict method + """ + # Arrange + trace_info = TraceInfo("/path/to/trace.txt") + tokens = [ + "instruction", + "scheme", + "poly_modulus_degree", + "keyrns_terms", + "arg0", + "arg1", + ] + + # Act + result = trace_info.get_param_index_dict(tokens) + + # Assert + assert isinstance(result, dict) + assert len(result) == len(tokens) + assert result["instruction"] == 0 + assert result["scheme"] == 1 + assert result["poly_modulus_degree"] == 2 + assert result["keyrns_terms"] == 3 + assert result["arg0"] == 4 + assert result["arg1"] == 5 + + def test_extract_context_and_args(self): + """ + @brief Test extract_context_and_args method + """ + # Arrange + trace_info = TraceInfo("/path/to/trace.txt") + tokens = ["kernel1", "CKKS", "8192", "2", "input_var", "output_var"] + param_idxs = { + "instruction": 0, + "scheme": 1, + "poly_modulus_degree": 2, + "keyrns_terms": 3, + "arg0": 4, + "arg1": 5, + } + + # Act + name, context_config, kern_args = trace_info.extract_context_and_args( + tokens, param_idxs, 1 + ) + + # Assert + assert name == "kernel1" + assert isinstance(context_config, ContextConfig) + assert context_config.scheme == "CKKS" + assert context_config.poly_mod_degree == 8192 + assert context_config.keyrns_terms == 2 + assert kern_args == ["input_var", "output_var"] + + def test_extract_context_and_args_missing_param(self): + """ + @brief Test extract_context_and_args with missing parameter + """ + # Arrange + trace_info = TraceInfo("/path/to/trace.txt") + tokens = ["kernel1", "CKKS", "8192", "2", "input_var", "output_var"] + param_idxs = { + "instruction": 0, + "scheme": 1, + # Missing "poly_modulus_degree" + "keyrns_terms": 3, + "arg0": 4, + "arg1": 5, + } + + # Act & Assert + with pytest.raises(KeyError, match="poly_modulus_degree"): + trace_info.extract_context_and_args(tokens, param_idxs, 1) + + def test_extract_context_and_args_invalid_number(self): + """ + @brief Test extract_context_and_args with invalid number + """ + # Arrange + trace_info = TraceInfo("/path/to/trace.txt") + tokens = ["kernel1", "CKKS", "invalid", "2", "input_var", "output_var"] + param_idxs = { + "instruction": 0, + "scheme": 1, + "poly_modulus_degree": 2, # Will try to convert "invalid" to int + "keyrns_terms": 3, + "arg0": 4, + "arg1": 5, + } + + # Act & Assert + with pytest.raises(ValueError): + trace_info.extract_context_and_args(tokens, param_idxs, 1) + + def test_parse_kernel_ops_with_valid_trace(self): + """ + @brief Test parse_kernel_ops with a valid trace file + """ + # Arrange + trace_file = "/path/to/trace.txt" + trace_content = ( + "instruction scheme poly_modulus_degree keyrns_terms arg0 arg1\n" + "kernel1 CKKS 8192 2 input_var output_var\n" + "kernel2 BFV 4096 1 input_var2 output_var2\n" + ) + + # Act + with patch("os.path.isfile", return_value=True), patch( + "builtins.open", mock_open(read_data=trace_content) + ), patch("linker.kern_trace.trace_info.tokenize_from_line") as mock_tokenize: + + # Mock the tokenize_from_line function to return expected tokens + mock_tokenize.side_effect = [ + ( + [ + "instruction", + "scheme", + "poly_modulus_degree", + "keyrns_terms", + "arg0", + "arg1", + ], + None, + ), + ( + ["add", "CKKS", "8192", "2", "input_var1-0-1", "output_var0-0-1"], + None, + ), + ( + ["mul", "BFV", "4096", "1", "input_var2-3-4", "output_var0-2-2"], + None, + ), + ] + + trace_info = TraceInfo(trace_file) + result = trace_info.parse_kernel_ops() + + # Assert + assert len(result) == 2 + assert isinstance(result[0], KernelOp) + assert isinstance(result[1], KernelOp) + assert result[0].expected_in_kern_file_name == "ckks_add_8192_l1_m2" + assert result[1].expected_in_kern_file_name == "bfv_mul_4096_l2_m1" + assert len(result[0].kern_vars) == 2 + assert len(result[1].kern_vars) == 2 + + def test_parse_kernel_ops_with_empty_trace(self): + """ + @brief Test parse_kernel_ops with an empty trace file + """ + # Arrange + trace_file = "/path/to/empty_trace.txt" + + # Act + with patch("os.path.isfile", return_value=True), patch( + "builtins.open", mock_open(read_data="") + ): + + trace_info = TraceInfo(trace_file) + result = trace_info.parse_kernel_ops() + + # Assert + assert isinstance(result, list) + assert len(result) == 0 + + def test_parse_kernel_ops_with_nonexistent_file(self): + """ + @brief Test parse_kernel_ops with a nonexistent file + """ + # Arrange + trace_file = "/path/to/nonexistent.txt" + + # Act & Assert + with patch("os.path.isfile", return_value=False): + trace_info = TraceInfo(trace_file) + with pytest.raises(FileNotFoundError): + trace_info.parse_kernel_ops() + + def test_parse_kernel_ops_skip_empty_lines(self): + """ + @brief Test parse_kernel_ops skips empty lines + """ + # Arrange + trace_file = "/path/to/trace.txt" + trace_content = ( + "instruction scheme poly_modulus_degree keyrns_terms arg0 arg1\n" + "\n" # Empty line + "kernel1 CKKS 8192 2 input_var output_var\n" + " \n" # Line with whitespace + "kernel2 BFV 4096 1 input_var2 output_var2\n" + ) + + # Act + with patch("os.path.isfile", return_value=True), patch( + "builtins.open", mock_open(read_data=trace_content) + ), patch("linker.kern_trace.trace_info.tokenize_from_line") as mock_tokenize: + + # Mock the tokenize_from_line function to return expected tokens + mock_tokenize.side_effect = [ + ( + [ + "instruction", + "scheme", + "poly_modulus_degree", + "keyrns_terms", + "arg0", + "arg1", + ], + None, + ), + ([], None), # Empty line + ( + ["add", "CKKS", "8192", "2", "input_var1-0-2", "output_var0-3-2"], + None, + ), + ([""], None), # Line with whitespace tokenizes to empty string + ( + ["mul", "BGV", "4096", "1", "input_var2-3-4", "output_var2-3-3"], + None, + ), + ] + + trace_info = TraceInfo(trace_file) + result = trace_info.parse_kernel_ops() + + # Assert + assert len(result) == 2 # Only 2 valid kernel operations + assert isinstance(result[0], KernelOp) + assert isinstance(result[1], KernelOp) diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_linker_run_config.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_linker_run_config.py new file mode 100644 index 00000000..26f126dd --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_linker_run_config.py @@ -0,0 +1,199 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# These contents may have been developed with support from one +# or more Intel-operated generative artificial intelligence solutions + +""" +@file test_linker_run_config.py +@brief Unit tests for the LinkerRunConfig class +""" +import os +from unittest.mock import patch, PropertyMock +import pytest + +from linker.linker_run_config import LinkerRunConfig +from assembler.common.run_config import RunConfig +from assembler.common.config import GlobalConfig + + +class TestLinkerRunConfig: + """ + @class TestLinkerRunConfig + @brief Test cases for the LinkerRunConfig class + """ + + def test_init_with_valid_params(self): + """ + @brief Test initialization with valid parameters + """ + # Arrange + kwargs = { + "input_prefixes": ["prefix1", "prefix2"], + "output_prefix": "output_prefix", + "input_mem_file": "input.mem", + "output_dir": "/tmp", + "has_hbm": True, + "hbm_size": 1024, + "suppress_comments": False, + "use_xinstfetch": False, + "using_trace_file": False, + } + + # Act + with patch("linker.linker_run_config.makeUniquePath", side_effect=lambda x: x): + config = LinkerRunConfig(**kwargs) + + # Assert + assert config.input_prefixes == ["prefix1", "prefix2"] + assert config.output_prefix == "output_prefix" + assert config.input_mem_file == "input.mem" + assert config.output_dir == "/tmp" + assert config.has_hbm is True + assert config.hbm_size == 1024 + assert config.suppress_comments is False + assert config.use_xinstfetch is False + assert config.using_trace_file is False + + def test_init_with_missing_required_param(self): + """ + @brief Test initialization with missing required parameters + """ + # Arrange + kwargs = { + "input_prefixes": ["prefix1"], + "input_mem_file": "input.mem", + "output_dir": "/tmp", + # Missing output_prefixes + } + + # Act & Assert + with pytest.raises(TypeError): + LinkerRunConfig(**kwargs) + + def test_as_dict(self): + """ + @brief Test the as_dict method returns a proper dictionary + """ + # Arrange + kwargs = { + "input_prefixes": ["prefix1"], + "output_prefix": "output_prefix", + "input_mem_file": "input.mem", + "output_dir": "/tmp", + "has_hbm": True, + "hbm_size": 1024, + } + + # Act + with patch("linker.linker_run_config.makeUniquePath", side_effect=lambda x: x): + config = LinkerRunConfig(**kwargs) + result = config.as_dict() + + # Assert Keys + assert isinstance(result, dict) + assert "input_prefixes" in result + assert "output_prefix" in result + assert "input_mem_file" in result + assert "output_dir" in result + assert "has_hbm" in result + assert "hbm_size" in result + + # Assert values + assert result["input_prefixes"] == ["prefix1"] + assert result["output_prefix"] == "output_prefix" + assert result["input_mem_file"] == "input.mem" + assert result["output_dir"] == "/tmp" + assert result["has_hbm"] is True + assert result["hbm_size"] == 1024 + + def test_str_representation(self): + """ + @brief Test the string representation of the configuration + """ + # Arrange + kwargs = { + "input_prefixes": ["prefix1"], + "output_prefix": "output_prefix", + "input_mem_file": "input.mem", + } + + # Act + with patch("assembler.common.makeUniquePath", side_effect=lambda x: x): + config = LinkerRunConfig(**kwargs) + result = str(config) + + # Assert params + assert "input_prefixes" in result + assert "output_prefix" in result + assert "input_mem_file" in result + # Assert values + assert "prefix1" in result + assert "output_prefix" in result + assert "input.mem" in result + + def test_init_for_default_params(self): + """ + @brief Test initialization with default parameters + """ + + # Arrange + kwargs = {"input_prefixes": ["prefix1"], "output_prefix": ""} + + # Reset the class-level config so the patch will take effect + RunConfig.reset_class_state() + + # Act + with patch( + "assembler.common.makeUniquePath", side_effect=lambda x: x + ), patch.object( + RunConfig, "DEFAULT_HBM_SIZE_KB", new_callable=PropertyMock + ) as mock_hbm_size, patch.object( + GlobalConfig, "suppress_comments", new_callable=PropertyMock + ) as mock_suppress_comments, patch.object( + GlobalConfig, "useXInstFetch", new_callable=PropertyMock + ) as mock_use_xinstfetch: + + # Mock the default HBM size + mock_suppress_comments.return_value = False + mock_use_xinstfetch.return_value = False + mock_hbm_size.return_value = 1024 + config = LinkerRunConfig(**kwargs) + + # Assert + assert config.output_prefix == "" + assert config.input_mem_file == "" + assert config.output_dir == os.getcwd() + assert config.has_hbm is True + assert config.hbm_size == 1024 + assert config.suppress_comments is False + assert config.use_xinstfetch is False + assert config.using_trace_file is False + + def test_init_with_invalid_param_values(self): + """ + @brief Test initialization with invalid parameter values + """ + # Arrange + base_kwargs = { + "input_prefixes": ["prefix1"], + "output_prefix": "output_prefix", + "input_mem_file": "input.mem", + "output_dir": "/tmp", + } + + # Test cases with invalid values + invalid_test_cases = [ + # Test negative hbm_size + {**base_kwargs, "hbm_size": -1024}, + # Test non-integer hbm_size + {**base_kwargs, "hbm_size": "not_an_integer"}, + # Test invalid boolean value for has_hbm + {**base_kwargs, "has_hbm": "not_a_boolean"}, + ] + + # Act & Assert + for test_case in invalid_test_cases: + with patch("assembler.common.makeUniquePath", side_effect=lambda x: x): + with pytest.raises(ValueError, match=r".*invalid.*|.*Invalid.*"): + LinkerRunConfig(**test_case) diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_loader.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_loader.py index 06a3827d..29f7d94b 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_loader.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_loader.py @@ -11,16 +11,7 @@ import unittest from unittest.mock import patch, mock_open, MagicMock, call -from linker.loader import ( - load_minst_kernel, - load_minst_kernel_from_file, - load_cinst_kernel, - load_cinst_kernel_from_file, - load_xinst_kernel, - load_xinst_kernel_from_file, - load_dinst_kernel, - load_dinst_kernel_from_file, -) +from linker.loader import Loader class TestLoader(unittest.TestCase): @@ -52,7 +43,7 @@ def test_load_minst_kernel_success(self, mock_factory, mock_create): mock_create.side_effect = self.mock_minst # Call the function - result = load_minst_kernel(self.minst_lines) + result = Loader.load_minst_kernel(self.minst_lines) # Verify the results self.assertEqual(result, self.mock_minst) @@ -79,14 +70,14 @@ def test_load_minst_kernel_failure(self, mock_factory, mock_create): # Call the function and check for exception with self.assertRaises(RuntimeError) as context: - load_minst_kernel(self.minst_lines) + Loader.load_minst_kernel(self.minst_lines) self.assertIn( f"Error parsing line 1: {self.minst_lines[0]}", str(context.exception) ) @patch("builtins.open", new_callable=mock_open) - @patch("linker.loader.load_minst_kernel") + @patch("linker.loader.Loader.load_minst_kernel") def test_load_minst_kernel_from_file_success(self, mock_load, mock_file): """@brief Test successful loading of MInstructions from a file. @@ -97,7 +88,7 @@ def test_load_minst_kernel_from_file_success(self, mock_load, mock_file): mock_load.return_value = self.mock_minst # Call the function - result = load_minst_kernel_from_file("test.minst") + result = Loader.load_minst_kernel_from_file("test.minst") # Verify the results self.assertEqual(result, self.mock_minst) @@ -105,7 +96,7 @@ def test_load_minst_kernel_from_file_success(self, mock_load, mock_file): mock_load.assert_called_once_with(self.minst_lines) @patch("builtins.open", new_callable=mock_open) - @patch("linker.loader.load_minst_kernel") + @patch("linker.loader.Loader.load_minst_kernel") def test_load_minst_kernel_from_file_failure(self, mock_load, mock_file): """@brief Test error handling when loading MInstructions from a file fails. @@ -117,7 +108,7 @@ def test_load_minst_kernel_from_file_failure(self, mock_load, mock_file): # Call the function and check for exception with self.assertRaises(RuntimeError) as context: - load_minst_kernel_from_file("test.minst") + Loader.load_minst_kernel_from_file("test.minst") self.assertIn( 'Error occurred loading file "test.minst"', str(context.exception) @@ -135,7 +126,7 @@ def test_load_cinst_kernel_success(self, mock_factory, mock_create): mock_create.side_effect = self.mock_cinst # Call the function - result = load_cinst_kernel(self.cinst_lines) + result = Loader.load_cinst_kernel(self.cinst_lines) # Verify the results self.assertEqual(result, self.mock_cinst) @@ -162,14 +153,14 @@ def test_load_cinst_kernel_failure(self, mock_factory, mock_create): # Call the function and check for exception with self.assertRaises(RuntimeError) as context: - load_cinst_kernel(self.cinst_lines) + Loader.load_cinst_kernel(self.cinst_lines) self.assertIn( f"Error parsing line 1: {self.cinst_lines[0]}", str(context.exception) ) @patch("builtins.open", new_callable=mock_open) - @patch("linker.loader.load_cinst_kernel") + @patch("linker.loader.Loader.load_cinst_kernel") def test_load_cinst_kernel_from_file_success(self, mock_load, mock_file): """@brief Test successful loading of CInstructions from a file. @@ -180,7 +171,7 @@ def test_load_cinst_kernel_from_file_success(self, mock_load, mock_file): mock_load.return_value = self.mock_cinst # Call the function - result = load_cinst_kernel_from_file("test.cinst") + result = Loader.load_cinst_kernel_from_file("test.cinst") # Verify the results self.assertEqual(result, self.mock_cinst) @@ -188,7 +179,7 @@ def test_load_cinst_kernel_from_file_success(self, mock_load, mock_file): mock_load.assert_called_once_with(self.cinst_lines) @patch("builtins.open", new_callable=mock_open) - @patch("linker.loader.load_cinst_kernel") + @patch("linker.loader.Loader.load_cinst_kernel") def test_load_cinst_kernel_from_file_failure(self, mock_load, mock_file): """@brief Test error handling when loading CInstructions from a file fails. @@ -200,7 +191,7 @@ def test_load_cinst_kernel_from_file_failure(self, mock_load, mock_file): # Call the function and check for exception with self.assertRaises(RuntimeError) as context: - load_cinst_kernel_from_file("test.cinst") + Loader.load_cinst_kernel_from_file("test.cinst") self.assertIn( 'Error occurred loading file "test.cinst"', str(context.exception) @@ -218,7 +209,7 @@ def test_load_xinst_kernel_success(self, mock_factory, mock_create): mock_create.side_effect = self.mock_xinst # Call the function - result = load_xinst_kernel(self.xinst_lines) + result = Loader.load_xinst_kernel(self.xinst_lines) # Verify the results self.assertEqual(result, self.mock_xinst) @@ -245,14 +236,14 @@ def test_load_xinst_kernel_failure(self, mock_factory, mock_create): # Call the function and check for exception with self.assertRaises(RuntimeError) as context: - load_xinst_kernel(self.xinst_lines) + Loader.load_xinst_kernel(self.xinst_lines) self.assertIn( f"Error parsing line 1: {self.xinst_lines[0]}", str(context.exception) ) @patch("builtins.open", new_callable=mock_open) - @patch("linker.loader.load_xinst_kernel") + @patch("linker.loader.Loader.load_xinst_kernel") def test_load_xinst_kernel_from_file_success(self, mock_load, mock_file): """@brief Test successful loading of XInstructions from a file. @@ -263,7 +254,7 @@ def test_load_xinst_kernel_from_file_success(self, mock_load, mock_file): mock_load.return_value = self.mock_xinst # Call the function - result = load_xinst_kernel_from_file("test.xinst") + result = Loader.load_xinst_kernel_from_file("test.xinst") # Verify the results self.assertEqual(result, self.mock_xinst) @@ -271,7 +262,7 @@ def test_load_xinst_kernel_from_file_success(self, mock_load, mock_file): mock_load.assert_called_once_with(self.xinst_lines) @patch("builtins.open", new_callable=mock_open) - @patch("linker.loader.load_xinst_kernel") + @patch("linker.loader.Loader.load_xinst_kernel") def test_load_xinst_kernel_from_file_failure(self, mock_load, mock_file): """@brief Test error handling when loading XInstructions from a file fails. @@ -283,7 +274,7 @@ def test_load_xinst_kernel_from_file_failure(self, mock_load, mock_file): # Call the function and check for exception with self.assertRaises(RuntimeError) as context: - load_xinst_kernel_from_file("test.xinst") + Loader.load_xinst_kernel_from_file("test.xinst") self.assertIn( 'Error occurred loading file "test.xinst"', str(context.exception) @@ -299,7 +290,7 @@ def test_load_dinst_kernel_success(self, mock_create): mock_create.side_effect = self.mock_dinst # Call the function - result = load_dinst_kernel(self.dinst_lines) + result = Loader.load_dinst_kernel(self.dinst_lines) # Verify the results self.assertEqual(result, self.mock_dinst) @@ -319,14 +310,14 @@ def test_load_dinst_kernel_failure(self, mock_create): # Call the function and check for exception with self.assertRaises(RuntimeError) as context: - load_dinst_kernel(self.dinst_lines) + Loader.load_dinst_kernel(self.dinst_lines) self.assertIn( f"Error parsing line 1: {self.dinst_lines[0]}", str(context.exception) ) @patch("builtins.open", new_callable=mock_open) - @patch("linker.loader.load_dinst_kernel") + @patch("linker.loader.Loader.load_dinst_kernel") def test_load_dinst_kernel_from_file_success(self, mock_load, mock_file): """@brief Test successful loading of DInstructions from a file. @@ -337,7 +328,7 @@ def test_load_dinst_kernel_from_file_success(self, mock_load, mock_file): mock_load.return_value = self.mock_dinst # Call the function - result = load_dinst_kernel_from_file("test.dinst") + result = Loader.load_dinst_kernel_from_file("test.dinst") # Verify the results self.assertEqual(result, self.mock_dinst) @@ -345,7 +336,7 @@ def test_load_dinst_kernel_from_file_success(self, mock_load, mock_file): mock_load.assert_called_once_with(self.dinst_lines) @patch("builtins.open", new_callable=mock_open) - @patch("linker.loader.load_dinst_kernel") + @patch("linker.loader.Loader.load_dinst_kernel") def test_load_dinst_kernel_from_file_failure(self, mock_load, mock_file): """@brief Test error handling when loading DInstructions from a file fails. @@ -357,7 +348,7 @@ def test_load_dinst_kernel_from_file_failure(self, mock_load, mock_file): # Call the function and check for exception with self.assertRaises(RuntimeError) as context: - load_dinst_kernel_from_file("test.dinst") + Loader.load_dinst_kernel_from_file("test.dinst") self.assertIn( 'Error occurred loading file "test.dinst"', str(context.exception) diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_program_linker.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_program_linker.py index aa89d860..b306daf8 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_program_linker.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_program_linker.py @@ -10,7 +10,8 @@ import io import unittest -from unittest.mock import patch, MagicMock, call +from unittest.mock import patch, MagicMock, call, mock_open +from collections import namedtuple from assembler.common.config import GlobalConfig from linker import MemoryModel @@ -115,73 +116,6 @@ def test_close(self): # Should not contain "terminating MInstQ" comment self.assertNotIn("terminating MInstQ", self.streams["minst"].getvalue()) - def test_validate_hbm_address(self): - """@brief Test validating a HBM address. - - @test Verifies that valid addresses are accepted and invalid ones raise exceptions - """ - - # Test validating a valid HBM address - self.mem_model.mem_info_vars = {} - self.program._validate_hbm_address("test_var", 10) - # No exception should be raised - - # Test validating a negative HBM address - with self.assertRaises(RuntimeError): - self.program._validate_hbm_address("test_var", -1) - - def test_validate_hbm_address_mismatch(self): - """@brief Test validating an HBM address that doesn't match the declared address. - - @test Verifies that a RuntimeError is raised when address doesn't match - """ - mock_var = MagicMock() - mock_var.hbm_address = 5 - self.mem_model.mem_info_vars = {"test_var": mock_var} - - with self.assertRaises(RuntimeError): - self.program._validate_hbm_address("test_var", 10) - - def test_validate_spad_address_valid(self): - """@brief Test validating a valid SPAD address with HBM disabled. - - @test Verifies that valid SPAD addresses are accepted when HBM is disabled - """ - with patch.object(GlobalConfig, "hasHBM", False): - self.mem_model.mem_info_vars = {} - self.program._validate_spad_address("test_var", 10) - # No exception should be raised - - def test_validate_spad_address_with_hbm_enabled(self): - """@brief Test validating a SPAD address with HBM enabled. - - @test Verifies that an AssertionError is raised when HBM is enabled - """ - with self.assertRaises(AssertionError): - self.program._validate_spad_address("test_var", 10) - - def test_validate_spad_address_negative(self): - """@brief Test validating a negative SPAD address. - - @test Verifies that a RuntimeError is raised for negative addresses - """ - with patch.object(GlobalConfig, "hasHBM", False): - with self.assertRaises(RuntimeError): - self.program._validate_spad_address("test_var", -1) - - def test_validate_spad_address_mismatch(self): - """@brief Test validating a SPAD address that doesn't match the declared address. - - @test Verifies that a RuntimeError is raised when address doesn't match - """ - with patch.object(GlobalConfig, "hasHBM", False): - mock_var = MagicMock() - mock_var.hbm_address = 5 - self.mem_model.mem_info_vars = {"test_var": mock_var} - - with self.assertRaises(RuntimeError): - self.program._validate_spad_address("test_var", 10) - def test_update_minsts(self): """@brief Test updating MInsts. @@ -573,6 +507,171 @@ def test_link_kernel_with_suppress_comments(self): self.assertNotIn("cinst_comment", cinst_output) self.assertNotIn("minst_comment", minst_output) + def test_link_kernels_to_files(self): + """ + @brief Test the link_kernels_to_files static method. + + @test Verifies that kernels are correctly linked and written to output files + """ + # Create a namedtuple similar to KernelInfo for testing + KernelInfo = namedtuple( + "KernelInfo", ["prefix", "minst", "cinst", "xinst", "mem", "remap_dict"] + ) + + # Arrange + input_files = [ + KernelInfo( + prefix="/tmp/input1", + minst="/tmp/input1.minst", + cinst="/tmp/input1.cinst", + xinst="/tmp/input1.xinst", + mem=None, + remap_dict={}, + ) + ] + + output_files = KernelInfo( + prefix="/tmp/output", + minst="/tmp/output.minst", + cinst="/tmp/output.cinst", + xinst="/tmp/output.xinst", + mem=None, + remap_dict=None, + ) + + mock_mem_model = MagicMock() + mock_verbose = MagicMock() + + # Act + with patch("builtins.open", mock_open()), patch( + "linker.steps.program_linker.Loader.load_minst_kernel_from_file", + return_value=[], + ), patch( + "linker.steps.program_linker.Loader.load_cinst_kernel_from_file", + return_value=[], + ), patch( + "linker.steps.program_linker.Loader.load_xinst_kernel_from_file", + return_value=[], + ), patch.object( + LinkedProgram, "__init__", return_value=None + ) as mock_init, patch.object( + LinkedProgram, "link_kernel" + ) as mock_link_kernel, patch.object( + LinkedProgram, "close" + ) as mock_close: + + LinkedProgram.link_kernels_to_files( + input_files, output_files, mock_mem_model, mock_verbose + ) + + # Assert + mock_init.assert_called_once() + mock_link_kernel.assert_called_once_with([], [], []) + mock_close.assert_called_once() + + +class TestLinkedProgramValidation(unittest.TestCase): + """@brief Tests for the validation methods of the LinkedProgram class.""" + + def setUp(self): + """@brief Set up test fixtures.""" + # Group related stream objects into a dictionary + self.streams = { + "minst": io.StringIO(), + "cinst": io.StringIO(), + "xinst": io.StringIO(), + } + self.mem_model = MagicMock(spec=MemoryModel) + + # Mock the hasHBM property to return True by default + self.has_hbm_patcher = patch.object(GlobalConfig, "hasHBM", True) + self.mock_has_hbm = self.has_hbm_patcher.start() + + # Mock the suppress_comments property to return False by default + self.suppress_comments_patcher = patch.object( + GlobalConfig, "suppress_comments", False + ) + self.mock_suppress_comments = self.suppress_comments_patcher.start() + + self.program = LinkedProgram( + self.streams["minst"], + self.streams["cinst"], + self.streams["xinst"], + self.mem_model, + ) + + def tearDown(self): + """@brief Tear down test fixtures.""" + self.has_hbm_patcher.stop() + self.suppress_comments_patcher.stop() + + def test_validate_hbm_address(self): + """@brief Test validating a HBM address. + + @test Verifies that valid addresses are accepted and invalid ones raise exceptions + """ + + # Test validating a valid HBM address + self.mem_model.mem_info_vars = {} + self.program._validate_hbm_address("test_var", 10) + # No exception should be raised + + # Test validating a negative HBM address + with self.assertRaises(RuntimeError): + self.program._validate_hbm_address("test_var", -1) + + def test_validate_hbm_address_mismatch(self): + """@brief Test validating an HBM address that doesn't match the declared address. + + @test Verifies that a RuntimeError is raised when address doesn't match + """ + mock_var = MagicMock() + mock_var.hbm_address = 5 + self.mem_model.mem_info_vars = {"test_var": mock_var} + + with self.assertRaises(RuntimeError): + self.program._validate_hbm_address("test_var", 10) + + def test_validate_spad_address_valid(self): + """@brief Test validating a valid SPAD address with HBM disabled. + + @test Verifies that valid SPAD addresses are accepted when HBM is disabled + """ + with patch.object(GlobalConfig, "hasHBM", False): + self.mem_model.mem_info_vars = {} + self.program._validate_spad_address("test_var", 10) + # No exception should be raised + + def test_validate_spad_address_with_hbm_enabled(self): + """@brief Test validating a SPAD address with HBM enabled. + + @test Verifies that an AssertionError is raised when HBM is enabled + """ + with self.assertRaises(AssertionError): + self.program._validate_spad_address("test_var", 10) + + def test_validate_spad_address_negative(self): + """@brief Test validating a negative SPAD address. + + @test Verifies that a RuntimeError is raised for negative addresses + """ + with patch.object(GlobalConfig, "hasHBM", False): + with self.assertRaises(RuntimeError): + self.program._validate_spad_address("test_var", -1) + + def test_validate_spad_address_mismatch(self): + """@brief Test validating a SPAD address that doesn't match the declared address. + + @test Verifies that a RuntimeError is raised when address doesn't match + """ + with patch.object(GlobalConfig, "hasHBM", False): + mock_var = MagicMock() + mock_var.hbm_address = 5 + self.mem_model.mem_info_vars = {"test_var": mock_var} + + with self.assertRaises(RuntimeError): + self.program._validate_spad_address("test_var", 10) + class TestJoinDinstKernels(unittest.TestCase): """@brief Tests for the join_dinst_kernels static method.""" diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_variable_discovery.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_variable_discovery.py index 4e9e942b..241bb3fe 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_variable_discovery.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_variable_discovery.py @@ -10,8 +10,16 @@ import unittest from unittest.mock import patch, MagicMock +from collections import namedtuple +import pytest -from linker.steps.variable_discovery import discover_variables, discover_variables_spad +from assembler.common.config import GlobalConfig +from linker.steps.variable_discovery import ( + discover_variables, + discover_variables_spad, + scan_variables, + check_unused_variables, +) class TestVariableDiscovery(unittest.TestCase): @@ -232,7 +240,7 @@ def test_discover_variables_spad_invalid_type(self): list(discover_variables_spad([invalid_obj])) # Verify the error message - self.assertIn("not a valid MInstruction", str(context.exception)) + self.assertIn("not a valid CInstruction", str(context.exception)) @patch("linker.steps.variable_discovery.cinst") @patch("linker.steps.variable_discovery.CInstruction") @@ -270,6 +278,72 @@ def test_discover_variables_spad_invalid_variable_name( # Verify the error message self.assertIn("Invalid Variable name", str(context.exception)) + def test_scan_variables(self): + """ + @brief Test scan_variables function with and without HBM + + @test Verifies that scan_variables correctly processes input files and updates the memory model + in both HBM and non-HBM modes + """ + # Create a namedtuple similar to KernelInfo for testing + KernelInfo = namedtuple( + "KernelInfo", + ["directory", "prefix", "minst", "cinst", "xinst", "mem", "remap_dict"], + ) + input_files = [ + KernelInfo( + directory="/tmp", + prefix="input1", + minst="/tmp/input1.minst", + cinst="/tmp/input1.cinst", + xinst="/tmp/input1.xinst", + mem=None, + remap_dict=None, + ) + ] + + # Test with both True and False for hasHBM + for has_hbm in [True, False]: + with self.subTest(has_hbm=has_hbm): + # Arrange + GlobalConfig.hasHBM = has_hbm + mock_mem_model = MagicMock() + mock_verbose = MagicMock() + + # Act + with patch( + "linker.steps.variable_discovery.Loader.load_minst_kernel_from_file", + return_value=[], + ), patch( + "linker.steps.variable_discovery.Loader.load_cinst_kernel_from_file", + return_value=[], + ), patch( + "linker.steps.variable_discovery.discover_variables", + return_value=["var1", "var2"], + ), patch( + "linker.steps.variable_discovery.discover_variables_spad", + return_value=["var1", "var2"], + ): + scan_variables(input_files, mock_mem_model, mock_verbose) + + # Assert + self.assertEqual(mock_mem_model.add_variable.call_count, 2) + + def test_check_unused_variables(self): + """ + @brief Test check_unused_variables function + """ + # Arrange + GlobalConfig.hasHBM = True + mock_mem_model = MagicMock() + mock_mem_model.mem_info_vars = {"var1": MagicMock(), "var2": MagicMock()} + mock_mem_model.variables = {"var1"} + mock_mem_model.mem_info_meta = {} + + # Act & Assert + with pytest.raises(RuntimeError): + check_unused_variables(mock_mem_model) + if __name__ == "__main__": unittest.main()