# DAX Performance Testing

## Summary

This notebook is designed to measure DAX query timings under different cache states (cold, warm, and hot). Specifically:

1. **DAX Queries from Excel**  
   - You must provide an Excel file containing the DAX queries in a table you wish to test.  
   - For each query, a column needs align with the `runQueryType` used for a given `queryId`.  
   - This notebook reads those queries and executes them on one or more Power BI/Fabric models.

2. **Lakehouse Logging**  
   - You also must attach the appropriate Lakehouse in Fabric so that logs can be saved (both in a table and as files if you choose).  

3. **Capacity Pause/Resume**  
   - In some scenarios (e.g., simulating a "cold" cache on DirectQuery or Import models), the code pauses and resumes capacities.  
   - **Warning**: Pausing a capacity will interrupt any running workloads on that capacity. Resuming will take time and resources, and can affect other workspaces assigned to the same capacity.  

Overall, the purpose is to capture performance metrics (timings, CPU usage, etc.) for DAX queries under different cache states.


### Install the latest .whl package

Check [here](https://pypi.org/project/semantic-link-labs/) to see the latest version.

In [None]:
%pip install semantic-link-labs

### Import the library and necessary packages

In [None]:
import sempy.fabric as fabric
import sempy_labs as labs
import pandas as pd
import time
import itertools
import random
import requests
import functools
import builtins
from threading import local
from contextlib import contextmanager
from uuid import uuid4
from pyspark.sql.functions import col, sum as _sum, when, countDistinct
from datetime import datetime

### Global configurations & variables

In [None]:
# Generate a unique run ID for this test run
run_id = str(uuid4())

# Define models and their configurations for testing
models = [
    {
        "name": "Model Name", # The name of the semantic model
        "storageMode": "DirectLake",  # Import, DirectQuery, or DirectLake
        "cache_types": ["cold", "warm", "hot"], # List of cache types to be run (hot, warm, and cold)
        "model_workspace_name": "Model Workspace Name", # The workspace name of the semantic model
        "database_name": "Lakehouse Name",  # Only needed for cold cache queries for Import and DirectQuery
        "database_workspace_name": "Lakehouse Workspace Name",  # Only needed for cold cache queries for Import and DirectQuery
        "runQueryType": "query", # The name of the column in your DAX Excel file contains the query to be run
    },
]

# Only needed for cold cache queries for Import and DirectQuery
workspace_capacities = {
    "Workspace Name": {
        "capacity_name": "Testing Capacity Name",
        "alt_capacity_name": "Alternate Capacity Name",
    }
}

# Read DAX queries from the Excel file uploaded to the attached lakehouse
# The first column must be 'queryId' and additional columns should contain variants of the DAX query.
dax_queries = pd.read_excel(
    "/lakehouse/default/Files/DAXExcelFileName.xlsx", "DAXTableName"
)

# Additional arguments controlling the behavior of query execution and logging
additional_arguments = {
    "roundNumber": 1, # The current round of DAX testing. Will be considered when determine if maxNumberPerQuery is met or not
    "onlyRunNewQueries": True, # Will determine if queries will stop being tested after maxNumberPerQuery is met
    "maxNumberPerQuery": 1, # The max number of queries to capture per round, queryId, model and cache type
    "maxFailuresBeforeSkipping": 5, # The number of failed query attempts per round, queryId, model and cache type before skipping
    "numberOfRunsPerQueryId": 15, # The number of times to loop over each queryId. If all combos have met maxNumberPerQuery, the loop will break
    "stopQueryIdsAt": 99, # Allows you to stop the queryId loop at a certain number, even if there are more queries present, i.e., there are queryIds 1-20 but stop at 5
    "forceStartQueriesAt1": False, # If set to False, testing will stop at the first incomplete queryId instead of starting at queryId 1  
    "logTableName": "DAXTestingLogTableName", # The name of the table in the attached lakehouse to save the performance logs to
    "clearAllLogs": False, # Will drop the existing logs table before starting testing
    "clearCurrentRoundLogs": False, # Will delete the logs associated with the current roundNumber before starting testing
    "randomizeRuns": True, # Will randomize the model and cache type combos when testing
    "skipSettingHotCache": False, # Should be False if randomizing the runs. If the runs are randomized, the previous warm cache run will set the hot cache
    "pauseAfterSettingCache": 5, # The number of seconds to wait after setting the cache
    "pauseAfterRunningQuery": 5, # The number of second to wait before writing the logs to the log table
    "pauseBetweenRuns": 30, # The number of seconds to wait before starting the next query
}

# Define the expected schema for DAX trace log events
event_schema = {
    "DirectQueryBegin": [
        "EventClass",
        "CurrentTime",
        "TextData",
        "StartTime",
        "EndTime",
        "Duration",
        "CpuTime",
        "Success",
    ],
    "DirectQueryEnd": [
        "EventClass",
        "CurrentTime",
        "TextData",
        "StartTime",
        "EndTime",
        "Duration",
        "CpuTime",
        "Success",
    ],
    "VertiPaqSEQueryBegin": [
        "EventClass",
        "EventSubclass",
        "CurrentTime",
        "TextData",
        "StartTime",
    ],
    "VertiPaqSEQueryEnd": [
        "EventClass",
        "EventSubclass",
        "CurrentTime",
        "TextData",
        "StartTime",
        "EndTime",
        "Duration",
        "CpuTime",
        "Success",
    ],
    "VertiPaqSEQueryCacheMatch": [
        "EventClass",
        "EventSubclass",
        "CurrentTime",
        "TextData",
    ],
    "QueryBegin": [
        "EventClass",
        "EventSubclass",
        "CurrentTime",
        "TextData",
        "StartTime",
        "ConnectionID",
        "SessionID",
        "RequestProperties"
    ],
    "QueryEnd": [
        "EventClass",
        "EventSubclass",
        "CurrentTime",
        "TextData",
        "StartTime",
        "EndTime",
        "Duration",
        "CpuTime",
        "Success",
        "ConnectionID",
        "SessionID",
    ],
}

# Dictionary to track if a capacity pause is needed for each model during testing
model_pause_capacity_needed = {}

# Variables for Pausing/Resuming Capacities: credentials and configuration parameters for Azure Key Vault and resource management
resource_group_name = ""
subscription_id = ""
key_vault_uri = ""
key_vault_client_id = ""
key_vault_tenant_id = ""
key_vault_client_secret = ""

# Enforce case-sensitivity in Spark to ensure column name matching is exact
spark.conf.set("spark.sql.caseSensitive", True)

### Logging & Retry Decorators, Basic Helpers

In [None]:
# Thread-local storage for tracking the call depth (used for indented printing)
_thread_local = local()

@contextmanager
def indented_print(indent_level: int):
    """
    Temporarily replaces the built-in print function with an indented version.
    This helps in visually distinguishing nested function calls in the logs.
    
    Parameters:
        indent_level (int): The indentation depth to apply.
    """
    original_print = builtins.print

    def custom_print(*args, **kwargs):
        indent = "    " * indent_level
        original_print(indent + " ".join(map(str, args)), **kwargs)

    builtins.print = custom_print
    try:
        yield
    finally:
        builtins.print = original_print

def log_function_calls(func):
    """
    Decorator that logs the start and end of a function call with indented printing.
    This is useful for tracking nested function calls in the execution logs.
    """
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        if not hasattr(_thread_local, "call_depth"):
            _thread_local.call_depth = 0

        indent = _thread_local.call_depth

        with indented_print(indent):
            print(f"✅ {func.__name__} - Starting")

        _thread_local.call_depth += 1
        try:
            with indented_print(_thread_local.call_depth):
                result = func(*args, **kwargs)
        finally:
            _thread_local.call_depth -= 1
            with indented_print(_thread_local.call_depth):
                print(f"✅ {func.__name__} - Ending")

        return result

    return wrapper

def retry(exceptions, tries=3, delay=5, backoff=2, logger=None):
    """
    Decorator for retrying a function call with exponential backoff.
    It will attempt to call the function and, if specified exceptions occur, wait and retry.

    Parameters:
        exceptions (tuple): Exception classes to catch.
        tries (int): Number of attempts.
        delay (int): Initial delay between attempts in seconds.
        backoff (int): Multiplier applied to the delay between attempts.
        logger (callable): Logging function to use for printing warnings.

    Returns:
        The decorated function.
    """
    def decorator_retry(func):
        @functools.wraps(func)
        def wrapper_retry(*args, **kwargs):
            if not hasattr(_thread_local, "call_depth"):
                _thread_local.call_depth = 0

            _tries, _delay = tries, delay
            first_fail = True  # Track whether we've failed before

            while _tries > 1:
                try:
                    return func(*args, **kwargs)
                except exceptions as e:
                    # Skip printing on the first failure; print on subsequent failures
                    if not first_fail:
                        with indented_print(_thread_local.call_depth):
                            print(f"⚠️ {func.__name__} failed with {e}, retrying in {_delay} seconds...")
                    else:
                        first_fail = False

                    time.sleep(_delay)
                    _tries -= 1
                    _delay *= backoff

            # Last attempt (no retries left); let the exception bubble if it fails
            return func(*args, **kwargs)

        return wrapper_retry

    return decorator_retry

In [None]:
def trace_started(_traces, _trace_name):
    # Check if a specific trace (by name) has started by looking into the traces DataFrame.
    return _traces.loc[_traces["Name"] == _trace_name].shape[0] > 0

def query_end_event_collected(_trace_events):
    # Verify if the 'QueryEnd' event is present in the collected trace events.
    return _trace_events.loc[_trace_events["Event Class"] == "QueryEnd"].shape[0] > 0

@log_function_calls
@retry(Exception, tries=10, delay=2, backoff=2, logger=print)
def wait_for_trace_start(trace_connection, trace_name):
    # Wait until the trace with the specified name is detected as started.
    if not trace_started(trace_connection.list_traces(), trace_name):
        raise Exception("Trace has not started yet")
    return True

@log_function_calls
@retry(Exception, tries=60, delay=3, backoff=1, logger=print)
def wait_for_query_end_event(trace):
    # Poll until the trace logs indicate that the query end event has been collected.
    logs = trace.get_trace_logs()
    if not query_end_event_collected(logs):
        raise Exception("Query end event not collected yet")
    return logs

@log_function_calls
@retry(Exception, tries=30, delay=5, backoff=1, logger=print)
def check_model_online(_model):
    # Check if the model is online by executing a simple DAX query.
    dax_query_eval_1(_model)

@log_function_calls
@retry(Exception, tries=30, delay=2, backoff=2, logger=print)
def wait_for_capacity_status(_capacity_name, target_status):
    # Check the current status of a capacity; wait until it matches the target status.
    current_status = labs.list_capacities().loc[
        labs.list_capacities()["Display Name"] == _capacity_name, "State"
    ].iloc[0]
    
    if current_status != target_status:
        raise Exception("Capacity status not updated yet")
    return current_status

def dax_query_eval_1(_model):
    # Execute a simple DAX query to verify connectivity and evaluate model responsiveness.
    fabric.evaluate_dax(
        _model["name"], "EVALUATE {1}", workspace=_model["model_workspace_name"]
    )

@log_function_calls
def wait_for_model_to_come_online(_model):
    # Wait until the model is confirmed to be online; raise an exception if it fails.
    try:
        check_model_online(_model)
        print("✅ Model is online")
    except Exception as e:
        raise Exception("❌ Model failed to come online") from e

### Pause & Resume Capacity

In [None]:
@log_function_calls
def update_model_pause_status(event, model=None, workspace=None):
    """
    Updates the model_pause_capacity_needed dictionary based on specific events.

    Parameters:
        event (str): The event type. Allowed values are:
            - "initialize": Initializes the dictionary so that every model is set to True,
              except DirectLake models which are always False.
            - "model_queried": A model has been queried. For Import models, mark it as True.
              For DirectQuery models, mark all models sharing the same database_name and
              database_workspace_name as True.
            - "capacity_paused": After a capacity pause, for the given workspace,
              mark as False all Import models in that workspace and all DirectQuery models whose
              database_workspace_name matches that workspace.
        model (dict, optional): The model dictionary (required for "model_queried").
        workspace (str, optional): The workspace name (required for "capacity_paused").
    """
    global model_pause_capacity_needed

    if event == "initialize":
        # Set every model to True except for DirectLake models, which are always False.
        for m in models:
            if m["storageMode"] == "DirectLake":
                model_pause_capacity_needed[m["name"]] = False
            else:
                model_pause_capacity_needed[m["name"]] = True

    elif event == "model_queried" and model is not None:
        if model["storageMode"] == "Import":
            model_pause_capacity_needed[model["name"]] = True
        elif model["storageMode"] == "DirectQuery":
            # Mark all DirectQuery models sharing the same database settings as True.
            target_db = model["database_name"]
            target_db_workspace = model["database_workspace_name"]
            for m in models:
                if (
                    m["storageMode"] == "DirectQuery"
                    and m["database_name"] == target_db
                    and m["database_workspace_name"] == target_db_workspace
                ):
                    model_pause_capacity_needed[m["name"]] = True

    elif event == "capacity_paused" and workspace is not None:
        # For the given workspace, mark as False all Import models and all DirectQuery models
        # whose database_workspace_name matches the workspace.
        for m in models:
            if m["model_workspace_name"] == workspace and m["storageMode"] == "Import":
                model_pause_capacity_needed[m["name"]] = False
        for m in models:
            if (
                m["storageMode"] == "DirectQuery"
                and m["database_workspace_name"] == workspace
            ):
                model_pause_capacity_needed[m["name"]] = False

    else:
        print(f"⚠️ Unknown event '{event}' or missing required parameter(s).")

    # Debug output (optional)
    print("📝 Updated model_pause_capacity_needed")

In [None]:
@labs.log_function_calls
def pause_resume_capacity(_capacity_name, _action, _simplify_logs=False):
    """
    Pauses or resumes a given capacity using Semantic Link Labs functions.
    
    Parameters:
       _capacity_name: The name of the capacity to be paused or resumed.
       _action: The action to perform, either "pause" or "resume".
       _simplify_logs: Optional flag to simplify logging output.
    """
    print(f"🔄 {_action.title()} capacity '{_capacity_name}': Attempting")

    # Check current status using the new labs function.
    current_status = labs.list_capacities().loc[
        labs.list_capacities()["Display Name"] == _capacity_name, "State"
    ].iloc[0]
    
    # Define mapping options for pause and resume actions.
    action_options = {
        "pause": {
            "expected_status": "Active",  # Current status needed to perform a pause.
            "target_status": "Paused",
            "action_function": labs.suspend_fabric_capacity,
        },
        "resume": {
            "expected_status": "Paused",  # Current status needed to perform a resume.
            "target_status": "Active",
            "action_function": labs.resume_fabric_capacity,
        },
    }
    
    if current_status == action_options[_action]["expected_status"]:
        print(f"🛠️ {_action.title()} capacity '{_capacity_name}': Requesting action")
        
        # Call the appropriate labs function for pausing or resuming the capacity.
        action_options[_action]["action_function"](
            capacity_name=_capacity_name,
            azure_subscription_id=subscription_id,
            resource_group=resource_group_name,
            key_vault_uri=key_vault_uri,
            key_vault_tenant_id=key_vault_tenant_id,
            key_vault_client_id=key_vault_client_id,
            key_vault_client_secret=key_vault_client_secret  # Use the provided secret
        )
        
        # Construct the GET URL for checking capacity status.
        base_url = (
            f"https://management.azure.com/subscriptions/{subscription_id}"
            f"/resourceGroups/{resource_group_name}"
            f"/providers/Microsoft.Fabric/capacities/{_capacity_name}"
        )
        get_url = f"{base_url}?api-version=2023-11-01"
        
        # Create headers using the token provider.
        token_provider = auth.token_provider.get()
        if token_provider is None:
            token_provider = ServicePrincipalTokenProvider.from_azure_key_vault(
                key_vault_uri=key_vault_uri,
                key_vault_tenant_id=key_vault_tenant_id,
                key_vault_client_id=key_vault_client_id,
                key_vault_client_secret=key_vault_client_secret
            )
        headers = _get_headers(token_provider, audience="azure")
        
        # Wait for the capacity status to change to the target status.
        try:
            wait_for_capacity_status(_capacity_name, action_options[_action]["target_status"])
            print(f"✅ {_action.title()} capacity '{_capacity_name}': Action successful")
        except Exception as e:
            print(f"⚠️ {_action.title()} capacity '{_capacity_name}': Timeout waiting for target status. Error: {e}")
    else:
        print(f"ℹ️ {_action.title()} capacity '{_capacity_name}': Already '{current_status}'")


### Cache-setting functions

In [None]:
@log_function_calls
@retry(exceptions=(Exception,), tries=5, delay=5, backoff=2)
def clear_vertipaq_cache(_model):
    """
    Clears the VertiPaq cache by calling labs.clear_cache.
    Retries automatically up to 'tries' times if there's an error.
    """
    print("🧹 Clearing VertiPaq cache")
    wait_for_model_to_come_online(_model)

    # Attempt the clearing and verify by executing a simple DAX query.
    try:
        labs.clear_cache(_model["name"], workspace=_model["model_workspace_name"])
        dax_query_eval_1(_model)
        print("✅ Clear VertiPaq cache successful")
    except Exception as e:
        # If clearing fails, refresh the TOM cache before retrying.
        print("🔄 Clearing VertiPaq cache failed; retrying...")
        fabric.refresh_tom_cache(_model["model_workspace_name"])
        raise e  # Re-raise exception to trigger the retry mechanism

    # Small buffer after clearing to allow processes to settle.
    time.sleep(5)

@log_function_calls
def set_hot_cache(_model, _expression, successful_query_count_goal=2):
    """
    Executes the same query multiple times to prime the cache (hot cache).
    The goal is to have a specified number of successful queries to confirm the cache is set.
    """
    print("🔥 Setting Hot Cache")
    successful_query_count = 0
    number_of_query_attempts = (
        successful_query_count_goal * 5 if successful_query_count_goal > 1 else 1
    )

    if additional_arguments["skipSettingHotCache"]:
        successful_query_count = successful_query_count_goal
    else:
        for _ in range(number_of_query_attempts):
            # Drop any existing traces before starting a new trace for hot cache priming.
            fabric.create_trace_connection(
                _model["name"], _model["model_workspace_name"]
            ).drop_traces()
            trace_name = f"Cache Trace {str(uuid4())}"
            with fabric.create_trace_connection(
                _model["name"], _model["model_workspace_name"]
            ) as trace_connection:
                with trace_connection.create_trace(event_schema, trace_name) as trace:
                    print("🔍 Starting trace for hot cache")
                    trace.start()
                    wait_for_trace_start(trace_connection, trace_name)
                    try:
                        print("⚡ Executing DAX query for hot cache")
                        fabric.evaluate_dax(
                            _model["name"],
                            _expression,
                            workspace=_model["model_workspace_name"],
                        )
                        successful_query_count += 1
                        print("✅ DAX query succeeded for hot cache")
                    except Exception as e:
                        print("❌ DAX query failed for hot cache:", e)

                    print("📜 Collecting trace logs for hot cache")
                    wait_for_query_end_event(trace)
                    trace.stop()

            if successful_query_count == successful_query_count_goal:
                break

    print(f"✅ Hot cache set; goal: {successful_query_count_goal} successful queries")
    return successful_query_count == successful_query_count_goal

@log_function_calls
def set_warm_cache(_model, _expression):
    """
    Sets a warm cache for the model.
      - For DirectLake models, performs a hot cache query then clears the VertiPaq cache.
      - For DirectQuery models, ensures a cold state first if needed.
    """
    print("🔥 Setting Warm Cache")

    if _model["storageMode"] == "DirectQuery":
        # For DirectQuery, simulate a cold cache before warming.
        set_cold_cache(_model)

    # Prime the cache with a hot query and then clear the VertiPaq cache.
    hot_cache_successful = set_hot_cache(
        _model, _expression, successful_query_count_goal=1
    )
    clear_vertipaq_cache(_model)

    print("✅ Warm cache set")
    return hot_cache_successful

@retry(exceptions=(Exception,), tries=5, delay=5, backoff=2)
def _refresh_tom_cache(workspace_name):
    """
    Calls fabric.refresh_tom_cache(workspace_name) exactly once,
    raising an exception if it fails. The '@retry' decorator will call it again on failure.
    """
    print(f"⌛ Refreshing TOM cache for workspace '{workspace_name}'")
    fabric.refresh_tom_cache(workspace_name)
    print("✅ TOM cache refreshed successfully")

@retry(exceptions=(Exception,), tries=30, delay=3, backoff=2)
def _wait_for_refresh_to_complete(_model, refresh_id):
    """
    Checks the status of a dataset refresh once. If it is still 'InProgress' or not in the expected terminal states,
    raises an exception to trigger a retry.
    """
    status = fabric.get_refresh_execution_details(
        _model["name"],
        refresh_id,
        workspace=_model["model_workspace_name"],
    ).status

    if status not in ["Completed", "Failed"]:
        # If refresh is still in progress, raise an exception to trigger another retry attempt.
        raise Exception(f"Refresh status is '{status}'; not done yet.")

    # Log completion status for the refresh process.
    print(f"✅ Refresh status: '{status}' - finishing polling.")

@log_function_calls
def set_cold_cache(_model):
    """
    Sets a cold cache for the model:
      - For DirectLake: perform a clearValues refresh, then a full refresh, and finally clear the VertiPaq cache.
      - For Import/DirectQuery: perform capacity reassignment and pause/resume operations, followed by a refresh.
    """
    print("❄️ Setting Cold Cache")
    if _model["storageMode"] != "DirectLake":
        if model_pause_capacity_needed[_model["name"]]:
            ws_caps = workspace_capacities[_model["model_workspace_name"]]
            print(f"🔄 Assigning alternate capacity for workspace '{_model['model_workspace_name']}'")
            labs.assign_workspace_to_capacity(ws_caps["alt_capacity_name"], _model["model_workspace_name"])
            print(f"✅ Alternate capacity assigned: {ws_caps['alt_capacity_name']}")

            pause_resume_capacity(ws_caps["capacity_name"], "pause")
            pause_resume_capacity(ws_caps["capacity_name"], "resume")

            print(f"🔄 Reassigning primary capacity for workspace '{_model['model_workspace_name']}'")
            labs.assign_workspace_to_capacity(ws_caps["capacity_name"], _model["model_workspace_name"])
            print(f"✅ Primary capacity assigned: {ws_caps['capacity_name']}")

            # Refresh the TOM cache and ensure the model is online after capacity actions.
            _refresh_tom_cache(_model["model_workspace_name"])

            wait_for_model_to_come_online(_model)

            # Update the pause status after the capacity has been paused.
            update_model_pause_status("capacity_paused", workspace=_model["model_workspace_name"])

            time.sleep(30)  # Allow time for the system to settle
            clear_vertipaq_cache(_model)
    else:
        print("ℹ️ Performing clear refresh for cold cache")
        refresh_status_clear = fabric.refresh_dataset(
            _model["name"],
            refresh_type="clearValues",
            workspace=_model["model_workspace_name"],
        )

        # Wait (by polling) for the clear refresh to complete or fail.
        _wait_for_refresh_to_complete(_model, refresh_status_clear)

        print("✅ Clear refresh completed; performing full refresh")
        refresh_status_full = fabric.refresh_dataset(
            _model["name"],
            refresh_type="full",
            workspace=_model["model_workspace_name"],
        )

        # Wait for the full refresh to complete or fail.
        _wait_for_refresh_to_complete(_model, refresh_status_full)

        # Finally, clear the VertiPaq cache (with retry logic in place).
        clear_vertipaq_cache(_model)

    print("✅ Cold cache set")

### Log-table helpers & query checks

In [None]:
@log_function_calls
def get_log_table(_table_name):
    """
    Returns a Spark DataFrame of the existing log table filtered for
    the environment & roundNumber. If the table doesn't exist, returns None.
    """
    try:
        raw_table = spark.table(_table_name)
        base_filters = (
            (col("roundNumber") == additional_arguments["roundNumber"])
            & (col("Event_Class") == "QueryEnd")
        )
        return raw_table.filter(base_filters)
    except Exception:
        return None

@log_function_calls
def max_queries_met(_check_logs, _log_table, _model_cache_combo, _queryId):
    """
    Checks if the maximum number of queries for a given
    model/cache/queryId combo has been met.
    
    The query is skipped if either:
      - The count of successful queries is greater than or equal to additional_arguments["maxNumberPerQuery"], or
      - The count of failed queries is greater than or equal to additional_arguments["maxFailuresBeforeSkipping"].
    """
    if _check_logs and _log_table is not None:
        base_filters = (
            (col("modelName") == _model_cache_combo["model"]["name"])
            & (col("queryId") == _queryId)
            & (col("cacheType") == _model_cache_combo["cache_type"])
        )
        
        success_count = _log_table.filter(
            base_filters & (col("Success") == "Success")
        ).count()
        
        failure_count = _log_table.filter(
            base_filters & (col("Success") == "Failure")
        ).count()

        result = (
            success_count >= additional_arguments["maxNumberPerQuery"]
            or failure_count >= additional_arguments["maxFailuresBeforeSkipping"]
        )
        
        print(
            f"📊 {'Skipping' if result else 'Continuing'} queries (Success: {success_count}, Failure: {failure_count})"
        )
        return result
    else:
        return False

@log_function_calls
def get_starting_query_id(_log_table, additional_arguments):
    """
    Determines the next queryId to start from by checking how many queries
    have fully met the maxNumberPerQuery across all model/cache combos.
    """
    print("🔍 Determining starting query ID")
    if _log_table is not None:
        success_failure_counts = _log_table.groupBy(
            "modelName", "cacheType", "queryId"
        ).agg(
            _sum(when(col("Success") == "Success", 1).otherwise(0)).alias(
                "success_count"
            ),
            _sum(when(col("Success") == "Failure", 1).otherwise(0)).alias(
                "failure_count"
            ),
        )

        valid_queries = success_failure_counts.filter(
            success_count >= additional_arguments["maxNumberPerQuery"]
            or failure_count >= additional_arguments["maxFailuresBeforeSkipping"]
        )

        distinct_combos_count = (
            valid_queries.select("modelName", "cacheType").distinct().count()
        )

        valid_query_ids = (
            valid_queries.groupBy("queryId")
            .agg(countDistinct("modelName", "cacheType").alias("valid_combinations"))
            .filter(col("valid_combinations") == distinct_combos_count)
        )

        query_id_list = [
            row["queryId"] for row in valid_query_ids.select("queryId").collect()
        ]
        query_id_list.sort()

        max_query_id = 0
        for i, qid in enumerate(query_id_list):
            if qid != i + 1:
                break
            max_query_id = qid

        starting_query_id = 1 if max_query_id == 0 else max_query_id + 1
    else:
        print(f"ℹ️ Log table {additional_arguments['logTableName']} does not exist")
        starting_query_id = 1

    print(f"✅ Starting query ID set to {starting_query_id}")
    return starting_query_id

### Main DAX testing orchestration functions

In [None]:
@log_function_calls
def run_dax_query_and_collect_logs(_model_cache_combo, _dax_query, _log_table):
    """
    Runs a single DAX query (given model + cache type + queryId),
    collects logs, and appends them to the table.
    """
    _model = _model_cache_combo["model"]
    used_dax_expression = _dax_query[_model["runQueryType"]]
    query_run_name = f"Model: {_model['name']}, QueryId: {_dax_query['queryId']}, Cache Type: {_model_cache_combo['cache_type']}"
    valid_cache_type_for_model = (
        _model_cache_combo["cache_type"] in _model["cache_types"]
    )

    print(f"🚀 Starting query: {query_run_name}")

    if (
        not max_queries_met(
            additional_arguments["onlyRunNewQueries"],
            _log_table,
            _model_cache_combo,
            _dax_query["queryId"],
        )
        and valid_cache_type_for_model
    ):

        # Record the time before cache setup begins.
        set_cache_start_time = datetime.now().isoformat()

        wait_for_model_to_come_online(_model)

        # Set the desired cache state before running the query.
        if _model_cache_combo["cache_type"] == "cold":
            set_cold_cache(_model)
            cache_set = True
        elif _model_cache_combo["cache_type"] == "warm":
            cache_set = set_warm_cache(_model, used_dax_expression)
        else:  # 'hot'
            cache_set = set_hot_cache(_model, used_dax_expression)

        # Record the time after cache setup completes.
        time.sleep(additional_arguments["pauseAfterSettingCache"])
        set_cache_end_time = datetime.now().isoformat()

        # Mark that the model has been queried for pause/resume tracking.
        update_model_pause_status("model_queried", model=_model)

        query_start_time = datetime.now().isoformat()

        # Start a new trace for the DAX query execution.
        fabric.create_trace_connection(
            _model["name"], _model["model_workspace_name"]
        ).drop_traces()
        trace_name = f"Simple DAX Trace {uuid4()}"

        with fabric.create_trace_connection(
            _model["name"], _model["model_workspace_name"]
        ) as trace_connection:
            with trace_connection.create_trace(event_schema, trace_name) as trace:
                print("🔍 Starting trace for DAX query")
                trace.start()
                wait_for_trace_start(trace_connection, trace_name)

                dax_query_result = "Success"
                try:
                    print("⚡ Executing DAX query")
                    fabric.evaluate_dax(
                        _model["name"],
                        used_dax_expression,
                        workspace=_model["model_workspace_name"],
                    )
                    print("✅ DAX query executed successfully")
                except Exception as e:
                    dax_query_result = str(e)
                    print("❌ DAX query execution failed:", e)

                print("📜 Collecting trace logs")
                wait_for_query_end_event(trace)
                current_query_trace_logs = trace.stop()

                # Extract RequestProperties from the QueryBegin event if available.
                if "Request Properties" in current_query_trace_logs.columns:
                    query_begin_rows = current_query_trace_logs[current_query_trace_logs["Event Class"] == "QueryBegin"]
                    if not query_begin_rows.empty:
                        request_properties_value = query_begin_rows.iloc[0]["Request Properties"]
                    else:
                        request_properties_value = None
                    current_query_trace_logs["Request Properties"] = request_properties_value
                else:
                    current_query_trace_logs["Request Properties"] = None

        time.sleep(additional_arguments["pauseAfterRunningQuery"])
        query_end_time = datetime.now().isoformat()
        
        # Append metadata columns to the trace logs DataFrame.
        current_query_trace_logs = current_query_trace_logs.assign(
            runId=run_id,
            setCacheStartTime=set_cache_start_time,
            setCacheEndTime=set_cache_end_time,
            queryStartTime=query_start_time,
            queryEndTime=query_end_time,
            modelName=_model["name"],
            queryId=_dax_query["queryId"],
            runQueryType=_model["runQueryType"],
            queryUUID=str(uuid4()),
            cacheType=_model_cache_combo["cache_type"],
            queryResult=dax_query_result,
            storageMode=_model["storageMode"],
            sentExpression=used_dax_expression,
            roundNumber=additional_arguments["roundNumber"],
        )

        # If cache was not set properly, mark the query as failed.
        if not cache_set:
            print("❌ Cache was not set properly; marking query as failed")
            current_query_trace_logs = current_query_trace_logs.assign(
                Success="Failure"
            )

        # Format column names for Spark by replacing spaces with underscores.
        current_query_trace_logs.columns = current_query_trace_logs.columns.str.replace(
            " ", "_"
        )
        current_query_trace_logs = spark.createDataFrame(current_query_trace_logs)

        print("💾 Appending logs to table")
        current_query_trace_logs.write.format("delta").mode("append").option(
            "mergeSchema", "true"
        ).saveAsTable(additional_arguments["logTableName"])

        print("✅ Logs appended to table")

        print("ℹ️ Pausing between runs")
        time.sleep(additional_arguments["pauseBetweenRuns"])

        return "Ran"

    else:
        print("⏭️ Query skipped (logs exist or invalid cache type)")
        return "Skipped"

In [None]:
@log_function_calls
def run_dax_queries():
    """
    Primary entry point: runs all queries from the loaded DAX Excel file.
    Handles log table management, capacity checks, and iteration over queries.
    """
    print("🚀 Starting all DAX queries")

    # Initialize the pause status for each model.
    update_model_pause_status("initialize")

    # Handle log table clearing or retrieval based on additional arguments.
    if additional_arguments["clearCurrentRoundLogs"]:
        print(
            f"🗑️ Dropping round {additional_arguments['roundNumber']} logs from {additional_arguments['logTableName']}"
        )
        spark.sql(
            f"DELETE FROM {additional_arguments['logTableName']} WHERE roundNumber = {additional_arguments['roundNumber']}"
        )
    if additional_arguments["clearAllLogs"]:
        print(f"🗑️ Dropping entire table {additional_arguments['logTableName']}")
        spark.sql(f"DROP TABLE IF EXISTS {additional_arguments['logTableName']}")
        startQueryIdsAt = 1
    else:
        print(f"🔍 Retrieving table {additional_arguments['logTableName']}")
        log_table = get_log_table(additional_arguments["logTableName"])
        startQueryIdsAt = (
            1
            if additional_arguments["clearCurrentRoundLogs"]
            or additional_arguments["forceStartQueriesAt1"]
            or log_table is None
            else get_starting_query_id(log_table, additional_arguments)
        )

    # Check if capacity pause/resume logic is required based on model storage modes.
    include_pause_resume_logic = any(
        model["storageMode"] in ["Import", "DirectQuery"]
        and "cold" in model["cache_types"]
        for model in models
    )

    if include_pause_resume_logic:
        # Validate that all required workspace capacities are defined.
        for m in models:
            ws = m["model_workspace_name"]
            if ws not in workspace_capacities:
                raise ValueError(
                    f"The workspace '{ws}' (in model '{m['name']}') is not found "
                    "in the 'workspace_capacities' dictionary. Please add it."
                )

        # Resume both primary and alternate capacities and reassign the primary capacity.
        for ws, caps in workspace_capacities.items():
            pause_resume_capacity(caps["capacity_name"], "resume", _simplify_logs=True)
            pause_resume_capacity(
                caps["alt_capacity_name"], "resume", _simplify_logs=True
            )
            print(
                f"✅ Assigning primary capacity '{caps['capacity_name']}' to workspace '{ws}'"
            )
            labs.assign_workspace_to_capacity(caps["capacity_name"], ws)

    # Loop over each DAX query from the Excel file.
    for _, dax_query in dax_queries.iterrows():
        if (
            dax_query["queryId"] <= additional_arguments["stopQueryIdsAt"]
            and dax_query["queryId"] >= startQueryIdsAt
        ):
            for _ in range(additional_arguments["numberOfRunsPerQueryId"]):
                total_query_count = 0
                skipped_query_count = 0
                if additional_arguments["randomizeRuns"]:
                    print("🔀 Randomizing run order of (model, cache_type)")
                    model_cache_combo = (
                        pd.DataFrame(
                            itertools.product(models, ["cold", "warm", "hot"]),
                            columns=["model", "cache_type"],
                        )
                        .sample(frac=1)
                        .reset_index(drop=True)
                    )
                else:
                    df = pd.DataFrame(
                        itertools.product(models, ["cold", "warm", "hot"]),
                        columns=["model", "cache_type"],
                    )
                    df["model_name"] = df["model"].apply(lambda m: m["name"])
                    df["cache_order"] = pd.Categorical(
                        df["cache_type"],
                        categories=["cold", "warm", "hot"],
                        ordered=True,
                    )
                    model_cache_combo = (
                        df.sort_values(by=["model_name", "cache_order"])
                        .drop(["model_name", "cache_order"], axis=1)
                        .reset_index(drop=True)
                    )

                for _, current_combo in model_cache_combo.iterrows():
                    total_query_count += 1
                    # Update the pause status when a model is queried.
                    if include_pause_resume_logic:
                        update_model_pause_status(
                            "model_queried", model=current_combo["model"]
                        )
                    run_status = run_dax_query_and_collect_logs(
                        current_combo, dax_query, log_table
                    )
                    if run_status == "Skipped":
                        skipped_query_count += 1

                print(
                    f"🔄 Refreshing log table from {additional_arguments['logTableName']}"
                )
                log_table = get_log_table(additional_arguments["logTableName"])

                if total_query_count == skipped_query_count:
                    print(
                        "ℹ️ No new queries; skipping additional runs for this query group"
                    )
                    break

    if include_pause_resume_logic:
        for ws, caps in workspace_capacities.items():
            pause_resume_capacity(
                caps["alt_capacity_name"], "pause", _simplify_logs=True
            )
    print("✅ All queries complete")

### Execute main flow

In [None]:
run_dax_queries()

### Stop session

In [None]:
mssparkutils.session.stop()