From 66f650927c2486f297954f903bd3894d5de93dcf Mon Sep 17 00:00:00 2001 From: "Rojas Chaves, Jose" Date: Sat, 28 Jun 2025 04:32:09 +0000 Subject: [PATCH 01/12] Initial linker atumation commit --- .../assembler/memory_model/mem_info.py | 225 +++++++++++------- .../hec-assembler-tools/debug_tools/main.py | 8 +- assembler_tools/hec-assembler-tools/he_as.py | 4 +- .../hec-assembler-tools/he_link.py | 118 ++++++--- .../hec-assembler-tools/linker/__init__.py | 4 + .../linker/instructions/__init__.py | 3 +- .../linker/instructions/dinst/__init__.py | 49 ++++ .../linker/instructions/dinst/dinstruction.py | 96 ++++++++ .../linker/instructions/dinst/dkeygen.py | 29 +++ .../linker/instructions/dinst/dload.py | 38 +++ .../linker/instructions/dinst/dstore.py | 38 +++ .../linker/instructions/instruction.py | 56 +++-- .../hec-assembler-tools/linker/loader.py | 70 +++++- .../linker/steps/program_linker.py | 63 ++++- 14 files changed, 641 insertions(+), 160 deletions(-) create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/dinst/__init__.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/dinst/dinstruction.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/dinst/dkeygen.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/dinst/dload.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/dinst/dstore.py diff --git a/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_info.py b/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_info.py index 35c5bb19..a30cc7c5 100644 --- a/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_info.py +++ b/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_info.py @@ -3,6 +3,8 @@ from assembler.common import constants from assembler.instructions import tokenize_from_line +from typing import Optional +from assembler.common.decorators import * from assembler.memory_model.variable import Variable from . import MemoryModel @@ -113,7 +115,7 @@ class Metadata: class Ones: @classmethod - def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: + def parse_from_mem_tokens(cls, tokens: list) -> MemInfoVariable: """ Parses a ones metadata variable from a tokenized line. @@ -123,15 +125,13 @@ def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: Returns: MemInfoVariable: The parsed ones metadata variable. """ - return MemInfo.Metadata.parseMetaFieldFromMemLine( - tokens, - MemInfo.Const.Keyword.LOAD_ONES, - var_prefix=MemInfo.Const.Keyword.LOAD_ONES, - ) + return MemInfo.Metadata.parse_meta_field_from_mem_tokens(tokens, + MemInfo.Const.Keyword.LOAD_ONES, + var_prefix=MemInfo.Const.Keyword.LOAD_ONES) class NTTAuxTable: @classmethod - def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: + def parse_from_mem_tokens(cls, tokens: list) -> MemInfoVariable: """ Parses an NTT auxiliary table metadata variable from a tokenized line. @@ -141,15 +141,13 @@ def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: Returns: MemInfoVariable: The parsed NTT auxiliary table metadata variable. """ - return MemInfo.Metadata.parseMetaFieldFromMemLine( - tokens, - MemInfo.Const.Keyword.LOAD_NTT_AUX_TABLE, - var_prefix=MemInfo.Const.Keyword.LOAD_NTT_AUX_TABLE, - ) + return MemInfo.Metadata.parse_meta_field_from_mem_tokens(tokens, + MemInfo.Const.Keyword.LOAD_NTT_AUX_TABLE, + var_prefix=MemInfo.Const.Keyword.LOAD_NTT_AUX_TABLE) class NTTRoutingTable: @classmethod - def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: + def parse_from_mem_tokens(cls, tokens: list) -> MemInfoVariable: """ Parses an NTT routing table metadata variable from a tokenized line. @@ -159,15 +157,13 @@ def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: Returns: MemInfoVariable: The parsed NTT routing table metadata variable. """ - return MemInfo.Metadata.parseMetaFieldFromMemLine( - tokens, - MemInfo.Const.Keyword.LOAD_NTT_ROUTING_TABLE, - var_prefix=MemInfo.Const.Keyword.LOAD_NTT_ROUTING_TABLE, - ) + return MemInfo.Metadata.parse_meta_field_from_mem_tokens(tokens, + MemInfo.Const.Keyword.LOAD_NTT_ROUTING_TABLE, + var_prefix=MemInfo.Const.Keyword.LOAD_NTT_ROUTING_TABLE) class iNTTAuxTable: @classmethod - def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: + def parse_from_mem_tokens(cls, tokens: list) -> MemInfoVariable: """ Parses an iNTT auxiliary table metadata variable from a tokenized line. @@ -177,15 +173,13 @@ def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: Returns: MemInfoVariable: The parsed iNTT auxiliary table metadata variable. """ - return MemInfo.Metadata.parseMetaFieldFromMemLine( - tokens, - MemInfo.Const.Keyword.LOAD_iNTT_AUX_TABLE, - var_prefix=MemInfo.Const.Keyword.LOAD_iNTT_AUX_TABLE, - ) + return MemInfo.Metadata.parse_meta_field_from_mem_tokens(tokens, + MemInfo.Const.Keyword.LOAD_iNTT_AUX_TABLE, + var_prefix=MemInfo.Const.Keyword.LOAD_iNTT_AUX_TABLE) class iNTTRoutingTable: @classmethod - def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: + def parse_from_mem_tokens(cls, tokens: list) -> MemInfoVariable: """ Parses an iNTT routing table metadata variable from a tokenized line. @@ -195,15 +189,13 @@ def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: Returns: MemInfoVariable: The parsed iNTT routing table metadata variable. """ - return MemInfo.Metadata.parseMetaFieldFromMemLine( - tokens, - MemInfo.Const.Keyword.LOAD_iNTT_ROUTING_TABLE, - var_prefix=MemInfo.Const.Keyword.LOAD_iNTT_ROUTING_TABLE, - ) + return MemInfo.Metadata.parse_meta_field_from_mem_tokens(tokens, + MemInfo.Const.Keyword.LOAD_iNTT_ROUTING_TABLE, + var_prefix=MemInfo.Const.Keyword.LOAD_iNTT_ROUTING_TABLE) class Twiddle: @classmethod - def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: + def parse_from_mem_tokens(cls, tokens: list) -> MemInfoVariable: """ Parses a twiddle metadata variable from a tokenized line. @@ -213,15 +205,13 @@ def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: Returns: MemInfoVariable: The parsed twiddle metadata variable. """ - return MemInfo.Metadata.parseMetaFieldFromMemLine( - tokens, - MemInfo.Const.Keyword.LOAD_TWIDDLE, - var_prefix=MemInfo.Const.Keyword.LOAD_TWIDDLE, - ) + return MemInfo.Metadata.parse_meta_field_from_mem_tokens(tokens, + MemInfo.Const.Keyword.LOAD_TWIDDLE, + var_prefix=MemInfo.Const.Keyword.LOAD_TWIDDLE) class KeygenSeed: @classmethod - def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: + def parse_from_mem_tokens(cls, tokens: list) -> MemInfoVariable: """ Parses a keygen seed metadata variable from a tokenized line. @@ -231,20 +221,16 @@ def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: Returns: MemInfoVariable: The parsed keygen seed metadata variable. """ - return MemInfo.Metadata.parseMetaFieldFromMemLine( - tokens, - MemInfo.Const.Keyword.LOAD_KEYGEN_SEED, - var_prefix=MemInfo.Const.Keyword.LOAD_KEYGEN_SEED, - ) + return MemInfo.Metadata.parse_meta_field_from_mem_tokens(tokens, + MemInfo.Const.Keyword.LOAD_KEYGEN_SEED, + var_prefix=MemInfo.Const.Keyword.LOAD_KEYGEN_SEED) @classmethod - def parseMetaFieldFromMemLine( - cls, - tokens: list, - meta_field_name: str, - var_prefix: str = "meta", - var_extra: str = None, - ) -> MemInfoVariable: + def parse_meta_field_from_mem_tokens(cls, + tokens: list, + meta_field_name: str, + var_prefix: str = "meta", + var_extra: str = None) -> MemInfoVariable: """ Parses a metadata variable name from a tokenized line. @@ -289,7 +275,7 @@ def __init__(self, **kwargs): MemInfoVariable(**d) for d in kwargs.get(meta_field, []) ] - def __getitem__(self, key): + def get_item(self, key): """ Retrieves the list of MemInfoVariable objects for the specified metadata field. @@ -373,7 +359,7 @@ def keygen_seeds(self) -> list: class Keygen: @classmethod - def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: + def parse_from_mem_tokens(cls, tokens: list) -> MemInfoVariable: """ Parses a keygen variable from a tokenized line. @@ -396,7 +382,7 @@ def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: class Input: @classmethod - def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: + def parse_from_mem_tokens(cls, tokens: list) -> MemInfoVariable: """ Parses an input variable from a tokenized line. @@ -422,7 +408,7 @@ def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: class Output: @classmethod - def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: + def parse_from_mem_tokens(cls, tokens: list) -> MemInfoVariable: """ Parses an output variable from a tokenized line. @@ -448,7 +434,7 @@ def __init__(self, **kwargs): Initializes a new MemInfo object. Clients may call this method without parameters for default initialization. - Clients should use MemInfo.from_iter() constructor to parse the contents of a .mem file. + Clients should use MemInfo.from_file_iter() constructor to parse the contents of a .mem file. Args: kwargs (dict): A dictionary as generated by the method MemInfo.as_dict(). This is provided as @@ -468,8 +454,72 @@ def __init__(self, **kwargs): ) self.validate() + @property + def factory_dict(self): + """ + Returns a dictionary mapping MemInfo types to their respective factory methods. + + This is used to create instances of MemInfoVariable based on the type of memory information. + + Returns: + dict: A dictionary mapping MemInfo types to their factory methods. + """ + return { MemInfo.Keygen: self.keygens, + MemInfo.Input: self.inputs, + MemInfo.Output: self.outputs, + MemInfo.Metadata.KeygenSeed: self.metadata.keygen_seeds, + MemInfo.Metadata.Ones: self.metadata.ones, + MemInfo.Metadata.NTTAuxTable: self.metadata.ntt_auxiliary_table, + MemInfo.Metadata.NTTRoutingTable: self.metadata.ntt_routing_table, + MemInfo.Metadata.iNTTAuxTable: self.metadata.intt_auxiliary_table, + MemInfo.Metadata.iNTTRoutingTable: self.metadata.intt_routing_table, + MemInfo.Metadata.Twiddle: self.metadata.twiddle } + + @classproperty + def mem_info_types(cls): + """ Retrieves the list of MemInfo variable types.""" + dummy = cls() + return dummy.factory_dict.keys() + @classmethod - def from_iter(cls, line_iter): + def get_meminfo_var_from_tokens(cls, tokens) -> tuple[Optional[MemInfoVariable], Optional[type]]: + """ + Parses a MemInfo variable from a list of tokens. + + Args: + tokens (list[str]): List of tokens to parse. + + Returns: + tuple[MemInfoVariable,type]: The parsed MemInfo variable and its type, or None if no variable could be parsed. + """ + miv: Optional[MemInfoVariable] = None + mem_info_type: Optional[type] = None + for mem_info_type in cls.mem_info_types: + miv: MemInfoVariable = mem_info_type.parse_from_mem_tokens(tokens) + if miv is not None: + break + + return miv, mem_info_type + + def add_meminfo_var_from_tokens(self, tokens): + """ + Parses a MemInfo variable from a list of tokens and adds it to the appropriate list. + + Args: + tokens (list[str]): List of tokens to parse. + + Raises: + RuntimeError: If the line could not be parsed. + """ + miv: MemInfoVariable = None + miv, mem_info_type = MemInfo.get_meminfo_var_from_tokens(tokens) + if miv is not None and mem_info_type is not None: + self.factory_dict[mem_info_type].append(miv) + else: + raise RuntimeError(f"Could not parse line") + + @classmethod + def from_file_iter(cls, line_iter): """ Creates a new MemInfo object from an iterator of strings, where each string is a line of text to parse. @@ -486,38 +536,47 @@ def from_iter(cls, line_iter): """ retval = cls() - - factory_dict = { - MemInfo.Keygen: retval.keygens, - MemInfo.Input: retval.inputs, - MemInfo.Output: retval.outputs, - MemInfo.Metadata.KeygenSeed: retval.metadata.keygen_seeds, - MemInfo.Metadata.Ones: retval.metadata.ones, - MemInfo.Metadata.NTTAuxTable: retval.metadata.ntt_auxiliary_table, - MemInfo.Metadata.NTTRoutingTable: retval.metadata.ntt_routing_table, - MemInfo.Metadata.iNTTAuxTable: retval.metadata.intt_auxiliary_table, - MemInfo.Metadata.iNTTRoutingTable: retval.metadata.intt_routing_table, - MemInfo.Metadata.Twiddle: retval.metadata.twiddle, - } + for line_no, s_line in enumerate(line_iter, 1): s_line = s_line.strip() if s_line: # skip empty lines tokens, _ = tokenize_from_line(s_line) if tokens and len(tokens) > 0: - b_parsed = False - for mem_info_type in factory_dict: - miv: MemInfoVariable = mem_info_type.parseFromMemLine(tokens) - if miv is not None: - factory_dict[mem_info_type].append(miv) - b_parsed = True - break # next line - if not b_parsed: - raise RuntimeError( - f'Could not parse line {line_no}: "{s_line}"' - ) + try: + retval.add_meminfo_var_from_tokens(tokens) + except RuntimeError as e: + raise RuntimeError(f"{e} {line_no}: {s_line}") from e retval.validate() return retval + @classmethod + def from_dinstrs(cls, dinstrs): + """ + Creates a new MemInfo object from an list of DInstructions. + + Args: + dinstrs (list[DInstruction]): List of DInstructions. + + Raises: + RuntimeError: If there is an error parsing the instruction tokens. + + Returns: + MemInfo: The constructed MemInfo object. + """ + + retval = cls() + for ints_no, dinstr in enumerate(dinstrs, 1): + tokens = dinstr.tokens + if tokens and len(tokens) > 0: + try: + retval.add_meminfo_var_from_tokens(tokens) + print(f"Added {tokens} to MemInfo") + except RuntimeError as e: + raise RuntimeError(f"{e} {ints_no}: {tokens}") from e + + retval.validate() + return retval + @property def keygens(self) -> list: """ @@ -526,7 +585,7 @@ def keygens(self) -> list: Returns: list: Keygen variables. """ - return self.__keygens + return self._keygens @property def inputs(self) -> list: @@ -536,7 +595,7 @@ def inputs(self) -> list: Returns: list: Input variables. """ - return self.__inputs + return self._inputs @property def outputs(self) -> list: @@ -546,7 +605,7 @@ def outputs(self) -> list: Returns: list: Output variables. """ - return self.__outputs + return self._outputs @property def metadata(self) -> Metadata: @@ -556,7 +615,7 @@ def metadata(self) -> Metadata: Returns: Metadata: MemInfo's metadata. """ - return self.__metadata + return self._metadata def as_dict(self): """ diff --git a/assembler_tools/hec-assembler-tools/debug_tools/main.py b/assembler_tools/hec-assembler-tools/debug_tools/main.py index a46c0f3a..9a675cd1 100644 --- a/assembler_tools/hec-assembler-tools/debug_tools/main.py +++ b/assembler_tools/hec-assembler-tools/debug_tools/main.py @@ -82,8 +82,8 @@ def main_readmem(args): ) mem_meta_info = None - with open(mem_filename, "r") as mem_ifnum: - mem_meta_info = mem_info.MemInfo.from_iter(mem_ifnum) + with open(mem_filename, 'r') as mem_ifnum: + mem_meta_info = mem_info.MemInfo.from_file_iter(mem_ifnum) if mem_meta_info: with io.StringIO() as retval_f: @@ -216,8 +216,8 @@ def asmisa_assembly( if b_verbose: print("Interpreting variable meta information...") - with open(mem_filename, "r") as mem_ifnum: - mem_meta_info = mem_info.MemInfo.from_iter(mem_ifnum) + with open(mem_filename, 'r') as mem_ifnum: + mem_meta_info = mem_info.MemInfo.from_file_iter(mem_ifnum) mem_info.updateMemoryModelWithMemInfo(hec_mem_model, mem_meta_info) if b_verbose: diff --git a/assembler_tools/hec-assembler-tools/he_as.py b/assembler_tools/hec-assembler-tools/he_as.py index 8c3e07b4..f92d63c9 100644 --- a/assembler_tools/hec-assembler-tools/he_as.py +++ b/assembler_tools/hec-assembler-tools/he_as.py @@ -250,8 +250,8 @@ def asmisaAssemble( if b_verbose: print("Interpreting variable meta information...") - with open(mem_filename, "r") as mem_ifnum: - mem_meta_info = mem_info.MemInfo.from_iter(mem_ifnum) + with open(mem_filename, 'r') as mem_ifnum: + mem_meta_info = mem_info.MemInfo.from_file_iter(mem_ifnum) mem_info.updateMemoryModelWithMemInfo(hec_mem_model, mem_meta_info) if b_verbose: diff --git a/assembler_tools/hec-assembler-tools/he_link.py b/assembler_tools/hec-assembler-tools/he_link.py index 2560ad99..f29390cb 100644 --- a/assembler_tools/hec-assembler-tools/he_link.py +++ b/assembler_tools/hec-assembler-tools/he_link.py @@ -45,6 +45,7 @@ from linker import loader from linker.steps import variable_discovery from linker.steps import program_linker +from linker.instructions import BaseInstruction class LinkerRunConfig(RunConfig): """ @@ -108,7 +109,8 @@ def __init__(self, **kwargs): # fix file names self.output_dir = makeUniquePath(self.output_dir) - self.input_mem_file = makeUniquePath(self.input_mem_file) + if self.input_mem_file is not None: + self.input_mem_file = makeUniquePath(self.input_mem_file) @classmethod def init_default_config(cls): @@ -118,6 +120,7 @@ def init_default_config(cls): if not cls.__initialized: cls.__default_config["input_prefixes"] = None cls.__default_config["input_mem_file"] = None + cls.__default_config["find_mem_files"] = False cls.__default_config["output_dir"] = os.getcwd() cls.__default_config["output_prefix"] = None cls.__default_config["has_hbm"] = True @@ -158,19 +161,23 @@ class KernelFiles(NamedTuple): Structure for kernel files. Attributes: + prefix (str): + Index = 0 minst (str): - Index = 0. Name for file containing MInstructions for represented kernel. + Index = 1. Name for file containing MInstructions for represented kernel. cinst (str): - Index = 1. Name for file containing CInstructions for represented kernel. + Index = 2. Name for file containing CInstructions for represented kernel. xinst (str): - Index = 2. Name for file containing XInstructions for represented kernel. - prefix (str): - Index = 3 + Index = 3. Name for file containing XInstructions for represented kernel. + mem (str, optional): + Index = 4. Name for file containing memory information for represented kernel. + This is used only when find_mem_files is set. """ + prefix: str minst: str cinst: str xinst: str - prefix: str + mem: str = None def main(run_config: LinkerRunConfig, verbose_stream = None): """ @@ -194,6 +201,7 @@ def main(run_config: LinkerRunConfig, verbose_stream = None): # Update global config GlobalConfig.hasHBM = run_config.has_hbm + GlobalConfig.suppressComments = run_config.suppress_comments mem_filename: str = run_config.input_mem_file hbm_capcity_words: int = constants.convertBytes2Words(run_config.hbm_size * constants.Constants.KILOBYTE) @@ -203,23 +211,40 @@ def main(run_config: LinkerRunConfig, verbose_stream = None): # prepare output file names output_prefix = os.path.join(run_config.output_dir, run_config.output_prefix) output_dir = os.path.dirname(output_prefix) + + # If find_mem_files is enabled, create a new memory file for the output + out_mem_file = None + if run_config.find_mem_files: + out_mem_file = makeUniquePath(output_prefix + '.mem') + pathlib.Path(output_dir).mkdir(exist_ok = True, parents=True) - output_files = KernelFiles(minst=makeUniquePath(output_prefix + '.minst'), + output_files = KernelFiles(prefix=makeUniquePath(output_prefix), + minst=makeUniquePath(output_prefix + '.minst'), cinst=makeUniquePath(output_prefix + '.cinst'), xinst=makeUniquePath(output_prefix + '.xinst'), - prefix=makeUniquePath(output_prefix)) + mem=out_mem_file + ) # prepare input file names for file_prefix in run_config.input_prefixes: - input_files.append(KernelFiles(minst=makeUniquePath(file_prefix + '.minst'), + + # If find_mem_files is enabled, try to find a .tw.mem file for each prefix + mem_file = None + if run_config.find_mem_files: + mem_file = makeUniquePath(file_prefix + '.mem') + + input_files.append(KernelFiles(prefix=makeUniquePath(file_prefix), + minst=makeUniquePath(file_prefix + '.minst'), cinst=makeUniquePath(file_prefix + '.cinst'), xinst=makeUniquePath(file_prefix + '.xinst'), - prefix=makeUniquePath(file_prefix))) - for input_filename in input_files[-1][:-1]: - if not os.path.isfile(input_filename): - raise FileNotFoundError(input_filename) - if input_filename in output_files: - raise RuntimeError(f'Input files cannot match output files: "{input_filename}"') + mem=mem_file)) + + for input_filename in input_files[-1][1:]: + if input_filename: + if not os.path.isfile(input_filename): + raise FileNotFoundError(input_filename) + if input_filename in output_files: + raise RuntimeError(f'Input files cannot match output files: "{input_filename}"') # reset counters Counter.reset() @@ -230,10 +255,23 @@ def main(run_config: LinkerRunConfig, verbose_stream = None): print("", file=verbose_stream) print("Interpreting variable meta information...", file=verbose_stream) - with open(mem_filename, 'r') as mem_ifnum: - mem_meta_info = mem_info.MemInfo.from_iter(mem_ifnum) + if run_config.find_mem_files: + if verbose_stream: + print("Linking together multiple memory files", file=verbose_stream) + + kernels_dinstrs = [] + for kernel in input_files: + kernel_dinstrs = loader.load_dinst_kernel_from_file(kernel.mem) + kernels_dinstrs.append(kernel_dinstrs) + + # Concatenate all mem info objects into one + kernel_dinstrs = program_linker.LinkedProgram.join_dinst_kernels(kernels_dinstrs) + mem_meta_info = mem_info.MemInfo.from_dinstrs(kernel_dinstrs) + else: + with open(mem_filename, 'r') as mem_ifnum: + mem_meta_info = mem_info.MemInfo.from_file_iter(mem_ifnum) - # initialize memory model + # Initialize memory model if verbose_stream: print("Initializing linker memory model", file=verbose_stream) @@ -241,7 +279,7 @@ def main(run_config: LinkerRunConfig, verbose_stream = None): if verbose_stream: print(f" HBM capacity: {mem_model.hbm.capacity} words", file=verbose_stream) - # find all variables and usage across all the input kernels + # Find all variables and usage across all the input kernels if verbose_stream: print(" Finding all program variables...", file=verbose_stream) @@ -250,18 +288,18 @@ def main(run_config: LinkerRunConfig, verbose_stream = None): for idx, kernel in enumerate(input_files): if not GlobalConfig.hasHBM: if verbose_stream: - print(" {}/{}".format(idx + 1, len(input_files)), kernel.cinst, + print(f" {idx + 1}/{len(input_files)}", kernel.cinst, file=verbose_stream) # load next CInst kernel and scan for variables used in SPAD - kernel_cinstrs = loader.loadCInstKernelFromFile(kernel.cinst) + kernel_cinstrs = loader.load_cinst_kernel_from_file(kernel.cinst) for var_name in variable_discovery.discoverVariablesSPAD(kernel_cinstrs): mem_model.addVariable(var_name) else: if verbose_stream: - print(" {}/{}".format(idx + 1, len(input_files)), kernel.minst, + print(f" {idx + 1}/{len(input_files)}", kernel.minst, file=verbose_stream) - # load next MInst kernel and scan for variables used - kernel_minstrs = loader.loadMInstKernelFromFile(kernel.minst) + # Load next MInst kernel and scan for variables used + kernel_minstrs = loader.load_minst_kernel_from_file(kernel.minst) for var_name in variable_discovery.discoverVariables(kernel_minstrs): mem_model.addVariable(var_name) @@ -293,9 +331,9 @@ def main(run_config: LinkerRunConfig, verbose_stream = None): if verbose_stream: print("[ {: >3}% ]".format(idx * 100 // len(input_files)), kernel.prefix, file=verbose_stream) - kernel_minstrs = loader.loadMInstKernelFromFile(kernel.minst) - kernel_cinstrs = loader.loadCInstKernelFromFile(kernel.cinst) - kernel_xinstrs = loader.loadXInstKernelFromFile(kernel.xinst) + kernel_minstrs = loader.load_minst_kernel_from_file(kernel.minst) + kernel_cinstrs = loader.load_cinst_kernel_from_file(kernel.cinst) + kernel_xinstrs = loader.load_xinst_kernel_from_file(kernel.xinst) result_program.linkKernel(kernel_minstrs, kernel_cinstrs, kernel_xinstrs) @@ -305,11 +343,19 @@ def main(run_config: LinkerRunConfig, verbose_stream = None): # signal that we have linked all kernels result_program.close() + if run_config.find_mem_files: + # Write the memory model to the output file + if verbose_stream: + print("Writing memory model to", output_files.mem, file=verbose_stream) + BaseInstruction.dump_instructions_to_file(kernel_dinstrs, output_files.mem) + if verbose_stream: print("Output written to files:", file=verbose_stream) print(" ", output_files.minst, file=verbose_stream) print(" ", output_files.cinst, file=verbose_stream) print(" ", output_files.xinst, file=verbose_stream) + if run_config.find_mem_files: + print(" ", output_files.mem, file=verbose_stream) def parse_args(): """ @@ -338,12 +384,14 @@ def parse_args(): help=("Input Mem specification (.json) file.")) parser.add_argument("--isa_spec", default="", dest="isa_spec_file", help=("Input ISA specification (.json) file.")) - parser.add_argument("-im", "--input_mem_file", dest="input_mem_file", required=True, + parser.add_argument("--find_mem_files", action="store_true", dest="find_mem_files", + help=("Tells the linker to find a memory file (*.tw.mem) for each input prefix given." + "This can be used to link multiple kernels together. " + "If this flag is not set, the linker will use the input_mem_file argument instead")) + parser.add_argument("-im", "--input_mem_file", dest="input_mem_file", required=False, help=("Input memory mapping file associated with the resulting program. " - "Specifies the names for input, output, and metadata variables for the full program. " - "This file is usually the same as the kernel's when converting a single kernel into " - "a program, but, when linking multiple kernels together, it should be tailored to the " - "whole program.")) + "Specifies the names for input, output, and metadata variables for a single kernel" + " or also a full program if instead this is used to link multiple kernels together.")) parser.add_argument("-o", "--output_prefix", dest="output_prefix", required=True, help=("Prefix for the output file names. " "Three files will be generated: \n" @@ -364,6 +412,10 @@ def parse_args(): "Increase level of verbosity by specifying flag multiple times, e.g. -vv")) args = parser.parse_args() + # Enforce input_mem_file only if find_mem_files is not set + if not args.find_mem_files and not args.input_mem_file: + parser.error("the following arguments are required: -im/--input_mem_file (unless --find_mem_files is set)") + return args if __name__ == "__main__": diff --git a/assembler_tools/hec-assembler-tools/linker/__init__.py b/assembler_tools/hec-assembler-tools/linker/__init__.py index a32676d9..b26839bb 100644 --- a/assembler_tools/hec-assembler-tools/linker/__init__.py +++ b/assembler_tools/hec-assembler-tools/linker/__init__.py @@ -202,13 +202,17 @@ def addVariable(self, var_name: str): """ var_info: VariableInfo if var_name in self.variables: + print(f' ROCHA Variable {var_name} already exists in memory model.') var_info = self.variables[var_name] else: + print(f'ROCHA Adding variable {var_name} to memory model.') var_info = VariableInfo(var_name) if var_name in self.__mem_info_vars: + print(f'\tROCHA Variable {var_name} is in MemInfo, allocating HBM address. HBM address {self.__mem_info_vars[var_name].hbm_address}.') # Variables explicitly marked in mem file must persist throughout the program # with predefined HBM address if var_name in self.__mem_info_fixed_addr_vars: + print(f'\tROCHA Variable {var_name} has fixed HBM address {self.__mem_info_vars[var_name].hbm_address}.') var_info.uses = float('inf') self.hbm.forceAllocate(var_info, self.__mem_info_vars[var_name].hbm_address) diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/__init__.py b/assembler_tools/hec-assembler-tools/linker/instructions/__init__.py index 135608cc..2ccbf24d 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/__init__.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/__init__.py @@ -4,8 +4,7 @@ from assembler.instructions import tokenize_from_line from linker.instructions.instruction import BaseInstruction - -def fromStrLine(line: str, factory) -> BaseInstruction: +def create_from_str_line(line: str, factory) -> BaseInstruction: """ Parses an instruction from a line of text. diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/__init__.py b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/__init__.py new file mode 100644 index 00000000..92f5f9e3 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/__init__.py @@ -0,0 +1,49 @@ +from . import dload, dstore, dkeygen +from . import dinstruction +from assembler.instructions import tokenizeFromLine +from assembler.memory_model.mem_info import MemInfo, MemInfoVariable + +DLoad = dload.Instruction +DStore = dstore.Instruction +DKeyGen = dkeygen.Instruction + +def factory() -> set: + """ + Creates a set of all DInstruction classes. + + Returns: + set: A set containing all DInstruction classes. + """ + return { DLoad, DStore, DKeyGen } + +def create_from_mem_line(line: str) -> dinstruction.DInstruction: + """ + Parses an data instruction from a line of the memory map. + + Parameters: + line (str): Line of text from which to parse an instruction. + + Returns: + DInstruction or None: The parsed DInstruction object, or None if no object could be + parsed from the specified input line. + """ + retval = None + tokens, comment = tokenizeFromLine(line) + for instr_type in factory(): + try: + retval = instr_type(tokens, comment) + except ValueError: + retval = None + if retval: + break + + try: + miv, _ = MemInfo.get_meminfo_var_from_tokens(tokens) + except RuntimeError as e: + raise RuntimeError(f'Error parsing memory map line "{line}"') from e + + miv_dict = miv.as_dict() + retval.var = miv_dict["var_name"] + retval.address = miv_dict["hbm_address"] + + return retval \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dinstruction.py b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dinstruction.py new file mode 100644 index 00000000..74ff95e0 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dinstruction.py @@ -0,0 +1,96 @@ +from linker.instructions.instruction import BaseInstruction +from assembler.common.counter import Counter +from assembler.common.decorators import * + +class DInstruction(BaseInstruction): + """ + Represents a DInstruction, inheriting from BaseInstruction. + """ + + _local_id_count = Counter.count(0) # Local counter for DInstruction IDs + _var: str = "" + _address: int = 0 + + @classmethod + def _get_name_token_index(cls) -> int: + """ + Gets the index of the token containing the name of the instruction. + + Returns: + int: The index of the name token, which is 0. + """ + return 0 + + @classproperty + def num_tokens(cls) -> int: + """ + Valid number of tokens for this instruction. + + Returns: + tupple: Valid number of tokens. + """ + return cls._get_num_tokens() + + def __init__(self, tokens: list, comment: str = ""): + """ + Constructs a new DInstruction. + + Parameters: + tokens (list): List of tokens for the instruction. + comment (str): Optional comment for the instruction. + """ + # Do not increment the global instruction count; skip BaseInstruction's __init__ logic for __id + assert self.name_token_index < self.num_tokens + + if len(tokens) > self.num_tokens: + raise ValueError((f"`tokens`: invalid amount of tokens. " + f"Instruction {self.name} requires at least {self.num_tokens}, but {len(tokens)} received")) + if tokens[self.name_token_index] != self.name: + raise ValueError(f"`tokens`: invalid name. Expected {self.name}, but {tokens[self.name_token_index]} received") + + self.comment = comment + self._tokens = list(tokens) + self._local_id = next(DInstruction._local_id_count) + + @property + def id(self): + """ + Unique ID for the instruction. + + This is a combination of the client ID specified during construction and a unique nonce per instruction. + + Returns: + tuple: (client_id: int, nonce: int) where client_id is the id specified at construction. + """ + return self._local_id + + @property + def var(self) -> str: + """ + Name of source/dest var. + """ + return self._var + + @var.setter + def var(self, value: str): + self._var = value + + @property + def address(self) -> str: + """ + Should be set to source/dest Mem address. + """ + return self._address + + @address.setter + def address(self, value: str): + self._address = value + + def to_line(self) -> str: + """ + Retrieves the string form of the instruction to write to the instruction file. + + Returns: + str: The string representation of the instruction. + """ + return ", ".join(self.tokens) diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dkeygen.py b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dkeygen.py new file mode 100644 index 00000000..ba19d31e --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dkeygen.py @@ -0,0 +1,29 @@ +from .dinstruction import DInstruction +from assembler.common.config import GlobalConfig +from assembler.memory_model.mem_info import MemInfo + +class Instruction(DInstruction): + """ + Encapsulates a `dkeygen` DInstruction. + """ + + @classmethod + def _get_num_tokens(cls) -> int: + """ + Gets the number of tokens allowed for the instruction. + + Returns: + int: The number of tokens, which is 4. + """ + return 4 + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction. + """ + return MemInfo.Const.Keyword.KEYGEN + diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dload.py b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dload.py new file mode 100644 index 00000000..b398f8d1 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dload.py @@ -0,0 +1,38 @@ +from .dinstruction import DInstruction +from assembler.memory_model.mem_info import MemInfo + +class Instruction(DInstruction): + """ + Encapsulates a `dload` DInstruction. + """ + + @classmethod + def _get_num_tokens(cls) -> tuple: + """ + Gets the number of tokens allowed for the instruction. + + Returns: + tupple: The number of tokens, which is 4. + """ + return 4 + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction. + """ + return MemInfo.Const.Keyword.LOAD + + @property + def tokens(self) -> list: + """ + Gets the list of tokens for the instruction. + + Returns: + list: The list of tokens. + """ + return [self.name, self._tokens[1], str(self.address)] + self._tokens[3:] + diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dstore.py b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dstore.py new file mode 100644 index 00000000..543172d6 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dstore.py @@ -0,0 +1,38 @@ +from .dinstruction import DInstruction +from assembler.common.config import GlobalConfig +from assembler.memory_model.mem_info import MemInfo + +class Instruction(DInstruction): + """ + Encapsulates a `dstore` DInstruction. + """ + + @classmethod + def _get_num_tokens(cls) -> int: + """ + Gets the number of tokens allowed for the instruction. + + Returns: + int: The number of tokens, which is 3. + """ + return 3 + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction. + """ + return MemInfo.Const.Keyword.STORE + + @property + def tokens(self) -> list: + """ + Gets the list of tokens for the instruction. + + Returns: + list: The list of tokens. + """ + return [self.name, self.var, str(self.address)] + self._tokens[3:] diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py b/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py index ad029607..bb51287a 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py @@ -1,5 +1,6 @@ from assembler.common.decorators import * from assembler.common.counter import Counter +from assembler.common.config import GlobalConfig class BaseInstruction: """ @@ -49,7 +50,7 @@ def _get_name(cls) -> str: raise NotImplementedError() @classproperty - def NAME_TOKEN_INDEX(cls) -> int: + def name_token_index(cls) -> int: """ Index for the token containing the name of the instruction in the list of tokens. @@ -72,7 +73,7 @@ def _get_name_token_index(cls) -> int: raise NotImplementedError() @classproperty - def NUM_TOKENS(cls) -> int: + def num_tokens(cls) -> int: """ Number of tokens required for this instruction. @@ -92,6 +93,21 @@ def _get_num_tokens(cls) -> int: """ raise NotImplementedError() + @classmethod + def dump_instructions_to_file(cls, instructions: list, filename: str): + """ + Writes a list of instruction objects to a file, one per line. + + Each instruction is converted to its string representation using the `to_line()` method. + + Args: + instructions (list): List of instruction objects (must have a to_line() method). + filename (str): Path to the output file. + """ + with open(filename, 'w') as f: + for instr in instructions: + f.write(instr.to_line() + '\n') + # Constructor # ----------- @@ -106,28 +122,22 @@ def __init__(self, tokens: list, comment: str = ""): Raises: ValueError: If the number of tokens is invalid or the instruction name is incorrect. """ - assert self.NAME_TOKEN_INDEX < self.NUM_TOKENS + assert self.name_token_index < self.num_tokens - if len(tokens) != self.NUM_TOKENS: - raise ValueError(('`tokens`: invalid amount of tokens. ' - 'Instruction {} requires {}, but {} received').format(self.name, - self.NUM_TOKENS, - len(tokens))) - if tokens[self.NAME_TOKEN_INDEX] != self.name: - raise ValueError('`tokens`: invalid name. Expected {}, but {} received'.format(self.name, - tokens[self.NAME_TOKEN_INDEX])) + if len(tokens) != self.num_tokens: + raise ValueError((f"`tokens`: invalid amount of tokens. " + f"Instruction {self.name} requires less " + f"than {self.num_tokens}, but {len(tokens)} received")) + if tokens[self.name_token_index] != self.name: + raise ValueError(f"`tokens`: invalid name. Expected {self.name}, but {tokens[self.name_token_index]} received") - self.__id = next(BaseInstruction.__id_count) + self._id = next(BaseInstruction.__id_count) - self.__tokens = list(tokens) + self._tokens = list(tokens) self.comment = comment def __repr__(self): - retval = ('<{}({}, id={}) object at {}>(tokens={})').format(type(self).__name__, - self.name, - self.id, - hex(id(self)), - self.token) + retval = (f"<{type(self).__name__}({self.name}, id={self.id}) object at {hex(id(self))}>(tokens={self.tokens})") return retval def __eq__(self, other): @@ -153,7 +163,7 @@ def id(self) -> tuple: Returns: tuple: (client_id: int, nonce: int) where client_id is the id specified at construction. """ - return self.__id + return self._id @property def tokens(self) -> list: @@ -163,7 +173,7 @@ def tokens(self) -> list: Returns: list: The list of tokens. """ - return self.__tokens + return self._tokens def to_line(self) -> str: """ @@ -172,4 +182,8 @@ def to_line(self) -> str: Returns: str: The string representation of the instruction. """ - return ", ".join(self.tokens) \ No newline at end of file + if not GlobalConfig.suppressComments: + comment_str = f" # {self.comment}" if self.comment else "" + + tokens_str = ", ".join(self._tokens) + return f"{tokens_str}{comment_str}" \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/linker/loader.py b/assembler_tools/hec-assembler-tools/linker/loader.py index 26894912..bbbe5b83 100644 --- a/assembler_tools/hec-assembler-tools/linker/loader.py +++ b/assembler_tools/hec-assembler-tools/linker/loader.py @@ -1,9 +1,11 @@ from linker.instructions import minst from linker.instructions import cinst from linker.instructions import xinst +from linker.instructions import dinst from linker import instructions +from assembler.memory_model.mem_info import MemInfo -def loadMInstKernel(line_iter) -> list: +def load_minst_kernel(line_iter) -> list: """ Loads MInstruction kernel from an iterator of lines. @@ -18,13 +20,13 @@ def loadMInstKernel(line_iter) -> list: """ retval = [] for idx, s_line in enumerate(line_iter): - minstr = instructions.fromStrLine(s_line, minst.factory()) + minstr = instructions.create_from_str_line(s_line, minst.factory()) if not minstr: raise RuntimeError(f'Error parsing line {idx + 1}: {s_line}') retval.append(minstr) return retval -def loadMInstKernelFromFile(filename: str) -> list: +def load_minst_kernel_from_file(filename: str) -> list: """ Loads MInstruction kernel from a file. @@ -39,11 +41,11 @@ def loadMInstKernelFromFile(filename: str) -> list: """ with open(filename, 'r') as kernel_minsts: try: - return loadMInstKernel(kernel_minsts) + return load_minst_kernel(kernel_minsts) except Exception as e: raise RuntimeError(f'Error occurred loading file "{filename}"') from e -def loadCInstKernel(line_iter) -> list: +def load_cinst_kernel(line_iter) -> list: """ Loads CInstruction kernel from an iterator of lines. @@ -58,13 +60,13 @@ def loadCInstKernel(line_iter) -> list: """ retval = [] for idx, s_line in enumerate(line_iter): - cinstr = instructions.fromStrLine(s_line, cinst.factory()) + cinstr = instructions.create_from_str_line(s_line, cinst.factory()) if not cinstr: raise RuntimeError(f'Error parsing line {idx + 1}: {s_line}') retval.append(cinstr) return retval -def loadCInstKernelFromFile(filename: str) -> list: +def load_cinst_kernel_from_file(filename: str) -> list: """ Loads CInstruction kernel from a file. @@ -79,11 +81,11 @@ def loadCInstKernelFromFile(filename: str) -> list: """ with open(filename, 'r') as kernel_cinsts: try: - return loadCInstKernel(kernel_cinsts) + return load_cinst_kernel(kernel_cinsts) except Exception as e: raise RuntimeError(f'Error occurred loading file "{filename}"') from e -def loadXInstKernel(line_iter) -> list: +def load_xinst_kernel(line_iter) -> list: """ Loads XInstruction kernel from an iterator of lines. @@ -98,13 +100,13 @@ def loadXInstKernel(line_iter) -> list: """ retval = [] for idx, s_line in enumerate(line_iter): - xinstr = instructions.fromStrLine(s_line, xinst.factory()) + xinstr = instructions.create_from_str_line(s_line, xinst.factory()) if not xinstr: raise RuntimeError(f'Error parsing line {idx + 1}: {s_line}') retval.append(xinstr) return retval -def loadXInstKernelFromFile(filename: str) -> list: +def load_xinst_kernel_from_file(filename: str) -> list: """ Loads XInstruction kernel from a file. @@ -119,6 +121,48 @@ def loadXInstKernelFromFile(filename: str) -> list: """ with open(filename, 'r') as kernel_xinsts: try: - return loadXInstKernel(kernel_xinsts) + return load_xinst_kernel(kernel_xinsts) except Exception as e: - raise RuntimeError(f'Error occurred loading file "{filename}"') from e \ No newline at end of file + raise RuntimeError(f'Error occurred loading file "{filename}"') from e + +def load_dinst_kernel(line_iter) -> list: + """ + Loads DInstruction kernel from an iterator of lines. + + Parameters: + line_iter: An iterator over lines of DInstruction strings. + + Returns: + list: A list of DInstruction objects. + + Raises: + RuntimeError: If a line cannot be parsed into an DInstruction. + """ + retval = [] + for idx, s_line in enumerate(line_iter): + dinstr = dinst.create_from_mem_line(s_line) + if not dinstr: + raise RuntimeError(f'Error parsing line {idx + 1}: {s_line}') + retval.append(dinstr) + + return retval + +def load_dinst_kernel_from_file(filename: str) -> list: + """ + Loads DInstruction kernel from a file. + + Parameters: + filename (str): The file containing DInstruction strings. + + Returns: + list: A list of DInstruction objects. + + Raises: + RuntimeError: If an error occurs while loading the file. + """ + with open(filename, 'r') as kernel_dinsts: + try: + return load_dinst_kernel(kernel_dinsts) + except Exception as e: + raise RuntimeError(f'Error occurred loading file "{filename}"') from e + \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py b/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py index 6f6c1999..e80cea99 100644 --- a/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py +++ b/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py @@ -1,8 +1,10 @@ from linker import MemoryModel -from linker.instructions import minst, cinst, xinst +from linker.instructions import minst, cinst, dinst +from linker.instructions.dinst.dinstruction import DInstruction from assembler.common.config import GlobalConfig from assembler.instructions import cinst as ISACInst +from assembler.memory_model.mem_info import MemInfo class LinkedProgram: """ @@ -349,4 +351,61 @@ def linkKernel(self, self.__minst_line_offset += (len(kernel_minstrs) - 1) # Subtract last line that is getting removed self.__cinst_line_offset += (len(kernel_cinstrs) - 1) # Subtract last line that is getting removed - self.__kernel_count += 1 # Count the appended kernel \ No newline at end of file + self.__kernel_count += 1 # Count the appended kernel + + @classmethod + def join_dinst_kernels(cls, kernels_instrs: list[list[DInstruction]]) -> list[DInstruction]: + """ + Joins a list of dinst kernels, consolidating variables that are outputs in one kernel + and inputs in the next. This ensures that variables carried across kernels are not duplicated, + and their Mem addresses are consistent. + + Args: + kernels_instrs (list): List of Kernels' DInstructions lists. + + Returns: + list[DInstructions]: A new instruction list representing the concatenated memory info. + """ + + if not kernels_instrs: + raise ValueError("No DInstructions lists provided for concatenation.") + + # Use dictionaries to track unique variables by name + inputs: dict[str: DInstruction] = {} + carry_over_vars: dict[str: DInstruction] = {} + + mem_address: int = 0 + new_kernels_instrs: list[DInstruction] = [] + for k_idx, kernel_instrs in enumerate(kernels_instrs): + + for idx, cur_dinst in enumerate(kernel_instrs): + + # Save the current output instruction to add at the end + if isinstance(cur_dinst, dinst.DStore): + key = cur_dinst.var + carry_over_vars[key] = cur_dinst + continue + + if isinstance(cur_dinst, (dinst.DLoad, dinst.DKeyGen)): + key = cur_dinst.var + # Skip if the input is already in carry-over from previous outputs + if key in carry_over_vars: + carry_over_vars.pop(key) # Remove from (output) carry-overs since it's now an input + continue + + # If the input is not (a previous output) in carry-over, add if it's not already (loaded) in inputs + if key not in inputs: + inputs[key] = cur_dinst + cur_dinst.address = mem_address + mem_address = mem_address + 1 + + new_kernels_instrs.append(cur_dinst) + continue + + # Add remaining carry-over variables to the new instructions + for var in carry_over_vars: + carry_over_vars[var].address = mem_address + new_kernels_instrs.append(carry_over_vars[var]) + mem_address = mem_address + 1 + + return new_kernels_instrs From 04ed69ecf6b9356a3c3bffa5c85f5ce3576eb17e Mon Sep 17 00:00:00 2001 From: "Rojas Chaves, Jose" Date: Tue, 1 Jul 2025 00:48:04 +0000 Subject: [PATCH 02/12] post rebase fixes --- .../assembler/common/config.py | 33 +- .../assembler/common/run_config.py | 61 +- .../assembler/instructions/instruction.py | 2 +- .../assembler/memory_model/mem_info.py | 128 ++-- .../hec-assembler-tools/debug_tools/main.py | 6 +- assembler_tools/hec-assembler-tools/he_as.py | 6 +- .../hec-assembler-tools/he_link.py | 629 +++++++++++------- .../linker/instructions/dinst/__init__.py | 19 +- .../linker/instructions/instruction.py | 57 +- .../linker/steps/program_linker.py | 532 ++++++++------- 10 files changed, 871 insertions(+), 602 deletions(-) diff --git a/assembler_tools/hec-assembler-tools/assembler/common/config.py b/assembler_tools/hec-assembler-tools/assembler/common/config.py index 4e31682e..49f2578a 100644 --- a/assembler_tools/hec-assembler-tools/assembler/common/config.py +++ b/assembler_tools/hec-assembler-tools/assembler/common/config.py @@ -1,29 +1,34 @@ - +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""A configuration class for controlling various aspects of the assembler's behavior.""" + + class GlobalConfig: """ A configuration class for controlling various aspects of the assembler's behavior. Attributes: - suppressComments (bool): + suppress_comments (bool): If True, no comments will be emitted in the output generated by the assembler. - - useHBMPlaceHolders (bool): - Specifies whether to use placeholders (names) for variable locations in HBM + + useHBMPlaceHolders (bool): + Specifies whether to use placeholders (names) for variable locations in HBM or the actual variable locations. - - useXInstFetch (bool): - Specifies whether `xinstfetch` instructions should be added into CInstQ or not. - When no `xinstfetch` instructions are added, it is assumed that the HERACLES + + useXInstFetch (bool): + Specifies whether `xinstfetch` instructions should be added into CInstQ or not. + When no `xinstfetch` instructions are added, it is assumed that the HERACLES automated mechanism for `xinstfetch` will be activated. - - debugVerbose (int): - If greater than 0, verbose prints will occur. Its value indicates how often to - print within loops (every `debugVerbose` iterations). This is used for internal + + debugVerbose (int): + If greater than 0, verbose prints will occur. Its value indicates how often to + print within loops (every `debugVerbose` iterations). This is used for internal debugging purposes. hashHBM (bool): Specifies whether the target architecture has HBM or not. """ - suppressComments = False + suppress_comments = False useHBMPlaceHolders = True useXInstFetch = True debugVerbose: int = 0 diff --git a/assembler_tools/hec-assembler-tools/assembler/common/run_config.py b/assembler_tools/hec-assembler-tools/assembler/common/run_config.py index 9737a962..54600a56 100644 --- a/assembler_tools/hec-assembler-tools/assembler/common/run_config.py +++ b/assembler_tools/hec-assembler-tools/assembler/common/run_config.py @@ -1,9 +1,13 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + import io from .decorators import * from . import constants from .config import GlobalConfig + class RunConfig: """ Configuration class for running the assembler with specific settings. @@ -12,11 +16,12 @@ class RunConfig: policies, and other options that affect the behavior of the assembler. """ - __initialized = False # Specifies whether static members have been initialized - __default_config = {} # Dictionary of all configuration items supported and their default values + __initialized = False # Specifies whether static members have been initialized + __default_config = ( + {} + ) # Dictionary of all configuration items supported and their default values - def __init__(self, - **kwargs): + def __init__(self, **kwargs): """ Constructs a new RunConfig object from input parameters. @@ -33,7 +38,7 @@ def __init__(self, suppress_comments (bool, optional): If true, no comments will be emitted in the output generated by the assembler. - Defaults to GlobalConfig.suppressComments (`False`). + Defaults to GlobalConfig.suppress_comments (`False`). use_hbm_placeholders (bool, optional): [DEPRECATED]/[UNUSED] Specifies whether to use placeholders (names) for variable locations in HBM (`True`) @@ -52,28 +57,42 @@ def __init__(self, ValueError: If at least one of the arguments passed is invalid. """ + RunConfig.init_default_config() + # Initialize class members for config_name, default_value in self.__default_config.items(): - setattr(self, config_name, kwargs.get(config_name, default_value)) + value = kwargs.get(config_name) + if value is not None: + setattr(self, config_name, value) + else: + setattr(self, config_name, default_value) # Validate inputs if self.repl_policy not in constants.Constants.REPLACEMENT_POLICIES: - raise ValueError('Invalid `repl_policy`. "{}" not in {}'.format(self.repl_policy, - constants.Constants.REPLACEMENT_POLICIES)) + raise ValueError( + 'Invalid `repl_policy`. "{}" not in {}'.format( + self.repl_policy, constants.Constants.REPLACEMENT_POLICIES + ) + ) + @classproperty def DEFAULT_HBM_SIZE_KB(cls) -> int: - return int(constants.MemoryModel.HBM.MAX_CAPACITY / constants.Constants.KILOBYTE) + return int( + constants.MemoryModel.HBM.MAX_CAPACITY / constants.Constants.KILOBYTE + ) @classproperty def DEFAULT_SPAD_SIZE_KB(cls) -> int: - return int(constants.MemoryModel.SPAD.MAX_CAPACITY / constants.Constants.KILOBYTE) + return int( + constants.MemoryModel.SPAD.MAX_CAPACITY / constants.Constants.KILOBYTE + ) @classproperty def DEFAULT_REPL_POLICY(cls) -> int: return constants.Constants.REPLACEMENT_POLICY_FTBU @classmethod - def init_static(cls): + def init_default_config(cls): """ Initializes static members of the RunConfig class. @@ -81,13 +100,14 @@ def init_static(cls): that they are only initialized once. """ if not cls.__initialized: - cls.__default_config["hbm_size"] = cls.DEFAULT_HBM_SIZE_KB - cls.__default_config["spad_size"] = cls.DEFAULT_SPAD_SIZE_KB - cls.__default_config["repl_policy"] = cls.DEFAULT_REPL_POLICY - cls.__default_config["suppress_comments"] = GlobalConfig.suppressComments - #cls.__default_config["use_hbm_placeholders"] = GlobalConfig.useHBMPlaceHolders - cls.__default_config["use_xinstfetch"] = GlobalConfig.useXInstFetch - cls.__default_config["debug_verbose"] = GlobalConfig.debugVerbose + cls.__default_config["has_hbm"] = True + cls.__default_config["hbm_size"] = cls.DEFAULT_HBM_SIZE_KB + cls.__default_config["spad_size"] = cls.DEFAULT_SPAD_SIZE_KB + cls.__default_config["repl_policy"] = cls.DEFAULT_REPL_POLICY + cls.__default_config["suppress_comments"] = GlobalConfig.suppress_comments + # cls.__default_config["use_hbm_placeholders"] = GlobalConfig.useHBMPlaceHolders + cls.__default_config["use_xinstfetch"] = GlobalConfig.useXInstFetch + cls.__default_config["debug_verbose"] = GlobalConfig.debugVerbose cls.__initialized = True @@ -113,4 +133,7 @@ def as_dict(self) -> dict: dict: A dictionary representation of the current configuration settings. """ tmp_self_dict = vars(self) - return { config_name: tmp_self_dict[config_name] for config_name in self.__default_config } + return { + config_name: tmp_self_dict[config_name] + for config_name in self.__default_config + } diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/instruction.py b/assembler_tools/hec-assembler-tools/assembler/instructions/instruction.py index ecbf9e9b..0418b49e 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/instruction.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/instruction.py @@ -552,7 +552,7 @@ def to_string_format(self, preamble, op_name: str, *extra_args) -> str: retval = f'{", ".join(str(x) for x in preamble)}, {retval}' if extra_args: retval += f', {", ".join([str(extra) for extra in extra_args])}' - if not GlobalConfig.suppressComments: + if not GlobalConfig.suppress_comments: if self.comment: retval += f" #{self.comment}" return retval diff --git a/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_info.py b/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_info.py index a30cc7c5..834fc8c5 100644 --- a/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_info.py +++ b/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_info.py @@ -125,9 +125,11 @@ def parse_from_mem_tokens(cls, tokens: list) -> MemInfoVariable: Returns: MemInfoVariable: The parsed ones metadata variable. """ - return MemInfo.Metadata.parse_meta_field_from_mem_tokens(tokens, - MemInfo.Const.Keyword.LOAD_ONES, - var_prefix=MemInfo.Const.Keyword.LOAD_ONES) + return MemInfo.Metadata.parse_meta_field_from_mem_tokens( + tokens, + MemInfo.Const.Keyword.LOAD_ONES, + var_prefix=MemInfo.Const.Keyword.LOAD_ONES, + ) class NTTAuxTable: @classmethod @@ -141,9 +143,11 @@ def parse_from_mem_tokens(cls, tokens: list) -> MemInfoVariable: Returns: MemInfoVariable: The parsed NTT auxiliary table metadata variable. """ - return MemInfo.Metadata.parse_meta_field_from_mem_tokens(tokens, - MemInfo.Const.Keyword.LOAD_NTT_AUX_TABLE, - var_prefix=MemInfo.Const.Keyword.LOAD_NTT_AUX_TABLE) + return MemInfo.Metadata.parse_meta_field_from_mem_tokens( + tokens, + MemInfo.Const.Keyword.LOAD_NTT_AUX_TABLE, + var_prefix=MemInfo.Const.Keyword.LOAD_NTT_AUX_TABLE, + ) class NTTRoutingTable: @classmethod @@ -157,9 +161,11 @@ def parse_from_mem_tokens(cls, tokens: list) -> MemInfoVariable: Returns: MemInfoVariable: The parsed NTT routing table metadata variable. """ - return MemInfo.Metadata.parse_meta_field_from_mem_tokens(tokens, - MemInfo.Const.Keyword.LOAD_NTT_ROUTING_TABLE, - var_prefix=MemInfo.Const.Keyword.LOAD_NTT_ROUTING_TABLE) + return MemInfo.Metadata.parse_meta_field_from_mem_tokens( + tokens, + MemInfo.Const.Keyword.LOAD_NTT_ROUTING_TABLE, + var_prefix=MemInfo.Const.Keyword.LOAD_NTT_ROUTING_TABLE, + ) class iNTTAuxTable: @classmethod @@ -173,9 +179,11 @@ def parse_from_mem_tokens(cls, tokens: list) -> MemInfoVariable: Returns: MemInfoVariable: The parsed iNTT auxiliary table metadata variable. """ - return MemInfo.Metadata.parse_meta_field_from_mem_tokens(tokens, - MemInfo.Const.Keyword.LOAD_iNTT_AUX_TABLE, - var_prefix=MemInfo.Const.Keyword.LOAD_iNTT_AUX_TABLE) + return MemInfo.Metadata.parse_meta_field_from_mem_tokens( + tokens, + MemInfo.Const.Keyword.LOAD_iNTT_AUX_TABLE, + var_prefix=MemInfo.Const.Keyword.LOAD_iNTT_AUX_TABLE, + ) class iNTTRoutingTable: @classmethod @@ -189,9 +197,11 @@ def parse_from_mem_tokens(cls, tokens: list) -> MemInfoVariable: Returns: MemInfoVariable: The parsed iNTT routing table metadata variable. """ - return MemInfo.Metadata.parse_meta_field_from_mem_tokens(tokens, - MemInfo.Const.Keyword.LOAD_iNTT_ROUTING_TABLE, - var_prefix=MemInfo.Const.Keyword.LOAD_iNTT_ROUTING_TABLE) + return MemInfo.Metadata.parse_meta_field_from_mem_tokens( + tokens, + MemInfo.Const.Keyword.LOAD_iNTT_ROUTING_TABLE, + var_prefix=MemInfo.Const.Keyword.LOAD_iNTT_ROUTING_TABLE, + ) class Twiddle: @classmethod @@ -205,9 +215,11 @@ def parse_from_mem_tokens(cls, tokens: list) -> MemInfoVariable: Returns: MemInfoVariable: The parsed twiddle metadata variable. """ - return MemInfo.Metadata.parse_meta_field_from_mem_tokens(tokens, - MemInfo.Const.Keyword.LOAD_TWIDDLE, - var_prefix=MemInfo.Const.Keyword.LOAD_TWIDDLE) + return MemInfo.Metadata.parse_meta_field_from_mem_tokens( + tokens, + MemInfo.Const.Keyword.LOAD_TWIDDLE, + var_prefix=MemInfo.Const.Keyword.LOAD_TWIDDLE, + ) class KeygenSeed: @classmethod @@ -221,16 +233,20 @@ def parse_from_mem_tokens(cls, tokens: list) -> MemInfoVariable: Returns: MemInfoVariable: The parsed keygen seed metadata variable. """ - return MemInfo.Metadata.parse_meta_field_from_mem_tokens(tokens, - MemInfo.Const.Keyword.LOAD_KEYGEN_SEED, - var_prefix=MemInfo.Const.Keyword.LOAD_KEYGEN_SEED) + return MemInfo.Metadata.parse_meta_field_from_mem_tokens( + tokens, + MemInfo.Const.Keyword.LOAD_KEYGEN_SEED, + var_prefix=MemInfo.Const.Keyword.LOAD_KEYGEN_SEED, + ) @classmethod - def parse_meta_field_from_mem_tokens(cls, - tokens: list, - meta_field_name: str, - var_prefix: str = "meta", - var_extra: str = None) -> MemInfoVariable: + def parse_meta_field_from_mem_tokens( + cls, + tokens: list, + meta_field_name: str, + var_prefix: str = "meta", + var_extra: str = None, + ) -> MemInfoVariable: """ Parses a metadata variable name from a tokenized line. @@ -440,16 +456,16 @@ def __init__(self, **kwargs): kwargs (dict): A dictionary as generated by the method MemInfo.as_dict(). This is provided as a shortcut to creating a MemInfo object from structured data such as the contents of a YAML file. """ - self.__keygens = [ + self._keygens = [ MemInfoVariable(**d) for d in kwargs.get(MemInfo.Const.FIELD_KEYGENS, []) ] - self.__inputs = [ + self._inputs = [ MemInfoVariable(**d) for d in kwargs.get(MemInfo.Const.FIELD_INPUTS, []) ] - self.__outputs = [ + self._outputs = [ MemInfoVariable(**d) for d in kwargs.get(MemInfo.Const.FIELD_OUTPUTS, []) ] - self.__metadata = MemInfo.Metadata( + self._metadata = MemInfo.Metadata( **kwargs.get(MemInfo.Const.FIELD_METADATA, {}) ) self.validate() @@ -460,29 +476,33 @@ def factory_dict(self): Returns a dictionary mapping MemInfo types to their respective factory methods. This is used to create instances of MemInfoVariable based on the type of memory information. - + Returns: dict: A dictionary mapping MemInfo types to their factory methods. """ - return { MemInfo.Keygen: self.keygens, - MemInfo.Input: self.inputs, - MemInfo.Output: self.outputs, - MemInfo.Metadata.KeygenSeed: self.metadata.keygen_seeds, - MemInfo.Metadata.Ones: self.metadata.ones, - MemInfo.Metadata.NTTAuxTable: self.metadata.ntt_auxiliary_table, - MemInfo.Metadata.NTTRoutingTable: self.metadata.ntt_routing_table, - MemInfo.Metadata.iNTTAuxTable: self.metadata.intt_auxiliary_table, - MemInfo.Metadata.iNTTRoutingTable: self.metadata.intt_routing_table, - MemInfo.Metadata.Twiddle: self.metadata.twiddle } - + return { + MemInfo.Keygen: self.keygens, + MemInfo.Input: self.inputs, + MemInfo.Output: self.outputs, + MemInfo.Metadata.KeygenSeed: self.metadata.keygen_seeds, + MemInfo.Metadata.Ones: self.metadata.ones, + MemInfo.Metadata.NTTAuxTable: self.metadata.ntt_auxiliary_table, + MemInfo.Metadata.NTTRoutingTable: self.metadata.ntt_routing_table, + MemInfo.Metadata.iNTTAuxTable: self.metadata.intt_auxiliary_table, + MemInfo.Metadata.iNTTRoutingTable: self.metadata.intt_routing_table, + MemInfo.Metadata.Twiddle: self.metadata.twiddle, + } + @classproperty def mem_info_types(cls): - """ Retrieves the list of MemInfo variable types.""" + """Retrieves the list of MemInfo variable types.""" dummy = cls() - return dummy.factory_dict.keys() - + return dummy.factory_dict.keys() + @classmethod - def get_meminfo_var_from_tokens(cls, tokens) -> tuple[Optional[MemInfoVariable], Optional[type]]: + def get_meminfo_var_from_tokens( + cls, tokens + ) -> tuple[Optional[MemInfoVariable], Optional[type]]: """ Parses a MemInfo variable from a list of tokens. @@ -498,26 +518,26 @@ def get_meminfo_var_from_tokens(cls, tokens) -> tuple[Optional[MemInfoVariable], miv: MemInfoVariable = mem_info_type.parse_from_mem_tokens(tokens) if miv is not None: break - + return miv, mem_info_type - + def add_meminfo_var_from_tokens(self, tokens): """ Parses a MemInfo variable from a list of tokens and adds it to the appropriate list. - + Args: tokens (list[str]): List of tokens to parse. Raises: RuntimeError: If the line could not be parsed. """ - miv: MemInfoVariable = None + miv: MemInfoVariable = None miv, mem_info_type = MemInfo.get_meminfo_var_from_tokens(tokens) if miv is not None and mem_info_type is not None: self.factory_dict[mem_info_type].append(miv) else: raise RuntimeError(f"Could not parse line") - + @classmethod def from_file_iter(cls, line_iter): """ @@ -536,7 +556,7 @@ def from_file_iter(cls, line_iter): """ retval = cls() - + for line_no, s_line in enumerate(line_iter, 1): s_line = s_line.strip() if s_line: # skip empty lines @@ -573,10 +593,10 @@ def from_dinstrs(cls, dinstrs): print(f"Added {tokens} to MemInfo") except RuntimeError as e: raise RuntimeError(f"{e} {ints_no}: {tokens}") from e - + retval.validate() return retval - + @property def keygens(self) -> list: """ diff --git a/assembler_tools/hec-assembler-tools/debug_tools/main.py b/assembler_tools/hec-assembler-tools/debug_tools/main.py index 9a675cd1..31573508 100644 --- a/assembler_tools/hec-assembler-tools/debug_tools/main.py +++ b/assembler_tools/hec-assembler-tools/debug_tools/main.py @@ -82,7 +82,7 @@ def main_readmem(args): ) mem_meta_info = None - with open(mem_filename, 'r') as mem_ifnum: + with open(mem_filename, "r") as mem_ifnum: mem_meta_info = mem_info.MemInfo.from_file_iter(mem_ifnum) if mem_meta_info: @@ -216,7 +216,7 @@ def asmisa_assembly( if b_verbose: print("Interpreting variable meta information...") - with open(mem_filename, 'r') as mem_ifnum: + with open(mem_filename, "r") as mem_ifnum: mem_meta_info = mem_info.MemInfo.from_file_iter(mem_ifnum) mem_info.updateMemoryModelWithMemInfo(hec_mem_model, mem_meta_info) @@ -289,7 +289,7 @@ def main_asmisa(args): b_use_old_mem_file = False b_verbose = True if args.verbose > 0 else False GlobalConfig.debugVerbose = 0 - GlobalConfig.suppressComments = False + GlobalConfig.suppress_comments = False GlobalConfig.useHBMPlaceHolders = True GlobalConfig.useXInstFetch = False diff --git a/assembler_tools/hec-assembler-tools/he_as.py b/assembler_tools/hec-assembler-tools/he_as.py index f92d63c9..0b666c39 100644 --- a/assembler_tools/hec-assembler-tools/he_as.py +++ b/assembler_tools/hec-assembler-tools/he_as.py @@ -146,7 +146,7 @@ def init_default_config(cls): cls.__default_config["spad_size"] = cls.DEFAULT_SPAD_SIZE_KB cls.__default_config["repl_policy"] = cls.DEFAULT_REPL_POLICY cls.__default_config["use_xinstfetch"] = GlobalConfig.useXInstFetch - cls.__default_config["suppress_comments"] = GlobalConfig.suppressComments + cls.__default_config["suppress_comments"] = GlobalConfig.suppress_comments cls.__default_config["debug_verbose"] = GlobalConfig.debugVerbose cls.__initialized = True @@ -250,7 +250,7 @@ def asmisaAssemble( if b_verbose: print("Interpreting variable meta information...") - with open(mem_filename, 'r') as mem_ifnum: + with open(mem_filename, "r") as mem_ifnum: mem_meta_info = mem_info.MemInfo.from_file_iter(mem_ifnum) mem_info.updateMemoryModelWithMemInfo(hec_mem_model, mem_meta_info) @@ -358,7 +358,7 @@ def main(config: AssemblerRunConfig, verbose: bool = False): GlobalConfig.useHBMPlaceHolders = True # config.use_hbm_placeholders GlobalConfig.useXInstFetch = config.use_xinstfetch - GlobalConfig.suppressComments = config.suppress_comments + GlobalConfig.suppress_comments = config.suppress_comments GlobalConfig.hasHBM = config.has_hbm GlobalConfig.debugVerbose = config.debug_verbose diff --git a/assembler_tools/hec-assembler-tools/he_link.py b/assembler_tools/hec-assembler-tools/he_link.py index f29390cb..ba64a947 100644 --- a/assembler_tools/hec-assembler-tools/he_link.py +++ b/assembler_tools/hec-assembler-tools/he_link.py @@ -1,39 +1,35 @@ #! /usr/bin/env python3 -# encoding: utf-8 -""" -This module provides functionality for linking assembled kernels into a full HERACLES program for execution queues: MINST, CINST, and XINST. +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 -Classes: - LinkerRunConfig - Maintains the configuration data for the run. +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions - KernelFiles - Structure for kernel files. +""" +@file he_link.py +@brief This module provides functionality for linking assembled kernels into a full HERACLES program for execution queues: MINST, CINST, and XINST. -Functions: - main(run_config: LinkerRunConfig, verbose_stream=None) - Executes the linking process using the provided configuration. +@par Classes: + - LinkerRunConfig: Maintains the configuration data for the run. + - KernelFiles: Structure for kernel files. - parse_args() -> argparse.Namespace - Parses command-line arguments for the linker script. +@par Functions: + - main(run_config: LinkerRunConfig, verbose_stream=None): Executes the linking process using the provided configuration. + - parse_args() -> argparse.Namespace: Parses command-line arguments for the linker script. -Usage: +@par Usage: This script is intended to be run as a standalone program. It requires specific command-line arguments to specify input and output files and configuration options for the linking process. - """ import argparse import io import os import pathlib import sys -import time import warnings +from typing import NamedTuple, Any, Optional import linker - -from typing import NamedTuple - from assembler.common import constants from assembler.common import makeUniquePath from assembler.common.counter import Counter @@ -47,306 +43,357 @@ from linker.steps import program_linker from linker.instructions import BaseInstruction + class LinkerRunConfig(RunConfig): """ - Maintains the configuration data for the run. + @class LinkerRunConfig + @brief Maintains the configuration data for the run. + + @fn as_dict + @brief Returns the configuration as a dictionary. - Methods: - as_dict() -> dict - Returns the configuration as a dictionary. + @return dict The configuration as a dictionary. """ - __initialized = False # specifies whether static members have been initialized + __initialized = False # specifies whether static members have been initialized # contains the dictionary of all configuration items supported and their # default value (or None if no default) - __default_config = {} + __default_config: dict[str, Any] = {} def __init__(self, **kwargs): """ - Constructs a new LinkerRunConfig Object from input parameters. + @brief Constructs a new LinkerRunConfig Object from input parameters. See base class constructor for more parameters. - Args: - input_prefixes (list[str]): - List of input prefixes, including full path. For an input prefix, linker will - assume there are three files named `input_prefixes[i] + '.minst'`, - `input_prefixes[i] + '.cinst'`, and `input_prefixes[i] + '.xinst'`. - This list must not be empty. - output_prefix (str): - Prefix for the output file names. - Three files will be generated: - `output_dir/output_prefix.minst`, `output_dir/output_prefix.cinst`, and - `output_dir/output_prefix.xinst`. - Output filenames cannot match input file names. - input_mem_file (str): - Input memory file associated with the result kernel. - output_dir (str): current working directory - OPTIONAL directory where to store all intermediate files and final output. - This will be created if it doesn't exists. - Defaults to current working directory. - - Raises: - TypeError: - A mandatory configuration value was missing. - ValueError: - At least, one of the arguments passed is invalid. + @param input_prefixes List of input prefixes, including full path. For an input prefix, linker will + assume there are three files named `input_prefixes[i] + '.minst'`, + `input_prefixes[i] + '.cinst'`, and `input_prefixes[i] + '.xinst'`. + This list must not be empty. + @param output_prefix Prefix for the output file names. + Three files will be generated: + `output_dir/output_prefix.minst`, `output_dir/output_prefix.cinst`, and + `output_dir/output_prefix.xinst`. + Output filenames cannot match input file names. + @param input_mem_file Input memory file associated with the result kernel. + @param output_dir OPTIONAL directory where to store all intermediate files and final output. + This will be created if it doesn't exists. + Defaults to current working directory. + + @exception TypeError A mandatory configuration value was missing. + @exception ValueError At least, one of the arguments passed is invalid. """ + super().__init__(**kwargs) self.init_default_config() - + # class members based on configuration for config_name, default_value in self.__default_config.items(): - value = kwargs.get(config_name) + print(f"ROCHA DEBUG: {config_name} = {default_value}") + value = kwargs.get(config_name, default_value) if value is not None: - assert(not hasattr(self, config_name)) setattr(self, config_name, value) + print(f"ROCHA ADDED: {config_name} = {value}") else: if not hasattr(self, config_name): setattr(self, config_name, default_value) + print(f"ROCHA ADDED: {config_name} = {default_value}") if getattr(self, config_name) is None: - raise TypeError(f'Expected value for configuration `{config_name}`, but `None` received.') + raise TypeError( + f"Expected value for configuration `{config_name}`, but `None` received." + ) # fix file names self.output_dir = makeUniquePath(self.output_dir) - if self.input_mem_file is not None: + # E0203: Access to member 'input_mem_file' before its definition. + # But it was defined in previous loop. + if self.input_mem_file != "": # pylint: disable=E0203 self.input_mem_file = makeUniquePath(self.input_mem_file) @classmethod def init_default_config(cls): """ - Initializes static members of the class. + @brief Initializes static members of the class. """ if not cls.__initialized: - cls.__default_config["input_prefixes"] = None - cls.__default_config["input_mem_file"] = None - cls.__default_config["find_mem_files"] = False - cls.__default_config["output_dir"] = os.getcwd() - cls.__default_config["output_prefix"] = None - cls.__default_config["has_hbm"] = True - cls.__default_config["hbm_size"] = cls.DEFAULT_HBM_SIZE_KB - cls.__default_config["use_xinstfetch"] = GlobalConfig.useXInstFetch - cls.__default_config["suppress_comments"] = GlobalConfig.suppressComments + cls.__default_config["input_prefixes"] = None + cls.__default_config["input_mem_file"] = "" + cls.__default_config["find_mem_files"] = False + cls.__default_config["output_dir"] = os.getcwd() + cls.__default_config["output_prefix"] = None cls.__initialized = True def __str__(self): """ - Provides a string representation of the configuration. - - Returns: - str: The string for the configuration. + @brief Provides a string representation of the configuration. + + @return str The string for the configuration. """ self_dict = self.as_dict() with io.StringIO() as retval_f: for key, value in self_dict.items(): - print("{}: {}".format(key, value), file=retval_f) + print(f"{key}: {value}", file=retval_f) retval = retval_f.getvalue() return retval def as_dict(self) -> dict: """ - Provides the configuration as a dictionary. + @brief Provides the configuration as a dictionary. - Returns: - dict: The configuration. + @return dict The configuration. """ retval = super().as_dict() tmp_self_dict = vars(self) - retval.update({ config_name: tmp_self_dict[config_name] for config_name in self.__default_config }) + retval.update( + { + config_name: tmp_self_dict[config_name] + for config_name in self.__default_config + } + ) return retval + class KernelFiles(NamedTuple): """ - Structure for kernel files. - - Attributes: - prefix (str): - Index = 0 - minst (str): - Index = 1. Name for file containing MInstructions for represented kernel. - cinst (str): - Index = 2. Name for file containing CInstructions for represented kernel. - xinst (str): - Index = 3. Name for file containing XInstructions for represented kernel. - mem (str, optional): - Index = 4. Name for file containing memory information for represented kernel. - This is used only when find_mem_files is set. + @class KernelFiles + @brief Structure for kernel files. + + @var prefix + Index = 0 + @var minst + Index = 1. Name for file containing MInstructions for represented kernel. + @var cinst + Index = 2. Name for file containing CInstructions for represented kernel. + @var xinst + Index = 3. Name for file containing XInstructions for represented kernel. + @var mem + Index = 4. Name for file containing memory information for represented kernel. + This is used only when find_mem_files is set. """ + prefix: str minst: str cinst: str xinst: str - mem: str = None + mem: Optional[str] = None -def main(run_config: LinkerRunConfig, verbose_stream = None): - """ - Executes the linking process using the provided configuration. - - This function prepares input and output file names, initializes the memory model, discovers variables, - and links each kernel, writing the output to specified files. - Args: - run_config (LinkerRunConfig): The configuration object containing run parameters. - verbose_stream: The stream to which verbose output is printed. Defaults to None. +def link_kernels(input_files, output_files, mem_model, verbose_stream): + """ + @brief Links input kernels and writes the output to the specified files. - Returns: - None + @param input_files List of KernelFiles for input kernels. + @param output_files KernelFiles for output. + @param mem_model Memory model to use. + @param run_config LinkerRunConfig object. + @param verbose_stream Stream for verbose output. """ - if verbose_stream: - print("Linking...", file=verbose_stream) + with open(output_files.minst, "w", encoding="utf-8") as fnum_output_minst, open( + output_files.cinst, "w", encoding="utf-8" + ) as fnum_output_cinst, open( + output_files.xinst, "w", encoding="utf-8" + ) as fnum_output_xinst: + + result_program = program_linker.LinkedProgram( + fnum_output_minst, fnum_output_cinst, fnum_output_xinst, mem_model + ) + for idx, kernel in enumerate(input_files): + if verbose_stream: + print( + f"[ {idx * 100 // len(input_files): >3}% ]", + kernel.prefix, + file=verbose_stream, + ) + kernel_minstrs = loader.load_minst_kernel_from_file(kernel.minst) + kernel_cinstrs = loader.load_cinst_kernel_from_file(kernel.cinst) + kernel_xinstrs = loader.load_xinst_kernel_from_file(kernel.xinst) + result_program.link_kernel(kernel_minstrs, kernel_cinstrs, kernel_xinstrs) + if verbose_stream: + print( + "[ 100% ] Finalizing output", output_files.prefix, file=verbose_stream + ) + result_program.close() - if run_config.use_xinstfetch: - warnings.warn(f'Ignoring configuration flag "use_xinstfetch".') - # Update global config - GlobalConfig.hasHBM = run_config.has_hbm - GlobalConfig.suppressComments = run_config.suppress_comments - - mem_filename: str = run_config.input_mem_file - hbm_capcity_words: int = constants.convertBytes2Words(run_config.hbm_size * constants.Constants.KILOBYTE) - input_files = [] # list(KernelFiles) - output_files: KernelFiles = None +def prepare_output_files(run_config) -> KernelFiles: + """ + @brief Prepares output file names and directories. - # prepare output file names + @param run_config LinkerRunConfig object. + @return KernelFiles Output file paths. + """ output_prefix = os.path.join(run_config.output_dir, run_config.output_prefix) output_dir = os.path.dirname(output_prefix) + pathlib.Path(output_dir).mkdir(exist_ok=True, parents=True) + out_mem_file = ( + makeUniquePath(output_prefix + ".mem") if run_config.find_mem_files else None + ) + return KernelFiles( + prefix=makeUniquePath(output_prefix), + minst=makeUniquePath(output_prefix + ".minst"), + cinst=makeUniquePath(output_prefix + ".cinst"), + xinst=makeUniquePath(output_prefix + ".xinst"), + mem=out_mem_file, + ) + + +def prepare_input_files(run_config, output_files) -> list: + """ + @brief Prepares input file names and checks for existence and conflicts. - # If find_mem_files is enabled, create a new memory file for the output - out_mem_file = None - if run_config.find_mem_files: - out_mem_file = makeUniquePath(output_prefix + '.mem') - - pathlib.Path(output_dir).mkdir(exist_ok = True, parents=True) - output_files = KernelFiles(prefix=makeUniquePath(output_prefix), - minst=makeUniquePath(output_prefix + '.minst'), - cinst=makeUniquePath(output_prefix + '.cinst'), - xinst=makeUniquePath(output_prefix + '.xinst'), - mem=out_mem_file - ) - - # prepare input file names + @param run_config LinkerRunConfig object. + @param output_files KernelFiles for output. + @return list List of KernelFiles for input. + @exception FileNotFoundError If an input file does not exist. + @exception RuntimeError If an input file matches an output file. + """ + input_files = [] for file_prefix in run_config.input_prefixes: - - # If find_mem_files is enabled, try to find a .tw.mem file for each prefix - mem_file = None - if run_config.find_mem_files: - mem_file = makeUniquePath(file_prefix + '.mem') - - input_files.append(KernelFiles(prefix=makeUniquePath(file_prefix), - minst=makeUniquePath(file_prefix + '.minst'), - cinst=makeUniquePath(file_prefix + '.cinst'), - xinst=makeUniquePath(file_prefix + '.xinst'), - mem=mem_file)) - - for input_filename in input_files[-1][1:]: + mem_file = ( + makeUniquePath(file_prefix + ".mem") if run_config.find_mem_files else None + ) + kernel_files = KernelFiles( + prefix=makeUniquePath(file_prefix), + minst=makeUniquePath(file_prefix + ".minst"), + cinst=makeUniquePath(file_prefix + ".cinst"), + xinst=makeUniquePath(file_prefix + ".xinst"), + mem=mem_file, + ) + input_files.append(kernel_files) + for input_filename in kernel_files[1:]: if input_filename: if not os.path.isfile(input_filename): raise FileNotFoundError(input_filename) if input_filename in output_files: - raise RuntimeError(f'Input files cannot match output files: "{input_filename}"') + raise RuntimeError( + f'Input files cannot match output files: "{input_filename}"' + ) + return input_files + + +def scan_variables(input_files, mem_model, verbose_stream): + """ + @brief Scans input files for variables and adds them to the memory model. + + @param input_files List of KernelFiles for input. + @param mem_model Memory model to update. + @param verbose_stream Stream for verbose output. + """ + for idx, kernel in enumerate(input_files): + if not GlobalConfig.hasHBM: + if verbose_stream: + print( + f" {idx + 1}/{len(input_files)}", + kernel.cinst, + file=verbose_stream, + ) + kernel_cinstrs = loader.load_cinst_kernel_from_file(kernel.cinst) + for var_name in variable_discovery.discoverVariablesSPAD(kernel_cinstrs): + mem_model.addVariable(var_name) + else: + if verbose_stream: + print( + f" {idx + 1}/{len(input_files)}", + kernel.minst, + file=verbose_stream, + ) + kernel_minstrs = loader.load_minst_kernel_from_file(kernel.minst) + for var_name in variable_discovery.discoverVariables(kernel_minstrs): + mem_model.addVariable(var_name) + + +def check_unused_variables(mem_model): + """ + @brief Checks for unused variables in the memory model and raises an error if found. + + @param mem_model Memory model to check. + @exception RuntimeError If an unused variable is found. + """ + for var_name in mem_model.mem_info_vars: + if var_name not in mem_model.variables: + if GlobalConfig.hasHBM or var_name not in mem_model.mem_info_meta: + raise RuntimeError( + f'Unused variable from input mem file: "{var_name}" not in memory model.' + ) + + +def main(run_config: LinkerRunConfig, verbose_stream=None): + """ + @brief Executes the linking process using the provided configuration. + + This function prepares input and output file names, initializes the memory model, discovers variables, + and links each kernel, writing the output to specified files. + + @param run_config The configuration object containing run parameters. + @param verbose_stream The stream to which verbose output is printed. Defaults to None. + + @return None + """ + if run_config.use_xinstfetch: + warnings.warn("Ignoring configuration flag 'use_xinstfetch'.") + + # Update global config + GlobalConfig.hasHBM = run_config.has_hbm + GlobalConfig.suppress_comments = run_config.suppress_comments - # reset counters + mem_filename: str = run_config.input_mem_file + hbm_capacity_words: int = constants.convertBytes2Words( + run_config.hbm_size * constants.Constants.KILOBYTE + ) + + # Prepare input and output files + output_files: KernelFiles = prepare_output_files(run_config) + input_files: list[KernelFiles] = prepare_input_files(run_config, output_files) + + # Reset counters Counter.reset() # parse mem file if verbose_stream: + print("Linking...", file=verbose_stream) print("", file=verbose_stream) print("Interpreting variable meta information...", file=verbose_stream) if run_config.find_mem_files: - if verbose_stream: - print("Linking together multiple memory files", file=verbose_stream) - kernels_dinstrs = [] for kernel in input_files: kernel_dinstrs = loader.load_dinst_kernel_from_file(kernel.mem) kernels_dinstrs.append(kernel_dinstrs) # Concatenate all mem info objects into one - kernel_dinstrs = program_linker.LinkedProgram.join_dinst_kernels(kernels_dinstrs) + kernel_dinstrs = program_linker.LinkedProgram.join_dinst_kernels( + kernels_dinstrs + ) mem_meta_info = mem_info.MemInfo.from_dinstrs(kernel_dinstrs) else: - with open(mem_filename, 'r') as mem_ifnum: + with open(mem_filename, "r", encoding="utf-8") as mem_ifnum: mem_meta_info = mem_info.MemInfo.from_file_iter(mem_ifnum) # Initialize memory model if verbose_stream: print("Initializing linker memory model", file=verbose_stream) - mem_model = linker.MemoryModel(hbm_capcity_words, mem_meta_info) + mem_model = linker.MemoryModel(hbm_capacity_words, mem_meta_info) if verbose_stream: print(f" HBM capacity: {mem_model.hbm.capacity} words", file=verbose_stream) - # Find all variables and usage across all the input kernels - - if verbose_stream: print(" Finding all program variables...", file=verbose_stream) print(" Scanning", file=verbose_stream) - for idx, kernel in enumerate(input_files): - if not GlobalConfig.hasHBM: - if verbose_stream: - print(f" {idx + 1}/{len(input_files)}", kernel.cinst, - file=verbose_stream) - # load next CInst kernel and scan for variables used in SPAD - kernel_cinstrs = loader.load_cinst_kernel_from_file(kernel.cinst) - for var_name in variable_discovery.discoverVariablesSPAD(kernel_cinstrs): - mem_model.addVariable(var_name) - else: - if verbose_stream: - print(f" {idx + 1}/{len(input_files)}", kernel.minst, - file=verbose_stream) - # Load next MInst kernel and scan for variables used - kernel_minstrs = loader.load_minst_kernel_from_file(kernel.minst) - for var_name in variable_discovery.discoverVariables(kernel_minstrs): - mem_model.addVariable(var_name) - - # check that all non-keygen variables from MemInfo are used - for var_name in mem_model.mem_info_vars: - if var_name not in mem_model.variables: - if GlobalConfig.hasHBM or var_name not in mem_model.mem_info_meta: # skip checking meta vars when no HBM - raise RuntimeError(f'Unused variable from input mem file: "{var_name}" not in memory model.') + scan_variables(input_files, mem_model, verbose_stream) + check_unused_variables(mem_model) if verbose_stream: print(f" Variables found: {len(mem_model.variables)}", file=verbose_stream) - - if verbose_stream: print("Linking started", file=verbose_stream) - # open the output files - with open(output_files.minst, 'w') as fnum_output_minst, \ - open(output_files.cinst, 'w') as fnum_output_cinst, \ - open(output_files.xinst, 'w') as fnum_output_xinst: - - # prepare the linker class - result_program = program_linker.LinkedProgram(fnum_output_minst, - fnum_output_cinst, - fnum_output_xinst, - mem_model, - supress_comments=run_config.suppress_comments) - # start linking each kernel - for idx, kernel in enumerate(input_files): - if verbose_stream: - print("[ {: >3}% ]".format(idx * 100 // len(input_files)), kernel.prefix, - file=verbose_stream) - kernel_minstrs = loader.load_minst_kernel_from_file(kernel.minst) - kernel_cinstrs = loader.load_cinst_kernel_from_file(kernel.cinst) - kernel_xinstrs = loader.load_xinst_kernel_from_file(kernel.xinst) - - result_program.linkKernel(kernel_minstrs, kernel_cinstrs, kernel_xinstrs) - - if verbose_stream: - print("[ 100% ] Finalizing output", output_files.prefix, file=verbose_stream) - - # signal that we have linked all kernels - result_program.close() + link_kernels(input_files, output_files, mem_model, verbose_stream) + # Write the memory model to the output file if run_config.find_mem_files: - # Write the memory model to the output file - if verbose_stream: - print("Writing memory model to", output_files.mem, file=verbose_stream) + BaseInstruction.dump_instructions_to_file(kernel_dinstrs, output_files.mem) if verbose_stream: @@ -357,75 +404,143 @@ def main(run_config: LinkerRunConfig, verbose_stream = None): if run_config.find_mem_files: print(" ", output_files.mem, file=verbose_stream) + def parse_args(): """ - Parses command-line arguments for the linker script. + @brief Parses command-line arguments for the linker script. This function sets up the argument parser and defines the expected arguments for the script. It returns a Namespace object containing the parsed arguments. - Returns: - argparse.Namespace: Parsed command-line arguments. + @return argparse.Namespace Parsed command-line arguments. """ parser = argparse.ArgumentParser( - description=("HERACLES Linker.\n" - "Links assembled kernels into a full HERACLES program " - "for each of the three execution queues: MINST, CINST, and XINST.\n\n" - "To link several kernels, specify each kernel's input prefix in order. " - "Variables that should carry on across kernels should be have the same name. " - "Linker will recognize matching variables and keep their values between kernels. " - "Variables that are inputs and outputs (and metadata) for the whole program must " - "be indicated in the input memory mapping file.")) - parser.add_argument("input_prefixes", nargs="+", - help=("List of input prefixes, including full path. For an input prefix, linker will " - "assume three files exist named `input_prefixes[i] + '.minst'`, " - "`input_prefixes[i] + '.cinst'`, and `input_prefixes[i] + '.xinst'`.")) - parser.add_argument("--mem_spec", default="", dest="mem_spec_file", - help=("Input Mem specification (.json) file.")) - parser.add_argument("--isa_spec", default="", dest="isa_spec_file", - help=("Input ISA specification (.json) file.")) - parser.add_argument("--find_mem_files", action="store_true", dest="find_mem_files", - help=("Tells the linker to find a memory file (*.tw.mem) for each input prefix given." - "This can be used to link multiple kernels together. " - "If this flag is not set, the linker will use the input_mem_file argument instead")) - parser.add_argument("-im", "--input_mem_file", dest="input_mem_file", required=False, - help=("Input memory mapping file associated with the resulting program. " - "Specifies the names for input, output, and metadata variables for a single kernel" - " or also a full program if instead this is used to link multiple kernels together.")) - parser.add_argument("-o", "--output_prefix", dest="output_prefix", required=True, - help=("Prefix for the output file names. " - "Three files will be generated: \n" - "`output_dir/output_prefix.minst`, `output_dir/output_prefix.cinst`, and " - "`output_dir/output_prefix.xinst`. \n" - "Output filenames cannot match input file names.")) - parser.add_argument("-od", "--output_dir", dest="output_dir", default="", - help=("Directory where to store all intermediate files and final output. " - "This will be created if it doesn't exists. " - "Defaults to current working directory.")) + description=( + "HERACLES Linker.\n" + "Links assembled kernels into a full HERACLES program " + "for each of the three execution queues: MINST, CINST, and XINST.\n\n" + "To link several kernels, specify each kernel's input prefix in order. " + "Variables that should carry on across kernels should be have the same name. " + "Linker will recognize matching variables and keep their values between kernels. " + "Variables that are inputs and outputs (and metadata) for the whole program must " + "be indicated in the input memory mapping file." + ) + ) + parser.add_argument( + "input_prefixes", + nargs="+", + help=( + "List of input prefixes, including full path. For an input prefix, linker will " + "assume three files exist named `input_prefixes[i] + '.minst'`, " + "`input_prefixes[i] + '.cinst'`, and `input_prefixes[i] + '.xinst'`." + ), + ) + parser.add_argument( + "--mem_spec", + default="", + dest="mem_spec_file", + help=("Input Mem specification (.json) file."), + ) + parser.add_argument( + "--isa_spec", + default="", + dest="isa_spec_file", + help=("Input ISA specification (.json) file."), + ) + parser.add_argument( + "--find_mem_files", + action="store_true", + dest="find_mem_files", + help=( + "Tells the linker to find a memory file (*.tw.mem) for each input prefix given." + "This can be used to link multiple kernels together. " + "If this flag is not set, the linker will use the input_mem_file argument instead" + ), + ) + parser.add_argument( + "-im", + "--input_mem_file", + dest="input_mem_file", + required=False, + default="", + help=( + "Input memory mapping file associated with the resulting program. " + "Specifies the names for input, output, and metadata variables for a single kernel" + " or also a full program if instead this is used to link multiple kernels together." + ), + ) + parser.add_argument( + "-o", + "--output_prefix", + dest="output_prefix", + required=True, + help=( + "Prefix for the output file names. " + "Three files will be generated: \n" + "`output_dir/output_prefix.minst`, `output_dir/output_prefix.cinst`, and " + "`output_dir/output_prefix.xinst`. \n" + "Output filenames cannot match input file names." + ), + ) + parser.add_argument( + "-od", + "--output_dir", + dest="output_dir", + default="", + help=( + "Directory where to store all intermediate files and final output. " + "This will be created if it doesn't exists. " + "Defaults to current working directory." + ), + ) parser.add_argument("--hbm_size", type=int, help="HBM size in KB.") - parser.add_argument("--no_hbm", dest="has_hbm", action="store_false", - help="If set, this flag tells he_prep there is no HBM in the target chip.") - parser.add_argument("--suppress_comments", "--no_comments", dest="suppress_comments", action="store_true", - help=("When enabled, no comments will be emited on the output generated.")) - parser.add_argument("-v", "--verbose", dest="verbose", action="count", default=0, - help=("If enabled, extra information and progress reports are printed to stdout. " - "Increase level of verbosity by specifying flag multiple times, e.g. -vv")) - args = parser.parse_args() + parser.add_argument( + "--no_hbm", + dest="has_hbm", + action="store_false", + help="If set, this flag tells he_prep there is no HBM in the target chip.", + ) + parser.add_argument( + "--suppress_comments", + "--no_comments", + dest="suppress_comments", + action="store_true", + help=("When enabled, no comments will be emitted on the output generated."), + ) + parser.add_argument( + "-v", + "--verbose", + dest="verbose", + action="count", + default=0, + help=( + "If enabled, extra information and progress reports are printed to stdout. " + "Increase level of verbosity by specifying flag multiple times, e.g. -vv" + ), + ) + p_args = parser.parse_args() # Enforce input_mem_file only if find_mem_files is not set - if not args.find_mem_files and not args.input_mem_file: - parser.error("the following arguments are required: -im/--input_mem_file (unless --find_mem_files is set)") + if not p_args.find_mem_files and p_args.input_mem_file == "": + parser.error( + "the following arguments are required: -im/--input_mem_file (unless --find_mem_files is set)" + ) + + return p_args - return args if __name__ == "__main__": module_dir = os.path.dirname(__file__) module_name = os.path.basename(__file__) args = parse_args() - args.mem_spec_file = MemSpecConfig.initialize_mem_spec(module_dir, args.mem_spec_file) - args.isa_spec_file = ISASpecConfig.initialize_isa_spec(module_dir, args.isa_spec_file) - config = LinkerRunConfig(**vars(args)) # convert argsparser into a dictionary + args.mem_spec_file = MemSpecConfig.initialize_mem_spec( + module_dir, args.mem_spec_file + ) + args.isa_spec_file = ISASpecConfig.initialize_isa_spec( + module_dir, args.isa_spec_file + ) + config = LinkerRunConfig(**vars(args)) # convert argsparser into a dictionary if args.verbose > 0: print(module_name) diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/__init__.py b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/__init__.py index 92f5f9e3..3d21769b 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/__init__.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/__init__.py @@ -1,12 +1,18 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""This module provides functionality to create and manage data instructions""" + +from assembler.instructions import tokenize_from_line +from assembler.memory_model.mem_info import MemInfo from . import dload, dstore, dkeygen from . import dinstruction -from assembler.instructions import tokenizeFromLine -from assembler.memory_model.mem_info import MemInfo, MemInfoVariable DLoad = dload.Instruction DStore = dstore.Instruction DKeyGen = dkeygen.Instruction + def factory() -> set: """ Creates a set of all DInstruction classes. @@ -14,7 +20,8 @@ def factory() -> set: Returns: set: A set containing all DInstruction classes. """ - return { DLoad, DStore, DKeyGen } + return {DLoad, DStore, DKeyGen} + def create_from_mem_line(line: str) -> dinstruction.DInstruction: """ @@ -27,8 +34,8 @@ def create_from_mem_line(line: str) -> dinstruction.DInstruction: DInstruction or None: The parsed DInstruction object, or None if no object could be parsed from the specified input line. """ - retval = None - tokens, comment = tokenizeFromLine(line) + retval: dinstruction.DInstruction = None + tokens, comment = tokenize_from_line(line) for instr_type in factory(): try: retval = instr_type(tokens, comment) @@ -46,4 +53,4 @@ def create_from_mem_line(line: str) -> dinstruction.DInstruction: retval.var = miv_dict["var_name"] retval.address = miv_dict["hbm_address"] - return retval \ No newline at end of file + return retval diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py b/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py index bb51287a..768aa35a 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py @@ -1,7 +1,15 @@ -from assembler.common.decorators import * +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +Base class for all instructions in the linker. +""" + +from assembler.common.decorators import classproperty from assembler.common.counter import Counter from assembler.common.config import GlobalConfig + class BaseInstruction: """ Base class for all instructions. @@ -23,20 +31,22 @@ class BaseInstruction: Retrieves the string form of the instruction to write to the instruction file. """ - __id_count = Counter.count(0) # Internal unique sequence counter to generate unique IDs + __id_count = Counter.count( + 0 + ) # Internal unique sequence counter to generate unique IDs # Class methods and properties # ---------------------------- @classproperty - def name(cls) -> str: + def name(self) -> str: """ Name for the instruction. Returns: str: The name of the instruction. """ - return cls._get_name() + return self._get_name() @classmethod def _get_name(cls) -> str: @@ -50,7 +60,7 @@ def _get_name(cls) -> str: raise NotImplementedError() @classproperty - def name_token_index(cls) -> int: + def name_token_index(self) -> int: """ Index for the token containing the name of the instruction in the list of tokens. @@ -58,7 +68,7 @@ def name_token_index(cls) -> int: Returns: int: The index of the name token. """ - return cls._get_name_token_index() + return self._get_name_token_index() @classmethod def _get_name_token_index(cls) -> int: @@ -73,14 +83,14 @@ def _get_name_token_index(cls) -> int: raise NotImplementedError() @classproperty - def num_tokens(cls) -> int: + def num_tokens(self) -> int: """ Number of tokens required for this instruction. Returns: int: The number of tokens required. """ - return cls._get_num_tokens() + return self._get_num_tokens() @classmethod def _get_num_tokens(cls) -> int: @@ -104,9 +114,9 @@ def dump_instructions_to_file(cls, instructions: list, filename: str): instructions (list): List of instruction objects (must have a to_line() method). filename (str): Path to the output file. """ - with open(filename, 'w') as f: + with open(filename, "w", encoding="utf-8") as f: for instr in instructions: - f.write(instr.to_line() + '\n') + f.write(instr.to_line() + "\n") # Constructor # ----------- @@ -124,12 +134,18 @@ def __init__(self, tokens: list, comment: str = ""): """ assert self.name_token_index < self.num_tokens - if len(tokens) != self.num_tokens: - raise ValueError((f"`tokens`: invalid amount of tokens. " - f"Instruction {self.name} requires less " - f"than {self.num_tokens}, but {len(tokens)} received")) - if tokens[self.name_token_index] != self.name: - raise ValueError(f"`tokens`: invalid name. Expected {self.name}, but {tokens[self.name_token_index]} received") + if len(tokens) != self.num_tokens: # pylint: disable=W0143 + raise ValueError( + ( + f"`tokens`: invalid amount of tokens. " + f"Instruction {self.name} requires less " + f"than {self.num_tokens}, but {len(tokens)} received" + ) + ) + if tokens[self.name_token_index] != self.name: # pylint: disable=W0143 + raise ValueError( + f"`tokens`: invalid name. Expected {self.name}, but {tokens[self.name_token_index]} received" + ) self._id = next(BaseInstruction.__id_count) @@ -137,7 +153,7 @@ def __init__(self, tokens: list, comment: str = ""): self.comment = comment def __repr__(self): - retval = (f"<{type(self).__name__}({self.name}, id={self.id}) object at {hex(id(self))}>(tokens={self.tokens})") + retval = f"<{type(self).__name__}({self.name}, id={self.id}) object at {hex(id(self))}>(tokens={self.tokens})" return retval def __eq__(self, other): @@ -148,7 +164,7 @@ def __hash__(self): return hash(self.id) def __str__(self): - return f'{self.name}({self.id})' + return f"{self.name}({self.id})" # Methods and properties # ---------------------------- @@ -182,8 +198,9 @@ def to_line(self) -> str: Returns: str: The string representation of the instruction. """ - if not GlobalConfig.suppressComments: + comment_str = "" + if not GlobalConfig.suppress_comments: comment_str = f" # {self.comment}" if self.comment else "" tokens_str = ", ".join(self._tokens) - return f"{tokens_str}{comment_str}" \ No newline at end of file + return f"{tokens_str}{comment_str}" diff --git a/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py b/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py index e80cea99..9ed8603d 100644 --- a/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py +++ b/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py @@ -1,14 +1,22 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + +"""This module provides functionality to link kernels into a program.""" from linker import MemoryModel from linker.instructions import minst, cinst, dinst from linker.instructions.dinst.dinstruction import DInstruction from assembler.common.config import GlobalConfig from assembler.instructions import cinst as ISACInst -from assembler.memory_model.mem_info import MemInfo -class LinkedProgram: + +class LinkedProgram: # pylint: disable=too-many-instance-attributes """ - Encapsulates a linked program. + @class LinkedProgram + @brief Encapsulates a linked program. This class offers facilities to track and link kernels, and outputs the linked program to specified output streams as kernels @@ -17,130 +25,146 @@ class LinkedProgram: The program itself is not contained in this object. """ - def __init__(self, - program_minst_ostream, - program_cinst_ostream, - program_xinst_ostream, - mem_model: MemoryModel, - supress_comments: bool): + def __init__( + self, + program_minst_ostream, + program_cinst_ostream, + program_xinst_ostream, + mem_model: MemoryModel, + ): """ - Initializes a LinkedProgram object. - - Parameters: - program_minst_ostream: Output stream for MInst instructions. - program_cinst_ostream: Output stream for CInst instructions. - program_xinst_ostream: Output stream for XInst instructions. - mem_model (MemoryModel): Correctly initialized linker memory model. It must already contain the - variables used throughout the program and their usage. - This memory model will be modified by this object when linking kernels. - supress_comments (bool): Whether to suppress comments in the output. + @brief Initializes a LinkedProgram object. + + @param program_minst_ostream Output stream for MInst instructions. + @param program_cinst_ostream Output stream for CInst instructions. + @param program_xinst_ostream Output stream for XInst instructions. + @param mem_model (MemoryModel): Correctly initialized linker memory model. It must already contain the + variables used throughout the program and their usage. + This memory model will be modified by this object when linking kernels. + @param suppress_comments (bool): Whether to suppress comments in the output. """ - self.__minst_ostream = program_minst_ostream - self.__cinst_ostream = program_cinst_ostream - self.__xinst_ostream = program_xinst_ostream - self.__mem_model = mem_model - self.__supress_comments = supress_comments - self.__bundle_offset = 0 + self.__minst_ostream = program_minst_ostream + self.__cinst_ostream = program_cinst_ostream + self.__xinst_ostream = program_xinst_ostream + self.__mem_model = mem_model + self.__bundle_offset = 0 self.__minst_line_offset = 0 self.__cinst_line_offset = 0 - self.__xinst_line_offset = 0 - self.__kernel_count = 0 # Number of kernels linked into this program - self.__is_open = True # Tracks whether this program is still accepting kernels to link + self.__kernel_count = 0 # Number of kernels linked into this program + self.__is_open = ( + True # Tracks whether this program is still accepting kernels to link + ) @property - def isOpen(self) -> bool: + def is_open(self) -> bool: """ - Checks if the program is open for linking new kernels. + @brief Checks if the program is open for linking new kernels. - Returns: - bool: True if the program is open, False otherwise. + @return bool True if the program is open, False otherwise. """ return self.__is_open - @property - def supressComments(self) -> bool: - """ - Checks if comments are suppressed in the output. - - Returns: - bool: True if comments are suppressed, False otherwise. - """ - return self.__supress_comments - def close(self): """ - Completes the program by terminating the queues with the correct exit code. + @brief Completes the program by terminating the queues with the correct exit code. Program will not accept new kernels to link after this call. - Raises: - RuntimeError: If the program is already closed. + @exception RuntimeError If the program is already closed. """ - if not self.isOpen: - raise RuntimeError('Program is already closed.') + if not self.is_open: + raise RuntimeError("Program is already closed.") # Add closing `cexit` tokens = [str(self.__cinst_line_offset), cinst.CExit.name] cexit_cinstr = cinst.CExit(tokens) - print(f'{cexit_cinstr.tokens[0]}, {cexit_cinstr.to_line()}', file=self.__cinst_ostream) + print( + f"{cexit_cinstr.tokens[0]}, {cexit_cinstr.to_line()}", + file=self.__cinst_ostream, + ) # Add closing msyncc - tokens = [str(self.__minst_line_offset), minst.MSyncc.name, str(self.__cinst_line_offset + 1)] + tokens = [ + str(self.__minst_line_offset), + minst.MSyncc.name, + str(self.__cinst_line_offset + 1), + ] cmsyncc_minstr = minst.MSyncc(tokens) - print(f'{cmsyncc_minstr.tokens[0]}, {cmsyncc_minstr.to_line()}', end="", file=self.__minst_ostream) - if not self.supressComments: - print(' # terminating MInstQ', end="", file=self.__minst_ostream) + print( + f"{cmsyncc_minstr.tokens[0]}, {cmsyncc_minstr.to_line()}", + end="", + file=self.__minst_ostream, + ) + if not GlobalConfig.suppress_comments: + print(" # terminating MInstQ", end="", file=self.__minst_ostream) print(file=self.__minst_ostream) # Program has been closed self.__is_open = False - def __validateHBMAddress(self, var_name: str, hbm_address: int): + def _validate_hbm_address(self, var_name: str, hbm_address: int): """ - Validates the HBM address for a variable. + @brief Validates the HBM address for a variable. - Parameters: - var_name (str): The name of the variable. - hbm_address (int): The HBM address to validate. + @param var_name The name of the variable. + @param hbm_address The HBM address to validate. - Raises: - RuntimeError: If the HBM address is invalid or does not match the declared address. + @exception RuntimeError If the HBM address is invalid or does not match the declared address. """ if hbm_address < 0: - raise RuntimeError(f'Invalid negative HBM address for variable "{var_name}".') + raise RuntimeError( + f'Invalid negative HBM address for variable "{var_name}".' + ) if var_name in self.__mem_model.mem_info_vars: if self.__mem_model.mem_info_vars[var_name].hbm_address != hbm_address: - raise RuntimeError(('Declared HBM address ({}) of mem Variable "{}"' - ' differs from allocated HBM address ({}).').format(self.__mem_model.mem_info_vars[var_name].hbm_address, - var_name, - hbm_address)) + raise RuntimeError( + ( + f"Declared HBM address " + f"({self.__mem_model.mem_info_vars[var_name].hbm_address})" + f" of mem Variable '{var_name}'" + f" differs from allocated HBM address ({hbm_address})." + ) + ) + + def _validate_spad_address(self, var_name: str, spad_address: int): + """ + @brief Validates the SPAD address for a variable (only available when no HBM). - def __validateSPADAddress(self, var_name: str, spad_address: int): + @param var_name The name of the variable. + @param spad_address The SPAD address to validate. + + @exception RuntimeError If the SPAD address is invalid or does not match the declared address. + """ # only available when no HBM assert not GlobalConfig.hasHBM # this method will validate the variable SPAD address against the - # original HBM address, since ther is no HBM + # original HBM address, since there is no HBM if spad_address < 0: - raise RuntimeError(f'Invalid negative SPAD address for variable "{var_name}".') + raise RuntimeError( + f'Invalid negative SPAD address for variable "{var_name}".' + ) if var_name in self.__mem_model.mem_info_vars: if self.__mem_model.mem_info_vars[var_name].hbm_address != spad_address: - raise RuntimeError(('Declared HBM address ({}) of mem Variable "{}"' - ' differs from allocated HBM address ({}).').format(self.__mem_model.mem_info_vars[var_name].hbm_address, - var_name, - spad_address)) - - def __updateMInsts(self, kernel_minstrs: list): + raise RuntimeError( + ( + f"Declared HBM address" + f" ({self.__mem_model.mem_info_vars[var_name].hbm_address})" + f" of mem Variable '{var_name}'" + f" differs from allocated HBM address ({spad_address})." + ) + ) + + def _update_minsts(self, kernel_minstrs: list): """ - Updates the MInsts in the kernel to offset to the current expected + @brief Updates the MInsts in the kernel to offset to the current expected synchronization points, and convert variable placeholders/names into the corresponding HBM address. All MInsts in the kernel are expected to synchronize with CInsts starting at line 0. - Does not change the `LinkedProgram` object. + Does not change the LinkedProgram object. - Parameters: - kernel_minstrs (list): List of MInstructions to update. + @param kernel_minstrs List of MInstructions to update. """ for minstr in kernel_minstrs: # Update msyncc @@ -149,110 +173,134 @@ def __updateMInsts(self, kernel_minstrs: list): # Change mload variable names into HBM addresses if isinstance(minstr, minst.MLoad): var_name = minstr.source - hbm_address = self.__mem_model.useVariable(var_name, self.__kernel_count) - self.__validateHBMAddress(var_name, hbm_address) + hbm_address = self.__mem_model.useVariable( + var_name, self.__kernel_count + ) + self._validate_hbm_address(var_name, hbm_address) minstr.source = str(hbm_address) - minstr.comment = f" var: {var_name} - HBM({hbm_address})" + f";{minstr.comment}" if minstr.comment else "" + minstr.comment = ( + f" var: {var_name} - HBM({hbm_address})" + f";{minstr.comment}" + if minstr.comment + else "" + ) # Change mstore variable names into HBM addresses if isinstance(minstr, minst.MStore): var_name = minstr.dest - hbm_address = self.__mem_model.useVariable(var_name, self.__kernel_count) - self.__validateHBMAddress(var_name, hbm_address) + hbm_address = self.__mem_model.useVariable( + var_name, self.__kernel_count + ) + self._validate_hbm_address(var_name, hbm_address) minstr.dest = str(hbm_address) - minstr.comment = f" var: {var_name} - HBM({hbm_address})" + f";{minstr.comment}" if minstr.comment else "" + minstr.comment = ( + f" var: {var_name} - HBM({hbm_address})" + f";{minstr.comment}" + if minstr.comment + else "" + ) - def __updateCInsts(self, kernel_cinstrs: list): + def _remove_and_merge_csyncm_cnop(self, kernel_cinstrs: list): """ - Updates the CInsts in the kernel to offset to the current expected bundle - and synchronization points. + @brief Remove csyncm instructions and merge consecutive cnop instructions. + + @param kernel_cinstrs List of CInstructions to process. + """ + i = 0 + current_bundle = 0 + csyncm_count = 0 + while i < len(kernel_cinstrs): + cinstr = kernel_cinstrs[i] + cinstr.tokens[0] = i # Update the line number + + # ------------------------------ + # This code block will remove csyncm instructions and keep track, + # later adding their throughput into a cnop instruction before + # a new bundle is fetched. + + if isinstance(cinstr, cinst.CNop): + # Add the missing cycles to any cnop we encounter up to this point + cinstr.cycles += csyncm_count * ISACInst.CSyncm.get_throughput() + # Idle cycles to account for the csyncm have been added + csyncm_count = 0 + + if isinstance(cinstr, (cinst.IFetch, cinst.NLoad, cinst.BLoad)): + if csyncm_count > 0: + # Extra cycles needed before scheduling next bundle + # Subtract 1 because cnop n, waits for n+1 cycles + cinstr_nop = cinst.CNop( + [ + i, + cinst.CNop.name, + str(csyncm_count * ISACInst.CSyncm.get_throughput() - 1), + ] + ) + kernel_cinstrs.insert(i, cinstr_nop) + csyncm_count = 0 + i += 1 + if isinstance(cinstr, cinst.IFetch): + current_bundle = cinstr.bundle + 1 + # Update the line number + cinstr.tokens[0] = i + + if isinstance(cinstr, cinst.CSyncm): + # Remove instruction + kernel_cinstrs.pop(i) + if current_bundle > 0: + csyncm_count += 1 + else: + i += 1 + + # ------------------------------ + # This code block differs from previous in that csyncm instructions + # are replaced in place by cnops with the corresponding throughput. + # This may result in several continuous cnop instructions, so, + # the cnop merging code afterwards is needed to remove this side effect + # if contiguous cnops are not desired. + + # if isinstance(cinstr, cinst.IFetch): + # current_bundle = cinstr.bundle + 1 + # + # if isinstance(cinstr, cinst.CSyncm): + # # replace instruction by cnop + # kernel_cinstrs.pop(i) + # if current_bundle > 0: + # cinstr_nop = cinst.CNop([i, cinst.CNop.name, str(ISACInst.CSyncm.get_throughput())]) # Subtract 1 because cnop n, waits for n+1 cycles + # kernel_cinstrs.insert(i, cinstr_nop) + # + # i += 1 # next instruction + + # Merge continuous cnop + i = 0 + while i < len(kernel_cinstrs): + cinstr = kernel_cinstrs[i] + cinstr.tokens[0] = i + if isinstance(cinstr, cinst.CNop): + # Do look ahead + if i + 1 < len(kernel_cinstrs): + if isinstance(kernel_cinstrs[i + 1], cinst.CNop): + # Add 1 because cnop n, waits for n+1 cycles + kernel_cinstrs[i + 1].cycles += cinstr.cycles + 1 + kernel_cinstrs.pop(i) + i -= 1 + i += 1 + + def _update_cinsts_addresses_and_offsets(self, kernel_cinstrs: list): + """ + @brief Updates bundle/target offsets and variable names to addresses for CInsts. All CInsts in the kernel are expected to start at bundle 0, and to synchronize with MInsts starting at line 0. - Does not change the `LinkedProgram` object. + Does not change the LinkedProgram object. - Parameters: - kernel_cinstrs (list): List of CInstructions to update. + @param kernel_cinstrs List of CInstructions to update. """ - - if not GlobalConfig.hasHBM: - # Remove csyncm instructions - i = 0 - current_bundle = 0 - csyncm_count = 0 # Used by 1st code block: plz remove if second code block ends up being the one used - while i < len(kernel_cinstrs): - cinstr = kernel_cinstrs[i] - cinstr.tokens[0] = i # Update the line number - - #------------------------------ - # This code block will remove csyncm instructions and keep track, - # later adding their throughput into a cnop instruction before - # a new bundle is fetched. - - if isinstance(cinstr, cinst.CNop): - # Add the missing cycles to any cnop we encounter up to this point - cinstr.cycles += (csyncm_count * ISACInst.CSyncm.get_throughput()) - csyncm_count = 0 # Idle cycles to account for the csyncm have been added - - if isinstance(cinstr, (cinst.IFetch, cinst.NLoad, cinst.BLoad)): - if csyncm_count > 0: - # Extra cycles needed before scheduling next bundle - cinstr_nop = cinst.CNop([i, cinst.CNop.name, str(csyncm_count * ISACInst.CSyncm.get_throughput() - 1)]) # Subtract 1 because cnop n, waits for n+1 cycles - kernel_cinstrs.insert(i, cinstr_nop) - csyncm_count = 0 # Idle cycles to account for the csyncm have been added - i += 1 - if isinstance(cinstr, cinst.IFetch): - current_bundle = cinstr.bundle + 1 - cinstr.tokens[0] = i # Update the line number - - if isinstance(cinstr, cinst.CSyncm): - # Remove instruction - kernel_cinstrs.pop(i) - if current_bundle > 0: - csyncm_count += 1 - else: - i += 1 # Next instruction - - #------------------------------ - # This code block differs from previous in that csyncm instructions - # are replaced in place by cnops with the corresponding throughput. - # This may result in several continuous cnop instructions, so, - # the cnop merging code afterwards is needed to remove this side effect - # if contiguous cnops are not desired. - - # if isinstance(cinstr, cinst.IFetch): - # current_bundle = cinstr.bundle + 1 - # - # if isinstance(cinstr, cinst.CSyncm): - # # replace instruction by cnop - # kernel_cinstrs.pop(i) - # if current_bundle > 0: - # cinstr_nop = cinst.CNop([i, cinst.CNop.name, str(ISACInst.CSyncm.get_throughput())]) # Subtract 1 because cnop n, waits for n+1 cycles - # kernel_cinstrs.insert(i, cinstr_nop) - # - # i += 1 # next instruction - - # Merge continuous cnop - i = 0 - while i < len(kernel_cinstrs): - cinstr = kernel_cinstrs[i] - cinstr.tokens[0] = i # Update the line number - - if isinstance(cinstr, cinst.CNop): - # Do look ahead - if i + 1 < len(kernel_cinstrs): - if isinstance(kernel_cinstrs[i + 1], cinst.CNop): - kernel_cinstrs[i + 1].cycles += (cinstr.cycles + 1) # Add 1 because cnop n, waits for n+1 cycles - kernel_cinstrs.pop(i) - i -= 1 - i += 1 - for cinstr in kernel_cinstrs: # Update ifetch if isinstance(cinstr, cinst.IFetch): cinstr.bundle = cinstr.bundle + self.__bundle_offset # Update xinstfetch if isinstance(cinstr, cinst.XInstFetch): - raise NotImplementedError('`xinstfetch` not currently supported by linker.') + raise NotImplementedError( + "`xinstfetch` not currently supported by linker." + ) # Update csyncm if isinstance(cinstr, cinst.CSyncm): cinstr.target = cinstr.target + self.__minst_line_offset @@ -260,125 +308,157 @@ def __updateCInsts(self, kernel_cinstrs: list): if not GlobalConfig.hasHBM: # update all SPAD instruction variable names to be SPAD addresses # change xload variable names into SPAD addresses - if isinstance(cinstr, (cinst.BLoad, cinst.BOnes, cinst.CLoad, cinst.NLoad)): + if isinstance( + cinstr, (cinst.BLoad, cinst.BOnes, cinst.CLoad, cinst.NLoad) + ): var_name = cinstr.source - hbm_address = self.__mem_model.useVariable(var_name, self.__kernel_count) - self.__validateSPADAddress(var_name, hbm_address) + hbm_address = self.__mem_model.useVariable( + var_name, self.__kernel_count + ) + self._validate_spad_address(var_name, hbm_address) cinstr.source = str(hbm_address) - cinstr.comment = f" var: {var_name} - HBM({hbm_address})" + f";{cinstr.comment}" if cinstr.comment else "" + cinstr.comment = ( + f" var: {var_name} - HBM({hbm_address})" + f";{cinstr.comment}" + if cinstr.comment + else "" + ) if isinstance(cinstr, cinst.CStore): var_name = cinstr.dest - hbm_address = self.__mem_model.useVariable(var_name, self.__kernel_count) - self.__validateSPADAddress(var_name, hbm_address) + hbm_address = self.__mem_model.useVariable( + var_name, self.__kernel_count + ) + self._validate_spad_address(var_name, hbm_address) cinstr.dest = str(hbm_address) - cinstr.comment = f" var: {var_name} - HBM({hbm_address})" + f";{cinstr.comment}" if cinstr.comment else "" + cinstr.comment = ( + f" var: {var_name} - HBM({hbm_address})" + f";{cinstr.comment}" + if cinstr.comment + else "" + ) + + def _update_cinsts(self, kernel_cinstrs: list): + """ + @brief Updates the CInsts in the kernel to offset to the current expected bundle + and synchronization points. + + All CInsts in the kernel are expected to start at bundle 0, and to + synchronize with MInsts starting at line 0. + Does not change the LinkedProgram object. + + @param kernel_cinstrs List of CInstructions to update. + """ + if not GlobalConfig.hasHBM: + self._remove_and_merge_csyncm_cnop(kernel_cinstrs) + + self._update_cinsts_addresses_and_offsets(kernel_cinstrs) - def __updateXInsts(self, kernel_xinstrs: list) -> int: + def _update_xinsts(self, kernel_xinstrs: list) -> int: """ - Updates the XInsts in the kernel to offset to the current expected bundle. + @brief Updates the XInsts in the kernel to offset to the current expected bundle. All XInsts in the kernel are expected to start at bundle 0. - Does not change the `LinkedProgram` object. + Does not change the LinkedProgram object. - Parameters: - kernel_xinstrs (list): List of XInstructions to update. + @param kernel_xinstrs List of XInstructions to update. - Returns: - int: The last bundle number after updating. + @return int The last bundle number after updating. """ last_bundle = self.__bundle_offset for xinstr in kernel_xinstrs: xinstr.bundle = xinstr.bundle + self.__bundle_offset if last_bundle > xinstr.bundle: - raise RuntimeError(f'Detected invalid bundle. Instruction bundle is less than previous: "{xinstr.to_line()}"') + raise RuntimeError( + f'Detected invalid bundle. Instruction bundle is less than previous: "{xinstr.to_line()}"' + ) last_bundle = xinstr.bundle return last_bundle - def linkKernel(self, - kernel_minstrs: list, - kernel_cinstrs: list, - kernel_xinstrs: list): + def link_kernel( + self, kernel_minstrs: list, kernel_cinstrs: list, kernel_xinstrs: list + ): """ - Links a specified kernel (given by its three instruction queues) into this + @brief Links a specified kernel (given by its three instruction queues) into this program. The adjusted kernels will be appended into the output streams specified during construction of this object. - Parameters: - kernel_minstrs (list): List of MInstructions for the MInst Queue corresponding to the kernel to link. - These instructions will be modified by this method. - kernel_cinstrs (list): List of CInstructions for the CInst Queue corresponding to the kernel to link. - These instructions will be modified by this method. - kernel_xinstrs (list): List of XInstructions for the XInst Queue corresponding to the kernel to link. - These instructions will be modified by this method. + @param kernel_minstrs List of MInstructions for the MInst Queue corresponding to the kernel to link. + These instructions will be modified by this method. + @param kernel_cinstrs List of CInstructions for the CInst Queue corresponding to the kernel to link. + These instructions will be modified by this method. + @param kernel_xinstrs List of XInstructions for the XInst Queue corresponding to the kernel to link. + These instructions will be modified by this method. - Raises: - RuntimeError: If the program is closed and does not accept new kernels. + @exception RuntimeError If the program is closed and does not accept new kernels. """ - if not self.isOpen: - raise RuntimeError('Program is closed and does not accept new kernels.') + if not self.is_open: + raise RuntimeError("Program is closed and does not accept new kernels.") # No minsts without HBM if not GlobalConfig.hasHBM: kernel_minstrs = [] - self.__updateMInsts(kernel_minstrs) - self.__updateCInsts(kernel_cinstrs) - self.__bundle_offset = self.__updateXInsts(kernel_xinstrs) + 1 + self._update_minsts(kernel_minstrs) + self._update_cinsts(kernel_cinstrs) + self.__bundle_offset = self._update_xinsts(kernel_xinstrs) + 1 # Append the kernel to the output for xinstr in kernel_xinstrs: print(xinstr.to_line(), end="", file=self.__xinst_ostream) - if not self.supressComments and xinstr.comment: - print(f' #{xinstr.comment}', end="", file=self.__xinst_ostream) + if not GlobalConfig.suppress_comments and xinstr.comment: + print(f" #{xinstr.comment}", end="", file=self.__xinst_ostream) print(file=self.__xinst_ostream) for idx, cinstr in enumerate(kernel_cinstrs[:-1]): # Skip the `cexit` line_no = idx + self.__cinst_line_offset - print(f'{line_no}, {cinstr.to_line()}', end="", file=self.__cinst_ostream) - if not self.supressComments and cinstr.comment: - print(f' #{cinstr.comment}', end="", file=self.__cinst_ostream) + print(f"{line_no}, {cinstr.to_line()}", end="", file=self.__cinst_ostream) + if not GlobalConfig.suppress_comments and cinstr.comment: + print(f" #{cinstr.comment}", end="", file=self.__cinst_ostream) print(file=self.__cinst_ostream) for idx, minstr in enumerate(kernel_minstrs[:-1]): # Skip the exit `msyncc` line_no = idx + self.__minst_line_offset - print(f'{line_no}, {minstr.to_line()}', end="", file=self.__minst_ostream) - if not self.supressComments and minstr.comment: - print(f' #{minstr.comment}', end="", file=self.__minst_ostream) + print(f"{line_no}, {minstr.to_line()}", end="", file=self.__minst_ostream) + if not GlobalConfig.suppress_comments and minstr.comment: + print(f" #{minstr.comment}", end="", file=self.__minst_ostream) print(file=self.__minst_ostream) - self.__minst_line_offset += (len(kernel_minstrs) - 1) # Subtract last line that is getting removed - self.__cinst_line_offset += (len(kernel_cinstrs) - 1) # Subtract last line that is getting removed + self.__minst_line_offset += ( + len(kernel_minstrs) - 1 + ) # Subtract last line that is getting removed + self.__cinst_line_offset += ( + len(kernel_cinstrs) - 1 + ) # Subtract last line that is getting removed self.__kernel_count += 1 # Count the appended kernel @classmethod - def join_dinst_kernels(cls, kernels_instrs: list[list[DInstruction]]) -> list[DInstruction]: + def join_dinst_kernels( + cls, kernels_instrs: list[list[DInstruction]] + ) -> list[DInstruction]: """ - Joins a list of dinst kernels, consolidating variables that are outputs in one kernel + @brief Joins a list of dinst kernels, consolidating variables that are outputs in one kernel and inputs in the next. This ensures that variables carried across kernels are not duplicated, and their Mem addresses are consistent. - Args: - kernels_instrs (list): List of Kernels' DInstructions lists. + @param kernels_instrs List of Kernels' DInstructions lists. - Returns: - list[DInstructions]: A new instruction list representing the concatenated memory info. + @return list[DInstruction] A new instruction list representing the concatenated memory info. + + @exception ValueError If no DInstructions lists are provided for concatenation. """ - + if not kernels_instrs: raise ValueError("No DInstructions lists provided for concatenation.") - + # Use dictionaries to track unique variables by name - inputs: dict[str: DInstruction] = {} - carry_over_vars: dict[str: DInstruction] = {} + inputs: dict[str, DInstruction] = {} + carry_over_vars: dict[str, DInstruction] = {} mem_address: int = 0 new_kernels_instrs: list[DInstruction] = [] - for k_idx, kernel_instrs in enumerate(kernels_instrs): - - for idx, cur_dinst in enumerate(kernel_instrs): + for kernel_instrs in kernels_instrs: + for cur_dinst in kernel_instrs: # Save the current output instruction to add at the end if isinstance(cur_dinst, dinst.DStore): @@ -390,10 +470,12 @@ def join_dinst_kernels(cls, kernels_instrs: list[list[DInstruction]]) -> list[DI key = cur_dinst.var # Skip if the input is already in carry-over from previous outputs if key in carry_over_vars: - carry_over_vars.pop(key) # Remove from (output) carry-overs since it's now an input + carry_over_vars.pop( + key + ) # Remove from (output) carry-overs since it's now an input continue - # If the input is not (a previous output) in carry-over, add if it's not already (loaded) in inputs + # If the input is not (a previous output) in carry-over, add if it's not already (loaded) in inputs if key not in inputs: inputs[key] = cur_dinst cur_dinst.address = mem_address @@ -403,9 +485,9 @@ def join_dinst_kernels(cls, kernels_instrs: list[list[DInstruction]]) -> list[DI continue # Add remaining carry-over variables to the new instructions - for var in carry_over_vars: - carry_over_vars[var].address = mem_address - new_kernels_instrs.append(carry_over_vars[var]) + for _, var in carry_over_vars.items(): + var.address = mem_address + new_kernels_instrs.append(var) mem_address = mem_address + 1 return new_kernels_instrs From f89d26547421fa2f80ebb881a8728faaba7653d3 Mon Sep 17 00:00:00 2001 From: "Rojas Chaves, Jose" Date: Wed, 2 Jul 2025 17:22:13 +0000 Subject: [PATCH 03/12] Tests for he_link --- .../assembler/common/run_config.py | 15 + .../hec-assembler-tools/he_link.py | 15 +- .../hec-assembler-tools/tests/conftest.py | 22 +- .../tests/unit_tests/test_he_link.py | 678 ++++++++++++++++++ 4 files changed, 723 insertions(+), 7 deletions(-) create mode 100644 assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_link.py diff --git a/assembler_tools/hec-assembler-tools/assembler/common/run_config.py b/assembler_tools/hec-assembler-tools/assembler/common/run_config.py index 54600a56..e5ccaef2 100644 --- a/assembler_tools/hec-assembler-tools/assembler/common/run_config.py +++ b/assembler_tools/hec-assembler-tools/assembler/common/run_config.py @@ -16,6 +16,15 @@ class RunConfig: policies, and other options that affect the behavior of the assembler. """ + # Type annotations for class attributes + has_hbm: bool + hbm_size: int + spad_size: int + repl_policy: str + suppress_comments: bool + use_xinstfetch: bool + debug_verbose: int + __initialized = False # Specifies whether static members have been initialized __default_config = ( {} @@ -111,6 +120,12 @@ def init_default_config(cls): cls.__initialized = True + # For testing only + @classmethod + def reset_class_state(cls): + cls.__initialized = False + cls.__default_config = {} + def __str__(self): """ Returns a string representation of the configuration. diff --git a/assembler_tools/hec-assembler-tools/he_link.py b/assembler_tools/hec-assembler-tools/he_link.py index ba64a947..0d9f39a6 100644 --- a/assembler_tools/hec-assembler-tools/he_link.py +++ b/assembler_tools/hec-assembler-tools/he_link.py @@ -2,8 +2,8 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -# These contents may have been developed with support from one or more Intel-operated -# generative artificial intelligence solutions +# These contents may have been developed with support from one +# or more Intel-operated generative artificial intelligence solutions """ @file he_link.py @@ -55,6 +55,13 @@ class LinkerRunConfig(RunConfig): @return dict The configuration as a dictionary. """ + # Type annotations for class attributes + input_prefixes: list[str] + input_mem_file: str + find_mem_files: bool + output_dir: str + output_prefix: str + __initialized = False # specifies whether static members have been initialized # contains the dictionary of all configuration items supported and their # default value (or None if no default) @@ -89,15 +96,12 @@ def __init__(self, **kwargs): # class members based on configuration for config_name, default_value in self.__default_config.items(): - print(f"ROCHA DEBUG: {config_name} = {default_value}") value = kwargs.get(config_name, default_value) if value is not None: setattr(self, config_name, value) - print(f"ROCHA ADDED: {config_name} = {value}") else: if not hasattr(self, config_name): setattr(self, config_name, default_value) - print(f"ROCHA ADDED: {config_name} = {default_value}") if getattr(self, config_name) is None: raise TypeError( f"Expected value for configuration `{config_name}`, but `None` received." @@ -393,7 +397,6 @@ def main(run_config: LinkerRunConfig, verbose_stream=None): # Write the memory model to the output file if run_config.find_mem_files: - BaseInstruction.dump_instructions_to_file(kernel_dinstrs, output_files.mem) if verbose_stream: diff --git a/assembler_tools/hec-assembler-tools/tests/conftest.py b/assembler_tools/hec-assembler-tools/tests/conftest.py index 5f7b42e1..4ce586a0 100644 --- a/assembler_tools/hec-assembler-tools/tests/conftest.py +++ b/assembler_tools/hec-assembler-tools/tests/conftest.py @@ -2,14 +2,34 @@ # SPDX-License-Identifier: Apache-2.0 """ -Pytest configuration and fixtures for assembler_tools tests. +@file conftest.py +@brief Configuration and fixtures for pytest """ import os +import sys +from unittest.mock import patch + import pytest + from assembler.spec_config.isa_spec import ISASpecConfig from assembler.spec_config.mem_spec import MemSpecConfig +# Add the parent directory to sys.path to make imports work +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + + +@pytest.fixture(autouse=True) +def mock_env_variables(): + """ + @brief Fixture to mock environment variables and provide common mocks + """ + with patch.dict( + "os.environ", + {"PYTHONPATH": "/home/jmrojasc/test/linker_sdk/encrypted-computing-sdk"}, + ): + yield + @pytest.fixture(scope="session", autouse=True) def initialize_specs(): diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_link.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_link.py new file mode 100644 index 00000000..446fa522 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_link.py @@ -0,0 +1,678 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# These contents may have been developed with support from one +# or more Intel-operated generative artificial intelligence solutions + +""" +@file test_he_link.py +@brief Unit tests for the he_link module +""" + +import os +import argparse +from unittest.mock import patch, mock_open, MagicMock, PropertyMock +import pytest + +import he_link +from assembler.common.config import GlobalConfig + + +class TestLinkerRunConfig: + """ + @class TestLinkerRunConfig + @brief Test cases for the LinkerRunConfig class + """ + + def test_init_with_valid_params(self): + """ + @brief Test initialization with valid parameters + """ + # Arrange + kwargs = { + "input_prefixes": ["prefix1", "prefix2"], + "output_prefix": "output_prefix", + "input_mem_file": "input.mem", + "output_dir": "/tmp", + "has_hbm": True, + "hbm_size": 1024, + "suppress_comments": False, + "use_xinstfetch": False, + "find_mem_files": False, + } + + # Act + with patch("he_link.makeUniquePath", side_effect=lambda x: x): + config = he_link.LinkerRunConfig(**kwargs) + + # Assert + assert config.input_prefixes == ["prefix1", "prefix2"] + assert config.output_prefix == "output_prefix" + assert config.input_mem_file == "input.mem" + assert config.output_dir == "/tmp" + assert config.has_hbm is True + assert config.hbm_size == 1024 + assert config.suppress_comments is False + assert config.use_xinstfetch is False + assert config.find_mem_files is False + + def test_init_with_missing_required_param(self): + """ + @brief Test initialization with missing required parameters + """ + # Arrange + kwargs = { + "output_prefix": "output_prefix", + "input_mem_file": "input.mem", + # Missing input_prefixes + } + + # Act & Assert + with pytest.raises(TypeError): + he_link.LinkerRunConfig(**kwargs) + + def test_as_dict(self): + """ + @brief Test the as_dict method returns a proper dictionary + """ + # Arrange + kwargs = { + "input_prefixes": ["prefix1"], + "output_prefix": "output_prefix", + "input_mem_file": "input.mem", + "output_dir": "/tmp", + "has_hbm": True, + "hbm_size": 1024, + } + + # Act + with patch("he_link.makeUniquePath", side_effect=lambda x: x): + config = he_link.LinkerRunConfig(**kwargs) + result = config.as_dict() + + # Assert Keys + assert isinstance(result, dict) + assert "input_prefixes" in result + assert "output_prefix" in result + assert "input_mem_file" in result + assert "output_dir" in result + assert "has_hbm" in result + assert "hbm_size" in result + + # Assert values + assert result["input_prefixes"] == ["prefix1"] + assert result["output_prefix"] == "output_prefix" + assert result["input_mem_file"] == "input.mem" + assert result["output_dir"] == "/tmp" + assert result["has_hbm"] is True + assert result["hbm_size"] == 1024 + + def test_str_representation(self): + """ + @brief Test the string representation of the configuration + """ + # Arrange + kwargs = { + "input_prefixes": ["prefix1"], + "output_prefix": "output_prefix", + "input_mem_file": "input.mem", + } + + # Act + with patch("he_link.makeUniquePath", side_effect=lambda x: x): + config = he_link.LinkerRunConfig(**kwargs) + result = str(config) + + # Assert params + assert "input_prefixes" in result + assert "output_prefix" in result + assert "input_mem_file" in result + # Assert values + assert "prefix1" in result + assert "output_prefix" in result + assert "input.mem" in result + + def test_init_for_default_params(self): + """ + @brief Test initialization with default parameters + """ + + # Arrange + kwargs = {"input_prefixes": ["prefix1"], "output_prefix": ""} + + # Reset the class-level config so the patch will take effect + he_link.RunConfig.reset_class_state() + + # Act + with patch("he_link.makeUniquePath", side_effect=lambda x: x), patch.object( + he_link.RunConfig, "DEFAULT_HBM_SIZE_KB", new_callable=PropertyMock + ) as mock_hbm_size, patch.object( + GlobalConfig, "suppress_comments", new_callable=PropertyMock + ) as mock_suppress_comments, patch.object( + GlobalConfig, "useXInstFetch", new_callable=PropertyMock + ) as mock_use_xinstfetch: + + # Mock the default HBM size + mock_suppress_comments.return_value = False + mock_use_xinstfetch.return_value = False + mock_hbm_size.return_value = 1024 + config = he_link.LinkerRunConfig(**kwargs) + + # Assert + assert config.output_prefix == "" + assert config.input_mem_file == "" + assert config.output_dir == os.getcwd() + assert config.has_hbm is True + assert config.hbm_size == 1024 + assert config.suppress_comments is False + assert config.use_xinstfetch is False + assert config.find_mem_files is False + + +class TestKernelFiles: + """ + @class TestKernelFiles + @brief Test cases for the KernelFiles class + """ + + def test_kernel_files_creation(self): + """ + @brief Test KernelFiles creation and attribute access + """ + # Act + kernel_files = he_link.KernelFiles( + prefix="prefix", + minst="prefix.minst", + cinst="prefix.cinst", + xinst="prefix.xinst", + mem="prefix.mem", + ) + + # Assert + assert kernel_files.prefix == "prefix" + assert kernel_files.minst == "prefix.minst" + assert kernel_files.cinst == "prefix.cinst" + assert kernel_files.xinst == "prefix.xinst" + assert kernel_files.mem == "prefix.mem" + + def test_kernel_files_without_mem(self): + """ + @brief Test KernelFiles creation without mem file + """ + # Act + kernel_files = he_link.KernelFiles( + prefix="prefix", + minst="prefix.minst", + cinst="prefix.cinst", + xinst="prefix.xinst", + ) + + # Assert + assert kernel_files.prefix == "prefix" + assert kernel_files.minst == "prefix.minst" + assert kernel_files.cinst == "prefix.cinst" + assert kernel_files.xinst == "prefix.xinst" + assert kernel_files.mem is None + + +class TestHelperFunctions: + """ + @class TestHelperFunctions + @brief Test cases for helper functions in he_link + """ + + def test_prepare_output_files(self): + """ + @brief Test prepare_output_files function creates correct output files + """ + # Arrange + mock_config = MagicMock() + mock_config.output_dir = "/tmp" + mock_config.output_prefix = "output" + mock_config.find_mem_files = False + + # Act + with patch("os.path.dirname", return_value="/tmp"), patch( + "pathlib.Path.mkdir" + ), patch("he_link.makeUniquePath", side_effect=lambda x: x): + result = he_link.prepare_output_files(mock_config) + + # Assert + assert result.prefix == "/tmp/output" + assert result.minst == "/tmp/output.minst" + assert result.cinst == "/tmp/output.cinst" + assert result.xinst == "/tmp/output.xinst" + assert result.mem is None + + def test_prepare_output_files_with_mem(self): + """ + @brief Test prepare_output_files with find_mem_files=True + """ + # Arrange + mock_config = MagicMock() + mock_config.output_dir = "/tmp" + mock_config.output_prefix = "output" + mock_config.find_mem_files = True + + # Act + with patch("os.path.dirname", return_value="/tmp"), patch( + "pathlib.Path.mkdir" + ), patch("he_link.makeUniquePath", side_effect=lambda x: x): + result = he_link.prepare_output_files(mock_config) + + # Assert + assert result.prefix == "/tmp/output" + assert result.minst == "/tmp/output.minst" + assert result.cinst == "/tmp/output.cinst" + assert result.xinst == "/tmp/output.xinst" + assert result.mem == "/tmp/output.mem" + + def test_prepare_input_files(self): + """ + @brief Test prepare_input_files function + """ + # Arrange + mock_config = MagicMock() + mock_config.input_prefixes = ["/tmp/input1", "/tmp/input2"] + mock_config.find_mem_files = False + + mock_output_files = he_link.KernelFiles( + prefix="/tmp/output", + minst="/tmp/output.minst", + cinst="/tmp/output.cinst", + xinst="/tmp/output.xinst", + ) + + # Act + with patch("os.path.isfile", return_value=True), patch( + "he_link.makeUniquePath", side_effect=lambda x: x + ): + result = he_link.prepare_input_files(mock_config, mock_output_files) + + # Assert + assert len(result) == 2 + assert result[0].prefix == "/tmp/input1" + assert result[0].minst == "/tmp/input1.minst" + assert result[0].cinst == "/tmp/input1.cinst" + assert result[0].xinst == "/tmp/input1.xinst" + assert result[0].mem is None + assert result[1].prefix == "/tmp/input2" + + def test_prepare_input_files_file_not_found(self): + """ + @brief Test prepare_input_files when a file doesn't exist + """ + # Arrange + mock_config = MagicMock() + mock_config.input_prefixes = ["/tmp/input1"] + mock_config.find_mem_files = False + + mock_output_files = he_link.KernelFiles( + prefix="/tmp/output", + minst="/tmp/output.minst", + cinst="/tmp/output.cinst", + xinst="/tmp/output.xinst", + ) + + # Act & Assert + with patch("os.path.isfile", return_value=False), patch( + "he_link.makeUniquePath", side_effect=lambda x: x + ): + with pytest.raises(FileNotFoundError): + he_link.prepare_input_files(mock_config, mock_output_files) + + def test_prepare_input_files_output_conflict(self): + """ + @brief Test prepare_input_files when input and output files conflict + """ + # Arrange + mock_config = MagicMock() + mock_config.input_prefixes = ["/tmp/input1"] + mock_config.find_mem_files = False + + # Output file matching an input file + mock_output_files = he_link.KernelFiles( + prefix="/tmp/output", + minst="/tmp/input1.minst", # Conflict + cinst="/tmp/output.cinst", + xinst="/tmp/output.xinst", + ) + + # Act & Assert + with patch("os.path.isfile", return_value=True), patch( + "he_link.makeUniquePath", side_effect=lambda x: x + ): + with pytest.raises(RuntimeError): + he_link.prepare_input_files(mock_config, mock_output_files) + + @pytest.mark.parametrize("has_hbm", [True, False]) + def test_scan_variables(self, has_hbm): + """ + @brief Test scan_variables function with and without HBM + @param has_hbm Boolean indicating whether HBM is enabled + """ + # Arrange + GlobalConfig.hasHBM = has_hbm + mock_mem_model = MagicMock() + mock_verbose = MagicMock() + + input_files = [ + he_link.KernelFiles( + prefix="/tmp/input1", + minst="/tmp/input1.minst", + cinst="/tmp/input1.cinst", + xinst="/tmp/input1.xinst", + ) + ] + + # Act + with patch("linker.loader.load_minst_kernel_from_file", return_value=[]), patch( + "linker.loader.load_cinst_kernel_from_file", return_value=[] + ), patch( + "linker.steps.variable_discovery.discoverVariables", + return_value=["var1", "var2"], + ), patch( + "linker.steps.variable_discovery.discoverVariablesSPAD", + return_value=["var1", "var2"], + ): + he_link.scan_variables(input_files, mock_mem_model, mock_verbose) + + # Assert + if has_hbm: + assert mock_mem_model.addVariable.call_count == 2 + else: + assert mock_mem_model.addVariable.call_count == 2 + + def test_check_unused_variables(self): + """ + @brief Test check_unused_variables function + """ + # Arrange + GlobalConfig.hasHBM = True + mock_mem_model = MagicMock() + mock_mem_model.mem_info_vars = {"var1": MagicMock(), "var2": MagicMock()} + mock_mem_model.variables = {"var1"} + mock_mem_model.mem_info_meta = {} + + # Act & Assert + with pytest.raises(RuntimeError): + he_link.check_unused_variables(mock_mem_model) + + def test_link_kernels(self): + """ + @brief Test link_kernels function + """ + # Arrange + input_files = [ + he_link.KernelFiles( + prefix="/tmp/input1", + minst="/tmp/input1.minst", + cinst="/tmp/input1.cinst", + xinst="/tmp/input1.xinst", + ) + ] + + output_files = he_link.KernelFiles( + prefix="/tmp/output", + minst="/tmp/output.minst", + cinst="/tmp/output.cinst", + xinst="/tmp/output.xinst", + ) + + mock_mem_model = MagicMock() + mock_verbose = MagicMock() + + # Act + with patch("builtins.open", mock_open()), patch( + "linker.loader.load_minst_kernel_from_file", return_value=[] + ), patch("linker.loader.load_cinst_kernel_from_file", return_value=[]), patch( + "linker.loader.load_xinst_kernel_from_file", return_value=[] + ), patch( + "linker.steps.program_linker.LinkedProgram" + ) as mock_linked_program: + he_link.link_kernels( + input_files, output_files, mock_mem_model, mock_verbose + ) + + # Assert + mock_linked_program.assert_called_once() + instance = mock_linked_program.return_value + assert instance.link_kernel.call_count == 1 + assert instance.close.call_count == 1 + + +class TestMainFunction: + """ + @class TestMainFunction + @brief Test cases for the main function + """ + + @pytest.mark.parametrize("find_mem_files", [True, False]) + def test_main(self, find_mem_files): + """ + @brief Test main function with find_mem_files=True + """ + # Arrange + mock_config = MagicMock() + mock_config.find_mem_files = find_mem_files + mock_config.has_hbm = True + mock_config.hbm_size = 1024 + mock_config.suppress_comments = False + mock_config.use_xinstfetch = False + + mock_verbose = MagicMock() + + # Act + with patch( + "assembler.common.constants.convertBytes2Words", return_value=1024 + ), patch("he_link.prepare_output_files") as mock_prepare_output, patch( + "he_link.prepare_input_files" + ) as mock_prepare_input, patch( + "assembler.common.counter.Counter.reset" + ), patch( + "linker.loader.load_dinst_kernel_from_file", return_value=["1", "2"] + ) as mock_load_dinst_kernel_from_file, patch( + "linker.instructions.BaseInstruction.dump_instructions_to_file" + ) as mock_dump_instructions, patch( + "linker.steps.program_linker.LinkedProgram.join_dinst_kernels", + return_value=[], + ) as mock_join_dinst_kernels, patch( + "assembler.memory_model.mem_info.MemInfo.from_dinstrs" + ) as mock_from_dinstrs, patch( + "assembler.memory_model.mem_info.MemInfo.from_file_iter" + ) as mock_from_file_iter, patch( + "linker.MemoryModel" + ), patch( + "he_link.scan_variables" + ) as mock_scan_variables, patch( + "he_link.check_unused_variables" + ) as mock_check_unused_variables, patch( + "he_link.link_kernels" + ) as mock_link_kernels, patch( + "he_link.BaseInstruction.dump_instructions_to_file" + ) as mock_dump_instructions: + + mock_prepare_input.return_value = [ + he_link.KernelFiles( + prefix="prefix1", + minst="prefix1.minst", + cinst="prefix1.cinst", + xinst="prefix1.xinst", + mem=None, + ), + he_link.KernelFiles( + prefix="prefix2", + minst="prefix2.minst", + cinst="prefix2.cinst", + xinst="prefix2.xinst", + mem=None, + ), + ] + he_link.main(mock_config, mock_verbose) + + # Assert pipeline is run as expected + mock_prepare_output.assert_called_once() + mock_prepare_input.assert_called_once() + mock_scan_variables.assert_called_once() + mock_check_unused_variables.assert_called_once() + mock_link_kernels.assert_called_once() + + if find_mem_files: + # Should use from_dinstrs, not from_file_iter + assert mock_from_dinstrs.called + assert mock_load_dinst_kernel_from_file.called + assert mock_join_dinst_kernels.called + assert mock_dump_instructions.called + + assert not mock_from_file_iter.called + else: + # Should use from_file_iter, not from_dinstrs + assert mock_from_file_iter.called + assert not mock_from_dinstrs.called + + def test_warning_on_use_xinstfetch(self): + """ + @brief Test warning is issued when use_xinstfetch is True + """ + # Arrange + mock_config = MagicMock() + mock_config.find_mem_files = False + mock_config.has_hbm = True + mock_config.hbm_size = 1024 + mock_config.suppress_comments = False + mock_config.use_xinstfetch = True # Should trigger warning + mock_config.input_mem_file = "input.mem" + + # Act & Assert + with patch("warnings.warn") as mock_warn, patch( + "assembler.common.constants.convertBytes2Words", return_value=1024 + ), patch("he_link.prepare_output_files"), patch( + "he_link.prepare_input_files" + ), patch( + "assembler.common.counter.Counter.reset" + ), patch( + "builtins.open", mock_open() + ), patch( + "assembler.memory_model.mem_info.MemInfo.from_file_iter" + ), patch( + "linker.MemoryModel" + ), patch( + "he_link.scan_variables" + ), patch( + "he_link.check_unused_variables" + ), patch( + "he_link.link_kernels" + ): + he_link.main(mock_config, None) + mock_warn.assert_called_once() + + +class TestParseArgs: + """ + @class TestParseArgs + @brief Test cases for the parse_args function + """ + + def test_parse_args_minimal(self): + """ + @brief Test parse_args with minimal arguments + """ + # Arrange + test_args = [ + "program", + "input_prefix", + "-o", + "output_prefix", + "-im", + "input.mem", + ] + + # Act + with patch("sys.argv", test_args), patch( + "argparse.ArgumentParser.parse_args", + return_value=argparse.Namespace( + input_prefixes=["input_prefix"], + output_prefix="output_prefix", + input_mem_file="input.mem", + output_dir="", + find_mem_files=False, + mem_spec_file="", + isa_spec_file="", + has_hbm=True, + hbm_size=None, + suppress_comments=False, + verbose=0, + ), + ): + args = he_link.parse_args() + + # Assert + assert args.input_prefixes == ["input_prefix"] + assert args.output_prefix == "output_prefix" + assert args.input_mem_file == "input.mem" + assert args.find_mem_files is False + + def test_parse_args_find_mem_files(self): + """ + @brief Test parse_args with find_mem_files flag + """ + # Arrange + test_args = [ + "program", + "input_prefix", + "-o", + "output_prefix", + "--find_mem_files", + ] + + # Act + with patch("sys.argv", test_args), patch( + "argparse.ArgumentParser.parse_args", + return_value=argparse.Namespace( + input_prefixes=["input_prefix"], + output_prefix="output_prefix", + input_mem_file="", + output_dir="", + find_mem_files=True, + mem_spec_file="", + isa_spec_file="", + has_hbm=True, + hbm_size=None, + suppress_comments=False, + verbose=0, + ), + ): + args = he_link.parse_args() + + # Assert + assert args.input_prefixes == ["input_prefix"] + assert args.output_prefix == "output_prefix" + assert args.input_mem_file == "" + assert args.find_mem_files is True + + def test_missing_input_mem_file(self): + """ + @brief Test parse_args with missing input_mem_file when find_mem_files is False + """ + # Arrange + test_args = ["program", "input_prefix", "-o", "output_prefix"] + + # Act & Assert + with patch("sys.argv", test_args), patch( + "argparse.ArgumentParser.parse_args", + return_value=argparse.Namespace( + input_prefixes=["input_prefix"], + output_prefix="output_prefix", + input_mem_file="", + output_dir="", + find_mem_files=False, + mem_spec_file="", + isa_spec_file="", + has_hbm=True, + hbm_size=None, + suppress_comments=False, + verbose=0, + ), + ), patch("argparse.ArgumentParser.error") as mock_error: + he_link.parse_args() + mock_error.assert_called_once() From 661a8a0a5cc835a3746ca937f9d71a815ce15048 Mon Sep 17 00:00:00 2001 From: "Rojas Chaves, Jose" Date: Mon, 7 Jul 2025 15:40:12 +0000 Subject: [PATCH 04/12] Adding more unit tests --- .../assembler/common/decorators.py | 32 +- .../assembler/memory_model/mem_info.py | 25 +- .../hec-assembler-tools/linker/__init__.py | 126 ++- .../linker/instructions/cinst/cinstruction.py | 32 +- .../linker/instructions/dinst/__init__.py | 9 +- .../linker/instructions/dinst/dinstruction.py | 79 +- .../linker/instructions/dinst/dkeygen.py | 15 +- .../linker/instructions/dinst/dload.py | 24 +- .../linker/instructions/dinst/dstore.py | 16 +- .../linker/instructions/instruction.py | 33 +- .../linker/instructions/minst/minstruction.py | 31 +- .../linker/instructions/xinst/xinstruction.py | 37 +- .../hec-assembler-tools/pytest.ini | 2 +- .../hec-assembler-tools/tests/__init__.py | 4 + .../hec-assembler-tools/tests/conftest.py | 28 +- .../tests/unit_tests/__init__.py | 2 + .../unit_tests/test_assembler/__init__.py | 2 + .../test_assembler/memory_model/__init__.py | 2 + .../memory_model/test_mem_info.py | 973 ++++++++++++++++++ .../tests/unit_tests/test_linker/__init__.py | 2 + .../tests/unit_tests/test_linker/test_init.py | 324 ++++++ .../test_linker/test_instructions/__init__.py | 4 + .../test_instructions/test_dinst/__init__.py | 4 + .../test_dinst/test_dinstruction.py | 95 ++ .../test_dinst/test_dkeygen.py | 116 +++ .../test_dinst/test_dload.py | 132 +++ .../test_dinst/test_dstore.py | 122 +++ .../test_instructions/test_dinst/test_init.py | 165 +++ .../test_instructions/test_init.py | 122 +++ .../test_instructions/test_instruction.py | 132 +++ .../unit_tests/test_linker/test_loader.py | 317 ++++++ .../test_linker/test_steps/__init__.py | 2 + .../test_steps/test_program_linker.py | 623 +++++++++++ .../test_steps/test_variable_discovery.py | 248 +++++ 34 files changed, 3758 insertions(+), 122 deletions(-) create mode 100644 assembler_tools/hec-assembler-tools/tests/__init__.py create mode 100644 assembler_tools/hec-assembler-tools/tests/unit_tests/__init__.py create mode 100644 assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/__init__.py create mode 100644 assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/memory_model/__init__.py create mode 100644 assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/memory_model/test_mem_info.py create mode 100644 assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/__init__.py create mode 100644 assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_init.py create mode 100644 assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/__init__.py create mode 100644 assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/__init__.py create mode 100644 assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dinstruction.py create mode 100644 assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dkeygen.py create mode 100644 assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dload.py create mode 100644 assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dstore.py create mode 100644 assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_init.py create mode 100644 assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_init.py create mode 100644 assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_instruction.py create mode 100644 assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_loader.py create mode 100644 assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/__init__.py create mode 100644 assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_program_linker.py create mode 100644 assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_variable_discovery.py diff --git a/assembler_tools/hec-assembler-tools/assembler/common/decorators.py b/assembler_tools/hec-assembler-tools/assembler/common/decorators.py index 09beaa98..1b4c5aea 100644 --- a/assembler_tools/hec-assembler-tools/assembler/common/decorators.py +++ b/assembler_tools/hec-assembler-tools/assembler/common/decorators.py @@ -1,28 +1,30 @@ - -class classproperty(object): +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +This module provides decorator utilities for the assembler. + +It contains decorators that enhance class and function behavior, +including property-like decorators for class methods. +""" + + +# pylint: disable=invalid-name +class classproperty(property): """ A decorator that allows a method to be accessed as a class-level property rather than on instances of the class. """ - def __init__(self, f): - """ - Initializes the classproperty with the given function. - - Args: - f (function): The function to be used as a class-level property. - """ - self.f = f - - def __get__(self, obj, owner): + def __get__(self, cls, owner): """ Retrieves the value of the class-level property. Args: - obj: The instance of the class (ignored in this context). - owner: The class that owns the property. + cls: The class that owns the property. + owner: The owner of the class (ignored in this context). Returns: The result of calling the decorated function with the class as an argument. """ - return self.f(owner) + return self.fget(owner) diff --git a/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_info.py b/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_info.py index 834fc8c5..d039851e 100644 --- a/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_info.py +++ b/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_info.py @@ -291,7 +291,7 @@ def __init__(self, **kwargs): MemInfoVariable(**d) for d in kwargs.get(meta_field, []) ] - def get_item(self, key): + def __getitem__(self, key): """ Retrieves the list of MemInfoVariable objects for the specified metadata field. @@ -457,7 +457,8 @@ def __init__(self, **kwargs): a shortcut to creating a MemInfo object from structured data such as the contents of a YAML file. """ self._keygens = [ - MemInfoVariable(**d) for d in kwargs.get(MemInfo.Const.FIELD_KEYGENS, []) + MemInfoKeygenVariable(**d) + for d in kwargs.get(MemInfo.Const.FIELD_KEYGENS, []) ] self._inputs = [ MemInfoVariable(**d) for d in kwargs.get(MemInfo.Const.FIELD_INPUTS, []) @@ -703,7 +704,7 @@ def validate(self): ) -def __allocateMemInfoVariable(mem_model: MemoryModel, v_info: MemInfoVariable): +def _allocateMemInfoVariable(mem_model: MemoryModel, v_info: MemInfoVariable): """ Allocates a memory information variable in the memory model. @@ -759,11 +760,11 @@ def updateMemoryModelWithMemInfo(mem_model: MemoryModel, mem_info: MemInfo): # Inputs for v_info in mem_info.inputs: - __allocateMemInfoVariable(mem_model, v_info) + _allocateMemInfoVariable(mem_model, v_info) # Outputs for v_info in mem_info.outputs: - __allocateMemInfoVariable(mem_model, v_info) + _allocateMemInfoVariable(mem_model, v_info) mem_model.output_variables.push(v_info.var_name, None) # Metadata @@ -771,7 +772,7 @@ def updateMemoryModelWithMemInfo(mem_model: MemoryModel, mem_info: MemInfo): # Ones for v_info in mem_info.metadata.ones: mem_model.retrieveVarAdd(v_info.var_name) - __allocateMemInfoVariable(mem_model, v_info) + _allocateMemInfoVariable(mem_model, v_info) mem_model.add_meta_ones_var(v_info.var_name) # Shuffle meta vars @@ -779,40 +780,40 @@ def updateMemoryModelWithMemInfo(mem_model: MemoryModel, mem_info: MemInfo): assert len(mem_info.metadata.ntt_auxiliary_table) == 1 v_info = mem_info.metadata.ntt_auxiliary_table[0] mem_model.retrieveVarAdd(v_info.var_name) - __allocateMemInfoVariable(mem_model, v_info) + _allocateMemInfoVariable(mem_model, v_info) mem_model.meta_ntt_aux_table = v_info.var_name if mem_info.metadata.ntt_routing_table: assert len(mem_info.metadata.ntt_routing_table) == 1 v_info = mem_info.metadata.ntt_routing_table[0] mem_model.retrieveVarAdd(v_info.var_name) - __allocateMemInfoVariable(mem_model, v_info) + _allocateMemInfoVariable(mem_model, v_info) mem_model.meta_ntt_routing_table = v_info.var_name if mem_info.metadata.intt_auxiliary_table: assert len(mem_info.metadata.intt_auxiliary_table) == 1 v_info = mem_info.metadata.intt_auxiliary_table[0] mem_model.retrieveVarAdd(v_info.var_name) - __allocateMemInfoVariable(mem_model, v_info) + _allocateMemInfoVariable(mem_model, v_info) mem_model.meta_intt_aux_table = v_info.var_name if mem_info.metadata.intt_routing_table: assert len(mem_info.metadata.intt_routing_table) == 1 v_info = mem_info.metadata.intt_routing_table[0] mem_model.retrieveVarAdd(v_info.var_name) - __allocateMemInfoVariable(mem_model, v_info) + _allocateMemInfoVariable(mem_model, v_info) mem_model.meta_intt_routing_table = v_info.var_name # Twiddle for v_info in mem_info.metadata.twiddle: mem_model.retrieveVarAdd(v_info.var_name) - __allocateMemInfoVariable(mem_model, v_info) + _allocateMemInfoVariable(mem_model, v_info) mem_model.add_meta_twiddle_var(v_info.var_name) # Keygen seeds for v_info in mem_info.metadata.keygen_seeds: mem_model.retrieveVarAdd(v_info.var_name) - __allocateMemInfoVariable(mem_model, v_info) + _allocateMemInfoVariable(mem_model, v_info) mem_model.add_meta_keygen_seed_var(v_info.var_name) # End metadata diff --git a/assembler_tools/hec-assembler-tools/linker/__init__.py b/assembler_tools/hec-assembler-tools/linker/__init__.py index b26839bb..9dadd1f4 100644 --- a/assembler_tools/hec-assembler-tools/linker/__init__.py +++ b/assembler_tools/hec-assembler-tools/linker/__init__.py @@ -1,3 +1,6 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + import collections.abc as collections from assembler.common.config import GlobalConfig from assembler.memory_model import mem_info @@ -5,6 +8,7 @@ # linker/__init__.py contains classes to encapsulate the memory model used # by the linker. + class VariableInfo(mem_info.MemInfoVariable): """ Represents information about a variable in the memory model. @@ -22,6 +26,7 @@ def __init__(self, var_name, hbm_address=-1): self.uses = 0 self.last_kernel_used = -1 + class HBM: """ Represents the HBM model. @@ -38,7 +43,7 @@ def __init__(self, hbm_size_words: int): ValueError: If hbm_size_words is less than 1. """ if hbm_size_words < 1: - raise ValueError('`hbm_size_words` must be a positive integer.') + raise ValueError("`hbm_size_words` must be a positive integer.") # Represents the memory buffer where variables live self.__buffer = [None] * hbm_size_words @@ -76,11 +81,16 @@ def forceAllocate(self, var_info: VariableInfo, hbm_address: int): RuntimeError: If the HBM address is already occupied by another variable. """ if hbm_address < 0 or hbm_address >= len(self.buffer): - raise IndexError('`hbm_address` out of bounds. Expected a word address in range [0, {}), but {} received'.format(len(self.buffer), - hbm_address)) + raise IndexError( + "`hbm_address` out of bounds. Expected a word address in range [0, {}), but {} received".format( + len(self.buffer), hbm_address + ) + ) if var_info.hbm_address != hbm_address: if var_info.hbm_address >= 0: - raise ValueError(f'`var_info`: variable {var_info.var_name} already allocated in address {var_info.hbm_address}.') + raise ValueError( + f"`var_info`: variable {var_info.var_name} already allocated in address {var_info.hbm_address}." + ) in_var_info = self.buffer[hbm_address] # Validate hbm address @@ -88,17 +98,23 @@ def forceAllocate(self, var_info: VariableInfo, hbm_address: int): # Attempt to recycle SPAD locations inside kernel when no HBM # Note: there is no HBM, so, SPAD is used as the sole memory space if in_var_info and in_var_info.uses > 0: - raise RuntimeError(('HBM address {} already occupied by variable {} ' - 'when attempting to allocate variable {}').format(hbm_address, - in_var_info.var_name, - var_info.var_name)) + raise RuntimeError( + ( + "HBM address {} already occupied by variable {} " + "when attempting to allocate variable {}" + ).format(hbm_address, in_var_info.var_name, var_info.var_name) + ) else: - if in_var_info \ - and (in_var_info.uses > 0 or in_var_info.last_kernel_used >= var_info.last_kernel_used): - raise RuntimeError(('HBM address {} already occupied by variable {} ' - 'when attempting to allocate variable {}').format(hbm_address, - in_var_info.var_name, - var_info.var_name)) + if in_var_info and ( + in_var_info.uses > 0 + or in_var_info.last_kernel_used >= var_info.last_kernel_used + ): + raise RuntimeError( + ( + "HBM address {} already occupied by variable {} " + "when attempting to allocate variable {}" + ).format(hbm_address, in_var_info.var_name, var_info.var_name) + ) var_info.hbm_address = hbm_address self.buffer[hbm_address] = var_info @@ -122,14 +138,17 @@ def allocate(self, var_info: VariableInfo): retval = idx break else: - if not in_var_info \ - or (in_var_info.uses <= 0 and in_var_info.last_kernel_used < var_info.last_kernel_used): + if not in_var_info or ( + in_var_info.uses <= 0 + and in_var_info.last_kernel_used < var_info.last_kernel_used + ): retval = idx break if retval < 0: - raise RuntimeError('Out of HBM memory.') + raise RuntimeError("Out of HBM memory.") self.forceAllocate(var_info, retval) + class MemoryModel: """ Encapsulates the memory model for a linker run, tracking HBM usage and program variables. @@ -146,20 +165,51 @@ def __init__(self, hbm_size_words: int, mem_meta_info: mem_info.MemInfo): self.hbm = HBM(hbm_size_words) self.__mem_info = mem_meta_info self.__variables = {} # dict(var_name: str, VariableInfo) - self.__keygen_vars = {var_info.var_name: var_info for var_info in self.__mem_info.keygens} - self.__mem_info_inputs = {var_info.var_name: var_info for var_info in self.__mem_info.inputs} - self.__mem_info_outputs = {var_info.var_name: var_info for var_info in self.__mem_info.outputs} - self.__mem_info_meta = {var_info.var_name: var_info for var_info in self.__mem_info.metadata.intt_auxiliary_table} \ - | {var_info.var_name: var_info for var_info in self.__mem_info.metadata.intt_routing_table} \ - | {var_info.var_name: var_info for var_info in self.__mem_info.metadata.ntt_auxiliary_table} \ - | {var_info.var_name: var_info for var_info in self.__mem_info.metadata.ntt_routing_table} \ - | {var_info.var_name: var_info for var_info in self.__mem_info.metadata.ones} \ - | {var_info.var_name: var_info for var_info in self.__mem_info.metadata.twiddle} \ - | {var_info.var_name: var_info for var_info in self.__mem_info.metadata.keygen_seeds} + self.__keygen_vars = { + var_info.var_name: var_info for var_info in self.__mem_info.keygens + } + self.__mem_info_inputs = { + var_info.var_name: var_info for var_info in self.__mem_info.inputs + } + self.__mem_info_outputs = { + var_info.var_name: var_info for var_info in self.__mem_info.outputs + } + self.__mem_info_meta = ( + { + var_info.var_name: var_info + for var_info in self.__mem_info.metadata.intt_auxiliary_table + } + | { + var_info.var_name: var_info + for var_info in self.__mem_info.metadata.intt_routing_table + } + | { + var_info.var_name: var_info + for var_info in self.__mem_info.metadata.ntt_auxiliary_table + } + | { + var_info.var_name: var_info + for var_info in self.__mem_info.metadata.ntt_routing_table + } + | { + var_info.var_name: var_info + for var_info in self.__mem_info.metadata.ones + } + | { + var_info.var_name: var_info + for var_info in self.__mem_info.metadata.twiddle + } + | { + var_info.var_name: var_info + for var_info in self.__mem_info.metadata.keygen_seeds + } + ) self.__mem_info_fixed_addr_vars = self.__mem_info_outputs | self.__mem_info_meta # Keygen variables should not be part of mem_info_vars set since they # do not start in HBM - self.__mem_info_vars = self.__mem_info_inputs | self.__mem_info_outputs | self.__mem_info_meta + self.__mem_info_vars = ( + self.__mem_info_inputs | self.__mem_info_outputs | self.__mem_info_meta + ) @property def mem_info_meta(self) -> collections.Collection: @@ -168,7 +218,7 @@ def mem_info_meta(self) -> collections.Collection: Clients must not modify this set. """ return self.__mem_info_meta - + @property def mem_info_vars(self) -> collections.Collection: """ @@ -202,20 +252,17 @@ def addVariable(self, var_name: str): """ var_info: VariableInfo if var_name in self.variables: - print(f' ROCHA Variable {var_name} already exists in memory model.') var_info = self.variables[var_name] else: - print(f'ROCHA Adding variable {var_name} to memory model.') var_info = VariableInfo(var_name) if var_name in self.__mem_info_vars: - print(f'\tROCHA Variable {var_name} is in MemInfo, allocating HBM address. HBM address {self.__mem_info_vars[var_name].hbm_address}.') # Variables explicitly marked in mem file must persist throughout the program # with predefined HBM address if var_name in self.__mem_info_fixed_addr_vars: - print(f'\tROCHA Variable {var_name} has fixed HBM address {self.__mem_info_vars[var_name].hbm_address}.') - var_info.uses = float('inf') - self.hbm.forceAllocate(var_info, - self.__mem_info_vars[var_name].hbm_address) + var_info.uses = float("inf") + self.hbm.forceAllocate( + var_info, self.__mem_info_vars[var_name].hbm_address + ) self.variables[var_name] = var_info var_info.uses += 1 @@ -244,7 +291,8 @@ def useVariable(self, var_name: str, kernel: int) -> int: self.hbm.allocate(var_info) assert var_info.hbm_address >= 0 - assert self.hbm.buffer[var_info.hbm_address].var_name == var_info.var_name, \ - f'Expected variable {var_info.var_name} in HBM {var_info.hbm_address}, but variable {self.hbm[var_info.hbm_address].var_name} found instead.' + assert ( + self.hbm.buffer[var_info.hbm_address].var_name == var_info.var_name + ), f"Expected variable {var_info.var_name} in HBM {var_info.hbm_address}, but variable {self.hbm[var_info.hbm_address].var_name} found instead." - return var_info.hbm_address \ No newline at end of file + return var_info.hbm_address diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/cinst/cinstruction.py b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/cinstruction.py index a82fb36d..81df3417 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/cinst/cinstruction.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/cinstruction.py @@ -1,10 +1,27 @@ -from linker.instructions.instruction import BaseInstruction +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""This module implements the base class for CInstructions.""" + +from linker.instructions.instruction import BaseInstruction + class CInstruction(BaseInstruction): """ Represents a CInstruction, inheriting from BaseInstruction. """ + @classmethod + def _get_name(cls) -> str: + """ + Derived classes should implement this method and return correct + name for the instruction. + + Raises: + NotImplementedError: Abstract method. This base method should not be called. + """ + raise NotImplementedError() + @classmethod def _get_name_token_index(cls) -> int: """ @@ -15,6 +32,17 @@ def _get_name_token_index(cls) -> int: """ return 1 + @classmethod + def _get_num_tokens(cls) -> int: + """ + Derived classes should implement this method and return correct + required number of tokens for the instruction. + + Raises: + NotImplementedError: Abstract method. This base method should not be called. + """ + raise NotImplementedError() + # Constructor # ----------- @@ -38,4 +66,4 @@ def to_line(self) -> str: Returns: str: The string representation of the instruction, excluding the first token. """ - return ", ".join(self.tokens[1:]) \ No newline at end of file + return ", ".join(self.tokens[1:]) diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/__init__.py b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/__init__.py index 3d21769b..d345637e 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/__init__.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/__init__.py @@ -3,6 +3,8 @@ """This module provides functionality to create and manage data instructions""" +from typing import Optional + from assembler.instructions import tokenize_from_line from assembler.memory_model.mem_info import MemInfo from . import dload, dstore, dkeygen @@ -34,16 +36,21 @@ def create_from_mem_line(line: str) -> dinstruction.DInstruction: DInstruction or None: The parsed DInstruction object, or None if no object could be parsed from the specified input line. """ - retval: dinstruction.DInstruction = None + print(f"ROCHA: create_from_mem_line {line}") + retval: Optional[dinstruction.DInstruction] = None tokens, comment = tokenize_from_line(line) for instr_type in factory(): try: retval = instr_type(tokens, comment) + print(f"ROCHA: {instr_type.__name__} {tokens} {retval}") except ValueError: retval = None if retval: break + if not retval: + raise RuntimeError(f'No valid instruction found for line "{line}"') + try: miv, _ = MemInfo.get_meminfo_var_from_tokens(tokens) except RuntimeError as e: diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dinstruction.py b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dinstruction.py index 74ff95e0..72358a4a 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dinstruction.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dinstruction.py @@ -1,6 +1,17 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +This module defines the base DInstruction class for data handling instructions. + +DInstruction is the parent class for all data instructions used in the +assembly process, providing common functionality and interfaces. +""" + from linker.instructions.instruction import BaseInstruction from assembler.common.counter import Counter -from assembler.common.decorators import * +from assembler.common.decorators import classproperty + class DInstruction(BaseInstruction): """ @@ -11,6 +22,17 @@ class DInstruction(BaseInstruction): _var: str = "" _address: int = 0 + @classmethod + def _get_name(cls) -> str: + """ + Derived classes should implement this method and return correct + name for the instruction. + + Raises: + NotImplementedError: Abstract method. This base method should not be called. + """ + raise NotImplementedError() + @classmethod def _get_name_token_index(cls) -> int: """ @@ -21,16 +43,50 @@ def _get_name_token_index(cls) -> int: """ return 0 + @classmethod + def _get_num_tokens(cls) -> int: + """ + Derived classes should implement this method and return correct + required number of tokens for the instruction. + + Raises: + NotImplementedError: Abstract method. This base method should not be called. + """ + raise NotImplementedError() + @classproperty - def num_tokens(cls) -> int: + def num_tokens(self) -> int: """ Valid number of tokens for this instruction. Returns: - tupple: Valid number of tokens. + tuple: Valid number of tokens. """ - return cls._get_num_tokens() - + return self._get_num_tokens() + + def _validate_tokens(self, tokens: list) -> None: + """ + Validates the tokens for this instruction. + + DInstruction allows at least the required number of tokens. + + Parameters: + tokens (list): List of tokens to validate. + + Raises: + ValueError: If tokens are invalid. + """ + assert self.name_token_index < self.num_tokens + if len(tokens) < self.num_tokens: + raise ValueError( + f"`tokens`: invalid amount of tokens. " + f"Instruction {self.name} requires at least {self.num_tokens}, but {len(tokens)} received" + ) + if tokens[self.name_token_index] != self.name: + raise ValueError( + f"`tokens`: invalid name. Expected {self.name}, but {tokens[self.name_token_index]} received" + ) + def __init__(self, tokens: list, comment: str = ""): """ Constructs a new DInstruction. @@ -40,13 +96,8 @@ def __init__(self, tokens: list, comment: str = ""): comment (str): Optional comment for the instruction. """ # Do not increment the global instruction count; skip BaseInstruction's __init__ logic for __id - assert self.name_token_index < self.num_tokens - - if len(tokens) > self.num_tokens: - raise ValueError((f"`tokens`: invalid amount of tokens. " - f"Instruction {self.name} requires at least {self.num_tokens}, but {len(tokens)} received")) - if tokens[self.name_token_index] != self.name: - raise ValueError(f"`tokens`: invalid name. Expected {self.name}, but {tokens[self.name_token_index]} received") + # Call BaseInstruction constructor but perform our own initialization + super().__init__(tokens, comment=comment) self.comment = comment self._tokens = list(tokens) @@ -76,7 +127,7 @@ def var(self, value: str): self._var = value @property - def address(self) -> str: + def address(self) -> int: """ Should be set to source/dest Mem address. """ @@ -84,7 +135,7 @@ def address(self) -> str: @address.setter def address(self, value: str): - self._address = value + self._address = int(value) if isinstance(value, str) else value def to_line(self) -> str: """ diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dkeygen.py b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dkeygen.py index ba19d31e..543633a2 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dkeygen.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dkeygen.py @@ -1,7 +1,15 @@ -from .dinstruction import DInstruction -from assembler.common.config import GlobalConfig +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +This module implements the DKeyGen instruction for key generation operations. +""" + from assembler.memory_model.mem_info import MemInfo +from .dinstruction import DInstruction + + class Instruction(DInstruction): """ Encapsulates a `dkeygen` DInstruction. @@ -10,7 +18,7 @@ class Instruction(DInstruction): @classmethod def _get_num_tokens(cls) -> int: """ - Gets the number of tokens allowed for the instruction. + Gets the number of tokens required for the instruction. Returns: int: The number of tokens, which is 4. @@ -26,4 +34,3 @@ def _get_name(cls) -> str: str: The name of the instruction. """ return MemInfo.Const.Keyword.KEYGEN - diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dload.py b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dload.py index b398f8d1..c6fb694b 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dload.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dload.py @@ -1,5 +1,16 @@ -from .dinstruction import DInstruction +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +This module implements the DLoad instruction for loading data from memory. + +The DLoad instruction is used to load data from specified memory locations +during the assembly process. +""" + from assembler.memory_model.mem_info import MemInfo +from .dinstruction import DInstruction + class Instruction(DInstruction): """ @@ -7,14 +18,14 @@ class Instruction(DInstruction): """ @classmethod - def _get_num_tokens(cls) -> tuple: + def _get_num_tokens(cls) -> int: """ - Gets the number of tokens allowed for the instruction. + Gets the number of tokens required for the instruction. Returns: - tupple: The number of tokens, which is 4. + int: The number of tokens, which is 3. """ - return 4 + return 3 @classmethod def _get_name(cls) -> str: @@ -25,7 +36,7 @@ def _get_name(cls) -> str: str: The name of the instruction. """ return MemInfo.Const.Keyword.LOAD - + @property def tokens(self) -> list: """ @@ -35,4 +46,3 @@ def tokens(self) -> list: list: The list of tokens. """ return [self.name, self._tokens[1], str(self.address)] + self._tokens[3:] - diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dstore.py b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dstore.py index 543172d6..fef41739 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dstore.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dstore.py @@ -1,6 +1,16 @@ -from .dinstruction import DInstruction -from assembler.common.config import GlobalConfig +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +This module implements the DStore instruction for storing data to memory. + +The DStore instruction is used to store data to specified memory locations +during the assembly process. +""" + from assembler.memory_model.mem_info import MemInfo +from .dinstruction import DInstruction + class Instruction(DInstruction): """ @@ -10,7 +20,7 @@ class Instruction(DInstruction): @classmethod def _get_num_tokens(cls) -> int: """ - Gets the number of tokens allowed for the instruction. + Gets the number of tokens required for the instruction. Returns: int: The number of tokens, which is 3. diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py b/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py index 768aa35a..1764fcf3 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py @@ -134,24 +134,37 @@ def __init__(self, tokens: list, comment: str = ""): """ assert self.name_token_index < self.num_tokens + self._validate_tokens(tokens) + + self._id = next(BaseInstruction.__id_count) + + self._tokens = list(tokens) + self.comment = comment + + def _validate_tokens(self, tokens: list) -> None: + """ + Validates the tokens for this instruction. + + Default implementation checks for exact token count match. + Child classes can override this method to implement different validation logic. + + Parameters: + tokens (list): List of tokens to validate. + + Raises: + ValueError: If tokens are invalid. + """ if len(tokens) != self.num_tokens: # pylint: disable=W0143 raise ValueError( - ( - f"`tokens`: invalid amount of tokens. " - f"Instruction {self.name} requires less " - f"than {self.num_tokens}, but {len(tokens)} received" - ) + f"`tokens`: invalid amount of tokens. " + f"Instruction {self.name} requires exactly {self.num_tokens}, but {len(tokens)} received" ) + if tokens[self.name_token_index] != self.name: # pylint: disable=W0143 raise ValueError( f"`tokens`: invalid name. Expected {self.name}, but {tokens[self.name_token_index]} received" ) - self._id = next(BaseInstruction.__id_count) - - self._tokens = list(tokens) - self.comment = comment - def __repr__(self): retval = f"<{type(self).__name__}({self.name}, id={self.id}) object at {hex(id(self))}>(tokens={self.tokens})" return retval diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/minst/minstruction.py b/assembler_tools/hec-assembler-tools/linker/instructions/minst/minstruction.py index 00d40452..10f04ecb 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/minst/minstruction.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/minst/minstruction.py @@ -1,11 +1,27 @@ - +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""This module implements the base class for MInstructions.""" + from linker.instructions.instruction import BaseInstruction + class MInstruction(BaseInstruction): """ Represents an MInstruction, inheriting from BaseInstruction. """ + @classmethod + def _get_name(cls) -> str: + """ + Derived classes should implement this method and return correct + name for the instruction. + + Raises: + NotImplementedError: Abstract method. This base method should not be called. + """ + raise NotImplementedError() + @classmethod def _get_name_token_index(cls) -> int: """ @@ -16,6 +32,17 @@ def _get_name_token_index(cls) -> int: """ return 1 + @classmethod + def _get_num_tokens(cls) -> int: + """ + Derived classes should implement this method and return correct + required number of tokens for the instruction. + + Raises: + NotImplementedError: Abstract method. This base method should not be called. + """ + raise NotImplementedError() + # Constructor # ----------- @@ -39,4 +66,4 @@ def to_line(self) -> str: Returns: str: The string representation of the instruction, excluding the first token. """ - return ", ".join(self.tokens[1:]) \ No newline at end of file + return ", ".join(self.tokens[1:]) diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/xinst/xinstruction.py b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/xinstruction.py index 3a622539..43b2b9b4 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/xinst/xinstruction.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/xinstruction.py @@ -1,11 +1,27 @@ - +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""This module implements the XInstruction""" + from linker.instructions.instruction import BaseInstruction + class XInstruction(BaseInstruction): """ Represents an XInstruction, inheriting from BaseInstruction. """ + @classmethod + def _get_name(cls) -> str: + """ + Derived classes should implement this method and return correct + name for the instruction. + + Raises: + NotImplementedError: Abstract method. This base method should not be called. + """ + raise NotImplementedError() + @classmethod def _get_name_token_index(cls) -> int: """ @@ -17,6 +33,17 @@ def _get_name_token_index(cls) -> int: # Name at index 2. return 2 + @classmethod + def _get_num_tokens(cls) -> int: + """ + Derived classes should implement this method and return correct + required number of tokens for the instruction. + + Raises: + NotImplementedError: Abstract method. This base method should not be called. + """ + raise NotImplementedError() + # Constructor # ----------- @@ -44,7 +71,7 @@ def bundle(self) -> int: Raises: RuntimeError: If the bundle format is invalid. """ - if len(self.tokens[0]) < 2 or self.tokens[0][0] != 'F': + if len(self.tokens[0]) < 2 or self.tokens[0][0] != "F": raise RuntimeError(f'Invalid bundle format detected: "{self.tokens[0]}".') return int(self.tokens[0][1:]) @@ -60,5 +87,7 @@ def bundle(self, value: int): ValueError: If the value is negative. """ if value < 0: - raise ValueError(f'`value`: expected non-negative bundle index, but {value} received.') - self.tokens[0] = f'F{value}' \ No newline at end of file + raise ValueError( + f"`value`: expected non-negative bundle index, but {value} received." + ) + self.tokens[0] = f"F{value}" diff --git a/assembler_tools/hec-assembler-tools/pytest.ini b/assembler_tools/hec-assembler-tools/pytest.ini index 06e92c68..be1763fa 100644 --- a/assembler_tools/hec-assembler-tools/pytest.ini +++ b/assembler_tools/hec-assembler-tools/pytest.ini @@ -1,4 +1,4 @@ [pytest] pythonpath = . testpaths = tests -#addopts = --cov=. +addopts = --cov=. diff --git a/assembler_tools/hec-assembler-tools/tests/__init__.py b/assembler_tools/hec-assembler-tools/tests/__init__.py new file mode 100644 index 00000000..58bdd8eb --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/__init__.py @@ -0,0 +1,4 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# Empty init file to make the directory a Python package diff --git a/assembler_tools/hec-assembler-tools/tests/conftest.py b/assembler_tools/hec-assembler-tools/tests/conftest.py index 4ce586a0..b8bc75ec 100644 --- a/assembler_tools/hec-assembler-tools/tests/conftest.py +++ b/assembler_tools/hec-assembler-tools/tests/conftest.py @@ -15,8 +15,22 @@ from assembler.spec_config.isa_spec import ISASpecConfig from assembler.spec_config.mem_spec import MemSpecConfig -# Add the parent directory to sys.path to make imports work -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +# Remove any existing paths that might conflict +for path in list(sys.path): + if "hec-assembler-tools" in path: + sys.path.remove(path) + +# Get the absolute path to the repository root directory +# Structure: /path/to/repo/encrypted-computing-sdk/assembler_tools/linker_sdk/tests/conftest.py +current_dir = os.path.dirname(os.path.abspath(__file__)) +linker_sdk_dir = os.path.dirname(current_dir) # linker_sdk directory +assembler_tools_dir = os.path.dirname(linker_sdk_dir) # assembler_tools directory +repo_root = os.path.dirname(assembler_tools_dir) # encrypted-computing-sdk directory + +# Add the paths to sys.path +sys.path.insert(0, linker_sdk_dir) +sys.path.insert(0, assembler_tools_dir) +sys.path.insert(0, repo_root) @pytest.fixture(autouse=True) @@ -24,10 +38,8 @@ def mock_env_variables(): """ @brief Fixture to mock environment variables and provide common mocks """ - with patch.dict( - "os.environ", - {"PYTHONPATH": "/home/jmrojasc/test/linker_sdk/encrypted-computing-sdk"}, - ): + # Use the repository root in PYTHONPATH instead of an absolute path + with patch.dict("os.environ", {"PYTHONPATH": repo_root}): yield @@ -42,7 +54,7 @@ def initialize_specs(): initialization methods for both ISASpecConfig and MemSpecConfig. Note: - This fixture is intended to be run from the assembler root directory. + This fixture is intended to be run from any location. Yields: None @@ -51,6 +63,6 @@ def initialize_specs(): Any exceptions raised by ISASpecConfig.initialize_isa_spec or MemSpecConfig.initialize_mem_spec will propagate. """ - module_dir = os.path.dirname(os.path.dirname(__file__)) + module_dir = linker_sdk_dir ISASpecConfig.initialize_isa_spec(module_dir, "") MemSpecConfig.initialize_mem_spec(module_dir, "") diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/__init__.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/__init__.py new file mode 100644 index 00000000..4057dc01 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/__init__.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/__init__.py new file mode 100644 index 00000000..4057dc01 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/memory_model/__init__.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/memory_model/__init__.py new file mode 100644 index 00000000..4057dc01 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/memory_model/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/memory_model/test_mem_info.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/memory_model/test_mem_info.py new file mode 100644 index 00000000..60ba2d44 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/memory_model/test_mem_info.py @@ -0,0 +1,973 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +Unit tests for the memory model mem_info module. +""" + +import unittest +from unittest.mock import patch, MagicMock, call, PropertyMock + +from assembler.memory_model import MemoryModel +from assembler.memory_model.mem_info import ( + MemInfoVariable, + MemInfoKeygenVariable, + MemInfo, + updateMemoryModelWithMemInfo, + _allocateMemInfoVariable, +) + + +class TestMemInfoVariable(unittest.TestCase): + """Tests for the MemInfoVariable class.""" + + def test_init_valid(self): + """Test initialization with valid parameters.""" + with patch( + "assembler.memory_model.variable.Variable.validateName", return_value=True + ): + var = MemInfoVariable("test_var", 42) + self.assertEqual(var.var_name, "test_var") + self.assertEqual(var.hbm_address, 42) + + def test_init_strips_whitespace(self): + """Test that initialization strips whitespace from variable name.""" + with patch( + "assembler.memory_model.variable.Variable.validateName", return_value=True + ): + var = MemInfoVariable(" test_var ", 42) + self.assertEqual(var.var_name, "test_var") + + def test_init_invalid_name(self): + """Test initialization with invalid variable name.""" + with patch( + "assembler.memory_model.variable.Variable.validateName", return_value=False + ): + with self.assertRaises(RuntimeError): + MemInfoVariable("invalid!var", 42) + + def test_repr(self): + """Test the __repr__ method.""" + with patch( + "assembler.memory_model.variable.Variable.validateName", return_value=True + ): + var = MemInfoVariable("test_var", 42) + self.assertEqual( + repr(var), repr({"var_name": "test_var", "hbm_address": 42}) + ) + + def test_as_dict(self): + """Test the as_dict method.""" + with patch( + "assembler.memory_model.variable.Variable.validateName", return_value=True + ): + var = MemInfoVariable("test_var", 42) + self.assertEqual(var.as_dict(), {"var_name": "test_var", "hbm_address": 42}) + + +class TestMemInfoKeygenVariable(unittest.TestCase): + """Tests for the MemInfoKeygenVariable class.""" + + def test_init_valid(self): + """Test initialization with valid parameters.""" + with patch( + "assembler.memory_model.variable.Variable.validateName", return_value=True + ): + var = MemInfoKeygenVariable("test_var", 2, 3) + self.assertEqual(var.var_name, "test_var") + self.assertEqual(var.hbm_address, -1) # Should be initialized to -1 + self.assertEqual(var.seed_index, 2) + self.assertEqual(var.key_index, 3) + + def test_init_negative_seed_index(self): + """Test initialization with negative seed index.""" + with patch( + "assembler.memory_model.variable.Variable.validateName", return_value=True + ): + with self.assertRaises(IndexError): + MemInfoKeygenVariable("test_var", -1, 3) + + def test_init_negative_key_index(self): + """Test initialization with negative key index.""" + with patch( + "assembler.memory_model.variable.Variable.validateName", return_value=True + ): + with self.assertRaises(IndexError): + MemInfoKeygenVariable("test_var", 2, -1) + + def test_as_dict(self): + """Test the as_dict method.""" + with patch( + "assembler.memory_model.variable.Variable.validateName", return_value=True + ): + var = MemInfoKeygenVariable("test_var", 2, 3) + self.assertEqual( + var.as_dict(), {"var_name": "test_var", "seed_index": 2, "key_index": 3} + ) + + +class TestMemInfoMetadata(unittest.TestCase): + """Tests for the MemInfo.Metadata class.""" + + def test_parse_meta_field_from_mem_tokens_valid(self): + """Test parsing a valid metadata field.""" + tokens = ["dload", "LOAD_ONES", "42", "ones_var"] + result = MemInfo.Metadata.parse_meta_field_from_mem_tokens( + tokens, "LOAD_ONES", var_prefix="ONES" + ) + self.assertIsNotNone(result) + self.assertEqual(result.var_name, "ones_var") + self.assertEqual(result.hbm_address, 42) + + def test_parse_meta_field_from_mem_tokens_no_name(self): + """Test parsing a metadata field without explicit name.""" + tokens = ["dload", "LOAD_ONES", "42"] + result = MemInfo.Metadata.parse_meta_field_from_mem_tokens( + tokens, "LOAD_ONES", var_prefix="ONES" + ) + self.assertIsNotNone(result) + self.assertEqual(result.var_name, "ONES_42") + self.assertEqual(result.hbm_address, 42) + + def test_parse_meta_field_from_mem_tokens_with_extra(self): + """Test parsing a metadata field with var_extra.""" + tokens = ["dload", "LOAD_ONES", "42"] + result = MemInfo.Metadata.parse_meta_field_from_mem_tokens( + tokens, "LOAD_ONES", var_prefix="ONES", var_extra="_extra" + ) + self.assertIsNotNone(result) + self.assertEqual(result.var_name, "ONES_extra") + self.assertEqual(result.hbm_address, 42) + + def test_parse_meta_field_from_mem_tokens_invalid(self): + """Test parsing an invalid metadata field.""" + # Not enough tokens + tokens = ["dload"] + result = MemInfo.Metadata.parse_meta_field_from_mem_tokens( + tokens, "LOAD_ONES", var_prefix="ONES" + ) + self.assertIsNone(result) + + # Wrong first token + tokens = ["wrong", "LOAD_ONES", "42"] + result = MemInfo.Metadata.parse_meta_field_from_mem_tokens( + tokens, "LOAD_ONES", var_prefix="ONES" + ) + self.assertIsNone(result) + + # Wrong second token + tokens = ["dload", "WRONG", "42"] + result = MemInfo.Metadata.parse_meta_field_from_mem_tokens( + tokens, "LOAD_ONES", var_prefix="ONES" + ) + self.assertIsNone(result) + + def test_metadata_init_and_properties(self): + """Test initialization and properties of Metadata class.""" + # Prepare test data + metadata_dict = { + "ones": [{"var_name": "ones_var", "hbm_address": 1}], + "ntt_auxiliary_table": [{"var_name": "ntt_aux", "hbm_address": 2}], + "ntt_routing_table": [{"var_name": "ntt_route", "hbm_address": 3}], + "intt_auxiliary_table": [{"var_name": "intt_aux", "hbm_address": 4}], + "intt_routing_table": [{"var_name": "intt_route", "hbm_address": 5}], + "twid": [{"var_name": "twiddle_var", "hbm_address": 6}], + "keygen_seed": [{"var_name": "keygen_seed", "hbm_address": 7}], + } + + # Initialize Metadata + metadata = MemInfo.Metadata(**metadata_dict) + + # Test property access + self.assertEqual(len(metadata.ones), 1) + self.assertEqual(metadata.ones[0].var_name, "ones_var") + + self.assertEqual(len(metadata.ntt_auxiliary_table), 1) + self.assertEqual(metadata.ntt_auxiliary_table[0].var_name, "ntt_aux") + + self.assertEqual(len(metadata.ntt_routing_table), 1) + self.assertEqual(metadata.ntt_routing_table[0].var_name, "ntt_route") + + self.assertEqual(len(metadata.intt_auxiliary_table), 1) + self.assertEqual(metadata.intt_auxiliary_table[0].var_name, "intt_aux") + + self.assertEqual(len(metadata.intt_routing_table), 1) + self.assertEqual(metadata.intt_routing_table[0].var_name, "intt_route") + + self.assertEqual(len(metadata.twiddle), 1) + self.assertEqual(metadata.twiddle[0].var_name, "twiddle_var") + + self.assertEqual(len(metadata.keygen_seeds), 1) + self.assertEqual(metadata.keygen_seeds[0].var_name, "keygen_seed") + + def test_get_item(self): + """Test the __getitem__ method.""" + metadata_dict = {"ones": [{"var_name": "ones_var", "hbm_address": 1}]} + metadata = MemInfo.Metadata(**metadata_dict) + + # Test __getitem__ using bracket notation + ones_list = metadata["ones"] + self.assertEqual(len(ones_list), 1) + self.assertEqual(ones_list[0].var_name, "ones_var") + + +class TestMemInfoParsers(unittest.TestCase): + """Tests for the various parser methods in MemInfo.""" + + def test_ones_parse_from_mem_tokens(self): + """Test parsing Ones metadata from tokens.""" + tokens = ["dload", "LOAD_ONES", "42", "ones_var"] + with patch( + "assembler.memory_model.mem_info.MemInfo.Metadata.parse_meta_field_from_mem_tokens" + ) as mock_parse: + mock_parse.return_value = MagicMock() + result = MemInfo.Metadata.Ones.parse_from_mem_tokens(tokens) + mock_parse.assert_called_once_with( + tokens, + MemInfo.Const.Keyword.LOAD_ONES, + var_prefix=MemInfo.Const.Keyword.LOAD_ONES, + ) + self.assertEqual(result, mock_parse.return_value) + + def test_ntt_aux_table_parse_from_mem_tokens(self): + """Test parsing NTTAuxTable metadata from tokens.""" + tokens = ["dload", "LOAD_NTT_AUX_TABLE", "42", "ntt_aux_var"] + with patch( + "assembler.memory_model.mem_info.MemInfo.Metadata.parse_meta_field_from_mem_tokens" + ) as mock_parse: + mock_parse.return_value = MagicMock() + result = MemInfo.Metadata.NTTAuxTable.parse_from_mem_tokens(tokens) + mock_parse.assert_called_once_with( + tokens, + MemInfo.Const.Keyword.LOAD_NTT_AUX_TABLE, + var_prefix=MemInfo.Const.Keyword.LOAD_NTT_AUX_TABLE, + ) + self.assertEqual(result, mock_parse.return_value) + + def test_ntt_routing_table_parse_from_mem_tokens(self): + """Test parsing NTTRoutingTable metadata from tokens.""" + tokens = ["dload", "LOAD_NTT_ROUTING_TABLE", "42", "ntt_route_var"] + with patch( + "assembler.memory_model.mem_info.MemInfo.Metadata.parse_meta_field_from_mem_tokens" + ) as mock_parse: + mock_parse.return_value = MagicMock() + result = MemInfo.Metadata.NTTRoutingTable.parse_from_mem_tokens(tokens) + mock_parse.assert_called_once_with( + tokens, + MemInfo.Const.Keyword.LOAD_NTT_ROUTING_TABLE, + var_prefix=MemInfo.Const.Keyword.LOAD_NTT_ROUTING_TABLE, + ) + self.assertEqual(result, mock_parse.return_value) + + def test_intt_aux_table_parse_from_mem_tokens(self): + """Test parsing iNTTAuxTable metadata from tokens.""" + tokens = ["dload", "LOAD_iNTT_AUX_TABLE", "42", "intt_aux_var"] + with patch( + "assembler.memory_model.mem_info.MemInfo.Metadata.parse_meta_field_from_mem_tokens" + ) as mock_parse: + mock_parse.return_value = MagicMock() + result = MemInfo.Metadata.iNTTAuxTable.parse_from_mem_tokens(tokens) + mock_parse.assert_called_once_with( + tokens, + MemInfo.Const.Keyword.LOAD_iNTT_AUX_TABLE, + var_prefix=MemInfo.Const.Keyword.LOAD_iNTT_AUX_TABLE, + ) + self.assertEqual(result, mock_parse.return_value) + + def test_intt_routing_table_parse_from_mem_tokens(self): + """Test parsing iNTTRoutingTable metadata from tokens.""" + tokens = ["dload", "LOAD_iNTT_ROUTING_TABLE", "42", "intt_route_var"] + with patch( + "assembler.memory_model.mem_info.MemInfo.Metadata.parse_meta_field_from_mem_tokens" + ) as mock_parse: + mock_parse.return_value = MagicMock() + result = MemInfo.Metadata.iNTTRoutingTable.parse_from_mem_tokens(tokens) + mock_parse.assert_called_once_with( + tokens, + MemInfo.Const.Keyword.LOAD_iNTT_ROUTING_TABLE, + var_prefix=MemInfo.Const.Keyword.LOAD_iNTT_ROUTING_TABLE, + ) + self.assertEqual(result, mock_parse.return_value) + + def test_twiddle_parse_from_mem_tokens(self): + """Test parsing Twiddle metadata from tokens.""" + tokens = ["dload", "LOAD_TWIDDLE", "42", "twiddle_var"] + with patch( + "assembler.memory_model.mem_info.MemInfo.Metadata.parse_meta_field_from_mem_tokens" + ) as mock_parse: + mock_parse.return_value = MagicMock() + result = MemInfo.Metadata.Twiddle.parse_from_mem_tokens(tokens) + mock_parse.assert_called_once_with( + tokens, + MemInfo.Const.Keyword.LOAD_TWIDDLE, + var_prefix=MemInfo.Const.Keyword.LOAD_TWIDDLE, + ) + self.assertEqual(result, mock_parse.return_value) + + def test_keygen_seed_parse_from_mem_tokens(self): + """Test parsing KeygenSeed metadata from tokens.""" + tokens = ["dload", "LOAD_KEYGEN_SEED", "42", "keygen_seed_var"] + with patch( + "assembler.memory_model.mem_info.MemInfo.Metadata.parse_meta_field_from_mem_tokens" + ) as mock_parse: + mock_parse.return_value = MagicMock() + result = MemInfo.Metadata.KeygenSeed.parse_from_mem_tokens(tokens) + mock_parse.assert_called_once_with( + tokens, + MemInfo.Const.Keyword.LOAD_KEYGEN_SEED, + var_prefix=MemInfo.Const.Keyword.LOAD_KEYGEN_SEED, + ) + self.assertEqual(result, mock_parse.return_value) + + def test_keygen_parse_from_mem_tokens_valid(self): + """Test parsing a valid keygen variable.""" + tokens = ["keygen", "2", "3", "keygen_var"] + result = MemInfo.Keygen.parse_from_mem_tokens(tokens) + self.assertIsNotNone(result) + self.assertEqual(result.var_name, "keygen_var") + self.assertEqual(result.seed_index, 2) + self.assertEqual(result.key_index, 3) + + def test_keygen_parse_from_mem_tokens_invalid(self): + """Test parsing an invalid keygen variable.""" + # Not enough tokens + tokens = ["keygen", "2", "3"] + result = MemInfo.Keygen.parse_from_mem_tokens(tokens) + self.assertIsNone(result) + + # Wrong first token + tokens = ["wrong", "2", "3", "keygen_var"] + result = MemInfo.Keygen.parse_from_mem_tokens(tokens) + self.assertIsNone(result) + + def test_input_parse_from_mem_tokens_valid(self): + """Test parsing a valid input variable.""" + tokens = ["dload", "poly", "42", "input_var"] + with patch( + "assembler.memory_model.variable.Variable.validateName", return_value=True + ): + result = MemInfo.Input.parse_from_mem_tokens(tokens) + self.assertIsNotNone(result) + self.assertEqual(result.var_name, "input_var") + self.assertEqual(result.hbm_address, 42) + + def test_input_parse_from_mem_tokens_invalid(self): + """Test parsing an invalid input variable.""" + # Not enough tokens + tokens = ["dload", "poly", "42"] + result = MemInfo.Input.parse_from_mem_tokens(tokens) + self.assertIsNone(result) + + # Wrong tokens + tokens = ["wrong", "poly", "42", "input_var"] + result = MemInfo.Input.parse_from_mem_tokens(tokens) + self.assertIsNone(result) + + tokens = ["dload", "wrong", "42", "input_var"] + result = MemInfo.Input.parse_from_mem_tokens(tokens) + self.assertIsNone(result) + + def test_output_parse_from_mem_tokens_valid(self): + """Test parsing a valid output variable.""" + tokens = ["dstore", "output_var", "42"] + with patch( + "assembler.memory_model.variable.Variable.validateName", return_value=True + ): + result = MemInfo.Output.parse_from_mem_tokens(tokens) + self.assertIsNotNone(result) + self.assertEqual(result.var_name, "output_var") + self.assertEqual(result.hbm_address, 42) + + def test_output_parse_from_mem_tokens_invalid(self): + """Test parsing an invalid output variable.""" + # Not enough tokens + tokens = ["store", "output_var"] + result = MemInfo.Output.parse_from_mem_tokens(tokens) + self.assertIsNone(result) + + # Wrong first token + tokens = ["wrong", "output_var", "42"] + result = MemInfo.Output.parse_from_mem_tokens(tokens) + self.assertIsNone(result) + + +class TestMemInfo(unittest.TestCase): + """Tests for the MemInfo class.""" + + def test_init_default(self): + """Test default initialization.""" + mem_info = MemInfo() + self.assertEqual(len(mem_info.keygens), 0) + self.assertEqual(len(mem_info.inputs), 0) + self.assertEqual(len(mem_info.outputs), 0) + # Verify metadata was initialized + self.assertIsInstance(mem_info.metadata, MemInfo.Metadata) + + def test_init_with_data(self): + """Test initialization with data.""" + # Prepare test data + test_data = { + "keygens": [{"var_name": "keygen_var", "seed_index": 1, "key_index": 2}], + "inputs": [{"var_name": "input_var", "hbm_address": 42}], + "outputs": [{"var_name": "output_var", "hbm_address": 43}], + "metadata": { + "ones": [{"var_name": "ones_var", "hbm_address": 44}], + "twid": [{"var_name": "twiddle_var", "hbm_address": 45}], + }, + } + + # Initialize with test data + with patch("assembler.memory_model.mem_info.MemInfo.validate"): + mem_info = MemInfo(**test_data) + + # Verify data was loaded correctly + self.assertEqual(len(mem_info.keygens), 1) + self.assertEqual(mem_info.keygens[0].var_name, "keygen_var") + + self.assertEqual(len(mem_info.inputs), 1) + self.assertEqual(mem_info.inputs[0].var_name, "input_var") + + self.assertEqual(len(mem_info.outputs), 1) + self.assertEqual(mem_info.outputs[0].var_name, "output_var") + + # Verify metadata + self.assertEqual(len(mem_info.metadata.ones), 1) + self.assertEqual(mem_info.metadata.ones[0].var_name, "ones_var") + + self.assertEqual(len(mem_info.metadata.twiddle), 1) + self.assertEqual(mem_info.metadata.twiddle[0].var_name, "twiddle_var") + + def test_factory_dict(self): + """Test the factory_dict property.""" + mem_info = MemInfo() + factory_dict = mem_info.factory_dict + + # Verify all expected keys are present + self.assertIn(MemInfo.Keygen, factory_dict) + self.assertIn(MemInfo.Input, factory_dict) + self.assertIn(MemInfo.Output, factory_dict) + self.assertIn(MemInfo.Metadata.KeygenSeed, factory_dict) + self.assertIn(MemInfo.Metadata.Ones, factory_dict) + self.assertIn(MemInfo.Metadata.NTTAuxTable, factory_dict) + self.assertIn(MemInfo.Metadata.NTTRoutingTable, factory_dict) + self.assertIn(MemInfo.Metadata.iNTTAuxTable, factory_dict) + self.assertIn(MemInfo.Metadata.iNTTRoutingTable, factory_dict) + self.assertIn(MemInfo.Metadata.Twiddle, factory_dict) + + # Verify values point to correct lists + self.assertEqual(factory_dict[MemInfo.Keygen], mem_info.keygens) + self.assertEqual(factory_dict[MemInfo.Input], mem_info.inputs) + self.assertEqual(factory_dict[MemInfo.Output], mem_info.outputs) + + def test_mem_info_types(self): + """Test the mem_info_types class property.""" + mem_info_types = MemInfo.mem_info_types + + # Verify expected types are in the list + self.assertIn(MemInfo.Keygen, mem_info_types) + self.assertIn(MemInfo.Input, mem_info_types) + self.assertIn(MemInfo.Output, mem_info_types) + self.assertIn(MemInfo.Metadata.KeygenSeed, mem_info_types) + self.assertIn(MemInfo.Metadata.Ones, mem_info_types) + self.assertIn(MemInfo.Metadata.NTTAuxTable, mem_info_types) + self.assertIn(MemInfo.Metadata.NTTRoutingTable, mem_info_types) + self.assertIn(MemInfo.Metadata.iNTTAuxTable, mem_info_types) + self.assertIn(MemInfo.Metadata.iNTTRoutingTable, mem_info_types) + self.assertIn(MemInfo.Metadata.Twiddle, mem_info_types) + + def test_get_meminfo_var_from_tokens_valid(self): + """Test getting a MemInfo variable from valid tokens.""" + tokens = ["keygen", "2", "3", "keygen_var"] + + # Mock the parse_from_mem_tokens method to return a mock variable + mock_variable = MagicMock() + with patch.object( + MemInfo.Keygen, "parse_from_mem_tokens", return_value=mock_variable + ): + var, var_type = MemInfo.get_meminfo_var_from_tokens(tokens) + + # Verify results + self.assertEqual(var, mock_variable) + self.assertEqual(var_type, MemInfo.Keygen) + + def test_get_meminfo_var_from_tokens_not_found(self): + """Test getting a MemInfo variable when no parser can handle it.""" + tokens = ["unknown", "token"] + + # Mock all parse_from_mem_tokens methods to return None + with patch.object( + MemInfo, + "mem_info_types", + return_value=[ + MagicMock(parse_from_mem_tokens=MagicMock(return_value=None)) + ], + ): + var, var_type = MemInfo.get_meminfo_var_from_tokens(tokens) + + # Verify results + self.assertIsNone(var) + self.assertIsNone(var_type) + + def test_add_meminfo_var_from_tokens_valid(self): + """Test adding a MemInfo variable from valid tokens.""" + tokens = ["keygen", "2", "3", "keygen_var"] + mem_info = MemInfo() + + # Mock get_meminfo_var_from_tokens + mock_variable = MagicMock() + mock_type = MagicMock() + mock_list = MagicMock() + mock_dict = {mock_type: mock_list} + + with patch.object( + MemInfo, + "get_meminfo_var_from_tokens", + return_value=(mock_variable, mock_type), + ), patch.object( + MemInfo, "factory_dict", new_callable=PropertyMock, return_value=mock_dict + ): + + # Call the method + mem_info.add_meminfo_var_from_tokens(tokens) + + # Verify the variable was added to the correct list + mock_list.append.assert_called_once_with(mock_variable) + + def test_add_meminfo_var_from_tokens_not_found(self): + """Test adding a MemInfo variable when no parser can handle it.""" + tokens = ["unknown", "token"] + mem_info = MemInfo() + + # Mock get_meminfo_var_from_tokens to return None + with patch.object( + MemInfo, "get_meminfo_var_from_tokens", return_value=(None, None) + ): + # Verify exception is raised + with self.assertRaises(RuntimeError): + mem_info.add_meminfo_var_from_tokens(tokens) + + def test_from_file_iter_valid(self): + """Test creating a MemInfo from a valid file iterator.""" + # Mock lines + lines = [ + "keygen, 2, 3, keygen_var", + "dload, input, 42, input_var", + "store, output_var, 43", + "dload, LOAD_ONES, 44, ones_var", + " ", # Empty line to test skipping + "# Comment line", # Comment line to test skipping + ] + + # Mock tokenize_from_line + def mock_tokenize(line): + if line.startswith("keygen"): + return (["keygen", "2", "3", "keygen_var"], "") + if line.startswith("dload, input"): + return (["dload", "input", "42", "input_var"], "") + if line.startswith("store"): + return (["store", "output_var", "43"], "") + if line.startswith("dload, LOAD_ONES"): + return (["dload", "LOAD_ONES", "44", "ones_var"], "") + + return ([], "") + + # Mock methods - patch the function where it's imported in mem_info + with patch( + "assembler.memory_model.mem_info.tokenize_from_line", + side_effect=mock_tokenize, + ), patch.object( + MemInfo, "add_meminfo_var_from_tokens" + ) as mock_add_var, patch.object( + MemInfo, "validate" + ): + + # Call the method + MemInfo.from_file_iter(lines) + + # Verify add_meminfo_var_from_tokens was called for each valid line + self.assertEqual(mock_add_var.call_count, 4) + + def test_from_file_iter_error(self): + """Test creating a MemInfo when an error occurs.""" + # Mock lines + lines = ["invalid line"] + + # Mock tokenize_from_line + def mock_tokenize(line): + return (["invalid"], line) + + # Mock methods + with patch( + "assembler.instructions.tokenize_from_line", side_effect=mock_tokenize + ), patch.object( + MemInfo, + "add_meminfo_var_from_tokens", + side_effect=RuntimeError("Test error"), + ), patch.object( + MemInfo, "validate" + ): + + # Verify exception is raised with line number and content + with self.assertRaises(RuntimeError) as context: + MemInfo.from_file_iter(lines) + + self.assertIn("1: invalid line", str(context.exception)) + + def test_from_dinstrs_valid(self): + """Test creating a MemInfo from valid DInstructions.""" + # Mock DInstructions + dinstrs = [ + MagicMock(tokens=["keygen", "2", "3", "keygen_var"]), + MagicMock(tokens=["dload", "input", "42", "input_var"]), + MagicMock(tokens=["store", "output_var", "43"]), + ] + + # Mock methods + with patch.object( + MemInfo, "add_meminfo_var_from_tokens" + ) as mock_add_var, patch.object(MemInfo, "validate"), patch( + "builtins.print" + ): # Mock print to avoid output + + # Call the method + MemInfo.from_dinstrs(dinstrs) + + # Verify add_meminfo_var_from_tokens was called for each instruction + self.assertEqual(mock_add_var.call_count, 3) + mock_add_var.assert_has_calls( + [ + call(["keygen", "2", "3", "keygen_var"]), + call(["dload", "input", "42", "input_var"]), + call(["store", "output_var", "43"]), + ] + ) + + def test_from_dinstrs_error(self): + """Test creating a MemInfo when an error occurs.""" + # Mock DInstructions + dinstrs = [MagicMock(tokens=["invalid"])] + + # Mock methods + with patch.object( + MemInfo, + "add_meminfo_var_from_tokens", + side_effect=RuntimeError("Test error"), + ), patch.object(MemInfo, "validate"), patch( + "builtins.print" + ): # Mock print to avoid output + + # Verify exception is raised with instruction number + with self.assertRaises(RuntimeError) as context: + MemInfo.from_dinstrs(dinstrs) + + self.assertIn("1: ['invalid']", str(context.exception)) + + def test_as_dict(self): + """Test the as_dict method.""" + # Create a MemInfo with test data + with patch("assembler.memory_model.mem_info.MemInfo.validate"): + + # dicts + keygens_dict = {"var_name": "keygen_var", "seed_index": 1, "key_index": 2} + inputs_dict = {"var_name": "input_var", "hbm_address": 42} + outputs_dict = {"var_name": "output_var", "hbm_address": 43} + ones_dict = {"var_name": "ones_var", "hbm_address": 44} + twiddle_dict = {"var_name": "twiddle_var", "hbm_address": 45} + + # Prepare test data + test_data = { + "keygens": [keygens_dict], + "inputs": [inputs_dict], + "outputs": [outputs_dict], + "metadata": { + "ones": [ones_dict], + "twid": [twiddle_dict], + }, + } + + # Create the MemInfo with the test data + mem_info = MemInfo(**test_data) + + # Call the method + result = mem_info.as_dict() + + # Verify result structure + self.assertIn("keygens", result) + self.assertIn("inputs", result) + self.assertIn("outputs", result) + self.assertIn("metadata", result) + + # Verify values + self.assertEqual(result["keygens"], [keygens_dict]) + self.assertEqual(result["inputs"], [inputs_dict]) + self.assertEqual(result["outputs"], [outputs_dict]) + self.assertIn("ones", result["metadata"]) + self.assertEqual(result["metadata"]["ones"], [ones_dict]) + + def test_validate_valid(self): + """Test validation with valid data.""" + + ones_dict = {"var_name": "ones_var", "hbm_address": 44} + twiddle_dict = {"var_name": "twiddle_var", "hbm_address": 45} + + twiddle_list = [ + twiddle_dict for _ in range(MemoryModel.MAX_TWIDDLE_META_VARS_PER_SEGMENT) + ] + + # Create metadata dictionary for initialization + metadata_dict = { + "ones": [ones_dict], + "twid": twiddle_list, + } + + # Initialize without validation to set up the test + mem_info = MemInfo(metadata=metadata_dict) + + # Now explicitly call validate which we want to test + mem_info.validate() # Should not raise any exceptions + + def test_validate_twiddle_mismatch(self): + """Test validation with mismatched twiddle count.""" + + ones_dict = {"var_name": "ones_var", "hbm_address": 44} + twiddle_dict = {"var_name": "twiddle_var", "hbm_address": 45} + + # Create metadata dictionary for initialization + metadata_dict = { + "ones": [ones_dict], + "twid": [twiddle_dict], + } + + # Call MemInfo initialization with metadata + with patch( + "assembler.memory_model.mem_info.MemoryModel.MAX_TWIDDLE_META_VARS_PER_SEGMENT", + 2, + ): + with self.assertRaises(RuntimeError) as context: + # Initialize without validation to set up the test + MemInfo(metadata=metadata_dict) + + self.assertIn( + "Expected 2 times as many twiddles as ones", str(context.exception) + ) + + def test_validate_duplicate_var_name(self): + """Test validation with duplicate variable names but different HBM addresses.""" + # Create variable dictionaries with duplicate names but different addresses + intt_aux_dict = {"var_name": "duplicate", "hbm_address": 1} + ntt_route_dict = {"var_name": "duplicate", "hbm_address": 2} + + # Create metadata dictionary for initialization with the duplicate variables + metadata_dict = { + "ones": [], + "twid": [], + "intt_auxiliary_table": [intt_aux_dict], + "ntt_routing_table": [ntt_route_dict], + "ntt_auxiliary_table": [], + "intt_routing_table": [], + } + + # Initialize MemInfo with the metadata containing duplicates + with self.assertRaises(RuntimeError) as context: + MemInfo(metadata=metadata_dict) + + self.assertIn( + 'Variable "duplicate" already allocated', str(context.exception) + ) + + +class TestUpdateMemoryModelWithMemInfo(unittest.TestCase): + """Tests for the updateMemoryModelWithMemInfo function.""" + + def setUp(self): + """Set up common test fixtures.""" + # Create mock MemoryModel + self.mock_mem_model = MagicMock() + self.mock_mem_model.retrieveVarAdd = MagicMock() + + # Create mock MemInfo + self.mock_mem_info = MagicMock() + + # Group all mock variables in a dictionary + self.vars = { + "input": MagicMock(var_name="input_var", hbm_address=1), + "output": MagicMock(var_name="output_var", hbm_address=2), + "ones": MagicMock(var_name="ones_var", hbm_address=3), + "ntt_aux": MagicMock(var_name="ntt_aux", hbm_address=4), + "ntt_route": MagicMock(var_name="ntt_route", hbm_address=5), + "intt_aux": MagicMock(var_name="intt_aux", hbm_address=6), + "intt_route": MagicMock(var_name="intt_route", hbm_address=7), + "twiddle": MagicMock(var_name="twiddle_var", hbm_address=8), + "keygen_seed": MagicMock(var_name="keygen_seed", hbm_address=9), + } + + # Set up MemInfo + self.mock_mem_info.inputs = [self.vars["input"]] + self.mock_mem_info.outputs = [self.vars["output"]] + + # Set up metadata + self.mock_metadata = MagicMock() + self.mock_metadata.ones = [self.vars["ones"]] + self.mock_metadata.ntt_auxiliary_table = [self.vars["ntt_aux"]] + self.mock_metadata.ntt_routing_table = [self.vars["ntt_route"]] + self.mock_metadata.intt_auxiliary_table = [self.vars["intt_aux"]] + self.mock_metadata.intt_routing_table = [self.vars["intt_route"]] + self.mock_metadata.twiddle = [self.vars["twiddle"]] + self.mock_metadata.keygen_seeds = [self.vars["keygen_seed"]] + + self.mock_mem_info.metadata = self.mock_metadata + + def test_update_memory_model_inputs(self): + """Test updating memory model with input variables.""" + # Call the function + with patch( + "assembler.memory_model.mem_info._allocateMemInfoVariable" + ) as mock_allocate: + updateMemoryModelWithMemInfo(self.mock_mem_model, self.mock_mem_info) + + # Verify input variables were allocated + mock_allocate.assert_any_call(self.mock_mem_model, self.vars["input"]) + + def test_update_memory_model_outputs(self): + """Test updating memory model with output variables.""" + # Call the function + with patch( + "assembler.memory_model.mem_info._allocateMemInfoVariable" + ) as mock_allocate: + updateMemoryModelWithMemInfo(self.mock_mem_model, self.mock_mem_info) + + # Verify output variables were allocated and added to output_variables + mock_allocate.assert_any_call(self.mock_mem_model, self.vars["output"]) + self.mock_mem_model.output_variables.push.assert_called_once_with( + "output_var", None + ) + + def test_update_memory_model_metadata(self): + """Test updating memory model with metadata variables.""" + # Call the function + with patch( + "assembler.memory_model.mem_info._allocateMemInfoVariable" + ) as mock_allocate: + updateMemoryModelWithMemInfo(self.mock_mem_model, self.mock_mem_info) + + # Verify metadata variables were retrieved, allocated and added to their respective lists + self.mock_mem_model.retrieveVarAdd.assert_has_calls( + [ + call("ones_var"), + call("ntt_aux"), + call("ntt_route"), + call("intt_aux"), + call("intt_route"), + call("twiddle_var"), + call("keygen_seed"), + ], + any_order=True, + ) + + mock_allocate.assert_has_calls( + [ + call(self.mock_mem_model, self.vars["ones"]), + call(self.mock_mem_model, self.vars["ntt_aux"]), + call(self.mock_mem_model, self.vars["ntt_route"]), + call(self.mock_mem_model, self.vars["intt_aux"]), + call(self.mock_mem_model, self.vars["intt_route"]), + call(self.mock_mem_model, self.vars["twiddle"]), + call(self.mock_mem_model, self.vars["keygen_seed"]), + ], + any_order=True, + ) + + self.mock_mem_model.add_meta_ones_var.assert_called_once_with("ones_var") + self.assertEqual(self.mock_mem_model.meta_ntt_aux_table, "ntt_aux") + self.assertEqual(self.mock_mem_model.meta_ntt_routing_table, "ntt_route") + self.assertEqual(self.mock_mem_model.meta_intt_aux_table, "intt_aux") + self.assertEqual(self.mock_mem_model.meta_intt_routing_table, "intt_route") + self.mock_mem_model.add_meta_twiddle_var.assert_called_once_with( + "twiddle_var" + ) + self.mock_mem_model.add_meta_keygen_seed_var.assert_called_once_with( + "keygen_seed" + ) + + +class TestAllocateMemInfoVariable(unittest.TestCase): + """Tests for the _allocateMemInfoVariable function.""" + + def test_allocate_mem_info_variable_success(self): + """Test successful allocation of a MemInfo variable.""" + # Create mock MemoryModel and variable + mock_mem_model = MagicMock() + mock_var_info = MagicMock(var_name="test_var", hbm_address=42) + + # Mock variables dictionary + mock_mem_model.variables = {"test_var": MagicMock(hbm_address=-1)} + + # Call the function + with patch( + "assembler.memory_model.mem_info._allocateMemInfoVariable" + ) as mock_function: + # Make it actually call the real function - simplified without lambda + mock_function.original = _allocateMemInfoVariable + mock_function.side_effect = mock_function.original + + mock_function(mock_mem_model, mock_var_info) + + # Verify the variable was allocated + mock_mem_model.hbm.allocateForce.assert_called_once_with( + 42, mock_mem_model.variables["test_var"] + ) + + def test_allocate_mem_info_variable_not_in_model(self): + """Test allocation when the variable is not in the memory model.""" + # Create mock MemoryModel and variable + mock_mem_model = MagicMock() + mock_var_info = MagicMock(var_name="missing_var", hbm_address=42) + + # Mock variables dictionary (missing the variable) + mock_mem_model.variables = {} + + # Call the function + with patch( + "assembler.memory_model.mem_info._allocateMemInfoVariable" + ) as mock_function: + # Make it actually call the real function - simplified without lambda + mock_function.original = _allocateMemInfoVariable + mock_function.side_effect = mock_function.original + + # Verify exception is raised + with self.assertRaises(RuntimeError) as context: + mock_function(mock_mem_model, mock_var_info) + + self.assertIn( + "Variable missing_var not in memory model", str(context.exception) + ) + + def test_allocate_mem_info_variable_mismatch(self): + """Test allocation when the variable has a different HBM address.""" + # Create mock MemoryModel and variable + mock_mem_model = MagicMock() + mock_var_info = MagicMock(var_name="test_var", hbm_address=42) + + # Mock variables dictionary with a variable that already has a different HBM address + mock_mem_model.variables = {"test_var": MagicMock(hbm_address=24)} + + # Call the function + with patch( + "assembler.memory_model.mem_info._allocateMemInfoVariable" + ) as mock_function: + # Make it actually call the real function - simplified without lambda + mock_function.original = _allocateMemInfoVariable + mock_function.side_effect = mock_function.original + + # Verify exception is raised + with self.assertRaises(RuntimeError) as context: + mock_function(mock_mem_model, mock_var_info) + + self.assertIn( + "Variable test_var already allocated in HBM address 24", + str(context.exception), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/__init__.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/__init__.py new file mode 100644 index 00000000..4057dc01 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_init.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_init.py new file mode 100644 index 00000000..7ef5e968 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_init.py @@ -0,0 +1,324 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +Unit tests for the memory model classes in linker/__init__.py. +""" + +import unittest +from unittest.mock import MagicMock, patch + +from assembler.common.config import GlobalConfig +from assembler.memory_model import mem_info +from linker import VariableInfo, HBM, MemoryModel + + +class TestVariableInfo(unittest.TestCase): + """Tests for the VariableInfo class.""" + + def test_init(self): + """Test initialization of VariableInfo.""" + var_info = VariableInfo("test_var", 42) + self.assertEqual(var_info.var_name, "test_var") + self.assertEqual(var_info.hbm_address, 42) + self.assertEqual(var_info.uses, 0) + self.assertEqual(var_info.last_kernel_used, -1) + + def test_init_default_values(self): + """Test initialization with default values.""" + var_info = VariableInfo("test_var") + self.assertEqual(var_info.var_name, "test_var") + self.assertEqual(var_info.hbm_address, -1) + self.assertEqual(var_info.uses, 0) + self.assertEqual(var_info.last_kernel_used, -1) + + +class TestHBM(unittest.TestCase): + """Tests for the HBM class.""" + + def setUp(self): + """Set up test fixtures.""" + self.hbm_size = 10 + self.hbm = HBM(self.hbm_size) + + def test_init(self): + """Test initialization of HBM.""" + self.assertEqual(len(self.hbm.buffer), self.hbm_size) + self.assertEqual(self.hbm.capacity, self.hbm_size) + # Check that buffer is initialized with None values + for item in self.hbm.buffer: + self.assertIsNone(item) + + def test_init_invalid_size(self): + """Test initialization with invalid size.""" + with self.assertRaises(ValueError): + HBM(0) + with self.assertRaises(ValueError): + HBM(-1) + + def test_capacity_property(self): + """Test the capacity property.""" + self.assertEqual(self.hbm.capacity, self.hbm_size) + + def test_buffer_property(self): + """Test the buffer property.""" + buffer = self.hbm.buffer + self.assertEqual(len(buffer), self.hbm_size) + # Check that buffer is initialized with None values + for item in buffer: + self.assertIsNone(item) + + def test_force_allocate_valid(self): + """Test forceAllocate with valid parameters.""" + var_info = VariableInfo("test_var") + self.hbm.forceAllocate(var_info, 5) + self.assertEqual(var_info.hbm_address, 5) + self.assertEqual(self.hbm.buffer[5], var_info) + + def test_force_allocate_out_of_bounds(self): + """Test forceAllocate with out of bounds address.""" + var_info = VariableInfo("test_var") + with self.assertRaises(IndexError): + self.hbm.forceAllocate(var_info, -1) + with self.assertRaises(IndexError): + self.hbm.forceAllocate(var_info, self.hbm_size) + + def test_force_allocate_already_allocated(self): + """Test forceAllocate with already allocated variable.""" + var_info = VariableInfo("test_var", 3) + with self.assertRaises(ValueError): + self.hbm.forceAllocate(var_info, 5) + + def test_force_allocate_address_occupied_with_hbm(self): + """Test forceAllocate with address occupied and HBM enabled.""" + with patch.object(GlobalConfig, "hasHBM", True): + # Occupy address 5 + var_info1 = VariableInfo("var1") + var_info1.uses = 1 + self.hbm.forceAllocate(var_info1, 5) + + # Try to allocate another variable at the same address + var_info2 = VariableInfo("var2") + with self.assertRaises(RuntimeError): + self.hbm.forceAllocate(var_info2, 5) + + def test_force_allocate_address_occupied_without_hbm(self): + """Test forceAllocate with address occupied and HBM disabled.""" + with patch.object(GlobalConfig, "hasHBM", False): + # Occupy address 5 + var_info1 = VariableInfo("var1") + var_info1.uses = 1 + self.hbm.forceAllocate(var_info1, 5) + + # Try to allocate another variable at the same address + var_info2 = VariableInfo("var2") + with self.assertRaises(RuntimeError): + self.hbm.forceAllocate(var_info2, 5) + + def test_force_allocate_address_recyclable_with_hbm(self): + """Test forceAllocate with recyclable address and HBM enabled.""" + with patch.object(GlobalConfig, "hasHBM", True): + # Occupy address 5 with a variable that's not used + var_info1 = VariableInfo("var1") + var_info1.uses = 0 + var_info1.last_kernel_used = 1 + self.hbm.forceAllocate(var_info1, 5) + + # Allocate another variable at the same address with higher kernel index + var_info2 = VariableInfo("var2") + var_info2.last_kernel_used = 2 + self.hbm.forceAllocate(var_info2, 5) + + # Check that the new variable is at the address + self.assertEqual(self.hbm.buffer[5], var_info2) + + def test_allocate(self): + """Test allocate method.""" + var_info = VariableInfo("test_var") + self.hbm.allocate(var_info) + # The variable should be allocated at the first available address (0) + self.assertEqual(var_info.hbm_address, 0) + self.assertEqual(self.hbm.buffer[0], var_info) + + def test_allocate_full_memory(self): + """Test allocate with full memory.""" + # Fill up the HBM + for i in range(self.hbm_size): + var_info = VariableInfo(f"var{i}") + var_info.uses = 1 + self.hbm.forceAllocate(var_info, i) + + # Try to allocate another variable + var_info = VariableInfo("test_var") + with self.assertRaises(RuntimeError): + self.hbm.allocate(var_info) + + def test_allocate_with_recycling(self): + """Test allocate with recycling unused addresses.""" + with patch.object(GlobalConfig, "hasHBM", True): + # Fill up the HBM + for i in range(self.hbm_size): + var_info = VariableInfo(f"var{i}") + var_info.uses = 1 if i != 3 else 0 + var_info.last_kernel_used = 1 + self.hbm.forceAllocate(var_info, i) + + # Allocate a new variable - should reuse address 3 + var_info = VariableInfo("test_var") + var_info.last_kernel_used = 2 + self.hbm.allocate(var_info) + self.assertEqual(var_info.hbm_address, 3) + + +class TestMemoryModel(unittest.TestCase): + """Tests for the MemoryModel class.""" + + def setUp(self): + """Set up test fixtures.""" + # Create a mock MemInfo + self.mock_mem_info = MagicMock(spec=mem_info.MemInfo) + + # Set up mock input variables + self.input_var = MagicMock(spec=mem_info.MemInfoVariable) + self.input_var.var_name = "input_var" + self.input_var.hbm_address = 1 + + # Set up mock output variables + self.output_var = MagicMock(spec=mem_info.MemInfoVariable) + self.output_var.var_name = "output_var" + self.output_var.hbm_address = 2 + + # Set up mock keygen variables + self.keygen_var = MagicMock(spec=mem_info.MemInfoVariable) + self.keygen_var.var_name = "keygen_var" + self.keygen_var.hbm_address = 3 + + # Set up mock metadata variables + self.meta_var = MagicMock(spec=mem_info.MemInfoVariable) + self.meta_var.var_name = "meta_var" + self.meta_var.hbm_address = 4 + + # Configure mock mem_info + self.mock_mem_info.inputs = [self.input_var] + self.mock_mem_info.outputs = [self.output_var] + self.mock_mem_info.keygens = [self.keygen_var] + + # Configure metadata + mock_metadata = MagicMock() + mock_metadata.intt_auxiliary_table = [self.meta_var] + mock_metadata.intt_routing_table = [] + mock_metadata.ntt_auxiliary_table = [] + mock_metadata.ntt_routing_table = [] + mock_metadata.ones = [] + mock_metadata.twiddle = [] + mock_metadata.keygen_seeds = [] + self.mock_mem_info.metadata = mock_metadata + + # Create the memory model + self.memory_model = MemoryModel(10, self.mock_mem_info) + + def test_init(self): + """Test initialization of MemoryModel.""" + self.assertIsInstance(self.memory_model.hbm, HBM) + self.assertEqual(self.memory_model.hbm.capacity, 10) + + # Check that variables are correctly initialized + self.assertEqual(len(self.memory_model.variables), 0) + + # Check mem_info_vars + self.assertIn("input_var", self.memory_model.mem_info_vars) + self.assertIn("output_var", self.memory_model.mem_info_vars) + self.assertIn("meta_var", self.memory_model.mem_info_vars) + self.assertNotIn("keygen_var", self.memory_model.mem_info_vars) + + # Check mem_info_meta + self.assertIn("meta_var", self.memory_model.mem_info_meta) + + def test_add_variable_new(self): + """Test adding a new variable.""" + self.memory_model.addVariable("test_var") + + # Check that variable was added + self.assertIn("test_var", self.memory_model.variables) + var_info = self.memory_model.variables["test_var"] + self.assertEqual(var_info.var_name, "test_var") + self.assertEqual(var_info.uses, 1) + + # Since it's not in mem_info_vars, it should not have an HBM address yet + self.assertEqual(var_info.hbm_address, -1) + + def test_add_variable_existing(self): + """Test adding an existing variable.""" + # Add the variable first + self.memory_model.addVariable("test_var") + + # Add it again + self.memory_model.addVariable("test_var") + + # Check that the uses were incremented + var_info = self.memory_model.variables["test_var"] + self.assertEqual(var_info.uses, 2) + + def test_add_variable_from_mem_info(self): + """Test adding a variable that's in mem_info.""" + self.memory_model.addVariable("input_var") + + # Check that variable was added + self.assertIn("input_var", self.memory_model.variables) + var_info = self.memory_model.variables["input_var"] + self.assertEqual(var_info.var_name, "input_var") + self.assertEqual(var_info.uses, 1) + + # It should have the HBM address from mem_info + self.assertEqual(var_info.hbm_address, 1) + + def test_add_variable_from_fixed_addr_vars(self): + """Test adding a variable that's in fixed_addr_vars.""" + self.memory_model.addVariable("output_var") + + # Check that variable was added + self.assertIn("output_var", self.memory_model.variables) + var_info = self.memory_model.variables["output_var"] + + # It should have infinite uses (float('inf')) + self.assertEqual(var_info.uses, float("inf") + 1) + + # It should have the HBM address from mem_info + self.assertEqual(var_info.hbm_address, 2) + + def test_use_variable(self): + """Test using a variable.""" + # Add the variable first + self.memory_model.addVariable("test_var") + + # Use the variable + hbm_address = self.memory_model.useVariable("test_var", 1) + + # Check that uses were decremented + var_info = self.memory_model.variables["test_var"] + self.assertEqual(var_info.uses, 0) + + # Check that last_kernel_used was updated + self.assertEqual(var_info.last_kernel_used, 1) + + # Check that hbm_address was allocated and returned + self.assertGreaterEqual(hbm_address, 0) + self.assertEqual(var_info.hbm_address, hbm_address) + + # Check that the variable is in the HBM buffer + self.assertEqual(self.memory_model.hbm.buffer[hbm_address], var_info) + + def test_use_variable_already_allocated(self): + """Test using a variable that already has an HBM address.""" + # Add a variable from mem_info which already has an HBM address + self.memory_model.addVariable("input_var") + + # Use the variable + hbm_address = self.memory_model.useVariable("input_var", 1) + + # Check that the returned HBM address is the one from mem_info + self.assertEqual(hbm_address, 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/__init__.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/__init__.py new file mode 100644 index 00000000..58bdd8eb --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/__init__.py @@ -0,0 +1,4 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# Empty init file to make the directory a Python package diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/__init__.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/__init__.py new file mode 100644 index 00000000..58bdd8eb --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/__init__.py @@ -0,0 +1,4 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# Empty init file to make the directory a Python package diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dinstruction.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dinstruction.py new file mode 100644 index 00000000..b0e7219a --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dinstruction.py @@ -0,0 +1,95 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +Unit tests for the DInstruction base class. + +This module tests the core functionality of the DInstruction class which +serves as the base for all data instructions. +""" + +import unittest + +from linker.instructions.dinst.dinstruction import DInstruction + + +class TestDInstruction(unittest.TestCase): + """ + Test cases for the DInstruction base class. + + These tests verify the common functionality shared by all data instructions, + including token handling, ID generation, and property access. + """ + + def setUp(self): + # Create a concrete subclass for testing since DInstruction is abstract + class ConcreteDInstruction(DInstruction): + """ + Concrete implementation of DInstruction for testing purposes. + + This class provides implementations of the abstract methods + required to instantiate and test the DInstruction class. + """ + + @classmethod + def _get_num_tokens(cls) -> int: + return 3 + + @classmethod + def _get_name(cls) -> str: + return "test_instruction" + + self.d_instruction_class = ConcreteDInstruction # Changed to snake_case + self.tokens = ["test_instruction", "var1", "123"] + self.comment = "Test comment" + self.dinst = self.d_instruction_class(self.tokens, self.comment) + + def test_get_name_token_index(self): + """Test _get_name_token_index returns 0""" + self.assertEqual( + self.d_instruction_class.name_token_index, 0 + ) # Updated reference + + def test_num_tokens_property(self): + """Test num_tokens property returns expected value""" + self.assertEqual(self.d_instruction_class.num_tokens, 3) # Updated reference + + def test_initialization_valid_tokens(self): + """Test initialization with valid tokens""" + inst = self.d_instruction_class(self.tokens, self.comment) + self.assertEqual(inst.tokens, self.tokens) + self.assertEqual(inst.comment, self.comment) + self.assertIsNotNone(inst.id) + + def test_initialization_token_count_too_few(self): + """Test initialization with too few tokens""" + with self.assertRaises(ValueError): + self.d_instruction_class(["test_instruction", "var1"]) + + def test_initialization_invalid_name(self): + """Test initialization with invalid name token""" + with self.assertRaises(ValueError): + self.d_instruction_class(["wrong_name", "var1", "123"]) + + def test_id_property(self): + """Test id property returns a unique id""" + inst1 = self.d_instruction_class(self.tokens) + inst2 = self.d_instruction_class(self.tokens) + self.assertNotEqual(inst1.id, inst2.id) + + def test_to_line_method(self): + """Test to_line method returns expected string""" + tokens = ["test_instruction", "var1", "123"] + inst = self.d_instruction_class(tokens, "") + expected = "test_instruction, var1, 123" + self.assertEqual(inst.to_line(), expected) + + def test_consecutive_ids(self): + """Test that consecutive instructions get incremental ids""" + inst1 = self.d_instruction_class(self.tokens) + inst2 = self.d_instruction_class(self.tokens) + self.assertEqual(inst2.id, inst1.id + 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dkeygen.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dkeygen.py new file mode 100644 index 00000000..91a6f207 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dkeygen.py @@ -0,0 +1,116 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +Unit tests for the DKeygen instruction class. + +This module tests the functionality of the DKeygen instruction which is +responsible for key generation operations. +""" + +import unittest +from unittest.mock import patch + +from assembler.memory_model.mem_info import MemInfo +from linker.instructions.dinst.dkeygen import Instruction + + +class TestDKeygenInstruction(unittest.TestCase): + """ + Test cases for the DKeygen instruction class. + + These tests verify that the DKeygen instruction correctly handles token + parsing, name resolution, and serialization. + """ + + def setUp(self): + # Create the instruction with sample parameters + self.seed_idx = 1 + self.key_idx = 2 + self.var_name = "var1" + self.inst = Instruction( + [Instruction.name, self.seed_idx, self.key_idx, self.var_name] + ) + + def test_get_num_tokens(self): + """Test that _get_num_tokens returns 4""" + self.assertEqual(Instruction.num_tokens, 4) + + def test_get_name(self): + """Test that _get_name returns the expected value""" + self.assertEqual(Instruction.name, MemInfo.Const.Keyword.KEYGEN) + + def test_initialization_valid_input(self): + """Test that initialization can set up the correct properties with valid name""" + inst = Instruction( + [MemInfo.Const.Keyword.KEYGEN, self.seed_idx, self.key_idx, self.var_name] + ) + self.assertEqual(inst.name, MemInfo.Const.Keyword.KEYGEN) + + def test_initialization_invalid_name(self): + """Test that initialization raises exception with invalid name""" + with self.assertRaises(ValueError): # Adjust exception type if needed + Instruction(["invalid_name", self.seed_idx, self.key_idx, self.var_name]) + + def test_tokens_property(self): + """Test that tokens property returns the correct list""" + # Since tokens property implementation is not visible in the dkeygen.py file, + # this test assumes default behavior from parent class or basic functionality + expected_tokens = [ + MemInfo.Const.Keyword.KEYGEN, + self.seed_idx, + self.key_idx, + self.var_name, + ] + self.assertEqual(self.inst.tokens[:4], expected_tokens) + + def test_tokens_with_additional_data(self): + """Test tokens property with additional tokens""" + additional_token = "extra" + inst_with_extra = Instruction( + [ + Instruction.name, + self.seed_idx, + self.key_idx, + self.var_name, + additional_token, + ] + ) + # If tokens property uses default implementation, it should include the additional token + self.assertIn(additional_token, inst_with_extra.tokens) + + @patch( + "linker.instructions.dinst.dinstruction.DInstruction.__init__", + return_value=None, + ) + def test_inheritance(self, mock_init): + """Test that Instruction properly extends DInstruction""" + # Ensure that DInstruction methods are called as expected + Instruction([Instruction.name, self.seed_idx, self.key_idx, self.var_name]) + # Verify DInstruction.__init__ was called + mock_init.assert_called() + + def test_invalid_token_count_too_few(self): + """Test behavior when fewer tokens than required are provided""" + with self.assertRaises(ValueError): # Adjust exception type if needed + Instruction([MemInfo.Const.Keyword.KEYGEN, self.seed_idx, self.key_idx]) + + def test_invalid_token_count_too_many(self): + """Test behavior when more tokens than required are provided""" + # This should not raise an error as additional tokens are handled + inst = Instruction( + [ + MemInfo.Const.Keyword.KEYGEN, + self.seed_idx, + self.key_idx, + self.var_name, + "extra1", + "extra2", + ] + ) + # Check that basic properties are still set correctly + self.assertEqual(inst.name, MemInfo.Const.Keyword.KEYGEN) + + +if __name__ == "__main__": + unittest.main() diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dload.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dload.py new file mode 100644 index 00000000..36e20e6d --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dload.py @@ -0,0 +1,132 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +Unit tests for the DLoad instruction class. + +This module tests the functionality of the DLoad instruction which is +responsible for loading data from memory locations. +""" + +import unittest +from unittest.mock import patch + +from assembler.memory_model.mem_info import MemInfo +from linker.instructions.dinst.dload import Instruction + + +class TestDLoadInstruction(unittest.TestCase): + """ + Test cases for the DLoad instruction class. + + These tests verify that the DLoad instruction correctly handles token + parsing, name resolution, and serialization. + """ + + def setUp(self): + # Create the instruction + self.var_name = "test_var" + self.address = 123 + self.type = "type1" + + def test_get_num_tokens(self): + """Test that _get_num_tokens returns 3""" + self.assertEqual(Instruction.num_tokens, 3) + + def test_get_name(self): + """Test that _get_name returns the expected value""" + self.assertEqual(Instruction.name, MemInfo.Const.Keyword.LOAD) + + def test_initialization_valid_input(self): + """Test that initialization can set up the correct properties with valid name""" + inst = Instruction( + [MemInfo.Const.Keyword.LOAD, self.type, str(self.address), self.var_name] + ) + + self.assertEqual(inst.name, MemInfo.Const.Keyword.LOAD) + + def test_initialization_valid_meta(self): + """Test that initialization can set up the correct properties with valid name""" + inst = Instruction([MemInfo.Const.Keyword.LOAD, self.type, str(self.address)]) + + self.assertEqual(inst.name, MemInfo.Const.Keyword.LOAD) + + def test_initialization_invalid_name(self): + """Test that initialization raises exception with invalid name""" + with self.assertRaises(ValueError): # Adjust exception type if needed + Instruction(["invalid_name", self.type, str(self.address), self.var_name]) + + def test_tokens_property(self): + """Test that tokens property returns the correct list""" + expected_tokens = [ + MemInfo.Const.Keyword.LOAD, + self.type, + str(self.address), + self.var_name, + ] + inst = Instruction( + [Instruction.name, self.type, str(self.address), self.var_name] + ) + + # Manually set properties to match expected behavior + inst.address = self.address + self.assertEqual(inst.tokens, expected_tokens) + + def test_tokens_with_additional_data(self): + """Test tokens property with additional tokens""" + additional_token = "extra" + inst_with_extra = Instruction( + [ + Instruction.name, + self.type, + str(self.address), + self.var_name, + additional_token, + ] + ) + inst_with_extra.address = self.address + expected_tokens = [ + MemInfo.Const.Keyword.LOAD, + self.type, + str(self.address), + self.var_name, + additional_token, + ] + self.assertEqual(inst_with_extra.tokens, expected_tokens) + + @patch( + "linker.instructions.dinst.dinstruction.DInstruction.__init__", + return_value=None, + ) + def test_inheritance(self, mock_init): + """Test that Instruction properly extends DInstruction""" + # Ensure that DInstruction methods are called as expected + Instruction([Instruction.name, self.type, str(self.address), self.var_name]) + # Verify DInstruction.__init__ was called + mock_init.assert_called() + + def test_invalid_token_count_too_few(self): + """Test behavior when fewer tokens than required are provided""" + with self.assertRaises(ValueError): # Adjust exception type if needed + Instruction([MemInfo.Const.Keyword.LOAD, self.var_name]) + + def test_invalid_token_count_too_many(self): + """Test behavior when more tokens than required are provided""" + # This should not raise an error as additional tokens are handled + inst = Instruction( + [ + MemInfo.Const.Keyword.LOAD, + self.type, + str(self.address), + self.var_name, + "extra1", + "extra2", + ] + ) + + # Check that basic properties are still set correctly + self.assertEqual(inst.name, MemInfo.Const.Keyword.LOAD) + + +if __name__ == "__main__": + unittest.main() diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dstore.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dstore.py new file mode 100644 index 00000000..71762da1 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dstore.py @@ -0,0 +1,122 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +Unit tests for the DStore instruction class. + +This module tests the functionality of the DStore instruction which is +responsible for storing data to memory locations. +""" + +import unittest +from unittest.mock import patch + +from assembler.memory_model.mem_info import MemInfo +from linker.instructions.dinst.dstore import Instruction + + +class TestDStoreInstruction(unittest.TestCase): + """ + Test cases for the DStore instruction class. + + These tests verify that the DStore instruction correctly handles token + parsing, name resolution, and serialization. + """ + + def setUp(self): + # Create the instruction + self.var_name = "test_var" + self.address = 123 + + def test_get_num_tokens(self): + """Test that _get_num_tokens returns 3""" + self.assertEqual(Instruction.num_tokens, 3) + + def test_get_name(self): + """Test that _get_name returns the expected value""" + self.assertEqual(Instruction.name, MemInfo.Const.Keyword.STORE) + + def test_initialization_valid_input(self): + """Test that initialization can set up the correct properties with valid name""" + inst = Instruction( + [MemInfo.Const.Keyword.STORE, self.var_name, str(self.address)] + ) + + self.assertEqual(inst.name, MemInfo.Const.Keyword.STORE) + + def test_initialization_invalid_name(self): + """Test that initialization raises exception with invalid name""" + with self.assertRaises(ValueError): # Adjust exception type if needed + Instruction(["invalid_name", self.var_name, str(self.address)]) + + def test_tokens_property(self): + """Test that tokens property returns the correct list""" + expected_tokens = [ + MemInfo.Const.Keyword.STORE, + self.var_name, + str(self.address), + ] + inst = Instruction([Instruction.name, self.var_name, str(self.address)]) + + # Manually set properties to match expected behavior + inst.var = self.var_name + inst.address = self.address + + self.assertEqual(inst.tokens, expected_tokens) + + def test_tokens_with_additional_data(self): + """Test tokens property with additional tokens""" + additional_token = "extra" + inst_with_extra = Instruction( + [ + Instruction.name, + self.var_name, + str(self.address), + additional_token, + ] + ) + inst_with_extra.var = self.var_name + inst_with_extra.address = self.address + expected_tokens = [ + MemInfo.Const.Keyword.STORE, + self.var_name, + str(self.address), + additional_token, + ] + self.assertEqual(inst_with_extra.tokens, expected_tokens) + + @patch( + "linker.instructions.dinst.dinstruction.DInstruction.__init__", + return_value=None, + ) + def test_inheritance(self, mock_init): + """Test that Instruction properly extends DInstruction""" + # Ensure that DInstruction methods are called as expected + Instruction([Instruction.name, self.var_name, str(self.address)]) + # Verify DInstruction.__init__ was called + mock_init.assert_called() + + def test_invalid_token_count_too_few(self): + """Test behavior when fewer tokens than required are provided""" + with self.assertRaises(ValueError): # Adjust exception type if needed + Instruction([MemInfo.Const.Keyword.STORE]) + + def test_invalid_token_count_too_many(self): + """Test behavior when more tokens than required are provided""" + # This should not raise an error as additional tokens are handled + inst = Instruction( + [ + MemInfo.Const.Keyword.STORE, + self.var_name, + str(self.address), + "extra1", + "extra2", + ] + ) + + # Check that basic properties are still set correctly + self.assertEqual(inst.name, MemInfo.Const.Keyword.STORE) + + +if __name__ == "__main__": + unittest.main() diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_init.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_init.py new file mode 100644 index 00000000..eb9579fb --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_init.py @@ -0,0 +1,165 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +Unit tests for the dinst package initialization module. + +This module tests the factory functions and initialization utilities for +data instructions. +""" + +import unittest +from unittest.mock import patch, MagicMock + +from linker.instructions.dinst import factory, create_from_mem_line +from linker.instructions.dinst import DLoad, DStore, DKeyGen + + +class TestDInstModule(unittest.TestCase): + """ + Test cases for data instruction initialization. + + These tests verify that the data instruction factory correctly creates + instruction instances and properly handles initialization errors. + """ + + def test_factory(self): + """Test that factory returns the expected set of instruction classes""" + instruction_set = factory() + self.assertIsInstance(instruction_set, set) + self.assertEqual(len(instruction_set), 3) + self.assertIn(DLoad, instruction_set) + self.assertIn(DStore, instruction_set) + self.assertIn(DKeyGen, instruction_set) + + @patch("assembler.instructions.tokenize_from_line") + @patch("assembler.memory_model.mem_info.MemInfo.get_meminfo_var_from_tokens") + def test_create_from_mem_line_dload_input(self, mock_get_meminfo, mock_tokenize): + """Test create_from_mem_line creates DLoad instruction""" + # Setup mocks + tokens = ["dload", "poly", "0x123", "var1"] + comment = "Test comment" + mock_tokenize.return_value = (tokens, comment) + + # Setup MemInfo mock + miv_mock = MagicMock() + miv_mock.as_dict.return_value = {"var_name": "var1", "hbm_address": 0x123} + mock_get_meminfo.return_value = (miv_mock, None) + + # Call function under test + result = create_from_mem_line("dload, poly, 0x123, var1 # Test comment") + + # Verify results + self.assertIsNotNone(result) + self.assertIsInstance(result, DLoad) + self.assertEqual(result.var, "var1") + self.assertEqual(result.address, 0x123) + + @patch("assembler.instructions.tokenize_from_line") + @patch("assembler.memory_model.mem_info.MemInfo.get_meminfo_var_from_tokens") + def test_create_from_mem_line_dload_meta(self, mock_get_meminfo, mock_tokenize): + """Test create_from_mem_line creates DLoad instruction""" + # Setup mocks + tokens = ["dload", "meta", "1"] + comment = "Test comment" + mock_tokenize.return_value = (tokens, comment) + + # Setup MemInfo mock + miv_mock = MagicMock() + miv_mock.as_dict.return_value = {"var_name": "meta1", "hbm_address": 1} + mock_get_meminfo.return_value = (miv_mock, None) + + # Call function under test + result = create_from_mem_line("dload, meta, 1 # Test comment") + + # Verify results + self.assertIsNotNone(result) + self.assertIsInstance(result, DLoad) + self.assertEqual(result.var, "meta1") + self.assertEqual(result.address, 1) + + @patch("assembler.instructions.tokenize_from_line") + @patch("assembler.memory_model.mem_info.MemInfo.get_meminfo_var_from_tokens") + def test_create_from_mem_line_dstore(self, mock_get_meminfo, mock_tokenize): + """Test create_from_mem_line creates DStore instruction""" + # Setup mocks + tokens = ["dstore", "var1", "0x456"] + comment = "Test comment" + mock_tokenize.return_value = (tokens, comment) + + # Setup MemInfo mock + miv_mock = MagicMock() + miv_mock.as_dict.return_value = {"var_name": "var1", "hbm_address": 0x456} + mock_get_meminfo.return_value = (miv_mock, None) + + # Call function under test + result = create_from_mem_line("dstore, var1, 0x456 # Test comment") + + # Verify results + self.assertIsNotNone(result) + self.assertIsInstance(result, DStore) + self.assertEqual(result.var, "var1") + self.assertEqual(result.address, 0x456) + + @patch("assembler.instructions.tokenize_from_line") + @patch("assembler.memory_model.mem_info.MemInfo.get_meminfo_var_from_tokens") + def test_create_from_mem_line_dkeygen(self, mock_get_meminfo, mock_tokenize): + """Test create_from_mem_line creates DKeyGen instruction""" + # Setup mocks + tokens = ["keygen", "key1", "type1", "256"] + comment = "Test comment" + mock_tokenize.return_value = (tokens, comment) + + # Setup MemInfo mock + miv_mock = MagicMock() + miv_mock.as_dict.return_value = {"var_name": "key1", "hbm_address": 0x0} + mock_get_meminfo.return_value = (miv_mock, None) + + # Call function under test + result = create_from_mem_line("keygen, key1, type1, 256 # Test comment") + + # Verify results + self.assertIsNotNone(result) + self.assertIsInstance(result, DKeyGen) + # Verify var and address were set correctly + self.assertEqual(result.var, "key1") + self.assertEqual(result.address, 0x0) + + @patch("assembler.instructions.tokenize_from_line") + @patch("assembler.memory_model.mem_info.MemInfo.get_meminfo_var_from_tokens") + def test_create_from_mem_line_invalid(self, mock_get_meminfo, mock_tokenize): + """Test create_from_mem_line with invalid instruction""" + # Setup mocks to return invalid tokens + tokens = ["invalid_instruction", "var1", "0x123"] + comment = "" + mock_tokenize.return_value = (tokens, comment) + + # Make get_meminfo_var_from_tokens raise RuntimeError + mock_get_meminfo.side_effect = RuntimeError("Invalid instruction") + + # This should raise RuntimeError due to no valid instruction found + with self.assertRaises(RuntimeError): + create_from_mem_line("invalid_instruction, var1, 0x123") + + @patch("assembler.instructions.tokenize_from_line") + @patch("assembler.memory_model.mem_info.MemInfo.get_meminfo_var_from_tokens") + def test_create_from_mem_line_meminfo_error(self, mock_get_meminfo, mock_tokenize): + """Test create_from_mem_line with MemInfo error""" + # Setup mocks + tokens = ["dstore", "var1", "0x123"] + comment = "" + mock_tokenize.return_value = (tokens, comment) + + # Make get_meminfo_var_from_tokens raise RuntimeError + mock_get_meminfo.side_effect = RuntimeError("Test error") + + # This should wrap the RuntimeError with information about the line + with self.assertRaises(RuntimeError) as context: + create_from_mem_line("dstore, var1, 0x123") + + # Verify the error message contains the original line + self.assertIn("dstore, var1, 0x123", str(context.exception)) + + +if __name__ == "__main__": + unittest.main() diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_init.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_init.py new file mode 100644 index 00000000..ea2e11b8 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_init.py @@ -0,0 +1,122 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +Unit tests for the linker instructions initialization module. + +This module contains tests that verify the behavior of the instruction factory +and initialization functionality. +""" + +import unittest +from unittest.mock import patch, MagicMock + +from linker.instructions import create_from_str_line + + +class TestCreateFromStrLine(unittest.TestCase): + """ + Test cases for instruction initialization functionality. + + These tests verify that instructions are correctly initialized, + their tokens are properly processed, and their factories work as expected. + """ + + def setUp(self): + # Create a mock class (not instance) + self.mock_class = MagicMock() + + # Create a mock instance that will be returned when the class is called + self.mock_instance = MagicMock() + self.mock_instance.__bool__.return_value = True + + # Configure the class to return the instance when called + self.mock_class.return_value = self.mock_instance + + # Create a factory with the mock class + self.factory = {self.mock_class} + + @patch("linker.instructions.tokenize_from_line") + def test_create_from_str_line_success(self, mock_tokenize): + """Test successful instruction creation""" + # Setup mock + tokens = ["instruction", "arg1", "arg2"] + comment = "Test comment" + mock_tokenize.return_value = (tokens, comment) + + # Call function + result = create_from_str_line( + "instruction, arg1, arg2 # Test comment", self.factory + ) + + # Verify + mock_tokenize.assert_called_once_with("instruction, arg1, arg2 # Test comment") + self.mock_class.assert_called_once_with(tokens, comment) + self.assertEqual(result, self.mock_instance) + + @patch("linker.instructions.tokenize_from_line") + def test_create_from_str_line_failure(self, mock_tokenize): + """Test when no instruction can be created""" + # Setup mock + tokens = ["unknown", "arg1", "arg2"] + comment = "Test comment" + mock_tokenize.return_value = (tokens, comment) + + # Make instruction creation fail + self.mock_class.side_effect = ValueError("Invalid instruction") + + # Call function + result = create_from_str_line( + "unknown, arg1, arg2 # Test comment", self.factory + ) + + # Verify + self.assertIsNone(result) + + @patch("linker.instructions.tokenize_from_line") + def test_create_from_str_line_multiple_instruction_types(self, mock_tokenize): + """Test with multiple instruction types in factory""" + # Setup mocks + tokens = ["instruction", "arg1", "arg2"] + comment = "Test comment" + mock_tokenize.return_value = (tokens, comment) + + # Create a second mock instruction class that fails + mock_class2 = MagicMock() + mock_class2.side_effect = ValueError("Invalid instruction") + + # Set up the factory with a specific order - use a list instead of a set + # to control the iteration order + factory = [mock_class2, self.mock_class] + + # Call function + result = create_from_str_line("instruction, arg1, arg2 # Test comment", factory) + + # Verify that it tried both instruction types and returned the successful one + mock_class2.assert_called_once_with(tokens, comment) + self.mock_class.assert_called_once_with(tokens, comment) + mock_class2.assert_called_once() + self.assertEqual(result, self.mock_instance) + + @patch("linker.instructions.tokenize_from_line") + def test_create_from_str_line_exception_handling(self, mock_tokenize): + """Test that general exceptions are caught""" + # Setup mock + tokens = ["instruction", "arg1", "arg2"] + comment = "Test comment" + mock_tokenize.return_value = (tokens, comment) + + # Make instruction creation raise a different exception + self.mock_class.side_effect = Exception("Unexpected error") + + # Call function - should handle the exception and return None + result = create_from_str_line( + "instruction, arg1, arg2 # Test comment", self.factory + ) + + # Verify + self.assertIsNone(result) + + +if __name__ == "__main__": + unittest.main() diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_instruction.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_instruction.py new file mode 100644 index 00000000..604e872e --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_instruction.py @@ -0,0 +1,132 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +Unit tests for the BaseInstruction class. +""" + +import os +import unittest +import tempfile +from unittest.mock import patch + +from assembler.common.config import GlobalConfig +from linker.instructions.instruction import BaseInstruction + + +class MockInstruction(BaseInstruction): + """Concrete implementation of BaseInstruction for testing.""" + + @classmethod + def _get_name(cls) -> str: + return "TEST" + + @classmethod + def _get_name_token_index(cls) -> int: + return 0 + + @classmethod + def _get_num_tokens(cls) -> int: + return 3 + + +class TestBaseInstruction(unittest.TestCase): + """Tests for the BaseInstruction class.""" + + def setUp(self): + """Setup for tests.""" + self.valid_tokens = ["TEST", "arg1", "arg2"] + self.comment = "This is a test comment" + + def test_init_valid(self): + """Test initialization with valid tokens.""" + instruction = MockInstruction(self.valid_tokens, self.comment) + self.assertEqual(instruction.tokens, self.valid_tokens) + self.assertEqual(instruction.comment, self.comment) + + def test_init_invalid_name(self): + """Test initialization with invalid instruction name.""" + invalid_tokens = ["WRONG", "arg1", "arg2"] + with self.assertRaises(ValueError) as context: + MockInstruction(invalid_tokens) + self.assertIn("invalid name", str(context.exception)) + + def test_init_invalid_num_tokens(self): + """Test initialization with incorrect number of tokens.""" + invalid_tokens = ["TEST", "arg1"] + with self.assertRaises(ValueError) as context: + MockInstruction(invalid_tokens) + self.assertIn("invalid amount of tokens", str(context.exception)) + + def test_id_generation(self): + """Test that each instruction gets a unique ID.""" + instruction1 = MockInstruction(self.valid_tokens) + instruction2 = MockInstruction(self.valid_tokens) + self.assertNotEqual(instruction1.id, instruction2.id) + + def test_str_representation(self): + """Test string representation.""" + instruction = MockInstruction(self.valid_tokens) + self.assertEqual(str(instruction), f"TEST({instruction.id})") + + def test_repr_representation(self): + """Test repr representation.""" + instruction = MockInstruction(self.valid_tokens) + self.assertIn("MockInstruction(TEST, id=", repr(instruction)) + self.assertIn("tokens=", repr(instruction)) + + def test_equality(self): + """Test equality operator.""" + instruction1 = MockInstruction(self.valid_tokens) + instruction2 = MockInstruction(self.valid_tokens) + self.assertNotEqual(instruction1, instruction2) + self.assertEqual(instruction1, instruction1) + + def test_hash(self): + """Test hash function.""" + instruction = MockInstruction(self.valid_tokens) + self.assertEqual(hash(instruction), hash(instruction.id)) + + def test_to_line_with_comment(self): + """Test to_line method with comment.""" + instruction = MockInstruction(self.valid_tokens, self.comment) + expected = f"TEST, arg1, arg2 # {self.comment}" + self.assertEqual(instruction.to_line(), expected) + + def test_to_line_without_comment(self): + """Test to_line method without comment.""" + instruction = MockInstruction(self.valid_tokens) + expected = "TEST, arg1, arg2" + self.assertEqual(instruction.to_line(), expected) + + def test_to_line_suppressed_comments(self): + """Test to_line method with suppressed comments.""" + with patch.object(GlobalConfig, "suppress_comments", True): + instruction = MockInstruction(self.valid_tokens, self.comment) + expected = "TEST, arg1, arg2" + self.assertEqual(instruction.to_line(), expected) + + def test_dump_instructions_to_file(self): + """Test dump_instructions_to_file method.""" + instruction1 = MockInstruction(self.valid_tokens, "Comment 1") + instruction2 = MockInstruction(self.valid_tokens, "Comment 2") + instructions = [instruction1, instruction2] + + with tempfile.NamedTemporaryFile(delete=False) as temp_file: + file_path = temp_file.name + + try: + BaseInstruction.dump_instructions_to_file(instructions, file_path) + + with open(file_path, "r", encoding="utf-8") as f: + lines = f.read().splitlines() + + self.assertEqual(len(lines), 2) + self.assertEqual(lines[0], instruction1.to_line()) + self.assertEqual(lines[1], instruction2.to_line()) + finally: + os.unlink(file_path) + + +if __name__ == "__main__": + unittest.main() diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_loader.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_loader.py new file mode 100644 index 00000000..2c6d1728 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_loader.py @@ -0,0 +1,317 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +Unit tests for the loader module. +""" + +import unittest +from unittest.mock import patch, mock_open, MagicMock, call + +from linker.loader import ( + load_minst_kernel, + load_minst_kernel_from_file, + load_cinst_kernel, + load_cinst_kernel_from_file, + load_xinst_kernel, + load_xinst_kernel_from_file, + load_dinst_kernel, + load_dinst_kernel_from_file, +) + + +class TestLoader(unittest.TestCase): + """Tests for the loader module functions.""" + + def setUp(self): + """Set up test fixtures.""" + # Sample instruction lines for each type + self.minst_lines = ["MINST arg1, arg2", "MINST arg3, arg4"] + self.cinst_lines = ["CINST arg1, arg2", "CINST arg3, arg4"] + self.xinst_lines = ["XINST arg1, arg2", "XINST arg3, arg4"] + self.dinst_lines = ["DINST arg1, arg2", "DINST arg3, arg4"] + + # Create mock instruction objects + self.mock_minst = [MagicMock(), MagicMock()] + self.mock_cinst = [MagicMock(), MagicMock()] + self.mock_xinst = [MagicMock(), MagicMock()] + self.mock_dinst = [MagicMock(), MagicMock()] + + @patch("linker.instructions.create_from_str_line") + @patch("linker.instructions.minst.factory") + def test_load_minst_kernel_success(self, mock_factory, mock_create): + """Test successful loading of MInstructions from an iterator.""" + # Configure mocks + mock_factory.return_value = "minst_factory" + mock_create.side_effect = self.mock_minst + + # Call the function + result = load_minst_kernel(self.minst_lines) + + # Verify the results + self.assertEqual(result, self.mock_minst) + # Factory is called once per line, so 2 times total + self.assertEqual(mock_factory.call_count, 2) + self.assertEqual(mock_create.call_count, 2) + mock_create.assert_has_calls( + [ + call(self.minst_lines[0], "minst_factory"), + call(self.minst_lines[1], "minst_factory"), + ] + ) + + @patch("linker.instructions.create_from_str_line") + @patch("linker.instructions.minst.factory") + def test_load_minst_kernel_failure(self, mock_factory, mock_create): + """Test error handling when loading MInstructions fails.""" + # Configure mocks + mock_factory.return_value = "minst_factory" + mock_create.return_value = None + + # Call the function and check for exception + with self.assertRaises(RuntimeError) as context: + load_minst_kernel(self.minst_lines) + + self.assertIn( + f"Error parsing line 1: {self.minst_lines[0]}", str(context.exception) + ) + + @patch("builtins.open", new_callable=mock_open) + @patch("linker.loader.load_minst_kernel") + def test_load_minst_kernel_from_file_success(self, mock_load, mock_file): + """Test successful loading of MInstructions from a file.""" + # Configure mocks + mock_file.return_value.__enter__.return_value = self.minst_lines + mock_load.return_value = self.mock_minst + + # Call the function + result = load_minst_kernel_from_file("test.minst") + + # Verify the results + self.assertEqual(result, self.mock_minst) + mock_file.assert_called_once_with("test.minst", "r") + mock_load.assert_called_once_with(self.minst_lines) + + @patch("builtins.open", new_callable=mock_open) + @patch("linker.loader.load_minst_kernel") + def test_load_minst_kernel_from_file_failure(self, mock_load, mock_file): + """Test error handling when loading MInstructions from a file fails.""" + # Configure mocks + mock_file.return_value.__enter__.return_value = self.minst_lines + mock_load.side_effect = Exception("Test error") + + # Call the function and check for exception + with self.assertRaises(RuntimeError) as context: + load_minst_kernel_from_file("test.minst") + + self.assertIn( + 'Error occurred loading file "test.minst"', str(context.exception) + ) + + @patch("linker.instructions.create_from_str_line") + @patch("linker.instructions.cinst.factory") + def test_load_cinst_kernel_success(self, mock_factory, mock_create): + """Test successful loading of CInstructions from an iterator.""" + # Configure mocks + mock_factory.return_value = "cinst_factory" + mock_create.side_effect = self.mock_cinst + + # Call the function + result = load_cinst_kernel(self.cinst_lines) + + # Verify the results + self.assertEqual(result, self.mock_cinst) + # Factory is called once per line, so 2 times total + self.assertEqual(mock_factory.call_count, 2) + self.assertEqual(mock_create.call_count, 2) + mock_create.assert_has_calls( + [ + call(self.cinst_lines[0], "cinst_factory"), + call(self.cinst_lines[1], "cinst_factory"), + ] + ) + + @patch("linker.instructions.create_from_str_line") + @patch("linker.instructions.cinst.factory") + def test_load_cinst_kernel_failure(self, mock_factory, mock_create): + """Test error handling when loading CInstructions fails.""" + # Configure mocks + mock_factory.return_value = "cinst_factory" + mock_create.return_value = None + + # Call the function and check for exception + with self.assertRaises(RuntimeError) as context: + load_cinst_kernel(self.cinst_lines) + + self.assertIn( + f"Error parsing line 1: {self.cinst_lines[0]}", str(context.exception) + ) + + @patch("builtins.open", new_callable=mock_open) + @patch("linker.loader.load_cinst_kernel") + def test_load_cinst_kernel_from_file_success(self, mock_load, mock_file): + """Test successful loading of CInstructions from a file.""" + # Configure mocks + mock_file.return_value.__enter__.return_value = self.cinst_lines + mock_load.return_value = self.mock_cinst + + # Call the function + result = load_cinst_kernel_from_file("test.cinst") + + # Verify the results + self.assertEqual(result, self.mock_cinst) + mock_file.assert_called_once_with("test.cinst", "r") + mock_load.assert_called_once_with(self.cinst_lines) + + @patch("builtins.open", new_callable=mock_open) + @patch("linker.loader.load_cinst_kernel") + def test_load_cinst_kernel_from_file_failure(self, mock_load, mock_file): + """Test error handling when loading CInstructions from a file fails.""" + # Configure mocks + mock_file.return_value.__enter__.return_value = self.cinst_lines + mock_load.side_effect = Exception("Test error") + + # Call the function and check for exception + with self.assertRaises(RuntimeError) as context: + load_cinst_kernel_from_file("test.cinst") + + self.assertIn( + 'Error occurred loading file "test.cinst"', str(context.exception) + ) + + @patch("linker.instructions.create_from_str_line") + @patch("linker.instructions.xinst.factory") + def test_load_xinst_kernel_success(self, mock_factory, mock_create): + """Test successful loading of XInstructions from an iterator.""" + # Configure mocks + mock_factory.return_value = "xinst_factory" + mock_create.side_effect = self.mock_xinst + + # Call the function + result = load_xinst_kernel(self.xinst_lines) + + # Verify the results + self.assertEqual(result, self.mock_xinst) + # Factory is called once per line, so 2 times total + self.assertEqual(mock_factory.call_count, 2) + self.assertEqual(mock_create.call_count, 2) + mock_create.assert_has_calls( + [ + call(self.xinst_lines[0], "xinst_factory"), + call(self.xinst_lines[1], "xinst_factory"), + ] + ) + + @patch("linker.instructions.create_from_str_line") + @patch("linker.instructions.xinst.factory") + def test_load_xinst_kernel_failure(self, mock_factory, mock_create): + """Test error handling when loading XInstructions fails.""" + # Configure mocks + mock_factory.return_value = "xinst_factory" + mock_create.return_value = None + + # Call the function and check for exception + with self.assertRaises(RuntimeError) as context: + load_xinst_kernel(self.xinst_lines) + + self.assertIn( + f"Error parsing line 1: {self.xinst_lines[0]}", str(context.exception) + ) + + @patch("builtins.open", new_callable=mock_open) + @patch("linker.loader.load_xinst_kernel") + def test_load_xinst_kernel_from_file_success(self, mock_load, mock_file): + """Test successful loading of XInstructions from a file.""" + # Configure mocks + mock_file.return_value.__enter__.return_value = self.xinst_lines + mock_load.return_value = self.mock_xinst + + # Call the function + result = load_xinst_kernel_from_file("test.xinst") + + # Verify the results + self.assertEqual(result, self.mock_xinst) + mock_file.assert_called_once_with("test.xinst", "r") + mock_load.assert_called_once_with(self.xinst_lines) + + @patch("builtins.open", new_callable=mock_open) + @patch("linker.loader.load_xinst_kernel") + def test_load_xinst_kernel_from_file_failure(self, mock_load, mock_file): + """Test error handling when loading XInstructions from a file fails.""" + # Configure mocks + mock_file.return_value.__enter__.return_value = self.xinst_lines + mock_load.side_effect = Exception("Test error") + + # Call the function and check for exception + with self.assertRaises(RuntimeError) as context: + load_xinst_kernel_from_file("test.xinst") + + self.assertIn( + 'Error occurred loading file "test.xinst"', str(context.exception) + ) + + @patch("linker.instructions.dinst.create_from_mem_line") + def test_load_dinst_kernel_success(self, mock_create): + """Test successful loading of DInstructions from an iterator.""" + # Configure mocks + mock_create.side_effect = self.mock_dinst + + # Call the function + result = load_dinst_kernel(self.dinst_lines) + + # Verify the results + self.assertEqual(result, self.mock_dinst) + self.assertEqual(mock_create.call_count, 2) + mock_create.assert_has_calls( + [call(self.dinst_lines[0]), call(self.dinst_lines[1])] + ) + + @patch("linker.instructions.dinst.create_from_mem_line") + def test_load_dinst_kernel_failure(self, mock_create): + """Test error handling when loading DInstructions fails.""" + # Configure mocks + mock_create.return_value = None + + # Call the function and check for exception + with self.assertRaises(RuntimeError) as context: + load_dinst_kernel(self.dinst_lines) + + self.assertIn( + f"Error parsing line 1: {self.dinst_lines[0]}", str(context.exception) + ) + + @patch("builtins.open", new_callable=mock_open) + @patch("linker.loader.load_dinst_kernel") + def test_load_dinst_kernel_from_file_success(self, mock_load, mock_file): + """Test successful loading of DInstructions from a file.""" + # Configure mocks + mock_file.return_value.__enter__.return_value = self.dinst_lines + mock_load.return_value = self.mock_dinst + + # Call the function + result = load_dinst_kernel_from_file("test.dinst") + + # Verify the results + self.assertEqual(result, self.mock_dinst) + mock_file.assert_called_once_with("test.dinst", "r") + mock_load.assert_called_once_with(self.dinst_lines) + + @patch("builtins.open", new_callable=mock_open) + @patch("linker.loader.load_dinst_kernel") + def test_load_dinst_kernel_from_file_failure(self, mock_load, mock_file): + """Test error handling when loading DInstructions from a file fails.""" + # Configure mocks + mock_file.return_value.__enter__.return_value = self.dinst_lines + mock_load.side_effect = Exception("Test error") + + # Call the function and check for exception + with self.assertRaises(RuntimeError) as context: + load_dinst_kernel_from_file("test.dinst") + + self.assertIn( + 'Error occurred loading file "test.dinst"', str(context.exception) + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/__init__.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/__init__.py new file mode 100644 index 00000000..4057dc01 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_program_linker.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_program_linker.py new file mode 100644 index 00000000..5147c332 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_program_linker.py @@ -0,0 +1,623 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +Unit tests for the program_linker module. +""" + +import io +import unittest +from unittest.mock import patch, MagicMock, call + +from assembler.common.config import GlobalConfig +from linker import MemoryModel +from linker.instructions import minst, cinst, dinst +from linker.steps.program_linker import LinkedProgram + + +# pylint: disable=protected-access +class TestLinkedProgram(unittest.TestCase): + """Tests for the LinkedProgram class.""" + + def setUp(self): + """Set up test fixtures.""" + # Group related stream objects into a dictionary + self.streams = { + "minst": io.StringIO(), + "cinst": io.StringIO(), + "xinst": io.StringIO(), + } + self.mem_model = MagicMock(spec=MemoryModel) + + # Mock the hasHBM property to return True by default + self.has_hbm_patcher = patch.object(GlobalConfig, "hasHBM", True) + self.mock_has_hbm = self.has_hbm_patcher.start() + + # Mock the suppress_comments property to return False by default + self.suppress_comments_patcher = patch.object( + GlobalConfig, "suppress_comments", False + ) + self.mock_suppress_comments = self.suppress_comments_patcher.start() + + self.program = LinkedProgram( + self.streams["minst"], + self.streams["cinst"], + self.streams["xinst"], + self.mem_model, + ) + + def tearDown(self): + """Tear down test fixtures.""" + self.has_hbm_patcher.stop() + self.suppress_comments_patcher.stop() + + def test_init(self): + """Test initialization of LinkedProgram.""" + self.assertEqual( + self.program._LinkedProgram__minst_ostream, self.streams["minst"] + ) + self.assertEqual( + self.program._LinkedProgram__cinst_ostream, self.streams["cinst"] + ) + self.assertEqual( + self.program._LinkedProgram__xinst_ostream, self.streams["xinst"] + ) + self.assertEqual(self.program._LinkedProgram__mem_model, self.mem_model) + self.assertEqual(self.program._LinkedProgram__bundle_offset, 0) + self.assertEqual(self.program._LinkedProgram__minst_line_offset, 0) + self.assertEqual(self.program._LinkedProgram__cinst_line_offset, 0) + self.assertEqual(self.program._LinkedProgram__kernel_count, 0) + self.assertTrue(self.program._LinkedProgram__is_open) + + def test_is_open_property(self): + """Test the is_open property.""" + self.assertTrue(self.program.is_open) + self.program._LinkedProgram__is_open = False + self.assertFalse(self.program.is_open) + + def test_close(self): + """Test closing the program.""" + self.program.close() + + # Verify cexit and msyncc were added + self.assertIn("cexit", self.streams["cinst"].getvalue().lower()) + self.assertIn("msyncc", self.streams["minst"].getvalue().lower()) + self.assertFalse(self.program.is_open) + + # Test that closing an already closed program raises RuntimeError + with self.assertRaises(RuntimeError): + self.program.close() + + # Clean the StringIO object properly + self.streams["minst"].seek(0) + self.streams["minst"].truncate(0) + self.streams["cinst"].seek(0) + self.streams["cinst"].truncate(0) + self.streams["xinst"].seek(0) + self.streams["xinst"].truncate(0) + + # Test closing the program with comments suppressed. + with patch.object(GlobalConfig, "suppress_comments", True): + program = LinkedProgram( + self.streams["minst"], + self.streams["cinst"], + self.streams["xinst"], + self.mem_model, + ) + program.close() + + # Should not contain "terminating MInstQ" comment + self.assertNotIn("terminating MInstQ", self.streams["minst"].getvalue()) + + def test_validate_hbm_address(self): + """Test validating a HBM address.""" + + # Test validating a valid HBM address + self.mem_model.mem_info_vars = {} + self.program._validate_hbm_address("test_var", 10) + # No exception should be raised + + # Test validating a negative HBM address + with self.assertRaises(RuntimeError): + self.program._validate_hbm_address("test_var", -1) + + def test_validate_hbm_address_mismatch(self): + """Test validating an HBM address that doesn't match the declared address.""" + mock_var = MagicMock() + mock_var.hbm_address = 5 + self.mem_model.mem_info_vars = {"test_var": mock_var} + + with self.assertRaises(RuntimeError): + self.program._validate_hbm_address("test_var", 10) + + def test_validate_spad_address_valid(self): + """Test validating a valid SPAD address with HBM disabled.""" + with patch.object(GlobalConfig, "hasHBM", False): + self.mem_model.mem_info_vars = {} + self.program._validate_spad_address("test_var", 10) + # No exception should be raised + + def test_validate_spad_address_with_hbm_enabled(self): + """Test validating a SPAD address with HBM enabled (should raise AssertionError).""" + with self.assertRaises(AssertionError): + self.program._validate_spad_address("test_var", 10) + + def test_validate_spad_address_negative(self): + """Test validating a negative SPAD address.""" + with patch.object(GlobalConfig, "hasHBM", False): + with self.assertRaises(RuntimeError): + self.program._validate_spad_address("test_var", -1) + + def test_validate_spad_address_mismatch(self): + """Test validating a SPAD address that doesn't match the declared address.""" + with patch.object(GlobalConfig, "hasHBM", False): + mock_var = MagicMock() + mock_var.hbm_address = 5 + self.mem_model.mem_info_vars = {"test_var": mock_var} + + with self.assertRaises(RuntimeError): + self.program._validate_spad_address("test_var", 10) + + def test_update_minsts(self): + """Test updating MInsts.""" + # Create mock MInstructions + mock_msyncc = MagicMock(spec=minst.MSyncc) + mock_msyncc.target = 5 + + mock_mload = MagicMock(spec=minst.MLoad) + mock_mload.source = "input_var" + mock_mload.comment = "original comment" + + mock_mstore = MagicMock(spec=minst.MStore) + mock_mstore.dest = "output_var" + mock_mstore.comment = None + + # Set up memory model mock + self.mem_model.useVariable.side_effect = [ + 10, + 20, + ] # Return different addresses for different vars + + # Execute the update + kernel_minstrs = [mock_msyncc, mock_mload, mock_mstore] + self.program._LinkedProgram__cinst_line_offset = 10 # Set initial offset + self.program._LinkedProgram__kernel_count = 1 # Set kernel count + self.program._update_minsts(kernel_minstrs) + + # Verify results + self.assertEqual(mock_msyncc.target, 15) # 5 + 10 + self.assertEqual(mock_mload.source, "10") # Replaced with HBM address + self.assertIn("input_var", mock_mload.comment) # Comment updated + self.assertIn( + "original comment", mock_mload.comment + ) # Original comment preserved + + self.assertEqual(mock_mstore.dest, "20") # Replaced with HBM address + + # Verify the memory model was used correctly + self.mem_model.useVariable.assert_has_calls( + [call("input_var", 1), call("output_var", 1)] + ) + + def test_remove_and_merge_csyncm_cnop(self): + """Test removing CSyncm instructions and merging CNop instructions.""" + # Create mock CInstructions + mock_ifetch = MagicMock(spec=cinst.IFetch) + mock_ifetch.bundle = 1 + mock_ifetch.tokens = [0] + + mock_csyncm1 = MagicMock(spec=cinst.CSyncm) + mock_csyncm1.tokens = [0] + + mock_cnop1 = MagicMock(spec=cinst.CNop) + mock_cnop1.cycles = 2 + mock_cnop1.tokens = [0] + + mock_csyncm2 = MagicMock(spec=cinst.CSyncm) + mock_csyncm2.tokens = [0] + + mock_cnop2 = MagicMock(spec=cinst.CNop) + mock_cnop2.cycles = 3 + mock_cnop2.tokens = [0] + + # Set up ISACInst.CSyncm.get_throughput + with patch( + "assembler.instructions.cinst.CSyncm.get_throughput", return_value=2 + ): + # Execute the method + kernel_cinstrs = [ + mock_ifetch, + mock_csyncm1, + mock_cnop1, + mock_csyncm2, + mock_cnop2, + ] + self.program._remove_and_merge_csyncm_cnop(kernel_cinstrs) + + # Verify CSyncm instructions were removed + self.assertNotIn(mock_csyncm1, kernel_cinstrs) + self.assertNotIn(mock_csyncm2, kernel_cinstrs) + + # Verify CNop cycles were updated (should have added 2 for each CSyncm) + # First CNop gets 2 cycles added from first CSyncm + self.assertEqual(mock_cnop1.cycles, 4) # 2 + 2 + + # Verify the line numbers were updated + for i, instr in enumerate(kernel_cinstrs): + self.assertEqual(instr.tokens[0], i) + + def test_update_cinsts_addresses_and_offsets(self): + """Test updating CInst addresses and offsets.""" + # Create mock CInstructions + mock_ifetch = MagicMock(spec=cinst.IFetch) + mock_ifetch.bundle = 1 + + mock_csyncm = MagicMock(spec=cinst.CSyncm) + mock_csyncm.target = 5 + + mock_xinstfetch = MagicMock(spec=cinst.XInstFetch) + + # Create SPAD instructions for no-HBM case + mock_bload = MagicMock(spec=cinst.BLoad) + mock_bload.source = "var1" + mock_bload.comment = "original comment" + + mock_cstore = MagicMock(spec=cinst.CStore) + mock_cstore.dest = "var2" + mock_cstore.comment = None + + # Execute the method with HBM enabled + kernel_cinstrs = [mock_ifetch, mock_csyncm] + self.program._LinkedProgram__bundle_offset = 10 + self.program._LinkedProgram__minst_line_offset = 20 + self.program._update_cinsts_addresses_and_offsets(kernel_cinstrs) + + # Verify results with HBM enabled + self.assertEqual(mock_ifetch.bundle, 11) # 1 + 10 + self.assertEqual(mock_csyncm.target, 25) # 5 + 20 + + # Test with HBM disabled + with patch.object(GlobalConfig, "hasHBM", False): + # Set up memory model mock + self.mem_model.useVariable.side_effect = [ + 30, + 40, + ] # Return different addresses for different vars + + kernel_cinstrs = [mock_bload, mock_cstore] + self.program._LinkedProgram__kernel_count = 2 + self.program._update_cinsts_addresses_and_offsets(kernel_cinstrs) + + # Verify SPAD instructions were updated + self.assertEqual(mock_bload.source, "30") + self.assertIn("var1", mock_bload.comment) + self.assertIn("original comment", mock_bload.comment) + + self.assertEqual(mock_cstore.dest, "40") + + # Verify the memory model was used correctly + self.mem_model.useVariable.assert_has_calls( + [call("var1", 2), call("var2", 2)] + ) + + # Test that XInstFetch raises NotImplementedError + with self.assertRaises(NotImplementedError): + self.program._update_cinsts_addresses_and_offsets([mock_xinstfetch]) + + def test_update_cinsts(self): + """Test updating CInsts.""" + # Create a mock for _remove_and_merge_csyncm_cnop and _update_cinsts_addresses_and_offsets + with patch.object( + LinkedProgram, "_remove_and_merge_csyncm_cnop" + ) as mock_remove, patch.object( + LinkedProgram, "_update_cinsts_addresses_and_offsets" + ) as mock_update: + + # Execute the method with HBM enabled + kernel_cinstrs = ["cinst1", "cinst2"] + self.program._update_cinsts(kernel_cinstrs) + + # Verify that only _update_cinsts_addresses_and_offsets was called + mock_remove.assert_not_called() + mock_update.assert_called_once_with(kernel_cinstrs) + + # Reset mocks + mock_remove.reset_mock() + mock_update.reset_mock() + + # Execute the method with HBM disabled + with patch.object(GlobalConfig, "hasHBM", False): + self.program._update_cinsts(kernel_cinstrs) + + # Verify that both methods were called + mock_remove.assert_called_once_with(kernel_cinstrs) + mock_update.assert_called_once_with(kernel_cinstrs) + + def test_update_xinsts(self): + """Test updating XInsts.""" + # Create mock XInstructions + mock_xinst1 = MagicMock() + mock_xinst1.bundle = 1 + + mock_xinst2 = MagicMock() + mock_xinst2.bundle = 2 + + mock_xinst3 = MagicMock() + mock_xinst3.bundle = 0 # Will cause an error when updated after mock_xinst2 + + # Execute the method + kernel_xinstrs = [mock_xinst1, mock_xinst2] + self.program._LinkedProgram__bundle_offset = 10 + last_bundle = self.program._update_xinsts(kernel_xinstrs) + + # Verify results + self.assertEqual(mock_xinst1.bundle, 11) # 1 + 10 + self.assertEqual(mock_xinst2.bundle, 12) # 2 + 10 + self.assertEqual(last_bundle, 12) + + # Test that an invalid bundle sequence raises RuntimeError + kernel_xinstrs = [ + mock_xinst2, + mock_xinst3, + ] # xinst3 has lower bundle than xinst2 + with self.assertRaises(RuntimeError): + self.program._update_xinsts(kernel_xinstrs) + + def test_link_kernel(self): + """Test linking a kernel.""" + # Create mocks for the update methods + with patch.object( + LinkedProgram, "_update_minsts" + ) as mock_update_minsts, patch.object( + LinkedProgram, "_update_cinsts" + ) as mock_update_cinsts, patch.object( + LinkedProgram, "_update_xinsts" + ) as mock_update_xinsts: + + # Setup mock_update_xinsts to return a bundle offset + mock_update_xinsts.return_value = 5 + + # Create mock instruction lists + kernel_minstrs = [MagicMock(), MagicMock()] + kernel_cinstrs = [MagicMock(), MagicMock()] + kernel_xinstrs = [MagicMock(), MagicMock()] + + # Configure the mocks for to_line method + for i, xinstr in enumerate(kernel_xinstrs): + xinstr.to_line.return_value = f"xinst{i}" + xinstr.comment = f"xinst_comment{i}" if i % 2 == 0 else None + + for i, cinstr in enumerate( + kernel_cinstrs[:-1] + ): # Skip the last one (cexit) + cinstr.to_line.return_value = f"cinst{i}" + cinstr.comment = f"cinst_comment{i}" if i % 2 == 0 else None + + for i, minstr in enumerate( + kernel_minstrs[:-1] + ): # Skip the last one (msyncc) + minstr.to_line.return_value = f"minst{i}" + minstr.comment = f"minst_comment{i}" if i % 2 == 0 else None + + # Execute the method + self.program.link_kernel(kernel_minstrs, kernel_cinstrs, kernel_xinstrs) + + # Verify update methods were called + mock_update_minsts.assert_called_once_with(kernel_minstrs) + mock_update_cinsts.assert_called_once_with(kernel_cinstrs) + mock_update_xinsts.assert_called_once_with(kernel_xinstrs) + + # Verify bundle offset was updated + self.assertEqual(self.program._LinkedProgram__bundle_offset, 6) # 5 + 1 + + # Verify line offsets were updated + self.assertEqual( + self.program._LinkedProgram__minst_line_offset, 1 + ) # len(kernel_minstrs) - 1 + self.assertEqual( + self.program._LinkedProgram__cinst_line_offset, 1 + ) # len(kernel_cinstrs) - 1 + + # Verify kernel count was incremented + self.assertEqual(self.program._LinkedProgram__kernel_count, 1) + + # Verify output streams contain the instructions + xinst_output = self.streams["xinst"].getvalue() + cinst_output = self.streams["cinst"].getvalue() + minst_output = self.streams["minst"].getvalue() + + self.assertIn("xinst0", xinst_output) + self.assertIn("xinst1", xinst_output) + self.assertIn("xinst_comment0", xinst_output) + + self.assertIn("0, cinst0", cinst_output) + self.assertIn("cinst_comment0", cinst_output) + + self.assertIn("0, minst0", minst_output) + self.assertIn("minst_comment0", minst_output) + + def test_link_kernel_with_no_hbm(self): + """Test linking a kernel with HBM disabled.""" + with patch.object(GlobalConfig, "hasHBM", False): + # Create mocks for the update methods + with patch.object( + LinkedProgram, "_update_cinsts" + ) as mock_update_cinsts, patch.object( + LinkedProgram, "_update_xinsts" + ) as mock_update_xinsts: + + # Setup mock_update_xinsts to return a bundle offset + mock_update_xinsts.return_value = 5 + + # Create mock instruction lists + kernel_minstrs = [MagicMock(), MagicMock()] # Should be ignored + kernel_cinstrs = [MagicMock(), MagicMock()] + kernel_xinstrs = [MagicMock(), MagicMock()] + + # Configure the mocks for to_line method + for xinstr in kernel_xinstrs: + xinstr.to_line.return_value = "xinst" + xinstr.comment = None + + for cinstr in kernel_cinstrs[:-1]: # Skip the last one (cexit) + cinstr.to_line.return_value = "cinst" + cinstr.comment = None + + # Execute the method + self.program.link_kernel(kernel_minstrs, kernel_cinstrs, kernel_xinstrs) + + # Verify update methods were called + # No minsts should be processed when HBM is disabled + mock_update_cinsts.assert_called_once_with(kernel_cinstrs) + mock_update_xinsts.assert_called_once_with(kernel_xinstrs) + + # Verify bundle offset was updated + self.assertEqual(self.program._LinkedProgram__bundle_offset, 6) # 5 + 1 + + # No MInst output when HBM is disabled + minst_output = self.streams["minst"].getvalue() + self.assertEqual(minst_output, "") + + def test_link_kernel_with_closed_program(self): + """Test linking a kernel with a closed program.""" + # Close the program + self.program._LinkedProgram__is_open = False + + # Try to link a kernel + with self.assertRaises(RuntimeError): + self.program.link_kernel([], [], []) + + def test_link_kernel_with_suppress_comments(self): + """Test linking a kernel with comments suppressed.""" + with patch.object(GlobalConfig, "suppress_comments", True): + # Create mocks for the update methods + with patch.object(LinkedProgram, "_update_minsts"), patch.object( + LinkedProgram, "_update_cinsts" + ), patch.object(LinkedProgram, "_update_xinsts"): + + # Create mock instruction lists with comments + kernel_minstrs = [MagicMock(), MagicMock()] + kernel_cinstrs = [MagicMock(), MagicMock()] + kernel_xinstrs = [MagicMock()] + + # Configure the mocks for to_line method + kernel_xinstrs[0].to_line.return_value = "xinst" + kernel_xinstrs[0].comment = "xinst_comment" + + kernel_cinstrs[0].to_line.return_value = "cinst" + kernel_cinstrs[0].comment = "cinst_comment" + + kernel_minstrs[0].to_line.return_value = "minst" + kernel_minstrs[0].comment = "minst_comment" + + # Execute the method + self.program.link_kernel(kernel_minstrs, kernel_cinstrs, kernel_xinstrs) + + # Verify comments were suppressed + xinst_output = self.streams["xinst"].getvalue() + cinst_output = self.streams["cinst"].getvalue() + minst_output = self.streams["minst"].getvalue() + + self.assertNotIn("xinst_comment", xinst_output) + self.assertNotIn("cinst_comment", cinst_output) + self.assertNotIn("minst_comment", minst_output) + + +class TestJoinDinstKernels(unittest.TestCase): + """Tests for the join_dinst_kernels static method.""" + + def test_join_dinst_kernels_empty(self): + """Test joining empty list of DInst kernels.""" + with self.assertRaises(ValueError): + LinkedProgram.join_dinst_kernels([]) + + def test_join_dinst_kernels_single_kernel(self): + """Test joining a single DInst kernel.""" + # Create mock DInstructions + mock_dload = MagicMock(spec=dinst.DLoad) + mock_dload.var = "var1" + + mock_dstore = MagicMock(spec=dinst.DStore) + mock_dstore.var = "var2" + + # Execute the method + result = LinkedProgram.join_dinst_kernels([[mock_dload, mock_dstore]]) + + # Verify result + self.assertEqual(len(result), 2) + self.assertEqual(result[0], mock_dload) + self.assertEqual(result[1], mock_dstore) + + # Verify address was set + self.assertEqual(mock_dload.address, 0) + self.assertEqual(mock_dstore.address, 1) + + def test_join_dinst_kernels_multiple_kernels(self): + """Test joining multiple DInst kernels.""" + # Create mock DInstructions for first kernel + mock_dload1 = MagicMock(spec=dinst.DLoad) + mock_dload1.var = "var1" + + mock_dstore1 = MagicMock(spec=dinst.DStore) + mock_dstore1.var = "var2" + + # Create mock DInstructions for second kernel + mock_dload2 = MagicMock(spec=dinst.DLoad) + mock_dload2.var = "var2" # Same as output from first kernel + + mock_dkeygen = MagicMock(spec=dinst.DKeyGen) + mock_dkeygen.var = "var3" + + mock_dstore2 = MagicMock(spec=dinst.DStore) + mock_dstore2.var = "var4" + + # Execute the method + result = LinkedProgram.join_dinst_kernels( + [[mock_dload1, mock_dstore1], [mock_dload2, mock_dkeygen, mock_dstore2]] + ) + + # Verify result - should contain load1, store1 (output), keygen, store2 (output) + # dload2 should be skipped since it loads var2 which is already an output from kernel1 + self.assertEqual(len(result), 3) + self.assertIn(mock_dload1, result) + self.assertNotIn(mock_dload2, result) # Should be skipped + self.assertIn(mock_dkeygen, result) + self.assertIn(mock_dstore2, result) + + # Verify addresses were set correctly and sequentially + # Note: exact order depends on dictionary iteration which is not guaranteed + used_addresses = {dinst.address for dinst in result} + self.assertEqual(used_addresses, {0, 1, 2}) # Three consecutive addresses + + def test_join_dinst_kernels_with_carry_over_vars(self): + """Test joining DInst kernels with carry-over variables that are both input and output.""" + # Create mock DInstructions for first kernel + mock_dload1 = MagicMock(spec=dinst.DLoad) + mock_dload1.var = "var1" + + mock_dstore1 = MagicMock(spec=dinst.DStore) + mock_dstore1.var = "var2" + + # Create mock DInstructions for second kernel + mock_dload2 = MagicMock(spec=dinst.DLoad) + mock_dload2.var = "var2" # Same as output from first kernel + + mock_dstore2 = MagicMock(spec=dinst.DStore) + mock_dstore2.var = "var2" # Same variable is also an output + + # Execute the method + result = LinkedProgram.join_dinst_kernels( + [[mock_dload1, mock_dstore1], [mock_dload2, mock_dstore2]] + ) + + # Verify result - should contain load1, store2 + # Both dload2 and dstore1 should be skipped since var2 is carried over + self.assertEqual(len(result), 2) + self.assertIn(mock_dload1, result) + self.assertNotIn(mock_dload2, result) # Should be skipped + self.assertNotIn(mock_dstore1, result) # Should be skipped + self.assertIn(mock_dstore2, result) # Final output for var2 + + +if __name__ == "__main__": + unittest.main() diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_variable_discovery.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_variable_discovery.py new file mode 100644 index 00000000..6d1a471d --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_variable_discovery.py @@ -0,0 +1,248 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +Unit tests for the variable discovery module. +""" + +import unittest +from unittest.mock import patch, MagicMock + +from linker.steps.variable_discovery import discoverVariables, discoverVariablesSPAD + + +class TestVariableDiscovery(unittest.TestCase): + """Tests for the variable discovery functions.""" + + def setUp(self): + """Set up test fixtures.""" + # Group MInstructions in a dictionary + self.m_instrs = { + "load": MagicMock(source="var1"), + "store": MagicMock(dest="var2"), + "other": MagicMock(), # MInstruction that's neither MLoad nor MStore + } + + # Group CInstructions in a dictionary + self.c_instrs = { + "bload": MagicMock(source="var3"), + "cload": MagicMock(source="var4"), + "bones": MagicMock(source="var5"), + "nload": MagicMock(source="var6"), + "cstore": MagicMock(dest="var7"), + "other": MagicMock(), # CInstruction that's none of the above + } + + @patch("linker.steps.variable_discovery.minst") + @patch("linker.steps.variable_discovery.MInstruction") + @patch("assembler.memory_model.variable.Variable.validateName") + def test_discover_variables_valid( + self, mock_validate, mock_minst_class, mock_minst + ): + """Test discovering variables from valid MInstructions.""" + # Setup mocks + mock_minst.MLoad = MagicMock() + mock_minst.MStore = MagicMock() + + # Configure type checking at the module level, avoiding patching isinstance + def is_minst_side_effect(obj): + return obj in [ + self.m_instrs["load"], + self.m_instrs["store"], + self.m_instrs["other"], + ] + + def is_mload_side_effect(obj): + return obj is self.m_instrs["load"] + + def is_mstore_side_effect(obj): + return obj is self.m_instrs["store"] + + # Replace the actual isinstance calls in the module with our mock functions + with patch( + "linker.steps.variable_discovery.isinstance", + side_effect=lambda obj, cls: { + mock_minst_class: is_minst_side_effect(obj), + mock_minst.MLoad: is_mload_side_effect(obj), + mock_minst.MStore: is_mstore_side_effect(obj), + }.get(cls, False), + ): + + # Configure validateName to return True + mock_validate.return_value = True + + # Test with a list containing both MLoad and MStore + minstrs = [ + self.m_instrs["load"], + self.m_instrs["store"], + self.m_instrs["other"], + ] + + # Call the function + result = list(discoverVariables(minstrs)) + + # Verify results + self.assertEqual(result, ["var1", "var2"]) + mock_validate.assert_any_call("var1") + mock_validate.assert_any_call("var2") + + def test_discover_variables_empty_list(self): + """Test discovering variables from an empty list of MInstructions.""" + # No need to patch isinstance for an empty list + result = list(discoverVariables([])) + + # Verify results - should be an empty list + self.assertEqual(result, []) + + def test_discover_variables_invalid_type(self): + """Test discovering variables with invalid types in the list.""" + # Setup mock to fail the isinstance check + invalid_obj = MagicMock() + + with patch("linker.steps.variable_discovery.isinstance", return_value=False): + # Call the function with a list containing an invalid type + with self.assertRaises(TypeError) as context: + list(discoverVariables([invalid_obj])) + + # Verify the error message + self.assertIn("not a valid MInstruction", str(context.exception)) + + @patch("linker.steps.variable_discovery.minst") + @patch("assembler.memory_model.variable.Variable.validateName") + def test_discover_variables_invalid_variable_name(self, mock_validate, mock_minst): + """Test discovering variables with an invalid variable name.""" + # Setup mocks + mock_minst.MLoad = MagicMock() + + # Configure validateName to return False + mock_validate.return_value = False + + with patch( + "linker.steps.variable_discovery.isinstance", + side_effect=lambda obj, cls: True, + ): + # Call the function + with self.assertRaises(RuntimeError) as context: + list(discoverVariables([self.m_instrs["load"]])) + + # Verify the error message + self.assertIn("Invalid Variable name", str(context.exception)) + + @patch("linker.steps.variable_discovery.cinst") + @patch("linker.steps.variable_discovery.CInstruction") + @patch("assembler.memory_model.variable.Variable.validateName") + def test_discover_variables_spad_valid( + self, mock_validate, mock_cinst_class, mock_cinst + ): + """Test discovering variables from valid CInstructions.""" + # Setup mocks + mock_cinst.BLoad = MagicMock() + mock_cinst.CLoad = MagicMock() + mock_cinst.BOnes = MagicMock() + mock_cinst.NLoad = MagicMock() + mock_cinst.CStore = MagicMock() + + # Configure validateName to return True + mock_validate.return_value = True + + # Test with a list containing all types of CInstructions + cinstrs = [ + self.c_instrs["bload"], + self.c_instrs["cload"], + self.c_instrs["bones"], + self.c_instrs["nload"], + self.c_instrs["cstore"], + self.c_instrs["other"], + ] + + # Improved mock for isinstance that handles tuples of classes + def mock_isinstance(obj, cls): + # Handle tuple case first + if isinstance(cls, tuple): + return any(mock_isinstance(obj, c) for c in cls) + + # Use a dictionary to map class types to their respective checks + class_checks = { + mock_cinst_class: lambda: obj in cinstrs, + mock_cinst.BLoad: lambda: obj is self.c_instrs["bload"], + mock_cinst.CLoad: lambda: obj is self.c_instrs["cload"], + mock_cinst.BOnes: lambda: obj is self.c_instrs["bones"], + mock_cinst.NLoad: lambda: obj is self.c_instrs["nload"], + mock_cinst.CStore: lambda: obj is self.c_instrs["cstore"], + } + + # Check if cls is in our mapping and return the result of its check function + return class_checks.get(cls, lambda: False)() + + # Patch the isinstance calls at the module level + with patch( + "linker.steps.variable_discovery.isinstance", side_effect=mock_isinstance + ): + # Call the function + result = list(discoverVariablesSPAD(cinstrs)) + + # Verify results + self.assertEqual(result, ["var3", "var4", "var5", "var6", "var7"]) + mock_validate.assert_any_call("var3") + mock_validate.assert_any_call("var4") + mock_validate.assert_any_call("var5") + mock_validate.assert_any_call("var6") + mock_validate.assert_any_call("var7") + + def test_discover_variables_spad_empty_list(self): + """Test discovering variables from an empty list of CInstructions.""" + # Call the function with an empty list + result = list(discoverVariablesSPAD([])) + + # Verify results - should be an empty list + self.assertEqual(result, []) + + def test_discover_variables_spad_invalid_type(self): + """Test discovering variables with invalid types in the list.""" + # Setup mock + invalid_obj = MagicMock() + + with patch("linker.steps.variable_discovery.isinstance", return_value=False): + # Call the function with a list containing an invalid type + with self.assertRaises(TypeError) as context: + list(discoverVariablesSPAD([invalid_obj])) + + # Verify the error message + self.assertIn("not a valid MInstruction", str(context.exception)) + + @patch("linker.steps.variable_discovery.cinst") + @patch("linker.steps.variable_discovery.CInstruction") + @patch("assembler.memory_model.variable.Variable.validateName") + def test_discover_variables_spad_invalid_variable_name( + self, mock_validate, mock_cinst_class, mock_cinst + ): + """Test discovering variables with an invalid variable name.""" + # Setup mocks + mock_cinst.BLoad = MagicMock() + + # Configure validateName to return False + mock_validate.return_value = False + + # Mock isinstance to make our mock object appear as CInstruction and BLoad + with patch( + "linker.steps.variable_discovery.isinstance", + side_effect=lambda obj, cls: { + mock_cinst_class: True, + ( + mock_cinst.BLoad, + mock_cinst.CLoad, + mock_cinst.BOnes, + mock_cinst.NLoad, + ): True, + }.get(cls, False), + ): + # Call the function + with self.assertRaises(RuntimeError) as context: + list(discoverVariablesSPAD([self.c_instrs["bload"]])) + + # Verify the error message + self.assertIn("Invalid Variable name", str(context.exception)) + + +if __name__ == "__main__": + unittest.main() From 359ba9649f5d581fd590f3d5fe87e9ab5e0034e0 Mon Sep 17 00:00:00 2001 From: "Rojas Chaves, Jose" Date: Mon, 7 Jul 2025 16:58:06 +0000 Subject: [PATCH 05/12] Cleaning --- .pre-commit-config.yaml | 2 - .../assembler/memory_model/mem_info.py | 4 +- .../hec-assembler-tools/he_link.py | 85 +++++++++++-------- .../hec-assembler-tools/linker/__init__.py | 15 ++-- .../linker/instructions/__init__.py | 9 +- .../linker/instructions/dinst/__init__.py | 5 +- .../linker/instructions/dinst/dinstruction.py | 3 + .../linker/instructions/dinst/dkeygen.py | 3 + .../linker/instructions/dinst/dload.py | 3 + .../linker/instructions/dinst/dstore.py | 3 + .../linker/instructions/instruction.py | 3 + .../hec-assembler-tools/linker/loader.py | 38 ++++++--- .../tests/unit_tests/test_he_link.py | 48 +++++------ .../tests/unit_tests/test_he_prep.py | 3 + .../tests/unit_tests/test_linker/test_init.py | 3 + .../test_instructions/test_init.py | 3 + .../test_instructions/test_instruction.py | 3 + .../unit_tests/test_linker/test_loader.py | 3 + .../test_steps/test_program_linker.py | 3 + .../test_steps/test_variable_discovery.py | 3 + 20 files changed, 161 insertions(+), 81 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ffdc582f..84e1b336 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -61,8 +61,6 @@ repos: *assembler_tools/hec-assembler-tools/debug_tools/main\.py| *assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/| *assembler_tools/hec-assembler-tools/he_as\.py| - *assembler_tools/hec-assembler-tools/linker/__init__\.py| - *assembler_tools/hec-assembler-tools/linker/instructions/__init__\.py| *assembler_tools/hec-assembler-tools/assembler/spec_config/isa_spec.py) args: ["--follow-imports=skip", "--install-types", "--non-interactive"] - repo: local diff --git a/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_info.py b/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_info.py index d039851e..8d034d59 100644 --- a/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_info.py +++ b/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_info.py @@ -1,6 +1,9 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +# These contents may have been developed with support from one +# or more Intel-operated generative artificial intelligence solutions + from assembler.common import constants from assembler.instructions import tokenize_from_line from typing import Optional @@ -591,7 +594,6 @@ def from_dinstrs(cls, dinstrs): if tokens and len(tokens) > 0: try: retval.add_meminfo_var_from_tokens(tokens) - print(f"Added {tokens} to MemInfo") except RuntimeError as e: raise RuntimeError(f"{e} {ints_no}: {tokens}") from e diff --git a/assembler_tools/hec-assembler-tools/he_link.py b/assembler_tools/hec-assembler-tools/he_link.py index 0d9f39a6..f12b97d8 100644 --- a/assembler_tools/hec-assembler-tools/he_link.py +++ b/assembler_tools/hec-assembler-tools/he_link.py @@ -44,6 +44,23 @@ from linker.instructions import BaseInstruction +class NullIO: + """ + @class NullIO + @brief A class that provides a no-operation implementation of write and flush methods. + """ + + def write(self, *argts, **kwargs): + """ + @brief A no-operation write method. + """ + + def flush(self): + """ + @brief A no-operation flush method. + """ + + class LinkerRunConfig(RunConfig): """ @class LinkerRunConfig @@ -58,7 +75,7 @@ class LinkerRunConfig(RunConfig): # Type annotations for class attributes input_prefixes: list[str] input_mem_file: str - find_mem_files: bool + multi_mem_files: bool output_dir: str output_prefix: str @@ -122,7 +139,7 @@ def init_default_config(cls): if not cls.__initialized: cls.__default_config["input_prefixes"] = None cls.__default_config["input_mem_file"] = "" - cls.__default_config["find_mem_files"] = False + cls.__default_config["multi_mem_files"] = False cls.__default_config["output_dir"] = os.getcwd() cls.__default_config["output_prefix"] = None @@ -173,7 +190,7 @@ class KernelFiles(NamedTuple): Index = 3. Name for file containing XInstructions for represented kernel. @var mem Index = 4. Name for file containing memory information for represented kernel. - This is used only when find_mem_files is set. + This is used only when multi_mem_files is set. """ prefix: str @@ -231,7 +248,7 @@ def prepare_output_files(run_config) -> KernelFiles: output_dir = os.path.dirname(output_prefix) pathlib.Path(output_dir).mkdir(exist_ok=True, parents=True) out_mem_file = ( - makeUniquePath(output_prefix + ".mem") if run_config.find_mem_files else None + makeUniquePath(output_prefix + ".mem") if run_config.multi_mem_files else None ) return KernelFiles( prefix=makeUniquePath(output_prefix), @@ -255,7 +272,7 @@ def prepare_input_files(run_config, output_files) -> list: input_files = [] for file_prefix in run_config.input_prefixes: mem_file = ( - makeUniquePath(file_prefix + ".mem") if run_config.find_mem_files else None + makeUniquePath(file_prefix + ".mem") if run_config.multi_mem_files else None ) kernel_files = KernelFiles( prefix=makeUniquePath(file_prefix), @@ -322,7 +339,7 @@ def check_unused_variables(mem_model): ) -def main(run_config: LinkerRunConfig, verbose_stream=None): +def main(run_config: LinkerRunConfig, verbose_stream=NullIO()): """ @brief Executes the linking process using the provided configuration. @@ -354,15 +371,15 @@ def main(run_config: LinkerRunConfig, verbose_stream=None): Counter.reset() # parse mem file + print("Linking...", file=verbose_stream) + print("", file=verbose_stream) + print("Interpreting variable meta information...", file=verbose_stream) - if verbose_stream: - print("Linking...", file=verbose_stream) - print("", file=verbose_stream) - print("Interpreting variable meta information...", file=verbose_stream) - - if run_config.find_mem_files: + if run_config.multi_mem_files: kernels_dinstrs = [] for kernel in input_files: + if kernel.mem is None: + raise RuntimeError(f"Memory file not found for kernel {kernel.prefix}") kernel_dinstrs = loader.load_dinst_kernel_from_file(kernel.mem) kernels_dinstrs.append(kernel_dinstrs) @@ -376,36 +393,34 @@ def main(run_config: LinkerRunConfig, verbose_stream=None): mem_meta_info = mem_info.MemInfo.from_file_iter(mem_ifnum) # Initialize memory model - if verbose_stream: - print("Initializing linker memory model", file=verbose_stream) + print("Initializing linker memory model", file=verbose_stream) mem_model = linker.MemoryModel(hbm_capacity_words, mem_meta_info) - if verbose_stream: - print(f" HBM capacity: {mem_model.hbm.capacity} words", file=verbose_stream) + print(f" HBM capacity: {mem_model.hbm.capacity} words", file=verbose_stream) - print(" Finding all program variables...", file=verbose_stream) - print(" Scanning", file=verbose_stream) + print(" Finding all program variables...", file=verbose_stream) + print(" Scanning", file=verbose_stream) scan_variables(input_files, mem_model, verbose_stream) check_unused_variables(mem_model) - if verbose_stream: - print(f" Variables found: {len(mem_model.variables)}", file=verbose_stream) - print("Linking started", file=verbose_stream) + print(f" Variables found: {len(mem_model.variables)}", file=verbose_stream) + print("Linking started", file=verbose_stream) link_kernels(input_files, output_files, mem_model, verbose_stream) # Write the memory model to the output file - if run_config.find_mem_files: + if run_config.multi_mem_files: + if output_files.mem is None: + raise RuntimeError("Output memory file path is None") BaseInstruction.dump_instructions_to_file(kernel_dinstrs, output_files.mem) - if verbose_stream: - print("Output written to files:", file=verbose_stream) - print(" ", output_files.minst, file=verbose_stream) - print(" ", output_files.cinst, file=verbose_stream) - print(" ", output_files.xinst, file=verbose_stream) - if run_config.find_mem_files: - print(" ", output_files.mem, file=verbose_stream) + print("Output written to files:", file=verbose_stream) + print(" ", output_files.minst, file=verbose_stream) + print(" ", output_files.cinst, file=verbose_stream) + print(" ", output_files.xinst, file=verbose_stream) + if run_config.multi_mem_files: + print(" ", output_files.mem, file=verbose_stream) def parse_args(): @@ -451,9 +466,9 @@ def parse_args(): help=("Input ISA specification (.json) file."), ) parser.add_argument( - "--find_mem_files", + "--multi_mem_files", action="store_true", - dest="find_mem_files", + dest="multi_mem_files", help=( "Tells the linker to find a memory file (*.tw.mem) for each input prefix given." "This can be used to link multiple kernels together. " @@ -523,10 +538,10 @@ def parse_args(): ) p_args = parser.parse_args() - # Enforce input_mem_file only if find_mem_files is not set - if not p_args.find_mem_files and p_args.input_mem_file == "": + # Enforce input_mem_file only if multi_mem_files is not set + if not p_args.multi_mem_files and p_args.input_mem_file == "": parser.error( - "the following arguments are required: -im/--input_mem_file (unless --find_mem_files is set)" + "the following arguments are required: -im/--input_mem_file (unless --multi_mem_files is set)" ) return p_args @@ -554,7 +569,7 @@ def parse_args(): print("=================") print() - main(config, sys.stdout if args.verbose > 1 else None) + main(config, sys.stdout if args.verbose > 1 else NullIO()) if args.verbose > 0: print() diff --git a/assembler_tools/hec-assembler-tools/linker/__init__.py b/assembler_tools/hec-assembler-tools/linker/__init__.py index 9dadd1f4..5fb57cda 100644 --- a/assembler_tools/hec-assembler-tools/linker/__init__.py +++ b/assembler_tools/hec-assembler-tools/linker/__init__.py @@ -1,12 +1,15 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + +"""linker/__init__.py contains classes to encapsulate the memory model used by the linker.""" + import collections.abc as collections from assembler.common.config import GlobalConfig from assembler.memory_model import mem_info - -# linker/__init__.py contains classes to encapsulate the memory model used -# by the linker. +from typing import Dict class VariableInfo(mem_info.MemInfoVariable): @@ -164,7 +167,9 @@ def __init__(self, hbm_size_words: int, mem_meta_info: mem_info.MemInfo): """ self.hbm = HBM(hbm_size_words) self.__mem_info = mem_meta_info - self.__variables = {} # dict(var_name: str, VariableInfo) + self.__variables: Dict[str, VariableInfo] = ( + {} + ) # dict(var_name: str, VariableInfo) self.__keygen_vars = { var_info.var_name: var_info for var_info in self.__mem_info.keygens } @@ -293,6 +298,6 @@ def useVariable(self, var_name: str, kernel: int) -> int: assert var_info.hbm_address >= 0 assert ( self.hbm.buffer[var_info.hbm_address].var_name == var_info.var_name - ), f"Expected variable {var_info.var_name} in HBM {var_info.hbm_address}, but variable {self.hbm[var_info.hbm_address].var_name} found instead." + ), f"Expected variable {var_info.var_name} in HBM {var_info.hbm_address}, but variable {self.hbm.buffer[var_info.hbm_address].var_name} found instead." return var_info.hbm_address diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/__init__.py b/assembler_tools/hec-assembler-tools/linker/instructions/__init__.py index 2ccbf24d..fdb88b28 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/__init__.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/__init__.py @@ -1,10 +1,17 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + +"""This module provides functionality to create instruction objects from a line of text.""" + +from typing import Optional from assembler.instructions import tokenize_from_line from linker.instructions.instruction import BaseInstruction -def create_from_str_line(line: str, factory) -> BaseInstruction: + +def create_from_str_line(line: str, factory) -> Optional[BaseInstruction]: """ Parses an instruction from a line of text. diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/__init__.py b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/__init__.py index d345637e..17371fae 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/__init__.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/__init__.py @@ -1,6 +1,9 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + """This module provides functionality to create and manage data instructions""" from typing import Optional @@ -36,13 +39,11 @@ def create_from_mem_line(line: str) -> dinstruction.DInstruction: DInstruction or None: The parsed DInstruction object, or None if no object could be parsed from the specified input line. """ - print(f"ROCHA: create_from_mem_line {line}") retval: Optional[dinstruction.DInstruction] = None tokens, comment = tokenize_from_line(line) for instr_type in factory(): try: retval = instr_type(tokens, comment) - print(f"ROCHA: {instr_type.__name__} {tokens} {retval}") except ValueError: retval = None if retval: diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dinstruction.py b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dinstruction.py index 72358a4a..4ccb2d0c 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dinstruction.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dinstruction.py @@ -1,6 +1,9 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + """ This module defines the base DInstruction class for data handling instructions. diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dkeygen.py b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dkeygen.py index 543633a2..70866265 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dkeygen.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dkeygen.py @@ -1,6 +1,9 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + """ This module implements the DKeyGen instruction for key generation operations. """ diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dload.py b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dload.py index c6fb694b..63322a46 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dload.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dload.py @@ -1,6 +1,9 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + """ This module implements the DLoad instruction for loading data from memory. diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dstore.py b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dstore.py index fef41739..a759f971 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dstore.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dstore.py @@ -1,6 +1,9 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + """ This module implements the DStore instruction for storing data to memory. diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py b/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py index 1764fcf3..fcc10be8 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py @@ -1,6 +1,9 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + """ Base class for all instructions in the linker. """ diff --git a/assembler_tools/hec-assembler-tools/linker/loader.py b/assembler_tools/hec-assembler-tools/linker/loader.py index bbbe5b83..ea20f6c1 100644 --- a/assembler_tools/hec-assembler-tools/linker/loader.py +++ b/assembler_tools/hec-assembler-tools/linker/loader.py @@ -1,9 +1,19 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# These contents may have been developed with support from one +# or more Intel-operated generative artificial intelligence solutions + +""" +This module provides functionality to load different types of instruction kernels +""" + from linker.instructions import minst from linker.instructions import cinst from linker.instructions import xinst from linker.instructions import dinst from linker import instructions -from assembler.memory_model.mem_info import MemInfo + def load_minst_kernel(line_iter) -> list: """ @@ -22,10 +32,11 @@ def load_minst_kernel(line_iter) -> list: for idx, s_line in enumerate(line_iter): minstr = instructions.create_from_str_line(s_line, minst.factory()) if not minstr: - raise RuntimeError(f'Error parsing line {idx + 1}: {s_line}') + raise RuntimeError(f"Error parsing line {idx + 1}: {s_line}") retval.append(minstr) return retval + def load_minst_kernel_from_file(filename: str) -> list: """ Loads MInstruction kernel from a file. @@ -39,12 +50,13 @@ def load_minst_kernel_from_file(filename: str) -> list: Raises: RuntimeError: If an error occurs while loading the file. """ - with open(filename, 'r') as kernel_minsts: + with open(filename, "r", encoding="utf-8") as kernel_minsts: try: return load_minst_kernel(kernel_minsts) except Exception as e: raise RuntimeError(f'Error occurred loading file "{filename}"') from e + def load_cinst_kernel(line_iter) -> list: """ Loads CInstruction kernel from an iterator of lines. @@ -62,10 +74,11 @@ def load_cinst_kernel(line_iter) -> list: for idx, s_line in enumerate(line_iter): cinstr = instructions.create_from_str_line(s_line, cinst.factory()) if not cinstr: - raise RuntimeError(f'Error parsing line {idx + 1}: {s_line}') + raise RuntimeError(f"Error parsing line {idx + 1}: {s_line}") retval.append(cinstr) return retval + def load_cinst_kernel_from_file(filename: str) -> list: """ Loads CInstruction kernel from a file. @@ -79,12 +92,13 @@ def load_cinst_kernel_from_file(filename: str) -> list: Raises: RuntimeError: If an error occurs while loading the file. """ - with open(filename, 'r') as kernel_cinsts: + with open(filename, "r", encoding="utf-8") as kernel_cinsts: try: return load_cinst_kernel(kernel_cinsts) except Exception as e: raise RuntimeError(f'Error occurred loading file "{filename}"') from e + def load_xinst_kernel(line_iter) -> list: """ Loads XInstruction kernel from an iterator of lines. @@ -102,10 +116,11 @@ def load_xinst_kernel(line_iter) -> list: for idx, s_line in enumerate(line_iter): xinstr = instructions.create_from_str_line(s_line, xinst.factory()) if not xinstr: - raise RuntimeError(f'Error parsing line {idx + 1}: {s_line}') + raise RuntimeError(f"Error parsing line {idx + 1}: {s_line}") retval.append(xinstr) return retval + def load_xinst_kernel_from_file(filename: str) -> list: """ Loads XInstruction kernel from a file. @@ -119,12 +134,13 @@ def load_xinst_kernel_from_file(filename: str) -> list: Raises: RuntimeError: If an error occurs while loading the file. """ - with open(filename, 'r') as kernel_xinsts: + with open(filename, "r", encoding="utf-8") as kernel_xinsts: try: return load_xinst_kernel(kernel_xinsts) except Exception as e: raise RuntimeError(f'Error occurred loading file "{filename}"') from e + def load_dinst_kernel(line_iter) -> list: """ Loads DInstruction kernel from an iterator of lines. @@ -142,11 +158,12 @@ def load_dinst_kernel(line_iter) -> list: for idx, s_line in enumerate(line_iter): dinstr = dinst.create_from_mem_line(s_line) if not dinstr: - raise RuntimeError(f'Error parsing line {idx + 1}: {s_line}') + raise RuntimeError(f"Error parsing line {idx + 1}: {s_line}") retval.append(dinstr) - + return retval + def load_dinst_kernel_from_file(filename: str) -> list: """ Loads DInstruction kernel from a file. @@ -160,9 +177,8 @@ def load_dinst_kernel_from_file(filename: str) -> list: Raises: RuntimeError: If an error occurs while loading the file. """ - with open(filename, 'r') as kernel_dinsts: + with open(filename, "r", encoding="utf-8") as kernel_dinsts: try: return load_dinst_kernel(kernel_dinsts) except Exception as e: raise RuntimeError(f'Error occurred loading file "{filename}"') from e - \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_link.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_link.py index 446fa522..36439e21 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_link.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_link.py @@ -38,7 +38,7 @@ def test_init_with_valid_params(self): "hbm_size": 1024, "suppress_comments": False, "use_xinstfetch": False, - "find_mem_files": False, + "multi_mem_files": False, } # Act @@ -54,7 +54,7 @@ def test_init_with_valid_params(self): assert config.hbm_size == 1024 assert config.suppress_comments is False assert config.use_xinstfetch is False - assert config.find_mem_files is False + assert config.multi_mem_files is False def test_init_with_missing_required_param(self): """ @@ -166,7 +166,7 @@ def test_init_for_default_params(self): assert config.hbm_size == 1024 assert config.suppress_comments is False assert config.use_xinstfetch is False - assert config.find_mem_files is False + assert config.multi_mem_files is False class TestKernelFiles: @@ -229,7 +229,7 @@ def test_prepare_output_files(self): mock_config = MagicMock() mock_config.output_dir = "/tmp" mock_config.output_prefix = "output" - mock_config.find_mem_files = False + mock_config.multi_mem_files = False # Act with patch("os.path.dirname", return_value="/tmp"), patch( @@ -246,13 +246,13 @@ def test_prepare_output_files(self): def test_prepare_output_files_with_mem(self): """ - @brief Test prepare_output_files with find_mem_files=True + @brief Test prepare_output_files with multi_mem_files=True """ # Arrange mock_config = MagicMock() mock_config.output_dir = "/tmp" mock_config.output_prefix = "output" - mock_config.find_mem_files = True + mock_config.multi_mem_files = True # Act with patch("os.path.dirname", return_value="/tmp"), patch( @@ -274,7 +274,7 @@ def test_prepare_input_files(self): # Arrange mock_config = MagicMock() mock_config.input_prefixes = ["/tmp/input1", "/tmp/input2"] - mock_config.find_mem_files = False + mock_config.multi_mem_files = False mock_output_files = he_link.KernelFiles( prefix="/tmp/output", @@ -305,7 +305,7 @@ def test_prepare_input_files_file_not_found(self): # Arrange mock_config = MagicMock() mock_config.input_prefixes = ["/tmp/input1"] - mock_config.find_mem_files = False + mock_config.multi_mem_files = False mock_output_files = he_link.KernelFiles( prefix="/tmp/output", @@ -328,7 +328,7 @@ def test_prepare_input_files_output_conflict(self): # Arrange mock_config = MagicMock() mock_config.input_prefixes = ["/tmp/input1"] - mock_config.find_mem_files = False + mock_config.multi_mem_files = False # Output file matching an input file mock_output_files = he_link.KernelFiles( @@ -447,14 +447,14 @@ class TestMainFunction: @brief Test cases for the main function """ - @pytest.mark.parametrize("find_mem_files", [True, False]) - def test_main(self, find_mem_files): + @pytest.mark.parametrize("multi_mem_files", [True, False]) + def test_main(self, multi_mem_files): """ - @brief Test main function with find_mem_files=True + @brief Test main function with multi_mem_files=True """ # Arrange mock_config = MagicMock() - mock_config.find_mem_files = find_mem_files + mock_config.multi_mem_files = multi_mem_files mock_config.has_hbm = True mock_config.hbm_size = 1024 mock_config.suppress_comments = False @@ -517,7 +517,7 @@ def test_main(self, find_mem_files): mock_check_unused_variables.assert_called_once() mock_link_kernels.assert_called_once() - if find_mem_files: + if multi_mem_files: # Should use from_dinstrs, not from_file_iter assert mock_from_dinstrs.called assert mock_load_dinst_kernel_from_file.called @@ -536,7 +536,7 @@ def test_warning_on_use_xinstfetch(self): """ # Arrange mock_config = MagicMock() - mock_config.find_mem_files = False + mock_config.multi_mem_files = False mock_config.has_hbm = True mock_config.hbm_size = 1024 mock_config.suppress_comments = False @@ -595,7 +595,7 @@ def test_parse_args_minimal(self): output_prefix="output_prefix", input_mem_file="input.mem", output_dir="", - find_mem_files=False, + multi_mem_files=False, mem_spec_file="", isa_spec_file="", has_hbm=True, @@ -610,11 +610,11 @@ def test_parse_args_minimal(self): assert args.input_prefixes == ["input_prefix"] assert args.output_prefix == "output_prefix" assert args.input_mem_file == "input.mem" - assert args.find_mem_files is False + assert args.multi_mem_files is False - def test_parse_args_find_mem_files(self): + def test_parse_args_multi_mem_files(self): """ - @brief Test parse_args with find_mem_files flag + @brief Test parse_args with multi_mem_files flag """ # Arrange test_args = [ @@ -622,7 +622,7 @@ def test_parse_args_find_mem_files(self): "input_prefix", "-o", "output_prefix", - "--find_mem_files", + "--multi_mem_files", ] # Act @@ -633,7 +633,7 @@ def test_parse_args_find_mem_files(self): output_prefix="output_prefix", input_mem_file="", output_dir="", - find_mem_files=True, + multi_mem_files=True, mem_spec_file="", isa_spec_file="", has_hbm=True, @@ -648,11 +648,11 @@ def test_parse_args_find_mem_files(self): assert args.input_prefixes == ["input_prefix"] assert args.output_prefix == "output_prefix" assert args.input_mem_file == "" - assert args.find_mem_files is True + assert args.multi_mem_files is True def test_missing_input_mem_file(self): """ - @brief Test parse_args with missing input_mem_file when find_mem_files is False + @brief Test parse_args with missing input_mem_file when multi_mem_files is False """ # Arrange test_args = ["program", "input_prefix", "-o", "output_prefix"] @@ -665,7 +665,7 @@ def test_missing_input_mem_file(self): output_prefix="output_prefix", input_mem_file="", output_dir="", - find_mem_files=False, + multi_mem_files=False, mem_spec_file="", isa_spec_file="", has_hbm=True, diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_prep.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_prep.py index c87ad818..68078899 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_prep.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_prep.py @@ -1,6 +1,9 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + """ Unit tests for he_prep module. """ diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_init.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_init.py index 7ef5e968..cb354ed7 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_init.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_init.py @@ -1,6 +1,9 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + """ Unit tests for the memory model classes in linker/__init__.py. """ diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_init.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_init.py index ea2e11b8..e688a545 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_init.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_init.py @@ -1,6 +1,9 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + """ Unit tests for the linker instructions initialization module. diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_instruction.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_instruction.py index 604e872e..0f8245e1 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_instruction.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_instruction.py @@ -1,6 +1,9 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + """ Unit tests for the BaseInstruction class. """ diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_loader.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_loader.py index 2c6d1728..05dd61dd 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_loader.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_loader.py @@ -1,6 +1,9 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + """ Unit tests for the loader module. """ diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_program_linker.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_program_linker.py index 5147c332..a4c58fde 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_program_linker.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_program_linker.py @@ -1,6 +1,9 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + """ Unit tests for the program_linker module. """ diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_variable_discovery.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_variable_discovery.py index 6d1a471d..41e1844b 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_variable_discovery.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_variable_discovery.py @@ -1,6 +1,9 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + """ Unit tests for the variable discovery module. """ From ccf69fd4a6bb32b9f0d2454b4934a541108532d4 Mon Sep 17 00:00:00 2001 From: "Rojas Chaves, Jose" Date: Mon, 7 Jul 2025 18:04:49 +0000 Subject: [PATCH 06/12] Doxygen style --- .../hec-assembler-tools/he_link.py | 4 +- .../hec-assembler-tools/linker/__init__.py | 99 ++++++-------- .../linker/instructions/__init__.py | 14 +- .../linker/instructions/dinst/__init__.py | 19 ++- .../linker/instructions/dinst/dinstruction.py | 72 +++++----- .../linker/instructions/dinst/dkeygen.py | 14 +- .../linker/instructions/dinst/dload.py | 19 ++- .../linker/instructions/dinst/dstore.py | 19 ++- .../linker/instructions/instruction.py | 92 +++++-------- .../hec-assembler-tools/linker/loader.py | 106 +++++---------- .../linker/steps/program_linker.py | 15 +- .../linker/steps/variable_discovery.py | 64 +++++---- .../memory_model/test_mem_info.py | 118 ++++++++-------- .../tests/unit_tests/test_he_link.py | 128 ++++++++++-------- .../tests/unit_tests/test_he_prep.py | 25 ++-- .../tests/unit_tests/test_linker/test_init.py | 122 +++++++++++++---- .../test_dinst/test_dinstruction.py | 50 +++++-- .../test_dinst/test_dkeygen.py | 51 +++++-- .../test_dinst/test_dload.py | 56 ++++++-- .../test_dinst/test_dstore.py | 51 +++++-- .../test_instructions/test_dinst/test_init.py | 41 ++++-- .../test_instructions/test_init.py | 31 ++++- .../test_instructions/test_instruction.py | 68 +++++++--- .../unit_tests/test_linker/test_loader.py | 94 +++++++++---- .../test_steps/test_program_linker.py | 120 ++++++++++++---- .../test_steps/test_variable_discovery.py | 64 ++++++--- 26 files changed, 944 insertions(+), 612 deletions(-) diff --git a/assembler_tools/hec-assembler-tools/he_link.py b/assembler_tools/hec-assembler-tools/he_link.py index f12b97d8..68a2c00e 100644 --- a/assembler_tools/hec-assembler-tools/he_link.py +++ b/assembler_tools/hec-assembler-tools/he_link.py @@ -310,7 +310,7 @@ def scan_variables(input_files, mem_model, verbose_stream): file=verbose_stream, ) kernel_cinstrs = loader.load_cinst_kernel_from_file(kernel.cinst) - for var_name in variable_discovery.discoverVariablesSPAD(kernel_cinstrs): + for var_name in variable_discovery.discover_variables_spad(kernel_cinstrs): mem_model.addVariable(var_name) else: if verbose_stream: @@ -320,7 +320,7 @@ def scan_variables(input_files, mem_model, verbose_stream): file=verbose_stream, ) kernel_minstrs = loader.load_minst_kernel_from_file(kernel.minst) - for var_name in variable_discovery.discoverVariables(kernel_minstrs): + for var_name in variable_discovery.discover_variables(kernel_minstrs): mem_model.addVariable(var_name) diff --git a/assembler_tools/hec-assembler-tools/linker/__init__.py b/assembler_tools/hec-assembler-tools/linker/__init__.py index 5fb57cda..ded86921 100644 --- a/assembler_tools/hec-assembler-tools/linker/__init__.py +++ b/assembler_tools/hec-assembler-tools/linker/__init__.py @@ -4,7 +4,7 @@ # These contents may have been developed with support from one or more Intel-operated # generative artificial intelligence solutions -"""linker/__init__.py contains classes to encapsulate the memory model used by the linker.""" +"""@brief linker/__init__.py contains classes to encapsulate the memory model used by the linker.""" import collections.abc as collections from assembler.common.config import GlobalConfig @@ -14,16 +14,15 @@ class VariableInfo(mem_info.MemInfoVariable): """ - Represents information about a variable in the memory model. + @brief Represents information about a variable in the memory model. """ def __init__(self, var_name, hbm_address=-1): """ - Initializes a VariableInfo object. + @brief Initializes a VariableInfo object. - Parameters: - var_name (str): The name of the variable. - hbm_address (int): The HBM address of the variable. Defaults to -1. + @param var_name The name of the variable. + @param hbm_address The HBM address of the variable. Defaults to -1. """ super().__init__(var_name, hbm_address) self.uses = 0 @@ -32,18 +31,15 @@ def __init__(self, var_name, hbm_address=-1): class HBM: """ - Represents the HBM model. + @brief Represents the HBM model. """ def __init__(self, hbm_size_words: int): """ - Initializes an HBM object. + @brief Initializes an HBM object. - Parameters: - hbm_size_words (int): The size of the HBM in words. - - Raises: - ValueError: If hbm_size_words is less than 1. + @param hbm_size_words The size of the HBM in words. + @throws ValueError If hbm_size_words is less than 1. """ if hbm_size_words < 1: raise ValueError("`hbm_size_words` must be a positive integer.") @@ -53,35 +49,30 @@ def __init__(self, hbm_size_words: int): @property def capacity(self) -> int: """ - Gets the capacity in words for the HBM buffer. + @brief Gets the capacity in words for the HBM buffer. - Returns: - int: The capacity of the HBM buffer. + @return The capacity of the HBM buffer. """ return len(self.buffer) @property def buffer(self) -> list: """ - Gets the HBM buffer. + @brief Gets the HBM buffer. - Returns: - list: The HBM buffer. + @return The HBM buffer. """ return self.__buffer def forceAllocate(self, var_info: VariableInfo, hbm_address: int): """ - Forcefully allocates a variable at a specific HBM address. - - Parameters: - var_info (VariableInfo): The variable information. - hbm_address (int): The HBM address to allocate the variable. + @brief Forcefully allocates a variable at a specific HBM address. - Raises: - IndexError: If hbm_address is out of bounds. - ValueError: If the variable is already allocated at a different address. - RuntimeError: If the HBM address is already occupied by another variable. + @param var_info The variable information. + @param hbm_address The HBM address to allocate the variable. + @throws IndexError If hbm_address is out of bounds. + @throws ValueError If the variable is already allocated at a different address. + @throws RuntimeError If the HBM address is already occupied by another variable. """ if hbm_address < 0 or hbm_address >= len(self.buffer): raise IndexError( @@ -123,13 +114,10 @@ def forceAllocate(self, var_info: VariableInfo, hbm_address: int): def allocate(self, var_info: VariableInfo): """ - Allocates a variable in the HBM. + @brief Allocates a variable in the HBM. - Parameters: - var_info (VariableInfo): The variable information. - - Raises: - RuntimeError: If there is no available HBM memory. + @param var_info The variable information. + @throws RuntimeError If there is no available HBM memory. """ # Find next available HBM address retval = -1 @@ -154,16 +142,15 @@ def allocate(self, var_info: VariableInfo): class MemoryModel: """ - Encapsulates the memory model for a linker run, tracking HBM usage and program variables. + @brief Encapsulates the memory model for a linker run, tracking HBM usage and program variables. """ def __init__(self, hbm_size_words: int, mem_meta_info: mem_info.MemInfo): """ - Initializes a MemoryModel object. + @brief Initializes a MemoryModel object. - Parameters: - hbm_size_words (int): The size of the HBM in words. - mem_meta_info (mem_info.MemInfo): The memory metadata information. + @param hbm_size_words The size of the HBM in words. + @param mem_meta_info The memory metadata information. """ self.hbm = HBM(hbm_size_words) self.__mem_info = mem_meta_info @@ -219,41 +206,42 @@ def __init__(self, hbm_size_words: int, mem_meta_info: mem_info.MemInfo): @property def mem_info_meta(self) -> collections.Collection: """ - Set of metadata variable names in MemInfo used to construct this object. + @brief Set of metadata variable names in MemInfo used to construct this object. + Clients must not modify this set. + + @return Collection of metadata variable names. """ return self.__mem_info_meta @property def mem_info_vars(self) -> collections.Collection: """ - Gets the set of variable names in MemInfo used to construct this object. + @brief Gets the set of variable names in MemInfo used to construct this object. - Returns: - collections.Collection: The set of variable names. + @return The set of variable names. """ return self.__mem_info_vars @property def variables(self) -> dict: """ - Gets direct access to internal variables dictionary. + @brief Gets direct access to internal variables dictionary. Clients should use as read-only. Must not add, replace, remove or change contents in any way. Use provided helper functions to manipulate. - Returns: - dict: A dictionary of variables. + @return A dictionary of variables. """ return self.__variables def addVariable(self, var_name: str): """ - Adds a variable to the HBM model. If variable already exists, its `uses` - field is incremented. + @brief Adds a variable to the HBM model. + + If variable already exists, its `uses` field is incremented. - Parameters: - var_name (str): The name of the variable to add. + @param var_name The name of the variable to add. """ var_info: VariableInfo if var_name in self.variables: @@ -273,17 +261,14 @@ def addVariable(self, var_name: str): def useVariable(self, var_name: str, kernel: int) -> int: """ - Uses a variable, decrementing its usage count. + @brief Uses a variable, decrementing its usage count. If a variable usage count reaches zero, it will be deallocated from HBM, if needed, when a future kernel requires HBM space. - Parameters: - var_name (str): The name of the variable to use. - kernel (int): The kernel that is using the variable. - - Returns: - int: The HBM address for the variable. + @param var_name The name of the variable to use. + @param kernel The kernel that is using the variable. + @return The HBM address for the variable. """ var_info: VariableInfo = self.variables[var_name] assert var_info.uses > 0 diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/__init__.py b/assembler_tools/hec-assembler-tools/linker/instructions/__init__.py index fdb88b28..57d2d7a5 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/__init__.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/__init__.py @@ -4,7 +4,7 @@ # These contents may have been developed with support from one or more Intel-operated # generative artificial intelligence solutions -"""This module provides functionality to create instruction objects from a line of text.""" +"""@brief This module provides functionality to create instruction objects from a line of text.""" from typing import Optional from assembler.instructions import tokenize_from_line @@ -13,14 +13,12 @@ def create_from_str_line(line: str, factory) -> Optional[BaseInstruction]: """ - Parses an instruction from a line of text. + @brief Parses an instruction from a line of text. - Parameters: - line (str): Line of text from which to parse an instruction. - - Returns: - BaseInstruction or None: The parsed BaseInstruction object, or None if no object could be - parsed from the specified input line. + @param line Line of text from which to parse an instruction. + @param factory Factory function or collection to create instruction objects. + @return The parsed BaseInstruction object, or None if no object could be + parsed from the specified input line. """ retval = None tokens, comment = tokenize_from_line(line) diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/__init__.py b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/__init__.py index 17371fae..1bcc2452 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/__init__.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/__init__.py @@ -4,7 +4,7 @@ # These contents may have been developed with support from one or more Intel-operated # generative artificial intelligence solutions -"""This module provides functionality to create and manage data instructions""" +"""@brief This module provides functionality to create and manage data instructions""" from typing import Optional @@ -20,24 +20,21 @@ def factory() -> set: """ - Creates a set of all DInstruction classes. + @brief Creates a set of all DInstruction classes. - Returns: - set: A set containing all DInstruction classes. + @return A set containing all DInstruction classes. """ return {DLoad, DStore, DKeyGen} def create_from_mem_line(line: str) -> dinstruction.DInstruction: """ - Parses an data instruction from a line of the memory map. + @brief Parses an data instruction from a line of the memory map. - Parameters: - line (str): Line of text from which to parse an instruction. - - Returns: - DInstruction or None: The parsed DInstruction object, or None if no object could be - parsed from the specified input line. + @param line Line of text from which to parse an instruction. + @return The parsed DInstruction object, or None if no object could be + parsed from the specified input line. + @throws RuntimeError If no valid instruction is found or if there's an error parsing the memory map line. """ retval: Optional[dinstruction.DInstruction] = None tokens, comment = tokenize_from_line(line) diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dinstruction.py b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dinstruction.py index 4ccb2d0c..14c69f30 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dinstruction.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dinstruction.py @@ -5,7 +5,7 @@ # generative artificial intelligence solutions """ -This module defines the base DInstruction class for data handling instructions. +@brief This module defines the base DInstruction class for data handling instructions. DInstruction is the parent class for all data instructions used in the assembly process, providing common functionality and interfaces. @@ -18,7 +18,7 @@ class DInstruction(BaseInstruction): """ - Represents a DInstruction, inheriting from BaseInstruction. + @brief Represents a DInstruction, inheriting from BaseInstruction. """ _local_id_count = Counter.count(0) # Local counter for DInstruction IDs @@ -28,56 +28,49 @@ class DInstruction(BaseInstruction): @classmethod def _get_name(cls) -> str: """ - Derived classes should implement this method and return correct + @brief Derived classes should implement this method and return correct name for the instruction. - Raises: - NotImplementedError: Abstract method. This base method should not be called. + @throws NotImplementedError Abstract method. This base method should not be called. """ raise NotImplementedError() @classmethod def _get_name_token_index(cls) -> int: """ - Gets the index of the token containing the name of the instruction. + @brief Gets the index of the token containing the name of the instruction. - Returns: - int: The index of the name token, which is 0. + @return The index of the name token, which is 0. """ return 0 @classmethod def _get_num_tokens(cls) -> int: """ - Derived classes should implement this method and return correct + @brief Derived classes should implement this method and return correct required number of tokens for the instruction. - Raises: - NotImplementedError: Abstract method. This base method should not be called. + @throws NotImplementedError Abstract method. This base method should not be called. """ raise NotImplementedError() @classproperty def num_tokens(self) -> int: """ - Valid number of tokens for this instruction. + @brief Valid number of tokens for this instruction. - Returns: - tuple: Valid number of tokens. + @return Valid number of tokens. """ return self._get_num_tokens() def _validate_tokens(self, tokens: list) -> None: """ - Validates the tokens for this instruction. + @brief Validates the tokens for this instruction. DInstruction allows at least the required number of tokens. - Parameters: - tokens (list): List of tokens to validate. - - Raises: - ValueError: If tokens are invalid. + @param tokens List of tokens to validate. + @throws ValueError If tokens are invalid. """ assert self.name_token_index < self.num_tokens if len(tokens) < self.num_tokens: @@ -92,11 +85,10 @@ def _validate_tokens(self, tokens: list) -> None: def __init__(self, tokens: list, comment: str = ""): """ - Constructs a new DInstruction. + @brief Constructs a new DInstruction. - Parameters: - tokens (list): List of tokens for the instruction. - comment (str): Optional comment for the instruction. + @param tokens List of tokens for the instruction. + @param comment Optional comment for the instruction. """ # Do not increment the global instruction count; skip BaseInstruction's __init__ logic for __id # Call BaseInstruction constructor but perform our own initialization @@ -109,42 +101,54 @@ def __init__(self, tokens: list, comment: str = ""): @property def id(self): """ - Unique ID for the instruction. + @brief Unique ID for the instruction. This is a combination of the client ID specified during construction and a unique nonce per instruction. - Returns: - tuple: (client_id: int, nonce: int) where client_id is the id specified at construction. + @return (client_id: int, nonce: int) where client_id is the id specified at construction. """ return self._local_id @property def var(self) -> str: """ - Name of source/dest var. + @brief Name of source/dest var. + + @return The variable name. """ return self._var @var.setter def var(self, value: str): + """ + @brief Sets the variable name. + + @param value The new variable name. + """ self._var = value @property def address(self) -> int: """ - Should be set to source/dest Mem address. + @brief Should be set to source/dest Mem address. + + @return The memory address. """ return self._address @address.setter - def address(self, value: str): - self._address = int(value) if isinstance(value, str) else value + def address(self, value: int): + """ + @brief Sets the memory address. + + @param value The new memory address (string or integer). + """ + self._address = value def to_line(self) -> str: """ - Retrieves the string form of the instruction to write to the instruction file. + @brief Retrieves the string form of the instruction to write to the instruction file. - Returns: - str: The string representation of the instruction. + @return The string representation of the instruction. """ return ", ".join(self.tokens) diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dkeygen.py b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dkeygen.py index 70866265..3980fdf2 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dkeygen.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dkeygen.py @@ -5,7 +5,7 @@ # generative artificial intelligence solutions """ -This module implements the DKeyGen instruction for key generation operations. +@brief This module implements the DKeyGen instruction for key generation operations. """ from assembler.memory_model.mem_info import MemInfo @@ -15,25 +15,23 @@ class Instruction(DInstruction): """ - Encapsulates a `dkeygen` DInstruction. + @brief Encapsulates a `dkeygen` DInstruction. """ @classmethod def _get_num_tokens(cls) -> int: """ - Gets the number of tokens required for the instruction. + @brief Gets the number of tokens required for the instruction. - Returns: - int: The number of tokens, which is 4. + @return The number of tokens, which is 4. """ return 4 @classmethod def _get_name(cls) -> str: """ - Gets the name of the instruction. + @brief Gets the name of the instruction. - Returns: - str: The name of the instruction. + @return The name of the instruction. """ return MemInfo.Const.Keyword.KEYGEN diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dload.py b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dload.py index 63322a46..e0902735 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dload.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dload.py @@ -5,7 +5,7 @@ # generative artificial intelligence solutions """ -This module implements the DLoad instruction for loading data from memory. +@brief This module implements the DLoad instruction for loading data from memory. The DLoad instruction is used to load data from specified memory locations during the assembly process. @@ -17,35 +17,32 @@ class Instruction(DInstruction): """ - Encapsulates a `dload` DInstruction. + @brief Encapsulates a `dload` DInstruction. """ @classmethod def _get_num_tokens(cls) -> int: """ - Gets the number of tokens required for the instruction. + @brief Gets the number of tokens required for the instruction. - Returns: - int: The number of tokens, which is 3. + @return The number of tokens, which is 3. """ return 3 @classmethod def _get_name(cls) -> str: """ - Gets the name of the instruction. + @brief Gets the name of the instruction. - Returns: - str: The name of the instruction. + @return The name of the instruction. """ return MemInfo.Const.Keyword.LOAD @property def tokens(self) -> list: """ - Gets the list of tokens for the instruction. + @brief Gets the list of tokens for the instruction. - Returns: - list: The list of tokens. + @return The list of tokens. """ return [self.name, self._tokens[1], str(self.address)] + self._tokens[3:] diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dstore.py b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dstore.py index a759f971..61d99546 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dstore.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dstore.py @@ -5,7 +5,7 @@ # generative artificial intelligence solutions """ -This module implements the DStore instruction for storing data to memory. +@brief This module implements the DStore instruction for storing data to memory. The DStore instruction is used to store data to specified memory locations during the assembly process. @@ -17,35 +17,32 @@ class Instruction(DInstruction): """ - Encapsulates a `dstore` DInstruction. + @brief Encapsulates a `dstore` DInstruction. """ @classmethod def _get_num_tokens(cls) -> int: """ - Gets the number of tokens required for the instruction. + @brief Gets the number of tokens required for the instruction. - Returns: - int: The number of tokens, which is 3. + @return The number of tokens, which is 3. """ return 3 @classmethod def _get_name(cls) -> str: """ - Gets the name of the instruction. + @brief Gets the name of the instruction. - Returns: - str: The name of the instruction. + @return The name of the instruction. """ return MemInfo.Const.Keyword.STORE @property def tokens(self) -> list: """ - Gets the list of tokens for the instruction. + @brief Gets the list of tokens for the instruction. - Returns: - list: The list of tokens. + @return The list of tokens. """ return [self.name, self.var, str(self.address)] + self._tokens[3:] diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py b/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py index fcc10be8..76d9a22e 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py @@ -5,7 +5,7 @@ # generative artificial intelligence solutions """ -Base class for all instructions in the linker. +@brief Base class for all instructions in the linker. """ from assembler.common.decorators import classproperty @@ -15,23 +15,17 @@ class BaseInstruction: """ - Base class for all instructions. + @brief Base class for all instructions. This class provides common functionality for all instructions in the linker. - Class Properties: - name (str): Returns the name of the represented operation. + @var comment Comment for the instruction. - Attributes: - comment (str): Comment for the instruction. + @property name Returns the name of the represented operation. + @property tokens List of tokens for the instruction. + @property id Unique instruction ID. This is a unique nonce representing the instruction. - Properties: - tokens (list[str]): List of tokens for the instruction. - id (int): Unique instruction ID. This is a unique nonce representing the instruction. - - Methods: - to_line(self) -> str: - Retrieves the string form of the instruction to write to the instruction file. + @fn to_line Retrieves the string form of the instruction to write to the instruction file. """ __id_count = Counter.count( @@ -44,78 +38,71 @@ class BaseInstruction: @classproperty def name(self) -> str: """ - Name for the instruction. + @brief Name for the instruction. - Returns: - str: The name of the instruction. + @return The name of the instruction. """ return self._get_name() @classmethod def _get_name(cls) -> str: """ - Derived classes should implement this method and return correct + @brief Derived classes should implement this method and return correct name for the instruction. - Raises: - NotImplementedError: Abstract method. This base method should not be called. + @throws NotImplementedError Abstract method. This base method should not be called. """ raise NotImplementedError() @classproperty def name_token_index(self) -> int: """ - Index for the token containing the name of the instruction + @brief Index for the token containing the name of the instruction in the list of tokens. - Returns: - int: The index of the name token. + @return The index of the name token. """ return self._get_name_token_index() @classmethod def _get_name_token_index(cls) -> int: """ - Derived classes should implement this method and return correct + @brief Derived classes should implement this method and return correct index for the token containing the name of the instruction in the list of tokens. - Raises: - NotImplementedError: Abstract method. This base method should not be called. + @throws NotImplementedError Abstract method. This base method should not be called. """ raise NotImplementedError() @classproperty def num_tokens(self) -> int: """ - Number of tokens required for this instruction. + @brief Number of tokens required for this instruction. - Returns: - int: The number of tokens required. + @return The number of tokens required. """ return self._get_num_tokens() @classmethod def _get_num_tokens(cls) -> int: """ - Derived classes should implement this method and return correct + @brief Derived classes should implement this method and return correct required number of tokens for the instruction. - Raises: - NotImplementedError: Abstract method. This base method should not be called. + @throws NotImplementedError Abstract method. This base method should not be called. """ raise NotImplementedError() @classmethod def dump_instructions_to_file(cls, instructions: list, filename: str): """ - Writes a list of instruction objects to a file, one per line. + @brief Writes a list of instruction objects to a file, one per line. Each instruction is converted to its string representation using the `to_line()` method. - Args: - instructions (list): List of instruction objects (must have a to_line() method). - filename (str): Path to the output file. + @param instructions List of instruction objects (must have a to_line() method). + @param filename Path to the output file. """ with open(filename, "w", encoding="utf-8") as f: for instr in instructions: @@ -126,14 +113,11 @@ def dump_instructions_to_file(cls, instructions: list, filename: str): def __init__(self, tokens: list, comment: str = ""): """ - Creates a new BaseInstruction object. - - Parameters: - tokens (list): List of tokens for the instruction. - comment (str): Optional comment for the instruction. + @brief Creates a new BaseInstruction object. - Raises: - ValueError: If the number of tokens is invalid or the instruction name is incorrect. + @param tokens List of tokens for the instruction. + @param comment Optional comment for the instruction. + @throws ValueError If the number of tokens is invalid or the instruction name is incorrect. """ assert self.name_token_index < self.num_tokens @@ -146,16 +130,13 @@ def __init__(self, tokens: list, comment: str = ""): def _validate_tokens(self, tokens: list) -> None: """ - Validates the tokens for this instruction. + @brief Validates the tokens for this instruction. Default implementation checks for exact token count match. Child classes can override this method to implement different validation logic. - Parameters: - tokens (list): List of tokens to validate. - - Raises: - ValueError: If tokens are invalid. + @param tokens List of tokens to validate. + @throws ValueError If tokens are invalid. """ if len(tokens) != self.num_tokens: # pylint: disable=W0143 raise ValueError( @@ -188,31 +169,28 @@ def __str__(self): @property def id(self) -> tuple: """ - Unique ID for the instruction. + @brief Unique ID for the instruction. This is a combination of the client ID specified during construction and a unique nonce per instruction. - Returns: - tuple: (client_id: int, nonce: int) where client_id is the id specified at construction. + @return (client_id: int, nonce: int) where client_id is the id specified at construction. """ return self._id @property def tokens(self) -> list: """ - Gets the list of tokens for the instruction. + @brief Gets the list of tokens for the instruction. - Returns: - list: The list of tokens. + @return The list of tokens. """ return self._tokens def to_line(self) -> str: """ - Retrieves the string form of the instruction to write to the instruction file. + @brief Retrieves the string form of the instruction to write to the instruction file. - Returns: - str: The string representation of the instruction. + @return The string representation of the instruction. """ comment_str = "" if not GlobalConfig.suppress_comments: diff --git a/assembler_tools/hec-assembler-tools/linker/loader.py b/assembler_tools/hec-assembler-tools/linker/loader.py index ea20f6c1..eeee007e 100644 --- a/assembler_tools/hec-assembler-tools/linker/loader.py +++ b/assembler_tools/hec-assembler-tools/linker/loader.py @@ -5,7 +5,7 @@ # or more Intel-operated generative artificial intelligence solutions """ -This module provides functionality to load different types of instruction kernels +@brief This module provides functionality to load different types of instruction kernels """ from linker.instructions import minst @@ -17,16 +17,11 @@ def load_minst_kernel(line_iter) -> list: """ - Loads MInstruction kernel from an iterator of lines. + @brief Loads MInstruction kernel from an iterator of lines. - Parameters: - line_iter: An iterator over lines of MInstruction strings. - - Returns: - list: A list of MInstruction objects. - - Raises: - RuntimeError: If a line cannot be parsed into an MInstruction. + @param line_iter An iterator over lines of MInstruction strings. + @return A list of MInstruction objects. + @throws RuntimeError If a line cannot be parsed into an MInstruction. """ retval = [] for idx, s_line in enumerate(line_iter): @@ -39,16 +34,11 @@ def load_minst_kernel(line_iter) -> list: def load_minst_kernel_from_file(filename: str) -> list: """ - Loads MInstruction kernel from a file. - - Parameters: - filename (str): The file containing MInstruction strings. - - Returns: - list: A list of MInstruction objects. + @brief Loads MInstruction kernel from a file. - Raises: - RuntimeError: If an error occurs while loading the file. + @param filename The file containing MInstruction strings. + @return A list of MInstruction objects. + @throws RuntimeError If an error occurs while loading the file. """ with open(filename, "r", encoding="utf-8") as kernel_minsts: try: @@ -59,16 +49,11 @@ def load_minst_kernel_from_file(filename: str) -> list: def load_cinst_kernel(line_iter) -> list: """ - Loads CInstruction kernel from an iterator of lines. + @brief Loads CInstruction kernel from an iterator of lines. - Parameters: - line_iter: An iterator over lines of CInstruction strings. - - Returns: - list: A list of CInstruction objects. - - Raises: - RuntimeError: If a line cannot be parsed into a CInstruction. + @param line_iter An iterator over lines of CInstruction strings. + @return A list of CInstruction objects. + @throws RuntimeError If a line cannot be parsed into a CInstruction. """ retval = [] for idx, s_line in enumerate(line_iter): @@ -81,16 +66,11 @@ def load_cinst_kernel(line_iter) -> list: def load_cinst_kernel_from_file(filename: str) -> list: """ - Loads CInstruction kernel from a file. - - Parameters: - filename (str): The file containing CInstruction strings. - - Returns: - list: A list of CInstruction objects. + @brief Loads CInstruction kernel from a file. - Raises: - RuntimeError: If an error occurs while loading the file. + @param filename The file containing CInstruction strings. + @return A list of CInstruction objects. + @throws RuntimeError If an error occurs while loading the file. """ with open(filename, "r", encoding="utf-8") as kernel_cinsts: try: @@ -101,16 +81,11 @@ def load_cinst_kernel_from_file(filename: str) -> list: def load_xinst_kernel(line_iter) -> list: """ - Loads XInstruction kernel from an iterator of lines. + @brief Loads XInstruction kernel from an iterator of lines. - Parameters: - line_iter: An iterator over lines of XInstruction strings. - - Returns: - list: A list of XInstruction objects. - - Raises: - RuntimeError: If a line cannot be parsed into an XInstruction. + @param line_iter An iterator over lines of XInstruction strings. + @return A list of XInstruction objects. + @throws RuntimeError If a line cannot be parsed into an XInstruction. """ retval = [] for idx, s_line in enumerate(line_iter): @@ -123,16 +98,11 @@ def load_xinst_kernel(line_iter) -> list: def load_xinst_kernel_from_file(filename: str) -> list: """ - Loads XInstruction kernel from a file. - - Parameters: - filename (str): The file containing XInstruction strings. - - Returns: - list: A list of XInstruction objects. + @brief Loads XInstruction kernel from a file. - Raises: - RuntimeError: If an error occurs while loading the file. + @param filename The file containing XInstruction strings. + @return A list of XInstruction objects. + @throws RuntimeError If an error occurs while loading the file. """ with open(filename, "r", encoding="utf-8") as kernel_xinsts: try: @@ -143,16 +113,11 @@ def load_xinst_kernel_from_file(filename: str) -> list: def load_dinst_kernel(line_iter) -> list: """ - Loads DInstruction kernel from an iterator of lines. + @brief Loads DInstruction kernel from an iterator of lines. - Parameters: - line_iter: An iterator over lines of DInstruction strings. - - Returns: - list: A list of DInstruction objects. - - Raises: - RuntimeError: If a line cannot be parsed into an DInstruction. + @param line_iter An iterator over lines of DInstruction strings. + @return A list of DInstruction objects. + @throws RuntimeError If a line cannot be parsed into an DInstruction. """ retval = [] for idx, s_line in enumerate(line_iter): @@ -166,16 +131,11 @@ def load_dinst_kernel(line_iter) -> list: def load_dinst_kernel_from_file(filename: str) -> list: """ - Loads DInstruction kernel from a file. - - Parameters: - filename (str): The file containing DInstruction strings. - - Returns: - list: A list of DInstruction objects. + @brief Loads DInstruction kernel from a file. - Raises: - RuntimeError: If an error occurs while loading the file. + @param filename The file containing DInstruction strings. + @return A list of DInstruction objects. + @throws RuntimeError If an error occurs while loading the file. """ with open(filename, "r", encoding="utf-8") as kernel_dinsts: try: diff --git a/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py b/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py index 9ed8603d..092a86dd 100644 --- a/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py +++ b/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py @@ -4,8 +4,9 @@ # These contents may have been developed with support from one or more Intel-operated # generative artificial intelligence solutions -"""This module provides functionality to link kernels into a program.""" +"""@brief This module provides functionality to link kernels into a program.""" +from typing import Dict, Any, cast from linker import MemoryModel from linker.instructions import minst, cinst, dinst from linker.instructions.dinst.dinstruction import DInstruction @@ -116,11 +117,13 @@ def _validate_hbm_address(self, var_name: str, hbm_address: int): f'Invalid negative HBM address for variable "{var_name}".' ) if var_name in self.__mem_model.mem_info_vars: - if self.__mem_model.mem_info_vars[var_name].hbm_address != hbm_address: + # Cast to dictionary to fix the indexing error + mem_info_vars_dict = cast(Dict[str, Any], self.__mem_model.mem_info_vars) + if mem_info_vars_dict[var_name].hbm_address != hbm_address: raise RuntimeError( ( f"Declared HBM address " - f"({self.__mem_model.mem_info_vars[var_name].hbm_address})" + f"({mem_info_vars_dict[var_name].hbm_address})" f" of mem Variable '{var_name}'" f" differs from allocated HBM address ({hbm_address})." ) @@ -145,11 +148,13 @@ def _validate_spad_address(self, var_name: str, spad_address: int): f'Invalid negative SPAD address for variable "{var_name}".' ) if var_name in self.__mem_model.mem_info_vars: - if self.__mem_model.mem_info_vars[var_name].hbm_address != spad_address: + # Cast to dictionary to fix the indexing error + mem_info_vars_dict = cast(Dict[str, Any], self.__mem_model.mem_info_vars) + if mem_info_vars_dict[var_name].hbm_address != spad_address: raise RuntimeError( ( f"Declared HBM address" - f" ({self.__mem_model.mem_info_vars[var_name].hbm_address})" + f" ({mem_info_vars_dict[var_name].hbm_address})" f" of mem Variable '{var_name}'" f" differs from allocated HBM address ({spad_address})." ) diff --git a/assembler_tools/hec-assembler-tools/linker/steps/variable_discovery.py b/assembler_tools/hec-assembler-tools/linker/steps/variable_discovery.py index 69457786..c862de25 100644 --- a/assembler_tools/hec-assembler-tools/linker/steps/variable_discovery.py +++ b/assembler_tools/hec-assembler-tools/linker/steps/variable_discovery.py @@ -1,26 +1,30 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +@brief This module provides functionality to discover variable names in MInstructions and CInstructions. +""" from assembler.memory_model.variable import Variable from linker.instructions import minst, cinst from linker.instructions.minst.minstruction import MInstruction from linker.instructions.cinst.cinstruction import CInstruction -def discoverVariablesSPAD(cinstrs: list): + +def discover_variables_spad(cinstrs: list): """ - Finds Variable names used in a list of CInstructions. - - Attributes: - cinstrs (list[CInstruction]): - List of CInstructions where to find variable names. - Raises: - RuntimeError: - Invalid Variable name detected in an CInstruction. - Returns: - Iterable: - Yields an iterable over variable names identified in the listing + @brief Finds Variable names used in a list of CInstructions. + + @param cinstrs List of CInstructions where to find variable names. + @throws TypeError If an item in the list is not a valid CInstruction. + @throws RuntimeError If an invalid Variable name is detected in a CInstruction. + @return Yields an iterable over variable names identified in the listing of CInstructions specified. """ for idx, cinstr in enumerate(cinstrs): if not isinstance(cinstr, CInstruction): - raise TypeError(f'Item {idx} in list of MInstructions is not a valid MInstruction.') + raise TypeError( + f"Item {idx} in list of MInstructions is not a valid MInstruction." + ) retval = None # TODO: Implement variable counting for CInst ############### @@ -32,27 +36,27 @@ def discoverVariablesSPAD(cinstrs: list): if retval is not None: if not Variable.validateName(retval): - raise RuntimeError(f'Invalid Variable name "{retval}" detected in instruction "{idx}, {cinstr.to_line()}"') + raise RuntimeError( + f'Invalid Variable name "{retval}" detected in instruction "{idx}, {cinstr.to_line()}"' + ) yield retval -def discoverVariables(minstrs: list): - """ - Finds variable names used in a list of MInstructions. - Parameters: - minstrs (list[MInstruction]): List of MInstructions where to find variable names. - - Raises: - TypeError: If an item in the list is not a valid MInstruction. - RuntimeError: If an invalid variable name is detected in an MInstruction. +def discover_variables(minstrs: list): + """ + @brief Finds variable names used in a list of MInstructions. - Returns: - Iterable: Yields an iterable over variable names identified in the listing - of MInstructions specified. + @param minstrs List of MInstructions where to find variable names. + @throws TypeError If an item in the list is not a valid MInstruction. + @throws RuntimeError If an invalid variable name is detected in an MInstruction. + @return Yields an iterable over variable names identified in the listing + of MInstructions specified. """ for idx, minstr in enumerate(minstrs): if not isinstance(minstr, MInstruction): - raise TypeError(f'Item {idx} in list of MInstructions is not a valid MInstruction.') + raise TypeError( + f"Item {idx} in list of MInstructions is not a valid MInstruction." + ) retval = None if isinstance(minstr, minst.MLoad): retval = minstr.source @@ -61,5 +65,7 @@ def discoverVariables(minstrs: list): if retval is not None: if not Variable.validateName(retval): - raise RuntimeError(f'Invalid Variable name "{retval}" detected in instruction "{idx}, {minstr.to_line()}"') - yield retval \ No newline at end of file + raise RuntimeError( + f'Invalid Variable name "{retval}" detected in instruction "{idx}, {minstr.to_line()}"' + ) + yield retval diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/memory_model/test_mem_info.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/memory_model/test_mem_info.py index 60ba2d44..6858b858 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/memory_model/test_mem_info.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/memory_model/test_mem_info.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 """ -Unit tests for the memory model mem_info module. +@brief Unit tests for the memory model mem_info module. """ import unittest @@ -19,10 +19,10 @@ class TestMemInfoVariable(unittest.TestCase): - """Tests for the MemInfoVariable class.""" + """@brief Tests for the MemInfoVariable class.""" def test_init_valid(self): - """Test initialization with valid parameters.""" + """@brief Test initialization with valid parameters.""" with patch( "assembler.memory_model.variable.Variable.validateName", return_value=True ): @@ -31,7 +31,7 @@ def test_init_valid(self): self.assertEqual(var.hbm_address, 42) def test_init_strips_whitespace(self): - """Test that initialization strips whitespace from variable name.""" + """@brief Test that initialization strips whitespace from variable name.""" with patch( "assembler.memory_model.variable.Variable.validateName", return_value=True ): @@ -39,7 +39,7 @@ def test_init_strips_whitespace(self): self.assertEqual(var.var_name, "test_var") def test_init_invalid_name(self): - """Test initialization with invalid variable name.""" + """@brief Test initialization with invalid variable name.""" with patch( "assembler.memory_model.variable.Variable.validateName", return_value=False ): @@ -47,7 +47,7 @@ def test_init_invalid_name(self): MemInfoVariable("invalid!var", 42) def test_repr(self): - """Test the __repr__ method.""" + """@brief Test the __repr__ method.""" with patch( "assembler.memory_model.variable.Variable.validateName", return_value=True ): @@ -57,7 +57,7 @@ def test_repr(self): ) def test_as_dict(self): - """Test the as_dict method.""" + """@brief Test the as_dict method.""" with patch( "assembler.memory_model.variable.Variable.validateName", return_value=True ): @@ -66,10 +66,10 @@ def test_as_dict(self): class TestMemInfoKeygenVariable(unittest.TestCase): - """Tests for the MemInfoKeygenVariable class.""" + """@brief Tests for the MemInfoKeygenVariable class.""" def test_init_valid(self): - """Test initialization with valid parameters.""" + """@brief Test initialization with valid parameters.""" with patch( "assembler.memory_model.variable.Variable.validateName", return_value=True ): @@ -80,7 +80,7 @@ def test_init_valid(self): self.assertEqual(var.key_index, 3) def test_init_negative_seed_index(self): - """Test initialization with negative seed index.""" + """@brief Test initialization with negative seed index.""" with patch( "assembler.memory_model.variable.Variable.validateName", return_value=True ): @@ -88,7 +88,7 @@ def test_init_negative_seed_index(self): MemInfoKeygenVariable("test_var", -1, 3) def test_init_negative_key_index(self): - """Test initialization with negative key index.""" + """@brief Test initialization with negative key index.""" with patch( "assembler.memory_model.variable.Variable.validateName", return_value=True ): @@ -96,7 +96,7 @@ def test_init_negative_key_index(self): MemInfoKeygenVariable("test_var", 2, -1) def test_as_dict(self): - """Test the as_dict method.""" + """@brief Test the as_dict method.""" with patch( "assembler.memory_model.variable.Variable.validateName", return_value=True ): @@ -107,10 +107,10 @@ def test_as_dict(self): class TestMemInfoMetadata(unittest.TestCase): - """Tests for the MemInfo.Metadata class.""" + """@brief Tests for the MemInfo.Metadata class.""" def test_parse_meta_field_from_mem_tokens_valid(self): - """Test parsing a valid metadata field.""" + """@brief Test parsing a valid metadata field.""" tokens = ["dload", "LOAD_ONES", "42", "ones_var"] result = MemInfo.Metadata.parse_meta_field_from_mem_tokens( tokens, "LOAD_ONES", var_prefix="ONES" @@ -120,7 +120,7 @@ def test_parse_meta_field_from_mem_tokens_valid(self): self.assertEqual(result.hbm_address, 42) def test_parse_meta_field_from_mem_tokens_no_name(self): - """Test parsing a metadata field without explicit name.""" + """@brief Test parsing a metadata field without explicit name.""" tokens = ["dload", "LOAD_ONES", "42"] result = MemInfo.Metadata.parse_meta_field_from_mem_tokens( tokens, "LOAD_ONES", var_prefix="ONES" @@ -130,7 +130,7 @@ def test_parse_meta_field_from_mem_tokens_no_name(self): self.assertEqual(result.hbm_address, 42) def test_parse_meta_field_from_mem_tokens_with_extra(self): - """Test parsing a metadata field with var_extra.""" + """@brief Test parsing a metadata field with var_extra.""" tokens = ["dload", "LOAD_ONES", "42"] result = MemInfo.Metadata.parse_meta_field_from_mem_tokens( tokens, "LOAD_ONES", var_prefix="ONES", var_extra="_extra" @@ -140,7 +140,7 @@ def test_parse_meta_field_from_mem_tokens_with_extra(self): self.assertEqual(result.hbm_address, 42) def test_parse_meta_field_from_mem_tokens_invalid(self): - """Test parsing an invalid metadata field.""" + """@brief Test parsing an invalid metadata field.""" # Not enough tokens tokens = ["dload"] result = MemInfo.Metadata.parse_meta_field_from_mem_tokens( @@ -163,7 +163,7 @@ def test_parse_meta_field_from_mem_tokens_invalid(self): self.assertIsNone(result) def test_metadata_init_and_properties(self): - """Test initialization and properties of Metadata class.""" + """@brief Test initialization and properties of Metadata class.""" # Prepare test data metadata_dict = { "ones": [{"var_name": "ones_var", "hbm_address": 1}], @@ -201,7 +201,7 @@ def test_metadata_init_and_properties(self): self.assertEqual(metadata.keygen_seeds[0].var_name, "keygen_seed") def test_get_item(self): - """Test the __getitem__ method.""" + """@brief Test the __getitem__ method.""" metadata_dict = {"ones": [{"var_name": "ones_var", "hbm_address": 1}]} metadata = MemInfo.Metadata(**metadata_dict) @@ -212,10 +212,10 @@ def test_get_item(self): class TestMemInfoParsers(unittest.TestCase): - """Tests for the various parser methods in MemInfo.""" + """@brief Tests for the various parser methods in MemInfo.""" def test_ones_parse_from_mem_tokens(self): - """Test parsing Ones metadata from tokens.""" + """@brief Test parsing Ones metadata from tokens.""" tokens = ["dload", "LOAD_ONES", "42", "ones_var"] with patch( "assembler.memory_model.mem_info.MemInfo.Metadata.parse_meta_field_from_mem_tokens" @@ -230,7 +230,7 @@ def test_ones_parse_from_mem_tokens(self): self.assertEqual(result, mock_parse.return_value) def test_ntt_aux_table_parse_from_mem_tokens(self): - """Test parsing NTTAuxTable metadata from tokens.""" + """@brief Test parsing NTTAuxTable metadata from tokens.""" tokens = ["dload", "LOAD_NTT_AUX_TABLE", "42", "ntt_aux_var"] with patch( "assembler.memory_model.mem_info.MemInfo.Metadata.parse_meta_field_from_mem_tokens" @@ -245,7 +245,7 @@ def test_ntt_aux_table_parse_from_mem_tokens(self): self.assertEqual(result, mock_parse.return_value) def test_ntt_routing_table_parse_from_mem_tokens(self): - """Test parsing NTTRoutingTable metadata from tokens.""" + """@brief Test parsing NTTRoutingTable metadata from tokens.""" tokens = ["dload", "LOAD_NTT_ROUTING_TABLE", "42", "ntt_route_var"] with patch( "assembler.memory_model.mem_info.MemInfo.Metadata.parse_meta_field_from_mem_tokens" @@ -260,7 +260,7 @@ def test_ntt_routing_table_parse_from_mem_tokens(self): self.assertEqual(result, mock_parse.return_value) def test_intt_aux_table_parse_from_mem_tokens(self): - """Test parsing iNTTAuxTable metadata from tokens.""" + """@brief Test parsing iNTTAuxTable metadata from tokens.""" tokens = ["dload", "LOAD_iNTT_AUX_TABLE", "42", "intt_aux_var"] with patch( "assembler.memory_model.mem_info.MemInfo.Metadata.parse_meta_field_from_mem_tokens" @@ -275,7 +275,7 @@ def test_intt_aux_table_parse_from_mem_tokens(self): self.assertEqual(result, mock_parse.return_value) def test_intt_routing_table_parse_from_mem_tokens(self): - """Test parsing iNTTRoutingTable metadata from tokens.""" + """@brief Test parsing iNTTRoutingTable metadata from tokens.""" tokens = ["dload", "LOAD_iNTT_ROUTING_TABLE", "42", "intt_route_var"] with patch( "assembler.memory_model.mem_info.MemInfo.Metadata.parse_meta_field_from_mem_tokens" @@ -290,7 +290,7 @@ def test_intt_routing_table_parse_from_mem_tokens(self): self.assertEqual(result, mock_parse.return_value) def test_twiddle_parse_from_mem_tokens(self): - """Test parsing Twiddle metadata from tokens.""" + """@brief Test parsing Twiddle metadata from tokens.""" tokens = ["dload", "LOAD_TWIDDLE", "42", "twiddle_var"] with patch( "assembler.memory_model.mem_info.MemInfo.Metadata.parse_meta_field_from_mem_tokens" @@ -305,7 +305,7 @@ def test_twiddle_parse_from_mem_tokens(self): self.assertEqual(result, mock_parse.return_value) def test_keygen_seed_parse_from_mem_tokens(self): - """Test parsing KeygenSeed metadata from tokens.""" + """@brief Test parsing KeygenSeed metadata from tokens.""" tokens = ["dload", "LOAD_KEYGEN_SEED", "42", "keygen_seed_var"] with patch( "assembler.memory_model.mem_info.MemInfo.Metadata.parse_meta_field_from_mem_tokens" @@ -320,7 +320,7 @@ def test_keygen_seed_parse_from_mem_tokens(self): self.assertEqual(result, mock_parse.return_value) def test_keygen_parse_from_mem_tokens_valid(self): - """Test parsing a valid keygen variable.""" + """@brief Test parsing a valid keygen variable.""" tokens = ["keygen", "2", "3", "keygen_var"] result = MemInfo.Keygen.parse_from_mem_tokens(tokens) self.assertIsNotNone(result) @@ -329,7 +329,7 @@ def test_keygen_parse_from_mem_tokens_valid(self): self.assertEqual(result.key_index, 3) def test_keygen_parse_from_mem_tokens_invalid(self): - """Test parsing an invalid keygen variable.""" + """@brief Test parsing an invalid keygen variable.""" # Not enough tokens tokens = ["keygen", "2", "3"] result = MemInfo.Keygen.parse_from_mem_tokens(tokens) @@ -341,7 +341,7 @@ def test_keygen_parse_from_mem_tokens_invalid(self): self.assertIsNone(result) def test_input_parse_from_mem_tokens_valid(self): - """Test parsing a valid input variable.""" + """@brief Test parsing a valid input variable.""" tokens = ["dload", "poly", "42", "input_var"] with patch( "assembler.memory_model.variable.Variable.validateName", return_value=True @@ -352,7 +352,7 @@ def test_input_parse_from_mem_tokens_valid(self): self.assertEqual(result.hbm_address, 42) def test_input_parse_from_mem_tokens_invalid(self): - """Test parsing an invalid input variable.""" + """@brief Test parsing an invalid input variable.""" # Not enough tokens tokens = ["dload", "poly", "42"] result = MemInfo.Input.parse_from_mem_tokens(tokens) @@ -368,7 +368,7 @@ def test_input_parse_from_mem_tokens_invalid(self): self.assertIsNone(result) def test_output_parse_from_mem_tokens_valid(self): - """Test parsing a valid output variable.""" + """@brief Test parsing a valid output variable.""" tokens = ["dstore", "output_var", "42"] with patch( "assembler.memory_model.variable.Variable.validateName", return_value=True @@ -379,7 +379,7 @@ def test_output_parse_from_mem_tokens_valid(self): self.assertEqual(result.hbm_address, 42) def test_output_parse_from_mem_tokens_invalid(self): - """Test parsing an invalid output variable.""" + """@brief Test parsing an invalid output variable.""" # Not enough tokens tokens = ["store", "output_var"] result = MemInfo.Output.parse_from_mem_tokens(tokens) @@ -392,10 +392,10 @@ def test_output_parse_from_mem_tokens_invalid(self): class TestMemInfo(unittest.TestCase): - """Tests for the MemInfo class.""" + """@brief Tests for the MemInfo class.""" def test_init_default(self): - """Test default initialization.""" + """@brief Test default initialization.""" mem_info = MemInfo() self.assertEqual(len(mem_info.keygens), 0) self.assertEqual(len(mem_info.inputs), 0) @@ -404,7 +404,7 @@ def test_init_default(self): self.assertIsInstance(mem_info.metadata, MemInfo.Metadata) def test_init_with_data(self): - """Test initialization with data.""" + """@brief Test initialization with data.""" # Prepare test data test_data = { "keygens": [{"var_name": "keygen_var", "seed_index": 1, "key_index": 2}], @@ -438,7 +438,7 @@ def test_init_with_data(self): self.assertEqual(mem_info.metadata.twiddle[0].var_name, "twiddle_var") def test_factory_dict(self): - """Test the factory_dict property.""" + """@brief Test the factory_dict property.""" mem_info = MemInfo() factory_dict = mem_info.factory_dict @@ -460,7 +460,7 @@ def test_factory_dict(self): self.assertEqual(factory_dict[MemInfo.Output], mem_info.outputs) def test_mem_info_types(self): - """Test the mem_info_types class property.""" + """@brief Test the mem_info_types class property.""" mem_info_types = MemInfo.mem_info_types # Verify expected types are in the list @@ -476,7 +476,7 @@ def test_mem_info_types(self): self.assertIn(MemInfo.Metadata.Twiddle, mem_info_types) def test_get_meminfo_var_from_tokens_valid(self): - """Test getting a MemInfo variable from valid tokens.""" + """@brief Test getting a MemInfo variable from valid tokens.""" tokens = ["keygen", "2", "3", "keygen_var"] # Mock the parse_from_mem_tokens method to return a mock variable @@ -491,7 +491,7 @@ def test_get_meminfo_var_from_tokens_valid(self): self.assertEqual(var_type, MemInfo.Keygen) def test_get_meminfo_var_from_tokens_not_found(self): - """Test getting a MemInfo variable when no parser can handle it.""" + """@brief Test getting a MemInfo variable when no parser can handle it.""" tokens = ["unknown", "token"] # Mock all parse_from_mem_tokens methods to return None @@ -509,7 +509,7 @@ def test_get_meminfo_var_from_tokens_not_found(self): self.assertIsNone(var_type) def test_add_meminfo_var_from_tokens_valid(self): - """Test adding a MemInfo variable from valid tokens.""" + """@brief Test adding a MemInfo variable from valid tokens.""" tokens = ["keygen", "2", "3", "keygen_var"] mem_info = MemInfo() @@ -534,7 +534,7 @@ def test_add_meminfo_var_from_tokens_valid(self): mock_list.append.assert_called_once_with(mock_variable) def test_add_meminfo_var_from_tokens_not_found(self): - """Test adding a MemInfo variable when no parser can handle it.""" + """@brief Test adding a MemInfo variable when no parser can handle it.""" tokens = ["unknown", "token"] mem_info = MemInfo() @@ -547,7 +547,7 @@ def test_add_meminfo_var_from_tokens_not_found(self): mem_info.add_meminfo_var_from_tokens(tokens) def test_from_file_iter_valid(self): - """Test creating a MemInfo from a valid file iterator.""" + """@brief Test creating a MemInfo from a valid file iterator.""" # Mock lines lines = [ "keygen, 2, 3, keygen_var", @@ -588,7 +588,7 @@ def mock_tokenize(line): self.assertEqual(mock_add_var.call_count, 4) def test_from_file_iter_error(self): - """Test creating a MemInfo when an error occurs.""" + """@brief Test creating a MemInfo when an error occurs.""" # Mock lines lines = ["invalid line"] @@ -614,7 +614,7 @@ def mock_tokenize(line): self.assertIn("1: invalid line", str(context.exception)) def test_from_dinstrs_valid(self): - """Test creating a MemInfo from valid DInstructions.""" + """@brief Test creating a MemInfo from valid DInstructions.""" # Mock DInstructions dinstrs = [ MagicMock(tokens=["keygen", "2", "3", "keygen_var"]), @@ -643,7 +643,7 @@ def test_from_dinstrs_valid(self): ) def test_from_dinstrs_error(self): - """Test creating a MemInfo when an error occurs.""" + """@brief Test creating a MemInfo when an error occurs.""" # Mock DInstructions dinstrs = [MagicMock(tokens=["invalid"])] @@ -663,7 +663,7 @@ def test_from_dinstrs_error(self): self.assertIn("1: ['invalid']", str(context.exception)) def test_as_dict(self): - """Test the as_dict method.""" + """@brief Test the as_dict method.""" # Create a MemInfo with test data with patch("assembler.memory_model.mem_info.MemInfo.validate"): @@ -705,7 +705,7 @@ def test_as_dict(self): self.assertEqual(result["metadata"]["ones"], [ones_dict]) def test_validate_valid(self): - """Test validation with valid data.""" + """@brief Test validation with valid data.""" ones_dict = {"var_name": "ones_var", "hbm_address": 44} twiddle_dict = {"var_name": "twiddle_var", "hbm_address": 45} @@ -727,7 +727,7 @@ def test_validate_valid(self): mem_info.validate() # Should not raise any exceptions def test_validate_twiddle_mismatch(self): - """Test validation with mismatched twiddle count.""" + """@brief Test validation with mismatched twiddle count.""" ones_dict = {"var_name": "ones_var", "hbm_address": 44} twiddle_dict = {"var_name": "twiddle_var", "hbm_address": 45} @@ -752,7 +752,7 @@ def test_validate_twiddle_mismatch(self): ) def test_validate_duplicate_var_name(self): - """Test validation with duplicate variable names but different HBM addresses.""" + """@brief Test validation with duplicate variable names but different HBM addresses.""" # Create variable dictionaries with duplicate names but different addresses intt_aux_dict = {"var_name": "duplicate", "hbm_address": 1} ntt_route_dict = {"var_name": "duplicate", "hbm_address": 2} @@ -777,10 +777,10 @@ def test_validate_duplicate_var_name(self): class TestUpdateMemoryModelWithMemInfo(unittest.TestCase): - """Tests for the updateMemoryModelWithMemInfo function.""" + """@brief Tests for the updateMemoryModelWithMemInfo function.""" def setUp(self): - """Set up common test fixtures.""" + """@brief Set up common test fixtures.""" # Create mock MemoryModel self.mock_mem_model = MagicMock() self.mock_mem_model.retrieveVarAdd = MagicMock() @@ -818,7 +818,7 @@ def setUp(self): self.mock_mem_info.metadata = self.mock_metadata def test_update_memory_model_inputs(self): - """Test updating memory model with input variables.""" + """@brief Test updating memory model with input variables.""" # Call the function with patch( "assembler.memory_model.mem_info._allocateMemInfoVariable" @@ -829,7 +829,7 @@ def test_update_memory_model_inputs(self): mock_allocate.assert_any_call(self.mock_mem_model, self.vars["input"]) def test_update_memory_model_outputs(self): - """Test updating memory model with output variables.""" + """@brief Test updating memory model with output variables.""" # Call the function with patch( "assembler.memory_model.mem_info._allocateMemInfoVariable" @@ -843,7 +843,7 @@ def test_update_memory_model_outputs(self): ) def test_update_memory_model_metadata(self): - """Test updating memory model with metadata variables.""" + """@brief Test updating memory model with metadata variables.""" # Call the function with patch( "assembler.memory_model.mem_info._allocateMemInfoVariable" @@ -891,10 +891,10 @@ def test_update_memory_model_metadata(self): class TestAllocateMemInfoVariable(unittest.TestCase): - """Tests for the _allocateMemInfoVariable function.""" + """@brief Tests for the _allocateMemInfoVariable function.""" def test_allocate_mem_info_variable_success(self): - """Test successful allocation of a MemInfo variable.""" + """@brief Test successful allocation of a MemInfo variable.""" # Create mock MemoryModel and variable mock_mem_model = MagicMock() mock_var_info = MagicMock(var_name="test_var", hbm_address=42) @@ -918,7 +918,7 @@ def test_allocate_mem_info_variable_success(self): ) def test_allocate_mem_info_variable_not_in_model(self): - """Test allocation when the variable is not in the memory model.""" + """@brief Test allocation when the variable is not in the memory model.""" # Create mock MemoryModel and variable mock_mem_model = MagicMock() mock_var_info = MagicMock(var_name="missing_var", hbm_address=42) @@ -943,7 +943,7 @@ def test_allocate_mem_info_variable_not_in_model(self): ) def test_allocate_mem_info_variable_mismatch(self): - """Test allocation when the variable has a different HBM address.""" + """@brief Test allocation when the variable has a different HBM address.""" # Create mock MemoryModel and variable mock_mem_model = MagicMock() mock_var_info = MagicMock(var_name="test_var", hbm_address=42) diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_link.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_link.py index 36439e21..d857f66a 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_link.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_link.py @@ -369,10 +369,10 @@ def test_scan_variables(self, has_hbm): with patch("linker.loader.load_minst_kernel_from_file", return_value=[]), patch( "linker.loader.load_cinst_kernel_from_file", return_value=[] ), patch( - "linker.steps.variable_discovery.discoverVariables", + "linker.steps.variable_discovery.discover_variables", return_value=["var1", "var2"], ), patch( - "linker.steps.variable_discovery.discoverVariablesSPAD", + "linker.steps.variable_discovery.discover_variables_spad", return_value=["var1", "var2"], ): he_link.scan_variables(input_files, mock_mem_model, mock_verbose) @@ -450,7 +450,7 @@ class TestMainFunction: @pytest.mark.parametrize("multi_mem_files", [True, False]) def test_main(self, multi_mem_files): """ - @brief Test main function with multi_mem_files=True + @brief Test main function with and without multi_mem_files """ # Arrange mock_config = MagicMock() @@ -460,75 +460,93 @@ def test_main(self, multi_mem_files): mock_config.suppress_comments = False mock_config.use_xinstfetch = False - mock_verbose = MagicMock() + # Setup input files with conditional mem files + input_files = [ + he_link.KernelFiles( + prefix="prefix1", + minst="prefix1.minst", + cinst="prefix1.cinst", + xinst="prefix1.xinst", + mem="prefix1.mem" if multi_mem_files else None, + ), + he_link.KernelFiles( + prefix="prefix2", + minst="prefix2.minst", + cinst="prefix2.cinst", + xinst="prefix2.xinst", + mem="prefix2.mem" if multi_mem_files else None, + ), + ] + + # Create a dictionary of mocks to reduce the number of local variables + mocks = { + "prepare_output": MagicMock(), + "prepare_input": MagicMock(return_value=input_files), + "scan_variables": MagicMock(), + "check_unused_variables": MagicMock(), + "link_kernels": MagicMock(), + "from_dinstrs": MagicMock(), + "from_file_iter": MagicMock(), + "load_dinst": MagicMock(return_value=["1", "2"]), + "join_dinst": MagicMock(return_value=[]), + "dump_instructions": MagicMock(), + } # Act with patch( "assembler.common.constants.convertBytes2Words", return_value=1024 - ), patch("he_link.prepare_output_files") as mock_prepare_output, patch( - "he_link.prepare_input_files" - ) as mock_prepare_input, patch( + ), patch("he_link.prepare_output_files", mocks["prepare_output"]), patch( + "he_link.prepare_input_files", mocks["prepare_input"] + ), patch( "assembler.common.counter.Counter.reset" ), patch( - "linker.loader.load_dinst_kernel_from_file", return_value=["1", "2"] - ) as mock_load_dinst_kernel_from_file, patch( - "linker.instructions.BaseInstruction.dump_instructions_to_file" - ) as mock_dump_instructions, patch( + "linker.loader.load_dinst_kernel_from_file", mocks["load_dinst"] + ), patch( + "linker.instructions.BaseInstruction.dump_instructions_to_file", + mocks["dump_instructions"], + ), patch( "linker.steps.program_linker.LinkedProgram.join_dinst_kernels", - return_value=[], - ) as mock_join_dinst_kernels, patch( - "assembler.memory_model.mem_info.MemInfo.from_dinstrs" - ) as mock_from_dinstrs, patch( - "assembler.memory_model.mem_info.MemInfo.from_file_iter" - ) as mock_from_file_iter, patch( + mocks["join_dinst"], + ), patch( + "assembler.memory_model.mem_info.MemInfo.from_dinstrs", + mocks["from_dinstrs"], + ), patch( + "assembler.memory_model.mem_info.MemInfo.from_file_iter", + mocks["from_file_iter"], + ), patch( "linker.MemoryModel" ), patch( - "he_link.scan_variables" - ) as mock_scan_variables, patch( - "he_link.check_unused_variables" - ) as mock_check_unused_variables, patch( - "he_link.link_kernels" - ) as mock_link_kernels, patch( - "he_link.BaseInstruction.dump_instructions_to_file" - ) as mock_dump_instructions: - - mock_prepare_input.return_value = [ - he_link.KernelFiles( - prefix="prefix1", - minst="prefix1.minst", - cinst="prefix1.cinst", - xinst="prefix1.xinst", - mem=None, - ), - he_link.KernelFiles( - prefix="prefix2", - minst="prefix2.minst", - cinst="prefix2.cinst", - xinst="prefix2.xinst", - mem=None, - ), - ] - he_link.main(mock_config, mock_verbose) + "he_link.scan_variables", mocks["scan_variables"] + ), patch( + "he_link.check_unused_variables", mocks["check_unused_variables"] + ), patch( + "he_link.link_kernels", mocks["link_kernels"] + ), patch( + "he_link.BaseInstruction.dump_instructions_to_file", + mocks["dump_instructions"], + ): + + he_link.main(mock_config, MagicMock()) # Assert pipeline is run as expected - mock_prepare_output.assert_called_once() - mock_prepare_input.assert_called_once() - mock_scan_variables.assert_called_once() - mock_check_unused_variables.assert_called_once() - mock_link_kernels.assert_called_once() + mocks["prepare_output"].assert_called_once() + mocks["prepare_input"].assert_called_once() + mocks["scan_variables"].assert_called_once() + mocks["check_unused_variables"].assert_called_once() + mocks["link_kernels"].assert_called_once() if multi_mem_files: # Should use from_dinstrs, not from_file_iter - assert mock_from_dinstrs.called - assert mock_load_dinst_kernel_from_file.called - assert mock_join_dinst_kernels.called - assert mock_dump_instructions.called + assert mocks["from_dinstrs"].called + assert mocks["load_dinst"].called + assert mocks["join_dinst"].called + assert mocks["dump_instructions"].called - assert not mock_from_file_iter.called + assert not mocks["from_file_iter"].called else: # Should use from_file_iter, not from_dinstrs - assert mock_from_file_iter.called - assert not mock_from_dinstrs.called + assert mocks["from_file_iter"].called + assert not mocks["from_dinstrs"].called def test_warning_on_use_xinstfetch(self): """ diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_prep.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_prep.py index 68078899..48edc8d3 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_prep.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_prep.py @@ -5,7 +5,7 @@ # generative artificial intelligence solutions """ -Unit tests for he_prep module. +@brief Unit tests for he_prep module. """ from unittest import mock @@ -18,10 +18,10 @@ def test_main_assigns_and_saves(monkeypatch, tmp_path): """ - Test that the main function assigns register banks, processes instructions, and saves the output. + @brief Test that the main function assigns register banks, processes instructions, and saves the output. - This test uses monkeypatching to mock dependencies and verifies that the output file - contains the expected instruction after processing a dummy input file. + @details This test uses monkeypatching to mock dependencies and verifies that the output file + contains the expected instruction after processing a dummy input file. """ # Prepare dummy input file input_file = tmp_path / "input.csv" @@ -48,7 +48,7 @@ def test_main_assigns_and_saves(monkeypatch, tmp_path): def test_main_no_input_file(): """ - Test that main raises an error when no input file is provided. + @brief Test that main raises an error when no input file is provided. """ with pytest.raises(FileNotFoundError): he_prep.main( @@ -58,7 +58,7 @@ def test_main_no_input_file(): def test_main_no_output_file(): """ - Test that main raises an error when no output file is provided. + @brief Test that main raises an error when no output file is provided. """ with pytest.raises(FileNotFoundError): he_prep.main( @@ -68,9 +68,9 @@ def test_main_no_output_file(): def test_main_no_instructions(monkeypatch): """ - Test that main handles the case where no instructions are processed. + @brief Test that main handles the case where no instructions are processed. - This test checks that the function can handle an empty instruction list without errors. + @details This test checks that the function can handle an empty instruction list without errors. """ input_file = "empty_input.csv" output_file = "empty_output.csv" @@ -101,7 +101,7 @@ def test_main_no_instructions(monkeypatch): def test_main_invalid_input_file(tmp_path): """ - Test that main raises an error when the input file does not exist. + @brief Test that main raises an error when the input file does not exist. """ input_file = tmp_path / "non_existent.csv" output_file = tmp_path / "output.csv" @@ -114,8 +114,9 @@ def test_main_invalid_input_file(tmp_path): def test_main_invalid_output_file(tmp_path): """ - Test that main raises an error when the output file cannot be created. - This test checks that the function handles file permission errors gracefully. + @brief Test that main raises an error when the output file cannot be created. + + @details This test checks that the function handles file permission errors gracefully. """ input_file = tmp_path / "input.csv" input_file.write_text("") # Write empty string to avoid SyntaxError @@ -133,7 +134,7 @@ def test_main_invalid_output_file(tmp_path): def test_parse_args(): """ - Test that parse_args returns the expected arguments. + @brief Test that parse_args returns the expected arguments. """ test_args = [ "prog", diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_init.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_init.py index cb354ed7..1570bcf4 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_init.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_init.py @@ -5,7 +5,7 @@ # generative artificial intelligence solutions """ -Unit tests for the memory model classes in linker/__init__.py. +@brief Unit tests for the memory model classes in linker/__init__.py. """ import unittest @@ -17,10 +17,13 @@ class TestVariableInfo(unittest.TestCase): - """Tests for the VariableInfo class.""" + """@brief Tests for the VariableInfo class.""" def test_init(self): - """Test initialization of VariableInfo.""" + """@brief Test initialization of VariableInfo. + + @test Verifies that VariableInfo is properly initialized with the given values + """ var_info = VariableInfo("test_var", 42) self.assertEqual(var_info.var_name, "test_var") self.assertEqual(var_info.hbm_address, 42) @@ -28,7 +31,10 @@ def test_init(self): self.assertEqual(var_info.last_kernel_used, -1) def test_init_default_values(self): - """Test initialization with default values.""" + """@brief Test initialization with default values. + + @test Verifies that VariableInfo is properly initialized with default values + """ var_info = VariableInfo("test_var") self.assertEqual(var_info.var_name, "test_var") self.assertEqual(var_info.hbm_address, -1) @@ -37,15 +43,18 @@ def test_init_default_values(self): class TestHBM(unittest.TestCase): - """Tests for the HBM class.""" + """@brief Tests for the HBM class.""" def setUp(self): - """Set up test fixtures.""" + """@brief Set up test fixtures.""" self.hbm_size = 10 self.hbm = HBM(self.hbm_size) def test_init(self): - """Test initialization of HBM.""" + """@brief Test initialization of HBM. + + @test Verifies that HBM is properly initialized with the given size + """ self.assertEqual(len(self.hbm.buffer), self.hbm_size) self.assertEqual(self.hbm.capacity, self.hbm_size) # Check that buffer is initialized with None values @@ -53,18 +62,27 @@ def test_init(self): self.assertIsNone(item) def test_init_invalid_size(self): - """Test initialization with invalid size.""" + """@brief Test initialization with invalid size. + + @test Verifies that ValueError is raised for invalid sizes + """ with self.assertRaises(ValueError): HBM(0) with self.assertRaises(ValueError): HBM(-1) def test_capacity_property(self): - """Test the capacity property.""" + """@brief Test the capacity property. + + @test Verifies that the capacity property returns the correct size + """ self.assertEqual(self.hbm.capacity, self.hbm_size) def test_buffer_property(self): - """Test the buffer property.""" + """@brief Test the buffer property. + + @test Verifies that the buffer property returns the correct buffer + """ buffer = self.hbm.buffer self.assertEqual(len(buffer), self.hbm_size) # Check that buffer is initialized with None values @@ -72,14 +90,20 @@ def test_buffer_property(self): self.assertIsNone(item) def test_force_allocate_valid(self): - """Test forceAllocate with valid parameters.""" + """@brief Test forceAllocate with valid parameters. + + @test Verifies that a variable is properly allocated at the specified address + """ var_info = VariableInfo("test_var") self.hbm.forceAllocate(var_info, 5) self.assertEqual(var_info.hbm_address, 5) self.assertEqual(self.hbm.buffer[5], var_info) def test_force_allocate_out_of_bounds(self): - """Test forceAllocate with out of bounds address.""" + """@brief Test forceAllocate with out of bounds address. + + @test Verifies that IndexError is raised for out-of-bounds addresses + """ var_info = VariableInfo("test_var") with self.assertRaises(IndexError): self.hbm.forceAllocate(var_info, -1) @@ -87,13 +111,19 @@ def test_force_allocate_out_of_bounds(self): self.hbm.forceAllocate(var_info, self.hbm_size) def test_force_allocate_already_allocated(self): - """Test forceAllocate with already allocated variable.""" + """@brief Test forceAllocate with already allocated variable. + + @test Verifies that ValueError is raised when variable is already allocated + """ var_info = VariableInfo("test_var", 3) with self.assertRaises(ValueError): self.hbm.forceAllocate(var_info, 5) def test_force_allocate_address_occupied_with_hbm(self): - """Test forceAllocate with address occupied and HBM enabled.""" + """@brief Test forceAllocate with address occupied and HBM enabled. + + @test Verifies that RuntimeError is raised when address is occupied + """ with patch.object(GlobalConfig, "hasHBM", True): # Occupy address 5 var_info1 = VariableInfo("var1") @@ -106,7 +136,10 @@ def test_force_allocate_address_occupied_with_hbm(self): self.hbm.forceAllocate(var_info2, 5) def test_force_allocate_address_occupied_without_hbm(self): - """Test forceAllocate with address occupied and HBM disabled.""" + """@brief Test forceAllocate with address occupied and HBM disabled. + + @test Verifies that RuntimeError is raised when address is occupied + """ with patch.object(GlobalConfig, "hasHBM", False): # Occupy address 5 var_info1 = VariableInfo("var1") @@ -119,7 +152,10 @@ def test_force_allocate_address_occupied_without_hbm(self): self.hbm.forceAllocate(var_info2, 5) def test_force_allocate_address_recyclable_with_hbm(self): - """Test forceAllocate with recyclable address and HBM enabled.""" + """@brief Test forceAllocate with recyclable address and HBM enabled. + + @test Verifies that an address can be recycled when the variable is not used + """ with patch.object(GlobalConfig, "hasHBM", True): # Occupy address 5 with a variable that's not used var_info1 = VariableInfo("var1") @@ -136,7 +172,10 @@ def test_force_allocate_address_recyclable_with_hbm(self): self.assertEqual(self.hbm.buffer[5], var_info2) def test_allocate(self): - """Test allocate method.""" + """@brief Test allocate method. + + @test Verifies that a variable is allocated at the first available address + """ var_info = VariableInfo("test_var") self.hbm.allocate(var_info) # The variable should be allocated at the first available address (0) @@ -144,7 +183,10 @@ def test_allocate(self): self.assertEqual(self.hbm.buffer[0], var_info) def test_allocate_full_memory(self): - """Test allocate with full memory.""" + """@brief Test allocate with full memory. + + @test Verifies that RuntimeError is raised when memory is full + """ # Fill up the HBM for i in range(self.hbm_size): var_info = VariableInfo(f"var{i}") @@ -157,7 +199,10 @@ def test_allocate_full_memory(self): self.hbm.allocate(var_info) def test_allocate_with_recycling(self): - """Test allocate with recycling unused addresses.""" + """@brief Test allocate with recycling unused addresses. + + @test Verifies that unused addresses can be recycled + """ with patch.object(GlobalConfig, "hasHBM", True): # Fill up the HBM for i in range(self.hbm_size): @@ -174,10 +219,10 @@ def test_allocate_with_recycling(self): class TestMemoryModel(unittest.TestCase): - """Tests for the MemoryModel class.""" + """@brief Tests for the MemoryModel class.""" def setUp(self): - """Set up test fixtures.""" + """@brief Set up test fixtures.""" # Create a mock MemInfo self.mock_mem_info = MagicMock(spec=mem_info.MemInfo) @@ -221,7 +266,10 @@ def setUp(self): self.memory_model = MemoryModel(10, self.mock_mem_info) def test_init(self): - """Test initialization of MemoryModel.""" + """@brief Test initialization of MemoryModel. + + @test Verifies that MemoryModel is properly initialized + """ self.assertIsInstance(self.memory_model.hbm, HBM) self.assertEqual(self.memory_model.hbm.capacity, 10) @@ -238,7 +286,10 @@ def test_init(self): self.assertIn("meta_var", self.memory_model.mem_info_meta) def test_add_variable_new(self): - """Test adding a new variable.""" + """@brief Test adding a new variable. + + @test Verifies that a new variable is correctly added to the model + """ self.memory_model.addVariable("test_var") # Check that variable was added @@ -251,7 +302,10 @@ def test_add_variable_new(self): self.assertEqual(var_info.hbm_address, -1) def test_add_variable_existing(self): - """Test adding an existing variable.""" + """@brief Test adding an existing variable. + + @test Verifies that the uses count is incremented for an existing variable + """ # Add the variable first self.memory_model.addVariable("test_var") @@ -263,7 +317,10 @@ def test_add_variable_existing(self): self.assertEqual(var_info.uses, 2) def test_add_variable_from_mem_info(self): - """Test adding a variable that's in mem_info.""" + """@brief Test adding a variable that's in mem_info. + + @test Verifies that a variable from mem_info is correctly added with its HBM address + """ self.memory_model.addVariable("input_var") # Check that variable was added @@ -276,7 +333,10 @@ def test_add_variable_from_mem_info(self): self.assertEqual(var_info.hbm_address, 1) def test_add_variable_from_fixed_addr_vars(self): - """Test adding a variable that's in fixed_addr_vars.""" + """@brief Test adding a variable that's in fixed_addr_vars. + + @test Verifies that a fixed-address variable is added with infinite uses + """ self.memory_model.addVariable("output_var") # Check that variable was added @@ -290,7 +350,10 @@ def test_add_variable_from_fixed_addr_vars(self): self.assertEqual(var_info.hbm_address, 2) def test_use_variable(self): - """Test using a variable.""" + """@brief Test using a variable. + + @test Verifies that using a variable decrements its uses count and allocates an HBM address + """ # Add the variable first self.memory_model.addVariable("test_var") @@ -312,7 +375,10 @@ def test_use_variable(self): self.assertEqual(self.memory_model.hbm.buffer[hbm_address], var_info) def test_use_variable_already_allocated(self): - """Test using a variable that already has an HBM address.""" + """@brief Test using a variable that already has an HBM address. + + @test Verifies that the existing HBM address is returned + """ # Add a variable from mem_info which already has an HBM address self.memory_model.addVariable("input_var") diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dinstruction.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dinstruction.py index b0e7219a..a6c2cf0a 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dinstruction.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dinstruction.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 """ -Unit tests for the DInstruction base class. +@brief Unit tests for the DInstruction base class. This module tests the core functionality of the DInstruction class which serves as the base for all data instructions. @@ -15,9 +15,9 @@ class TestDInstruction(unittest.TestCase): """ - Test cases for the DInstruction base class. + @brief Test cases for the DInstruction base class. - These tests verify the common functionality shared by all data instructions, + @details These tests verify the common functionality shared by all data instructions, including token handling, ID generation, and property access. """ @@ -25,9 +25,9 @@ def setUp(self): # Create a concrete subclass for testing since DInstruction is abstract class ConcreteDInstruction(DInstruction): """ - Concrete implementation of DInstruction for testing purposes. + @brief Concrete implementation of DInstruction for testing purposes. - This class provides implementations of the abstract methods + @details This class provides implementations of the abstract methods required to instantiate and test the DInstruction class. """ @@ -45,47 +45,71 @@ def _get_name(cls) -> str: self.dinst = self.d_instruction_class(self.tokens, self.comment) def test_get_name_token_index(self): - """Test _get_name_token_index returns 0""" + """@brief Test _get_name_token_index returns 0 + + @test Verifies the name token is at index 0 + """ self.assertEqual( self.d_instruction_class.name_token_index, 0 ) # Updated reference def test_num_tokens_property(self): - """Test num_tokens property returns expected value""" + """@brief Test num_tokens property returns expected value + + @test Verifies the num_tokens property returns the value from _get_num_tokens + """ self.assertEqual(self.d_instruction_class.num_tokens, 3) # Updated reference def test_initialization_valid_tokens(self): - """Test initialization with valid tokens""" + """@brief Test initialization with valid tokens + + @test Verifies an instance can be created with valid tokens and properties are set correctly + """ inst = self.d_instruction_class(self.tokens, self.comment) self.assertEqual(inst.tokens, self.tokens) self.assertEqual(inst.comment, self.comment) self.assertIsNotNone(inst.id) def test_initialization_token_count_too_few(self): - """Test initialization with too few tokens""" + """@brief Test initialization with too few tokens + + @test Verifies ValueError is raised when too few tokens are provided + """ with self.assertRaises(ValueError): self.d_instruction_class(["test_instruction", "var1"]) def test_initialization_invalid_name(self): - """Test initialization with invalid name token""" + """@brief Test initialization with invalid name token + + @test Verifies ValueError is raised when an invalid instruction name is provided + """ with self.assertRaises(ValueError): self.d_instruction_class(["wrong_name", "var1", "123"]) def test_id_property(self): - """Test id property returns a unique id""" + """@brief Test id property returns a unique id + + @test Verifies each instruction instance gets a unique ID + """ inst1 = self.d_instruction_class(self.tokens) inst2 = self.d_instruction_class(self.tokens) self.assertNotEqual(inst1.id, inst2.id) def test_to_line_method(self): - """Test to_line method returns expected string""" + """@brief Test to_line method returns expected string + + @test Verifies the to_line method correctly formats the instruction as a string + """ tokens = ["test_instruction", "var1", "123"] inst = self.d_instruction_class(tokens, "") expected = "test_instruction, var1, 123" self.assertEqual(inst.to_line(), expected) def test_consecutive_ids(self): - """Test that consecutive instructions get incremental ids""" + """@brief Test that consecutive instructions get incremental ids + + @test Verifies IDs are incremented sequentially for new instances + """ inst1 = self.d_instruction_class(self.tokens) inst2 = self.d_instruction_class(self.tokens) self.assertEqual(inst2.id, inst1.id + 1) diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dkeygen.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dkeygen.py index 91a6f207..f9769657 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dkeygen.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dkeygen.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 """ -Unit tests for the DKeygen instruction class. +@brief Unit tests for the DKeygen instruction class. This module tests the functionality of the DKeygen instruction which is responsible for key generation operations. @@ -17,9 +17,9 @@ class TestDKeygenInstruction(unittest.TestCase): """ - Test cases for the DKeygen instruction class. + @brief Test cases for the DKeygen instruction class. - These tests verify that the DKeygen instruction correctly handles token + @details These tests verify that the DKeygen instruction correctly handles token parsing, name resolution, and serialization. """ @@ -33,27 +33,42 @@ def setUp(self): ) def test_get_num_tokens(self): - """Test that _get_num_tokens returns 4""" + """@brief Test that _get_num_tokens returns 4 + + @test Verifies the instruction requires exactly 4 tokens + """ self.assertEqual(Instruction.num_tokens, 4) def test_get_name(self): - """Test that _get_name returns the expected value""" + """@brief Test that _get_name returns the expected value + + @test Verifies the instruction name matches the MemInfo constant + """ self.assertEqual(Instruction.name, MemInfo.Const.Keyword.KEYGEN) def test_initialization_valid_input(self): - """Test that initialization can set up the correct properties with valid name""" + """@brief Test that initialization can set up the correct properties with valid name + + @test Verifies the instruction is properly initialized with valid tokens + """ inst = Instruction( [MemInfo.Const.Keyword.KEYGEN, self.seed_idx, self.key_idx, self.var_name] ) self.assertEqual(inst.name, MemInfo.Const.Keyword.KEYGEN) def test_initialization_invalid_name(self): - """Test that initialization raises exception with invalid name""" + """@brief Test that initialization raises exception with invalid name + + @test Verifies ValueError is raised when an invalid instruction name is provided + """ with self.assertRaises(ValueError): # Adjust exception type if needed Instruction(["invalid_name", self.seed_idx, self.key_idx, self.var_name]) def test_tokens_property(self): - """Test that tokens property returns the correct list""" + """@brief Test that tokens property returns the correct list + + @test Verifies the tokens property correctly formats the instruction tokens + """ # Since tokens property implementation is not visible in the dkeygen.py file, # this test assumes default behavior from parent class or basic functionality expected_tokens = [ @@ -65,7 +80,10 @@ def test_tokens_property(self): self.assertEqual(self.inst.tokens[:4], expected_tokens) def test_tokens_with_additional_data(self): - """Test tokens property with additional tokens""" + """@brief Test tokens property with additional tokens + + @test Verifies extra tokens are preserved in the tokens property + """ additional_token = "extra" inst_with_extra = Instruction( [ @@ -84,19 +102,28 @@ def test_tokens_with_additional_data(self): return_value=None, ) def test_inheritance(self, mock_init): - """Test that Instruction properly extends DInstruction""" + """@brief Test that Instruction properly extends DInstruction + + @test Verifies the parent constructor is called during initialization + """ # Ensure that DInstruction methods are called as expected Instruction([Instruction.name, self.seed_idx, self.key_idx, self.var_name]) # Verify DInstruction.__init__ was called mock_init.assert_called() def test_invalid_token_count_too_few(self): - """Test behavior when fewer tokens than required are provided""" + """@brief Test behavior when fewer tokens than required are provided + + @test Verifies ValueError is raised when too few tokens are provided + """ with self.assertRaises(ValueError): # Adjust exception type if needed Instruction([MemInfo.Const.Keyword.KEYGEN, self.seed_idx, self.key_idx]) def test_invalid_token_count_too_many(self): - """Test behavior when more tokens than required are provided""" + """@brief Test behavior when more tokens than required are provided + + @test Verifies extra tokens are handled gracefully without errors + """ # This should not raise an error as additional tokens are handled inst = Instruction( [ diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dload.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dload.py index 36e20e6d..c6168790 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dload.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dload.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 """ -Unit tests for the DLoad instruction class. +@brief Unit tests for the DLoad instruction class. This module tests the functionality of the DLoad instruction which is responsible for loading data from memory locations. @@ -17,9 +17,9 @@ class TestDLoadInstruction(unittest.TestCase): """ - Test cases for the DLoad instruction class. + @brief Test cases for the DLoad instruction class. - These tests verify that the DLoad instruction correctly handles token + @details These tests verify that the DLoad instruction correctly handles token parsing, name resolution, and serialization. """ @@ -30,15 +30,24 @@ def setUp(self): self.type = "type1" def test_get_num_tokens(self): - """Test that _get_num_tokens returns 3""" + """@brief Test that _get_num_tokens returns 3 + + @test Verifies the instruction requires exactly 3 tokens + """ self.assertEqual(Instruction.num_tokens, 3) def test_get_name(self): - """Test that _get_name returns the expected value""" + """@brief Test that _get_name returns the expected value + + @test Verifies the instruction name matches the MemInfo constant + """ self.assertEqual(Instruction.name, MemInfo.Const.Keyword.LOAD) def test_initialization_valid_input(self): - """Test that initialization can set up the correct properties with valid name""" + """@brief Test that initialization can set up the correct properties with valid name + + @test Verifies the instruction is properly initialized with valid tokens + """ inst = Instruction( [MemInfo.Const.Keyword.LOAD, self.type, str(self.address), self.var_name] ) @@ -46,18 +55,27 @@ def test_initialization_valid_input(self): self.assertEqual(inst.name, MemInfo.Const.Keyword.LOAD) def test_initialization_valid_meta(self): - """Test that initialization can set up the correct properties with valid name""" + """@brief Test that initialization can set up the correct properties with metadata + + @test Verifies the instruction handles metadata loading correctly + """ inst = Instruction([MemInfo.Const.Keyword.LOAD, self.type, str(self.address)]) self.assertEqual(inst.name, MemInfo.Const.Keyword.LOAD) def test_initialization_invalid_name(self): - """Test that initialization raises exception with invalid name""" + """@brief Test that initialization raises exception with invalid name + + @test Verifies ValueError is raised when an invalid instruction name is provided + """ with self.assertRaises(ValueError): # Adjust exception type if needed Instruction(["invalid_name", self.type, str(self.address), self.var_name]) def test_tokens_property(self): - """Test that tokens property returns the correct list""" + """@brief Test that tokens property returns the correct list + + @test Verifies the tokens property correctly formats the instruction tokens + """ expected_tokens = [ MemInfo.Const.Keyword.LOAD, self.type, @@ -73,7 +91,10 @@ def test_tokens_property(self): self.assertEqual(inst.tokens, expected_tokens) def test_tokens_with_additional_data(self): - """Test tokens property with additional tokens""" + """@brief Test tokens property with additional tokens + + @test Verifies extra tokens are preserved in the tokens property + """ additional_token = "extra" inst_with_extra = Instruction( [ @@ -99,19 +120,28 @@ def test_tokens_with_additional_data(self): return_value=None, ) def test_inheritance(self, mock_init): - """Test that Instruction properly extends DInstruction""" + """@brief Test that Instruction properly extends DInstruction + + @test Verifies the parent constructor is called during initialization + """ # Ensure that DInstruction methods are called as expected Instruction([Instruction.name, self.type, str(self.address), self.var_name]) # Verify DInstruction.__init__ was called mock_init.assert_called() def test_invalid_token_count_too_few(self): - """Test behavior when fewer tokens than required are provided""" + """@brief Test behavior when fewer tokens than required are provided + + @test Verifies ValueError is raised when too few tokens are provided + """ with self.assertRaises(ValueError): # Adjust exception type if needed Instruction([MemInfo.Const.Keyword.LOAD, self.var_name]) def test_invalid_token_count_too_many(self): - """Test behavior when more tokens than required are provided""" + """@brief Test behavior when more tokens than required are provided + + @test Verifies extra tokens are handled gracefully without errors + """ # This should not raise an error as additional tokens are handled inst = Instruction( [ diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dstore.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dstore.py index 71762da1..938a9db9 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dstore.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dstore.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 """ -Unit tests for the DStore instruction class. +@brief Unit tests for the DStore instruction class. This module tests the functionality of the DStore instruction which is responsible for storing data to memory locations. @@ -17,9 +17,9 @@ class TestDStoreInstruction(unittest.TestCase): """ - Test cases for the DStore instruction class. + @brief Test cases for the DStore instruction class. - These tests verify that the DStore instruction correctly handles token + @details These tests verify that the DStore instruction correctly handles token parsing, name resolution, and serialization. """ @@ -29,15 +29,24 @@ def setUp(self): self.address = 123 def test_get_num_tokens(self): - """Test that _get_num_tokens returns 3""" + """@brief Test that _get_num_tokens returns 3 + + @test Verifies the instruction requires exactly 3 tokens + """ self.assertEqual(Instruction.num_tokens, 3) def test_get_name(self): - """Test that _get_name returns the expected value""" + """@brief Test that _get_name returns the expected value + + @test Verifies the instruction name matches the MemInfo constant + """ self.assertEqual(Instruction.name, MemInfo.Const.Keyword.STORE) def test_initialization_valid_input(self): - """Test that initialization can set up the correct properties with valid name""" + """@brief Test that initialization can set up the correct properties with valid name + + @test Verifies the instruction is properly initialized with valid tokens + """ inst = Instruction( [MemInfo.Const.Keyword.STORE, self.var_name, str(self.address)] ) @@ -45,12 +54,18 @@ def test_initialization_valid_input(self): self.assertEqual(inst.name, MemInfo.Const.Keyword.STORE) def test_initialization_invalid_name(self): - """Test that initialization raises exception with invalid name""" + """@brief Test that initialization raises exception with invalid name + + @test Verifies ValueError is raised when an invalid instruction name is provided + """ with self.assertRaises(ValueError): # Adjust exception type if needed Instruction(["invalid_name", self.var_name, str(self.address)]) def test_tokens_property(self): - """Test that tokens property returns the correct list""" + """@brief Test that tokens property returns the correct list + + @test Verifies the tokens property correctly formats the instruction tokens + """ expected_tokens = [ MemInfo.Const.Keyword.STORE, self.var_name, @@ -65,7 +80,10 @@ def test_tokens_property(self): self.assertEqual(inst.tokens, expected_tokens) def test_tokens_with_additional_data(self): - """Test tokens property with additional tokens""" + """@brief Test tokens property with additional tokens + + @test Verifies extra tokens are preserved in the tokens property + """ additional_token = "extra" inst_with_extra = Instruction( [ @@ -90,19 +108,28 @@ def test_tokens_with_additional_data(self): return_value=None, ) def test_inheritance(self, mock_init): - """Test that Instruction properly extends DInstruction""" + """@brief Test that Instruction properly extends DInstruction + + @test Verifies the parent constructor is called during initialization + """ # Ensure that DInstruction methods are called as expected Instruction([Instruction.name, self.var_name, str(self.address)]) # Verify DInstruction.__init__ was called mock_init.assert_called() def test_invalid_token_count_too_few(self): - """Test behavior when fewer tokens than required are provided""" + """@brief Test behavior when fewer tokens than required are provided + + @test Verifies ValueError is raised when too few tokens are provided + """ with self.assertRaises(ValueError): # Adjust exception type if needed Instruction([MemInfo.Const.Keyword.STORE]) def test_invalid_token_count_too_many(self): - """Test behavior when more tokens than required are provided""" + """@brief Test behavior when more tokens than required are provided + + @test Verifies extra tokens are handled gracefully without errors + """ # This should not raise an error as additional tokens are handled inst = Instruction( [ diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_init.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_init.py index eb9579fb..26a9a799 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_init.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_init.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 """ -Unit tests for the dinst package initialization module. +@brief Unit tests for the dinst package initialization module. This module tests the factory functions and initialization utilities for data instructions. @@ -17,14 +17,17 @@ class TestDInstModule(unittest.TestCase): """ - Test cases for data instruction initialization. + @brief Test cases for data instruction initialization. - These tests verify that the data instruction factory correctly creates + @details These tests verify that the data instruction factory correctly creates instruction instances and properly handles initialization errors. """ def test_factory(self): - """Test that factory returns the expected set of instruction classes""" + """@brief Test that factory returns the expected set of instruction classes + + @test Verifies the factory returns a set containing DLoad, DStore, and DKeyGen + """ instruction_set = factory() self.assertIsInstance(instruction_set, set) self.assertEqual(len(instruction_set), 3) @@ -35,7 +38,10 @@ def test_factory(self): @patch("assembler.instructions.tokenize_from_line") @patch("assembler.memory_model.mem_info.MemInfo.get_meminfo_var_from_tokens") def test_create_from_mem_line_dload_input(self, mock_get_meminfo, mock_tokenize): - """Test create_from_mem_line creates DLoad instruction""" + """@brief Test create_from_mem_line creates DLoad instruction + + @test Verifies that a DLoad instruction is created with correct properties + """ # Setup mocks tokens = ["dload", "poly", "0x123", "var1"] comment = "Test comment" @@ -58,7 +64,10 @@ def test_create_from_mem_line_dload_input(self, mock_get_meminfo, mock_tokenize) @patch("assembler.instructions.tokenize_from_line") @patch("assembler.memory_model.mem_info.MemInfo.get_meminfo_var_from_tokens") def test_create_from_mem_line_dload_meta(self, mock_get_meminfo, mock_tokenize): - """Test create_from_mem_line creates DLoad instruction""" + """@brief Test create_from_mem_line creates DLoad instruction for metadata + + @test Verifies that a DLoad instruction is created for metadata entries + """ # Setup mocks tokens = ["dload", "meta", "1"] comment = "Test comment" @@ -81,7 +90,10 @@ def test_create_from_mem_line_dload_meta(self, mock_get_meminfo, mock_tokenize): @patch("assembler.instructions.tokenize_from_line") @patch("assembler.memory_model.mem_info.MemInfo.get_meminfo_var_from_tokens") def test_create_from_mem_line_dstore(self, mock_get_meminfo, mock_tokenize): - """Test create_from_mem_line creates DStore instruction""" + """@brief Test create_from_mem_line creates DStore instruction + + @test Verifies that a DStore instruction is created with correct properties + """ # Setup mocks tokens = ["dstore", "var1", "0x456"] comment = "Test comment" @@ -104,7 +116,10 @@ def test_create_from_mem_line_dstore(self, mock_get_meminfo, mock_tokenize): @patch("assembler.instructions.tokenize_from_line") @patch("assembler.memory_model.mem_info.MemInfo.get_meminfo_var_from_tokens") def test_create_from_mem_line_dkeygen(self, mock_get_meminfo, mock_tokenize): - """Test create_from_mem_line creates DKeyGen instruction""" + """@brief Test create_from_mem_line creates DKeyGen instruction + + @test Verifies that a DKeyGen instruction is created with correct properties + """ # Setup mocks tokens = ["keygen", "key1", "type1", "256"] comment = "Test comment" @@ -128,7 +143,10 @@ def test_create_from_mem_line_dkeygen(self, mock_get_meminfo, mock_tokenize): @patch("assembler.instructions.tokenize_from_line") @patch("assembler.memory_model.mem_info.MemInfo.get_meminfo_var_from_tokens") def test_create_from_mem_line_invalid(self, mock_get_meminfo, mock_tokenize): - """Test create_from_mem_line with invalid instruction""" + """@brief Test create_from_mem_line with invalid instruction + + @test Verifies that RuntimeError is raised for invalid instructions + """ # Setup mocks to return invalid tokens tokens = ["invalid_instruction", "var1", "0x123"] comment = "" @@ -144,7 +162,10 @@ def test_create_from_mem_line_invalid(self, mock_get_meminfo, mock_tokenize): @patch("assembler.instructions.tokenize_from_line") @patch("assembler.memory_model.mem_info.MemInfo.get_meminfo_var_from_tokens") def test_create_from_mem_line_meminfo_error(self, mock_get_meminfo, mock_tokenize): - """Test create_from_mem_line with MemInfo error""" + """@brief Test create_from_mem_line with MemInfo error + + @test Verifies that RuntimeError is wrapped with line information + """ # Setup mocks tokens = ["dstore", "var1", "0x123"] comment = "" diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_init.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_init.py index e688a545..08771961 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_init.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_init.py @@ -5,7 +5,7 @@ # generative artificial intelligence solutions """ -Unit tests for the linker instructions initialization module. +@brief Unit tests for the linker instructions initialization module. This module contains tests that verify the behavior of the instruction factory and initialization functionality. @@ -19,7 +19,7 @@ class TestCreateFromStrLine(unittest.TestCase): """ - Test cases for instruction initialization functionality. + @brief Test cases for instruction initialization functionality. These tests verify that instructions are correctly initialized, their tokens are properly processed, and their factories work as expected. @@ -41,7 +41,12 @@ def setUp(self): @patch("linker.instructions.tokenize_from_line") def test_create_from_str_line_success(self, mock_tokenize): - """Test successful instruction creation""" + """ + @brief Test successful instruction creation + + @test Verifies that an instruction is correctly created from a string line + when a valid factory is provided + """ # Setup mock tokens = ["instruction", "arg1", "arg2"] comment = "Test comment" @@ -59,7 +64,11 @@ def test_create_from_str_line_success(self, mock_tokenize): @patch("linker.instructions.tokenize_from_line") def test_create_from_str_line_failure(self, mock_tokenize): - """Test when no instruction can be created""" + """ + @brief Test when no instruction can be created + + @test Verifies that None is returned when instruction creation fails + """ # Setup mock tokens = ["unknown", "arg1", "arg2"] comment = "Test comment" @@ -78,7 +87,12 @@ def test_create_from_str_line_failure(self, mock_tokenize): @patch("linker.instructions.tokenize_from_line") def test_create_from_str_line_multiple_instruction_types(self, mock_tokenize): - """Test with multiple instruction types in factory""" + """ + @brief Test with multiple instruction types in factory + + @test Verifies that the function tries each instruction type in the factory + until one succeeds + """ # Setup mocks tokens = ["instruction", "arg1", "arg2"] comment = "Test comment" @@ -103,7 +117,12 @@ def test_create_from_str_line_multiple_instruction_types(self, mock_tokenize): @patch("linker.instructions.tokenize_from_line") def test_create_from_str_line_exception_handling(self, mock_tokenize): - """Test that general exceptions are caught""" + """ + @brief Test that general exceptions are caught + + @test Verifies that unexpected exceptions during instruction creation are + handled gracefully and None is returned + """ # Setup mock tokens = ["instruction", "arg1", "arg2"] comment = "Test comment" diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_instruction.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_instruction.py index 0f8245e1..b5842745 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_instruction.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_instruction.py @@ -5,7 +5,7 @@ # generative artificial intelligence solutions """ -Unit tests for the BaseInstruction class. +@brief Unit tests for the BaseInstruction class. """ import os @@ -18,7 +18,7 @@ class MockInstruction(BaseInstruction): - """Concrete implementation of BaseInstruction for testing.""" + """@brief Concrete implementation of BaseInstruction for testing.""" @classmethod def _get_name(cls) -> str: @@ -34,83 +34,119 @@ def _get_num_tokens(cls) -> int: class TestBaseInstruction(unittest.TestCase): - """Tests for the BaseInstruction class.""" + """@brief Tests for the BaseInstruction class.""" def setUp(self): - """Setup for tests.""" + """@brief Setup for tests.""" self.valid_tokens = ["TEST", "arg1", "arg2"] self.comment = "This is a test comment" def test_init_valid(self): - """Test initialization with valid tokens.""" + """@brief Test initialization with valid tokens. + + @test Verifies that an instruction can be correctly initialized with valid tokens + """ instruction = MockInstruction(self.valid_tokens, self.comment) self.assertEqual(instruction.tokens, self.valid_tokens) self.assertEqual(instruction.comment, self.comment) def test_init_invalid_name(self): - """Test initialization with invalid instruction name.""" + """@brief Test initialization with invalid instruction name. + + @test Verifies that a ValueError is raised when the instruction name is invalid + """ invalid_tokens = ["WRONG", "arg1", "arg2"] with self.assertRaises(ValueError) as context: MockInstruction(invalid_tokens) self.assertIn("invalid name", str(context.exception)) def test_init_invalid_num_tokens(self): - """Test initialization with incorrect number of tokens.""" + """@brief Test initialization with incorrect number of tokens. + + @test Verifies that a ValueError is raised when the number of tokens is incorrect + """ invalid_tokens = ["TEST", "arg1"] with self.assertRaises(ValueError) as context: MockInstruction(invalid_tokens) self.assertIn("invalid amount of tokens", str(context.exception)) def test_id_generation(self): - """Test that each instruction gets a unique ID.""" + """@brief Test that each instruction gets a unique ID. + + @test Verifies that different instruction instances have different IDs + """ instruction1 = MockInstruction(self.valid_tokens) instruction2 = MockInstruction(self.valid_tokens) self.assertNotEqual(instruction1.id, instruction2.id) def test_str_representation(self): - """Test string representation.""" + """@brief Test string representation. + + @test Verifies that the string representation is correctly formatted + """ instruction = MockInstruction(self.valid_tokens) self.assertEqual(str(instruction), f"TEST({instruction.id})") def test_repr_representation(self): - """Test repr representation.""" + """@brief Test repr representation. + + @test Verifies that the repr representation contains the expected information + """ instruction = MockInstruction(self.valid_tokens) self.assertIn("MockInstruction(TEST, id=", repr(instruction)) self.assertIn("tokens=", repr(instruction)) def test_equality(self): - """Test equality operator.""" + """@brief Test equality operator. + + @test Verifies that equality is based on object identity rather than value + """ instruction1 = MockInstruction(self.valid_tokens) instruction2 = MockInstruction(self.valid_tokens) self.assertNotEqual(instruction1, instruction2) self.assertEqual(instruction1, instruction1) def test_hash(self): - """Test hash function.""" + """@brief Test hash function. + + @test Verifies that the hash is based on the instruction's ID + """ instruction = MockInstruction(self.valid_tokens) self.assertEqual(hash(instruction), hash(instruction.id)) def test_to_line_with_comment(self): - """Test to_line method with comment.""" + """@brief Test to_line method with comment. + + @test Verifies that to_line correctly includes the comment when present + """ instruction = MockInstruction(self.valid_tokens, self.comment) expected = f"TEST, arg1, arg2 # {self.comment}" self.assertEqual(instruction.to_line(), expected) def test_to_line_without_comment(self): - """Test to_line method without comment.""" + """@brief Test to_line method without comment. + + @test Verifies that to_line works correctly when no comment is present + """ instruction = MockInstruction(self.valid_tokens) expected = "TEST, arg1, arg2" self.assertEqual(instruction.to_line(), expected) def test_to_line_suppressed_comments(self): - """Test to_line method with suppressed comments.""" + """@brief Test to_line method with suppressed comments. + + @test Verifies that comments are not included when GlobalConfig.suppress_comments is True + """ with patch.object(GlobalConfig, "suppress_comments", True): instruction = MockInstruction(self.valid_tokens, self.comment) expected = "TEST, arg1, arg2" self.assertEqual(instruction.to_line(), expected) def test_dump_instructions_to_file(self): - """Test dump_instructions_to_file method.""" + """@brief Test dump_instructions_to_file method. + + @test Verifies that instructions are correctly written to a file + """ instruction1 = MockInstruction(self.valid_tokens, "Comment 1") instruction2 = MockInstruction(self.valid_tokens, "Comment 2") instructions = [instruction1, instruction2] diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_loader.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_loader.py index 05dd61dd..06a3827d 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_loader.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_loader.py @@ -5,7 +5,7 @@ # generative artificial intelligence solutions """ -Unit tests for the loader module. +@brief Unit tests for the loader module. """ import unittest @@ -24,10 +24,10 @@ class TestLoader(unittest.TestCase): - """Tests for the loader module functions.""" + """@brief Tests for the loader module functions.""" def setUp(self): - """Set up test fixtures.""" + """@brief Set up test fixtures.""" # Sample instruction lines for each type self.minst_lines = ["MINST arg1, arg2", "MINST arg3, arg4"] self.cinst_lines = ["CINST arg1, arg2", "CINST arg3, arg4"] @@ -43,7 +43,10 @@ def setUp(self): @patch("linker.instructions.create_from_str_line") @patch("linker.instructions.minst.factory") def test_load_minst_kernel_success(self, mock_factory, mock_create): - """Test successful loading of MInstructions from an iterator.""" + """@brief Test successful loading of MInstructions from an iterator. + + @test Verifies that MInstructions are properly created from string lines + """ # Configure mocks mock_factory.return_value = "minst_factory" mock_create.side_effect = self.mock_minst @@ -66,7 +69,10 @@ def test_load_minst_kernel_success(self, mock_factory, mock_create): @patch("linker.instructions.create_from_str_line") @patch("linker.instructions.minst.factory") def test_load_minst_kernel_failure(self, mock_factory, mock_create): - """Test error handling when loading MInstructions fails.""" + """@brief Test error handling when loading MInstructions fails. + + @test Verifies that a RuntimeError is raised when parsing fails + """ # Configure mocks mock_factory.return_value = "minst_factory" mock_create.return_value = None @@ -82,7 +88,10 @@ def test_load_minst_kernel_failure(self, mock_factory, mock_create): @patch("builtins.open", new_callable=mock_open) @patch("linker.loader.load_minst_kernel") def test_load_minst_kernel_from_file_success(self, mock_load, mock_file): - """Test successful loading of MInstructions from a file.""" + """@brief Test successful loading of MInstructions from a file. + + @test Verifies that file contents are properly passed to load_minst_kernel + """ # Configure mocks mock_file.return_value.__enter__.return_value = self.minst_lines mock_load.return_value = self.mock_minst @@ -92,13 +101,16 @@ def test_load_minst_kernel_from_file_success(self, mock_load, mock_file): # Verify the results self.assertEqual(result, self.mock_minst) - mock_file.assert_called_once_with("test.minst", "r") + mock_file.assert_called_once_with("test.minst", "r", encoding="utf-8") mock_load.assert_called_once_with(self.minst_lines) @patch("builtins.open", new_callable=mock_open) @patch("linker.loader.load_minst_kernel") def test_load_minst_kernel_from_file_failure(self, mock_load, mock_file): - """Test error handling when loading MInstructions from a file fails.""" + """@brief Test error handling when loading MInstructions from a file fails. + + @test Verifies that a RuntimeError is raised with appropriate message + """ # Configure mocks mock_file.return_value.__enter__.return_value = self.minst_lines mock_load.side_effect = Exception("Test error") @@ -114,7 +126,10 @@ def test_load_minst_kernel_from_file_failure(self, mock_load, mock_file): @patch("linker.instructions.create_from_str_line") @patch("linker.instructions.cinst.factory") def test_load_cinst_kernel_success(self, mock_factory, mock_create): - """Test successful loading of CInstructions from an iterator.""" + """@brief Test successful loading of CInstructions from an iterator. + + @test Verifies that CInstructions are properly created from string lines + """ # Configure mocks mock_factory.return_value = "cinst_factory" mock_create.side_effect = self.mock_cinst @@ -137,7 +152,10 @@ def test_load_cinst_kernel_success(self, mock_factory, mock_create): @patch("linker.instructions.create_from_str_line") @patch("linker.instructions.cinst.factory") def test_load_cinst_kernel_failure(self, mock_factory, mock_create): - """Test error handling when loading CInstructions fails.""" + """@brief Test error handling when loading CInstructions fails. + + @test Verifies that a RuntimeError is raised when parsing fails + """ # Configure mocks mock_factory.return_value = "cinst_factory" mock_create.return_value = None @@ -153,7 +171,10 @@ def test_load_cinst_kernel_failure(self, mock_factory, mock_create): @patch("builtins.open", new_callable=mock_open) @patch("linker.loader.load_cinst_kernel") def test_load_cinst_kernel_from_file_success(self, mock_load, mock_file): - """Test successful loading of CInstructions from a file.""" + """@brief Test successful loading of CInstructions from a file. + + @test Verifies that file contents are properly passed to load_cinst_kernel + """ # Configure mocks mock_file.return_value.__enter__.return_value = self.cinst_lines mock_load.return_value = self.mock_cinst @@ -163,13 +184,16 @@ def test_load_cinst_kernel_from_file_success(self, mock_load, mock_file): # Verify the results self.assertEqual(result, self.mock_cinst) - mock_file.assert_called_once_with("test.cinst", "r") + mock_file.assert_called_once_with("test.cinst", "r", encoding="utf-8") mock_load.assert_called_once_with(self.cinst_lines) @patch("builtins.open", new_callable=mock_open) @patch("linker.loader.load_cinst_kernel") def test_load_cinst_kernel_from_file_failure(self, mock_load, mock_file): - """Test error handling when loading CInstructions from a file fails.""" + """@brief Test error handling when loading CInstructions from a file fails. + + @test Verifies that a RuntimeError is raised with appropriate message + """ # Configure mocks mock_file.return_value.__enter__.return_value = self.cinst_lines mock_load.side_effect = Exception("Test error") @@ -185,7 +209,10 @@ def test_load_cinst_kernel_from_file_failure(self, mock_load, mock_file): @patch("linker.instructions.create_from_str_line") @patch("linker.instructions.xinst.factory") def test_load_xinst_kernel_success(self, mock_factory, mock_create): - """Test successful loading of XInstructions from an iterator.""" + """@brief Test successful loading of XInstructions from an iterator. + + @test Verifies that XInstructions are properly created from string lines + """ # Configure mocks mock_factory.return_value = "xinst_factory" mock_create.side_effect = self.mock_xinst @@ -208,7 +235,10 @@ def test_load_xinst_kernel_success(self, mock_factory, mock_create): @patch("linker.instructions.create_from_str_line") @patch("linker.instructions.xinst.factory") def test_load_xinst_kernel_failure(self, mock_factory, mock_create): - """Test error handling when loading XInstructions fails.""" + """@brief Test error handling when loading XInstructions fails. + + @test Verifies that a RuntimeError is raised when parsing fails + """ # Configure mocks mock_factory.return_value = "xinst_factory" mock_create.return_value = None @@ -224,7 +254,10 @@ def test_load_xinst_kernel_failure(self, mock_factory, mock_create): @patch("builtins.open", new_callable=mock_open) @patch("linker.loader.load_xinst_kernel") def test_load_xinst_kernel_from_file_success(self, mock_load, mock_file): - """Test successful loading of XInstructions from a file.""" + """@brief Test successful loading of XInstructions from a file. + + @test Verifies that file contents are properly passed to load_xinst_kernel + """ # Configure mocks mock_file.return_value.__enter__.return_value = self.xinst_lines mock_load.return_value = self.mock_xinst @@ -234,13 +267,16 @@ def test_load_xinst_kernel_from_file_success(self, mock_load, mock_file): # Verify the results self.assertEqual(result, self.mock_xinst) - mock_file.assert_called_once_with("test.xinst", "r") + mock_file.assert_called_once_with("test.xinst", "r", encoding="utf-8") mock_load.assert_called_once_with(self.xinst_lines) @patch("builtins.open", new_callable=mock_open) @patch("linker.loader.load_xinst_kernel") def test_load_xinst_kernel_from_file_failure(self, mock_load, mock_file): - """Test error handling when loading XInstructions from a file fails.""" + """@brief Test error handling when loading XInstructions from a file fails. + + @test Verifies that a RuntimeError is raised with appropriate message + """ # Configure mocks mock_file.return_value.__enter__.return_value = self.xinst_lines mock_load.side_effect = Exception("Test error") @@ -255,7 +291,10 @@ def test_load_xinst_kernel_from_file_failure(self, mock_load, mock_file): @patch("linker.instructions.dinst.create_from_mem_line") def test_load_dinst_kernel_success(self, mock_create): - """Test successful loading of DInstructions from an iterator.""" + """@brief Test successful loading of DInstructions from an iterator. + + @test Verifies that DInstructions are properly created from string lines + """ # Configure mocks mock_create.side_effect = self.mock_dinst @@ -271,7 +310,10 @@ def test_load_dinst_kernel_success(self, mock_create): @patch("linker.instructions.dinst.create_from_mem_line") def test_load_dinst_kernel_failure(self, mock_create): - """Test error handling when loading DInstructions fails.""" + """@brief Test error handling when loading DInstructions fails. + + @test Verifies that a RuntimeError is raised when parsing fails + """ # Configure mocks mock_create.return_value = None @@ -286,7 +328,10 @@ def test_load_dinst_kernel_failure(self, mock_create): @patch("builtins.open", new_callable=mock_open) @patch("linker.loader.load_dinst_kernel") def test_load_dinst_kernel_from_file_success(self, mock_load, mock_file): - """Test successful loading of DInstructions from a file.""" + """@brief Test successful loading of DInstructions from a file. + + @test Verifies that file contents are properly passed to load_dinst_kernel + """ # Configure mocks mock_file.return_value.__enter__.return_value = self.dinst_lines mock_load.return_value = self.mock_dinst @@ -296,13 +341,16 @@ def test_load_dinst_kernel_from_file_success(self, mock_load, mock_file): # Verify the results self.assertEqual(result, self.mock_dinst) - mock_file.assert_called_once_with("test.dinst", "r") + mock_file.assert_called_once_with("test.dinst", "r", encoding="utf-8") mock_load.assert_called_once_with(self.dinst_lines) @patch("builtins.open", new_callable=mock_open) @patch("linker.loader.load_dinst_kernel") def test_load_dinst_kernel_from_file_failure(self, mock_load, mock_file): - """Test error handling when loading DInstructions from a file fails.""" + """@brief Test error handling when loading DInstructions from a file fails. + + @test Verifies that a RuntimeError is raised with appropriate message + """ # Configure mocks mock_file.return_value.__enter__.return_value = self.dinst_lines mock_load.side_effect = Exception("Test error") diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_program_linker.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_program_linker.py index a4c58fde..46f88f70 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_program_linker.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_program_linker.py @@ -5,7 +5,7 @@ # generative artificial intelligence solutions """ -Unit tests for the program_linker module. +@brief Unit tests for the program_linker module. """ import io @@ -20,10 +20,10 @@ # pylint: disable=protected-access class TestLinkedProgram(unittest.TestCase): - """Tests for the LinkedProgram class.""" + """@brief Tests for the LinkedProgram class.""" def setUp(self): - """Set up test fixtures.""" + """@brief Set up test fixtures.""" # Group related stream objects into a dictionary self.streams = { "minst": io.StringIO(), @@ -50,12 +50,15 @@ def setUp(self): ) def tearDown(self): - """Tear down test fixtures.""" + """@brief Tear down test fixtures.""" self.has_hbm_patcher.stop() self.suppress_comments_patcher.stop() def test_init(self): - """Test initialization of LinkedProgram.""" + """@brief Test initialization of LinkedProgram. + + @test Verifies that all instance variables are correctly initialized + """ self.assertEqual( self.program._LinkedProgram__minst_ostream, self.streams["minst"] ) @@ -73,13 +76,19 @@ def test_init(self): self.assertTrue(self.program._LinkedProgram__is_open) def test_is_open_property(self): - """Test the is_open property.""" + """@brief Test the is_open property. + + @test Verifies that the is_open property reflects the internal state + """ self.assertTrue(self.program.is_open) self.program._LinkedProgram__is_open = False self.assertFalse(self.program.is_open) def test_close(self): - """Test closing the program.""" + """@brief Test closing the program. + + @test Verifies that cexit and msyncc instructions are added and program is marked as closed + """ self.program.close() # Verify cexit and msyncc were added @@ -113,7 +122,10 @@ def test_close(self): self.assertNotIn("terminating MInstQ", self.streams["minst"].getvalue()) def test_validate_hbm_address(self): - """Test validating a HBM address.""" + """@brief Test validating a HBM address. + + @test Verifies that valid addresses are accepted and invalid ones raise exceptions + """ # Test validating a valid HBM address self.mem_model.mem_info_vars = {} @@ -125,7 +137,10 @@ def test_validate_hbm_address(self): self.program._validate_hbm_address("test_var", -1) def test_validate_hbm_address_mismatch(self): - """Test validating an HBM address that doesn't match the declared address.""" + """@brief Test validating an HBM address that doesn't match the declared address. + + @test Verifies that a RuntimeError is raised when address doesn't match + """ mock_var = MagicMock() mock_var.hbm_address = 5 self.mem_model.mem_info_vars = {"test_var": mock_var} @@ -134,25 +149,37 @@ def test_validate_hbm_address_mismatch(self): self.program._validate_hbm_address("test_var", 10) def test_validate_spad_address_valid(self): - """Test validating a valid SPAD address with HBM disabled.""" + """@brief Test validating a valid SPAD address with HBM disabled. + + @test Verifies that valid SPAD addresses are accepted when HBM is disabled + """ with patch.object(GlobalConfig, "hasHBM", False): self.mem_model.mem_info_vars = {} self.program._validate_spad_address("test_var", 10) # No exception should be raised def test_validate_spad_address_with_hbm_enabled(self): - """Test validating a SPAD address with HBM enabled (should raise AssertionError).""" + """@brief Test validating a SPAD address with HBM enabled. + + @test Verifies that an AssertionError is raised when HBM is enabled + """ with self.assertRaises(AssertionError): self.program._validate_spad_address("test_var", 10) def test_validate_spad_address_negative(self): - """Test validating a negative SPAD address.""" + """@brief Test validating a negative SPAD address. + + @test Verifies that a RuntimeError is raised for negative addresses + """ with patch.object(GlobalConfig, "hasHBM", False): with self.assertRaises(RuntimeError): self.program._validate_spad_address("test_var", -1) def test_validate_spad_address_mismatch(self): - """Test validating a SPAD address that doesn't match the declared address.""" + """@brief Test validating a SPAD address that doesn't match the declared address. + + @test Verifies that a RuntimeError is raised when address doesn't match + """ with patch.object(GlobalConfig, "hasHBM", False): mock_var = MagicMock() mock_var.hbm_address = 5 @@ -162,7 +189,10 @@ def test_validate_spad_address_mismatch(self): self.program._validate_spad_address("test_var", 10) def test_update_minsts(self): - """Test updating MInsts.""" + """@brief Test updating MInsts. + + @test Verifies that MInsts are correctly updated with offsets and variable addresses + """ # Create mock MInstructions mock_msyncc = MagicMock(spec=minst.MSyncc) mock_msyncc.target = 5 @@ -203,7 +233,10 @@ def test_update_minsts(self): ) def test_remove_and_merge_csyncm_cnop(self): - """Test removing CSyncm instructions and merging CNop instructions.""" + """@brief Test removing CSyncm instructions and merging CNop instructions. + + @test Verifies that CSyncm instructions are removed and CNop cycles are updated correctly + """ # Create mock CInstructions mock_ifetch = MagicMock(spec=cinst.IFetch) mock_ifetch.bundle = 1 @@ -250,7 +283,10 @@ def test_remove_and_merge_csyncm_cnop(self): self.assertEqual(instr.tokens[0], i) def test_update_cinsts_addresses_and_offsets(self): - """Test updating CInst addresses and offsets.""" + """@brief Test updating CInst addresses and offsets. + + @test Verifies that CInst addresses and offsets are correctly updated + """ # Create mock CInstructions mock_ifetch = MagicMock(spec=cinst.IFetch) mock_ifetch.bundle = 1 @@ -308,7 +344,10 @@ def test_update_cinsts_addresses_and_offsets(self): self.program._update_cinsts_addresses_and_offsets([mock_xinstfetch]) def test_update_cinsts(self): - """Test updating CInsts.""" + """@brief Test updating CInsts. + + @test Verifies that the correct update methods are called based on HBM configuration + """ # Create a mock for _remove_and_merge_csyncm_cnop and _update_cinsts_addresses_and_offsets with patch.object( LinkedProgram, "_remove_and_merge_csyncm_cnop" @@ -337,7 +376,10 @@ def test_update_cinsts(self): mock_update.assert_called_once_with(kernel_cinstrs) def test_update_xinsts(self): - """Test updating XInsts.""" + """@brief Test updating XInsts. + + @test Verifies that XInst bundles are correctly updated and invalid sequences are detected + """ # Create mock XInstructions mock_xinst1 = MagicMock() mock_xinst1.bundle = 1 @@ -367,7 +409,10 @@ def test_update_xinsts(self): self.program._update_xinsts(kernel_xinstrs) def test_link_kernel(self): - """Test linking a kernel.""" + """@brief Test linking a kernel. + + @test Verifies that a kernel is correctly linked with updated instructions + """ # Create mocks for the update methods with patch.object( LinkedProgram, "_update_minsts" @@ -440,7 +485,10 @@ def test_link_kernel(self): self.assertIn("minst_comment0", minst_output) def test_link_kernel_with_no_hbm(self): - """Test linking a kernel with HBM disabled.""" + """@brief Test linking a kernel with HBM disabled. + + @test Verifies that MInsts are ignored when HBM is disabled + """ with patch.object(GlobalConfig, "hasHBM", False): # Create mocks for the update methods with patch.object( @@ -482,7 +530,10 @@ def test_link_kernel_with_no_hbm(self): self.assertEqual(minst_output, "") def test_link_kernel_with_closed_program(self): - """Test linking a kernel with a closed program.""" + """@brief Test linking a kernel with a closed program. + + @test Verifies that a RuntimeError is raised when linking to a closed program + """ # Close the program self.program._LinkedProgram__is_open = False @@ -491,7 +542,10 @@ def test_link_kernel_with_closed_program(self): self.program.link_kernel([], [], []) def test_link_kernel_with_suppress_comments(self): - """Test linking a kernel with comments suppressed.""" + """@brief Test linking a kernel with comments suppressed. + + @test Verifies that comments are not included in the output when suppressed + """ with patch.object(GlobalConfig, "suppress_comments", True): # Create mocks for the update methods with patch.object(LinkedProgram, "_update_minsts"), patch.object( @@ -527,15 +581,21 @@ def test_link_kernel_with_suppress_comments(self): class TestJoinDinstKernels(unittest.TestCase): - """Tests for the join_dinst_kernels static method.""" + """@brief Tests for the join_dinst_kernels static method.""" def test_join_dinst_kernels_empty(self): - """Test joining empty list of DInst kernels.""" + """@brief Test joining empty list of DInst kernels. + + @test Verifies that a ValueError is raised for an empty list + """ with self.assertRaises(ValueError): LinkedProgram.join_dinst_kernels([]) def test_join_dinst_kernels_single_kernel(self): - """Test joining a single DInst kernel.""" + """@brief Test joining a single DInst kernel. + + @test Verifies that instructions from a single kernel are correctly processed + """ # Create mock DInstructions mock_dload = MagicMock(spec=dinst.DLoad) mock_dload.var = "var1" @@ -556,7 +616,10 @@ def test_join_dinst_kernels_single_kernel(self): self.assertEqual(mock_dstore.address, 1) def test_join_dinst_kernels_multiple_kernels(self): - """Test joining multiple DInst kernels.""" + """@brief Test joining multiple DInst kernels. + + @test Verifies that instructions from multiple kernels are correctly merged + """ # Create mock DInstructions for first kernel mock_dload1 = MagicMock(spec=dinst.DLoad) mock_dload1.var = "var1" @@ -593,7 +656,10 @@ def test_join_dinst_kernels_multiple_kernels(self): self.assertEqual(used_addresses, {0, 1, 2}) # Three consecutive addresses def test_join_dinst_kernels_with_carry_over_vars(self): - """Test joining DInst kernels with carry-over variables that are both input and output.""" + """@brief Test joining DInst kernels with carry-over variables. + + @test Verifies that variables used across kernels are properly consolidated + """ # Create mock DInstructions for first kernel mock_dload1 = MagicMock(spec=dinst.DLoad) mock_dload1.var = "var1" diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_variable_discovery.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_variable_discovery.py index 41e1844b..4e9e942b 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_variable_discovery.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_variable_discovery.py @@ -5,20 +5,20 @@ # generative artificial intelligence solutions """ -Unit tests for the variable discovery module. +@brief Unit tests for the variable discovery module. """ import unittest from unittest.mock import patch, MagicMock -from linker.steps.variable_discovery import discoverVariables, discoverVariablesSPAD +from linker.steps.variable_discovery import discover_variables, discover_variables_spad class TestVariableDiscovery(unittest.TestCase): - """Tests for the variable discovery functions.""" + """@brief Tests for the variable discovery functions.""" def setUp(self): - """Set up test fixtures.""" + """@brief Set up test fixtures.""" # Group MInstructions in a dictionary self.m_instrs = { "load": MagicMock(source="var1"), @@ -42,7 +42,10 @@ def setUp(self): def test_discover_variables_valid( self, mock_validate, mock_minst_class, mock_minst ): - """Test discovering variables from valid MInstructions.""" + """@brief Test discovering variables from valid MInstructions. + + @test Verifies that variables are correctly discovered from MLoad and MStore instructions + """ # Setup mocks mock_minst.MLoad = MagicMock() mock_minst.MStore = MagicMock() @@ -82,7 +85,7 @@ def is_mstore_side_effect(obj): ] # Call the function - result = list(discoverVariables(minstrs)) + result = list(discover_variables(minstrs)) # Verify results self.assertEqual(result, ["var1", "var2"]) @@ -90,22 +93,28 @@ def is_mstore_side_effect(obj): mock_validate.assert_any_call("var2") def test_discover_variables_empty_list(self): - """Test discovering variables from an empty list of MInstructions.""" + """@brief Test discovering variables from an empty list of MInstructions. + + @test Verifies that an empty list is returned when no instructions are provided + """ # No need to patch isinstance for an empty list - result = list(discoverVariables([])) + result = list(discover_variables([])) # Verify results - should be an empty list self.assertEqual(result, []) def test_discover_variables_invalid_type(self): - """Test discovering variables with invalid types in the list.""" + """@brief Test discovering variables with invalid types in the list. + + @test Verifies that a TypeError is raised when an invalid object is in the list + """ # Setup mock to fail the isinstance check invalid_obj = MagicMock() with patch("linker.steps.variable_discovery.isinstance", return_value=False): # Call the function with a list containing an invalid type with self.assertRaises(TypeError) as context: - list(discoverVariables([invalid_obj])) + list(discover_variables([invalid_obj])) # Verify the error message self.assertIn("not a valid MInstruction", str(context.exception)) @@ -113,7 +122,10 @@ def test_discover_variables_invalid_type(self): @patch("linker.steps.variable_discovery.minst") @patch("assembler.memory_model.variable.Variable.validateName") def test_discover_variables_invalid_variable_name(self, mock_validate, mock_minst): - """Test discovering variables with an invalid variable name.""" + """@brief Test discovering variables with an invalid variable name. + + @test Verifies that a RuntimeError is raised when a variable name is invalid + """ # Setup mocks mock_minst.MLoad = MagicMock() @@ -126,7 +138,7 @@ def test_discover_variables_invalid_variable_name(self, mock_validate, mock_mins ): # Call the function with self.assertRaises(RuntimeError) as context: - list(discoverVariables([self.m_instrs["load"]])) + list(discover_variables([self.m_instrs["load"]])) # Verify the error message self.assertIn("Invalid Variable name", str(context.exception)) @@ -137,7 +149,10 @@ def test_discover_variables_invalid_variable_name(self, mock_validate, mock_mins def test_discover_variables_spad_valid( self, mock_validate, mock_cinst_class, mock_cinst ): - """Test discovering variables from valid CInstructions.""" + """@brief Test discovering variables from valid CInstructions. + + @test Verifies that variables are correctly discovered from all relevant CInstruction types + """ # Setup mocks mock_cinst.BLoad = MagicMock() mock_cinst.CLoad = MagicMock() @@ -182,7 +197,7 @@ def mock_isinstance(obj, cls): "linker.steps.variable_discovery.isinstance", side_effect=mock_isinstance ): # Call the function - result = list(discoverVariablesSPAD(cinstrs)) + result = list(discover_variables_spad(cinstrs)) # Verify results self.assertEqual(result, ["var3", "var4", "var5", "var6", "var7"]) @@ -193,22 +208,28 @@ def mock_isinstance(obj, cls): mock_validate.assert_any_call("var7") def test_discover_variables_spad_empty_list(self): - """Test discovering variables from an empty list of CInstructions.""" + """@brief Test discovering variables from an empty list of CInstructions. + + @test Verifies that an empty list is returned when no instructions are provided + """ # Call the function with an empty list - result = list(discoverVariablesSPAD([])) + result = list(discover_variables_spad([])) # Verify results - should be an empty list self.assertEqual(result, []) def test_discover_variables_spad_invalid_type(self): - """Test discovering variables with invalid types in the list.""" + """@brief Test discovering variables with invalid types in the list. + + @test Verifies that a TypeError is raised when an invalid object is in the list + """ # Setup mock invalid_obj = MagicMock() with patch("linker.steps.variable_discovery.isinstance", return_value=False): # Call the function with a list containing an invalid type with self.assertRaises(TypeError) as context: - list(discoverVariablesSPAD([invalid_obj])) + list(discover_variables_spad([invalid_obj])) # Verify the error message self.assertIn("not a valid MInstruction", str(context.exception)) @@ -219,7 +240,10 @@ def test_discover_variables_spad_invalid_type(self): def test_discover_variables_spad_invalid_variable_name( self, mock_validate, mock_cinst_class, mock_cinst ): - """Test discovering variables with an invalid variable name.""" + """@brief Test discovering variables with an invalid variable name. + + @test Verifies that a RuntimeError is raised when a variable name is invalid + """ # Setup mocks mock_cinst.BLoad = MagicMock() @@ -241,7 +265,7 @@ def test_discover_variables_spad_invalid_variable_name( ): # Call the function with self.assertRaises(RuntimeError) as context: - list(discoverVariablesSPAD([self.c_instrs["bload"]])) + list(discover_variables_spad([self.c_instrs["bload"]])) # Verify the error message self.assertIn("Invalid Variable name", str(context.exception)) From c7e6fca2eede5de12135c0aadb240ba1ca2cae48 Mon Sep 17 00:00:00 2001 From: "Rojas Chaves, Jose" Date: Mon, 7 Jul 2025 19:52:09 +0000 Subject: [PATCH 07/12] Fixing for CI --- .../linker/instructions/cinst/cinstruction.py | 9 --------- .../linker/instructions/dinst/dinstruction.py | 8 -------- .../linker/instructions/minst/minstruction.py | 9 --------- assembler_tools/hec-assembler-tools/pytest.ini | 2 +- 4 files changed, 1 insertion(+), 27 deletions(-) diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/cinst/cinstruction.py b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/cinstruction.py index 81df3417..a6157425 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/cinst/cinstruction.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/cinstruction.py @@ -58,12 +58,3 @@ def __init__(self, tokens: list, comment: str = ""): ValueError: If the number of tokens is invalid or the instruction name is incorrect. """ super().__init__(tokens, comment=comment) - - def to_line(self) -> str: - """ - Retrieves the string form of the instruction to write to the instruction file. - - Returns: - str: The string representation of the instruction, excluding the first token. - """ - return ", ".join(self.tokens[1:]) diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dinstruction.py b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dinstruction.py index 14c69f30..9db0ad16 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dinstruction.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dinstruction.py @@ -144,11 +144,3 @@ def address(self, value: int): @param value The new memory address (string or integer). """ self._address = value - - def to_line(self) -> str: - """ - @brief Retrieves the string form of the instruction to write to the instruction file. - - @return The string representation of the instruction. - """ - return ", ".join(self.tokens) diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/minst/minstruction.py b/assembler_tools/hec-assembler-tools/linker/instructions/minst/minstruction.py index 10f04ecb..e78ce2cd 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/minst/minstruction.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/minst/minstruction.py @@ -58,12 +58,3 @@ def __init__(self, tokens: list, comment: str = ""): ValueError: If the number of tokens is invalid or the instruction name is incorrect. """ super().__init__(tokens, comment=comment) - - def to_line(self) -> str: - """ - Retrieves the string form of the instruction to write to the instruction file. - - Returns: - str: The string representation of the instruction, excluding the first token. - """ - return ", ".join(self.tokens[1:]) diff --git a/assembler_tools/hec-assembler-tools/pytest.ini b/assembler_tools/hec-assembler-tools/pytest.ini index be1763fa..06e92c68 100644 --- a/assembler_tools/hec-assembler-tools/pytest.ini +++ b/assembler_tools/hec-assembler-tools/pytest.ini @@ -1,4 +1,4 @@ [pytest] pythonpath = . testpaths = tests -addopts = --cov=. +#addopts = --cov=. From c8d83343682ea6058248a7b083c799a67f6a416b Mon Sep 17 00:00:00 2001 From: "Rojas Chaves, Jose" Date: Mon, 7 Jul 2025 22:06:37 +0000 Subject: [PATCH 08/12] Fixing pyLimt --- .../linker/steps/program_linker.py | 86 +++++++++---------- .../test_steps/test_program_linker.py | 48 +++++------ 2 files changed, 62 insertions(+), 72 deletions(-) diff --git a/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py b/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py index 092a86dd..63e5511f 100644 --- a/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py +++ b/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py @@ -44,15 +44,15 @@ def __init__( This memory model will be modified by this object when linking kernels. @param suppress_comments (bool): Whether to suppress comments in the output. """ - self.__minst_ostream = program_minst_ostream - self.__cinst_ostream = program_cinst_ostream - self.__xinst_ostream = program_xinst_ostream + self._minst_ostream = program_minst_ostream + self._cinst_ostream = program_cinst_ostream + self._xinst_ostream = program_xinst_ostream self.__mem_model = mem_model - self.__bundle_offset = 0 - self.__minst_line_offset = 0 - self.__cinst_line_offset = 0 - self.__kernel_count = 0 # Number of kernels linked into this program - self.__is_open = ( + self._bundle_offset = 0 + self._minst_line_offset = 0 + self._cinst_line_offset = 0 + self._kernel_count = 0 # Number of kernels linked into this program + self._is_open = ( True # Tracks whether this program is still accepting kernels to link ) @@ -63,7 +63,7 @@ def is_open(self) -> bool: @return bool True if the program is open, False otherwise. """ - return self.__is_open + return self._is_open def close(self): """ @@ -77,31 +77,31 @@ def close(self): raise RuntimeError("Program is already closed.") # Add closing `cexit` - tokens = [str(self.__cinst_line_offset), cinst.CExit.name] + tokens = [str(self._cinst_line_offset), cinst.CExit.name] cexit_cinstr = cinst.CExit(tokens) print( f"{cexit_cinstr.tokens[0]}, {cexit_cinstr.to_line()}", - file=self.__cinst_ostream, + file=self._cinst_ostream, ) # Add closing msyncc tokens = [ - str(self.__minst_line_offset), + str(self._minst_line_offset), minst.MSyncc.name, - str(self.__cinst_line_offset + 1), + str(self._cinst_line_offset + 1), ] cmsyncc_minstr = minst.MSyncc(tokens) print( f"{cmsyncc_minstr.tokens[0]}, {cmsyncc_minstr.to_line()}", end="", - file=self.__minst_ostream, + file=self._minst_ostream, ) if not GlobalConfig.suppress_comments: - print(" # terminating MInstQ", end="", file=self.__minst_ostream) - print(file=self.__minst_ostream) + print(" # terminating MInstQ", end="", file=self._minst_ostream) + print(file=self._minst_ostream) # Program has been closed - self.__is_open = False + self._is_open = False def _validate_hbm_address(self, var_name: str, hbm_address: int): """ @@ -174,13 +174,11 @@ def _update_minsts(self, kernel_minstrs: list): for minstr in kernel_minstrs: # Update msyncc if isinstance(minstr, minst.MSyncc): - minstr.target = minstr.target + self.__cinst_line_offset + minstr.target = minstr.target + self._cinst_line_offset # Change mload variable names into HBM addresses if isinstance(minstr, minst.MLoad): var_name = minstr.source - hbm_address = self.__mem_model.useVariable( - var_name, self.__kernel_count - ) + hbm_address = self.__mem_model.useVariable(var_name, self._kernel_count) self._validate_hbm_address(var_name, hbm_address) minstr.source = str(hbm_address) minstr.comment = ( @@ -191,9 +189,7 @@ def _update_minsts(self, kernel_minstrs: list): # Change mstore variable names into HBM addresses if isinstance(minstr, minst.MStore): var_name = minstr.dest - hbm_address = self.__mem_model.useVariable( - var_name, self.__kernel_count - ) + hbm_address = self.__mem_model.useVariable(var_name, self._kernel_count) self._validate_hbm_address(var_name, hbm_address) minstr.dest = str(hbm_address) minstr.comment = ( @@ -300,7 +296,7 @@ def _update_cinsts_addresses_and_offsets(self, kernel_cinstrs: list): for cinstr in kernel_cinstrs: # Update ifetch if isinstance(cinstr, cinst.IFetch): - cinstr.bundle = cinstr.bundle + self.__bundle_offset + cinstr.bundle = cinstr.bundle + self._bundle_offset # Update xinstfetch if isinstance(cinstr, cinst.XInstFetch): raise NotImplementedError( @@ -308,7 +304,7 @@ def _update_cinsts_addresses_and_offsets(self, kernel_cinstrs: list): ) # Update csyncm if isinstance(cinstr, cinst.CSyncm): - cinstr.target = cinstr.target + self.__minst_line_offset + cinstr.target = cinstr.target + self._minst_line_offset if not GlobalConfig.hasHBM: # update all SPAD instruction variable names to be SPAD addresses @@ -318,7 +314,7 @@ def _update_cinsts_addresses_and_offsets(self, kernel_cinstrs: list): ): var_name = cinstr.source hbm_address = self.__mem_model.useVariable( - var_name, self.__kernel_count + var_name, self._kernel_count ) self._validate_spad_address(var_name, hbm_address) cinstr.source = str(hbm_address) @@ -330,7 +326,7 @@ def _update_cinsts_addresses_and_offsets(self, kernel_cinstrs: list): if isinstance(cinstr, cinst.CStore): var_name = cinstr.dest hbm_address = self.__mem_model.useVariable( - var_name, self.__kernel_count + var_name, self._kernel_count ) self._validate_spad_address(var_name, hbm_address) cinstr.dest = str(hbm_address) @@ -367,9 +363,9 @@ def _update_xinsts(self, kernel_xinstrs: list) -> int: @return int The last bundle number after updating. """ - last_bundle = self.__bundle_offset + last_bundle = self._bundle_offset for xinstr in kernel_xinstrs: - xinstr.bundle = xinstr.bundle + self.__bundle_offset + xinstr.bundle = xinstr.bundle + self._bundle_offset if last_bundle > xinstr.bundle: raise RuntimeError( f'Detected invalid bundle. Instruction bundle is less than previous: "{xinstr.to_line()}"' @@ -405,37 +401,37 @@ def link_kernel( self._update_minsts(kernel_minstrs) self._update_cinsts(kernel_cinstrs) - self.__bundle_offset = self._update_xinsts(kernel_xinstrs) + 1 + self._bundle_offset = self._update_xinsts(kernel_xinstrs) + 1 # Append the kernel to the output for xinstr in kernel_xinstrs: - print(xinstr.to_line(), end="", file=self.__xinst_ostream) + print(xinstr.to_line(), end="", file=self._xinst_ostream) if not GlobalConfig.suppress_comments and xinstr.comment: - print(f" #{xinstr.comment}", end="", file=self.__xinst_ostream) - print(file=self.__xinst_ostream) + print(f" #{xinstr.comment}", end="", file=self._xinst_ostream) + print(file=self._xinst_ostream) for idx, cinstr in enumerate(kernel_cinstrs[:-1]): # Skip the `cexit` - line_no = idx + self.__cinst_line_offset - print(f"{line_no}, {cinstr.to_line()}", end="", file=self.__cinst_ostream) + line_no = idx + self._cinst_line_offset + print(f"{line_no}, {cinstr.to_line()}", end="", file=self._cinst_ostream) if not GlobalConfig.suppress_comments and cinstr.comment: - print(f" #{cinstr.comment}", end="", file=self.__cinst_ostream) - print(file=self.__cinst_ostream) + print(f" #{cinstr.comment}", end="", file=self._cinst_ostream) + print(file=self._cinst_ostream) for idx, minstr in enumerate(kernel_minstrs[:-1]): # Skip the exit `msyncc` - line_no = idx + self.__minst_line_offset - print(f"{line_no}, {minstr.to_line()}", end="", file=self.__minst_ostream) + line_no = idx + self._minst_line_offset + print(f"{line_no}, {minstr.to_line()}", end="", file=self._minst_ostream) if not GlobalConfig.suppress_comments and minstr.comment: - print(f" #{minstr.comment}", end="", file=self.__minst_ostream) - print(file=self.__minst_ostream) + print(f" #{minstr.comment}", end="", file=self._minst_ostream) + print(file=self._minst_ostream) - self.__minst_line_offset += ( + self._minst_line_offset += ( len(kernel_minstrs) - 1 ) # Subtract last line that is getting removed - self.__cinst_line_offset += ( + self._cinst_line_offset += ( len(kernel_cinstrs) - 1 ) # Subtract last line that is getting removed - self.__kernel_count += 1 # Count the appended kernel + self._kernel_count += 1 # Count the appended kernel @classmethod def join_dinst_kernels( diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_program_linker.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_program_linker.py index 46f88f70..96c0ccbe 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_program_linker.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_program_linker.py @@ -59,21 +59,15 @@ def test_init(self): @test Verifies that all instance variables are correctly initialized """ - self.assertEqual( - self.program._LinkedProgram__minst_ostream, self.streams["minst"] - ) - self.assertEqual( - self.program._LinkedProgram__cinst_ostream, self.streams["cinst"] - ) - self.assertEqual( - self.program._LinkedProgram__xinst_ostream, self.streams["xinst"] - ) + self.assertEqual(self.program._minst_ostream, self.streams["minst"]) + self.assertEqual(self.program._cinst_ostream, self.streams["cinst"]) + self.assertEqual(self.program._xinst_ostream, self.streams["xinst"]) self.assertEqual(self.program._LinkedProgram__mem_model, self.mem_model) - self.assertEqual(self.program._LinkedProgram__bundle_offset, 0) - self.assertEqual(self.program._LinkedProgram__minst_line_offset, 0) - self.assertEqual(self.program._LinkedProgram__cinst_line_offset, 0) - self.assertEqual(self.program._LinkedProgram__kernel_count, 0) - self.assertTrue(self.program._LinkedProgram__is_open) + self.assertEqual(self.program._bundle_offset, 0) + self.assertEqual(self.program._minst_line_offset, 0) + self.assertEqual(self.program._cinst_line_offset, 0) + self.assertEqual(self.program._kernel_count, 0) + self.assertTrue(self.program.is_open) def test_is_open_property(self): """@brief Test the is_open property. @@ -81,7 +75,7 @@ def test_is_open_property(self): @test Verifies that the is_open property reflects the internal state """ self.assertTrue(self.program.is_open) - self.program._LinkedProgram__is_open = False + self.program._is_open = False self.assertFalse(self.program.is_open) def test_close(self): @@ -213,8 +207,8 @@ def test_update_minsts(self): # Execute the update kernel_minstrs = [mock_msyncc, mock_mload, mock_mstore] - self.program._LinkedProgram__cinst_line_offset = 10 # Set initial offset - self.program._LinkedProgram__kernel_count = 1 # Set kernel count + self.program._cinst_line_offset = 10 # Set initial offset + self.program._kernel_count = 1 # Set kernel count self.program._update_minsts(kernel_minstrs) # Verify results @@ -307,8 +301,8 @@ def test_update_cinsts_addresses_and_offsets(self): # Execute the method with HBM enabled kernel_cinstrs = [mock_ifetch, mock_csyncm] - self.program._LinkedProgram__bundle_offset = 10 - self.program._LinkedProgram__minst_line_offset = 20 + self.program._bundle_offset = 10 + self.program._minst_line_offset = 20 self.program._update_cinsts_addresses_and_offsets(kernel_cinstrs) # Verify results with HBM enabled @@ -324,7 +318,7 @@ def test_update_cinsts_addresses_and_offsets(self): ] # Return different addresses for different vars kernel_cinstrs = [mock_bload, mock_cstore] - self.program._LinkedProgram__kernel_count = 2 + self.program._kernel_count = 2 self.program._update_cinsts_addresses_and_offsets(kernel_cinstrs) # Verify SPAD instructions were updated @@ -392,7 +386,7 @@ def test_update_xinsts(self): # Execute the method kernel_xinstrs = [mock_xinst1, mock_xinst2] - self.program._LinkedProgram__bundle_offset = 10 + self.program._bundle_offset = 10 last_bundle = self.program._update_xinsts(kernel_xinstrs) # Verify results @@ -456,18 +450,18 @@ def test_link_kernel(self): mock_update_xinsts.assert_called_once_with(kernel_xinstrs) # Verify bundle offset was updated - self.assertEqual(self.program._LinkedProgram__bundle_offset, 6) # 5 + 1 + self.assertEqual(self.program._bundle_offset, 6) # 5 + 1 # Verify line offsets were updated self.assertEqual( - self.program._LinkedProgram__minst_line_offset, 1 + self.program._minst_line_offset, 1 ) # len(kernel_minstrs) - 1 self.assertEqual( - self.program._LinkedProgram__cinst_line_offset, 1 + self.program._cinst_line_offset, 1 ) # len(kernel_cinstrs) - 1 # Verify kernel count was incremented - self.assertEqual(self.program._LinkedProgram__kernel_count, 1) + self.assertEqual(self.program._kernel_count, 1) # Verify output streams contain the instructions xinst_output = self.streams["xinst"].getvalue() @@ -523,7 +517,7 @@ def test_link_kernel_with_no_hbm(self): mock_update_xinsts.assert_called_once_with(kernel_xinstrs) # Verify bundle offset was updated - self.assertEqual(self.program._LinkedProgram__bundle_offset, 6) # 5 + 1 + self.assertEqual(self.program._bundle_offset, 6) # 5 + 1 # No MInst output when HBM is disabled minst_output = self.streams["minst"].getvalue() @@ -535,7 +529,7 @@ def test_link_kernel_with_closed_program(self): @test Verifies that a RuntimeError is raised when linking to a closed program """ # Close the program - self.program._LinkedProgram__is_open = False + self.program._is_open = False # Try to link a kernel with self.assertRaises(RuntimeError): From e88abcca2557d08cbb66370449fcb02b864a511e Mon Sep 17 00:00:00 2001 From: "Rojas Chaves, Jose" Date: Mon, 7 Jul 2025 22:12:52 +0000 Subject: [PATCH 09/12] Removing pyLint filters --- .pre-commit-config.yaml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 84e1b336..aa5b6563 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -79,8 +79,6 @@ repos: *assembler_tools/hec-assembler-tools/debug_tools/main\.py| *assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/| *assembler_tools/hec-assembler-tools/he_as\.py| - *assembler_tools/hec-assembler-tools/linker/__init__\.py| - *assembler_tools/hec-assembler-tools/linker/instructions/__init__\.py| *assembler_tools/hec-assembler-tools/assembler/spec_config/isa_spec.py) args: - -rn # Only display messages From e00667ced19a98c91c956fb9cddf94cf7a06c552 Mon Sep 17 00:00:00 2001 From: "Rojas Chaves, Jose" Date: Tue, 8 Jul 2025 00:03:01 +0000 Subject: [PATCH 10/12] Fixing inits --- .../hec-assembler-tools/he_link.py | 4 +- .../hec-assembler-tools/linker/__init__.py | 119 +++++++++--------- .../linker/instructions/__init__.py | 2 +- .../linker/steps/program_linker.py | 12 +- .../tests/unit_tests/test_he_link.py | 4 +- .../tests/unit_tests/test_linker/test_init.py | 54 ++++---- .../test_instructions/test_init.py | 8 +- .../test_steps/test_program_linker.py | 8 +- 8 files changed, 110 insertions(+), 101 deletions(-) diff --git a/assembler_tools/hec-assembler-tools/he_link.py b/assembler_tools/hec-assembler-tools/he_link.py index 68a2c00e..f714529e 100644 --- a/assembler_tools/hec-assembler-tools/he_link.py +++ b/assembler_tools/hec-assembler-tools/he_link.py @@ -311,7 +311,7 @@ def scan_variables(input_files, mem_model, verbose_stream): ) kernel_cinstrs = loader.load_cinst_kernel_from_file(kernel.cinst) for var_name in variable_discovery.discover_variables_spad(kernel_cinstrs): - mem_model.addVariable(var_name) + mem_model.add_variable(var_name) else: if verbose_stream: print( @@ -321,7 +321,7 @@ def scan_variables(input_files, mem_model, verbose_stream): ) kernel_minstrs = loader.load_minst_kernel_from_file(kernel.minst) for var_name in variable_discovery.discover_variables(kernel_minstrs): - mem_model.addVariable(var_name) + mem_model.add_variable(var_name) def check_unused_variables(mem_model): diff --git a/assembler_tools/hec-assembler-tools/linker/__init__.py b/assembler_tools/hec-assembler-tools/linker/__init__.py index ded86921..3d6174e2 100644 --- a/assembler_tools/hec-assembler-tools/linker/__init__.py +++ b/assembler_tools/hec-assembler-tools/linker/__init__.py @@ -7,9 +7,10 @@ """@brief linker/__init__.py contains classes to encapsulate the memory model used by the linker.""" import collections.abc as collections +from typing import Dict + from assembler.common.config import GlobalConfig from assembler.memory_model import mem_info -from typing import Dict class VariableInfo(mem_info.MemInfoVariable): @@ -64,7 +65,7 @@ def buffer(self) -> list: """ return self.__buffer - def forceAllocate(self, var_info: VariableInfo, hbm_address: int): + def force_allocate(self, var_info: VariableInfo, hbm_address: int): """ @brief Forcefully allocates a variable at a specific HBM address. @@ -76,9 +77,7 @@ def forceAllocate(self, var_info: VariableInfo, hbm_address: int): """ if hbm_address < 0 or hbm_address >= len(self.buffer): raise IndexError( - "`hbm_address` out of bounds. Expected a word address in range [0, {}), but {} received".format( - len(self.buffer), hbm_address - ) + f"`hbm_address` out of bounds. Expected a word address in range [0, {len(self.buffer)}), but {hbm_address} received" ) if var_info.hbm_address != hbm_address: if var_info.hbm_address >= 0: @@ -93,10 +92,8 @@ def forceAllocate(self, var_info: VariableInfo, hbm_address: int): # Note: there is no HBM, so, SPAD is used as the sole memory space if in_var_info and in_var_info.uses > 0: raise RuntimeError( - ( - "HBM address {} already occupied by variable {} " - "when attempting to allocate variable {}" - ).format(hbm_address, in_var_info.var_name, var_info.var_name) + f"HBM address {hbm_address} already occupied by variable {in_var_info.var_name} " + f"when attempting to allocate variable {var_info.var_name}" ) else: if in_var_info and ( @@ -104,10 +101,8 @@ def forceAllocate(self, var_info: VariableInfo, hbm_address: int): or in_var_info.last_kernel_used >= var_info.last_kernel_used ): raise RuntimeError( - ( - "HBM address {} already occupied by variable {} " - "when attempting to allocate variable {}" - ).format(hbm_address, in_var_info.var_name, var_info.var_name) + f"HBM address {hbm_address} already occupied by variable {in_var_info.var_name} " + f"when attempting to allocate variable {var_info.var_name}" ) var_info.hbm_address = hbm_address self.buffer[hbm_address] = var_info @@ -137,7 +132,7 @@ def allocate(self, var_info: VariableInfo): break if retval < 0: raise RuntimeError("Out of HBM memory.") - self.forceAllocate(var_info, retval) + self.force_allocate(var_info, retval) class MemoryModel: @@ -157,50 +152,60 @@ def __init__(self, hbm_size_words: int, mem_meta_info: mem_info.MemInfo): self.__variables: Dict[str, VariableInfo] = ( {} ) # dict(var_name: str, VariableInfo) - self.__keygen_vars = { - var_info.var_name: var_info for var_info in self.__mem_info.keygens - } - self.__mem_info_inputs = { - var_info.var_name: var_info for var_info in self.__mem_info.inputs - } - self.__mem_info_outputs = { - var_info.var_name: var_info for var_info in self.__mem_info.outputs + + # Group related collections into a dictionary + self.__mem_collections = { + "keygen_vars": { + var_info.var_name: var_info for var_info in self.__mem_info.keygens + }, + "inputs": { + var_info.var_name: var_info for var_info in self.__mem_info.inputs + }, + "outputs": { + var_info.var_name: var_info for var_info in self.__mem_info.outputs + }, + "meta": ( + { + var_info.var_name: var_info + for var_info in self.__mem_info.metadata.intt_auxiliary_table + } + | { + var_info.var_name: var_info + for var_info in self.__mem_info.metadata.intt_routing_table + } + | { + var_info.var_name: var_info + for var_info in self.__mem_info.metadata.ntt_auxiliary_table + } + | { + var_info.var_name: var_info + for var_info in self.__mem_info.metadata.ntt_routing_table + } + | { + var_info.var_name: var_info + for var_info in self.__mem_info.metadata.ones + } + | { + var_info.var_name: var_info + for var_info in self.__mem_info.metadata.twiddle + } + | { + var_info.var_name: var_info + for var_info in self.__mem_info.metadata.keygen_seeds + } + ), } - self.__mem_info_meta = ( - { - var_info.var_name: var_info - for var_info in self.__mem_info.metadata.intt_auxiliary_table - } - | { - var_info.var_name: var_info - for var_info in self.__mem_info.metadata.intt_routing_table - } - | { - var_info.var_name: var_info - for var_info in self.__mem_info.metadata.ntt_auxiliary_table - } - | { - var_info.var_name: var_info - for var_info in self.__mem_info.metadata.ntt_routing_table - } - | { - var_info.var_name: var_info - for var_info in self.__mem_info.metadata.ones - } - | { - var_info.var_name: var_info - for var_info in self.__mem_info.metadata.twiddle - } - | { - var_info.var_name: var_info - for var_info in self.__mem_info.metadata.keygen_seeds - } + + # Derived collections + self.__mem_info_fixed_addr_vars = ( + self.__mem_collections["outputs"] | self.__mem_collections["meta"] ) - self.__mem_info_fixed_addr_vars = self.__mem_info_outputs | self.__mem_info_meta # Keygen variables should not be part of mem_info_vars set since they # do not start in HBM self.__mem_info_vars = ( - self.__mem_info_inputs | self.__mem_info_outputs | self.__mem_info_meta + self.__mem_collections["inputs"] + | self.__mem_collections["outputs"] + | self.__mem_collections["meta"] ) @property @@ -212,7 +217,7 @@ def mem_info_meta(self) -> collections.Collection: @return Collection of metadata variable names. """ - return self.__mem_info_meta + return self.__mem_collections["meta"] @property def mem_info_vars(self) -> collections.Collection: @@ -235,7 +240,7 @@ def variables(self) -> dict: """ return self.__variables - def addVariable(self, var_name: str): + def add_variable(self, var_name: str): """ @brief Adds a variable to the HBM model. @@ -253,13 +258,13 @@ def addVariable(self, var_name: str): # with predefined HBM address if var_name in self.__mem_info_fixed_addr_vars: var_info.uses = float("inf") - self.hbm.forceAllocate( + self.hbm.force_allocate( var_info, self.__mem_info_vars[var_name].hbm_address ) self.variables[var_name] = var_info var_info.uses += 1 - def useVariable(self, var_name: str, kernel: int) -> int: + def use_variable(self, var_name: str, kernel: int) -> int: """ @brief Uses a variable, decrementing its usage count. diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/__init__.py b/assembler_tools/hec-assembler-tools/linker/instructions/__init__.py index 57d2d7a5..32e96a7a 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/__init__.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/__init__.py @@ -25,7 +25,7 @@ def create_from_str_line(line: str, factory) -> Optional[BaseInstruction]: for instr_type in factory: try: retval = instr_type(tokens, comment) - except: + except (TypeError, ValueError, AttributeError): retval = None if retval: break diff --git a/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py b/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py index 63e5511f..738cd125 100644 --- a/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py +++ b/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py @@ -178,7 +178,9 @@ def _update_minsts(self, kernel_minstrs: list): # Change mload variable names into HBM addresses if isinstance(minstr, minst.MLoad): var_name = minstr.source - hbm_address = self.__mem_model.useVariable(var_name, self._kernel_count) + hbm_address = self.__mem_model.use_variable( + var_name, self._kernel_count + ) self._validate_hbm_address(var_name, hbm_address) minstr.source = str(hbm_address) minstr.comment = ( @@ -189,7 +191,9 @@ def _update_minsts(self, kernel_minstrs: list): # Change mstore variable names into HBM addresses if isinstance(minstr, minst.MStore): var_name = minstr.dest - hbm_address = self.__mem_model.useVariable(var_name, self._kernel_count) + hbm_address = self.__mem_model.use_variable( + var_name, self._kernel_count + ) self._validate_hbm_address(var_name, hbm_address) minstr.dest = str(hbm_address) minstr.comment = ( @@ -313,7 +317,7 @@ def _update_cinsts_addresses_and_offsets(self, kernel_cinstrs: list): cinstr, (cinst.BLoad, cinst.BOnes, cinst.CLoad, cinst.NLoad) ): var_name = cinstr.source - hbm_address = self.__mem_model.useVariable( + hbm_address = self.__mem_model.use_variable( var_name, self._kernel_count ) self._validate_spad_address(var_name, hbm_address) @@ -325,7 +329,7 @@ def _update_cinsts_addresses_and_offsets(self, kernel_cinstrs: list): ) if isinstance(cinstr, cinst.CStore): var_name = cinstr.dest - hbm_address = self.__mem_model.useVariable( + hbm_address = self.__mem_model.use_variable( var_name, self._kernel_count ) self._validate_spad_address(var_name, hbm_address) diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_link.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_link.py index d857f66a..eff9a306 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_link.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_link.py @@ -379,9 +379,9 @@ def test_scan_variables(self, has_hbm): # Assert if has_hbm: - assert mock_mem_model.addVariable.call_count == 2 + assert mock_mem_model.add_variable.call_count == 2 else: - assert mock_mem_model.addVariable.call_count == 2 + assert mock_mem_model.add_variable.call_count == 2 def test_check_unused_variables(self): """ diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_init.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_init.py index 1570bcf4..910ceb34 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_init.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_init.py @@ -90,37 +90,37 @@ def test_buffer_property(self): self.assertIsNone(item) def test_force_allocate_valid(self): - """@brief Test forceAllocate with valid parameters. + """@brief Test force_allocate with valid parameters. @test Verifies that a variable is properly allocated at the specified address """ var_info = VariableInfo("test_var") - self.hbm.forceAllocate(var_info, 5) + self.hbm.force_allocate(var_info, 5) self.assertEqual(var_info.hbm_address, 5) self.assertEqual(self.hbm.buffer[5], var_info) def test_force_allocate_out_of_bounds(self): - """@brief Test forceAllocate with out of bounds address. + """@brief Test force_allocate with out of bounds address. @test Verifies that IndexError is raised for out-of-bounds addresses """ var_info = VariableInfo("test_var") with self.assertRaises(IndexError): - self.hbm.forceAllocate(var_info, -1) + self.hbm.force_allocate(var_info, -1) with self.assertRaises(IndexError): - self.hbm.forceAllocate(var_info, self.hbm_size) + self.hbm.force_allocate(var_info, self.hbm_size) def test_force_allocate_already_allocated(self): - """@brief Test forceAllocate with already allocated variable. + """@brief Test force_allocate with already allocated variable. @test Verifies that ValueError is raised when variable is already allocated """ var_info = VariableInfo("test_var", 3) with self.assertRaises(ValueError): - self.hbm.forceAllocate(var_info, 5) + self.hbm.force_allocate(var_info, 5) def test_force_allocate_address_occupied_with_hbm(self): - """@brief Test forceAllocate with address occupied and HBM enabled. + """@brief Test force_allocate with address occupied and HBM enabled. @test Verifies that RuntimeError is raised when address is occupied """ @@ -128,15 +128,15 @@ def test_force_allocate_address_occupied_with_hbm(self): # Occupy address 5 var_info1 = VariableInfo("var1") var_info1.uses = 1 - self.hbm.forceAllocate(var_info1, 5) + self.hbm.force_allocate(var_info1, 5) # Try to allocate another variable at the same address var_info2 = VariableInfo("var2") with self.assertRaises(RuntimeError): - self.hbm.forceAllocate(var_info2, 5) + self.hbm.force_allocate(var_info2, 5) def test_force_allocate_address_occupied_without_hbm(self): - """@brief Test forceAllocate with address occupied and HBM disabled. + """@brief Test force_allocate with address occupied and HBM disabled. @test Verifies that RuntimeError is raised when address is occupied """ @@ -144,15 +144,15 @@ def test_force_allocate_address_occupied_without_hbm(self): # Occupy address 5 var_info1 = VariableInfo("var1") var_info1.uses = 1 - self.hbm.forceAllocate(var_info1, 5) + self.hbm.force_allocate(var_info1, 5) # Try to allocate another variable at the same address var_info2 = VariableInfo("var2") with self.assertRaises(RuntimeError): - self.hbm.forceAllocate(var_info2, 5) + self.hbm.force_allocate(var_info2, 5) def test_force_allocate_address_recyclable_with_hbm(self): - """@brief Test forceAllocate with recyclable address and HBM enabled. + """@brief Test force_allocate with recyclable address and HBM enabled. @test Verifies that an address can be recycled when the variable is not used """ @@ -161,12 +161,12 @@ def test_force_allocate_address_recyclable_with_hbm(self): var_info1 = VariableInfo("var1") var_info1.uses = 0 var_info1.last_kernel_used = 1 - self.hbm.forceAllocate(var_info1, 5) + self.hbm.force_allocate(var_info1, 5) # Allocate another variable at the same address with higher kernel index var_info2 = VariableInfo("var2") var_info2.last_kernel_used = 2 - self.hbm.forceAllocate(var_info2, 5) + self.hbm.force_allocate(var_info2, 5) # Check that the new variable is at the address self.assertEqual(self.hbm.buffer[5], var_info2) @@ -191,7 +191,7 @@ def test_allocate_full_memory(self): for i in range(self.hbm_size): var_info = VariableInfo(f"var{i}") var_info.uses = 1 - self.hbm.forceAllocate(var_info, i) + self.hbm.force_allocate(var_info, i) # Try to allocate another variable var_info = VariableInfo("test_var") @@ -209,7 +209,7 @@ def test_allocate_with_recycling(self): var_info = VariableInfo(f"var{i}") var_info.uses = 1 if i != 3 else 0 var_info.last_kernel_used = 1 - self.hbm.forceAllocate(var_info, i) + self.hbm.force_allocate(var_info, i) # Allocate a new variable - should reuse address 3 var_info = VariableInfo("test_var") @@ -290,7 +290,7 @@ def test_add_variable_new(self): @test Verifies that a new variable is correctly added to the model """ - self.memory_model.addVariable("test_var") + self.memory_model.add_variable("test_var") # Check that variable was added self.assertIn("test_var", self.memory_model.variables) @@ -307,10 +307,10 @@ def test_add_variable_existing(self): @test Verifies that the uses count is incremented for an existing variable """ # Add the variable first - self.memory_model.addVariable("test_var") + self.memory_model.add_variable("test_var") # Add it again - self.memory_model.addVariable("test_var") + self.memory_model.add_variable("test_var") # Check that the uses were incremented var_info = self.memory_model.variables["test_var"] @@ -321,7 +321,7 @@ def test_add_variable_from_mem_info(self): @test Verifies that a variable from mem_info is correctly added with its HBM address """ - self.memory_model.addVariable("input_var") + self.memory_model.add_variable("input_var") # Check that variable was added self.assertIn("input_var", self.memory_model.variables) @@ -337,7 +337,7 @@ def test_add_variable_from_fixed_addr_vars(self): @test Verifies that a fixed-address variable is added with infinite uses """ - self.memory_model.addVariable("output_var") + self.memory_model.add_variable("output_var") # Check that variable was added self.assertIn("output_var", self.memory_model.variables) @@ -355,10 +355,10 @@ def test_use_variable(self): @test Verifies that using a variable decrements its uses count and allocates an HBM address """ # Add the variable first - self.memory_model.addVariable("test_var") + self.memory_model.add_variable("test_var") # Use the variable - hbm_address = self.memory_model.useVariable("test_var", 1) + hbm_address = self.memory_model.use_variable("test_var", 1) # Check that uses were decremented var_info = self.memory_model.variables["test_var"] @@ -380,10 +380,10 @@ def test_use_variable_already_allocated(self): @test Verifies that the existing HBM address is returned """ # Add a variable from mem_info which already has an HBM address - self.memory_model.addVariable("input_var") + self.memory_model.add_variable("input_var") # Use the variable - hbm_address = self.memory_model.useVariable("input_var", 1) + hbm_address = self.memory_model.use_variable("input_var", 1) # Check that the returned HBM address is the one from mem_info self.assertEqual(hbm_address, 1) diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_init.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_init.py index 08771961..a0e1f1f0 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_init.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_init.py @@ -118,9 +118,9 @@ def test_create_from_str_line_multiple_instruction_types(self, mock_tokenize): @patch("linker.instructions.tokenize_from_line") def test_create_from_str_line_exception_handling(self, mock_tokenize): """ - @brief Test that general exceptions are caught + @brief Test that specific exceptions are caught - @test Verifies that unexpected exceptions during instruction creation are + @test Verifies that expected exceptions during instruction creation are handled gracefully and None is returned """ # Setup mock @@ -128,8 +128,8 @@ def test_create_from_str_line_exception_handling(self, mock_tokenize): comment = "Test comment" mock_tokenize.return_value = (tokens, comment) - # Make instruction creation raise a different exception - self.mock_class.side_effect = Exception("Unexpected error") + # Make instruction creation raise one of the expected exception types + self.mock_class.side_effect = ValueError("Invalid values") # Call function - should handle the exception and return None result = create_from_str_line( diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_program_linker.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_program_linker.py index 96c0ccbe..b9802696 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_program_linker.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_program_linker.py @@ -200,7 +200,7 @@ def test_update_minsts(self): mock_mstore.comment = None # Set up memory model mock - self.mem_model.useVariable.side_effect = [ + self.mem_model.use_variable.side_effect = [ 10, 20, ] # Return different addresses for different vars @@ -222,7 +222,7 @@ def test_update_minsts(self): self.assertEqual(mock_mstore.dest, "20") # Replaced with HBM address # Verify the memory model was used correctly - self.mem_model.useVariable.assert_has_calls( + self.mem_model.use_variable.assert_has_calls( [call("input_var", 1), call("output_var", 1)] ) @@ -312,7 +312,7 @@ def test_update_cinsts_addresses_and_offsets(self): # Test with HBM disabled with patch.object(GlobalConfig, "hasHBM", False): # Set up memory model mock - self.mem_model.useVariable.side_effect = [ + self.mem_model.use_variable.side_effect = [ 30, 40, ] # Return different addresses for different vars @@ -329,7 +329,7 @@ def test_update_cinsts_addresses_and_offsets(self): self.assertEqual(mock_cstore.dest, "40") # Verify the memory model was used correctly - self.mem_model.useVariable.assert_has_calls( + self.mem_model.use_variable.assert_has_calls( [call("var1", 2), call("var2", 2)] ) From 214c774e0594d9391811cca7d5ca578ef74ff0a2 Mon Sep 17 00:00:00 2001 From: "Rojas Chaves, Jose" Date: Tue, 8 Jul 2025 20:57:34 +0000 Subject: [PATCH 11/12] Fixes for manual integration tests --- assembler_tools/hec-assembler-tools/he_link.py | 2 +- .../linker/instructions/cinst/cinstruction.py | 9 +++++++++ .../linker/instructions/instruction.py | 2 +- .../linker/instructions/minst/minstruction.py | 10 ++++++++++ .../hec-assembler-tools/linker/steps/program_linker.py | 4 ++-- 5 files changed, 23 insertions(+), 4 deletions(-) diff --git a/assembler_tools/hec-assembler-tools/he_link.py b/assembler_tools/hec-assembler-tools/he_link.py index f714529e..b674494e 100644 --- a/assembler_tools/hec-assembler-tools/he_link.py +++ b/assembler_tools/hec-assembler-tools/he_link.py @@ -347,7 +347,7 @@ def main(run_config: LinkerRunConfig, verbose_stream=NullIO()): and links each kernel, writing the output to specified files. @param run_config The configuration object containing run parameters. - @param verbose_stream The stream to which verbose output is printed. Defaults to None. + @param verbose_stream The stream to which verbose output is printed. Defaults to NullIO. @return None """ diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/cinst/cinstruction.py b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/cinstruction.py index a6157425..81df3417 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/cinst/cinstruction.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/cinstruction.py @@ -58,3 +58,12 @@ def __init__(self, tokens: list, comment: str = ""): ValueError: If the number of tokens is invalid or the instruction name is incorrect. """ super().__init__(tokens, comment=comment) + + def to_line(self) -> str: + """ + Retrieves the string form of the instruction to write to the instruction file. + + Returns: + str: The string representation of the instruction, excluding the first token. + """ + return ", ".join(self.tokens[1:]) diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py b/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py index 76d9a22e..7a0ff43a 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py @@ -196,5 +196,5 @@ def to_line(self) -> str: if not GlobalConfig.suppress_comments: comment_str = f" # {self.comment}" if self.comment else "" - tokens_str = ", ".join(self._tokens) + tokens_str = ", ".join(self.tokens) return f"{tokens_str}{comment_str}" diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/minst/minstruction.py b/assembler_tools/hec-assembler-tools/linker/instructions/minst/minstruction.py index e78ce2cd..bb00f61f 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/minst/minstruction.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/minst/minstruction.py @@ -1,6 +1,7 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +# pylint: disable=duplicate-code """This module implements the base class for MInstructions.""" from linker.instructions.instruction import BaseInstruction @@ -58,3 +59,12 @@ def __init__(self, tokens: list, comment: str = ""): ValueError: If the number of tokens is invalid or the instruction name is incorrect. """ super().__init__(tokens, comment=comment) + + def to_line(self) -> str: + """ + Retrieves the string form of the instruction to write to the instruction file. + + Returns: + str: The string representation of the instruction, excluding the first token. + """ + return ", ".join(self.tokens[1:]) diff --git a/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py b/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py index 738cd125..4096bd9e 100644 --- a/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py +++ b/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py @@ -213,7 +213,7 @@ def _remove_and_merge_csyncm_cnop(self, kernel_cinstrs: list): csyncm_count = 0 while i < len(kernel_cinstrs): cinstr = kernel_cinstrs[i] - cinstr.tokens[0] = i # Update the line number + cinstr.tokens[0] = str(i) # Update the line number # ------------------------------ # This code block will remove csyncm instructions and keep track, @@ -276,7 +276,7 @@ def _remove_and_merge_csyncm_cnop(self, kernel_cinstrs: list): i = 0 while i < len(kernel_cinstrs): cinstr = kernel_cinstrs[i] - cinstr.tokens[0] = i + cinstr.tokens[0] = str(i) if isinstance(cinstr, cinst.CNop): # Do look ahead if i + 1 < len(kernel_cinstrs): From 22c95a1afd9223e602c390d07035f1bd9534afac Mon Sep 17 00:00:00 2001 From: "Rojas Chaves, Jose" Date: Tue, 8 Jul 2025 21:02:12 +0000 Subject: [PATCH 12/12] Updating test --- .../unit_tests/test_linker/test_steps/test_program_linker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_program_linker.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_program_linker.py index b9802696..aa89d860 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_program_linker.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_program_linker.py @@ -274,7 +274,7 @@ def test_remove_and_merge_csyncm_cnop(self): # Verify the line numbers were updated for i, instr in enumerate(kernel_cinstrs): - self.assertEqual(instr.tokens[0], i) + self.assertEqual(instr.tokens[0], str(i)) def test_update_cinsts_addresses_and_offsets(self): """@brief Test updating CInst addresses and offsets.