diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ffdc582f..aa5b6563 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -61,8 +61,6 @@ repos: *assembler_tools/hec-assembler-tools/debug_tools/main\.py| *assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/| *assembler_tools/hec-assembler-tools/he_as\.py| - *assembler_tools/hec-assembler-tools/linker/__init__\.py| - *assembler_tools/hec-assembler-tools/linker/instructions/__init__\.py| *assembler_tools/hec-assembler-tools/assembler/spec_config/isa_spec.py) args: ["--follow-imports=skip", "--install-types", "--non-interactive"] - repo: local @@ -81,8 +79,6 @@ repos: *assembler_tools/hec-assembler-tools/debug_tools/main\.py| *assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/| *assembler_tools/hec-assembler-tools/he_as\.py| - *assembler_tools/hec-assembler-tools/linker/__init__\.py| - *assembler_tools/hec-assembler-tools/linker/instructions/__init__\.py| *assembler_tools/hec-assembler-tools/assembler/spec_config/isa_spec.py) args: - -rn # Only display messages diff --git a/assembler_tools/hec-assembler-tools/assembler/common/config.py b/assembler_tools/hec-assembler-tools/assembler/common/config.py index 4e31682e..49f2578a 100644 --- a/assembler_tools/hec-assembler-tools/assembler/common/config.py +++ b/assembler_tools/hec-assembler-tools/assembler/common/config.py @@ -1,29 +1,34 @@ - +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""A configuration class for controlling various aspects of the assembler's behavior.""" + + class GlobalConfig: """ A configuration class for controlling various aspects of the assembler's behavior. Attributes: - suppressComments (bool): + suppress_comments (bool): If True, no comments will be emitted in the output generated by the assembler. - - useHBMPlaceHolders (bool): - Specifies whether to use placeholders (names) for variable locations in HBM + + useHBMPlaceHolders (bool): + Specifies whether to use placeholders (names) for variable locations in HBM or the actual variable locations. - - useXInstFetch (bool): - Specifies whether `xinstfetch` instructions should be added into CInstQ or not. - When no `xinstfetch` instructions are added, it is assumed that the HERACLES + + useXInstFetch (bool): + Specifies whether `xinstfetch` instructions should be added into CInstQ or not. + When no `xinstfetch` instructions are added, it is assumed that the HERACLES automated mechanism for `xinstfetch` will be activated. - - debugVerbose (int): - If greater than 0, verbose prints will occur. Its value indicates how often to - print within loops (every `debugVerbose` iterations). This is used for internal + + debugVerbose (int): + If greater than 0, verbose prints will occur. Its value indicates how often to + print within loops (every `debugVerbose` iterations). This is used for internal debugging purposes. hashHBM (bool): Specifies whether the target architecture has HBM or not. """ - suppressComments = False + suppress_comments = False useHBMPlaceHolders = True useXInstFetch = True debugVerbose: int = 0 diff --git a/assembler_tools/hec-assembler-tools/assembler/common/decorators.py b/assembler_tools/hec-assembler-tools/assembler/common/decorators.py index 09beaa98..1b4c5aea 100644 --- a/assembler_tools/hec-assembler-tools/assembler/common/decorators.py +++ b/assembler_tools/hec-assembler-tools/assembler/common/decorators.py @@ -1,28 +1,30 @@ - -class classproperty(object): +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +This module provides decorator utilities for the assembler. + +It contains decorators that enhance class and function behavior, +including property-like decorators for class methods. +""" + + +# pylint: disable=invalid-name +class classproperty(property): """ A decorator that allows a method to be accessed as a class-level property rather than on instances of the class. """ - def __init__(self, f): - """ - Initializes the classproperty with the given function. - - Args: - f (function): The function to be used as a class-level property. - """ - self.f = f - - def __get__(self, obj, owner): + def __get__(self, cls, owner): """ Retrieves the value of the class-level property. Args: - obj: The instance of the class (ignored in this context). - owner: The class that owns the property. + cls: The class that owns the property. + owner: The owner of the class (ignored in this context). Returns: The result of calling the decorated function with the class as an argument. """ - return self.f(owner) + return self.fget(owner) diff --git a/assembler_tools/hec-assembler-tools/assembler/common/run_config.py b/assembler_tools/hec-assembler-tools/assembler/common/run_config.py index 9737a962..e5ccaef2 100644 --- a/assembler_tools/hec-assembler-tools/assembler/common/run_config.py +++ b/assembler_tools/hec-assembler-tools/assembler/common/run_config.py @@ -1,9 +1,13 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + import io from .decorators import * from . import constants from .config import GlobalConfig + class RunConfig: """ Configuration class for running the assembler with specific settings. @@ -12,11 +16,21 @@ class RunConfig: policies, and other options that affect the behavior of the assembler. """ - __initialized = False # Specifies whether static members have been initialized - __default_config = {} # Dictionary of all configuration items supported and their default values - - def __init__(self, - **kwargs): + # Type annotations for class attributes + has_hbm: bool + hbm_size: int + spad_size: int + repl_policy: str + suppress_comments: bool + use_xinstfetch: bool + debug_verbose: int + + __initialized = False # Specifies whether static members have been initialized + __default_config = ( + {} + ) # Dictionary of all configuration items supported and their default values + + def __init__(self, **kwargs): """ Constructs a new RunConfig object from input parameters. @@ -33,7 +47,7 @@ def __init__(self, suppress_comments (bool, optional): If true, no comments will be emitted in the output generated by the assembler. - Defaults to GlobalConfig.suppressComments (`False`). + Defaults to GlobalConfig.suppress_comments (`False`). use_hbm_placeholders (bool, optional): [DEPRECATED]/[UNUSED] Specifies whether to use placeholders (names) for variable locations in HBM (`True`) @@ -52,28 +66,42 @@ def __init__(self, ValueError: If at least one of the arguments passed is invalid. """ + RunConfig.init_default_config() + # Initialize class members for config_name, default_value in self.__default_config.items(): - setattr(self, config_name, kwargs.get(config_name, default_value)) + value = kwargs.get(config_name) + if value is not None: + setattr(self, config_name, value) + else: + setattr(self, config_name, default_value) # Validate inputs if self.repl_policy not in constants.Constants.REPLACEMENT_POLICIES: - raise ValueError('Invalid `repl_policy`. "{}" not in {}'.format(self.repl_policy, - constants.Constants.REPLACEMENT_POLICIES)) + raise ValueError( + 'Invalid `repl_policy`. "{}" not in {}'.format( + self.repl_policy, constants.Constants.REPLACEMENT_POLICIES + ) + ) + @classproperty def DEFAULT_HBM_SIZE_KB(cls) -> int: - return int(constants.MemoryModel.HBM.MAX_CAPACITY / constants.Constants.KILOBYTE) + return int( + constants.MemoryModel.HBM.MAX_CAPACITY / constants.Constants.KILOBYTE + ) @classproperty def DEFAULT_SPAD_SIZE_KB(cls) -> int: - return int(constants.MemoryModel.SPAD.MAX_CAPACITY / constants.Constants.KILOBYTE) + return int( + constants.MemoryModel.SPAD.MAX_CAPACITY / constants.Constants.KILOBYTE + ) @classproperty def DEFAULT_REPL_POLICY(cls) -> int: return constants.Constants.REPLACEMENT_POLICY_FTBU @classmethod - def init_static(cls): + def init_default_config(cls): """ Initializes static members of the RunConfig class. @@ -81,16 +109,23 @@ def init_static(cls): that they are only initialized once. """ if not cls.__initialized: - cls.__default_config["hbm_size"] = cls.DEFAULT_HBM_SIZE_KB - cls.__default_config["spad_size"] = cls.DEFAULT_SPAD_SIZE_KB - cls.__default_config["repl_policy"] = cls.DEFAULT_REPL_POLICY - cls.__default_config["suppress_comments"] = GlobalConfig.suppressComments - #cls.__default_config["use_hbm_placeholders"] = GlobalConfig.useHBMPlaceHolders - cls.__default_config["use_xinstfetch"] = GlobalConfig.useXInstFetch - cls.__default_config["debug_verbose"] = GlobalConfig.debugVerbose + cls.__default_config["has_hbm"] = True + cls.__default_config["hbm_size"] = cls.DEFAULT_HBM_SIZE_KB + cls.__default_config["spad_size"] = cls.DEFAULT_SPAD_SIZE_KB + cls.__default_config["repl_policy"] = cls.DEFAULT_REPL_POLICY + cls.__default_config["suppress_comments"] = GlobalConfig.suppress_comments + # cls.__default_config["use_hbm_placeholders"] = GlobalConfig.useHBMPlaceHolders + cls.__default_config["use_xinstfetch"] = GlobalConfig.useXInstFetch + cls.__default_config["debug_verbose"] = GlobalConfig.debugVerbose cls.__initialized = True + # For testing only + @classmethod + def reset_class_state(cls): + cls.__initialized = False + cls.__default_config = {} + def __str__(self): """ Returns a string representation of the configuration. @@ -113,4 +148,7 @@ def as_dict(self) -> dict: dict: A dictionary representation of the current configuration settings. """ tmp_self_dict = vars(self) - return { config_name: tmp_self_dict[config_name] for config_name in self.__default_config } + return { + config_name: tmp_self_dict[config_name] + for config_name in self.__default_config + } diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/instruction.py b/assembler_tools/hec-assembler-tools/assembler/instructions/instruction.py index ecbf9e9b..0418b49e 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/instruction.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/instruction.py @@ -552,7 +552,7 @@ def to_string_format(self, preamble, op_name: str, *extra_args) -> str: retval = f'{", ".join(str(x) for x in preamble)}, {retval}' if extra_args: retval += f', {", ".join([str(extra) for extra in extra_args])}' - if not GlobalConfig.suppressComments: + if not GlobalConfig.suppress_comments: if self.comment: retval += f" #{self.comment}" return retval diff --git a/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_info.py b/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_info.py index 35c5bb19..8d034d59 100644 --- a/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_info.py +++ b/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_info.py @@ -1,8 +1,13 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +# These contents may have been developed with support from one +# or more Intel-operated generative artificial intelligence solutions + from assembler.common import constants from assembler.instructions import tokenize_from_line +from typing import Optional +from assembler.common.decorators import * from assembler.memory_model.variable import Variable from . import MemoryModel @@ -113,7 +118,7 @@ class Metadata: class Ones: @classmethod - def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: + def parse_from_mem_tokens(cls, tokens: list) -> MemInfoVariable: """ Parses a ones metadata variable from a tokenized line. @@ -123,7 +128,7 @@ def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: Returns: MemInfoVariable: The parsed ones metadata variable. """ - return MemInfo.Metadata.parseMetaFieldFromMemLine( + return MemInfo.Metadata.parse_meta_field_from_mem_tokens( tokens, MemInfo.Const.Keyword.LOAD_ONES, var_prefix=MemInfo.Const.Keyword.LOAD_ONES, @@ -131,7 +136,7 @@ def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: class NTTAuxTable: @classmethod - def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: + def parse_from_mem_tokens(cls, tokens: list) -> MemInfoVariable: """ Parses an NTT auxiliary table metadata variable from a tokenized line. @@ -141,7 +146,7 @@ def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: Returns: MemInfoVariable: The parsed NTT auxiliary table metadata variable. """ - return MemInfo.Metadata.parseMetaFieldFromMemLine( + return MemInfo.Metadata.parse_meta_field_from_mem_tokens( tokens, MemInfo.Const.Keyword.LOAD_NTT_AUX_TABLE, var_prefix=MemInfo.Const.Keyword.LOAD_NTT_AUX_TABLE, @@ -149,7 +154,7 @@ def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: class NTTRoutingTable: @classmethod - def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: + def parse_from_mem_tokens(cls, tokens: list) -> MemInfoVariable: """ Parses an NTT routing table metadata variable from a tokenized line. @@ -159,7 +164,7 @@ def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: Returns: MemInfoVariable: The parsed NTT routing table metadata variable. """ - return MemInfo.Metadata.parseMetaFieldFromMemLine( + return MemInfo.Metadata.parse_meta_field_from_mem_tokens( tokens, MemInfo.Const.Keyword.LOAD_NTT_ROUTING_TABLE, var_prefix=MemInfo.Const.Keyword.LOAD_NTT_ROUTING_TABLE, @@ -167,7 +172,7 @@ def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: class iNTTAuxTable: @classmethod - def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: + def parse_from_mem_tokens(cls, tokens: list) -> MemInfoVariable: """ Parses an iNTT auxiliary table metadata variable from a tokenized line. @@ -177,7 +182,7 @@ def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: Returns: MemInfoVariable: The parsed iNTT auxiliary table metadata variable. """ - return MemInfo.Metadata.parseMetaFieldFromMemLine( + return MemInfo.Metadata.parse_meta_field_from_mem_tokens( tokens, MemInfo.Const.Keyword.LOAD_iNTT_AUX_TABLE, var_prefix=MemInfo.Const.Keyword.LOAD_iNTT_AUX_TABLE, @@ -185,7 +190,7 @@ def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: class iNTTRoutingTable: @classmethod - def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: + def parse_from_mem_tokens(cls, tokens: list) -> MemInfoVariable: """ Parses an iNTT routing table metadata variable from a tokenized line. @@ -195,7 +200,7 @@ def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: Returns: MemInfoVariable: The parsed iNTT routing table metadata variable. """ - return MemInfo.Metadata.parseMetaFieldFromMemLine( + return MemInfo.Metadata.parse_meta_field_from_mem_tokens( tokens, MemInfo.Const.Keyword.LOAD_iNTT_ROUTING_TABLE, var_prefix=MemInfo.Const.Keyword.LOAD_iNTT_ROUTING_TABLE, @@ -203,7 +208,7 @@ def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: class Twiddle: @classmethod - def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: + def parse_from_mem_tokens(cls, tokens: list) -> MemInfoVariable: """ Parses a twiddle metadata variable from a tokenized line. @@ -213,7 +218,7 @@ def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: Returns: MemInfoVariable: The parsed twiddle metadata variable. """ - return MemInfo.Metadata.parseMetaFieldFromMemLine( + return MemInfo.Metadata.parse_meta_field_from_mem_tokens( tokens, MemInfo.Const.Keyword.LOAD_TWIDDLE, var_prefix=MemInfo.Const.Keyword.LOAD_TWIDDLE, @@ -221,7 +226,7 @@ def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: class KeygenSeed: @classmethod - def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: + def parse_from_mem_tokens(cls, tokens: list) -> MemInfoVariable: """ Parses a keygen seed metadata variable from a tokenized line. @@ -231,14 +236,14 @@ def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: Returns: MemInfoVariable: The parsed keygen seed metadata variable. """ - return MemInfo.Metadata.parseMetaFieldFromMemLine( + return MemInfo.Metadata.parse_meta_field_from_mem_tokens( tokens, MemInfo.Const.Keyword.LOAD_KEYGEN_SEED, var_prefix=MemInfo.Const.Keyword.LOAD_KEYGEN_SEED, ) @classmethod - def parseMetaFieldFromMemLine( + def parse_meta_field_from_mem_tokens( cls, tokens: list, meta_field_name: str, @@ -373,7 +378,7 @@ def keygen_seeds(self) -> list: class Keygen: @classmethod - def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: + def parse_from_mem_tokens(cls, tokens: list) -> MemInfoVariable: """ Parses a keygen variable from a tokenized line. @@ -396,7 +401,7 @@ def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: class Input: @classmethod - def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: + def parse_from_mem_tokens(cls, tokens: list) -> MemInfoVariable: """ Parses an input variable from a tokenized line. @@ -422,7 +427,7 @@ def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: class Output: @classmethod - def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: + def parse_from_mem_tokens(cls, tokens: list) -> MemInfoVariable: """ Parses an output variable from a tokenized line. @@ -448,28 +453,97 @@ def __init__(self, **kwargs): Initializes a new MemInfo object. Clients may call this method without parameters for default initialization. - Clients should use MemInfo.from_iter() constructor to parse the contents of a .mem file. + Clients should use MemInfo.from_file_iter() constructor to parse the contents of a .mem file. Args: kwargs (dict): A dictionary as generated by the method MemInfo.as_dict(). This is provided as a shortcut to creating a MemInfo object from structured data such as the contents of a YAML file. """ - self.__keygens = [ - MemInfoVariable(**d) for d in kwargs.get(MemInfo.Const.FIELD_KEYGENS, []) + self._keygens = [ + MemInfoKeygenVariable(**d) + for d in kwargs.get(MemInfo.Const.FIELD_KEYGENS, []) ] - self.__inputs = [ + self._inputs = [ MemInfoVariable(**d) for d in kwargs.get(MemInfo.Const.FIELD_INPUTS, []) ] - self.__outputs = [ + self._outputs = [ MemInfoVariable(**d) for d in kwargs.get(MemInfo.Const.FIELD_OUTPUTS, []) ] - self.__metadata = MemInfo.Metadata( + self._metadata = MemInfo.Metadata( **kwargs.get(MemInfo.Const.FIELD_METADATA, {}) ) self.validate() + @property + def factory_dict(self): + """ + Returns a dictionary mapping MemInfo types to their respective factory methods. + + This is used to create instances of MemInfoVariable based on the type of memory information. + + Returns: + dict: A dictionary mapping MemInfo types to their factory methods. + """ + return { + MemInfo.Keygen: self.keygens, + MemInfo.Input: self.inputs, + MemInfo.Output: self.outputs, + MemInfo.Metadata.KeygenSeed: self.metadata.keygen_seeds, + MemInfo.Metadata.Ones: self.metadata.ones, + MemInfo.Metadata.NTTAuxTable: self.metadata.ntt_auxiliary_table, + MemInfo.Metadata.NTTRoutingTable: self.metadata.ntt_routing_table, + MemInfo.Metadata.iNTTAuxTable: self.metadata.intt_auxiliary_table, + MemInfo.Metadata.iNTTRoutingTable: self.metadata.intt_routing_table, + MemInfo.Metadata.Twiddle: self.metadata.twiddle, + } + + @classproperty + def mem_info_types(cls): + """Retrieves the list of MemInfo variable types.""" + dummy = cls() + return dummy.factory_dict.keys() + @classmethod - def from_iter(cls, line_iter): + def get_meminfo_var_from_tokens( + cls, tokens + ) -> tuple[Optional[MemInfoVariable], Optional[type]]: + """ + Parses a MemInfo variable from a list of tokens. + + Args: + tokens (list[str]): List of tokens to parse. + + Returns: + tuple[MemInfoVariable,type]: The parsed MemInfo variable and its type, or None if no variable could be parsed. + """ + miv: Optional[MemInfoVariable] = None + mem_info_type: Optional[type] = None + for mem_info_type in cls.mem_info_types: + miv: MemInfoVariable = mem_info_type.parse_from_mem_tokens(tokens) + if miv is not None: + break + + return miv, mem_info_type + + def add_meminfo_var_from_tokens(self, tokens): + """ + Parses a MemInfo variable from a list of tokens and adds it to the appropriate list. + + Args: + tokens (list[str]): List of tokens to parse. + + Raises: + RuntimeError: If the line could not be parsed. + """ + miv: MemInfoVariable = None + miv, mem_info_type = MemInfo.get_meminfo_var_from_tokens(tokens) + if miv is not None and mem_info_type is not None: + self.factory_dict[mem_info_type].append(miv) + else: + raise RuntimeError(f"Could not parse line") + + @classmethod + def from_file_iter(cls, line_iter): """ Creates a new MemInfo object from an iterator of strings, where each string is a line of text to parse. @@ -487,34 +561,42 @@ def from_iter(cls, line_iter): retval = cls() - factory_dict = { - MemInfo.Keygen: retval.keygens, - MemInfo.Input: retval.inputs, - MemInfo.Output: retval.outputs, - MemInfo.Metadata.KeygenSeed: retval.metadata.keygen_seeds, - MemInfo.Metadata.Ones: retval.metadata.ones, - MemInfo.Metadata.NTTAuxTable: retval.metadata.ntt_auxiliary_table, - MemInfo.Metadata.NTTRoutingTable: retval.metadata.ntt_routing_table, - MemInfo.Metadata.iNTTAuxTable: retval.metadata.intt_auxiliary_table, - MemInfo.Metadata.iNTTRoutingTable: retval.metadata.intt_routing_table, - MemInfo.Metadata.Twiddle: retval.metadata.twiddle, - } for line_no, s_line in enumerate(line_iter, 1): s_line = s_line.strip() if s_line: # skip empty lines tokens, _ = tokenize_from_line(s_line) if tokens and len(tokens) > 0: - b_parsed = False - for mem_info_type in factory_dict: - miv: MemInfoVariable = mem_info_type.parseFromMemLine(tokens) - if miv is not None: - factory_dict[mem_info_type].append(miv) - b_parsed = True - break # next line - if not b_parsed: - raise RuntimeError( - f'Could not parse line {line_no}: "{s_line}"' - ) + try: + retval.add_meminfo_var_from_tokens(tokens) + except RuntimeError as e: + raise RuntimeError(f"{e} {line_no}: {s_line}") from e + retval.validate() + return retval + + @classmethod + def from_dinstrs(cls, dinstrs): + """ + Creates a new MemInfo object from an list of DInstructions. + + Args: + dinstrs (list[DInstruction]): List of DInstructions. + + Raises: + RuntimeError: If there is an error parsing the instruction tokens. + + Returns: + MemInfo: The constructed MemInfo object. + """ + + retval = cls() + for ints_no, dinstr in enumerate(dinstrs, 1): + tokens = dinstr.tokens + if tokens and len(tokens) > 0: + try: + retval.add_meminfo_var_from_tokens(tokens) + except RuntimeError as e: + raise RuntimeError(f"{e} {ints_no}: {tokens}") from e + retval.validate() return retval @@ -526,7 +608,7 @@ def keygens(self) -> list: Returns: list: Keygen variables. """ - return self.__keygens + return self._keygens @property def inputs(self) -> list: @@ -536,7 +618,7 @@ def inputs(self) -> list: Returns: list: Input variables. """ - return self.__inputs + return self._inputs @property def outputs(self) -> list: @@ -546,7 +628,7 @@ def outputs(self) -> list: Returns: list: Output variables. """ - return self.__outputs + return self._outputs @property def metadata(self) -> Metadata: @@ -556,7 +638,7 @@ def metadata(self) -> Metadata: Returns: Metadata: MemInfo's metadata. """ - return self.__metadata + return self._metadata def as_dict(self): """ @@ -624,7 +706,7 @@ def validate(self): ) -def __allocateMemInfoVariable(mem_model: MemoryModel, v_info: MemInfoVariable): +def _allocateMemInfoVariable(mem_model: MemoryModel, v_info: MemInfoVariable): """ Allocates a memory information variable in the memory model. @@ -680,11 +762,11 @@ def updateMemoryModelWithMemInfo(mem_model: MemoryModel, mem_info: MemInfo): # Inputs for v_info in mem_info.inputs: - __allocateMemInfoVariable(mem_model, v_info) + _allocateMemInfoVariable(mem_model, v_info) # Outputs for v_info in mem_info.outputs: - __allocateMemInfoVariable(mem_model, v_info) + _allocateMemInfoVariable(mem_model, v_info) mem_model.output_variables.push(v_info.var_name, None) # Metadata @@ -692,7 +774,7 @@ def updateMemoryModelWithMemInfo(mem_model: MemoryModel, mem_info: MemInfo): # Ones for v_info in mem_info.metadata.ones: mem_model.retrieveVarAdd(v_info.var_name) - __allocateMemInfoVariable(mem_model, v_info) + _allocateMemInfoVariable(mem_model, v_info) mem_model.add_meta_ones_var(v_info.var_name) # Shuffle meta vars @@ -700,40 +782,40 @@ def updateMemoryModelWithMemInfo(mem_model: MemoryModel, mem_info: MemInfo): assert len(mem_info.metadata.ntt_auxiliary_table) == 1 v_info = mem_info.metadata.ntt_auxiliary_table[0] mem_model.retrieveVarAdd(v_info.var_name) - __allocateMemInfoVariable(mem_model, v_info) + _allocateMemInfoVariable(mem_model, v_info) mem_model.meta_ntt_aux_table = v_info.var_name if mem_info.metadata.ntt_routing_table: assert len(mem_info.metadata.ntt_routing_table) == 1 v_info = mem_info.metadata.ntt_routing_table[0] mem_model.retrieveVarAdd(v_info.var_name) - __allocateMemInfoVariable(mem_model, v_info) + _allocateMemInfoVariable(mem_model, v_info) mem_model.meta_ntt_routing_table = v_info.var_name if mem_info.metadata.intt_auxiliary_table: assert len(mem_info.metadata.intt_auxiliary_table) == 1 v_info = mem_info.metadata.intt_auxiliary_table[0] mem_model.retrieveVarAdd(v_info.var_name) - __allocateMemInfoVariable(mem_model, v_info) + _allocateMemInfoVariable(mem_model, v_info) mem_model.meta_intt_aux_table = v_info.var_name if mem_info.metadata.intt_routing_table: assert len(mem_info.metadata.intt_routing_table) == 1 v_info = mem_info.metadata.intt_routing_table[0] mem_model.retrieveVarAdd(v_info.var_name) - __allocateMemInfoVariable(mem_model, v_info) + _allocateMemInfoVariable(mem_model, v_info) mem_model.meta_intt_routing_table = v_info.var_name # Twiddle for v_info in mem_info.metadata.twiddle: mem_model.retrieveVarAdd(v_info.var_name) - __allocateMemInfoVariable(mem_model, v_info) + _allocateMemInfoVariable(mem_model, v_info) mem_model.add_meta_twiddle_var(v_info.var_name) # Keygen seeds for v_info in mem_info.metadata.keygen_seeds: mem_model.retrieveVarAdd(v_info.var_name) - __allocateMemInfoVariable(mem_model, v_info) + _allocateMemInfoVariable(mem_model, v_info) mem_model.add_meta_keygen_seed_var(v_info.var_name) # End metadata diff --git a/assembler_tools/hec-assembler-tools/debug_tools/main.py b/assembler_tools/hec-assembler-tools/debug_tools/main.py index a46c0f3a..31573508 100644 --- a/assembler_tools/hec-assembler-tools/debug_tools/main.py +++ b/assembler_tools/hec-assembler-tools/debug_tools/main.py @@ -83,7 +83,7 @@ def main_readmem(args): mem_meta_info = None with open(mem_filename, "r") as mem_ifnum: - mem_meta_info = mem_info.MemInfo.from_iter(mem_ifnum) + mem_meta_info = mem_info.MemInfo.from_file_iter(mem_ifnum) if mem_meta_info: with io.StringIO() as retval_f: @@ -217,7 +217,7 @@ def asmisa_assembly( if b_verbose: print("Interpreting variable meta information...") with open(mem_filename, "r") as mem_ifnum: - mem_meta_info = mem_info.MemInfo.from_iter(mem_ifnum) + mem_meta_info = mem_info.MemInfo.from_file_iter(mem_ifnum) mem_info.updateMemoryModelWithMemInfo(hec_mem_model, mem_meta_info) if b_verbose: @@ -289,7 +289,7 @@ def main_asmisa(args): b_use_old_mem_file = False b_verbose = True if args.verbose > 0 else False GlobalConfig.debugVerbose = 0 - GlobalConfig.suppressComments = False + GlobalConfig.suppress_comments = False GlobalConfig.useHBMPlaceHolders = True GlobalConfig.useXInstFetch = False diff --git a/assembler_tools/hec-assembler-tools/he_as.py b/assembler_tools/hec-assembler-tools/he_as.py index 8c3e07b4..0b666c39 100644 --- a/assembler_tools/hec-assembler-tools/he_as.py +++ b/assembler_tools/hec-assembler-tools/he_as.py @@ -146,7 +146,7 @@ def init_default_config(cls): cls.__default_config["spad_size"] = cls.DEFAULT_SPAD_SIZE_KB cls.__default_config["repl_policy"] = cls.DEFAULT_REPL_POLICY cls.__default_config["use_xinstfetch"] = GlobalConfig.useXInstFetch - cls.__default_config["suppress_comments"] = GlobalConfig.suppressComments + cls.__default_config["suppress_comments"] = GlobalConfig.suppress_comments cls.__default_config["debug_verbose"] = GlobalConfig.debugVerbose cls.__initialized = True @@ -251,7 +251,7 @@ def asmisaAssemble( if b_verbose: print("Interpreting variable meta information...") with open(mem_filename, "r") as mem_ifnum: - mem_meta_info = mem_info.MemInfo.from_iter(mem_ifnum) + mem_meta_info = mem_info.MemInfo.from_file_iter(mem_ifnum) mem_info.updateMemoryModelWithMemInfo(hec_mem_model, mem_meta_info) if b_verbose: @@ -358,7 +358,7 @@ def main(config: AssemblerRunConfig, verbose: bool = False): GlobalConfig.useHBMPlaceHolders = True # config.use_hbm_placeholders GlobalConfig.useXInstFetch = config.use_xinstfetch - GlobalConfig.suppressComments = config.suppress_comments + GlobalConfig.suppress_comments = config.suppress_comments GlobalConfig.hasHBM = config.has_hbm GlobalConfig.debugVerbose = config.debug_verbose diff --git a/assembler_tools/hec-assembler-tools/he_link.py b/assembler_tools/hec-assembler-tools/he_link.py index 2560ad99..b674494e 100644 --- a/assembler_tools/hec-assembler-tools/he_link.py +++ b/assembler_tools/hec-assembler-tools/he_link.py @@ -1,39 +1,35 @@ #! /usr/bin/env python3 -# encoding: utf-8 -""" -This module provides functionality for linking assembled kernels into a full HERACLES program for execution queues: MINST, CINST, and XINST. +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 -Classes: - LinkerRunConfig - Maintains the configuration data for the run. +# These contents may have been developed with support from one +# or more Intel-operated generative artificial intelligence solutions - KernelFiles - Structure for kernel files. +""" +@file he_link.py +@brief This module provides functionality for linking assembled kernels into a full HERACLES program for execution queues: MINST, CINST, and XINST. -Functions: - main(run_config: LinkerRunConfig, verbose_stream=None) - Executes the linking process using the provided configuration. +@par Classes: + - LinkerRunConfig: Maintains the configuration data for the run. + - KernelFiles: Structure for kernel files. - parse_args() -> argparse.Namespace - Parses command-line arguments for the linker script. +@par Functions: + - main(run_config: LinkerRunConfig, verbose_stream=None): Executes the linking process using the provided configuration. + - parse_args() -> argparse.Namespace: Parses command-line arguments for the linker script. -Usage: +@par Usage: This script is intended to be run as a standalone program. It requires specific command-line arguments to specify input and output files and configuration options for the linking process. - """ import argparse import io import os import pathlib import sys -import time import warnings +from typing import NamedTuple, Any, Optional import linker - -from typing import NamedTuple - from assembler.common import constants from assembler.common import makeUniquePath from assembler.common.counter import Counter @@ -45,335 +41,524 @@ from linker import loader from linker.steps import variable_discovery from linker.steps import program_linker +from linker.instructions import BaseInstruction + + +class NullIO: + """ + @class NullIO + @brief A class that provides a no-operation implementation of write and flush methods. + """ + + def write(self, *argts, **kwargs): + """ + @brief A no-operation write method. + """ + + def flush(self): + """ + @brief A no-operation flush method. + """ + class LinkerRunConfig(RunConfig): """ - Maintains the configuration data for the run. + @class LinkerRunConfig + @brief Maintains the configuration data for the run. + + @fn as_dict + @brief Returns the configuration as a dictionary. - Methods: - as_dict() -> dict - Returns the configuration as a dictionary. + @return dict The configuration as a dictionary. """ - __initialized = False # specifies whether static members have been initialized + # Type annotations for class attributes + input_prefixes: list[str] + input_mem_file: str + multi_mem_files: bool + output_dir: str + output_prefix: str + + __initialized = False # specifies whether static members have been initialized # contains the dictionary of all configuration items supported and their # default value (or None if no default) - __default_config = {} + __default_config: dict[str, Any] = {} def __init__(self, **kwargs): """ - Constructs a new LinkerRunConfig Object from input parameters. + @brief Constructs a new LinkerRunConfig Object from input parameters. See base class constructor for more parameters. - Args: - input_prefixes (list[str]): - List of input prefixes, including full path. For an input prefix, linker will - assume there are three files named `input_prefixes[i] + '.minst'`, - `input_prefixes[i] + '.cinst'`, and `input_prefixes[i] + '.xinst'`. - This list must not be empty. - output_prefix (str): - Prefix for the output file names. - Three files will be generated: - `output_dir/output_prefix.minst`, `output_dir/output_prefix.cinst`, and - `output_dir/output_prefix.xinst`. - Output filenames cannot match input file names. - input_mem_file (str): - Input memory file associated with the result kernel. - output_dir (str): current working directory - OPTIONAL directory where to store all intermediate files and final output. - This will be created if it doesn't exists. - Defaults to current working directory. - - Raises: - TypeError: - A mandatory configuration value was missing. - ValueError: - At least, one of the arguments passed is invalid. + @param input_prefixes List of input prefixes, including full path. For an input prefix, linker will + assume there are three files named `input_prefixes[i] + '.minst'`, + `input_prefixes[i] + '.cinst'`, and `input_prefixes[i] + '.xinst'`. + This list must not be empty. + @param output_prefix Prefix for the output file names. + Three files will be generated: + `output_dir/output_prefix.minst`, `output_dir/output_prefix.cinst`, and + `output_dir/output_prefix.xinst`. + Output filenames cannot match input file names. + @param input_mem_file Input memory file associated with the result kernel. + @param output_dir OPTIONAL directory where to store all intermediate files and final output. + This will be created if it doesn't exists. + Defaults to current working directory. + + @exception TypeError A mandatory configuration value was missing. + @exception ValueError At least, one of the arguments passed is invalid. """ + super().__init__(**kwargs) self.init_default_config() - + # class members based on configuration for config_name, default_value in self.__default_config.items(): - value = kwargs.get(config_name) + value = kwargs.get(config_name, default_value) if value is not None: - assert(not hasattr(self, config_name)) setattr(self, config_name, value) else: if not hasattr(self, config_name): setattr(self, config_name, default_value) if getattr(self, config_name) is None: - raise TypeError(f'Expected value for configuration `{config_name}`, but `None` received.') + raise TypeError( + f"Expected value for configuration `{config_name}`, but `None` received." + ) # fix file names self.output_dir = makeUniquePath(self.output_dir) - self.input_mem_file = makeUniquePath(self.input_mem_file) + # E0203: Access to member 'input_mem_file' before its definition. + # But it was defined in previous loop. + if self.input_mem_file != "": # pylint: disable=E0203 + self.input_mem_file = makeUniquePath(self.input_mem_file) @classmethod def init_default_config(cls): """ - Initializes static members of the class. + @brief Initializes static members of the class. """ if not cls.__initialized: - cls.__default_config["input_prefixes"] = None - cls.__default_config["input_mem_file"] = None - cls.__default_config["output_dir"] = os.getcwd() - cls.__default_config["output_prefix"] = None - cls.__default_config["has_hbm"] = True - cls.__default_config["hbm_size"] = cls.DEFAULT_HBM_SIZE_KB - cls.__default_config["use_xinstfetch"] = GlobalConfig.useXInstFetch - cls.__default_config["suppress_comments"] = GlobalConfig.suppressComments + cls.__default_config["input_prefixes"] = None + cls.__default_config["input_mem_file"] = "" + cls.__default_config["multi_mem_files"] = False + cls.__default_config["output_dir"] = os.getcwd() + cls.__default_config["output_prefix"] = None cls.__initialized = True def __str__(self): """ - Provides a string representation of the configuration. - - Returns: - str: The string for the configuration. + @brief Provides a string representation of the configuration. + + @return str The string for the configuration. """ self_dict = self.as_dict() with io.StringIO() as retval_f: for key, value in self_dict.items(): - print("{}: {}".format(key, value), file=retval_f) + print(f"{key}: {value}", file=retval_f) retval = retval_f.getvalue() return retval def as_dict(self) -> dict: """ - Provides the configuration as a dictionary. + @brief Provides the configuration as a dictionary. - Returns: - dict: The configuration. + @return dict The configuration. """ retval = super().as_dict() tmp_self_dict = vars(self) - retval.update({ config_name: tmp_self_dict[config_name] for config_name in self.__default_config }) + retval.update( + { + config_name: tmp_self_dict[config_name] + for config_name in self.__default_config + } + ) return retval + class KernelFiles(NamedTuple): """ - Structure for kernel files. - - Attributes: - minst (str): - Index = 0. Name for file containing MInstructions for represented kernel. - cinst (str): - Index = 1. Name for file containing CInstructions for represented kernel. - xinst (str): - Index = 2. Name for file containing XInstructions for represented kernel. - prefix (str): - Index = 3 + @class KernelFiles + @brief Structure for kernel files. + + @var prefix + Index = 0 + @var minst + Index = 1. Name for file containing MInstructions for represented kernel. + @var cinst + Index = 2. Name for file containing CInstructions for represented kernel. + @var xinst + Index = 3. Name for file containing XInstructions for represented kernel. + @var mem + Index = 4. Name for file containing memory information for represented kernel. + This is used only when multi_mem_files is set. """ + + prefix: str minst: str cinst: str xinst: str - prefix: str - -def main(run_config: LinkerRunConfig, verbose_stream = None): - """ - Executes the linking process using the provided configuration. + mem: Optional[str] = None - This function prepares input and output file names, initializes the memory model, discovers variables, - and links each kernel, writing the output to specified files. - Args: - run_config (LinkerRunConfig): The configuration object containing run parameters. - verbose_stream: The stream to which verbose output is printed. Defaults to None. - - Returns: - None +def link_kernels(input_files, output_files, mem_model, verbose_stream): """ - if verbose_stream: - print("Linking...", file=verbose_stream) + @brief Links input kernels and writes the output to the specified files. - if run_config.use_xinstfetch: - warnings.warn(f'Ignoring configuration flag "use_xinstfetch".') + @param input_files List of KernelFiles for input kernels. + @param output_files KernelFiles for output. + @param mem_model Memory model to use. + @param run_config LinkerRunConfig object. + @param verbose_stream Stream for verbose output. + """ + with open(output_files.minst, "w", encoding="utf-8") as fnum_output_minst, open( + output_files.cinst, "w", encoding="utf-8" + ) as fnum_output_cinst, open( + output_files.xinst, "w", encoding="utf-8" + ) as fnum_output_xinst: + + result_program = program_linker.LinkedProgram( + fnum_output_minst, fnum_output_cinst, fnum_output_xinst, mem_model + ) + for idx, kernel in enumerate(input_files): + if verbose_stream: + print( + f"[ {idx * 100 // len(input_files): >3}% ]", + kernel.prefix, + file=verbose_stream, + ) + kernel_minstrs = loader.load_minst_kernel_from_file(kernel.minst) + kernel_cinstrs = loader.load_cinst_kernel_from_file(kernel.cinst) + kernel_xinstrs = loader.load_xinst_kernel_from_file(kernel.xinst) + result_program.link_kernel(kernel_minstrs, kernel_cinstrs, kernel_xinstrs) + if verbose_stream: + print( + "[ 100% ] Finalizing output", output_files.prefix, file=verbose_stream + ) + result_program.close() - # Update global config - GlobalConfig.hasHBM = run_config.has_hbm - mem_filename: str = run_config.input_mem_file - hbm_capcity_words: int = constants.convertBytes2Words(run_config.hbm_size * constants.Constants.KILOBYTE) - input_files = [] # list(KernelFiles) - output_files: KernelFiles = None +def prepare_output_files(run_config) -> KernelFiles: + """ + @brief Prepares output file names and directories. - # prepare output file names + @param run_config LinkerRunConfig object. + @return KernelFiles Output file paths. + """ output_prefix = os.path.join(run_config.output_dir, run_config.output_prefix) output_dir = os.path.dirname(output_prefix) - pathlib.Path(output_dir).mkdir(exist_ok = True, parents=True) - output_files = KernelFiles(minst=makeUniquePath(output_prefix + '.minst'), - cinst=makeUniquePath(output_prefix + '.cinst'), - xinst=makeUniquePath(output_prefix + '.xinst'), - prefix=makeUniquePath(output_prefix)) + pathlib.Path(output_dir).mkdir(exist_ok=True, parents=True) + out_mem_file = ( + makeUniquePath(output_prefix + ".mem") if run_config.multi_mem_files else None + ) + return KernelFiles( + prefix=makeUniquePath(output_prefix), + minst=makeUniquePath(output_prefix + ".minst"), + cinst=makeUniquePath(output_prefix + ".cinst"), + xinst=makeUniquePath(output_prefix + ".xinst"), + mem=out_mem_file, + ) + + +def prepare_input_files(run_config, output_files) -> list: + """ + @brief Prepares input file names and checks for existence and conflicts. - # prepare input file names + @param run_config LinkerRunConfig object. + @param output_files KernelFiles for output. + @return list List of KernelFiles for input. + @exception FileNotFoundError If an input file does not exist. + @exception RuntimeError If an input file matches an output file. + """ + input_files = [] for file_prefix in run_config.input_prefixes: - input_files.append(KernelFiles(minst=makeUniquePath(file_prefix + '.minst'), - cinst=makeUniquePath(file_prefix + '.cinst'), - xinst=makeUniquePath(file_prefix + '.xinst'), - prefix=makeUniquePath(file_prefix))) - for input_filename in input_files[-1][:-1]: - if not os.path.isfile(input_filename): - raise FileNotFoundError(input_filename) - if input_filename in output_files: - raise RuntimeError(f'Input files cannot match output files: "{input_filename}"') - - # reset counters - Counter.reset() - - # parse mem file - - if verbose_stream: - print("", file=verbose_stream) - print("Interpreting variable meta information...", file=verbose_stream) - - with open(mem_filename, 'r') as mem_ifnum: - mem_meta_info = mem_info.MemInfo.from_iter(mem_ifnum) - - # initialize memory model - if verbose_stream: - print("Initializing linker memory model", file=verbose_stream) - - mem_model = linker.MemoryModel(hbm_capcity_words, mem_meta_info) - if verbose_stream: - print(f" HBM capacity: {mem_model.hbm.capacity} words", file=verbose_stream) - - # find all variables and usage across all the input kernels - - if verbose_stream: - print(" Finding all program variables...", file=verbose_stream) - print(" Scanning", file=verbose_stream) + mem_file = ( + makeUniquePath(file_prefix + ".mem") if run_config.multi_mem_files else None + ) + kernel_files = KernelFiles( + prefix=makeUniquePath(file_prefix), + minst=makeUniquePath(file_prefix + ".minst"), + cinst=makeUniquePath(file_prefix + ".cinst"), + xinst=makeUniquePath(file_prefix + ".xinst"), + mem=mem_file, + ) + input_files.append(kernel_files) + for input_filename in kernel_files[1:]: + if input_filename: + if not os.path.isfile(input_filename): + raise FileNotFoundError(input_filename) + if input_filename in output_files: + raise RuntimeError( + f'Input files cannot match output files: "{input_filename}"' + ) + return input_files + + +def scan_variables(input_files, mem_model, verbose_stream): + """ + @brief Scans input files for variables and adds them to the memory model. + @param input_files List of KernelFiles for input. + @param mem_model Memory model to update. + @param verbose_stream Stream for verbose output. + """ for idx, kernel in enumerate(input_files): if not GlobalConfig.hasHBM: if verbose_stream: - print(" {}/{}".format(idx + 1, len(input_files)), kernel.cinst, - file=verbose_stream) - # load next CInst kernel and scan for variables used in SPAD - kernel_cinstrs = loader.loadCInstKernelFromFile(kernel.cinst) - for var_name in variable_discovery.discoverVariablesSPAD(kernel_cinstrs): - mem_model.addVariable(var_name) + print( + f" {idx + 1}/{len(input_files)}", + kernel.cinst, + file=verbose_stream, + ) + kernel_cinstrs = loader.load_cinst_kernel_from_file(kernel.cinst) + for var_name in variable_discovery.discover_variables_spad(kernel_cinstrs): + mem_model.add_variable(var_name) else: if verbose_stream: - print(" {}/{}".format(idx + 1, len(input_files)), kernel.minst, - file=verbose_stream) - # load next MInst kernel and scan for variables used - kernel_minstrs = loader.loadMInstKernelFromFile(kernel.minst) - for var_name in variable_discovery.discoverVariables(kernel_minstrs): - mem_model.addVariable(var_name) - - # check that all non-keygen variables from MemInfo are used + print( + f" {idx + 1}/{len(input_files)}", + kernel.minst, + file=verbose_stream, + ) + kernel_minstrs = loader.load_minst_kernel_from_file(kernel.minst) + for var_name in variable_discovery.discover_variables(kernel_minstrs): + mem_model.add_variable(var_name) + + +def check_unused_variables(mem_model): + """ + @brief Checks for unused variables in the memory model and raises an error if found. + + @param mem_model Memory model to check. + @exception RuntimeError If an unused variable is found. + """ for var_name in mem_model.mem_info_vars: if var_name not in mem_model.variables: - if GlobalConfig.hasHBM or var_name not in mem_model.mem_info_meta: # skip checking meta vars when no HBM - raise RuntimeError(f'Unused variable from input mem file: "{var_name}" not in memory model.') - - if verbose_stream: - print(f" Variables found: {len(mem_model.variables)}", file=verbose_stream) - - if verbose_stream: - print("Linking started", file=verbose_stream) - - # open the output files - with open(output_files.minst, 'w') as fnum_output_minst, \ - open(output_files.cinst, 'w') as fnum_output_cinst, \ - open(output_files.xinst, 'w') as fnum_output_xinst: - - # prepare the linker class - result_program = program_linker.LinkedProgram(fnum_output_minst, - fnum_output_cinst, - fnum_output_xinst, - mem_model, - supress_comments=run_config.suppress_comments) - # start linking each kernel - for idx, kernel in enumerate(input_files): - if verbose_stream: - print("[ {: >3}% ]".format(idx * 100 // len(input_files)), kernel.prefix, - file=verbose_stream) - kernel_minstrs = loader.loadMInstKernelFromFile(kernel.minst) - kernel_cinstrs = loader.loadCInstKernelFromFile(kernel.cinst) - kernel_xinstrs = loader.loadXInstKernelFromFile(kernel.xinst) + if GlobalConfig.hasHBM or var_name not in mem_model.mem_info_meta: + raise RuntimeError( + f'Unused variable from input mem file: "{var_name}" not in memory model.' + ) - result_program.linkKernel(kernel_minstrs, kernel_cinstrs, kernel_xinstrs) - if verbose_stream: - print("[ 100% ] Finalizing output", output_files.prefix, file=verbose_stream) +def main(run_config: LinkerRunConfig, verbose_stream=NullIO()): + """ + @brief Executes the linking process using the provided configuration. - # signal that we have linked all kernels - result_program.close() + This function prepares input and output file names, initializes the memory model, discovers variables, + and links each kernel, writing the output to specified files. + + @param run_config The configuration object containing run parameters. + @param verbose_stream The stream to which verbose output is printed. Defaults to NullIO. + + @return None + """ + if run_config.use_xinstfetch: + warnings.warn("Ignoring configuration flag 'use_xinstfetch'.") + + # Update global config + GlobalConfig.hasHBM = run_config.has_hbm + GlobalConfig.suppress_comments = run_config.suppress_comments + + mem_filename: str = run_config.input_mem_file + hbm_capacity_words: int = constants.convertBytes2Words( + run_config.hbm_size * constants.Constants.KILOBYTE + ) + + # Prepare input and output files + output_files: KernelFiles = prepare_output_files(run_config) + input_files: list[KernelFiles] = prepare_input_files(run_config, output_files) + + # Reset counters + Counter.reset() + + # parse mem file + print("Linking...", file=verbose_stream) + print("", file=verbose_stream) + print("Interpreting variable meta information...", file=verbose_stream) + + if run_config.multi_mem_files: + kernels_dinstrs = [] + for kernel in input_files: + if kernel.mem is None: + raise RuntimeError(f"Memory file not found for kernel {kernel.prefix}") + kernel_dinstrs = loader.load_dinst_kernel_from_file(kernel.mem) + kernels_dinstrs.append(kernel_dinstrs) + + # Concatenate all mem info objects into one + kernel_dinstrs = program_linker.LinkedProgram.join_dinst_kernels( + kernels_dinstrs + ) + mem_meta_info = mem_info.MemInfo.from_dinstrs(kernel_dinstrs) + else: + with open(mem_filename, "r", encoding="utf-8") as mem_ifnum: + mem_meta_info = mem_info.MemInfo.from_file_iter(mem_ifnum) + + # Initialize memory model + print("Initializing linker memory model", file=verbose_stream) + + mem_model = linker.MemoryModel(hbm_capacity_words, mem_meta_info) + print(f" HBM capacity: {mem_model.hbm.capacity} words", file=verbose_stream) + + print(" Finding all program variables...", file=verbose_stream) + print(" Scanning", file=verbose_stream) + + scan_variables(input_files, mem_model, verbose_stream) + check_unused_variables(mem_model) + + print(f" Variables found: {len(mem_model.variables)}", file=verbose_stream) + print("Linking started", file=verbose_stream) + + link_kernels(input_files, output_files, mem_model, verbose_stream) + + # Write the memory model to the output file + if run_config.multi_mem_files: + if output_files.mem is None: + raise RuntimeError("Output memory file path is None") + BaseInstruction.dump_instructions_to_file(kernel_dinstrs, output_files.mem) + + print("Output written to files:", file=verbose_stream) + print(" ", output_files.minst, file=verbose_stream) + print(" ", output_files.cinst, file=verbose_stream) + print(" ", output_files.xinst, file=verbose_stream) + if run_config.multi_mem_files: + print(" ", output_files.mem, file=verbose_stream) - if verbose_stream: - print("Output written to files:", file=verbose_stream) - print(" ", output_files.minst, file=verbose_stream) - print(" ", output_files.cinst, file=verbose_stream) - print(" ", output_files.xinst, file=verbose_stream) def parse_args(): """ - Parses command-line arguments for the linker script. + @brief Parses command-line arguments for the linker script. This function sets up the argument parser and defines the expected arguments for the script. It returns a Namespace object containing the parsed arguments. - Returns: - argparse.Namespace: Parsed command-line arguments. + @return argparse.Namespace Parsed command-line arguments. """ parser = argparse.ArgumentParser( - description=("HERACLES Linker.\n" - "Links assembled kernels into a full HERACLES program " - "for each of the three execution queues: MINST, CINST, and XINST.\n\n" - "To link several kernels, specify each kernel's input prefix in order. " - "Variables that should carry on across kernels should be have the same name. " - "Linker will recognize matching variables and keep their values between kernels. " - "Variables that are inputs and outputs (and metadata) for the whole program must " - "be indicated in the input memory mapping file.")) - parser.add_argument("input_prefixes", nargs="+", - help=("List of input prefixes, including full path. For an input prefix, linker will " - "assume three files exist named `input_prefixes[i] + '.minst'`, " - "`input_prefixes[i] + '.cinst'`, and `input_prefixes[i] + '.xinst'`.")) - parser.add_argument("--mem_spec", default="", dest="mem_spec_file", - help=("Input Mem specification (.json) file.")) - parser.add_argument("--isa_spec", default="", dest="isa_spec_file", - help=("Input ISA specification (.json) file.")) - parser.add_argument("-im", "--input_mem_file", dest="input_mem_file", required=True, - help=("Input memory mapping file associated with the resulting program. " - "Specifies the names for input, output, and metadata variables for the full program. " - "This file is usually the same as the kernel's when converting a single kernel into " - "a program, but, when linking multiple kernels together, it should be tailored to the " - "whole program.")) - parser.add_argument("-o", "--output_prefix", dest="output_prefix", required=True, - help=("Prefix for the output file names. " - "Three files will be generated: \n" - "`output_dir/output_prefix.minst`, `output_dir/output_prefix.cinst`, and " - "`output_dir/output_prefix.xinst`. \n" - "Output filenames cannot match input file names.")) - parser.add_argument("-od", "--output_dir", dest="output_dir", default="", - help=("Directory where to store all intermediate files and final output. " - "This will be created if it doesn't exists. " - "Defaults to current working directory.")) + description=( + "HERACLES Linker.\n" + "Links assembled kernels into a full HERACLES program " + "for each of the three execution queues: MINST, CINST, and XINST.\n\n" + "To link several kernels, specify each kernel's input prefix in order. " + "Variables that should carry on across kernels should be have the same name. " + "Linker will recognize matching variables and keep their values between kernels. " + "Variables that are inputs and outputs (and metadata) for the whole program must " + "be indicated in the input memory mapping file." + ) + ) + parser.add_argument( + "input_prefixes", + nargs="+", + help=( + "List of input prefixes, including full path. For an input prefix, linker will " + "assume three files exist named `input_prefixes[i] + '.minst'`, " + "`input_prefixes[i] + '.cinst'`, and `input_prefixes[i] + '.xinst'`." + ), + ) + parser.add_argument( + "--mem_spec", + default="", + dest="mem_spec_file", + help=("Input Mem specification (.json) file."), + ) + parser.add_argument( + "--isa_spec", + default="", + dest="isa_spec_file", + help=("Input ISA specification (.json) file."), + ) + parser.add_argument( + "--multi_mem_files", + action="store_true", + dest="multi_mem_files", + help=( + "Tells the linker to find a memory file (*.tw.mem) for each input prefix given." + "This can be used to link multiple kernels together. " + "If this flag is not set, the linker will use the input_mem_file argument instead" + ), + ) + parser.add_argument( + "-im", + "--input_mem_file", + dest="input_mem_file", + required=False, + default="", + help=( + "Input memory mapping file associated with the resulting program. " + "Specifies the names for input, output, and metadata variables for a single kernel" + " or also a full program if instead this is used to link multiple kernels together." + ), + ) + parser.add_argument( + "-o", + "--output_prefix", + dest="output_prefix", + required=True, + help=( + "Prefix for the output file names. " + "Three files will be generated: \n" + "`output_dir/output_prefix.minst`, `output_dir/output_prefix.cinst`, and " + "`output_dir/output_prefix.xinst`. \n" + "Output filenames cannot match input file names." + ), + ) + parser.add_argument( + "-od", + "--output_dir", + dest="output_dir", + default="", + help=( + "Directory where to store all intermediate files and final output. " + "This will be created if it doesn't exists. " + "Defaults to current working directory." + ), + ) parser.add_argument("--hbm_size", type=int, help="HBM size in KB.") - parser.add_argument("--no_hbm", dest="has_hbm", action="store_false", - help="If set, this flag tells he_prep there is no HBM in the target chip.") - parser.add_argument("--suppress_comments", "--no_comments", dest="suppress_comments", action="store_true", - help=("When enabled, no comments will be emited on the output generated.")) - parser.add_argument("-v", "--verbose", dest="verbose", action="count", default=0, - help=("If enabled, extra information and progress reports are printed to stdout. " - "Increase level of verbosity by specifying flag multiple times, e.g. -vv")) - args = parser.parse_args() + parser.add_argument( + "--no_hbm", + dest="has_hbm", + action="store_false", + help="If set, this flag tells he_prep there is no HBM in the target chip.", + ) + parser.add_argument( + "--suppress_comments", + "--no_comments", + dest="suppress_comments", + action="store_true", + help=("When enabled, no comments will be emitted on the output generated."), + ) + parser.add_argument( + "-v", + "--verbose", + dest="verbose", + action="count", + default=0, + help=( + "If enabled, extra information and progress reports are printed to stdout. " + "Increase level of verbosity by specifying flag multiple times, e.g. -vv" + ), + ) + p_args = parser.parse_args() + + # Enforce input_mem_file only if multi_mem_files is not set + if not p_args.multi_mem_files and p_args.input_mem_file == "": + parser.error( + "the following arguments are required: -im/--input_mem_file (unless --multi_mem_files is set)" + ) + + return p_args - return args if __name__ == "__main__": module_dir = os.path.dirname(__file__) module_name = os.path.basename(__file__) args = parse_args() - args.mem_spec_file = MemSpecConfig.initialize_mem_spec(module_dir, args.mem_spec_file) - args.isa_spec_file = ISASpecConfig.initialize_isa_spec(module_dir, args.isa_spec_file) - config = LinkerRunConfig(**vars(args)) # convert argsparser into a dictionary + args.mem_spec_file = MemSpecConfig.initialize_mem_spec( + module_dir, args.mem_spec_file + ) + args.isa_spec_file = ISASpecConfig.initialize_isa_spec( + module_dir, args.isa_spec_file + ) + config = LinkerRunConfig(**vars(args)) # convert argsparser into a dictionary if args.verbose > 0: print(module_name) @@ -384,7 +569,7 @@ def parse_args(): print("=================") print() - main(config, sys.stdout if args.verbose > 1 else None) + main(config, sys.stdout if args.verbose > 1 else NullIO()) if args.verbose > 0: print() diff --git a/assembler_tools/hec-assembler-tools/linker/__init__.py b/assembler_tools/hec-assembler-tools/linker/__init__.py index a32676d9..3d6174e2 100644 --- a/assembler_tools/hec-assembler-tools/linker/__init__.py +++ b/assembler_tools/hec-assembler-tools/linker/__init__.py @@ -1,86 +1,89 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + +"""@brief linker/__init__.py contains classes to encapsulate the memory model used by the linker.""" + import collections.abc as collections +from typing import Dict + from assembler.common.config import GlobalConfig from assembler.memory_model import mem_info -# linker/__init__.py contains classes to encapsulate the memory model used -# by the linker. class VariableInfo(mem_info.MemInfoVariable): """ - Represents information about a variable in the memory model. + @brief Represents information about a variable in the memory model. """ def __init__(self, var_name, hbm_address=-1): """ - Initializes a VariableInfo object. + @brief Initializes a VariableInfo object. - Parameters: - var_name (str): The name of the variable. - hbm_address (int): The HBM address of the variable. Defaults to -1. + @param var_name The name of the variable. + @param hbm_address The HBM address of the variable. Defaults to -1. """ super().__init__(var_name, hbm_address) self.uses = 0 self.last_kernel_used = -1 + class HBM: """ - Represents the HBM model. + @brief Represents the HBM model. """ def __init__(self, hbm_size_words: int): """ - Initializes an HBM object. - - Parameters: - hbm_size_words (int): The size of the HBM in words. + @brief Initializes an HBM object. - Raises: - ValueError: If hbm_size_words is less than 1. + @param hbm_size_words The size of the HBM in words. + @throws ValueError If hbm_size_words is less than 1. """ if hbm_size_words < 1: - raise ValueError('`hbm_size_words` must be a positive integer.') + raise ValueError("`hbm_size_words` must be a positive integer.") # Represents the memory buffer where variables live self.__buffer = [None] * hbm_size_words @property def capacity(self) -> int: """ - Gets the capacity in words for the HBM buffer. + @brief Gets the capacity in words for the HBM buffer. - Returns: - int: The capacity of the HBM buffer. + @return The capacity of the HBM buffer. """ return len(self.buffer) @property def buffer(self) -> list: """ - Gets the HBM buffer. + @brief Gets the HBM buffer. - Returns: - list: The HBM buffer. + @return The HBM buffer. """ return self.__buffer - def forceAllocate(self, var_info: VariableInfo, hbm_address: int): + def force_allocate(self, var_info: VariableInfo, hbm_address: int): """ - Forcefully allocates a variable at a specific HBM address. - - Parameters: - var_info (VariableInfo): The variable information. - hbm_address (int): The HBM address to allocate the variable. + @brief Forcefully allocates a variable at a specific HBM address. - Raises: - IndexError: If hbm_address is out of bounds. - ValueError: If the variable is already allocated at a different address. - RuntimeError: If the HBM address is already occupied by another variable. + @param var_info The variable information. + @param hbm_address The HBM address to allocate the variable. + @throws IndexError If hbm_address is out of bounds. + @throws ValueError If the variable is already allocated at a different address. + @throws RuntimeError If the HBM address is already occupied by another variable. """ if hbm_address < 0 or hbm_address >= len(self.buffer): - raise IndexError('`hbm_address` out of bounds. Expected a word address in range [0, {}), but {} received'.format(len(self.buffer), - hbm_address)) + raise IndexError( + f"`hbm_address` out of bounds. Expected a word address in range [0, {len(self.buffer)}), but {hbm_address} received" + ) if var_info.hbm_address != hbm_address: if var_info.hbm_address >= 0: - raise ValueError(f'`var_info`: variable {var_info.var_name} already allocated in address {var_info.hbm_address}.') + raise ValueError( + f"`var_info`: variable {var_info.var_name} already allocated in address {var_info.hbm_address}." + ) in_var_info = self.buffer[hbm_address] # Validate hbm address @@ -88,29 +91,28 @@ def forceAllocate(self, var_info: VariableInfo, hbm_address: int): # Attempt to recycle SPAD locations inside kernel when no HBM # Note: there is no HBM, so, SPAD is used as the sole memory space if in_var_info and in_var_info.uses > 0: - raise RuntimeError(('HBM address {} already occupied by variable {} ' - 'when attempting to allocate variable {}').format(hbm_address, - in_var_info.var_name, - var_info.var_name)) + raise RuntimeError( + f"HBM address {hbm_address} already occupied by variable {in_var_info.var_name} " + f"when attempting to allocate variable {var_info.var_name}" + ) else: - if in_var_info \ - and (in_var_info.uses > 0 or in_var_info.last_kernel_used >= var_info.last_kernel_used): - raise RuntimeError(('HBM address {} already occupied by variable {} ' - 'when attempting to allocate variable {}').format(hbm_address, - in_var_info.var_name, - var_info.var_name)) + if in_var_info and ( + in_var_info.uses > 0 + or in_var_info.last_kernel_used >= var_info.last_kernel_used + ): + raise RuntimeError( + f"HBM address {hbm_address} already occupied by variable {in_var_info.var_name} " + f"when attempting to allocate variable {var_info.var_name}" + ) var_info.hbm_address = hbm_address self.buffer[hbm_address] = var_info def allocate(self, var_info: VariableInfo): """ - Allocates a variable in the HBM. + @brief Allocates a variable in the HBM. - Parameters: - var_info (VariableInfo): The variable information. - - Raises: - RuntimeError: If there is no available HBM memory. + @param var_info The variable information. + @throws RuntimeError If there is no available HBM memory. """ # Find next available HBM address retval = -1 @@ -122,83 +124,129 @@ def allocate(self, var_info: VariableInfo): retval = idx break else: - if not in_var_info \ - or (in_var_info.uses <= 0 and in_var_info.last_kernel_used < var_info.last_kernel_used): + if not in_var_info or ( + in_var_info.uses <= 0 + and in_var_info.last_kernel_used < var_info.last_kernel_used + ): retval = idx break if retval < 0: - raise RuntimeError('Out of HBM memory.') - self.forceAllocate(var_info, retval) + raise RuntimeError("Out of HBM memory.") + self.force_allocate(var_info, retval) + class MemoryModel: """ - Encapsulates the memory model for a linker run, tracking HBM usage and program variables. + @brief Encapsulates the memory model for a linker run, tracking HBM usage and program variables. """ def __init__(self, hbm_size_words: int, mem_meta_info: mem_info.MemInfo): """ - Initializes a MemoryModel object. + @brief Initializes a MemoryModel object. - Parameters: - hbm_size_words (int): The size of the HBM in words. - mem_meta_info (mem_info.MemInfo): The memory metadata information. + @param hbm_size_words The size of the HBM in words. + @param mem_meta_info The memory metadata information. """ self.hbm = HBM(hbm_size_words) self.__mem_info = mem_meta_info - self.__variables = {} # dict(var_name: str, VariableInfo) - self.__keygen_vars = {var_info.var_name: var_info for var_info in self.__mem_info.keygens} - self.__mem_info_inputs = {var_info.var_name: var_info for var_info in self.__mem_info.inputs} - self.__mem_info_outputs = {var_info.var_name: var_info for var_info in self.__mem_info.outputs} - self.__mem_info_meta = {var_info.var_name: var_info for var_info in self.__mem_info.metadata.intt_auxiliary_table} \ - | {var_info.var_name: var_info for var_info in self.__mem_info.metadata.intt_routing_table} \ - | {var_info.var_name: var_info for var_info in self.__mem_info.metadata.ntt_auxiliary_table} \ - | {var_info.var_name: var_info for var_info in self.__mem_info.metadata.ntt_routing_table} \ - | {var_info.var_name: var_info for var_info in self.__mem_info.metadata.ones} \ - | {var_info.var_name: var_info for var_info in self.__mem_info.metadata.twiddle} \ - | {var_info.var_name: var_info for var_info in self.__mem_info.metadata.keygen_seeds} - self.__mem_info_fixed_addr_vars = self.__mem_info_outputs | self.__mem_info_meta + self.__variables: Dict[str, VariableInfo] = ( + {} + ) # dict(var_name: str, VariableInfo) + + # Group related collections into a dictionary + self.__mem_collections = { + "keygen_vars": { + var_info.var_name: var_info for var_info in self.__mem_info.keygens + }, + "inputs": { + var_info.var_name: var_info for var_info in self.__mem_info.inputs + }, + "outputs": { + var_info.var_name: var_info for var_info in self.__mem_info.outputs + }, + "meta": ( + { + var_info.var_name: var_info + for var_info in self.__mem_info.metadata.intt_auxiliary_table + } + | { + var_info.var_name: var_info + for var_info in self.__mem_info.metadata.intt_routing_table + } + | { + var_info.var_name: var_info + for var_info in self.__mem_info.metadata.ntt_auxiliary_table + } + | { + var_info.var_name: var_info + for var_info in self.__mem_info.metadata.ntt_routing_table + } + | { + var_info.var_name: var_info + for var_info in self.__mem_info.metadata.ones + } + | { + var_info.var_name: var_info + for var_info in self.__mem_info.metadata.twiddle + } + | { + var_info.var_name: var_info + for var_info in self.__mem_info.metadata.keygen_seeds + } + ), + } + + # Derived collections + self.__mem_info_fixed_addr_vars = ( + self.__mem_collections["outputs"] | self.__mem_collections["meta"] + ) # Keygen variables should not be part of mem_info_vars set since they # do not start in HBM - self.__mem_info_vars = self.__mem_info_inputs | self.__mem_info_outputs | self.__mem_info_meta + self.__mem_info_vars = ( + self.__mem_collections["inputs"] + | self.__mem_collections["outputs"] + | self.__mem_collections["meta"] + ) @property def mem_info_meta(self) -> collections.Collection: """ - Set of metadata variable names in MemInfo used to construct this object. + @brief Set of metadata variable names in MemInfo used to construct this object. + Clients must not modify this set. + + @return Collection of metadata variable names. """ - return self.__mem_info_meta - + return self.__mem_collections["meta"] + @property def mem_info_vars(self) -> collections.Collection: """ - Gets the set of variable names in MemInfo used to construct this object. + @brief Gets the set of variable names in MemInfo used to construct this object. - Returns: - collections.Collection: The set of variable names. + @return The set of variable names. """ return self.__mem_info_vars @property def variables(self) -> dict: """ - Gets direct access to internal variables dictionary. + @brief Gets direct access to internal variables dictionary. Clients should use as read-only. Must not add, replace, remove or change contents in any way. Use provided helper functions to manipulate. - Returns: - dict: A dictionary of variables. + @return A dictionary of variables. """ return self.__variables - def addVariable(self, var_name: str): + def add_variable(self, var_name: str): """ - Adds a variable to the HBM model. If variable already exists, its `uses` - field is incremented. + @brief Adds a variable to the HBM model. + + If variable already exists, its `uses` field is incremented. - Parameters: - var_name (str): The name of the variable to add. + @param var_name The name of the variable to add. """ var_info: VariableInfo if var_name in self.variables: @@ -209,25 +257,23 @@ def addVariable(self, var_name: str): # Variables explicitly marked in mem file must persist throughout the program # with predefined HBM address if var_name in self.__mem_info_fixed_addr_vars: - var_info.uses = float('inf') - self.hbm.forceAllocate(var_info, - self.__mem_info_vars[var_name].hbm_address) + var_info.uses = float("inf") + self.hbm.force_allocate( + var_info, self.__mem_info_vars[var_name].hbm_address + ) self.variables[var_name] = var_info var_info.uses += 1 - def useVariable(self, var_name: str, kernel: int) -> int: + def use_variable(self, var_name: str, kernel: int) -> int: """ - Uses a variable, decrementing its usage count. + @brief Uses a variable, decrementing its usage count. If a variable usage count reaches zero, it will be deallocated from HBM, if needed, when a future kernel requires HBM space. - Parameters: - var_name (str): The name of the variable to use. - kernel (int): The kernel that is using the variable. - - Returns: - int: The HBM address for the variable. + @param var_name The name of the variable to use. + @param kernel The kernel that is using the variable. + @return The HBM address for the variable. """ var_info: VariableInfo = self.variables[var_name] assert var_info.uses > 0 @@ -240,7 +286,8 @@ def useVariable(self, var_name: str, kernel: int) -> int: self.hbm.allocate(var_info) assert var_info.hbm_address >= 0 - assert self.hbm.buffer[var_info.hbm_address].var_name == var_info.var_name, \ - f'Expected variable {var_info.var_name} in HBM {var_info.hbm_address}, but variable {self.hbm[var_info.hbm_address].var_name} found instead.' + assert ( + self.hbm.buffer[var_info.hbm_address].var_name == var_info.var_name + ), f"Expected variable {var_info.var_name} in HBM {var_info.hbm_address}, but variable {self.hbm.buffer[var_info.hbm_address].var_name} found instead." - return var_info.hbm_address \ No newline at end of file + return var_info.hbm_address diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/__init__.py b/assembler_tools/hec-assembler-tools/linker/instructions/__init__.py index 135608cc..32e96a7a 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/__init__.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/__init__.py @@ -1,27 +1,31 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + +"""@brief This module provides functionality to create instruction objects from a line of text.""" + +from typing import Optional from assembler.instructions import tokenize_from_line from linker.instructions.instruction import BaseInstruction -def fromStrLine(line: str, factory) -> BaseInstruction: +def create_from_str_line(line: str, factory) -> Optional[BaseInstruction]: """ - Parses an instruction from a line of text. - - Parameters: - line (str): Line of text from which to parse an instruction. + @brief Parses an instruction from a line of text. - Returns: - BaseInstruction or None: The parsed BaseInstruction object, or None if no object could be - parsed from the specified input line. + @param line Line of text from which to parse an instruction. + @param factory Factory function or collection to create instruction objects. + @return The parsed BaseInstruction object, or None if no object could be + parsed from the specified input line. """ retval = None tokens, comment = tokenize_from_line(line) for instr_type in factory: try: retval = instr_type(tokens, comment) - except: + except (TypeError, ValueError, AttributeError): retval = None if retval: break diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/cinst/cinstruction.py b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/cinstruction.py index a82fb36d..81df3417 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/cinst/cinstruction.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/cinstruction.py @@ -1,10 +1,27 @@ -from linker.instructions.instruction import BaseInstruction +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""This module implements the base class for CInstructions.""" + +from linker.instructions.instruction import BaseInstruction + class CInstruction(BaseInstruction): """ Represents a CInstruction, inheriting from BaseInstruction. """ + @classmethod + def _get_name(cls) -> str: + """ + Derived classes should implement this method and return correct + name for the instruction. + + Raises: + NotImplementedError: Abstract method. This base method should not be called. + """ + raise NotImplementedError() + @classmethod def _get_name_token_index(cls) -> int: """ @@ -15,6 +32,17 @@ def _get_name_token_index(cls) -> int: """ return 1 + @classmethod + def _get_num_tokens(cls) -> int: + """ + Derived classes should implement this method and return correct + required number of tokens for the instruction. + + Raises: + NotImplementedError: Abstract method. This base method should not be called. + """ + raise NotImplementedError() + # Constructor # ----------- @@ -38,4 +66,4 @@ def to_line(self) -> str: Returns: str: The string representation of the instruction, excluding the first token. """ - return ", ".join(self.tokens[1:]) \ No newline at end of file + return ", ".join(self.tokens[1:]) diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/__init__.py b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/__init__.py new file mode 100644 index 00000000..1bcc2452 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/__init__.py @@ -0,0 +1,61 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + +"""@brief This module provides functionality to create and manage data instructions""" + +from typing import Optional + +from assembler.instructions import tokenize_from_line +from assembler.memory_model.mem_info import MemInfo +from . import dload, dstore, dkeygen +from . import dinstruction + +DLoad = dload.Instruction +DStore = dstore.Instruction +DKeyGen = dkeygen.Instruction + + +def factory() -> set: + """ + @brief Creates a set of all DInstruction classes. + + @return A set containing all DInstruction classes. + """ + return {DLoad, DStore, DKeyGen} + + +def create_from_mem_line(line: str) -> dinstruction.DInstruction: + """ + @brief Parses an data instruction from a line of the memory map. + + @param line Line of text from which to parse an instruction. + @return The parsed DInstruction object, or None if no object could be + parsed from the specified input line. + @throws RuntimeError If no valid instruction is found or if there's an error parsing the memory map line. + """ + retval: Optional[dinstruction.DInstruction] = None + tokens, comment = tokenize_from_line(line) + for instr_type in factory(): + try: + retval = instr_type(tokens, comment) + except ValueError: + retval = None + if retval: + break + + if not retval: + raise RuntimeError(f'No valid instruction found for line "{line}"') + + try: + miv, _ = MemInfo.get_meminfo_var_from_tokens(tokens) + except RuntimeError as e: + raise RuntimeError(f'Error parsing memory map line "{line}"') from e + + miv_dict = miv.as_dict() + retval.var = miv_dict["var_name"] + retval.address = miv_dict["hbm_address"] + + return retval diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dinstruction.py b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dinstruction.py new file mode 100644 index 00000000..9db0ad16 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dinstruction.py @@ -0,0 +1,146 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + +""" +@brief This module defines the base DInstruction class for data handling instructions. + +DInstruction is the parent class for all data instructions used in the +assembly process, providing common functionality and interfaces. +""" + +from linker.instructions.instruction import BaseInstruction +from assembler.common.counter import Counter +from assembler.common.decorators import classproperty + + +class DInstruction(BaseInstruction): + """ + @brief Represents a DInstruction, inheriting from BaseInstruction. + """ + + _local_id_count = Counter.count(0) # Local counter for DInstruction IDs + _var: str = "" + _address: int = 0 + + @classmethod + def _get_name(cls) -> str: + """ + @brief Derived classes should implement this method and return correct + name for the instruction. + + @throws NotImplementedError Abstract method. This base method should not be called. + """ + raise NotImplementedError() + + @classmethod + def _get_name_token_index(cls) -> int: + """ + @brief Gets the index of the token containing the name of the instruction. + + @return The index of the name token, which is 0. + """ + return 0 + + @classmethod + def _get_num_tokens(cls) -> int: + """ + @brief Derived classes should implement this method and return correct + required number of tokens for the instruction. + + @throws NotImplementedError Abstract method. This base method should not be called. + """ + raise NotImplementedError() + + @classproperty + def num_tokens(self) -> int: + """ + @brief Valid number of tokens for this instruction. + + @return Valid number of tokens. + """ + return self._get_num_tokens() + + def _validate_tokens(self, tokens: list) -> None: + """ + @brief Validates the tokens for this instruction. + + DInstruction allows at least the required number of tokens. + + @param tokens List of tokens to validate. + @throws ValueError If tokens are invalid. + """ + assert self.name_token_index < self.num_tokens + if len(tokens) < self.num_tokens: + raise ValueError( + f"`tokens`: invalid amount of tokens. " + f"Instruction {self.name} requires at least {self.num_tokens}, but {len(tokens)} received" + ) + if tokens[self.name_token_index] != self.name: + raise ValueError( + f"`tokens`: invalid name. Expected {self.name}, but {tokens[self.name_token_index]} received" + ) + + def __init__(self, tokens: list, comment: str = ""): + """ + @brief Constructs a new DInstruction. + + @param tokens List of tokens for the instruction. + @param comment Optional comment for the instruction. + """ + # Do not increment the global instruction count; skip BaseInstruction's __init__ logic for __id + # Call BaseInstruction constructor but perform our own initialization + super().__init__(tokens, comment=comment) + + self.comment = comment + self._tokens = list(tokens) + self._local_id = next(DInstruction._local_id_count) + + @property + def id(self): + """ + @brief Unique ID for the instruction. + + This is a combination of the client ID specified during construction and a unique nonce per instruction. + + @return (client_id: int, nonce: int) where client_id is the id specified at construction. + """ + return self._local_id + + @property + def var(self) -> str: + """ + @brief Name of source/dest var. + + @return The variable name. + """ + return self._var + + @var.setter + def var(self, value: str): + """ + @brief Sets the variable name. + + @param value The new variable name. + """ + self._var = value + + @property + def address(self) -> int: + """ + @brief Should be set to source/dest Mem address. + + @return The memory address. + """ + return self._address + + @address.setter + def address(self, value: int): + """ + @brief Sets the memory address. + + @param value The new memory address (string or integer). + """ + self._address = value diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dkeygen.py b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dkeygen.py new file mode 100644 index 00000000..3980fdf2 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dkeygen.py @@ -0,0 +1,37 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + +""" +@brief This module implements the DKeyGen instruction for key generation operations. +""" + +from assembler.memory_model.mem_info import MemInfo + +from .dinstruction import DInstruction + + +class Instruction(DInstruction): + """ + @brief Encapsulates a `dkeygen` DInstruction. + """ + + @classmethod + def _get_num_tokens(cls) -> int: + """ + @brief Gets the number of tokens required for the instruction. + + @return The number of tokens, which is 4. + """ + return 4 + + @classmethod + def _get_name(cls) -> str: + """ + @brief Gets the name of the instruction. + + @return The name of the instruction. + """ + return MemInfo.Const.Keyword.KEYGEN diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dload.py b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dload.py new file mode 100644 index 00000000..e0902735 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dload.py @@ -0,0 +1,48 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + +""" +@brief This module implements the DLoad instruction for loading data from memory. + +The DLoad instruction is used to load data from specified memory locations +during the assembly process. +""" + +from assembler.memory_model.mem_info import MemInfo +from .dinstruction import DInstruction + + +class Instruction(DInstruction): + """ + @brief Encapsulates a `dload` DInstruction. + """ + + @classmethod + def _get_num_tokens(cls) -> int: + """ + @brief Gets the number of tokens required for the instruction. + + @return The number of tokens, which is 3. + """ + return 3 + + @classmethod + def _get_name(cls) -> str: + """ + @brief Gets the name of the instruction. + + @return The name of the instruction. + """ + return MemInfo.Const.Keyword.LOAD + + @property + def tokens(self) -> list: + """ + @brief Gets the list of tokens for the instruction. + + @return The list of tokens. + """ + return [self.name, self._tokens[1], str(self.address)] + self._tokens[3:] diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dstore.py b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dstore.py new file mode 100644 index 00000000..61d99546 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/dinst/dstore.py @@ -0,0 +1,48 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + +""" +@brief This module implements the DStore instruction for storing data to memory. + +The DStore instruction is used to store data to specified memory locations +during the assembly process. +""" + +from assembler.memory_model.mem_info import MemInfo +from .dinstruction import DInstruction + + +class Instruction(DInstruction): + """ + @brief Encapsulates a `dstore` DInstruction. + """ + + @classmethod + def _get_num_tokens(cls) -> int: + """ + @brief Gets the number of tokens required for the instruction. + + @return The number of tokens, which is 3. + """ + return 3 + + @classmethod + def _get_name(cls) -> str: + """ + @brief Gets the name of the instruction. + + @return The name of the instruction. + """ + return MemInfo.Const.Keyword.STORE + + @property + def tokens(self) -> list: + """ + @brief Gets the list of tokens for the instruction. + + @return The list of tokens. + """ + return [self.name, self.var, str(self.address)] + self._tokens[3:] diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py b/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py index ad029607..7a0ff43a 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py @@ -1,133 +1,156 @@ -from assembler.common.decorators import * +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + +""" +@brief Base class for all instructions in the linker. +""" + +from assembler.common.decorators import classproperty from assembler.common.counter import Counter +from assembler.common.config import GlobalConfig + class BaseInstruction: """ - Base class for all instructions. + @brief Base class for all instructions. This class provides common functionality for all instructions in the linker. - Class Properties: - name (str): Returns the name of the represented operation. - - Attributes: - comment (str): Comment for the instruction. + @var comment Comment for the instruction. - Properties: - tokens (list[str]): List of tokens for the instruction. - id (int): Unique instruction ID. This is a unique nonce representing the instruction. + @property name Returns the name of the represented operation. + @property tokens List of tokens for the instruction. + @property id Unique instruction ID. This is a unique nonce representing the instruction. - Methods: - to_line(self) -> str: - Retrieves the string form of the instruction to write to the instruction file. + @fn to_line Retrieves the string form of the instruction to write to the instruction file. """ - __id_count = Counter.count(0) # Internal unique sequence counter to generate unique IDs + __id_count = Counter.count( + 0 + ) # Internal unique sequence counter to generate unique IDs # Class methods and properties # ---------------------------- @classproperty - def name(cls) -> str: + def name(self) -> str: """ - Name for the instruction. + @brief Name for the instruction. - Returns: - str: The name of the instruction. + @return The name of the instruction. """ - return cls._get_name() + return self._get_name() @classmethod def _get_name(cls) -> str: """ - Derived classes should implement this method and return correct + @brief Derived classes should implement this method and return correct name for the instruction. - Raises: - NotImplementedError: Abstract method. This base method should not be called. + @throws NotImplementedError Abstract method. This base method should not be called. """ raise NotImplementedError() @classproperty - def NAME_TOKEN_INDEX(cls) -> int: + def name_token_index(self) -> int: """ - Index for the token containing the name of the instruction + @brief Index for the token containing the name of the instruction in the list of tokens. - Returns: - int: The index of the name token. + @return The index of the name token. """ - return cls._get_name_token_index() + return self._get_name_token_index() @classmethod def _get_name_token_index(cls) -> int: """ - Derived classes should implement this method and return correct + @brief Derived classes should implement this method and return correct index for the token containing the name of the instruction in the list of tokens. - Raises: - NotImplementedError: Abstract method. This base method should not be called. + @throws NotImplementedError Abstract method. This base method should not be called. """ raise NotImplementedError() @classproperty - def NUM_TOKENS(cls) -> int: + def num_tokens(self) -> int: """ - Number of tokens required for this instruction. + @brief Number of tokens required for this instruction. - Returns: - int: The number of tokens required. + @return The number of tokens required. """ - return cls._get_num_tokens() + return self._get_num_tokens() @classmethod def _get_num_tokens(cls) -> int: """ - Derived classes should implement this method and return correct + @brief Derived classes should implement this method and return correct required number of tokens for the instruction. - Raises: - NotImplementedError: Abstract method. This base method should not be called. + @throws NotImplementedError Abstract method. This base method should not be called. """ raise NotImplementedError() + @classmethod + def dump_instructions_to_file(cls, instructions: list, filename: str): + """ + @brief Writes a list of instruction objects to a file, one per line. + + Each instruction is converted to its string representation using the `to_line()` method. + + @param instructions List of instruction objects (must have a to_line() method). + @param filename Path to the output file. + """ + with open(filename, "w", encoding="utf-8") as f: + for instr in instructions: + f.write(instr.to_line() + "\n") + # Constructor # ----------- def __init__(self, tokens: list, comment: str = ""): """ - Creates a new BaseInstruction object. - - Parameters: - tokens (list): List of tokens for the instruction. - comment (str): Optional comment for the instruction. + @brief Creates a new BaseInstruction object. - Raises: - ValueError: If the number of tokens is invalid or the instruction name is incorrect. + @param tokens List of tokens for the instruction. + @param comment Optional comment for the instruction. + @throws ValueError If the number of tokens is invalid or the instruction name is incorrect. """ - assert self.NAME_TOKEN_INDEX < self.NUM_TOKENS + assert self.name_token_index < self.num_tokens - if len(tokens) != self.NUM_TOKENS: - raise ValueError(('`tokens`: invalid amount of tokens. ' - 'Instruction {} requires {}, but {} received').format(self.name, - self.NUM_TOKENS, - len(tokens))) - if tokens[self.NAME_TOKEN_INDEX] != self.name: - raise ValueError('`tokens`: invalid name. Expected {}, but {} received'.format(self.name, - tokens[self.NAME_TOKEN_INDEX])) + self._validate_tokens(tokens) - self.__id = next(BaseInstruction.__id_count) + self._id = next(BaseInstruction.__id_count) - self.__tokens = list(tokens) + self._tokens = list(tokens) self.comment = comment + def _validate_tokens(self, tokens: list) -> None: + """ + @brief Validates the tokens for this instruction. + + Default implementation checks for exact token count match. + Child classes can override this method to implement different validation logic. + + @param tokens List of tokens to validate. + @throws ValueError If tokens are invalid. + """ + if len(tokens) != self.num_tokens: # pylint: disable=W0143 + raise ValueError( + f"`tokens`: invalid amount of tokens. " + f"Instruction {self.name} requires exactly {self.num_tokens}, but {len(tokens)} received" + ) + + if tokens[self.name_token_index] != self.name: # pylint: disable=W0143 + raise ValueError( + f"`tokens`: invalid name. Expected {self.name}, but {tokens[self.name_token_index]} received" + ) + def __repr__(self): - retval = ('<{}({}, id={}) object at {}>(tokens={})').format(type(self).__name__, - self.name, - self.id, - hex(id(self)), - self.token) + retval = f"<{type(self).__name__}({self.name}, id={self.id}) object at {hex(id(self))}>(tokens={self.tokens})" return retval def __eq__(self, other): @@ -138,7 +161,7 @@ def __hash__(self): return hash(self.id) def __str__(self): - return f'{self.name}({self.id})' + return f"{self.name}({self.id})" # Methods and properties # ---------------------------- @@ -146,30 +169,32 @@ def __str__(self): @property def id(self) -> tuple: """ - Unique ID for the instruction. + @brief Unique ID for the instruction. This is a combination of the client ID specified during construction and a unique nonce per instruction. - Returns: - tuple: (client_id: int, nonce: int) where client_id is the id specified at construction. + @return (client_id: int, nonce: int) where client_id is the id specified at construction. """ - return self.__id + return self._id @property def tokens(self) -> list: """ - Gets the list of tokens for the instruction. + @brief Gets the list of tokens for the instruction. - Returns: - list: The list of tokens. + @return The list of tokens. """ - return self.__tokens + return self._tokens def to_line(self) -> str: """ - Retrieves the string form of the instruction to write to the instruction file. + @brief Retrieves the string form of the instruction to write to the instruction file. - Returns: - str: The string representation of the instruction. + @return The string representation of the instruction. """ - return ", ".join(self.tokens) \ No newline at end of file + comment_str = "" + if not GlobalConfig.suppress_comments: + comment_str = f" # {self.comment}" if self.comment else "" + + tokens_str = ", ".join(self.tokens) + return f"{tokens_str}{comment_str}" diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/minst/minstruction.py b/assembler_tools/hec-assembler-tools/linker/instructions/minst/minstruction.py index 00d40452..bb00f61f 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/minst/minstruction.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/minst/minstruction.py @@ -1,11 +1,28 @@ - +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# pylint: disable=duplicate-code +"""This module implements the base class for MInstructions.""" + from linker.instructions.instruction import BaseInstruction + class MInstruction(BaseInstruction): """ Represents an MInstruction, inheriting from BaseInstruction. """ + @classmethod + def _get_name(cls) -> str: + """ + Derived classes should implement this method and return correct + name for the instruction. + + Raises: + NotImplementedError: Abstract method. This base method should not be called. + """ + raise NotImplementedError() + @classmethod def _get_name_token_index(cls) -> int: """ @@ -16,6 +33,17 @@ def _get_name_token_index(cls) -> int: """ return 1 + @classmethod + def _get_num_tokens(cls) -> int: + """ + Derived classes should implement this method and return correct + required number of tokens for the instruction. + + Raises: + NotImplementedError: Abstract method. This base method should not be called. + """ + raise NotImplementedError() + # Constructor # ----------- @@ -39,4 +67,4 @@ def to_line(self) -> str: Returns: str: The string representation of the instruction, excluding the first token. """ - return ", ".join(self.tokens[1:]) \ No newline at end of file + return ", ".join(self.tokens[1:]) diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/xinst/xinstruction.py b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/xinstruction.py index 3a622539..43b2b9b4 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/xinst/xinstruction.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/xinstruction.py @@ -1,11 +1,27 @@ - +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""This module implements the XInstruction""" + from linker.instructions.instruction import BaseInstruction + class XInstruction(BaseInstruction): """ Represents an XInstruction, inheriting from BaseInstruction. """ + @classmethod + def _get_name(cls) -> str: + """ + Derived classes should implement this method and return correct + name for the instruction. + + Raises: + NotImplementedError: Abstract method. This base method should not be called. + """ + raise NotImplementedError() + @classmethod def _get_name_token_index(cls) -> int: """ @@ -17,6 +33,17 @@ def _get_name_token_index(cls) -> int: # Name at index 2. return 2 + @classmethod + def _get_num_tokens(cls) -> int: + """ + Derived classes should implement this method and return correct + required number of tokens for the instruction. + + Raises: + NotImplementedError: Abstract method. This base method should not be called. + """ + raise NotImplementedError() + # Constructor # ----------- @@ -44,7 +71,7 @@ def bundle(self) -> int: Raises: RuntimeError: If the bundle format is invalid. """ - if len(self.tokens[0]) < 2 or self.tokens[0][0] != 'F': + if len(self.tokens[0]) < 2 or self.tokens[0][0] != "F": raise RuntimeError(f'Invalid bundle format detected: "{self.tokens[0]}".') return int(self.tokens[0][1:]) @@ -60,5 +87,7 @@ def bundle(self, value: int): ValueError: If the value is negative. """ if value < 0: - raise ValueError(f'`value`: expected non-negative bundle index, but {value} received.') - self.tokens[0] = f'F{value}' \ No newline at end of file + raise ValueError( + f"`value`: expected non-negative bundle index, but {value} received." + ) + self.tokens[0] = f"F{value}" diff --git a/assembler_tools/hec-assembler-tools/linker/loader.py b/assembler_tools/hec-assembler-tools/linker/loader.py index 26894912..eeee007e 100644 --- a/assembler_tools/hec-assembler-tools/linker/loader.py +++ b/assembler_tools/hec-assembler-tools/linker/loader.py @@ -1,124 +1,144 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# These contents may have been developed with support from one +# or more Intel-operated generative artificial intelligence solutions + +""" +@brief This module provides functionality to load different types of instruction kernels +""" + from linker.instructions import minst from linker.instructions import cinst from linker.instructions import xinst +from linker.instructions import dinst from linker import instructions -def loadMInstKernel(line_iter) -> list: - """ - Loads MInstruction kernel from an iterator of lines. - - Parameters: - line_iter: An iterator over lines of MInstruction strings. - Returns: - list: A list of MInstruction objects. +def load_minst_kernel(line_iter) -> list: + """ + @brief Loads MInstruction kernel from an iterator of lines. - Raises: - RuntimeError: If a line cannot be parsed into an MInstruction. + @param line_iter An iterator over lines of MInstruction strings. + @return A list of MInstruction objects. + @throws RuntimeError If a line cannot be parsed into an MInstruction. """ retval = [] for idx, s_line in enumerate(line_iter): - minstr = instructions.fromStrLine(s_line, minst.factory()) + minstr = instructions.create_from_str_line(s_line, minst.factory()) if not minstr: - raise RuntimeError(f'Error parsing line {idx + 1}: {s_line}') + raise RuntimeError(f"Error parsing line {idx + 1}: {s_line}") retval.append(minstr) return retval -def loadMInstKernelFromFile(filename: str) -> list: - """ - Loads MInstruction kernel from a file. - - Parameters: - filename (str): The file containing MInstruction strings. - Returns: - list: A list of MInstruction objects. +def load_minst_kernel_from_file(filename: str) -> list: + """ + @brief Loads MInstruction kernel from a file. - Raises: - RuntimeError: If an error occurs while loading the file. + @param filename The file containing MInstruction strings. + @return A list of MInstruction objects. + @throws RuntimeError If an error occurs while loading the file. """ - with open(filename, 'r') as kernel_minsts: + with open(filename, "r", encoding="utf-8") as kernel_minsts: try: - return loadMInstKernel(kernel_minsts) + return load_minst_kernel(kernel_minsts) except Exception as e: raise RuntimeError(f'Error occurred loading file "{filename}"') from e -def loadCInstKernel(line_iter) -> list: - """ - Loads CInstruction kernel from an iterator of lines. - - Parameters: - line_iter: An iterator over lines of CInstruction strings. - Returns: - list: A list of CInstruction objects. +def load_cinst_kernel(line_iter) -> list: + """ + @brief Loads CInstruction kernel from an iterator of lines. - Raises: - RuntimeError: If a line cannot be parsed into a CInstruction. + @param line_iter An iterator over lines of CInstruction strings. + @return A list of CInstruction objects. + @throws RuntimeError If a line cannot be parsed into a CInstruction. """ retval = [] for idx, s_line in enumerate(line_iter): - cinstr = instructions.fromStrLine(s_line, cinst.factory()) + cinstr = instructions.create_from_str_line(s_line, cinst.factory()) if not cinstr: - raise RuntimeError(f'Error parsing line {idx + 1}: {s_line}') + raise RuntimeError(f"Error parsing line {idx + 1}: {s_line}") retval.append(cinstr) return retval -def loadCInstKernelFromFile(filename: str) -> list: - """ - Loads CInstruction kernel from a file. - - Parameters: - filename (str): The file containing CInstruction strings. - Returns: - list: A list of CInstruction objects. +def load_cinst_kernel_from_file(filename: str) -> list: + """ + @brief Loads CInstruction kernel from a file. - Raises: - RuntimeError: If an error occurs while loading the file. + @param filename The file containing CInstruction strings. + @return A list of CInstruction objects. + @throws RuntimeError If an error occurs while loading the file. """ - with open(filename, 'r') as kernel_cinsts: + with open(filename, "r", encoding="utf-8") as kernel_cinsts: try: - return loadCInstKernel(kernel_cinsts) + return load_cinst_kernel(kernel_cinsts) except Exception as e: raise RuntimeError(f'Error occurred loading file "{filename}"') from e -def loadXInstKernel(line_iter) -> list: - """ - Loads XInstruction kernel from an iterator of lines. - - Parameters: - line_iter: An iterator over lines of XInstruction strings. - Returns: - list: A list of XInstruction objects. +def load_xinst_kernel(line_iter) -> list: + """ + @brief Loads XInstruction kernel from an iterator of lines. - Raises: - RuntimeError: If a line cannot be parsed into an XInstruction. + @param line_iter An iterator over lines of XInstruction strings. + @return A list of XInstruction objects. + @throws RuntimeError If a line cannot be parsed into an XInstruction. """ retval = [] for idx, s_line in enumerate(line_iter): - xinstr = instructions.fromStrLine(s_line, xinst.factory()) + xinstr = instructions.create_from_str_line(s_line, xinst.factory()) if not xinstr: - raise RuntimeError(f'Error parsing line {idx + 1}: {s_line}') + raise RuntimeError(f"Error parsing line {idx + 1}: {s_line}") retval.append(xinstr) return retval -def loadXInstKernelFromFile(filename: str) -> list: + +def load_xinst_kernel_from_file(filename: str) -> list: """ - Loads XInstruction kernel from a file. + @brief Loads XInstruction kernel from a file. - Parameters: - filename (str): The file containing XInstruction strings. + @param filename The file containing XInstruction strings. + @return A list of XInstruction objects. + @throws RuntimeError If an error occurs while loading the file. + """ + with open(filename, "r", encoding="utf-8") as kernel_xinsts: + try: + return load_xinst_kernel(kernel_xinsts) + except Exception as e: + raise RuntimeError(f'Error occurred loading file "{filename}"') from e - Returns: - list: A list of XInstruction objects. - Raises: - RuntimeError: If an error occurs while loading the file. +def load_dinst_kernel(line_iter) -> list: """ - with open(filename, 'r') as kernel_xinsts: + @brief Loads DInstruction kernel from an iterator of lines. + + @param line_iter An iterator over lines of DInstruction strings. + @return A list of DInstruction objects. + @throws RuntimeError If a line cannot be parsed into an DInstruction. + """ + retval = [] + for idx, s_line in enumerate(line_iter): + dinstr = dinst.create_from_mem_line(s_line) + if not dinstr: + raise RuntimeError(f"Error parsing line {idx + 1}: {s_line}") + retval.append(dinstr) + + return retval + + +def load_dinst_kernel_from_file(filename: str) -> list: + """ + @brief Loads DInstruction kernel from a file. + + @param filename The file containing DInstruction strings. + @return A list of DInstruction objects. + @throws RuntimeError If an error occurs while loading the file. + """ + with open(filename, "r", encoding="utf-8") as kernel_dinsts: try: - return loadXInstKernel(kernel_xinsts) + return load_dinst_kernel(kernel_dinsts) except Exception as e: - raise RuntimeError(f'Error occurred loading file "{filename}"') from e \ No newline at end of file + raise RuntimeError(f'Error occurred loading file "{filename}"') from e diff --git a/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py b/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py index 6f6c1999..4096bd9e 100644 --- a/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py +++ b/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py @@ -1,12 +1,23 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + +"""@brief This module provides functionality to link kernels into a program.""" + +from typing import Dict, Any, cast from linker import MemoryModel -from linker.instructions import minst, cinst, xinst +from linker.instructions import minst, cinst, dinst +from linker.instructions.dinst.dinstruction import DInstruction from assembler.common.config import GlobalConfig from assembler.instructions import cinst as ISACInst -class LinkedProgram: + +class LinkedProgram: # pylint: disable=too-many-instance-attributes """ - Encapsulates a linked program. + @class LinkedProgram + @brief Encapsulates a linked program. This class offers facilities to track and link kernels, and outputs the linked program to specified output streams as kernels @@ -15,338 +26,473 @@ class LinkedProgram: The program itself is not contained in this object. """ - def __init__(self, - program_minst_ostream, - program_cinst_ostream, - program_xinst_ostream, - mem_model: MemoryModel, - supress_comments: bool): - """ - Initializes a LinkedProgram object. - - Parameters: - program_minst_ostream: Output stream for MInst instructions. - program_cinst_ostream: Output stream for CInst instructions. - program_xinst_ostream: Output stream for XInst instructions. - mem_model (MemoryModel): Correctly initialized linker memory model. It must already contain the - variables used throughout the program and their usage. - This memory model will be modified by this object when linking kernels. - supress_comments (bool): Whether to suppress comments in the output. - """ - self.__minst_ostream = program_minst_ostream - self.__cinst_ostream = program_cinst_ostream - self.__xinst_ostream = program_xinst_ostream - self.__mem_model = mem_model - self.__supress_comments = supress_comments - self.__bundle_offset = 0 - self.__minst_line_offset = 0 - self.__cinst_line_offset = 0 - self.__xinst_line_offset = 0 - self.__kernel_count = 0 # Number of kernels linked into this program - self.__is_open = True # Tracks whether this program is still accepting kernels to link - - @property - def isOpen(self) -> bool: + def __init__( + self, + program_minst_ostream, + program_cinst_ostream, + program_xinst_ostream, + mem_model: MemoryModel, + ): """ - Checks if the program is open for linking new kernels. - - Returns: - bool: True if the program is open, False otherwise. + @brief Initializes a LinkedProgram object. + + @param program_minst_ostream Output stream for MInst instructions. + @param program_cinst_ostream Output stream for CInst instructions. + @param program_xinst_ostream Output stream for XInst instructions. + @param mem_model (MemoryModel): Correctly initialized linker memory model. It must already contain the + variables used throughout the program and their usage. + This memory model will be modified by this object when linking kernels. + @param suppress_comments (bool): Whether to suppress comments in the output. """ - return self.__is_open + self._minst_ostream = program_minst_ostream + self._cinst_ostream = program_cinst_ostream + self._xinst_ostream = program_xinst_ostream + self.__mem_model = mem_model + self._bundle_offset = 0 + self._minst_line_offset = 0 + self._cinst_line_offset = 0 + self._kernel_count = 0 # Number of kernels linked into this program + self._is_open = ( + True # Tracks whether this program is still accepting kernels to link + ) @property - def supressComments(self) -> bool: + def is_open(self) -> bool: """ - Checks if comments are suppressed in the output. + @brief Checks if the program is open for linking new kernels. - Returns: - bool: True if comments are suppressed, False otherwise. + @return bool True if the program is open, False otherwise. """ - return self.__supress_comments + return self._is_open def close(self): """ - Completes the program by terminating the queues with the correct exit code. + @brief Completes the program by terminating the queues with the correct exit code. Program will not accept new kernels to link after this call. - Raises: - RuntimeError: If the program is already closed. + @exception RuntimeError If the program is already closed. """ - if not self.isOpen: - raise RuntimeError('Program is already closed.') + if not self.is_open: + raise RuntimeError("Program is already closed.") # Add closing `cexit` - tokens = [str(self.__cinst_line_offset), cinst.CExit.name] + tokens = [str(self._cinst_line_offset), cinst.CExit.name] cexit_cinstr = cinst.CExit(tokens) - print(f'{cexit_cinstr.tokens[0]}, {cexit_cinstr.to_line()}', file=self.__cinst_ostream) + print( + f"{cexit_cinstr.tokens[0]}, {cexit_cinstr.to_line()}", + file=self._cinst_ostream, + ) # Add closing msyncc - tokens = [str(self.__minst_line_offset), minst.MSyncc.name, str(self.__cinst_line_offset + 1)] + tokens = [ + str(self._minst_line_offset), + minst.MSyncc.name, + str(self._cinst_line_offset + 1), + ] cmsyncc_minstr = minst.MSyncc(tokens) - print(f'{cmsyncc_minstr.tokens[0]}, {cmsyncc_minstr.to_line()}', end="", file=self.__minst_ostream) - if not self.supressComments: - print(' # terminating MInstQ', end="", file=self.__minst_ostream) - print(file=self.__minst_ostream) + print( + f"{cmsyncc_minstr.tokens[0]}, {cmsyncc_minstr.to_line()}", + end="", + file=self._minst_ostream, + ) + if not GlobalConfig.suppress_comments: + print(" # terminating MInstQ", end="", file=self._minst_ostream) + print(file=self._minst_ostream) # Program has been closed - self.__is_open = False + self._is_open = False - def __validateHBMAddress(self, var_name: str, hbm_address: int): + def _validate_hbm_address(self, var_name: str, hbm_address: int): """ - Validates the HBM address for a variable. + @brief Validates the HBM address for a variable. - Parameters: - var_name (str): The name of the variable. - hbm_address (int): The HBM address to validate. + @param var_name The name of the variable. + @param hbm_address The HBM address to validate. - Raises: - RuntimeError: If the HBM address is invalid or does not match the declared address. + @exception RuntimeError If the HBM address is invalid or does not match the declared address. """ if hbm_address < 0: - raise RuntimeError(f'Invalid negative HBM address for variable "{var_name}".') + raise RuntimeError( + f'Invalid negative HBM address for variable "{var_name}".' + ) if var_name in self.__mem_model.mem_info_vars: - if self.__mem_model.mem_info_vars[var_name].hbm_address != hbm_address: - raise RuntimeError(('Declared HBM address ({}) of mem Variable "{}"' - ' differs from allocated HBM address ({}).').format(self.__mem_model.mem_info_vars[var_name].hbm_address, - var_name, - hbm_address)) + # Cast to dictionary to fix the indexing error + mem_info_vars_dict = cast(Dict[str, Any], self.__mem_model.mem_info_vars) + if mem_info_vars_dict[var_name].hbm_address != hbm_address: + raise RuntimeError( + ( + f"Declared HBM address " + f"({mem_info_vars_dict[var_name].hbm_address})" + f" of mem Variable '{var_name}'" + f" differs from allocated HBM address ({hbm_address})." + ) + ) + + def _validate_spad_address(self, var_name: str, spad_address: int): + """ + @brief Validates the SPAD address for a variable (only available when no HBM). + + @param var_name The name of the variable. + @param spad_address The SPAD address to validate. - def __validateSPADAddress(self, var_name: str, spad_address: int): + @exception RuntimeError If the SPAD address is invalid or does not match the declared address. + """ # only available when no HBM assert not GlobalConfig.hasHBM # this method will validate the variable SPAD address against the - # original HBM address, since ther is no HBM + # original HBM address, since there is no HBM if spad_address < 0: - raise RuntimeError(f'Invalid negative SPAD address for variable "{var_name}".') + raise RuntimeError( + f'Invalid negative SPAD address for variable "{var_name}".' + ) if var_name in self.__mem_model.mem_info_vars: - if self.__mem_model.mem_info_vars[var_name].hbm_address != spad_address: - raise RuntimeError(('Declared HBM address ({}) of mem Variable "{}"' - ' differs from allocated HBM address ({}).').format(self.__mem_model.mem_info_vars[var_name].hbm_address, - var_name, - spad_address)) - - def __updateMInsts(self, kernel_minstrs: list): + # Cast to dictionary to fix the indexing error + mem_info_vars_dict = cast(Dict[str, Any], self.__mem_model.mem_info_vars) + if mem_info_vars_dict[var_name].hbm_address != spad_address: + raise RuntimeError( + ( + f"Declared HBM address" + f" ({mem_info_vars_dict[var_name].hbm_address})" + f" of mem Variable '{var_name}'" + f" differs from allocated HBM address ({spad_address})." + ) + ) + + def _update_minsts(self, kernel_minstrs: list): """ - Updates the MInsts in the kernel to offset to the current expected + @brief Updates the MInsts in the kernel to offset to the current expected synchronization points, and convert variable placeholders/names into the corresponding HBM address. All MInsts in the kernel are expected to synchronize with CInsts starting at line 0. - Does not change the `LinkedProgram` object. + Does not change the LinkedProgram object. - Parameters: - kernel_minstrs (list): List of MInstructions to update. + @param kernel_minstrs List of MInstructions to update. """ for minstr in kernel_minstrs: # Update msyncc if isinstance(minstr, minst.MSyncc): - minstr.target = minstr.target + self.__cinst_line_offset + minstr.target = minstr.target + self._cinst_line_offset # Change mload variable names into HBM addresses if isinstance(minstr, minst.MLoad): var_name = minstr.source - hbm_address = self.__mem_model.useVariable(var_name, self.__kernel_count) - self.__validateHBMAddress(var_name, hbm_address) + hbm_address = self.__mem_model.use_variable( + var_name, self._kernel_count + ) + self._validate_hbm_address(var_name, hbm_address) minstr.source = str(hbm_address) - minstr.comment = f" var: {var_name} - HBM({hbm_address})" + f";{minstr.comment}" if minstr.comment else "" + minstr.comment = ( + f" var: {var_name} - HBM({hbm_address})" + f";{minstr.comment}" + if minstr.comment + else "" + ) # Change mstore variable names into HBM addresses if isinstance(minstr, minst.MStore): var_name = minstr.dest - hbm_address = self.__mem_model.useVariable(var_name, self.__kernel_count) - self.__validateHBMAddress(var_name, hbm_address) + hbm_address = self.__mem_model.use_variable( + var_name, self._kernel_count + ) + self._validate_hbm_address(var_name, hbm_address) minstr.dest = str(hbm_address) - minstr.comment = f" var: {var_name} - HBM({hbm_address})" + f";{minstr.comment}" if minstr.comment else "" + minstr.comment = ( + f" var: {var_name} - HBM({hbm_address})" + f";{minstr.comment}" + if minstr.comment + else "" + ) - def __updateCInsts(self, kernel_cinstrs: list): + def _remove_and_merge_csyncm_cnop(self, kernel_cinstrs: list): """ - Updates the CInsts in the kernel to offset to the current expected bundle - and synchronization points. + @brief Remove csyncm instructions and merge consecutive cnop instructions. + + @param kernel_cinstrs List of CInstructions to process. + """ + i = 0 + current_bundle = 0 + csyncm_count = 0 + while i < len(kernel_cinstrs): + cinstr = kernel_cinstrs[i] + cinstr.tokens[0] = str(i) # Update the line number + + # ------------------------------ + # This code block will remove csyncm instructions and keep track, + # later adding their throughput into a cnop instruction before + # a new bundle is fetched. + + if isinstance(cinstr, cinst.CNop): + # Add the missing cycles to any cnop we encounter up to this point + cinstr.cycles += csyncm_count * ISACInst.CSyncm.get_throughput() + # Idle cycles to account for the csyncm have been added + csyncm_count = 0 + + if isinstance(cinstr, (cinst.IFetch, cinst.NLoad, cinst.BLoad)): + if csyncm_count > 0: + # Extra cycles needed before scheduling next bundle + # Subtract 1 because cnop n, waits for n+1 cycles + cinstr_nop = cinst.CNop( + [ + i, + cinst.CNop.name, + str(csyncm_count * ISACInst.CSyncm.get_throughput() - 1), + ] + ) + kernel_cinstrs.insert(i, cinstr_nop) + csyncm_count = 0 + i += 1 + if isinstance(cinstr, cinst.IFetch): + current_bundle = cinstr.bundle + 1 + # Update the line number + cinstr.tokens[0] = i + + if isinstance(cinstr, cinst.CSyncm): + # Remove instruction + kernel_cinstrs.pop(i) + if current_bundle > 0: + csyncm_count += 1 + else: + i += 1 + + # ------------------------------ + # This code block differs from previous in that csyncm instructions + # are replaced in place by cnops with the corresponding throughput. + # This may result in several continuous cnop instructions, so, + # the cnop merging code afterwards is needed to remove this side effect + # if contiguous cnops are not desired. + + # if isinstance(cinstr, cinst.IFetch): + # current_bundle = cinstr.bundle + 1 + # + # if isinstance(cinstr, cinst.CSyncm): + # # replace instruction by cnop + # kernel_cinstrs.pop(i) + # if current_bundle > 0: + # cinstr_nop = cinst.CNop([i, cinst.CNop.name, str(ISACInst.CSyncm.get_throughput())]) # Subtract 1 because cnop n, waits for n+1 cycles + # kernel_cinstrs.insert(i, cinstr_nop) + # + # i += 1 # next instruction + + # Merge continuous cnop + i = 0 + while i < len(kernel_cinstrs): + cinstr = kernel_cinstrs[i] + cinstr.tokens[0] = str(i) + if isinstance(cinstr, cinst.CNop): + # Do look ahead + if i + 1 < len(kernel_cinstrs): + if isinstance(kernel_cinstrs[i + 1], cinst.CNop): + # Add 1 because cnop n, waits for n+1 cycles + kernel_cinstrs[i + 1].cycles += cinstr.cycles + 1 + kernel_cinstrs.pop(i) + i -= 1 + i += 1 + + def _update_cinsts_addresses_and_offsets(self, kernel_cinstrs: list): + """ + @brief Updates bundle/target offsets and variable names to addresses for CInsts. All CInsts in the kernel are expected to start at bundle 0, and to synchronize with MInsts starting at line 0. - Does not change the `LinkedProgram` object. + Does not change the LinkedProgram object. - Parameters: - kernel_cinstrs (list): List of CInstructions to update. + @param kernel_cinstrs List of CInstructions to update. """ - - if not GlobalConfig.hasHBM: - # Remove csyncm instructions - i = 0 - current_bundle = 0 - csyncm_count = 0 # Used by 1st code block: plz remove if second code block ends up being the one used - while i < len(kernel_cinstrs): - cinstr = kernel_cinstrs[i] - cinstr.tokens[0] = i # Update the line number - - #------------------------------ - # This code block will remove csyncm instructions and keep track, - # later adding their throughput into a cnop instruction before - # a new bundle is fetched. - - if isinstance(cinstr, cinst.CNop): - # Add the missing cycles to any cnop we encounter up to this point - cinstr.cycles += (csyncm_count * ISACInst.CSyncm.get_throughput()) - csyncm_count = 0 # Idle cycles to account for the csyncm have been added - - if isinstance(cinstr, (cinst.IFetch, cinst.NLoad, cinst.BLoad)): - if csyncm_count > 0: - # Extra cycles needed before scheduling next bundle - cinstr_nop = cinst.CNop([i, cinst.CNop.name, str(csyncm_count * ISACInst.CSyncm.get_throughput() - 1)]) # Subtract 1 because cnop n, waits for n+1 cycles - kernel_cinstrs.insert(i, cinstr_nop) - csyncm_count = 0 # Idle cycles to account for the csyncm have been added - i += 1 - if isinstance(cinstr, cinst.IFetch): - current_bundle = cinstr.bundle + 1 - cinstr.tokens[0] = i # Update the line number - - if isinstance(cinstr, cinst.CSyncm): - # Remove instruction - kernel_cinstrs.pop(i) - if current_bundle > 0: - csyncm_count += 1 - else: - i += 1 # Next instruction - - #------------------------------ - # This code block differs from previous in that csyncm instructions - # are replaced in place by cnops with the corresponding throughput. - # This may result in several continuous cnop instructions, so, - # the cnop merging code afterwards is needed to remove this side effect - # if contiguous cnops are not desired. - - # if isinstance(cinstr, cinst.IFetch): - # current_bundle = cinstr.bundle + 1 - # - # if isinstance(cinstr, cinst.CSyncm): - # # replace instruction by cnop - # kernel_cinstrs.pop(i) - # if current_bundle > 0: - # cinstr_nop = cinst.CNop([i, cinst.CNop.name, str(ISACInst.CSyncm.get_throughput())]) # Subtract 1 because cnop n, waits for n+1 cycles - # kernel_cinstrs.insert(i, cinstr_nop) - # - # i += 1 # next instruction - - # Merge continuous cnop - i = 0 - while i < len(kernel_cinstrs): - cinstr = kernel_cinstrs[i] - cinstr.tokens[0] = i # Update the line number - - if isinstance(cinstr, cinst.CNop): - # Do look ahead - if i + 1 < len(kernel_cinstrs): - if isinstance(kernel_cinstrs[i + 1], cinst.CNop): - kernel_cinstrs[i + 1].cycles += (cinstr.cycles + 1) # Add 1 because cnop n, waits for n+1 cycles - kernel_cinstrs.pop(i) - i -= 1 - i += 1 - for cinstr in kernel_cinstrs: # Update ifetch if isinstance(cinstr, cinst.IFetch): - cinstr.bundle = cinstr.bundle + self.__bundle_offset + cinstr.bundle = cinstr.bundle + self._bundle_offset # Update xinstfetch if isinstance(cinstr, cinst.XInstFetch): - raise NotImplementedError('`xinstfetch` not currently supported by linker.') + raise NotImplementedError( + "`xinstfetch` not currently supported by linker." + ) # Update csyncm if isinstance(cinstr, cinst.CSyncm): - cinstr.target = cinstr.target + self.__minst_line_offset + cinstr.target = cinstr.target + self._minst_line_offset if not GlobalConfig.hasHBM: # update all SPAD instruction variable names to be SPAD addresses # change xload variable names into SPAD addresses - if isinstance(cinstr, (cinst.BLoad, cinst.BOnes, cinst.CLoad, cinst.NLoad)): + if isinstance( + cinstr, (cinst.BLoad, cinst.BOnes, cinst.CLoad, cinst.NLoad) + ): var_name = cinstr.source - hbm_address = self.__mem_model.useVariable(var_name, self.__kernel_count) - self.__validateSPADAddress(var_name, hbm_address) + hbm_address = self.__mem_model.use_variable( + var_name, self._kernel_count + ) + self._validate_spad_address(var_name, hbm_address) cinstr.source = str(hbm_address) - cinstr.comment = f" var: {var_name} - HBM({hbm_address})" + f";{cinstr.comment}" if cinstr.comment else "" + cinstr.comment = ( + f" var: {var_name} - HBM({hbm_address})" + f";{cinstr.comment}" + if cinstr.comment + else "" + ) if isinstance(cinstr, cinst.CStore): var_name = cinstr.dest - hbm_address = self.__mem_model.useVariable(var_name, self.__kernel_count) - self.__validateSPADAddress(var_name, hbm_address) + hbm_address = self.__mem_model.use_variable( + var_name, self._kernel_count + ) + self._validate_spad_address(var_name, hbm_address) cinstr.dest = str(hbm_address) - cinstr.comment = f" var: {var_name} - HBM({hbm_address})" + f";{cinstr.comment}" if cinstr.comment else "" + cinstr.comment = ( + f" var: {var_name} - HBM({hbm_address})" + f";{cinstr.comment}" + if cinstr.comment + else "" + ) + + def _update_cinsts(self, kernel_cinstrs: list): + """ + @brief Updates the CInsts in the kernel to offset to the current expected bundle + and synchronization points. - def __updateXInsts(self, kernel_xinstrs: list) -> int: + All CInsts in the kernel are expected to start at bundle 0, and to + synchronize with MInsts starting at line 0. + Does not change the LinkedProgram object. + + @param kernel_cinstrs List of CInstructions to update. """ - Updates the XInsts in the kernel to offset to the current expected bundle. + if not GlobalConfig.hasHBM: + self._remove_and_merge_csyncm_cnop(kernel_cinstrs) + + self._update_cinsts_addresses_and_offsets(kernel_cinstrs) + + def _update_xinsts(self, kernel_xinstrs: list) -> int: + """ + @brief Updates the XInsts in the kernel to offset to the current expected bundle. All XInsts in the kernel are expected to start at bundle 0. - Does not change the `LinkedProgram` object. + Does not change the LinkedProgram object. - Parameters: - kernel_xinstrs (list): List of XInstructions to update. + @param kernel_xinstrs List of XInstructions to update. - Returns: - int: The last bundle number after updating. + @return int The last bundle number after updating. """ - last_bundle = self.__bundle_offset + last_bundle = self._bundle_offset for xinstr in kernel_xinstrs: - xinstr.bundle = xinstr.bundle + self.__bundle_offset + xinstr.bundle = xinstr.bundle + self._bundle_offset if last_bundle > xinstr.bundle: - raise RuntimeError(f'Detected invalid bundle. Instruction bundle is less than previous: "{xinstr.to_line()}"') + raise RuntimeError( + f'Detected invalid bundle. Instruction bundle is less than previous: "{xinstr.to_line()}"' + ) last_bundle = xinstr.bundle return last_bundle - def linkKernel(self, - kernel_minstrs: list, - kernel_cinstrs: list, - kernel_xinstrs: list): + def link_kernel( + self, kernel_minstrs: list, kernel_cinstrs: list, kernel_xinstrs: list + ): """ - Links a specified kernel (given by its three instruction queues) into this + @brief Links a specified kernel (given by its three instruction queues) into this program. The adjusted kernels will be appended into the output streams specified during construction of this object. - Parameters: - kernel_minstrs (list): List of MInstructions for the MInst Queue corresponding to the kernel to link. - These instructions will be modified by this method. - kernel_cinstrs (list): List of CInstructions for the CInst Queue corresponding to the kernel to link. - These instructions will be modified by this method. - kernel_xinstrs (list): List of XInstructions for the XInst Queue corresponding to the kernel to link. - These instructions will be modified by this method. + @param kernel_minstrs List of MInstructions for the MInst Queue corresponding to the kernel to link. + These instructions will be modified by this method. + @param kernel_cinstrs List of CInstructions for the CInst Queue corresponding to the kernel to link. + These instructions will be modified by this method. + @param kernel_xinstrs List of XInstructions for the XInst Queue corresponding to the kernel to link. + These instructions will be modified by this method. - Raises: - RuntimeError: If the program is closed and does not accept new kernels. + @exception RuntimeError If the program is closed and does not accept new kernels. """ - if not self.isOpen: - raise RuntimeError('Program is closed and does not accept new kernels.') + if not self.is_open: + raise RuntimeError("Program is closed and does not accept new kernels.") # No minsts without HBM if not GlobalConfig.hasHBM: kernel_minstrs = [] - self.__updateMInsts(kernel_minstrs) - self.__updateCInsts(kernel_cinstrs) - self.__bundle_offset = self.__updateXInsts(kernel_xinstrs) + 1 + self._update_minsts(kernel_minstrs) + self._update_cinsts(kernel_cinstrs) + self._bundle_offset = self._update_xinsts(kernel_xinstrs) + 1 # Append the kernel to the output for xinstr in kernel_xinstrs: - print(xinstr.to_line(), end="", file=self.__xinst_ostream) - if not self.supressComments and xinstr.comment: - print(f' #{xinstr.comment}', end="", file=self.__xinst_ostream) - print(file=self.__xinst_ostream) + print(xinstr.to_line(), end="", file=self._xinst_ostream) + if not GlobalConfig.suppress_comments and xinstr.comment: + print(f" #{xinstr.comment}", end="", file=self._xinst_ostream) + print(file=self._xinst_ostream) for idx, cinstr in enumerate(kernel_cinstrs[:-1]): # Skip the `cexit` - line_no = idx + self.__cinst_line_offset - print(f'{line_no}, {cinstr.to_line()}', end="", file=self.__cinst_ostream) - if not self.supressComments and cinstr.comment: - print(f' #{cinstr.comment}', end="", file=self.__cinst_ostream) - print(file=self.__cinst_ostream) + line_no = idx + self._cinst_line_offset + print(f"{line_no}, {cinstr.to_line()}", end="", file=self._cinst_ostream) + if not GlobalConfig.suppress_comments and cinstr.comment: + print(f" #{cinstr.comment}", end="", file=self._cinst_ostream) + print(file=self._cinst_ostream) for idx, minstr in enumerate(kernel_minstrs[:-1]): # Skip the exit `msyncc` - line_no = idx + self.__minst_line_offset - print(f'{line_no}, {minstr.to_line()}', end="", file=self.__minst_ostream) - if not self.supressComments and minstr.comment: - print(f' #{minstr.comment}', end="", file=self.__minst_ostream) - print(file=self.__minst_ostream) - - self.__minst_line_offset += (len(kernel_minstrs) - 1) # Subtract last line that is getting removed - self.__cinst_line_offset += (len(kernel_cinstrs) - 1) # Subtract last line that is getting removed - self.__kernel_count += 1 # Count the appended kernel \ No newline at end of file + line_no = idx + self._minst_line_offset + print(f"{line_no}, {minstr.to_line()}", end="", file=self._minst_ostream) + if not GlobalConfig.suppress_comments and minstr.comment: + print(f" #{minstr.comment}", end="", file=self._minst_ostream) + print(file=self._minst_ostream) + + self._minst_line_offset += ( + len(kernel_minstrs) - 1 + ) # Subtract last line that is getting removed + self._cinst_line_offset += ( + len(kernel_cinstrs) - 1 + ) # Subtract last line that is getting removed + self._kernel_count += 1 # Count the appended kernel + + @classmethod + def join_dinst_kernels( + cls, kernels_instrs: list[list[DInstruction]] + ) -> list[DInstruction]: + """ + @brief Joins a list of dinst kernels, consolidating variables that are outputs in one kernel + and inputs in the next. This ensures that variables carried across kernels are not duplicated, + and their Mem addresses are consistent. + + @param kernels_instrs List of Kernels' DInstructions lists. + + @return list[DInstruction] A new instruction list representing the concatenated memory info. + + @exception ValueError If no DInstructions lists are provided for concatenation. + """ + + if not kernels_instrs: + raise ValueError("No DInstructions lists provided for concatenation.") + + # Use dictionaries to track unique variables by name + inputs: dict[str, DInstruction] = {} + carry_over_vars: dict[str, DInstruction] = {} + + mem_address: int = 0 + new_kernels_instrs: list[DInstruction] = [] + for kernel_instrs in kernels_instrs: + for cur_dinst in kernel_instrs: + + # Save the current output instruction to add at the end + if isinstance(cur_dinst, dinst.DStore): + key = cur_dinst.var + carry_over_vars[key] = cur_dinst + continue + + if isinstance(cur_dinst, (dinst.DLoad, dinst.DKeyGen)): + key = cur_dinst.var + # Skip if the input is already in carry-over from previous outputs + if key in carry_over_vars: + carry_over_vars.pop( + key + ) # Remove from (output) carry-overs since it's now an input + continue + + # If the input is not (a previous output) in carry-over, add if it's not already (loaded) in inputs + if key not in inputs: + inputs[key] = cur_dinst + cur_dinst.address = mem_address + mem_address = mem_address + 1 + + new_kernels_instrs.append(cur_dinst) + continue + + # Add remaining carry-over variables to the new instructions + for _, var in carry_over_vars.items(): + var.address = mem_address + new_kernels_instrs.append(var) + mem_address = mem_address + 1 + + return new_kernels_instrs diff --git a/assembler_tools/hec-assembler-tools/linker/steps/variable_discovery.py b/assembler_tools/hec-assembler-tools/linker/steps/variable_discovery.py index 69457786..c862de25 100644 --- a/assembler_tools/hec-assembler-tools/linker/steps/variable_discovery.py +++ b/assembler_tools/hec-assembler-tools/linker/steps/variable_discovery.py @@ -1,26 +1,30 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +@brief This module provides functionality to discover variable names in MInstructions and CInstructions. +""" from assembler.memory_model.variable import Variable from linker.instructions import minst, cinst from linker.instructions.minst.minstruction import MInstruction from linker.instructions.cinst.cinstruction import CInstruction -def discoverVariablesSPAD(cinstrs: list): + +def discover_variables_spad(cinstrs: list): """ - Finds Variable names used in a list of CInstructions. - - Attributes: - cinstrs (list[CInstruction]): - List of CInstructions where to find variable names. - Raises: - RuntimeError: - Invalid Variable name detected in an CInstruction. - Returns: - Iterable: - Yields an iterable over variable names identified in the listing + @brief Finds Variable names used in a list of CInstructions. + + @param cinstrs List of CInstructions where to find variable names. + @throws TypeError If an item in the list is not a valid CInstruction. + @throws RuntimeError If an invalid Variable name is detected in a CInstruction. + @return Yields an iterable over variable names identified in the listing of CInstructions specified. """ for idx, cinstr in enumerate(cinstrs): if not isinstance(cinstr, CInstruction): - raise TypeError(f'Item {idx} in list of MInstructions is not a valid MInstruction.') + raise TypeError( + f"Item {idx} in list of MInstructions is not a valid MInstruction." + ) retval = None # TODO: Implement variable counting for CInst ############### @@ -32,27 +36,27 @@ def discoverVariablesSPAD(cinstrs: list): if retval is not None: if not Variable.validateName(retval): - raise RuntimeError(f'Invalid Variable name "{retval}" detected in instruction "{idx}, {cinstr.to_line()}"') + raise RuntimeError( + f'Invalid Variable name "{retval}" detected in instruction "{idx}, {cinstr.to_line()}"' + ) yield retval -def discoverVariables(minstrs: list): - """ - Finds variable names used in a list of MInstructions. - Parameters: - minstrs (list[MInstruction]): List of MInstructions where to find variable names. - - Raises: - TypeError: If an item in the list is not a valid MInstruction. - RuntimeError: If an invalid variable name is detected in an MInstruction. +def discover_variables(minstrs: list): + """ + @brief Finds variable names used in a list of MInstructions. - Returns: - Iterable: Yields an iterable over variable names identified in the listing - of MInstructions specified. + @param minstrs List of MInstructions where to find variable names. + @throws TypeError If an item in the list is not a valid MInstruction. + @throws RuntimeError If an invalid variable name is detected in an MInstruction. + @return Yields an iterable over variable names identified in the listing + of MInstructions specified. """ for idx, minstr in enumerate(minstrs): if not isinstance(minstr, MInstruction): - raise TypeError(f'Item {idx} in list of MInstructions is not a valid MInstruction.') + raise TypeError( + f"Item {idx} in list of MInstructions is not a valid MInstruction." + ) retval = None if isinstance(minstr, minst.MLoad): retval = minstr.source @@ -61,5 +65,7 @@ def discoverVariables(minstrs: list): if retval is not None: if not Variable.validateName(retval): - raise RuntimeError(f'Invalid Variable name "{retval}" detected in instruction "{idx}, {minstr.to_line()}"') - yield retval \ No newline at end of file + raise RuntimeError( + f'Invalid Variable name "{retval}" detected in instruction "{idx}, {minstr.to_line()}"' + ) + yield retval diff --git a/assembler_tools/hec-assembler-tools/tests/__init__.py b/assembler_tools/hec-assembler-tools/tests/__init__.py new file mode 100644 index 00000000..58bdd8eb --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/__init__.py @@ -0,0 +1,4 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# Empty init file to make the directory a Python package diff --git a/assembler_tools/hec-assembler-tools/tests/conftest.py b/assembler_tools/hec-assembler-tools/tests/conftest.py index 5f7b42e1..b8bc75ec 100644 --- a/assembler_tools/hec-assembler-tools/tests/conftest.py +++ b/assembler_tools/hec-assembler-tools/tests/conftest.py @@ -2,14 +2,46 @@ # SPDX-License-Identifier: Apache-2.0 """ -Pytest configuration and fixtures for assembler_tools tests. +@file conftest.py +@brief Configuration and fixtures for pytest """ import os +import sys +from unittest.mock import patch + import pytest + from assembler.spec_config.isa_spec import ISASpecConfig from assembler.spec_config.mem_spec import MemSpecConfig +# Remove any existing paths that might conflict +for path in list(sys.path): + if "hec-assembler-tools" in path: + sys.path.remove(path) + +# Get the absolute path to the repository root directory +# Structure: /path/to/repo/encrypted-computing-sdk/assembler_tools/linker_sdk/tests/conftest.py +current_dir = os.path.dirname(os.path.abspath(__file__)) +linker_sdk_dir = os.path.dirname(current_dir) # linker_sdk directory +assembler_tools_dir = os.path.dirname(linker_sdk_dir) # assembler_tools directory +repo_root = os.path.dirname(assembler_tools_dir) # encrypted-computing-sdk directory + +# Add the paths to sys.path +sys.path.insert(0, linker_sdk_dir) +sys.path.insert(0, assembler_tools_dir) +sys.path.insert(0, repo_root) + + +@pytest.fixture(autouse=True) +def mock_env_variables(): + """ + @brief Fixture to mock environment variables and provide common mocks + """ + # Use the repository root in PYTHONPATH instead of an absolute path + with patch.dict("os.environ", {"PYTHONPATH": repo_root}): + yield + @pytest.fixture(scope="session", autouse=True) def initialize_specs(): @@ -22,7 +54,7 @@ def initialize_specs(): initialization methods for both ISASpecConfig and MemSpecConfig. Note: - This fixture is intended to be run from the assembler root directory. + This fixture is intended to be run from any location. Yields: None @@ -31,6 +63,6 @@ def initialize_specs(): Any exceptions raised by ISASpecConfig.initialize_isa_spec or MemSpecConfig.initialize_mem_spec will propagate. """ - module_dir = os.path.dirname(os.path.dirname(__file__)) + module_dir = linker_sdk_dir ISASpecConfig.initialize_isa_spec(module_dir, "") MemSpecConfig.initialize_mem_spec(module_dir, "") diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/__init__.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/__init__.py new file mode 100644 index 00000000..4057dc01 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/__init__.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/__init__.py new file mode 100644 index 00000000..4057dc01 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/memory_model/__init__.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/memory_model/__init__.py new file mode 100644 index 00000000..4057dc01 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/memory_model/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/memory_model/test_mem_info.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/memory_model/test_mem_info.py new file mode 100644 index 00000000..6858b858 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/memory_model/test_mem_info.py @@ -0,0 +1,973 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +@brief Unit tests for the memory model mem_info module. +""" + +import unittest +from unittest.mock import patch, MagicMock, call, PropertyMock + +from assembler.memory_model import MemoryModel +from assembler.memory_model.mem_info import ( + MemInfoVariable, + MemInfoKeygenVariable, + MemInfo, + updateMemoryModelWithMemInfo, + _allocateMemInfoVariable, +) + + +class TestMemInfoVariable(unittest.TestCase): + """@brief Tests for the MemInfoVariable class.""" + + def test_init_valid(self): + """@brief Test initialization with valid parameters.""" + with patch( + "assembler.memory_model.variable.Variable.validateName", return_value=True + ): + var = MemInfoVariable("test_var", 42) + self.assertEqual(var.var_name, "test_var") + self.assertEqual(var.hbm_address, 42) + + def test_init_strips_whitespace(self): + """@brief Test that initialization strips whitespace from variable name.""" + with patch( + "assembler.memory_model.variable.Variable.validateName", return_value=True + ): + var = MemInfoVariable(" test_var ", 42) + self.assertEqual(var.var_name, "test_var") + + def test_init_invalid_name(self): + """@brief Test initialization with invalid variable name.""" + with patch( + "assembler.memory_model.variable.Variable.validateName", return_value=False + ): + with self.assertRaises(RuntimeError): + MemInfoVariable("invalid!var", 42) + + def test_repr(self): + """@brief Test the __repr__ method.""" + with patch( + "assembler.memory_model.variable.Variable.validateName", return_value=True + ): + var = MemInfoVariable("test_var", 42) + self.assertEqual( + repr(var), repr({"var_name": "test_var", "hbm_address": 42}) + ) + + def test_as_dict(self): + """@brief Test the as_dict method.""" + with patch( + "assembler.memory_model.variable.Variable.validateName", return_value=True + ): + var = MemInfoVariable("test_var", 42) + self.assertEqual(var.as_dict(), {"var_name": "test_var", "hbm_address": 42}) + + +class TestMemInfoKeygenVariable(unittest.TestCase): + """@brief Tests for the MemInfoKeygenVariable class.""" + + def test_init_valid(self): + """@brief Test initialization with valid parameters.""" + with patch( + "assembler.memory_model.variable.Variable.validateName", return_value=True + ): + var = MemInfoKeygenVariable("test_var", 2, 3) + self.assertEqual(var.var_name, "test_var") + self.assertEqual(var.hbm_address, -1) # Should be initialized to -1 + self.assertEqual(var.seed_index, 2) + self.assertEqual(var.key_index, 3) + + def test_init_negative_seed_index(self): + """@brief Test initialization with negative seed index.""" + with patch( + "assembler.memory_model.variable.Variable.validateName", return_value=True + ): + with self.assertRaises(IndexError): + MemInfoKeygenVariable("test_var", -1, 3) + + def test_init_negative_key_index(self): + """@brief Test initialization with negative key index.""" + with patch( + "assembler.memory_model.variable.Variable.validateName", return_value=True + ): + with self.assertRaises(IndexError): + MemInfoKeygenVariable("test_var", 2, -1) + + def test_as_dict(self): + """@brief Test the as_dict method.""" + with patch( + "assembler.memory_model.variable.Variable.validateName", return_value=True + ): + var = MemInfoKeygenVariable("test_var", 2, 3) + self.assertEqual( + var.as_dict(), {"var_name": "test_var", "seed_index": 2, "key_index": 3} + ) + + +class TestMemInfoMetadata(unittest.TestCase): + """@brief Tests for the MemInfo.Metadata class.""" + + def test_parse_meta_field_from_mem_tokens_valid(self): + """@brief Test parsing a valid metadata field.""" + tokens = ["dload", "LOAD_ONES", "42", "ones_var"] + result = MemInfo.Metadata.parse_meta_field_from_mem_tokens( + tokens, "LOAD_ONES", var_prefix="ONES" + ) + self.assertIsNotNone(result) + self.assertEqual(result.var_name, "ones_var") + self.assertEqual(result.hbm_address, 42) + + def test_parse_meta_field_from_mem_tokens_no_name(self): + """@brief Test parsing a metadata field without explicit name.""" + tokens = ["dload", "LOAD_ONES", "42"] + result = MemInfo.Metadata.parse_meta_field_from_mem_tokens( + tokens, "LOAD_ONES", var_prefix="ONES" + ) + self.assertIsNotNone(result) + self.assertEqual(result.var_name, "ONES_42") + self.assertEqual(result.hbm_address, 42) + + def test_parse_meta_field_from_mem_tokens_with_extra(self): + """@brief Test parsing a metadata field with var_extra.""" + tokens = ["dload", "LOAD_ONES", "42"] + result = MemInfo.Metadata.parse_meta_field_from_mem_tokens( + tokens, "LOAD_ONES", var_prefix="ONES", var_extra="_extra" + ) + self.assertIsNotNone(result) + self.assertEqual(result.var_name, "ONES_extra") + self.assertEqual(result.hbm_address, 42) + + def test_parse_meta_field_from_mem_tokens_invalid(self): + """@brief Test parsing an invalid metadata field.""" + # Not enough tokens + tokens = ["dload"] + result = MemInfo.Metadata.parse_meta_field_from_mem_tokens( + tokens, "LOAD_ONES", var_prefix="ONES" + ) + self.assertIsNone(result) + + # Wrong first token + tokens = ["wrong", "LOAD_ONES", "42"] + result = MemInfo.Metadata.parse_meta_field_from_mem_tokens( + tokens, "LOAD_ONES", var_prefix="ONES" + ) + self.assertIsNone(result) + + # Wrong second token + tokens = ["dload", "WRONG", "42"] + result = MemInfo.Metadata.parse_meta_field_from_mem_tokens( + tokens, "LOAD_ONES", var_prefix="ONES" + ) + self.assertIsNone(result) + + def test_metadata_init_and_properties(self): + """@brief Test initialization and properties of Metadata class.""" + # Prepare test data + metadata_dict = { + "ones": [{"var_name": "ones_var", "hbm_address": 1}], + "ntt_auxiliary_table": [{"var_name": "ntt_aux", "hbm_address": 2}], + "ntt_routing_table": [{"var_name": "ntt_route", "hbm_address": 3}], + "intt_auxiliary_table": [{"var_name": "intt_aux", "hbm_address": 4}], + "intt_routing_table": [{"var_name": "intt_route", "hbm_address": 5}], + "twid": [{"var_name": "twiddle_var", "hbm_address": 6}], + "keygen_seed": [{"var_name": "keygen_seed", "hbm_address": 7}], + } + + # Initialize Metadata + metadata = MemInfo.Metadata(**metadata_dict) + + # Test property access + self.assertEqual(len(metadata.ones), 1) + self.assertEqual(metadata.ones[0].var_name, "ones_var") + + self.assertEqual(len(metadata.ntt_auxiliary_table), 1) + self.assertEqual(metadata.ntt_auxiliary_table[0].var_name, "ntt_aux") + + self.assertEqual(len(metadata.ntt_routing_table), 1) + self.assertEqual(metadata.ntt_routing_table[0].var_name, "ntt_route") + + self.assertEqual(len(metadata.intt_auxiliary_table), 1) + self.assertEqual(metadata.intt_auxiliary_table[0].var_name, "intt_aux") + + self.assertEqual(len(metadata.intt_routing_table), 1) + self.assertEqual(metadata.intt_routing_table[0].var_name, "intt_route") + + self.assertEqual(len(metadata.twiddle), 1) + self.assertEqual(metadata.twiddle[0].var_name, "twiddle_var") + + self.assertEqual(len(metadata.keygen_seeds), 1) + self.assertEqual(metadata.keygen_seeds[0].var_name, "keygen_seed") + + def test_get_item(self): + """@brief Test the __getitem__ method.""" + metadata_dict = {"ones": [{"var_name": "ones_var", "hbm_address": 1}]} + metadata = MemInfo.Metadata(**metadata_dict) + + # Test __getitem__ using bracket notation + ones_list = metadata["ones"] + self.assertEqual(len(ones_list), 1) + self.assertEqual(ones_list[0].var_name, "ones_var") + + +class TestMemInfoParsers(unittest.TestCase): + """@brief Tests for the various parser methods in MemInfo.""" + + def test_ones_parse_from_mem_tokens(self): + """@brief Test parsing Ones metadata from tokens.""" + tokens = ["dload", "LOAD_ONES", "42", "ones_var"] + with patch( + "assembler.memory_model.mem_info.MemInfo.Metadata.parse_meta_field_from_mem_tokens" + ) as mock_parse: + mock_parse.return_value = MagicMock() + result = MemInfo.Metadata.Ones.parse_from_mem_tokens(tokens) + mock_parse.assert_called_once_with( + tokens, + MemInfo.Const.Keyword.LOAD_ONES, + var_prefix=MemInfo.Const.Keyword.LOAD_ONES, + ) + self.assertEqual(result, mock_parse.return_value) + + def test_ntt_aux_table_parse_from_mem_tokens(self): + """@brief Test parsing NTTAuxTable metadata from tokens.""" + tokens = ["dload", "LOAD_NTT_AUX_TABLE", "42", "ntt_aux_var"] + with patch( + "assembler.memory_model.mem_info.MemInfo.Metadata.parse_meta_field_from_mem_tokens" + ) as mock_parse: + mock_parse.return_value = MagicMock() + result = MemInfo.Metadata.NTTAuxTable.parse_from_mem_tokens(tokens) + mock_parse.assert_called_once_with( + tokens, + MemInfo.Const.Keyword.LOAD_NTT_AUX_TABLE, + var_prefix=MemInfo.Const.Keyword.LOAD_NTT_AUX_TABLE, + ) + self.assertEqual(result, mock_parse.return_value) + + def test_ntt_routing_table_parse_from_mem_tokens(self): + """@brief Test parsing NTTRoutingTable metadata from tokens.""" + tokens = ["dload", "LOAD_NTT_ROUTING_TABLE", "42", "ntt_route_var"] + with patch( + "assembler.memory_model.mem_info.MemInfo.Metadata.parse_meta_field_from_mem_tokens" + ) as mock_parse: + mock_parse.return_value = MagicMock() + result = MemInfo.Metadata.NTTRoutingTable.parse_from_mem_tokens(tokens) + mock_parse.assert_called_once_with( + tokens, + MemInfo.Const.Keyword.LOAD_NTT_ROUTING_TABLE, + var_prefix=MemInfo.Const.Keyword.LOAD_NTT_ROUTING_TABLE, + ) + self.assertEqual(result, mock_parse.return_value) + + def test_intt_aux_table_parse_from_mem_tokens(self): + """@brief Test parsing iNTTAuxTable metadata from tokens.""" + tokens = ["dload", "LOAD_iNTT_AUX_TABLE", "42", "intt_aux_var"] + with patch( + "assembler.memory_model.mem_info.MemInfo.Metadata.parse_meta_field_from_mem_tokens" + ) as mock_parse: + mock_parse.return_value = MagicMock() + result = MemInfo.Metadata.iNTTAuxTable.parse_from_mem_tokens(tokens) + mock_parse.assert_called_once_with( + tokens, + MemInfo.Const.Keyword.LOAD_iNTT_AUX_TABLE, + var_prefix=MemInfo.Const.Keyword.LOAD_iNTT_AUX_TABLE, + ) + self.assertEqual(result, mock_parse.return_value) + + def test_intt_routing_table_parse_from_mem_tokens(self): + """@brief Test parsing iNTTRoutingTable metadata from tokens.""" + tokens = ["dload", "LOAD_iNTT_ROUTING_TABLE", "42", "intt_route_var"] + with patch( + "assembler.memory_model.mem_info.MemInfo.Metadata.parse_meta_field_from_mem_tokens" + ) as mock_parse: + mock_parse.return_value = MagicMock() + result = MemInfo.Metadata.iNTTRoutingTable.parse_from_mem_tokens(tokens) + mock_parse.assert_called_once_with( + tokens, + MemInfo.Const.Keyword.LOAD_iNTT_ROUTING_TABLE, + var_prefix=MemInfo.Const.Keyword.LOAD_iNTT_ROUTING_TABLE, + ) + self.assertEqual(result, mock_parse.return_value) + + def test_twiddle_parse_from_mem_tokens(self): + """@brief Test parsing Twiddle metadata from tokens.""" + tokens = ["dload", "LOAD_TWIDDLE", "42", "twiddle_var"] + with patch( + "assembler.memory_model.mem_info.MemInfo.Metadata.parse_meta_field_from_mem_tokens" + ) as mock_parse: + mock_parse.return_value = MagicMock() + result = MemInfo.Metadata.Twiddle.parse_from_mem_tokens(tokens) + mock_parse.assert_called_once_with( + tokens, + MemInfo.Const.Keyword.LOAD_TWIDDLE, + var_prefix=MemInfo.Const.Keyword.LOAD_TWIDDLE, + ) + self.assertEqual(result, mock_parse.return_value) + + def test_keygen_seed_parse_from_mem_tokens(self): + """@brief Test parsing KeygenSeed metadata from tokens.""" + tokens = ["dload", "LOAD_KEYGEN_SEED", "42", "keygen_seed_var"] + with patch( + "assembler.memory_model.mem_info.MemInfo.Metadata.parse_meta_field_from_mem_tokens" + ) as mock_parse: + mock_parse.return_value = MagicMock() + result = MemInfo.Metadata.KeygenSeed.parse_from_mem_tokens(tokens) + mock_parse.assert_called_once_with( + tokens, + MemInfo.Const.Keyword.LOAD_KEYGEN_SEED, + var_prefix=MemInfo.Const.Keyword.LOAD_KEYGEN_SEED, + ) + self.assertEqual(result, mock_parse.return_value) + + def test_keygen_parse_from_mem_tokens_valid(self): + """@brief Test parsing a valid keygen variable.""" + tokens = ["keygen", "2", "3", "keygen_var"] + result = MemInfo.Keygen.parse_from_mem_tokens(tokens) + self.assertIsNotNone(result) + self.assertEqual(result.var_name, "keygen_var") + self.assertEqual(result.seed_index, 2) + self.assertEqual(result.key_index, 3) + + def test_keygen_parse_from_mem_tokens_invalid(self): + """@brief Test parsing an invalid keygen variable.""" + # Not enough tokens + tokens = ["keygen", "2", "3"] + result = MemInfo.Keygen.parse_from_mem_tokens(tokens) + self.assertIsNone(result) + + # Wrong first token + tokens = ["wrong", "2", "3", "keygen_var"] + result = MemInfo.Keygen.parse_from_mem_tokens(tokens) + self.assertIsNone(result) + + def test_input_parse_from_mem_tokens_valid(self): + """@brief Test parsing a valid input variable.""" + tokens = ["dload", "poly", "42", "input_var"] + with patch( + "assembler.memory_model.variable.Variable.validateName", return_value=True + ): + result = MemInfo.Input.parse_from_mem_tokens(tokens) + self.assertIsNotNone(result) + self.assertEqual(result.var_name, "input_var") + self.assertEqual(result.hbm_address, 42) + + def test_input_parse_from_mem_tokens_invalid(self): + """@brief Test parsing an invalid input variable.""" + # Not enough tokens + tokens = ["dload", "poly", "42"] + result = MemInfo.Input.parse_from_mem_tokens(tokens) + self.assertIsNone(result) + + # Wrong tokens + tokens = ["wrong", "poly", "42", "input_var"] + result = MemInfo.Input.parse_from_mem_tokens(tokens) + self.assertIsNone(result) + + tokens = ["dload", "wrong", "42", "input_var"] + result = MemInfo.Input.parse_from_mem_tokens(tokens) + self.assertIsNone(result) + + def test_output_parse_from_mem_tokens_valid(self): + """@brief Test parsing a valid output variable.""" + tokens = ["dstore", "output_var", "42"] + with patch( + "assembler.memory_model.variable.Variable.validateName", return_value=True + ): + result = MemInfo.Output.parse_from_mem_tokens(tokens) + self.assertIsNotNone(result) + self.assertEqual(result.var_name, "output_var") + self.assertEqual(result.hbm_address, 42) + + def test_output_parse_from_mem_tokens_invalid(self): + """@brief Test parsing an invalid output variable.""" + # Not enough tokens + tokens = ["store", "output_var"] + result = MemInfo.Output.parse_from_mem_tokens(tokens) + self.assertIsNone(result) + + # Wrong first token + tokens = ["wrong", "output_var", "42"] + result = MemInfo.Output.parse_from_mem_tokens(tokens) + self.assertIsNone(result) + + +class TestMemInfo(unittest.TestCase): + """@brief Tests for the MemInfo class.""" + + def test_init_default(self): + """@brief Test default initialization.""" + mem_info = MemInfo() + self.assertEqual(len(mem_info.keygens), 0) + self.assertEqual(len(mem_info.inputs), 0) + self.assertEqual(len(mem_info.outputs), 0) + # Verify metadata was initialized + self.assertIsInstance(mem_info.metadata, MemInfo.Metadata) + + def test_init_with_data(self): + """@brief Test initialization with data.""" + # Prepare test data + test_data = { + "keygens": [{"var_name": "keygen_var", "seed_index": 1, "key_index": 2}], + "inputs": [{"var_name": "input_var", "hbm_address": 42}], + "outputs": [{"var_name": "output_var", "hbm_address": 43}], + "metadata": { + "ones": [{"var_name": "ones_var", "hbm_address": 44}], + "twid": [{"var_name": "twiddle_var", "hbm_address": 45}], + }, + } + + # Initialize with test data + with patch("assembler.memory_model.mem_info.MemInfo.validate"): + mem_info = MemInfo(**test_data) + + # Verify data was loaded correctly + self.assertEqual(len(mem_info.keygens), 1) + self.assertEqual(mem_info.keygens[0].var_name, "keygen_var") + + self.assertEqual(len(mem_info.inputs), 1) + self.assertEqual(mem_info.inputs[0].var_name, "input_var") + + self.assertEqual(len(mem_info.outputs), 1) + self.assertEqual(mem_info.outputs[0].var_name, "output_var") + + # Verify metadata + self.assertEqual(len(mem_info.metadata.ones), 1) + self.assertEqual(mem_info.metadata.ones[0].var_name, "ones_var") + + self.assertEqual(len(mem_info.metadata.twiddle), 1) + self.assertEqual(mem_info.metadata.twiddle[0].var_name, "twiddle_var") + + def test_factory_dict(self): + """@brief Test the factory_dict property.""" + mem_info = MemInfo() + factory_dict = mem_info.factory_dict + + # Verify all expected keys are present + self.assertIn(MemInfo.Keygen, factory_dict) + self.assertIn(MemInfo.Input, factory_dict) + self.assertIn(MemInfo.Output, factory_dict) + self.assertIn(MemInfo.Metadata.KeygenSeed, factory_dict) + self.assertIn(MemInfo.Metadata.Ones, factory_dict) + self.assertIn(MemInfo.Metadata.NTTAuxTable, factory_dict) + self.assertIn(MemInfo.Metadata.NTTRoutingTable, factory_dict) + self.assertIn(MemInfo.Metadata.iNTTAuxTable, factory_dict) + self.assertIn(MemInfo.Metadata.iNTTRoutingTable, factory_dict) + self.assertIn(MemInfo.Metadata.Twiddle, factory_dict) + + # Verify values point to correct lists + self.assertEqual(factory_dict[MemInfo.Keygen], mem_info.keygens) + self.assertEqual(factory_dict[MemInfo.Input], mem_info.inputs) + self.assertEqual(factory_dict[MemInfo.Output], mem_info.outputs) + + def test_mem_info_types(self): + """@brief Test the mem_info_types class property.""" + mem_info_types = MemInfo.mem_info_types + + # Verify expected types are in the list + self.assertIn(MemInfo.Keygen, mem_info_types) + self.assertIn(MemInfo.Input, mem_info_types) + self.assertIn(MemInfo.Output, mem_info_types) + self.assertIn(MemInfo.Metadata.KeygenSeed, mem_info_types) + self.assertIn(MemInfo.Metadata.Ones, mem_info_types) + self.assertIn(MemInfo.Metadata.NTTAuxTable, mem_info_types) + self.assertIn(MemInfo.Metadata.NTTRoutingTable, mem_info_types) + self.assertIn(MemInfo.Metadata.iNTTAuxTable, mem_info_types) + self.assertIn(MemInfo.Metadata.iNTTRoutingTable, mem_info_types) + self.assertIn(MemInfo.Metadata.Twiddle, mem_info_types) + + def test_get_meminfo_var_from_tokens_valid(self): + """@brief Test getting a MemInfo variable from valid tokens.""" + tokens = ["keygen", "2", "3", "keygen_var"] + + # Mock the parse_from_mem_tokens method to return a mock variable + mock_variable = MagicMock() + with patch.object( + MemInfo.Keygen, "parse_from_mem_tokens", return_value=mock_variable + ): + var, var_type = MemInfo.get_meminfo_var_from_tokens(tokens) + + # Verify results + self.assertEqual(var, mock_variable) + self.assertEqual(var_type, MemInfo.Keygen) + + def test_get_meminfo_var_from_tokens_not_found(self): + """@brief Test getting a MemInfo variable when no parser can handle it.""" + tokens = ["unknown", "token"] + + # Mock all parse_from_mem_tokens methods to return None + with patch.object( + MemInfo, + "mem_info_types", + return_value=[ + MagicMock(parse_from_mem_tokens=MagicMock(return_value=None)) + ], + ): + var, var_type = MemInfo.get_meminfo_var_from_tokens(tokens) + + # Verify results + self.assertIsNone(var) + self.assertIsNone(var_type) + + def test_add_meminfo_var_from_tokens_valid(self): + """@brief Test adding a MemInfo variable from valid tokens.""" + tokens = ["keygen", "2", "3", "keygen_var"] + mem_info = MemInfo() + + # Mock get_meminfo_var_from_tokens + mock_variable = MagicMock() + mock_type = MagicMock() + mock_list = MagicMock() + mock_dict = {mock_type: mock_list} + + with patch.object( + MemInfo, + "get_meminfo_var_from_tokens", + return_value=(mock_variable, mock_type), + ), patch.object( + MemInfo, "factory_dict", new_callable=PropertyMock, return_value=mock_dict + ): + + # Call the method + mem_info.add_meminfo_var_from_tokens(tokens) + + # Verify the variable was added to the correct list + mock_list.append.assert_called_once_with(mock_variable) + + def test_add_meminfo_var_from_tokens_not_found(self): + """@brief Test adding a MemInfo variable when no parser can handle it.""" + tokens = ["unknown", "token"] + mem_info = MemInfo() + + # Mock get_meminfo_var_from_tokens to return None + with patch.object( + MemInfo, "get_meminfo_var_from_tokens", return_value=(None, None) + ): + # Verify exception is raised + with self.assertRaises(RuntimeError): + mem_info.add_meminfo_var_from_tokens(tokens) + + def test_from_file_iter_valid(self): + """@brief Test creating a MemInfo from a valid file iterator.""" + # Mock lines + lines = [ + "keygen, 2, 3, keygen_var", + "dload, input, 42, input_var", + "store, output_var, 43", + "dload, LOAD_ONES, 44, ones_var", + " ", # Empty line to test skipping + "# Comment line", # Comment line to test skipping + ] + + # Mock tokenize_from_line + def mock_tokenize(line): + if line.startswith("keygen"): + return (["keygen", "2", "3", "keygen_var"], "") + if line.startswith("dload, input"): + return (["dload", "input", "42", "input_var"], "") + if line.startswith("store"): + return (["store", "output_var", "43"], "") + if line.startswith("dload, LOAD_ONES"): + return (["dload", "LOAD_ONES", "44", "ones_var"], "") + + return ([], "") + + # Mock methods - patch the function where it's imported in mem_info + with patch( + "assembler.memory_model.mem_info.tokenize_from_line", + side_effect=mock_tokenize, + ), patch.object( + MemInfo, "add_meminfo_var_from_tokens" + ) as mock_add_var, patch.object( + MemInfo, "validate" + ): + + # Call the method + MemInfo.from_file_iter(lines) + + # Verify add_meminfo_var_from_tokens was called for each valid line + self.assertEqual(mock_add_var.call_count, 4) + + def test_from_file_iter_error(self): + """@brief Test creating a MemInfo when an error occurs.""" + # Mock lines + lines = ["invalid line"] + + # Mock tokenize_from_line + def mock_tokenize(line): + return (["invalid"], line) + + # Mock methods + with patch( + "assembler.instructions.tokenize_from_line", side_effect=mock_tokenize + ), patch.object( + MemInfo, + "add_meminfo_var_from_tokens", + side_effect=RuntimeError("Test error"), + ), patch.object( + MemInfo, "validate" + ): + + # Verify exception is raised with line number and content + with self.assertRaises(RuntimeError) as context: + MemInfo.from_file_iter(lines) + + self.assertIn("1: invalid line", str(context.exception)) + + def test_from_dinstrs_valid(self): + """@brief Test creating a MemInfo from valid DInstructions.""" + # Mock DInstructions + dinstrs = [ + MagicMock(tokens=["keygen", "2", "3", "keygen_var"]), + MagicMock(tokens=["dload", "input", "42", "input_var"]), + MagicMock(tokens=["store", "output_var", "43"]), + ] + + # Mock methods + with patch.object( + MemInfo, "add_meminfo_var_from_tokens" + ) as mock_add_var, patch.object(MemInfo, "validate"), patch( + "builtins.print" + ): # Mock print to avoid output + + # Call the method + MemInfo.from_dinstrs(dinstrs) + + # Verify add_meminfo_var_from_tokens was called for each instruction + self.assertEqual(mock_add_var.call_count, 3) + mock_add_var.assert_has_calls( + [ + call(["keygen", "2", "3", "keygen_var"]), + call(["dload", "input", "42", "input_var"]), + call(["store", "output_var", "43"]), + ] + ) + + def test_from_dinstrs_error(self): + """@brief Test creating a MemInfo when an error occurs.""" + # Mock DInstructions + dinstrs = [MagicMock(tokens=["invalid"])] + + # Mock methods + with patch.object( + MemInfo, + "add_meminfo_var_from_tokens", + side_effect=RuntimeError("Test error"), + ), patch.object(MemInfo, "validate"), patch( + "builtins.print" + ): # Mock print to avoid output + + # Verify exception is raised with instruction number + with self.assertRaises(RuntimeError) as context: + MemInfo.from_dinstrs(dinstrs) + + self.assertIn("1: ['invalid']", str(context.exception)) + + def test_as_dict(self): + """@brief Test the as_dict method.""" + # Create a MemInfo with test data + with patch("assembler.memory_model.mem_info.MemInfo.validate"): + + # dicts + keygens_dict = {"var_name": "keygen_var", "seed_index": 1, "key_index": 2} + inputs_dict = {"var_name": "input_var", "hbm_address": 42} + outputs_dict = {"var_name": "output_var", "hbm_address": 43} + ones_dict = {"var_name": "ones_var", "hbm_address": 44} + twiddle_dict = {"var_name": "twiddle_var", "hbm_address": 45} + + # Prepare test data + test_data = { + "keygens": [keygens_dict], + "inputs": [inputs_dict], + "outputs": [outputs_dict], + "metadata": { + "ones": [ones_dict], + "twid": [twiddle_dict], + }, + } + + # Create the MemInfo with the test data + mem_info = MemInfo(**test_data) + + # Call the method + result = mem_info.as_dict() + + # Verify result structure + self.assertIn("keygens", result) + self.assertIn("inputs", result) + self.assertIn("outputs", result) + self.assertIn("metadata", result) + + # Verify values + self.assertEqual(result["keygens"], [keygens_dict]) + self.assertEqual(result["inputs"], [inputs_dict]) + self.assertEqual(result["outputs"], [outputs_dict]) + self.assertIn("ones", result["metadata"]) + self.assertEqual(result["metadata"]["ones"], [ones_dict]) + + def test_validate_valid(self): + """@brief Test validation with valid data.""" + + ones_dict = {"var_name": "ones_var", "hbm_address": 44} + twiddle_dict = {"var_name": "twiddle_var", "hbm_address": 45} + + twiddle_list = [ + twiddle_dict for _ in range(MemoryModel.MAX_TWIDDLE_META_VARS_PER_SEGMENT) + ] + + # Create metadata dictionary for initialization + metadata_dict = { + "ones": [ones_dict], + "twid": twiddle_list, + } + + # Initialize without validation to set up the test + mem_info = MemInfo(metadata=metadata_dict) + + # Now explicitly call validate which we want to test + mem_info.validate() # Should not raise any exceptions + + def test_validate_twiddle_mismatch(self): + """@brief Test validation with mismatched twiddle count.""" + + ones_dict = {"var_name": "ones_var", "hbm_address": 44} + twiddle_dict = {"var_name": "twiddle_var", "hbm_address": 45} + + # Create metadata dictionary for initialization + metadata_dict = { + "ones": [ones_dict], + "twid": [twiddle_dict], + } + + # Call MemInfo initialization with metadata + with patch( + "assembler.memory_model.mem_info.MemoryModel.MAX_TWIDDLE_META_VARS_PER_SEGMENT", + 2, + ): + with self.assertRaises(RuntimeError) as context: + # Initialize without validation to set up the test + MemInfo(metadata=metadata_dict) + + self.assertIn( + "Expected 2 times as many twiddles as ones", str(context.exception) + ) + + def test_validate_duplicate_var_name(self): + """@brief Test validation with duplicate variable names but different HBM addresses.""" + # Create variable dictionaries with duplicate names but different addresses + intt_aux_dict = {"var_name": "duplicate", "hbm_address": 1} + ntt_route_dict = {"var_name": "duplicate", "hbm_address": 2} + + # Create metadata dictionary for initialization with the duplicate variables + metadata_dict = { + "ones": [], + "twid": [], + "intt_auxiliary_table": [intt_aux_dict], + "ntt_routing_table": [ntt_route_dict], + "ntt_auxiliary_table": [], + "intt_routing_table": [], + } + + # Initialize MemInfo with the metadata containing duplicates + with self.assertRaises(RuntimeError) as context: + MemInfo(metadata=metadata_dict) + + self.assertIn( + 'Variable "duplicate" already allocated', str(context.exception) + ) + + +class TestUpdateMemoryModelWithMemInfo(unittest.TestCase): + """@brief Tests for the updateMemoryModelWithMemInfo function.""" + + def setUp(self): + """@brief Set up common test fixtures.""" + # Create mock MemoryModel + self.mock_mem_model = MagicMock() + self.mock_mem_model.retrieveVarAdd = MagicMock() + + # Create mock MemInfo + self.mock_mem_info = MagicMock() + + # Group all mock variables in a dictionary + self.vars = { + "input": MagicMock(var_name="input_var", hbm_address=1), + "output": MagicMock(var_name="output_var", hbm_address=2), + "ones": MagicMock(var_name="ones_var", hbm_address=3), + "ntt_aux": MagicMock(var_name="ntt_aux", hbm_address=4), + "ntt_route": MagicMock(var_name="ntt_route", hbm_address=5), + "intt_aux": MagicMock(var_name="intt_aux", hbm_address=6), + "intt_route": MagicMock(var_name="intt_route", hbm_address=7), + "twiddle": MagicMock(var_name="twiddle_var", hbm_address=8), + "keygen_seed": MagicMock(var_name="keygen_seed", hbm_address=9), + } + + # Set up MemInfo + self.mock_mem_info.inputs = [self.vars["input"]] + self.mock_mem_info.outputs = [self.vars["output"]] + + # Set up metadata + self.mock_metadata = MagicMock() + self.mock_metadata.ones = [self.vars["ones"]] + self.mock_metadata.ntt_auxiliary_table = [self.vars["ntt_aux"]] + self.mock_metadata.ntt_routing_table = [self.vars["ntt_route"]] + self.mock_metadata.intt_auxiliary_table = [self.vars["intt_aux"]] + self.mock_metadata.intt_routing_table = [self.vars["intt_route"]] + self.mock_metadata.twiddle = [self.vars["twiddle"]] + self.mock_metadata.keygen_seeds = [self.vars["keygen_seed"]] + + self.mock_mem_info.metadata = self.mock_metadata + + def test_update_memory_model_inputs(self): + """@brief Test updating memory model with input variables.""" + # Call the function + with patch( + "assembler.memory_model.mem_info._allocateMemInfoVariable" + ) as mock_allocate: + updateMemoryModelWithMemInfo(self.mock_mem_model, self.mock_mem_info) + + # Verify input variables were allocated + mock_allocate.assert_any_call(self.mock_mem_model, self.vars["input"]) + + def test_update_memory_model_outputs(self): + """@brief Test updating memory model with output variables.""" + # Call the function + with patch( + "assembler.memory_model.mem_info._allocateMemInfoVariable" + ) as mock_allocate: + updateMemoryModelWithMemInfo(self.mock_mem_model, self.mock_mem_info) + + # Verify output variables were allocated and added to output_variables + mock_allocate.assert_any_call(self.mock_mem_model, self.vars["output"]) + self.mock_mem_model.output_variables.push.assert_called_once_with( + "output_var", None + ) + + def test_update_memory_model_metadata(self): + """@brief Test updating memory model with metadata variables.""" + # Call the function + with patch( + "assembler.memory_model.mem_info._allocateMemInfoVariable" + ) as mock_allocate: + updateMemoryModelWithMemInfo(self.mock_mem_model, self.mock_mem_info) + + # Verify metadata variables were retrieved, allocated and added to their respective lists + self.mock_mem_model.retrieveVarAdd.assert_has_calls( + [ + call("ones_var"), + call("ntt_aux"), + call("ntt_route"), + call("intt_aux"), + call("intt_route"), + call("twiddle_var"), + call("keygen_seed"), + ], + any_order=True, + ) + + mock_allocate.assert_has_calls( + [ + call(self.mock_mem_model, self.vars["ones"]), + call(self.mock_mem_model, self.vars["ntt_aux"]), + call(self.mock_mem_model, self.vars["ntt_route"]), + call(self.mock_mem_model, self.vars["intt_aux"]), + call(self.mock_mem_model, self.vars["intt_route"]), + call(self.mock_mem_model, self.vars["twiddle"]), + call(self.mock_mem_model, self.vars["keygen_seed"]), + ], + any_order=True, + ) + + self.mock_mem_model.add_meta_ones_var.assert_called_once_with("ones_var") + self.assertEqual(self.mock_mem_model.meta_ntt_aux_table, "ntt_aux") + self.assertEqual(self.mock_mem_model.meta_ntt_routing_table, "ntt_route") + self.assertEqual(self.mock_mem_model.meta_intt_aux_table, "intt_aux") + self.assertEqual(self.mock_mem_model.meta_intt_routing_table, "intt_route") + self.mock_mem_model.add_meta_twiddle_var.assert_called_once_with( + "twiddle_var" + ) + self.mock_mem_model.add_meta_keygen_seed_var.assert_called_once_with( + "keygen_seed" + ) + + +class TestAllocateMemInfoVariable(unittest.TestCase): + """@brief Tests for the _allocateMemInfoVariable function.""" + + def test_allocate_mem_info_variable_success(self): + """@brief Test successful allocation of a MemInfo variable.""" + # Create mock MemoryModel and variable + mock_mem_model = MagicMock() + mock_var_info = MagicMock(var_name="test_var", hbm_address=42) + + # Mock variables dictionary + mock_mem_model.variables = {"test_var": MagicMock(hbm_address=-1)} + + # Call the function + with patch( + "assembler.memory_model.mem_info._allocateMemInfoVariable" + ) as mock_function: + # Make it actually call the real function - simplified without lambda + mock_function.original = _allocateMemInfoVariable + mock_function.side_effect = mock_function.original + + mock_function(mock_mem_model, mock_var_info) + + # Verify the variable was allocated + mock_mem_model.hbm.allocateForce.assert_called_once_with( + 42, mock_mem_model.variables["test_var"] + ) + + def test_allocate_mem_info_variable_not_in_model(self): + """@brief Test allocation when the variable is not in the memory model.""" + # Create mock MemoryModel and variable + mock_mem_model = MagicMock() + mock_var_info = MagicMock(var_name="missing_var", hbm_address=42) + + # Mock variables dictionary (missing the variable) + mock_mem_model.variables = {} + + # Call the function + with patch( + "assembler.memory_model.mem_info._allocateMemInfoVariable" + ) as mock_function: + # Make it actually call the real function - simplified without lambda + mock_function.original = _allocateMemInfoVariable + mock_function.side_effect = mock_function.original + + # Verify exception is raised + with self.assertRaises(RuntimeError) as context: + mock_function(mock_mem_model, mock_var_info) + + self.assertIn( + "Variable missing_var not in memory model", str(context.exception) + ) + + def test_allocate_mem_info_variable_mismatch(self): + """@brief Test allocation when the variable has a different HBM address.""" + # Create mock MemoryModel and variable + mock_mem_model = MagicMock() + mock_var_info = MagicMock(var_name="test_var", hbm_address=42) + + # Mock variables dictionary with a variable that already has a different HBM address + mock_mem_model.variables = {"test_var": MagicMock(hbm_address=24)} + + # Call the function + with patch( + "assembler.memory_model.mem_info._allocateMemInfoVariable" + ) as mock_function: + # Make it actually call the real function - simplified without lambda + mock_function.original = _allocateMemInfoVariable + mock_function.side_effect = mock_function.original + + # Verify exception is raised + with self.assertRaises(RuntimeError) as context: + mock_function(mock_mem_model, mock_var_info) + + self.assertIn( + "Variable test_var already allocated in HBM address 24", + str(context.exception), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_link.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_link.py new file mode 100644 index 00000000..eff9a306 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_link.py @@ -0,0 +1,696 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# These contents may have been developed with support from one +# or more Intel-operated generative artificial intelligence solutions + +""" +@file test_he_link.py +@brief Unit tests for the he_link module +""" + +import os +import argparse +from unittest.mock import patch, mock_open, MagicMock, PropertyMock +import pytest + +import he_link +from assembler.common.config import GlobalConfig + + +class TestLinkerRunConfig: + """ + @class TestLinkerRunConfig + @brief Test cases for the LinkerRunConfig class + """ + + def test_init_with_valid_params(self): + """ + @brief Test initialization with valid parameters + """ + # Arrange + kwargs = { + "input_prefixes": ["prefix1", "prefix2"], + "output_prefix": "output_prefix", + "input_mem_file": "input.mem", + "output_dir": "/tmp", + "has_hbm": True, + "hbm_size": 1024, + "suppress_comments": False, + "use_xinstfetch": False, + "multi_mem_files": False, + } + + # Act + with patch("he_link.makeUniquePath", side_effect=lambda x: x): + config = he_link.LinkerRunConfig(**kwargs) + + # Assert + assert config.input_prefixes == ["prefix1", "prefix2"] + assert config.output_prefix == "output_prefix" + assert config.input_mem_file == "input.mem" + assert config.output_dir == "/tmp" + assert config.has_hbm is True + assert config.hbm_size == 1024 + assert config.suppress_comments is False + assert config.use_xinstfetch is False + assert config.multi_mem_files is False + + def test_init_with_missing_required_param(self): + """ + @brief Test initialization with missing required parameters + """ + # Arrange + kwargs = { + "output_prefix": "output_prefix", + "input_mem_file": "input.mem", + # Missing input_prefixes + } + + # Act & Assert + with pytest.raises(TypeError): + he_link.LinkerRunConfig(**kwargs) + + def test_as_dict(self): + """ + @brief Test the as_dict method returns a proper dictionary + """ + # Arrange + kwargs = { + "input_prefixes": ["prefix1"], + "output_prefix": "output_prefix", + "input_mem_file": "input.mem", + "output_dir": "/tmp", + "has_hbm": True, + "hbm_size": 1024, + } + + # Act + with patch("he_link.makeUniquePath", side_effect=lambda x: x): + config = he_link.LinkerRunConfig(**kwargs) + result = config.as_dict() + + # Assert Keys + assert isinstance(result, dict) + assert "input_prefixes" in result + assert "output_prefix" in result + assert "input_mem_file" in result + assert "output_dir" in result + assert "has_hbm" in result + assert "hbm_size" in result + + # Assert values + assert result["input_prefixes"] == ["prefix1"] + assert result["output_prefix"] == "output_prefix" + assert result["input_mem_file"] == "input.mem" + assert result["output_dir"] == "/tmp" + assert result["has_hbm"] is True + assert result["hbm_size"] == 1024 + + def test_str_representation(self): + """ + @brief Test the string representation of the configuration + """ + # Arrange + kwargs = { + "input_prefixes": ["prefix1"], + "output_prefix": "output_prefix", + "input_mem_file": "input.mem", + } + + # Act + with patch("he_link.makeUniquePath", side_effect=lambda x: x): + config = he_link.LinkerRunConfig(**kwargs) + result = str(config) + + # Assert params + assert "input_prefixes" in result + assert "output_prefix" in result + assert "input_mem_file" in result + # Assert values + assert "prefix1" in result + assert "output_prefix" in result + assert "input.mem" in result + + def test_init_for_default_params(self): + """ + @brief Test initialization with default parameters + """ + + # Arrange + kwargs = {"input_prefixes": ["prefix1"], "output_prefix": ""} + + # Reset the class-level config so the patch will take effect + he_link.RunConfig.reset_class_state() + + # Act + with patch("he_link.makeUniquePath", side_effect=lambda x: x), patch.object( + he_link.RunConfig, "DEFAULT_HBM_SIZE_KB", new_callable=PropertyMock + ) as mock_hbm_size, patch.object( + GlobalConfig, "suppress_comments", new_callable=PropertyMock + ) as mock_suppress_comments, patch.object( + GlobalConfig, "useXInstFetch", new_callable=PropertyMock + ) as mock_use_xinstfetch: + + # Mock the default HBM size + mock_suppress_comments.return_value = False + mock_use_xinstfetch.return_value = False + mock_hbm_size.return_value = 1024 + config = he_link.LinkerRunConfig(**kwargs) + + # Assert + assert config.output_prefix == "" + assert config.input_mem_file == "" + assert config.output_dir == os.getcwd() + assert config.has_hbm is True + assert config.hbm_size == 1024 + assert config.suppress_comments is False + assert config.use_xinstfetch is False + assert config.multi_mem_files is False + + +class TestKernelFiles: + """ + @class TestKernelFiles + @brief Test cases for the KernelFiles class + """ + + def test_kernel_files_creation(self): + """ + @brief Test KernelFiles creation and attribute access + """ + # Act + kernel_files = he_link.KernelFiles( + prefix="prefix", + minst="prefix.minst", + cinst="prefix.cinst", + xinst="prefix.xinst", + mem="prefix.mem", + ) + + # Assert + assert kernel_files.prefix == "prefix" + assert kernel_files.minst == "prefix.minst" + assert kernel_files.cinst == "prefix.cinst" + assert kernel_files.xinst == "prefix.xinst" + assert kernel_files.mem == "prefix.mem" + + def test_kernel_files_without_mem(self): + """ + @brief Test KernelFiles creation without mem file + """ + # Act + kernel_files = he_link.KernelFiles( + prefix="prefix", + minst="prefix.minst", + cinst="prefix.cinst", + xinst="prefix.xinst", + ) + + # Assert + assert kernel_files.prefix == "prefix" + assert kernel_files.minst == "prefix.minst" + assert kernel_files.cinst == "prefix.cinst" + assert kernel_files.xinst == "prefix.xinst" + assert kernel_files.mem is None + + +class TestHelperFunctions: + """ + @class TestHelperFunctions + @brief Test cases for helper functions in he_link + """ + + def test_prepare_output_files(self): + """ + @brief Test prepare_output_files function creates correct output files + """ + # Arrange + mock_config = MagicMock() + mock_config.output_dir = "/tmp" + mock_config.output_prefix = "output" + mock_config.multi_mem_files = False + + # Act + with patch("os.path.dirname", return_value="/tmp"), patch( + "pathlib.Path.mkdir" + ), patch("he_link.makeUniquePath", side_effect=lambda x: x): + result = he_link.prepare_output_files(mock_config) + + # Assert + assert result.prefix == "/tmp/output" + assert result.minst == "/tmp/output.minst" + assert result.cinst == "/tmp/output.cinst" + assert result.xinst == "/tmp/output.xinst" + assert result.mem is None + + def test_prepare_output_files_with_mem(self): + """ + @brief Test prepare_output_files with multi_mem_files=True + """ + # Arrange + mock_config = MagicMock() + mock_config.output_dir = "/tmp" + mock_config.output_prefix = "output" + mock_config.multi_mem_files = True + + # Act + with patch("os.path.dirname", return_value="/tmp"), patch( + "pathlib.Path.mkdir" + ), patch("he_link.makeUniquePath", side_effect=lambda x: x): + result = he_link.prepare_output_files(mock_config) + + # Assert + assert result.prefix == "/tmp/output" + assert result.minst == "/tmp/output.minst" + assert result.cinst == "/tmp/output.cinst" + assert result.xinst == "/tmp/output.xinst" + assert result.mem == "/tmp/output.mem" + + def test_prepare_input_files(self): + """ + @brief Test prepare_input_files function + """ + # Arrange + mock_config = MagicMock() + mock_config.input_prefixes = ["/tmp/input1", "/tmp/input2"] + mock_config.multi_mem_files = False + + mock_output_files = he_link.KernelFiles( + prefix="/tmp/output", + minst="/tmp/output.minst", + cinst="/tmp/output.cinst", + xinst="/tmp/output.xinst", + ) + + # Act + with patch("os.path.isfile", return_value=True), patch( + "he_link.makeUniquePath", side_effect=lambda x: x + ): + result = he_link.prepare_input_files(mock_config, mock_output_files) + + # Assert + assert len(result) == 2 + assert result[0].prefix == "/tmp/input1" + assert result[0].minst == "/tmp/input1.minst" + assert result[0].cinst == "/tmp/input1.cinst" + assert result[0].xinst == "/tmp/input1.xinst" + assert result[0].mem is None + assert result[1].prefix == "/tmp/input2" + + def test_prepare_input_files_file_not_found(self): + """ + @brief Test prepare_input_files when a file doesn't exist + """ + # Arrange + mock_config = MagicMock() + mock_config.input_prefixes = ["/tmp/input1"] + mock_config.multi_mem_files = False + + mock_output_files = he_link.KernelFiles( + prefix="/tmp/output", + minst="/tmp/output.minst", + cinst="/tmp/output.cinst", + xinst="/tmp/output.xinst", + ) + + # Act & Assert + with patch("os.path.isfile", return_value=False), patch( + "he_link.makeUniquePath", side_effect=lambda x: x + ): + with pytest.raises(FileNotFoundError): + he_link.prepare_input_files(mock_config, mock_output_files) + + def test_prepare_input_files_output_conflict(self): + """ + @brief Test prepare_input_files when input and output files conflict + """ + # Arrange + mock_config = MagicMock() + mock_config.input_prefixes = ["/tmp/input1"] + mock_config.multi_mem_files = False + + # Output file matching an input file + mock_output_files = he_link.KernelFiles( + prefix="/tmp/output", + minst="/tmp/input1.minst", # Conflict + cinst="/tmp/output.cinst", + xinst="/tmp/output.xinst", + ) + + # Act & Assert + with patch("os.path.isfile", return_value=True), patch( + "he_link.makeUniquePath", side_effect=lambda x: x + ): + with pytest.raises(RuntimeError): + he_link.prepare_input_files(mock_config, mock_output_files) + + @pytest.mark.parametrize("has_hbm", [True, False]) + def test_scan_variables(self, has_hbm): + """ + @brief Test scan_variables function with and without HBM + @param has_hbm Boolean indicating whether HBM is enabled + """ + # Arrange + GlobalConfig.hasHBM = has_hbm + mock_mem_model = MagicMock() + mock_verbose = MagicMock() + + input_files = [ + he_link.KernelFiles( + prefix="/tmp/input1", + minst="/tmp/input1.minst", + cinst="/tmp/input1.cinst", + xinst="/tmp/input1.xinst", + ) + ] + + # Act + with patch("linker.loader.load_minst_kernel_from_file", return_value=[]), patch( + "linker.loader.load_cinst_kernel_from_file", return_value=[] + ), patch( + "linker.steps.variable_discovery.discover_variables", + return_value=["var1", "var2"], + ), patch( + "linker.steps.variable_discovery.discover_variables_spad", + return_value=["var1", "var2"], + ): + he_link.scan_variables(input_files, mock_mem_model, mock_verbose) + + # Assert + if has_hbm: + assert mock_mem_model.add_variable.call_count == 2 + else: + assert mock_mem_model.add_variable.call_count == 2 + + def test_check_unused_variables(self): + """ + @brief Test check_unused_variables function + """ + # Arrange + GlobalConfig.hasHBM = True + mock_mem_model = MagicMock() + mock_mem_model.mem_info_vars = {"var1": MagicMock(), "var2": MagicMock()} + mock_mem_model.variables = {"var1"} + mock_mem_model.mem_info_meta = {} + + # Act & Assert + with pytest.raises(RuntimeError): + he_link.check_unused_variables(mock_mem_model) + + def test_link_kernels(self): + """ + @brief Test link_kernels function + """ + # Arrange + input_files = [ + he_link.KernelFiles( + prefix="/tmp/input1", + minst="/tmp/input1.minst", + cinst="/tmp/input1.cinst", + xinst="/tmp/input1.xinst", + ) + ] + + output_files = he_link.KernelFiles( + prefix="/tmp/output", + minst="/tmp/output.minst", + cinst="/tmp/output.cinst", + xinst="/tmp/output.xinst", + ) + + mock_mem_model = MagicMock() + mock_verbose = MagicMock() + + # Act + with patch("builtins.open", mock_open()), patch( + "linker.loader.load_minst_kernel_from_file", return_value=[] + ), patch("linker.loader.load_cinst_kernel_from_file", return_value=[]), patch( + "linker.loader.load_xinst_kernel_from_file", return_value=[] + ), patch( + "linker.steps.program_linker.LinkedProgram" + ) as mock_linked_program: + he_link.link_kernels( + input_files, output_files, mock_mem_model, mock_verbose + ) + + # Assert + mock_linked_program.assert_called_once() + instance = mock_linked_program.return_value + assert instance.link_kernel.call_count == 1 + assert instance.close.call_count == 1 + + +class TestMainFunction: + """ + @class TestMainFunction + @brief Test cases for the main function + """ + + @pytest.mark.parametrize("multi_mem_files", [True, False]) + def test_main(self, multi_mem_files): + """ + @brief Test main function with and without multi_mem_files + """ + # Arrange + mock_config = MagicMock() + mock_config.multi_mem_files = multi_mem_files + mock_config.has_hbm = True + mock_config.hbm_size = 1024 + mock_config.suppress_comments = False + mock_config.use_xinstfetch = False + + # Setup input files with conditional mem files + input_files = [ + he_link.KernelFiles( + prefix="prefix1", + minst="prefix1.minst", + cinst="prefix1.cinst", + xinst="prefix1.xinst", + mem="prefix1.mem" if multi_mem_files else None, + ), + he_link.KernelFiles( + prefix="prefix2", + minst="prefix2.minst", + cinst="prefix2.cinst", + xinst="prefix2.xinst", + mem="prefix2.mem" if multi_mem_files else None, + ), + ] + + # Create a dictionary of mocks to reduce the number of local variables + mocks = { + "prepare_output": MagicMock(), + "prepare_input": MagicMock(return_value=input_files), + "scan_variables": MagicMock(), + "check_unused_variables": MagicMock(), + "link_kernels": MagicMock(), + "from_dinstrs": MagicMock(), + "from_file_iter": MagicMock(), + "load_dinst": MagicMock(return_value=["1", "2"]), + "join_dinst": MagicMock(return_value=[]), + "dump_instructions": MagicMock(), + } + + # Act + with patch( + "assembler.common.constants.convertBytes2Words", return_value=1024 + ), patch("he_link.prepare_output_files", mocks["prepare_output"]), patch( + "he_link.prepare_input_files", mocks["prepare_input"] + ), patch( + "assembler.common.counter.Counter.reset" + ), patch( + "linker.loader.load_dinst_kernel_from_file", mocks["load_dinst"] + ), patch( + "linker.instructions.BaseInstruction.dump_instructions_to_file", + mocks["dump_instructions"], + ), patch( + "linker.steps.program_linker.LinkedProgram.join_dinst_kernels", + mocks["join_dinst"], + ), patch( + "assembler.memory_model.mem_info.MemInfo.from_dinstrs", + mocks["from_dinstrs"], + ), patch( + "assembler.memory_model.mem_info.MemInfo.from_file_iter", + mocks["from_file_iter"], + ), patch( + "linker.MemoryModel" + ), patch( + "he_link.scan_variables", mocks["scan_variables"] + ), patch( + "he_link.check_unused_variables", mocks["check_unused_variables"] + ), patch( + "he_link.link_kernels", mocks["link_kernels"] + ), patch( + "he_link.BaseInstruction.dump_instructions_to_file", + mocks["dump_instructions"], + ): + + he_link.main(mock_config, MagicMock()) + + # Assert pipeline is run as expected + mocks["prepare_output"].assert_called_once() + mocks["prepare_input"].assert_called_once() + mocks["scan_variables"].assert_called_once() + mocks["check_unused_variables"].assert_called_once() + mocks["link_kernels"].assert_called_once() + + if multi_mem_files: + # Should use from_dinstrs, not from_file_iter + assert mocks["from_dinstrs"].called + assert mocks["load_dinst"].called + assert mocks["join_dinst"].called + assert mocks["dump_instructions"].called + + assert not mocks["from_file_iter"].called + else: + # Should use from_file_iter, not from_dinstrs + assert mocks["from_file_iter"].called + assert not mocks["from_dinstrs"].called + + def test_warning_on_use_xinstfetch(self): + """ + @brief Test warning is issued when use_xinstfetch is True + """ + # Arrange + mock_config = MagicMock() + mock_config.multi_mem_files = False + mock_config.has_hbm = True + mock_config.hbm_size = 1024 + mock_config.suppress_comments = False + mock_config.use_xinstfetch = True # Should trigger warning + mock_config.input_mem_file = "input.mem" + + # Act & Assert + with patch("warnings.warn") as mock_warn, patch( + "assembler.common.constants.convertBytes2Words", return_value=1024 + ), patch("he_link.prepare_output_files"), patch( + "he_link.prepare_input_files" + ), patch( + "assembler.common.counter.Counter.reset" + ), patch( + "builtins.open", mock_open() + ), patch( + "assembler.memory_model.mem_info.MemInfo.from_file_iter" + ), patch( + "linker.MemoryModel" + ), patch( + "he_link.scan_variables" + ), patch( + "he_link.check_unused_variables" + ), patch( + "he_link.link_kernels" + ): + he_link.main(mock_config, None) + mock_warn.assert_called_once() + + +class TestParseArgs: + """ + @class TestParseArgs + @brief Test cases for the parse_args function + """ + + def test_parse_args_minimal(self): + """ + @brief Test parse_args with minimal arguments + """ + # Arrange + test_args = [ + "program", + "input_prefix", + "-o", + "output_prefix", + "-im", + "input.mem", + ] + + # Act + with patch("sys.argv", test_args), patch( + "argparse.ArgumentParser.parse_args", + return_value=argparse.Namespace( + input_prefixes=["input_prefix"], + output_prefix="output_prefix", + input_mem_file="input.mem", + output_dir="", + multi_mem_files=False, + mem_spec_file="", + isa_spec_file="", + has_hbm=True, + hbm_size=None, + suppress_comments=False, + verbose=0, + ), + ): + args = he_link.parse_args() + + # Assert + assert args.input_prefixes == ["input_prefix"] + assert args.output_prefix == "output_prefix" + assert args.input_mem_file == "input.mem" + assert args.multi_mem_files is False + + def test_parse_args_multi_mem_files(self): + """ + @brief Test parse_args with multi_mem_files flag + """ + # Arrange + test_args = [ + "program", + "input_prefix", + "-o", + "output_prefix", + "--multi_mem_files", + ] + + # Act + with patch("sys.argv", test_args), patch( + "argparse.ArgumentParser.parse_args", + return_value=argparse.Namespace( + input_prefixes=["input_prefix"], + output_prefix="output_prefix", + input_mem_file="", + output_dir="", + multi_mem_files=True, + mem_spec_file="", + isa_spec_file="", + has_hbm=True, + hbm_size=None, + suppress_comments=False, + verbose=0, + ), + ): + args = he_link.parse_args() + + # Assert + assert args.input_prefixes == ["input_prefix"] + assert args.output_prefix == "output_prefix" + assert args.input_mem_file == "" + assert args.multi_mem_files is True + + def test_missing_input_mem_file(self): + """ + @brief Test parse_args with missing input_mem_file when multi_mem_files is False + """ + # Arrange + test_args = ["program", "input_prefix", "-o", "output_prefix"] + + # Act & Assert + with patch("sys.argv", test_args), patch( + "argparse.ArgumentParser.parse_args", + return_value=argparse.Namespace( + input_prefixes=["input_prefix"], + output_prefix="output_prefix", + input_mem_file="", + output_dir="", + multi_mem_files=False, + mem_spec_file="", + isa_spec_file="", + has_hbm=True, + hbm_size=None, + suppress_comments=False, + verbose=0, + ), + ), patch("argparse.ArgumentParser.error") as mock_error: + he_link.parse_args() + mock_error.assert_called_once() diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_prep.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_prep.py index c87ad818..48edc8d3 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_prep.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_prep.py @@ -1,8 +1,11 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + """ -Unit tests for he_prep module. +@brief Unit tests for he_prep module. """ from unittest import mock @@ -15,10 +18,10 @@ def test_main_assigns_and_saves(monkeypatch, tmp_path): """ - Test that the main function assigns register banks, processes instructions, and saves the output. + @brief Test that the main function assigns register banks, processes instructions, and saves the output. - This test uses monkeypatching to mock dependencies and verifies that the output file - contains the expected instruction after processing a dummy input file. + @details This test uses monkeypatching to mock dependencies and verifies that the output file + contains the expected instruction after processing a dummy input file. """ # Prepare dummy input file input_file = tmp_path / "input.csv" @@ -45,7 +48,7 @@ def test_main_assigns_and_saves(monkeypatch, tmp_path): def test_main_no_input_file(): """ - Test that main raises an error when no input file is provided. + @brief Test that main raises an error when no input file is provided. """ with pytest.raises(FileNotFoundError): he_prep.main( @@ -55,7 +58,7 @@ def test_main_no_input_file(): def test_main_no_output_file(): """ - Test that main raises an error when no output file is provided. + @brief Test that main raises an error when no output file is provided. """ with pytest.raises(FileNotFoundError): he_prep.main( @@ -65,9 +68,9 @@ def test_main_no_output_file(): def test_main_no_instructions(monkeypatch): """ - Test that main handles the case where no instructions are processed. + @brief Test that main handles the case where no instructions are processed. - This test checks that the function can handle an empty instruction list without errors. + @details This test checks that the function can handle an empty instruction list without errors. """ input_file = "empty_input.csv" output_file = "empty_output.csv" @@ -98,7 +101,7 @@ def test_main_no_instructions(monkeypatch): def test_main_invalid_input_file(tmp_path): """ - Test that main raises an error when the input file does not exist. + @brief Test that main raises an error when the input file does not exist. """ input_file = tmp_path / "non_existent.csv" output_file = tmp_path / "output.csv" @@ -111,8 +114,9 @@ def test_main_invalid_input_file(tmp_path): def test_main_invalid_output_file(tmp_path): """ - Test that main raises an error when the output file cannot be created. - This test checks that the function handles file permission errors gracefully. + @brief Test that main raises an error when the output file cannot be created. + + @details This test checks that the function handles file permission errors gracefully. """ input_file = tmp_path / "input.csv" input_file.write_text("") # Write empty string to avoid SyntaxError @@ -130,7 +134,7 @@ def test_main_invalid_output_file(tmp_path): def test_parse_args(): """ - Test that parse_args returns the expected arguments. + @brief Test that parse_args returns the expected arguments. """ test_args = [ "prog", diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/__init__.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/__init__.py new file mode 100644 index 00000000..4057dc01 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_init.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_init.py new file mode 100644 index 00000000..910ceb34 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_init.py @@ -0,0 +1,393 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + +""" +@brief Unit tests for the memory model classes in linker/__init__.py. +""" + +import unittest +from unittest.mock import MagicMock, patch + +from assembler.common.config import GlobalConfig +from assembler.memory_model import mem_info +from linker import VariableInfo, HBM, MemoryModel + + +class TestVariableInfo(unittest.TestCase): + """@brief Tests for the VariableInfo class.""" + + def test_init(self): + """@brief Test initialization of VariableInfo. + + @test Verifies that VariableInfo is properly initialized with the given values + """ + var_info = VariableInfo("test_var", 42) + self.assertEqual(var_info.var_name, "test_var") + self.assertEqual(var_info.hbm_address, 42) + self.assertEqual(var_info.uses, 0) + self.assertEqual(var_info.last_kernel_used, -1) + + def test_init_default_values(self): + """@brief Test initialization with default values. + + @test Verifies that VariableInfo is properly initialized with default values + """ + var_info = VariableInfo("test_var") + self.assertEqual(var_info.var_name, "test_var") + self.assertEqual(var_info.hbm_address, -1) + self.assertEqual(var_info.uses, 0) + self.assertEqual(var_info.last_kernel_used, -1) + + +class TestHBM(unittest.TestCase): + """@brief Tests for the HBM class.""" + + def setUp(self): + """@brief Set up test fixtures.""" + self.hbm_size = 10 + self.hbm = HBM(self.hbm_size) + + def test_init(self): + """@brief Test initialization of HBM. + + @test Verifies that HBM is properly initialized with the given size + """ + self.assertEqual(len(self.hbm.buffer), self.hbm_size) + self.assertEqual(self.hbm.capacity, self.hbm_size) + # Check that buffer is initialized with None values + for item in self.hbm.buffer: + self.assertIsNone(item) + + def test_init_invalid_size(self): + """@brief Test initialization with invalid size. + + @test Verifies that ValueError is raised for invalid sizes + """ + with self.assertRaises(ValueError): + HBM(0) + with self.assertRaises(ValueError): + HBM(-1) + + def test_capacity_property(self): + """@brief Test the capacity property. + + @test Verifies that the capacity property returns the correct size + """ + self.assertEqual(self.hbm.capacity, self.hbm_size) + + def test_buffer_property(self): + """@brief Test the buffer property. + + @test Verifies that the buffer property returns the correct buffer + """ + buffer = self.hbm.buffer + self.assertEqual(len(buffer), self.hbm_size) + # Check that buffer is initialized with None values + for item in buffer: + self.assertIsNone(item) + + def test_force_allocate_valid(self): + """@brief Test force_allocate with valid parameters. + + @test Verifies that a variable is properly allocated at the specified address + """ + var_info = VariableInfo("test_var") + self.hbm.force_allocate(var_info, 5) + self.assertEqual(var_info.hbm_address, 5) + self.assertEqual(self.hbm.buffer[5], var_info) + + def test_force_allocate_out_of_bounds(self): + """@brief Test force_allocate with out of bounds address. + + @test Verifies that IndexError is raised for out-of-bounds addresses + """ + var_info = VariableInfo("test_var") + with self.assertRaises(IndexError): + self.hbm.force_allocate(var_info, -1) + with self.assertRaises(IndexError): + self.hbm.force_allocate(var_info, self.hbm_size) + + def test_force_allocate_already_allocated(self): + """@brief Test force_allocate with already allocated variable. + + @test Verifies that ValueError is raised when variable is already allocated + """ + var_info = VariableInfo("test_var", 3) + with self.assertRaises(ValueError): + self.hbm.force_allocate(var_info, 5) + + def test_force_allocate_address_occupied_with_hbm(self): + """@brief Test force_allocate with address occupied and HBM enabled. + + @test Verifies that RuntimeError is raised when address is occupied + """ + with patch.object(GlobalConfig, "hasHBM", True): + # Occupy address 5 + var_info1 = VariableInfo("var1") + var_info1.uses = 1 + self.hbm.force_allocate(var_info1, 5) + + # Try to allocate another variable at the same address + var_info2 = VariableInfo("var2") + with self.assertRaises(RuntimeError): + self.hbm.force_allocate(var_info2, 5) + + def test_force_allocate_address_occupied_without_hbm(self): + """@brief Test force_allocate with address occupied and HBM disabled. + + @test Verifies that RuntimeError is raised when address is occupied + """ + with patch.object(GlobalConfig, "hasHBM", False): + # Occupy address 5 + var_info1 = VariableInfo("var1") + var_info1.uses = 1 + self.hbm.force_allocate(var_info1, 5) + + # Try to allocate another variable at the same address + var_info2 = VariableInfo("var2") + with self.assertRaises(RuntimeError): + self.hbm.force_allocate(var_info2, 5) + + def test_force_allocate_address_recyclable_with_hbm(self): + """@brief Test force_allocate with recyclable address and HBM enabled. + + @test Verifies that an address can be recycled when the variable is not used + """ + with patch.object(GlobalConfig, "hasHBM", True): + # Occupy address 5 with a variable that's not used + var_info1 = VariableInfo("var1") + var_info1.uses = 0 + var_info1.last_kernel_used = 1 + self.hbm.force_allocate(var_info1, 5) + + # Allocate another variable at the same address with higher kernel index + var_info2 = VariableInfo("var2") + var_info2.last_kernel_used = 2 + self.hbm.force_allocate(var_info2, 5) + + # Check that the new variable is at the address + self.assertEqual(self.hbm.buffer[5], var_info2) + + def test_allocate(self): + """@brief Test allocate method. + + @test Verifies that a variable is allocated at the first available address + """ + var_info = VariableInfo("test_var") + self.hbm.allocate(var_info) + # The variable should be allocated at the first available address (0) + self.assertEqual(var_info.hbm_address, 0) + self.assertEqual(self.hbm.buffer[0], var_info) + + def test_allocate_full_memory(self): + """@brief Test allocate with full memory. + + @test Verifies that RuntimeError is raised when memory is full + """ + # Fill up the HBM + for i in range(self.hbm_size): + var_info = VariableInfo(f"var{i}") + var_info.uses = 1 + self.hbm.force_allocate(var_info, i) + + # Try to allocate another variable + var_info = VariableInfo("test_var") + with self.assertRaises(RuntimeError): + self.hbm.allocate(var_info) + + def test_allocate_with_recycling(self): + """@brief Test allocate with recycling unused addresses. + + @test Verifies that unused addresses can be recycled + """ + with patch.object(GlobalConfig, "hasHBM", True): + # Fill up the HBM + for i in range(self.hbm_size): + var_info = VariableInfo(f"var{i}") + var_info.uses = 1 if i != 3 else 0 + var_info.last_kernel_used = 1 + self.hbm.force_allocate(var_info, i) + + # Allocate a new variable - should reuse address 3 + var_info = VariableInfo("test_var") + var_info.last_kernel_used = 2 + self.hbm.allocate(var_info) + self.assertEqual(var_info.hbm_address, 3) + + +class TestMemoryModel(unittest.TestCase): + """@brief Tests for the MemoryModel class.""" + + def setUp(self): + """@brief Set up test fixtures.""" + # Create a mock MemInfo + self.mock_mem_info = MagicMock(spec=mem_info.MemInfo) + + # Set up mock input variables + self.input_var = MagicMock(spec=mem_info.MemInfoVariable) + self.input_var.var_name = "input_var" + self.input_var.hbm_address = 1 + + # Set up mock output variables + self.output_var = MagicMock(spec=mem_info.MemInfoVariable) + self.output_var.var_name = "output_var" + self.output_var.hbm_address = 2 + + # Set up mock keygen variables + self.keygen_var = MagicMock(spec=mem_info.MemInfoVariable) + self.keygen_var.var_name = "keygen_var" + self.keygen_var.hbm_address = 3 + + # Set up mock metadata variables + self.meta_var = MagicMock(spec=mem_info.MemInfoVariable) + self.meta_var.var_name = "meta_var" + self.meta_var.hbm_address = 4 + + # Configure mock mem_info + self.mock_mem_info.inputs = [self.input_var] + self.mock_mem_info.outputs = [self.output_var] + self.mock_mem_info.keygens = [self.keygen_var] + + # Configure metadata + mock_metadata = MagicMock() + mock_metadata.intt_auxiliary_table = [self.meta_var] + mock_metadata.intt_routing_table = [] + mock_metadata.ntt_auxiliary_table = [] + mock_metadata.ntt_routing_table = [] + mock_metadata.ones = [] + mock_metadata.twiddle = [] + mock_metadata.keygen_seeds = [] + self.mock_mem_info.metadata = mock_metadata + + # Create the memory model + self.memory_model = MemoryModel(10, self.mock_mem_info) + + def test_init(self): + """@brief Test initialization of MemoryModel. + + @test Verifies that MemoryModel is properly initialized + """ + self.assertIsInstance(self.memory_model.hbm, HBM) + self.assertEqual(self.memory_model.hbm.capacity, 10) + + # Check that variables are correctly initialized + self.assertEqual(len(self.memory_model.variables), 0) + + # Check mem_info_vars + self.assertIn("input_var", self.memory_model.mem_info_vars) + self.assertIn("output_var", self.memory_model.mem_info_vars) + self.assertIn("meta_var", self.memory_model.mem_info_vars) + self.assertNotIn("keygen_var", self.memory_model.mem_info_vars) + + # Check mem_info_meta + self.assertIn("meta_var", self.memory_model.mem_info_meta) + + def test_add_variable_new(self): + """@brief Test adding a new variable. + + @test Verifies that a new variable is correctly added to the model + """ + self.memory_model.add_variable("test_var") + + # Check that variable was added + self.assertIn("test_var", self.memory_model.variables) + var_info = self.memory_model.variables["test_var"] + self.assertEqual(var_info.var_name, "test_var") + self.assertEqual(var_info.uses, 1) + + # Since it's not in mem_info_vars, it should not have an HBM address yet + self.assertEqual(var_info.hbm_address, -1) + + def test_add_variable_existing(self): + """@brief Test adding an existing variable. + + @test Verifies that the uses count is incremented for an existing variable + """ + # Add the variable first + self.memory_model.add_variable("test_var") + + # Add it again + self.memory_model.add_variable("test_var") + + # Check that the uses were incremented + var_info = self.memory_model.variables["test_var"] + self.assertEqual(var_info.uses, 2) + + def test_add_variable_from_mem_info(self): + """@brief Test adding a variable that's in mem_info. + + @test Verifies that a variable from mem_info is correctly added with its HBM address + """ + self.memory_model.add_variable("input_var") + + # Check that variable was added + self.assertIn("input_var", self.memory_model.variables) + var_info = self.memory_model.variables["input_var"] + self.assertEqual(var_info.var_name, "input_var") + self.assertEqual(var_info.uses, 1) + + # It should have the HBM address from mem_info + self.assertEqual(var_info.hbm_address, 1) + + def test_add_variable_from_fixed_addr_vars(self): + """@brief Test adding a variable that's in fixed_addr_vars. + + @test Verifies that a fixed-address variable is added with infinite uses + """ + self.memory_model.add_variable("output_var") + + # Check that variable was added + self.assertIn("output_var", self.memory_model.variables) + var_info = self.memory_model.variables["output_var"] + + # It should have infinite uses (float('inf')) + self.assertEqual(var_info.uses, float("inf") + 1) + + # It should have the HBM address from mem_info + self.assertEqual(var_info.hbm_address, 2) + + def test_use_variable(self): + """@brief Test using a variable. + + @test Verifies that using a variable decrements its uses count and allocates an HBM address + """ + # Add the variable first + self.memory_model.add_variable("test_var") + + # Use the variable + hbm_address = self.memory_model.use_variable("test_var", 1) + + # Check that uses were decremented + var_info = self.memory_model.variables["test_var"] + self.assertEqual(var_info.uses, 0) + + # Check that last_kernel_used was updated + self.assertEqual(var_info.last_kernel_used, 1) + + # Check that hbm_address was allocated and returned + self.assertGreaterEqual(hbm_address, 0) + self.assertEqual(var_info.hbm_address, hbm_address) + + # Check that the variable is in the HBM buffer + self.assertEqual(self.memory_model.hbm.buffer[hbm_address], var_info) + + def test_use_variable_already_allocated(self): + """@brief Test using a variable that already has an HBM address. + + @test Verifies that the existing HBM address is returned + """ + # Add a variable from mem_info which already has an HBM address + self.memory_model.add_variable("input_var") + + # Use the variable + hbm_address = self.memory_model.use_variable("input_var", 1) + + # Check that the returned HBM address is the one from mem_info + self.assertEqual(hbm_address, 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/__init__.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/__init__.py new file mode 100644 index 00000000..58bdd8eb --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/__init__.py @@ -0,0 +1,4 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# Empty init file to make the directory a Python package diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/__init__.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/__init__.py new file mode 100644 index 00000000..58bdd8eb --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/__init__.py @@ -0,0 +1,4 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# Empty init file to make the directory a Python package diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dinstruction.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dinstruction.py new file mode 100644 index 00000000..a6c2cf0a --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dinstruction.py @@ -0,0 +1,119 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +@brief Unit tests for the DInstruction base class. + +This module tests the core functionality of the DInstruction class which +serves as the base for all data instructions. +""" + +import unittest + +from linker.instructions.dinst.dinstruction import DInstruction + + +class TestDInstruction(unittest.TestCase): + """ + @brief Test cases for the DInstruction base class. + + @details These tests verify the common functionality shared by all data instructions, + including token handling, ID generation, and property access. + """ + + def setUp(self): + # Create a concrete subclass for testing since DInstruction is abstract + class ConcreteDInstruction(DInstruction): + """ + @brief Concrete implementation of DInstruction for testing purposes. + + @details This class provides implementations of the abstract methods + required to instantiate and test the DInstruction class. + """ + + @classmethod + def _get_num_tokens(cls) -> int: + return 3 + + @classmethod + def _get_name(cls) -> str: + return "test_instruction" + + self.d_instruction_class = ConcreteDInstruction # Changed to snake_case + self.tokens = ["test_instruction", "var1", "123"] + self.comment = "Test comment" + self.dinst = self.d_instruction_class(self.tokens, self.comment) + + def test_get_name_token_index(self): + """@brief Test _get_name_token_index returns 0 + + @test Verifies the name token is at index 0 + """ + self.assertEqual( + self.d_instruction_class.name_token_index, 0 + ) # Updated reference + + def test_num_tokens_property(self): + """@brief Test num_tokens property returns expected value + + @test Verifies the num_tokens property returns the value from _get_num_tokens + """ + self.assertEqual(self.d_instruction_class.num_tokens, 3) # Updated reference + + def test_initialization_valid_tokens(self): + """@brief Test initialization with valid tokens + + @test Verifies an instance can be created with valid tokens and properties are set correctly + """ + inst = self.d_instruction_class(self.tokens, self.comment) + self.assertEqual(inst.tokens, self.tokens) + self.assertEqual(inst.comment, self.comment) + self.assertIsNotNone(inst.id) + + def test_initialization_token_count_too_few(self): + """@brief Test initialization with too few tokens + + @test Verifies ValueError is raised when too few tokens are provided + """ + with self.assertRaises(ValueError): + self.d_instruction_class(["test_instruction", "var1"]) + + def test_initialization_invalid_name(self): + """@brief Test initialization with invalid name token + + @test Verifies ValueError is raised when an invalid instruction name is provided + """ + with self.assertRaises(ValueError): + self.d_instruction_class(["wrong_name", "var1", "123"]) + + def test_id_property(self): + """@brief Test id property returns a unique id + + @test Verifies each instruction instance gets a unique ID + """ + inst1 = self.d_instruction_class(self.tokens) + inst2 = self.d_instruction_class(self.tokens) + self.assertNotEqual(inst1.id, inst2.id) + + def test_to_line_method(self): + """@brief Test to_line method returns expected string + + @test Verifies the to_line method correctly formats the instruction as a string + """ + tokens = ["test_instruction", "var1", "123"] + inst = self.d_instruction_class(tokens, "") + expected = "test_instruction, var1, 123" + self.assertEqual(inst.to_line(), expected) + + def test_consecutive_ids(self): + """@brief Test that consecutive instructions get incremental ids + + @test Verifies IDs are incremented sequentially for new instances + """ + inst1 = self.d_instruction_class(self.tokens) + inst2 = self.d_instruction_class(self.tokens) + self.assertEqual(inst2.id, inst1.id + 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dkeygen.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dkeygen.py new file mode 100644 index 00000000..f9769657 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dkeygen.py @@ -0,0 +1,143 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +@brief Unit tests for the DKeygen instruction class. + +This module tests the functionality of the DKeygen instruction which is +responsible for key generation operations. +""" + +import unittest +from unittest.mock import patch + +from assembler.memory_model.mem_info import MemInfo +from linker.instructions.dinst.dkeygen import Instruction + + +class TestDKeygenInstruction(unittest.TestCase): + """ + @brief Test cases for the DKeygen instruction class. + + @details These tests verify that the DKeygen instruction correctly handles token + parsing, name resolution, and serialization. + """ + + def setUp(self): + # Create the instruction with sample parameters + self.seed_idx = 1 + self.key_idx = 2 + self.var_name = "var1" + self.inst = Instruction( + [Instruction.name, self.seed_idx, self.key_idx, self.var_name] + ) + + def test_get_num_tokens(self): + """@brief Test that _get_num_tokens returns 4 + + @test Verifies the instruction requires exactly 4 tokens + """ + self.assertEqual(Instruction.num_tokens, 4) + + def test_get_name(self): + """@brief Test that _get_name returns the expected value + + @test Verifies the instruction name matches the MemInfo constant + """ + self.assertEqual(Instruction.name, MemInfo.Const.Keyword.KEYGEN) + + def test_initialization_valid_input(self): + """@brief Test that initialization can set up the correct properties with valid name + + @test Verifies the instruction is properly initialized with valid tokens + """ + inst = Instruction( + [MemInfo.Const.Keyword.KEYGEN, self.seed_idx, self.key_idx, self.var_name] + ) + self.assertEqual(inst.name, MemInfo.Const.Keyword.KEYGEN) + + def test_initialization_invalid_name(self): + """@brief Test that initialization raises exception with invalid name + + @test Verifies ValueError is raised when an invalid instruction name is provided + """ + with self.assertRaises(ValueError): # Adjust exception type if needed + Instruction(["invalid_name", self.seed_idx, self.key_idx, self.var_name]) + + def test_tokens_property(self): + """@brief Test that tokens property returns the correct list + + @test Verifies the tokens property correctly formats the instruction tokens + """ + # Since tokens property implementation is not visible in the dkeygen.py file, + # this test assumes default behavior from parent class or basic functionality + expected_tokens = [ + MemInfo.Const.Keyword.KEYGEN, + self.seed_idx, + self.key_idx, + self.var_name, + ] + self.assertEqual(self.inst.tokens[:4], expected_tokens) + + def test_tokens_with_additional_data(self): + """@brief Test tokens property with additional tokens + + @test Verifies extra tokens are preserved in the tokens property + """ + additional_token = "extra" + inst_with_extra = Instruction( + [ + Instruction.name, + self.seed_idx, + self.key_idx, + self.var_name, + additional_token, + ] + ) + # If tokens property uses default implementation, it should include the additional token + self.assertIn(additional_token, inst_with_extra.tokens) + + @patch( + "linker.instructions.dinst.dinstruction.DInstruction.__init__", + return_value=None, + ) + def test_inheritance(self, mock_init): + """@brief Test that Instruction properly extends DInstruction + + @test Verifies the parent constructor is called during initialization + """ + # Ensure that DInstruction methods are called as expected + Instruction([Instruction.name, self.seed_idx, self.key_idx, self.var_name]) + # Verify DInstruction.__init__ was called + mock_init.assert_called() + + def test_invalid_token_count_too_few(self): + """@brief Test behavior when fewer tokens than required are provided + + @test Verifies ValueError is raised when too few tokens are provided + """ + with self.assertRaises(ValueError): # Adjust exception type if needed + Instruction([MemInfo.Const.Keyword.KEYGEN, self.seed_idx, self.key_idx]) + + def test_invalid_token_count_too_many(self): + """@brief Test behavior when more tokens than required are provided + + @test Verifies extra tokens are handled gracefully without errors + """ + # This should not raise an error as additional tokens are handled + inst = Instruction( + [ + MemInfo.Const.Keyword.KEYGEN, + self.seed_idx, + self.key_idx, + self.var_name, + "extra1", + "extra2", + ] + ) + # Check that basic properties are still set correctly + self.assertEqual(inst.name, MemInfo.Const.Keyword.KEYGEN) + + +if __name__ == "__main__": + unittest.main() diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dload.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dload.py new file mode 100644 index 00000000..c6168790 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dload.py @@ -0,0 +1,162 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +@brief Unit tests for the DLoad instruction class. + +This module tests the functionality of the DLoad instruction which is +responsible for loading data from memory locations. +""" + +import unittest +from unittest.mock import patch + +from assembler.memory_model.mem_info import MemInfo +from linker.instructions.dinst.dload import Instruction + + +class TestDLoadInstruction(unittest.TestCase): + """ + @brief Test cases for the DLoad instruction class. + + @details These tests verify that the DLoad instruction correctly handles token + parsing, name resolution, and serialization. + """ + + def setUp(self): + # Create the instruction + self.var_name = "test_var" + self.address = 123 + self.type = "type1" + + def test_get_num_tokens(self): + """@brief Test that _get_num_tokens returns 3 + + @test Verifies the instruction requires exactly 3 tokens + """ + self.assertEqual(Instruction.num_tokens, 3) + + def test_get_name(self): + """@brief Test that _get_name returns the expected value + + @test Verifies the instruction name matches the MemInfo constant + """ + self.assertEqual(Instruction.name, MemInfo.Const.Keyword.LOAD) + + def test_initialization_valid_input(self): + """@brief Test that initialization can set up the correct properties with valid name + + @test Verifies the instruction is properly initialized with valid tokens + """ + inst = Instruction( + [MemInfo.Const.Keyword.LOAD, self.type, str(self.address), self.var_name] + ) + + self.assertEqual(inst.name, MemInfo.Const.Keyword.LOAD) + + def test_initialization_valid_meta(self): + """@brief Test that initialization can set up the correct properties with metadata + + @test Verifies the instruction handles metadata loading correctly + """ + inst = Instruction([MemInfo.Const.Keyword.LOAD, self.type, str(self.address)]) + + self.assertEqual(inst.name, MemInfo.Const.Keyword.LOAD) + + def test_initialization_invalid_name(self): + """@brief Test that initialization raises exception with invalid name + + @test Verifies ValueError is raised when an invalid instruction name is provided + """ + with self.assertRaises(ValueError): # Adjust exception type if needed + Instruction(["invalid_name", self.type, str(self.address), self.var_name]) + + def test_tokens_property(self): + """@brief Test that tokens property returns the correct list + + @test Verifies the tokens property correctly formats the instruction tokens + """ + expected_tokens = [ + MemInfo.Const.Keyword.LOAD, + self.type, + str(self.address), + self.var_name, + ] + inst = Instruction( + [Instruction.name, self.type, str(self.address), self.var_name] + ) + + # Manually set properties to match expected behavior + inst.address = self.address + self.assertEqual(inst.tokens, expected_tokens) + + def test_tokens_with_additional_data(self): + """@brief Test tokens property with additional tokens + + @test Verifies extra tokens are preserved in the tokens property + """ + additional_token = "extra" + inst_with_extra = Instruction( + [ + Instruction.name, + self.type, + str(self.address), + self.var_name, + additional_token, + ] + ) + inst_with_extra.address = self.address + expected_tokens = [ + MemInfo.Const.Keyword.LOAD, + self.type, + str(self.address), + self.var_name, + additional_token, + ] + self.assertEqual(inst_with_extra.tokens, expected_tokens) + + @patch( + "linker.instructions.dinst.dinstruction.DInstruction.__init__", + return_value=None, + ) + def test_inheritance(self, mock_init): + """@brief Test that Instruction properly extends DInstruction + + @test Verifies the parent constructor is called during initialization + """ + # Ensure that DInstruction methods are called as expected + Instruction([Instruction.name, self.type, str(self.address), self.var_name]) + # Verify DInstruction.__init__ was called + mock_init.assert_called() + + def test_invalid_token_count_too_few(self): + """@brief Test behavior when fewer tokens than required are provided + + @test Verifies ValueError is raised when too few tokens are provided + """ + with self.assertRaises(ValueError): # Adjust exception type if needed + Instruction([MemInfo.Const.Keyword.LOAD, self.var_name]) + + def test_invalid_token_count_too_many(self): + """@brief Test behavior when more tokens than required are provided + + @test Verifies extra tokens are handled gracefully without errors + """ + # This should not raise an error as additional tokens are handled + inst = Instruction( + [ + MemInfo.Const.Keyword.LOAD, + self.type, + str(self.address), + self.var_name, + "extra1", + "extra2", + ] + ) + + # Check that basic properties are still set correctly + self.assertEqual(inst.name, MemInfo.Const.Keyword.LOAD) + + +if __name__ == "__main__": + unittest.main() diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dstore.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dstore.py new file mode 100644 index 00000000..938a9db9 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_dstore.py @@ -0,0 +1,149 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +@brief Unit tests for the DStore instruction class. + +This module tests the functionality of the DStore instruction which is +responsible for storing data to memory locations. +""" + +import unittest +from unittest.mock import patch + +from assembler.memory_model.mem_info import MemInfo +from linker.instructions.dinst.dstore import Instruction + + +class TestDStoreInstruction(unittest.TestCase): + """ + @brief Test cases for the DStore instruction class. + + @details These tests verify that the DStore instruction correctly handles token + parsing, name resolution, and serialization. + """ + + def setUp(self): + # Create the instruction + self.var_name = "test_var" + self.address = 123 + + def test_get_num_tokens(self): + """@brief Test that _get_num_tokens returns 3 + + @test Verifies the instruction requires exactly 3 tokens + """ + self.assertEqual(Instruction.num_tokens, 3) + + def test_get_name(self): + """@brief Test that _get_name returns the expected value + + @test Verifies the instruction name matches the MemInfo constant + """ + self.assertEqual(Instruction.name, MemInfo.Const.Keyword.STORE) + + def test_initialization_valid_input(self): + """@brief Test that initialization can set up the correct properties with valid name + + @test Verifies the instruction is properly initialized with valid tokens + """ + inst = Instruction( + [MemInfo.Const.Keyword.STORE, self.var_name, str(self.address)] + ) + + self.assertEqual(inst.name, MemInfo.Const.Keyword.STORE) + + def test_initialization_invalid_name(self): + """@brief Test that initialization raises exception with invalid name + + @test Verifies ValueError is raised when an invalid instruction name is provided + """ + with self.assertRaises(ValueError): # Adjust exception type if needed + Instruction(["invalid_name", self.var_name, str(self.address)]) + + def test_tokens_property(self): + """@brief Test that tokens property returns the correct list + + @test Verifies the tokens property correctly formats the instruction tokens + """ + expected_tokens = [ + MemInfo.Const.Keyword.STORE, + self.var_name, + str(self.address), + ] + inst = Instruction([Instruction.name, self.var_name, str(self.address)]) + + # Manually set properties to match expected behavior + inst.var = self.var_name + inst.address = self.address + + self.assertEqual(inst.tokens, expected_tokens) + + def test_tokens_with_additional_data(self): + """@brief Test tokens property with additional tokens + + @test Verifies extra tokens are preserved in the tokens property + """ + additional_token = "extra" + inst_with_extra = Instruction( + [ + Instruction.name, + self.var_name, + str(self.address), + additional_token, + ] + ) + inst_with_extra.var = self.var_name + inst_with_extra.address = self.address + expected_tokens = [ + MemInfo.Const.Keyword.STORE, + self.var_name, + str(self.address), + additional_token, + ] + self.assertEqual(inst_with_extra.tokens, expected_tokens) + + @patch( + "linker.instructions.dinst.dinstruction.DInstruction.__init__", + return_value=None, + ) + def test_inheritance(self, mock_init): + """@brief Test that Instruction properly extends DInstruction + + @test Verifies the parent constructor is called during initialization + """ + # Ensure that DInstruction methods are called as expected + Instruction([Instruction.name, self.var_name, str(self.address)]) + # Verify DInstruction.__init__ was called + mock_init.assert_called() + + def test_invalid_token_count_too_few(self): + """@brief Test behavior when fewer tokens than required are provided + + @test Verifies ValueError is raised when too few tokens are provided + """ + with self.assertRaises(ValueError): # Adjust exception type if needed + Instruction([MemInfo.Const.Keyword.STORE]) + + def test_invalid_token_count_too_many(self): + """@brief Test behavior when more tokens than required are provided + + @test Verifies extra tokens are handled gracefully without errors + """ + # This should not raise an error as additional tokens are handled + inst = Instruction( + [ + MemInfo.Const.Keyword.STORE, + self.var_name, + str(self.address), + "extra1", + "extra2", + ] + ) + + # Check that basic properties are still set correctly + self.assertEqual(inst.name, MemInfo.Const.Keyword.STORE) + + +if __name__ == "__main__": + unittest.main() diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_init.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_init.py new file mode 100644 index 00000000..26a9a799 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_dinst/test_init.py @@ -0,0 +1,186 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +@brief Unit tests for the dinst package initialization module. + +This module tests the factory functions and initialization utilities for +data instructions. +""" + +import unittest +from unittest.mock import patch, MagicMock + +from linker.instructions.dinst import factory, create_from_mem_line +from linker.instructions.dinst import DLoad, DStore, DKeyGen + + +class TestDInstModule(unittest.TestCase): + """ + @brief Test cases for data instruction initialization. + + @details These tests verify that the data instruction factory correctly creates + instruction instances and properly handles initialization errors. + """ + + def test_factory(self): + """@brief Test that factory returns the expected set of instruction classes + + @test Verifies the factory returns a set containing DLoad, DStore, and DKeyGen + """ + instruction_set = factory() + self.assertIsInstance(instruction_set, set) + self.assertEqual(len(instruction_set), 3) + self.assertIn(DLoad, instruction_set) + self.assertIn(DStore, instruction_set) + self.assertIn(DKeyGen, instruction_set) + + @patch("assembler.instructions.tokenize_from_line") + @patch("assembler.memory_model.mem_info.MemInfo.get_meminfo_var_from_tokens") + def test_create_from_mem_line_dload_input(self, mock_get_meminfo, mock_tokenize): + """@brief Test create_from_mem_line creates DLoad instruction + + @test Verifies that a DLoad instruction is created with correct properties + """ + # Setup mocks + tokens = ["dload", "poly", "0x123", "var1"] + comment = "Test comment" + mock_tokenize.return_value = (tokens, comment) + + # Setup MemInfo mock + miv_mock = MagicMock() + miv_mock.as_dict.return_value = {"var_name": "var1", "hbm_address": 0x123} + mock_get_meminfo.return_value = (miv_mock, None) + + # Call function under test + result = create_from_mem_line("dload, poly, 0x123, var1 # Test comment") + + # Verify results + self.assertIsNotNone(result) + self.assertIsInstance(result, DLoad) + self.assertEqual(result.var, "var1") + self.assertEqual(result.address, 0x123) + + @patch("assembler.instructions.tokenize_from_line") + @patch("assembler.memory_model.mem_info.MemInfo.get_meminfo_var_from_tokens") + def test_create_from_mem_line_dload_meta(self, mock_get_meminfo, mock_tokenize): + """@brief Test create_from_mem_line creates DLoad instruction for metadata + + @test Verifies that a DLoad instruction is created for metadata entries + """ + # Setup mocks + tokens = ["dload", "meta", "1"] + comment = "Test comment" + mock_tokenize.return_value = (tokens, comment) + + # Setup MemInfo mock + miv_mock = MagicMock() + miv_mock.as_dict.return_value = {"var_name": "meta1", "hbm_address": 1} + mock_get_meminfo.return_value = (miv_mock, None) + + # Call function under test + result = create_from_mem_line("dload, meta, 1 # Test comment") + + # Verify results + self.assertIsNotNone(result) + self.assertIsInstance(result, DLoad) + self.assertEqual(result.var, "meta1") + self.assertEqual(result.address, 1) + + @patch("assembler.instructions.tokenize_from_line") + @patch("assembler.memory_model.mem_info.MemInfo.get_meminfo_var_from_tokens") + def test_create_from_mem_line_dstore(self, mock_get_meminfo, mock_tokenize): + """@brief Test create_from_mem_line creates DStore instruction + + @test Verifies that a DStore instruction is created with correct properties + """ + # Setup mocks + tokens = ["dstore", "var1", "0x456"] + comment = "Test comment" + mock_tokenize.return_value = (tokens, comment) + + # Setup MemInfo mock + miv_mock = MagicMock() + miv_mock.as_dict.return_value = {"var_name": "var1", "hbm_address": 0x456} + mock_get_meminfo.return_value = (miv_mock, None) + + # Call function under test + result = create_from_mem_line("dstore, var1, 0x456 # Test comment") + + # Verify results + self.assertIsNotNone(result) + self.assertIsInstance(result, DStore) + self.assertEqual(result.var, "var1") + self.assertEqual(result.address, 0x456) + + @patch("assembler.instructions.tokenize_from_line") + @patch("assembler.memory_model.mem_info.MemInfo.get_meminfo_var_from_tokens") + def test_create_from_mem_line_dkeygen(self, mock_get_meminfo, mock_tokenize): + """@brief Test create_from_mem_line creates DKeyGen instruction + + @test Verifies that a DKeyGen instruction is created with correct properties + """ + # Setup mocks + tokens = ["keygen", "key1", "type1", "256"] + comment = "Test comment" + mock_tokenize.return_value = (tokens, comment) + + # Setup MemInfo mock + miv_mock = MagicMock() + miv_mock.as_dict.return_value = {"var_name": "key1", "hbm_address": 0x0} + mock_get_meminfo.return_value = (miv_mock, None) + + # Call function under test + result = create_from_mem_line("keygen, key1, type1, 256 # Test comment") + + # Verify results + self.assertIsNotNone(result) + self.assertIsInstance(result, DKeyGen) + # Verify var and address were set correctly + self.assertEqual(result.var, "key1") + self.assertEqual(result.address, 0x0) + + @patch("assembler.instructions.tokenize_from_line") + @patch("assembler.memory_model.mem_info.MemInfo.get_meminfo_var_from_tokens") + def test_create_from_mem_line_invalid(self, mock_get_meminfo, mock_tokenize): + """@brief Test create_from_mem_line with invalid instruction + + @test Verifies that RuntimeError is raised for invalid instructions + """ + # Setup mocks to return invalid tokens + tokens = ["invalid_instruction", "var1", "0x123"] + comment = "" + mock_tokenize.return_value = (tokens, comment) + + # Make get_meminfo_var_from_tokens raise RuntimeError + mock_get_meminfo.side_effect = RuntimeError("Invalid instruction") + + # This should raise RuntimeError due to no valid instruction found + with self.assertRaises(RuntimeError): + create_from_mem_line("invalid_instruction, var1, 0x123") + + @patch("assembler.instructions.tokenize_from_line") + @patch("assembler.memory_model.mem_info.MemInfo.get_meminfo_var_from_tokens") + def test_create_from_mem_line_meminfo_error(self, mock_get_meminfo, mock_tokenize): + """@brief Test create_from_mem_line with MemInfo error + + @test Verifies that RuntimeError is wrapped with line information + """ + # Setup mocks + tokens = ["dstore", "var1", "0x123"] + comment = "" + mock_tokenize.return_value = (tokens, comment) + + # Make get_meminfo_var_from_tokens raise RuntimeError + mock_get_meminfo.side_effect = RuntimeError("Test error") + + # This should wrap the RuntimeError with information about the line + with self.assertRaises(RuntimeError) as context: + create_from_mem_line("dstore, var1, 0x123") + + # Verify the error message contains the original line + self.assertIn("dstore, var1, 0x123", str(context.exception)) + + +if __name__ == "__main__": + unittest.main() diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_init.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_init.py new file mode 100644 index 00000000..a0e1f1f0 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_init.py @@ -0,0 +1,144 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + +""" +@brief Unit tests for the linker instructions initialization module. + +This module contains tests that verify the behavior of the instruction factory +and initialization functionality. +""" + +import unittest +from unittest.mock import patch, MagicMock + +from linker.instructions import create_from_str_line + + +class TestCreateFromStrLine(unittest.TestCase): + """ + @brief Test cases for instruction initialization functionality. + + These tests verify that instructions are correctly initialized, + their tokens are properly processed, and their factories work as expected. + """ + + def setUp(self): + # Create a mock class (not instance) + self.mock_class = MagicMock() + + # Create a mock instance that will be returned when the class is called + self.mock_instance = MagicMock() + self.mock_instance.__bool__.return_value = True + + # Configure the class to return the instance when called + self.mock_class.return_value = self.mock_instance + + # Create a factory with the mock class + self.factory = {self.mock_class} + + @patch("linker.instructions.tokenize_from_line") + def test_create_from_str_line_success(self, mock_tokenize): + """ + @brief Test successful instruction creation + + @test Verifies that an instruction is correctly created from a string line + when a valid factory is provided + """ + # Setup mock + tokens = ["instruction", "arg1", "arg2"] + comment = "Test comment" + mock_tokenize.return_value = (tokens, comment) + + # Call function + result = create_from_str_line( + "instruction, arg1, arg2 # Test comment", self.factory + ) + + # Verify + mock_tokenize.assert_called_once_with("instruction, arg1, arg2 # Test comment") + self.mock_class.assert_called_once_with(tokens, comment) + self.assertEqual(result, self.mock_instance) + + @patch("linker.instructions.tokenize_from_line") + def test_create_from_str_line_failure(self, mock_tokenize): + """ + @brief Test when no instruction can be created + + @test Verifies that None is returned when instruction creation fails + """ + # Setup mock + tokens = ["unknown", "arg1", "arg2"] + comment = "Test comment" + mock_tokenize.return_value = (tokens, comment) + + # Make instruction creation fail + self.mock_class.side_effect = ValueError("Invalid instruction") + + # Call function + result = create_from_str_line( + "unknown, arg1, arg2 # Test comment", self.factory + ) + + # Verify + self.assertIsNone(result) + + @patch("linker.instructions.tokenize_from_line") + def test_create_from_str_line_multiple_instruction_types(self, mock_tokenize): + """ + @brief Test with multiple instruction types in factory + + @test Verifies that the function tries each instruction type in the factory + until one succeeds + """ + # Setup mocks + tokens = ["instruction", "arg1", "arg2"] + comment = "Test comment" + mock_tokenize.return_value = (tokens, comment) + + # Create a second mock instruction class that fails + mock_class2 = MagicMock() + mock_class2.side_effect = ValueError("Invalid instruction") + + # Set up the factory with a specific order - use a list instead of a set + # to control the iteration order + factory = [mock_class2, self.mock_class] + + # Call function + result = create_from_str_line("instruction, arg1, arg2 # Test comment", factory) + + # Verify that it tried both instruction types and returned the successful one + mock_class2.assert_called_once_with(tokens, comment) + self.mock_class.assert_called_once_with(tokens, comment) + mock_class2.assert_called_once() + self.assertEqual(result, self.mock_instance) + + @patch("linker.instructions.tokenize_from_line") + def test_create_from_str_line_exception_handling(self, mock_tokenize): + """ + @brief Test that specific exceptions are caught + + @test Verifies that expected exceptions during instruction creation are + handled gracefully and None is returned + """ + # Setup mock + tokens = ["instruction", "arg1", "arg2"] + comment = "Test comment" + mock_tokenize.return_value = (tokens, comment) + + # Make instruction creation raise one of the expected exception types + self.mock_class.side_effect = ValueError("Invalid values") + + # Call function - should handle the exception and return None + result = create_from_str_line( + "instruction, arg1, arg2 # Test comment", self.factory + ) + + # Verify + self.assertIsNone(result) + + +if __name__ == "__main__": + unittest.main() diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_instruction.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_instruction.py new file mode 100644 index 00000000..b5842745 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_instructions/test_instruction.py @@ -0,0 +1,171 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + +""" +@brief Unit tests for the BaseInstruction class. +""" + +import os +import unittest +import tempfile +from unittest.mock import patch + +from assembler.common.config import GlobalConfig +from linker.instructions.instruction import BaseInstruction + + +class MockInstruction(BaseInstruction): + """@brief Concrete implementation of BaseInstruction for testing.""" + + @classmethod + def _get_name(cls) -> str: + return "TEST" + + @classmethod + def _get_name_token_index(cls) -> int: + return 0 + + @classmethod + def _get_num_tokens(cls) -> int: + return 3 + + +class TestBaseInstruction(unittest.TestCase): + """@brief Tests for the BaseInstruction class.""" + + def setUp(self): + """@brief Setup for tests.""" + self.valid_tokens = ["TEST", "arg1", "arg2"] + self.comment = "This is a test comment" + + def test_init_valid(self): + """@brief Test initialization with valid tokens. + + @test Verifies that an instruction can be correctly initialized with valid tokens + """ + instruction = MockInstruction(self.valid_tokens, self.comment) + self.assertEqual(instruction.tokens, self.valid_tokens) + self.assertEqual(instruction.comment, self.comment) + + def test_init_invalid_name(self): + """@brief Test initialization with invalid instruction name. + + @test Verifies that a ValueError is raised when the instruction name is invalid + """ + invalid_tokens = ["WRONG", "arg1", "arg2"] + with self.assertRaises(ValueError) as context: + MockInstruction(invalid_tokens) + self.assertIn("invalid name", str(context.exception)) + + def test_init_invalid_num_tokens(self): + """@brief Test initialization with incorrect number of tokens. + + @test Verifies that a ValueError is raised when the number of tokens is incorrect + """ + invalid_tokens = ["TEST", "arg1"] + with self.assertRaises(ValueError) as context: + MockInstruction(invalid_tokens) + self.assertIn("invalid amount of tokens", str(context.exception)) + + def test_id_generation(self): + """@brief Test that each instruction gets a unique ID. + + @test Verifies that different instruction instances have different IDs + """ + instruction1 = MockInstruction(self.valid_tokens) + instruction2 = MockInstruction(self.valid_tokens) + self.assertNotEqual(instruction1.id, instruction2.id) + + def test_str_representation(self): + """@brief Test string representation. + + @test Verifies that the string representation is correctly formatted + """ + instruction = MockInstruction(self.valid_tokens) + self.assertEqual(str(instruction), f"TEST({instruction.id})") + + def test_repr_representation(self): + """@brief Test repr representation. + + @test Verifies that the repr representation contains the expected information + """ + instruction = MockInstruction(self.valid_tokens) + self.assertIn("MockInstruction(TEST, id=", repr(instruction)) + self.assertIn("tokens=", repr(instruction)) + + def test_equality(self): + """@brief Test equality operator. + + @test Verifies that equality is based on object identity rather than value + """ + instruction1 = MockInstruction(self.valid_tokens) + instruction2 = MockInstruction(self.valid_tokens) + self.assertNotEqual(instruction1, instruction2) + self.assertEqual(instruction1, instruction1) + + def test_hash(self): + """@brief Test hash function. + + @test Verifies that the hash is based on the instruction's ID + """ + instruction = MockInstruction(self.valid_tokens) + self.assertEqual(hash(instruction), hash(instruction.id)) + + def test_to_line_with_comment(self): + """@brief Test to_line method with comment. + + @test Verifies that to_line correctly includes the comment when present + """ + instruction = MockInstruction(self.valid_tokens, self.comment) + expected = f"TEST, arg1, arg2 # {self.comment}" + self.assertEqual(instruction.to_line(), expected) + + def test_to_line_without_comment(self): + """@brief Test to_line method without comment. + + @test Verifies that to_line works correctly when no comment is present + """ + instruction = MockInstruction(self.valid_tokens) + expected = "TEST, arg1, arg2" + self.assertEqual(instruction.to_line(), expected) + + def test_to_line_suppressed_comments(self): + """@brief Test to_line method with suppressed comments. + + @test Verifies that comments are not included when GlobalConfig.suppress_comments is True + """ + with patch.object(GlobalConfig, "suppress_comments", True): + instruction = MockInstruction(self.valid_tokens, self.comment) + expected = "TEST, arg1, arg2" + self.assertEqual(instruction.to_line(), expected) + + def test_dump_instructions_to_file(self): + """@brief Test dump_instructions_to_file method. + + @test Verifies that instructions are correctly written to a file + """ + instruction1 = MockInstruction(self.valid_tokens, "Comment 1") + instruction2 = MockInstruction(self.valid_tokens, "Comment 2") + instructions = [instruction1, instruction2] + + with tempfile.NamedTemporaryFile(delete=False) as temp_file: + file_path = temp_file.name + + try: + BaseInstruction.dump_instructions_to_file(instructions, file_path) + + with open(file_path, "r", encoding="utf-8") as f: + lines = f.read().splitlines() + + self.assertEqual(len(lines), 2) + self.assertEqual(lines[0], instruction1.to_line()) + self.assertEqual(lines[1], instruction2.to_line()) + finally: + os.unlink(file_path) + + +if __name__ == "__main__": + unittest.main() diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_loader.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_loader.py new file mode 100644 index 00000000..06a3827d --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_loader.py @@ -0,0 +1,368 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + +""" +@brief Unit tests for the loader module. +""" + +import unittest +from unittest.mock import patch, mock_open, MagicMock, call + +from linker.loader import ( + load_minst_kernel, + load_minst_kernel_from_file, + load_cinst_kernel, + load_cinst_kernel_from_file, + load_xinst_kernel, + load_xinst_kernel_from_file, + load_dinst_kernel, + load_dinst_kernel_from_file, +) + + +class TestLoader(unittest.TestCase): + """@brief Tests for the loader module functions.""" + + def setUp(self): + """@brief Set up test fixtures.""" + # Sample instruction lines for each type + self.minst_lines = ["MINST arg1, arg2", "MINST arg3, arg4"] + self.cinst_lines = ["CINST arg1, arg2", "CINST arg3, arg4"] + self.xinst_lines = ["XINST arg1, arg2", "XINST arg3, arg4"] + self.dinst_lines = ["DINST arg1, arg2", "DINST arg3, arg4"] + + # Create mock instruction objects + self.mock_minst = [MagicMock(), MagicMock()] + self.mock_cinst = [MagicMock(), MagicMock()] + self.mock_xinst = [MagicMock(), MagicMock()] + self.mock_dinst = [MagicMock(), MagicMock()] + + @patch("linker.instructions.create_from_str_line") + @patch("linker.instructions.minst.factory") + def test_load_minst_kernel_success(self, mock_factory, mock_create): + """@brief Test successful loading of MInstructions from an iterator. + + @test Verifies that MInstructions are properly created from string lines + """ + # Configure mocks + mock_factory.return_value = "minst_factory" + mock_create.side_effect = self.mock_minst + + # Call the function + result = load_minst_kernel(self.minst_lines) + + # Verify the results + self.assertEqual(result, self.mock_minst) + # Factory is called once per line, so 2 times total + self.assertEqual(mock_factory.call_count, 2) + self.assertEqual(mock_create.call_count, 2) + mock_create.assert_has_calls( + [ + call(self.minst_lines[0], "minst_factory"), + call(self.minst_lines[1], "minst_factory"), + ] + ) + + @patch("linker.instructions.create_from_str_line") + @patch("linker.instructions.minst.factory") + def test_load_minst_kernel_failure(self, mock_factory, mock_create): + """@brief Test error handling when loading MInstructions fails. + + @test Verifies that a RuntimeError is raised when parsing fails + """ + # Configure mocks + mock_factory.return_value = "minst_factory" + mock_create.return_value = None + + # Call the function and check for exception + with self.assertRaises(RuntimeError) as context: + load_minst_kernel(self.minst_lines) + + self.assertIn( + f"Error parsing line 1: {self.minst_lines[0]}", str(context.exception) + ) + + @patch("builtins.open", new_callable=mock_open) + @patch("linker.loader.load_minst_kernel") + def test_load_minst_kernel_from_file_success(self, mock_load, mock_file): + """@brief Test successful loading of MInstructions from a file. + + @test Verifies that file contents are properly passed to load_minst_kernel + """ + # Configure mocks + mock_file.return_value.__enter__.return_value = self.minst_lines + mock_load.return_value = self.mock_minst + + # Call the function + result = load_minst_kernel_from_file("test.minst") + + # Verify the results + self.assertEqual(result, self.mock_minst) + mock_file.assert_called_once_with("test.minst", "r", encoding="utf-8") + mock_load.assert_called_once_with(self.minst_lines) + + @patch("builtins.open", new_callable=mock_open) + @patch("linker.loader.load_minst_kernel") + def test_load_minst_kernel_from_file_failure(self, mock_load, mock_file): + """@brief Test error handling when loading MInstructions from a file fails. + + @test Verifies that a RuntimeError is raised with appropriate message + """ + # Configure mocks + mock_file.return_value.__enter__.return_value = self.minst_lines + mock_load.side_effect = Exception("Test error") + + # Call the function and check for exception + with self.assertRaises(RuntimeError) as context: + load_minst_kernel_from_file("test.minst") + + self.assertIn( + 'Error occurred loading file "test.minst"', str(context.exception) + ) + + @patch("linker.instructions.create_from_str_line") + @patch("linker.instructions.cinst.factory") + def test_load_cinst_kernel_success(self, mock_factory, mock_create): + """@brief Test successful loading of CInstructions from an iterator. + + @test Verifies that CInstructions are properly created from string lines + """ + # Configure mocks + mock_factory.return_value = "cinst_factory" + mock_create.side_effect = self.mock_cinst + + # Call the function + result = load_cinst_kernel(self.cinst_lines) + + # Verify the results + self.assertEqual(result, self.mock_cinst) + # Factory is called once per line, so 2 times total + self.assertEqual(mock_factory.call_count, 2) + self.assertEqual(mock_create.call_count, 2) + mock_create.assert_has_calls( + [ + call(self.cinst_lines[0], "cinst_factory"), + call(self.cinst_lines[1], "cinst_factory"), + ] + ) + + @patch("linker.instructions.create_from_str_line") + @patch("linker.instructions.cinst.factory") + def test_load_cinst_kernel_failure(self, mock_factory, mock_create): + """@brief Test error handling when loading CInstructions fails. + + @test Verifies that a RuntimeError is raised when parsing fails + """ + # Configure mocks + mock_factory.return_value = "cinst_factory" + mock_create.return_value = None + + # Call the function and check for exception + with self.assertRaises(RuntimeError) as context: + load_cinst_kernel(self.cinst_lines) + + self.assertIn( + f"Error parsing line 1: {self.cinst_lines[0]}", str(context.exception) + ) + + @patch("builtins.open", new_callable=mock_open) + @patch("linker.loader.load_cinst_kernel") + def test_load_cinst_kernel_from_file_success(self, mock_load, mock_file): + """@brief Test successful loading of CInstructions from a file. + + @test Verifies that file contents are properly passed to load_cinst_kernel + """ + # Configure mocks + mock_file.return_value.__enter__.return_value = self.cinst_lines + mock_load.return_value = self.mock_cinst + + # Call the function + result = load_cinst_kernel_from_file("test.cinst") + + # Verify the results + self.assertEqual(result, self.mock_cinst) + mock_file.assert_called_once_with("test.cinst", "r", encoding="utf-8") + mock_load.assert_called_once_with(self.cinst_lines) + + @patch("builtins.open", new_callable=mock_open) + @patch("linker.loader.load_cinst_kernel") + def test_load_cinst_kernel_from_file_failure(self, mock_load, mock_file): + """@brief Test error handling when loading CInstructions from a file fails. + + @test Verifies that a RuntimeError is raised with appropriate message + """ + # Configure mocks + mock_file.return_value.__enter__.return_value = self.cinst_lines + mock_load.side_effect = Exception("Test error") + + # Call the function and check for exception + with self.assertRaises(RuntimeError) as context: + load_cinst_kernel_from_file("test.cinst") + + self.assertIn( + 'Error occurred loading file "test.cinst"', str(context.exception) + ) + + @patch("linker.instructions.create_from_str_line") + @patch("linker.instructions.xinst.factory") + def test_load_xinst_kernel_success(self, mock_factory, mock_create): + """@brief Test successful loading of XInstructions from an iterator. + + @test Verifies that XInstructions are properly created from string lines + """ + # Configure mocks + mock_factory.return_value = "xinst_factory" + mock_create.side_effect = self.mock_xinst + + # Call the function + result = load_xinst_kernel(self.xinst_lines) + + # Verify the results + self.assertEqual(result, self.mock_xinst) + # Factory is called once per line, so 2 times total + self.assertEqual(mock_factory.call_count, 2) + self.assertEqual(mock_create.call_count, 2) + mock_create.assert_has_calls( + [ + call(self.xinst_lines[0], "xinst_factory"), + call(self.xinst_lines[1], "xinst_factory"), + ] + ) + + @patch("linker.instructions.create_from_str_line") + @patch("linker.instructions.xinst.factory") + def test_load_xinst_kernel_failure(self, mock_factory, mock_create): + """@brief Test error handling when loading XInstructions fails. + + @test Verifies that a RuntimeError is raised when parsing fails + """ + # Configure mocks + mock_factory.return_value = "xinst_factory" + mock_create.return_value = None + + # Call the function and check for exception + with self.assertRaises(RuntimeError) as context: + load_xinst_kernel(self.xinst_lines) + + self.assertIn( + f"Error parsing line 1: {self.xinst_lines[0]}", str(context.exception) + ) + + @patch("builtins.open", new_callable=mock_open) + @patch("linker.loader.load_xinst_kernel") + def test_load_xinst_kernel_from_file_success(self, mock_load, mock_file): + """@brief Test successful loading of XInstructions from a file. + + @test Verifies that file contents are properly passed to load_xinst_kernel + """ + # Configure mocks + mock_file.return_value.__enter__.return_value = self.xinst_lines + mock_load.return_value = self.mock_xinst + + # Call the function + result = load_xinst_kernel_from_file("test.xinst") + + # Verify the results + self.assertEqual(result, self.mock_xinst) + mock_file.assert_called_once_with("test.xinst", "r", encoding="utf-8") + mock_load.assert_called_once_with(self.xinst_lines) + + @patch("builtins.open", new_callable=mock_open) + @patch("linker.loader.load_xinst_kernel") + def test_load_xinst_kernel_from_file_failure(self, mock_load, mock_file): + """@brief Test error handling when loading XInstructions from a file fails. + + @test Verifies that a RuntimeError is raised with appropriate message + """ + # Configure mocks + mock_file.return_value.__enter__.return_value = self.xinst_lines + mock_load.side_effect = Exception("Test error") + + # Call the function and check for exception + with self.assertRaises(RuntimeError) as context: + load_xinst_kernel_from_file("test.xinst") + + self.assertIn( + 'Error occurred loading file "test.xinst"', str(context.exception) + ) + + @patch("linker.instructions.dinst.create_from_mem_line") + def test_load_dinst_kernel_success(self, mock_create): + """@brief Test successful loading of DInstructions from an iterator. + + @test Verifies that DInstructions are properly created from string lines + """ + # Configure mocks + mock_create.side_effect = self.mock_dinst + + # Call the function + result = load_dinst_kernel(self.dinst_lines) + + # Verify the results + self.assertEqual(result, self.mock_dinst) + self.assertEqual(mock_create.call_count, 2) + mock_create.assert_has_calls( + [call(self.dinst_lines[0]), call(self.dinst_lines[1])] + ) + + @patch("linker.instructions.dinst.create_from_mem_line") + def test_load_dinst_kernel_failure(self, mock_create): + """@brief Test error handling when loading DInstructions fails. + + @test Verifies that a RuntimeError is raised when parsing fails + """ + # Configure mocks + mock_create.return_value = None + + # Call the function and check for exception + with self.assertRaises(RuntimeError) as context: + load_dinst_kernel(self.dinst_lines) + + self.assertIn( + f"Error parsing line 1: {self.dinst_lines[0]}", str(context.exception) + ) + + @patch("builtins.open", new_callable=mock_open) + @patch("linker.loader.load_dinst_kernel") + def test_load_dinst_kernel_from_file_success(self, mock_load, mock_file): + """@brief Test successful loading of DInstructions from a file. + + @test Verifies that file contents are properly passed to load_dinst_kernel + """ + # Configure mocks + mock_file.return_value.__enter__.return_value = self.dinst_lines + mock_load.return_value = self.mock_dinst + + # Call the function + result = load_dinst_kernel_from_file("test.dinst") + + # Verify the results + self.assertEqual(result, self.mock_dinst) + mock_file.assert_called_once_with("test.dinst", "r", encoding="utf-8") + mock_load.assert_called_once_with(self.dinst_lines) + + @patch("builtins.open", new_callable=mock_open) + @patch("linker.loader.load_dinst_kernel") + def test_load_dinst_kernel_from_file_failure(self, mock_load, mock_file): + """@brief Test error handling when loading DInstructions from a file fails. + + @test Verifies that a RuntimeError is raised with appropriate message + """ + # Configure mocks + mock_file.return_value.__enter__.return_value = self.dinst_lines + mock_load.side_effect = Exception("Test error") + + # Call the function and check for exception + with self.assertRaises(RuntimeError) as context: + load_dinst_kernel_from_file("test.dinst") + + self.assertIn( + 'Error occurred loading file "test.dinst"', str(context.exception) + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/__init__.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/__init__.py new file mode 100644 index 00000000..4057dc01 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_program_linker.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_program_linker.py new file mode 100644 index 00000000..aa89d860 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_program_linker.py @@ -0,0 +1,686 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + +""" +@brief Unit tests for the program_linker module. +""" + +import io +import unittest +from unittest.mock import patch, MagicMock, call + +from assembler.common.config import GlobalConfig +from linker import MemoryModel +from linker.instructions import minst, cinst, dinst +from linker.steps.program_linker import LinkedProgram + + +# pylint: disable=protected-access +class TestLinkedProgram(unittest.TestCase): + """@brief Tests for the LinkedProgram class.""" + + def setUp(self): + """@brief Set up test fixtures.""" + # Group related stream objects into a dictionary + self.streams = { + "minst": io.StringIO(), + "cinst": io.StringIO(), + "xinst": io.StringIO(), + } + self.mem_model = MagicMock(spec=MemoryModel) + + # Mock the hasHBM property to return True by default + self.has_hbm_patcher = patch.object(GlobalConfig, "hasHBM", True) + self.mock_has_hbm = self.has_hbm_patcher.start() + + # Mock the suppress_comments property to return False by default + self.suppress_comments_patcher = patch.object( + GlobalConfig, "suppress_comments", False + ) + self.mock_suppress_comments = self.suppress_comments_patcher.start() + + self.program = LinkedProgram( + self.streams["minst"], + self.streams["cinst"], + self.streams["xinst"], + self.mem_model, + ) + + def tearDown(self): + """@brief Tear down test fixtures.""" + self.has_hbm_patcher.stop() + self.suppress_comments_patcher.stop() + + def test_init(self): + """@brief Test initialization of LinkedProgram. + + @test Verifies that all instance variables are correctly initialized + """ + self.assertEqual(self.program._minst_ostream, self.streams["minst"]) + self.assertEqual(self.program._cinst_ostream, self.streams["cinst"]) + self.assertEqual(self.program._xinst_ostream, self.streams["xinst"]) + self.assertEqual(self.program._LinkedProgram__mem_model, self.mem_model) + self.assertEqual(self.program._bundle_offset, 0) + self.assertEqual(self.program._minst_line_offset, 0) + self.assertEqual(self.program._cinst_line_offset, 0) + self.assertEqual(self.program._kernel_count, 0) + self.assertTrue(self.program.is_open) + + def test_is_open_property(self): + """@brief Test the is_open property. + + @test Verifies that the is_open property reflects the internal state + """ + self.assertTrue(self.program.is_open) + self.program._is_open = False + self.assertFalse(self.program.is_open) + + def test_close(self): + """@brief Test closing the program. + + @test Verifies that cexit and msyncc instructions are added and program is marked as closed + """ + self.program.close() + + # Verify cexit and msyncc were added + self.assertIn("cexit", self.streams["cinst"].getvalue().lower()) + self.assertIn("msyncc", self.streams["minst"].getvalue().lower()) + self.assertFalse(self.program.is_open) + + # Test that closing an already closed program raises RuntimeError + with self.assertRaises(RuntimeError): + self.program.close() + + # Clean the StringIO object properly + self.streams["minst"].seek(0) + self.streams["minst"].truncate(0) + self.streams["cinst"].seek(0) + self.streams["cinst"].truncate(0) + self.streams["xinst"].seek(0) + self.streams["xinst"].truncate(0) + + # Test closing the program with comments suppressed. + with patch.object(GlobalConfig, "suppress_comments", True): + program = LinkedProgram( + self.streams["minst"], + self.streams["cinst"], + self.streams["xinst"], + self.mem_model, + ) + program.close() + + # Should not contain "terminating MInstQ" comment + self.assertNotIn("terminating MInstQ", self.streams["minst"].getvalue()) + + def test_validate_hbm_address(self): + """@brief Test validating a HBM address. + + @test Verifies that valid addresses are accepted and invalid ones raise exceptions + """ + + # Test validating a valid HBM address + self.mem_model.mem_info_vars = {} + self.program._validate_hbm_address("test_var", 10) + # No exception should be raised + + # Test validating a negative HBM address + with self.assertRaises(RuntimeError): + self.program._validate_hbm_address("test_var", -1) + + def test_validate_hbm_address_mismatch(self): + """@brief Test validating an HBM address that doesn't match the declared address. + + @test Verifies that a RuntimeError is raised when address doesn't match + """ + mock_var = MagicMock() + mock_var.hbm_address = 5 + self.mem_model.mem_info_vars = {"test_var": mock_var} + + with self.assertRaises(RuntimeError): + self.program._validate_hbm_address("test_var", 10) + + def test_validate_spad_address_valid(self): + """@brief Test validating a valid SPAD address with HBM disabled. + + @test Verifies that valid SPAD addresses are accepted when HBM is disabled + """ + with patch.object(GlobalConfig, "hasHBM", False): + self.mem_model.mem_info_vars = {} + self.program._validate_spad_address("test_var", 10) + # No exception should be raised + + def test_validate_spad_address_with_hbm_enabled(self): + """@brief Test validating a SPAD address with HBM enabled. + + @test Verifies that an AssertionError is raised when HBM is enabled + """ + with self.assertRaises(AssertionError): + self.program._validate_spad_address("test_var", 10) + + def test_validate_spad_address_negative(self): + """@brief Test validating a negative SPAD address. + + @test Verifies that a RuntimeError is raised for negative addresses + """ + with patch.object(GlobalConfig, "hasHBM", False): + with self.assertRaises(RuntimeError): + self.program._validate_spad_address("test_var", -1) + + def test_validate_spad_address_mismatch(self): + """@brief Test validating a SPAD address that doesn't match the declared address. + + @test Verifies that a RuntimeError is raised when address doesn't match + """ + with patch.object(GlobalConfig, "hasHBM", False): + mock_var = MagicMock() + mock_var.hbm_address = 5 + self.mem_model.mem_info_vars = {"test_var": mock_var} + + with self.assertRaises(RuntimeError): + self.program._validate_spad_address("test_var", 10) + + def test_update_minsts(self): + """@brief Test updating MInsts. + + @test Verifies that MInsts are correctly updated with offsets and variable addresses + """ + # Create mock MInstructions + mock_msyncc = MagicMock(spec=minst.MSyncc) + mock_msyncc.target = 5 + + mock_mload = MagicMock(spec=minst.MLoad) + mock_mload.source = "input_var" + mock_mload.comment = "original comment" + + mock_mstore = MagicMock(spec=minst.MStore) + mock_mstore.dest = "output_var" + mock_mstore.comment = None + + # Set up memory model mock + self.mem_model.use_variable.side_effect = [ + 10, + 20, + ] # Return different addresses for different vars + + # Execute the update + kernel_minstrs = [mock_msyncc, mock_mload, mock_mstore] + self.program._cinst_line_offset = 10 # Set initial offset + self.program._kernel_count = 1 # Set kernel count + self.program._update_minsts(kernel_minstrs) + + # Verify results + self.assertEqual(mock_msyncc.target, 15) # 5 + 10 + self.assertEqual(mock_mload.source, "10") # Replaced with HBM address + self.assertIn("input_var", mock_mload.comment) # Comment updated + self.assertIn( + "original comment", mock_mload.comment + ) # Original comment preserved + + self.assertEqual(mock_mstore.dest, "20") # Replaced with HBM address + + # Verify the memory model was used correctly + self.mem_model.use_variable.assert_has_calls( + [call("input_var", 1), call("output_var", 1)] + ) + + def test_remove_and_merge_csyncm_cnop(self): + """@brief Test removing CSyncm instructions and merging CNop instructions. + + @test Verifies that CSyncm instructions are removed and CNop cycles are updated correctly + """ + # Create mock CInstructions + mock_ifetch = MagicMock(spec=cinst.IFetch) + mock_ifetch.bundle = 1 + mock_ifetch.tokens = [0] + + mock_csyncm1 = MagicMock(spec=cinst.CSyncm) + mock_csyncm1.tokens = [0] + + mock_cnop1 = MagicMock(spec=cinst.CNop) + mock_cnop1.cycles = 2 + mock_cnop1.tokens = [0] + + mock_csyncm2 = MagicMock(spec=cinst.CSyncm) + mock_csyncm2.tokens = [0] + + mock_cnop2 = MagicMock(spec=cinst.CNop) + mock_cnop2.cycles = 3 + mock_cnop2.tokens = [0] + + # Set up ISACInst.CSyncm.get_throughput + with patch( + "assembler.instructions.cinst.CSyncm.get_throughput", return_value=2 + ): + # Execute the method + kernel_cinstrs = [ + mock_ifetch, + mock_csyncm1, + mock_cnop1, + mock_csyncm2, + mock_cnop2, + ] + self.program._remove_and_merge_csyncm_cnop(kernel_cinstrs) + + # Verify CSyncm instructions were removed + self.assertNotIn(mock_csyncm1, kernel_cinstrs) + self.assertNotIn(mock_csyncm2, kernel_cinstrs) + + # Verify CNop cycles were updated (should have added 2 for each CSyncm) + # First CNop gets 2 cycles added from first CSyncm + self.assertEqual(mock_cnop1.cycles, 4) # 2 + 2 + + # Verify the line numbers were updated + for i, instr in enumerate(kernel_cinstrs): + self.assertEqual(instr.tokens[0], str(i)) + + def test_update_cinsts_addresses_and_offsets(self): + """@brief Test updating CInst addresses and offsets. + + @test Verifies that CInst addresses and offsets are correctly updated + """ + # Create mock CInstructions + mock_ifetch = MagicMock(spec=cinst.IFetch) + mock_ifetch.bundle = 1 + + mock_csyncm = MagicMock(spec=cinst.CSyncm) + mock_csyncm.target = 5 + + mock_xinstfetch = MagicMock(spec=cinst.XInstFetch) + + # Create SPAD instructions for no-HBM case + mock_bload = MagicMock(spec=cinst.BLoad) + mock_bload.source = "var1" + mock_bload.comment = "original comment" + + mock_cstore = MagicMock(spec=cinst.CStore) + mock_cstore.dest = "var2" + mock_cstore.comment = None + + # Execute the method with HBM enabled + kernel_cinstrs = [mock_ifetch, mock_csyncm] + self.program._bundle_offset = 10 + self.program._minst_line_offset = 20 + self.program._update_cinsts_addresses_and_offsets(kernel_cinstrs) + + # Verify results with HBM enabled + self.assertEqual(mock_ifetch.bundle, 11) # 1 + 10 + self.assertEqual(mock_csyncm.target, 25) # 5 + 20 + + # Test with HBM disabled + with patch.object(GlobalConfig, "hasHBM", False): + # Set up memory model mock + self.mem_model.use_variable.side_effect = [ + 30, + 40, + ] # Return different addresses for different vars + + kernel_cinstrs = [mock_bload, mock_cstore] + self.program._kernel_count = 2 + self.program._update_cinsts_addresses_and_offsets(kernel_cinstrs) + + # Verify SPAD instructions were updated + self.assertEqual(mock_bload.source, "30") + self.assertIn("var1", mock_bload.comment) + self.assertIn("original comment", mock_bload.comment) + + self.assertEqual(mock_cstore.dest, "40") + + # Verify the memory model was used correctly + self.mem_model.use_variable.assert_has_calls( + [call("var1", 2), call("var2", 2)] + ) + + # Test that XInstFetch raises NotImplementedError + with self.assertRaises(NotImplementedError): + self.program._update_cinsts_addresses_and_offsets([mock_xinstfetch]) + + def test_update_cinsts(self): + """@brief Test updating CInsts. + + @test Verifies that the correct update methods are called based on HBM configuration + """ + # Create a mock for _remove_and_merge_csyncm_cnop and _update_cinsts_addresses_and_offsets + with patch.object( + LinkedProgram, "_remove_and_merge_csyncm_cnop" + ) as mock_remove, patch.object( + LinkedProgram, "_update_cinsts_addresses_and_offsets" + ) as mock_update: + + # Execute the method with HBM enabled + kernel_cinstrs = ["cinst1", "cinst2"] + self.program._update_cinsts(kernel_cinstrs) + + # Verify that only _update_cinsts_addresses_and_offsets was called + mock_remove.assert_not_called() + mock_update.assert_called_once_with(kernel_cinstrs) + + # Reset mocks + mock_remove.reset_mock() + mock_update.reset_mock() + + # Execute the method with HBM disabled + with patch.object(GlobalConfig, "hasHBM", False): + self.program._update_cinsts(kernel_cinstrs) + + # Verify that both methods were called + mock_remove.assert_called_once_with(kernel_cinstrs) + mock_update.assert_called_once_with(kernel_cinstrs) + + def test_update_xinsts(self): + """@brief Test updating XInsts. + + @test Verifies that XInst bundles are correctly updated and invalid sequences are detected + """ + # Create mock XInstructions + mock_xinst1 = MagicMock() + mock_xinst1.bundle = 1 + + mock_xinst2 = MagicMock() + mock_xinst2.bundle = 2 + + mock_xinst3 = MagicMock() + mock_xinst3.bundle = 0 # Will cause an error when updated after mock_xinst2 + + # Execute the method + kernel_xinstrs = [mock_xinst1, mock_xinst2] + self.program._bundle_offset = 10 + last_bundle = self.program._update_xinsts(kernel_xinstrs) + + # Verify results + self.assertEqual(mock_xinst1.bundle, 11) # 1 + 10 + self.assertEqual(mock_xinst2.bundle, 12) # 2 + 10 + self.assertEqual(last_bundle, 12) + + # Test that an invalid bundle sequence raises RuntimeError + kernel_xinstrs = [ + mock_xinst2, + mock_xinst3, + ] # xinst3 has lower bundle than xinst2 + with self.assertRaises(RuntimeError): + self.program._update_xinsts(kernel_xinstrs) + + def test_link_kernel(self): + """@brief Test linking a kernel. + + @test Verifies that a kernel is correctly linked with updated instructions + """ + # Create mocks for the update methods + with patch.object( + LinkedProgram, "_update_minsts" + ) as mock_update_minsts, patch.object( + LinkedProgram, "_update_cinsts" + ) as mock_update_cinsts, patch.object( + LinkedProgram, "_update_xinsts" + ) as mock_update_xinsts: + + # Setup mock_update_xinsts to return a bundle offset + mock_update_xinsts.return_value = 5 + + # Create mock instruction lists + kernel_minstrs = [MagicMock(), MagicMock()] + kernel_cinstrs = [MagicMock(), MagicMock()] + kernel_xinstrs = [MagicMock(), MagicMock()] + + # Configure the mocks for to_line method + for i, xinstr in enumerate(kernel_xinstrs): + xinstr.to_line.return_value = f"xinst{i}" + xinstr.comment = f"xinst_comment{i}" if i % 2 == 0 else None + + for i, cinstr in enumerate( + kernel_cinstrs[:-1] + ): # Skip the last one (cexit) + cinstr.to_line.return_value = f"cinst{i}" + cinstr.comment = f"cinst_comment{i}" if i % 2 == 0 else None + + for i, minstr in enumerate( + kernel_minstrs[:-1] + ): # Skip the last one (msyncc) + minstr.to_line.return_value = f"minst{i}" + minstr.comment = f"minst_comment{i}" if i % 2 == 0 else None + + # Execute the method + self.program.link_kernel(kernel_minstrs, kernel_cinstrs, kernel_xinstrs) + + # Verify update methods were called + mock_update_minsts.assert_called_once_with(kernel_minstrs) + mock_update_cinsts.assert_called_once_with(kernel_cinstrs) + mock_update_xinsts.assert_called_once_with(kernel_xinstrs) + + # Verify bundle offset was updated + self.assertEqual(self.program._bundle_offset, 6) # 5 + 1 + + # Verify line offsets were updated + self.assertEqual( + self.program._minst_line_offset, 1 + ) # len(kernel_minstrs) - 1 + self.assertEqual( + self.program._cinst_line_offset, 1 + ) # len(kernel_cinstrs) - 1 + + # Verify kernel count was incremented + self.assertEqual(self.program._kernel_count, 1) + + # Verify output streams contain the instructions + xinst_output = self.streams["xinst"].getvalue() + cinst_output = self.streams["cinst"].getvalue() + minst_output = self.streams["minst"].getvalue() + + self.assertIn("xinst0", xinst_output) + self.assertIn("xinst1", xinst_output) + self.assertIn("xinst_comment0", xinst_output) + + self.assertIn("0, cinst0", cinst_output) + self.assertIn("cinst_comment0", cinst_output) + + self.assertIn("0, minst0", minst_output) + self.assertIn("minst_comment0", minst_output) + + def test_link_kernel_with_no_hbm(self): + """@brief Test linking a kernel with HBM disabled. + + @test Verifies that MInsts are ignored when HBM is disabled + """ + with patch.object(GlobalConfig, "hasHBM", False): + # Create mocks for the update methods + with patch.object( + LinkedProgram, "_update_cinsts" + ) as mock_update_cinsts, patch.object( + LinkedProgram, "_update_xinsts" + ) as mock_update_xinsts: + + # Setup mock_update_xinsts to return a bundle offset + mock_update_xinsts.return_value = 5 + + # Create mock instruction lists + kernel_minstrs = [MagicMock(), MagicMock()] # Should be ignored + kernel_cinstrs = [MagicMock(), MagicMock()] + kernel_xinstrs = [MagicMock(), MagicMock()] + + # Configure the mocks for to_line method + for xinstr in kernel_xinstrs: + xinstr.to_line.return_value = "xinst" + xinstr.comment = None + + for cinstr in kernel_cinstrs[:-1]: # Skip the last one (cexit) + cinstr.to_line.return_value = "cinst" + cinstr.comment = None + + # Execute the method + self.program.link_kernel(kernel_minstrs, kernel_cinstrs, kernel_xinstrs) + + # Verify update methods were called + # No minsts should be processed when HBM is disabled + mock_update_cinsts.assert_called_once_with(kernel_cinstrs) + mock_update_xinsts.assert_called_once_with(kernel_xinstrs) + + # Verify bundle offset was updated + self.assertEqual(self.program._bundle_offset, 6) # 5 + 1 + + # No MInst output when HBM is disabled + minst_output = self.streams["minst"].getvalue() + self.assertEqual(minst_output, "") + + def test_link_kernel_with_closed_program(self): + """@brief Test linking a kernel with a closed program. + + @test Verifies that a RuntimeError is raised when linking to a closed program + """ + # Close the program + self.program._is_open = False + + # Try to link a kernel + with self.assertRaises(RuntimeError): + self.program.link_kernel([], [], []) + + def test_link_kernel_with_suppress_comments(self): + """@brief Test linking a kernel with comments suppressed. + + @test Verifies that comments are not included in the output when suppressed + """ + with patch.object(GlobalConfig, "suppress_comments", True): + # Create mocks for the update methods + with patch.object(LinkedProgram, "_update_minsts"), patch.object( + LinkedProgram, "_update_cinsts" + ), patch.object(LinkedProgram, "_update_xinsts"): + + # Create mock instruction lists with comments + kernel_minstrs = [MagicMock(), MagicMock()] + kernel_cinstrs = [MagicMock(), MagicMock()] + kernel_xinstrs = [MagicMock()] + + # Configure the mocks for to_line method + kernel_xinstrs[0].to_line.return_value = "xinst" + kernel_xinstrs[0].comment = "xinst_comment" + + kernel_cinstrs[0].to_line.return_value = "cinst" + kernel_cinstrs[0].comment = "cinst_comment" + + kernel_minstrs[0].to_line.return_value = "minst" + kernel_minstrs[0].comment = "minst_comment" + + # Execute the method + self.program.link_kernel(kernel_minstrs, kernel_cinstrs, kernel_xinstrs) + + # Verify comments were suppressed + xinst_output = self.streams["xinst"].getvalue() + cinst_output = self.streams["cinst"].getvalue() + minst_output = self.streams["minst"].getvalue() + + self.assertNotIn("xinst_comment", xinst_output) + self.assertNotIn("cinst_comment", cinst_output) + self.assertNotIn("minst_comment", minst_output) + + +class TestJoinDinstKernels(unittest.TestCase): + """@brief Tests for the join_dinst_kernels static method.""" + + def test_join_dinst_kernels_empty(self): + """@brief Test joining empty list of DInst kernels. + + @test Verifies that a ValueError is raised for an empty list + """ + with self.assertRaises(ValueError): + LinkedProgram.join_dinst_kernels([]) + + def test_join_dinst_kernels_single_kernel(self): + """@brief Test joining a single DInst kernel. + + @test Verifies that instructions from a single kernel are correctly processed + """ + # Create mock DInstructions + mock_dload = MagicMock(spec=dinst.DLoad) + mock_dload.var = "var1" + + mock_dstore = MagicMock(spec=dinst.DStore) + mock_dstore.var = "var2" + + # Execute the method + result = LinkedProgram.join_dinst_kernels([[mock_dload, mock_dstore]]) + + # Verify result + self.assertEqual(len(result), 2) + self.assertEqual(result[0], mock_dload) + self.assertEqual(result[1], mock_dstore) + + # Verify address was set + self.assertEqual(mock_dload.address, 0) + self.assertEqual(mock_dstore.address, 1) + + def test_join_dinst_kernels_multiple_kernels(self): + """@brief Test joining multiple DInst kernels. + + @test Verifies that instructions from multiple kernels are correctly merged + """ + # Create mock DInstructions for first kernel + mock_dload1 = MagicMock(spec=dinst.DLoad) + mock_dload1.var = "var1" + + mock_dstore1 = MagicMock(spec=dinst.DStore) + mock_dstore1.var = "var2" + + # Create mock DInstructions for second kernel + mock_dload2 = MagicMock(spec=dinst.DLoad) + mock_dload2.var = "var2" # Same as output from first kernel + + mock_dkeygen = MagicMock(spec=dinst.DKeyGen) + mock_dkeygen.var = "var3" + + mock_dstore2 = MagicMock(spec=dinst.DStore) + mock_dstore2.var = "var4" + + # Execute the method + result = LinkedProgram.join_dinst_kernels( + [[mock_dload1, mock_dstore1], [mock_dload2, mock_dkeygen, mock_dstore2]] + ) + + # Verify result - should contain load1, store1 (output), keygen, store2 (output) + # dload2 should be skipped since it loads var2 which is already an output from kernel1 + self.assertEqual(len(result), 3) + self.assertIn(mock_dload1, result) + self.assertNotIn(mock_dload2, result) # Should be skipped + self.assertIn(mock_dkeygen, result) + self.assertIn(mock_dstore2, result) + + # Verify addresses were set correctly and sequentially + # Note: exact order depends on dictionary iteration which is not guaranteed + used_addresses = {dinst.address for dinst in result} + self.assertEqual(used_addresses, {0, 1, 2}) # Three consecutive addresses + + def test_join_dinst_kernels_with_carry_over_vars(self): + """@brief Test joining DInst kernels with carry-over variables. + + @test Verifies that variables used across kernels are properly consolidated + """ + # Create mock DInstructions for first kernel + mock_dload1 = MagicMock(spec=dinst.DLoad) + mock_dload1.var = "var1" + + mock_dstore1 = MagicMock(spec=dinst.DStore) + mock_dstore1.var = "var2" + + # Create mock DInstructions for second kernel + mock_dload2 = MagicMock(spec=dinst.DLoad) + mock_dload2.var = "var2" # Same as output from first kernel + + mock_dstore2 = MagicMock(spec=dinst.DStore) + mock_dstore2.var = "var2" # Same variable is also an output + + # Execute the method + result = LinkedProgram.join_dinst_kernels( + [[mock_dload1, mock_dstore1], [mock_dload2, mock_dstore2]] + ) + + # Verify result - should contain load1, store2 + # Both dload2 and dstore1 should be skipped since var2 is carried over + self.assertEqual(len(result), 2) + self.assertIn(mock_dload1, result) + self.assertNotIn(mock_dload2, result) # Should be skipped + self.assertNotIn(mock_dstore1, result) # Should be skipped + self.assertIn(mock_dstore2, result) # Final output for var2 + + +if __name__ == "__main__": + unittest.main() diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_variable_discovery.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_variable_discovery.py new file mode 100644 index 00000000..4e9e942b --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_linker/test_steps/test_variable_discovery.py @@ -0,0 +1,275 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + +""" +@brief Unit tests for the variable discovery module. +""" + +import unittest +from unittest.mock import patch, MagicMock + +from linker.steps.variable_discovery import discover_variables, discover_variables_spad + + +class TestVariableDiscovery(unittest.TestCase): + """@brief Tests for the variable discovery functions.""" + + def setUp(self): + """@brief Set up test fixtures.""" + # Group MInstructions in a dictionary + self.m_instrs = { + "load": MagicMock(source="var1"), + "store": MagicMock(dest="var2"), + "other": MagicMock(), # MInstruction that's neither MLoad nor MStore + } + + # Group CInstructions in a dictionary + self.c_instrs = { + "bload": MagicMock(source="var3"), + "cload": MagicMock(source="var4"), + "bones": MagicMock(source="var5"), + "nload": MagicMock(source="var6"), + "cstore": MagicMock(dest="var7"), + "other": MagicMock(), # CInstruction that's none of the above + } + + @patch("linker.steps.variable_discovery.minst") + @patch("linker.steps.variable_discovery.MInstruction") + @patch("assembler.memory_model.variable.Variable.validateName") + def test_discover_variables_valid( + self, mock_validate, mock_minst_class, mock_minst + ): + """@brief Test discovering variables from valid MInstructions. + + @test Verifies that variables are correctly discovered from MLoad and MStore instructions + """ + # Setup mocks + mock_minst.MLoad = MagicMock() + mock_minst.MStore = MagicMock() + + # Configure type checking at the module level, avoiding patching isinstance + def is_minst_side_effect(obj): + return obj in [ + self.m_instrs["load"], + self.m_instrs["store"], + self.m_instrs["other"], + ] + + def is_mload_side_effect(obj): + return obj is self.m_instrs["load"] + + def is_mstore_side_effect(obj): + return obj is self.m_instrs["store"] + + # Replace the actual isinstance calls in the module with our mock functions + with patch( + "linker.steps.variable_discovery.isinstance", + side_effect=lambda obj, cls: { + mock_minst_class: is_minst_side_effect(obj), + mock_minst.MLoad: is_mload_side_effect(obj), + mock_minst.MStore: is_mstore_side_effect(obj), + }.get(cls, False), + ): + + # Configure validateName to return True + mock_validate.return_value = True + + # Test with a list containing both MLoad and MStore + minstrs = [ + self.m_instrs["load"], + self.m_instrs["store"], + self.m_instrs["other"], + ] + + # Call the function + result = list(discover_variables(minstrs)) + + # Verify results + self.assertEqual(result, ["var1", "var2"]) + mock_validate.assert_any_call("var1") + mock_validate.assert_any_call("var2") + + def test_discover_variables_empty_list(self): + """@brief Test discovering variables from an empty list of MInstructions. + + @test Verifies that an empty list is returned when no instructions are provided + """ + # No need to patch isinstance for an empty list + result = list(discover_variables([])) + + # Verify results - should be an empty list + self.assertEqual(result, []) + + def test_discover_variables_invalid_type(self): + """@brief Test discovering variables with invalid types in the list. + + @test Verifies that a TypeError is raised when an invalid object is in the list + """ + # Setup mock to fail the isinstance check + invalid_obj = MagicMock() + + with patch("linker.steps.variable_discovery.isinstance", return_value=False): + # Call the function with a list containing an invalid type + with self.assertRaises(TypeError) as context: + list(discover_variables([invalid_obj])) + + # Verify the error message + self.assertIn("not a valid MInstruction", str(context.exception)) + + @patch("linker.steps.variable_discovery.minst") + @patch("assembler.memory_model.variable.Variable.validateName") + def test_discover_variables_invalid_variable_name(self, mock_validate, mock_minst): + """@brief Test discovering variables with an invalid variable name. + + @test Verifies that a RuntimeError is raised when a variable name is invalid + """ + # Setup mocks + mock_minst.MLoad = MagicMock() + + # Configure validateName to return False + mock_validate.return_value = False + + with patch( + "linker.steps.variable_discovery.isinstance", + side_effect=lambda obj, cls: True, + ): + # Call the function + with self.assertRaises(RuntimeError) as context: + list(discover_variables([self.m_instrs["load"]])) + + # Verify the error message + self.assertIn("Invalid Variable name", str(context.exception)) + + @patch("linker.steps.variable_discovery.cinst") + @patch("linker.steps.variable_discovery.CInstruction") + @patch("assembler.memory_model.variable.Variable.validateName") + def test_discover_variables_spad_valid( + self, mock_validate, mock_cinst_class, mock_cinst + ): + """@brief Test discovering variables from valid CInstructions. + + @test Verifies that variables are correctly discovered from all relevant CInstruction types + """ + # Setup mocks + mock_cinst.BLoad = MagicMock() + mock_cinst.CLoad = MagicMock() + mock_cinst.BOnes = MagicMock() + mock_cinst.NLoad = MagicMock() + mock_cinst.CStore = MagicMock() + + # Configure validateName to return True + mock_validate.return_value = True + + # Test with a list containing all types of CInstructions + cinstrs = [ + self.c_instrs["bload"], + self.c_instrs["cload"], + self.c_instrs["bones"], + self.c_instrs["nload"], + self.c_instrs["cstore"], + self.c_instrs["other"], + ] + + # Improved mock for isinstance that handles tuples of classes + def mock_isinstance(obj, cls): + # Handle tuple case first + if isinstance(cls, tuple): + return any(mock_isinstance(obj, c) for c in cls) + + # Use a dictionary to map class types to their respective checks + class_checks = { + mock_cinst_class: lambda: obj in cinstrs, + mock_cinst.BLoad: lambda: obj is self.c_instrs["bload"], + mock_cinst.CLoad: lambda: obj is self.c_instrs["cload"], + mock_cinst.BOnes: lambda: obj is self.c_instrs["bones"], + mock_cinst.NLoad: lambda: obj is self.c_instrs["nload"], + mock_cinst.CStore: lambda: obj is self.c_instrs["cstore"], + } + + # Check if cls is in our mapping and return the result of its check function + return class_checks.get(cls, lambda: False)() + + # Patch the isinstance calls at the module level + with patch( + "linker.steps.variable_discovery.isinstance", side_effect=mock_isinstance + ): + # Call the function + result = list(discover_variables_spad(cinstrs)) + + # Verify results + self.assertEqual(result, ["var3", "var4", "var5", "var6", "var7"]) + mock_validate.assert_any_call("var3") + mock_validate.assert_any_call("var4") + mock_validate.assert_any_call("var5") + mock_validate.assert_any_call("var6") + mock_validate.assert_any_call("var7") + + def test_discover_variables_spad_empty_list(self): + """@brief Test discovering variables from an empty list of CInstructions. + + @test Verifies that an empty list is returned when no instructions are provided + """ + # Call the function with an empty list + result = list(discover_variables_spad([])) + + # Verify results - should be an empty list + self.assertEqual(result, []) + + def test_discover_variables_spad_invalid_type(self): + """@brief Test discovering variables with invalid types in the list. + + @test Verifies that a TypeError is raised when an invalid object is in the list + """ + # Setup mock + invalid_obj = MagicMock() + + with patch("linker.steps.variable_discovery.isinstance", return_value=False): + # Call the function with a list containing an invalid type + with self.assertRaises(TypeError) as context: + list(discover_variables_spad([invalid_obj])) + + # Verify the error message + self.assertIn("not a valid MInstruction", str(context.exception)) + + @patch("linker.steps.variable_discovery.cinst") + @patch("linker.steps.variable_discovery.CInstruction") + @patch("assembler.memory_model.variable.Variable.validateName") + def test_discover_variables_spad_invalid_variable_name( + self, mock_validate, mock_cinst_class, mock_cinst + ): + """@brief Test discovering variables with an invalid variable name. + + @test Verifies that a RuntimeError is raised when a variable name is invalid + """ + # Setup mocks + mock_cinst.BLoad = MagicMock() + + # Configure validateName to return False + mock_validate.return_value = False + + # Mock isinstance to make our mock object appear as CInstruction and BLoad + with patch( + "linker.steps.variable_discovery.isinstance", + side_effect=lambda obj, cls: { + mock_cinst_class: True, + ( + mock_cinst.BLoad, + mock_cinst.CLoad, + mock_cinst.BOnes, + mock_cinst.NLoad, + ): True, + }.get(cls, False), + ): + # Call the function + with self.assertRaises(RuntimeError) as context: + list(discover_variables_spad([self.c_instrs["bload"]])) + + # Verify the error message + self.assertIn("Invalid Variable name", str(context.exception)) + + +if __name__ == "__main__": + unittest.main()