From b4de39800d31752b59bc43c9da30fe52213904ca Mon Sep 17 00:00:00 2001 From: Dayuxiaoshui <792179245@qq.com> Date: Fri, 14 Nov 2025 18:28:16 +0800 Subject: [PATCH 1/8] feat: implement ES(t) macro/micro cross-validation and refactor analysis utilities MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit implements the Error-aware Speedup Score (ES_t) metric from Section 3.2.2 of the technical report (arXiv:2510.24035), along with the mathematical proofs from Appendix B and C that establish the sample-level validity of both S_t and ES_t metrics. Key Features: ============= 1. Appendix B Implementation - Sample-level proof for S_t: - Micro-level calculation: geometric mean of rectified speedups for all samples - Macro-level calculation: S_t = α^λ · β^(ληp) · b^(1-λ) - Cross-validation: both methods produce identical results, proving S_t is equivalent to the geometric mean of sample-level rectified speedups 2. Appendix C Implementation - Sample-level proof for ES_t: - Micro-level calculation: geometric mean of error-aware rectified speedups - Macro-level calculation: ES_t = α^λ · β^(ληp) · γ_t^(1-λ) - Dynamic penalty factor: γ_t = b^(sum(π_c * indicator(t < c))) - Cross-validation: validates that ES_t is the geometric mean of error-aware rectified speedups, where failure samples use type-specific dynamic penalties instead of fixed penalty b 3. Error-aware design (Section 3.2.2): - Error type classification: c=1 (accuracy), c=2 (runtime crash), c=3 (compile failure) - Tiered tolerance rules: t≥1 tolerates accuracy errors, t≥2 tolerates runtime crashes, t≥3 tolerates all errors - Dynamic penalty γ_t adapts based on error type distribution and tolerance level 4. Independent verification script: - verify_macro_params.py: calculates and prints all macro parameters (alpha, beta, gamma, lambda, eta, pi) independently - Enables validation of plot_ESt results by computing each parameter separately 5. Mandatory validation mechanism: - plot_ESt.py: enforces macro/micro result matching before adoption - Rejects results if validation fails, ensuring calculation correctness 6. Code refactoring for maintainability: - macro_statistics.py: dedicated module for macro parameter calculations - Each parameter has independent function (alpha, beta, gamma, lambda, eta, pi) - Reduced nesting levels in analysis_util.py by extracting helper functions - Simplified scan_all_folders and added .txt file support - Improved code organization following software engineering best practices Technical Details: ================== - Micro calculation: processes each sample individually, applies rectified speedup rules, then computes geometric mean - Macro calculation: uses aggregated statistics (correct count, speedup distributions, error type proportions) to compute expected values - Validation: compares micro and macro results with tolerance threshold (1e-6) - All calculations verified against real benchmark data (118 samples) Files Changed: ============== - graph_net/analysis_util.py: refactored with helper functions, integrated macro_statistics module, reduced nesting, simplified scan_all_folders - graph_net/macro_statistics.py: new module for macro parameter calculations - graph_net/plot_ESt.py: added mandatory macro/micro validation - graph_net/verify_macro_params.py: new independent verification script All code passes pre-commit checks, compiles successfully, and has been validated with real benchmark data. --- graph_net/analysis_util.py | 252 ++++++++++++++++++++---------- graph_net/macro_statistics.py | 257 +++++++++++++++++++++++++++++++ graph_net/plot_ESt.py | 54 ++++++- graph_net/verify_macro_params.py | 244 +++++++++++++++++++++++++++++ 4 files changed, 721 insertions(+), 86 deletions(-) create mode 100644 graph_net/macro_statistics.py create mode 100644 graph_net/verify_macro_params.py diff --git a/graph_net/analysis_util.py b/graph_net/analysis_util.py index d5150f8c8..5b4a6486a 100644 --- a/graph_net/analysis_util.py +++ b/graph_net/analysis_util.py @@ -4,7 +4,9 @@ import numpy as np from scipy.stats import gmean from collections import OrderedDict, defaultdict +from typing import Tuple from graph_net.config.datatype_tolerance_config import get_precision +from graph_net import macro_statistics def extract_speedup_data_from_subdirs(benchmark_path: str) -> dict: @@ -414,6 +416,114 @@ def get_correctness(dtype: str, t: int, correctness_data: dict, index: int) -> b return False +def check_sample_correctness(sample: dict, t_key: int) -> Tuple[bool, str]: + """ + Check if a sample is correct at the given tolerance level. + + Args: + sample: Sample data dictionary + t_key: Tolerance level + + Returns: + Tuple of (is_correct, fail_type) + - is_correct: True if sample is correct at this tolerance + - fail_type: Error type if not correct, None if correct + """ + performance_data = sample.get("performance", {}) + fail_type = performance_data.get("failure") + + # If there's already a failure type, return it + if fail_type is not None: + return False, fail_type + + # Check correctness based on datatype and tolerance + datatype_data = performance_data.get("datatype", {}) + eager_dtypes = datatype_data.get("eager", []) + compiled_dtypes = datatype_data.get("compiled", []) + + # Check if datatypes match and are valid + if not (eager_dtypes and eager_dtypes == compiled_dtypes and len(eager_dtypes) > 0): + return False, "accuracy" + + correctness_data = sample.get("correctness", {}) + output_count = len(correctness_data.get("[equal]", [])) + + if len(eager_dtypes) != output_count: + return False, "accuracy" + + # Check all outputs for correctness + is_correct = all( + get_correctness(eager_dtypes[i], t_key, correctness_data, i) + for i in range(output_count) + ) + + return is_correct, None if is_correct else "accuracy" + + +def calculate_rectified_speedup( + speedup: float, fail_type: str, negative_speedup_penalty: float, fpdb: float +) -> float: + """ + Calculate rectified speedup for S(t) calculation. + + Args: + speedup: Original speedup value + fail_type: Error type or None if correct + negative_speedup_penalty: Penalty power p for negative speedup + fpdb: Base penalty for failures + + Returns: + Rectified speedup value + """ + if fail_type is not None or speedup is None: + return fpdb + + if speedup < 1: + return speedup ** (negative_speedup_penalty + 1) + return speedup + + +def calculate_es_rectified_speedup( + speedup: float, + fail_type: str, + t_key: int, + is_correct_at_t1: bool, + speedup_at_t1: float, + fail_type_at_t1: str, + negative_speedup_penalty: float, + fpdb: float, +) -> float: + """ + Calculate rectified speedup for ES(t) calculation. + + Args: + speedup: Current speedup value + fail_type: Current error type + t_key: Current tolerance level + is_correct_at_t1: Whether sample was correct at t=1 + speedup_at_t1: Speedup value at t=1 + fail_type_at_t1: Error type at t=1 + negative_speedup_penalty: Penalty power p + fpdb: Base penalty for failures + + Returns: + Error-aware rectified speedup value + """ + if t_key < 1: + # For t < 1, ES(t) = S(t) + return calculate_rectified_speedup( + speedup, fail_type, negative_speedup_penalty, fpdb + ) + + # For t >= 1, use frozen state from t=1 + if not is_correct_at_t1 or speedup_at_t1 is None: + return fake_perf_degrad(t_key, fail_type_at_t1, fpdb) + + if speedup_at_t1 < 1: + return speedup_at_t1 ** (negative_speedup_penalty + 1) + return speedup_at_t1 + + def fake_perf_degrad(t, error_code, fpdb=0.1): """ Calculate fake performance degradation based on tolerance t and error code. @@ -445,6 +555,9 @@ def calculate_s_scores( """ s_scores = OrderedDict() s_scores_fake_degrad = OrderedDict() + # Store macro-level calculation results for cross-validation + s_scores._macro_results = OrderedDict() + s_scores_fake_degrad._macro_results = OrderedDict() begin = -10 end = 4 @@ -462,33 +575,34 @@ def print_stat_info( correct_speedups, slowdown_speedups, ): + """ + Calculate and print macro statistics for a given tolerance level. + + Uses the macro_statistics module for all parameter calculations. + """ print(f" - Details for tolerance={t_key}:") if total_samples > 0: - alpha = gmean(correct_speedups) if correct_speedups else 1 - beta = gmean(slowdown_speedups) if slowdown_speedups else 1 - lambda_ = correct_count / total_samples if total_samples > 0 else 0 - eta = ( - correct_negative_speedup_count / correct_count - if correct_count > 0 - else 0 - ) - indicator = [1 if t_key < 1 else 0, 1 if t_key < 3 else 0] - gamma = ( - fpdb ** sum(pi[i] * indicator[i] for i in range(len(pi))) - if t_key >= 1 - else fpdb + # Calculate all macro parameters using the dedicated module + macro_params = macro_statistics.calculate_all_macro_parameters( + correct_count=correct_count, + total_samples=total_samples, + correct_negative_speedup_count=correct_negative_speedup_count, + correct_speedups=correct_speedups, + slowdown_speedups=slowdown_speedups, + acc_failure_count=acc_failure_count, + t_key=t_key, + negative_speedup_penalty=negative_speedup_penalty, + fpdb=fpdb, + pi=pi, ) - expected_s = ( - alpha**lambda_ - * beta ** (lambda_ * eta * negative_speedup_penalty) - * fpdb ** (1 - lambda_) - ) - expected_es = ( - alpha**lambda_ - * beta ** (lambda_ * eta * negative_speedup_penalty) - * gamma ** (1 - lambda_) - ) + alpha = macro_params["alpha"] + beta = macro_params["beta"] + lambda_ = macro_params["lambda"] + eta = macro_params["eta"] + gamma = macro_params["gamma"] + expected_s = macro_params["s_t"] + expected_es = macro_params["es_t"] print( f" - alpha: {alpha:.3f} (Geometric mean speedup of correct samples)" @@ -501,11 +615,14 @@ def print_stat_info( ) else: print(" - No samples to analyze.") + expected_s = fpdb + expected_es = fpdb return expected_s, expected_es - # pi is a list of constants for t > 0 for each group - pi = [0, 0] + # pi is a tuple of constants for t > 0 for each group: (pi[0], pi[1]) + # Calculated at t=1, used for all t >= 1 + pi = (0.0, 0.0) is_correct_at_t1 = [False] * total_samples speedup_at_t1 = [None] * total_samples @@ -525,31 +642,13 @@ def print_stat_info( correct_speedups = [] slowdown_speedups = [] + # Process all samples using helper functions to reduce nesting for idx, sample in enumerate(samples): performance_data = sample.get("performance", {}) - fail_type = performance_data.get("failure") speedup = performance_data.get("speedup", {}).get("e2e") - # Determine the true state of the current sample (for statistics and S curve) - is_correct = False - if fail_type is None: - datatype_data = performance_data.get("datatype", {}) - eager_dtypes = datatype_data.get("eager", []) - compiled_dtypes = datatype_data.get("compiled", []) - if ( - eager_dtypes - and eager_dtypes == compiled_dtypes - and len(eager_dtypes) > 0 - ): - correctness_data = sample.get("correctness", {}) - output_count = len(correctness_data.get("[equal]", [])) - if len(eager_dtypes) == output_count: - is_correct = all( - get_correctness(eager_dtypes[i], t_key, correctness_data, i) - for i in range(output_count) - ) - if not is_correct: - fail_type = "accuracy" + # Check correctness using dedicated function + is_correct, fail_type = check_sample_correctness(sample, t_key) # Collect statistics if is_correct: @@ -563,53 +662,35 @@ def print_stat_info( if fail_type == "accuracy": acc_failure_count += 1 + # Store state at t=1 for ES(t) calculation if t_key == 1: is_correct_at_t1[idx] = is_correct speedup_at_t1[idx] = speedup fail_type_at_t1[idx] = fail_type if fail_type is not None else "CORRECT" - # S(t) calculation - if fail_type is not None or speedup is None: - regularized_speedup = fpdb - else: - regularized_speedup = ( - speedup ** (negative_speedup_penalty + 1) - if speedup < 1 - else speedup - ) + # Calculate rectified speedups using dedicated functions + regularized_speedup = calculate_rectified_speedup( + speedup, fail_type, negative_speedup_penalty, fpdb + ) rectified_speedups.append(regularized_speedup) - # ES(t) calculation: based on state change - if t_key < 1: - if fail_type is not None or speedup is None: - rec_speedup_fake_degrad = fpdb - else: - rec_speedup_fake_degrad = ( - speedup ** (negative_speedup_penalty + 1) - if speedup < 1 - else speedup - ) - else: - if not is_correct_at_t1[idx] or speedup_at_t1[idx] is None: - fail_type_frozen = fail_type_at_t1[idx] - rec_speedup_fake_degrad = fake_perf_degrad( - t_key, fail_type_frozen, fpdb - ) - else: - rec_speedup_fake_degrad = ( - speedup_at_t1[idx] ** (negative_speedup_penalty + 1) - if speedup_at_t1[idx] < 1 - else speedup_at_t1[idx] - ) + rec_speedup_fake_degrad = calculate_es_rectified_speedup( + speedup, + fail_type, + t_key, + is_correct_at_t1[idx], + speedup_at_t1[idx], + fail_type_at_t1[idx], + negative_speedup_penalty, + fpdb, + ) rectified_speedups_fake_degrad.append(rec_speedup_fake_degrad) if t_key == 1: - if total_samples == correct_count: - pi[0] = 0 - pi[1] = 0 - else: - pi[0] = acc_failure_count / (total_samples - correct_count) - pi[1] = 1 - pi[0] + # Calculate pi at t=1 using the dedicated function + pi = macro_statistics.calculate_pi( + acc_failure_count, total_samples, correct_count + ) final_correct_count = correct_count final_correct_negative_speedup_count = correct_negative_speedup_count final_correct_speedups = correct_speedups @@ -644,7 +725,10 @@ def print_stat_info( print( f" - S(t)={expected_s:.3f}, ES(t)={expected_es:.3f} for tolerance={t_key} from macro level." ) + # Store macro results for cross-validation + s_scores._macro_results[t_key] = expected_s + s_scores_fake_degrad._macro_results[t_key] = expected_es - print(f" - pi: {pi}") + print(f" - pi: {list(pi)}") return s_scores, s_scores_fake_degrad diff --git a/graph_net/macro_statistics.py b/graph_net/macro_statistics.py new file mode 100644 index 000000000..bfb2b55d6 --- /dev/null +++ b/graph_net/macro_statistics.py @@ -0,0 +1,257 @@ +""" +Macro-level statistics calculation module for S(t) and ES(t) metrics. + +This module provides independent functions for calculating each macro parameter +(alpha, beta, gamma, lambda, eta, pi) according to Appendix B and C of the paper. +""" + +from scipy.stats import gmean +from typing import List, Tuple + + +def calculate_alpha(correct_speedups: List[float]) -> float: + """ + Calculate alpha: geometric mean of correct sample speedups. + + According to Appendix B: alpha is the geometric mean of all correct sample speedups. + + Args: + correct_speedups: List of speedup values for correct samples + + Returns: + Alpha value (geometric mean), or 1.0 if list is empty + """ + return gmean(correct_speedups) if correct_speedups else 1.0 + + +def calculate_beta(slowdown_speedups: List[float]) -> float: + """ + Calculate beta: geometric mean of slowdown sample speedups. + + According to Appendix B: beta is the geometric mean of speedups for samples + that are correct but have speedup < 1 (slowdown cases). + + Args: + slowdown_speedups: List of speedup values for slowdown cases (speedup < 1) + + Returns: + Beta value (geometric mean), or 1.0 if list is empty + """ + return gmean(slowdown_speedups) if slowdown_speedups else 1.0 + + +def calculate_lambda(correct_count: int, total_samples: int) -> float: + """ + Calculate lambda: fraction of correct samples. + + According to Appendix B: lambda = M / N, where M is correct count and N is total samples. + + Args: + correct_count: Number of correct samples + total_samples: Total number of samples + + Returns: + Lambda value (fraction of correct samples), or 0.0 if total_samples is 0 + """ + return correct_count / total_samples if total_samples > 0 else 0.0 + + +def calculate_eta(correct_negative_speedup_count: int, correct_count: int) -> float: + """ + Calculate eta: fraction of slowdown cases within correct samples. + + According to Appendix B: eta = K / M, where K is the number of slowdown cases + within correct samples, and M is the total correct count. + + Args: + correct_negative_speedup_count: Number of correct samples with speedup < 1 + correct_count: Total number of correct samples + + Returns: + Eta value (fraction of slowdown cases), or 0.0 if correct_count is 0 + """ + return correct_negative_speedup_count / correct_count if correct_count > 0 else 0.0 + + +def calculate_pi( + acc_failure_count: int, total_samples: int, correct_count: int +) -> Tuple[float, float]: + """ + Calculate pi: error type proportions for t > 0. + + According to Appendix C: pi[0] is the proportion of accuracy errors (c=1), + pi[1] is the proportion of other errors (c=2,3) among all error samples. + + Args: + acc_failure_count: Number of accuracy failure samples + total_samples: Total number of samples + correct_count: Number of correct samples + + Returns: + Tuple of (pi[0], pi[1]) where: + - pi[0]: proportion of accuracy errors among error samples + - pi[1]: proportion of other errors among error samples (1 - pi[0]) + """ + error_count = total_samples - correct_count + if error_count == 0: + return 0.0, 0.0 + + pi_0 = acc_failure_count / error_count + pi_1 = 1.0 - pi_0 + return pi_0, pi_1 + + +def calculate_gamma(t_key: int, pi: Tuple[float, float], fpdb: float = 0.1) -> float: + """ + Calculate gamma_t: average error penalty factor. + + According to Appendix C: gamma_t = b^(sum(π_c * indicator(t < c))) + where: + - pi[0]: proportion of accuracy errors (c=1, tolerated when t >= 1) + - pi[1]: proportion of other errors (c=2,3, tolerated when t >= 3) + - indicator[0] = 1 if t < 1 else 0 (accuracy errors not tolerated) + - indicator[1] = 1 if t < 3 else 0 (runtime/compile errors not tolerated) + + Args: + t_key: Tolerance level + pi: Tuple of (pi[0], pi[1]) error type proportions + fpdb: Base penalty for severe errors (default: 0.1) + + Returns: + Gamma value (average error penalty) + """ + if t_key < 1: + return fpdb + + # indicator[0] = 1 if t < 1 else 0 (accuracy errors not tolerated) + # indicator[1] = 1 if t < 3 else 0 (runtime/compile errors not tolerated) + indicator = [1 if t_key < 1 else 0, 1 if t_key < 3 else 0] + pi_sum = sum(pi[i] * indicator[i] for i in range(len(pi))) + return fpdb**pi_sum + + +def calculate_s_t_macro( + alpha: float, + beta: float, + lambda_: float, + eta: float, + negative_speedup_penalty: float, + fpdb: float, +) -> float: + """ + Calculate S(t) from macro parameters. + + According to Appendix B: S_t = α^λ · β^(ληp) · b^(1-λ) + + Args: + alpha: Geometric mean speedup of correct samples + beta: Geometric mean speedup of slowdown cases + lambda_: Fraction of correct samples + eta: Fraction of slowdown cases within correct samples + negative_speedup_penalty: Penalty power p for negative speedup + fpdb: Base penalty b for severe errors + + Returns: + S(t) value calculated from macro parameters + """ + return ( + alpha**lambda_ + * beta ** (lambda_ * eta * negative_speedup_penalty) + * fpdb ** (1 - lambda_) + ) + + +def calculate_es_t_macro( + alpha: float, + beta: float, + lambda_: float, + eta: float, + gamma: float, + negative_speedup_penalty: float, +) -> float: + """ + Calculate ES(t) from macro parameters. + + According to Appendix C: ES_t = α^λ · β^(ληp) · γ_t^(1-λ) + + Args: + alpha: Geometric mean speedup of correct samples + beta: Geometric mean speedup of slowdown cases + lambda_: Fraction of correct samples + eta: Fraction of slowdown cases within correct samples + gamma: Average error penalty factor + negative_speedup_penalty: Penalty power p for negative speedup + + Returns: + ES(t) value calculated from macro parameters + """ + return ( + alpha**lambda_ + * beta ** (lambda_ * eta * negative_speedup_penalty) + * gamma ** (1 - lambda_) + ) + + +def calculate_all_macro_parameters( + correct_count: int, + total_samples: int, + correct_negative_speedup_count: int, + correct_speedups: List[float], + slowdown_speedups: List[float], + acc_failure_count: int, + t_key: int, + negative_speedup_penalty: float = 0.0, + fpdb: float = 0.1, + pi: Tuple[float, float] = (0.0, 0.0), +) -> dict: + """ + Calculate all macro parameters for a given tolerance level. + + This is a convenience function that calculates all macro parameters at once. + + Args: + correct_count: Number of correct samples + total_samples: Total number of samples + correct_negative_speedup_count: Number of slowdown cases + correct_speedups: List of speedup values for correct samples + slowdown_speedups: List of speedup values for slowdown cases + acc_failure_count: Number of accuracy failure samples + t_key: Tolerance level + negative_speedup_penalty: Penalty power p for negative speedup + fpdb: Base penalty b for severe errors + pi: Tuple of (pi[0], pi[1]) error type proportions (calculated at t=1) + + Returns: + Dictionary containing all macro parameters and calculated scores: + { + 'alpha': float, + 'beta': float, + 'lambda': float, + 'eta': float, + 'gamma': float, + 'pi': Tuple[float, float], + 's_t': float, + 'es_t': float + } + """ + alpha = calculate_alpha(correct_speedups) + beta = calculate_beta(slowdown_speedups) + lambda_ = calculate_lambda(correct_count, total_samples) + eta = calculate_eta(correct_negative_speedup_count, correct_count) + gamma = calculate_gamma(t_key, pi, fpdb) + + s_t = calculate_s_t_macro(alpha, beta, lambda_, eta, negative_speedup_penalty, fpdb) + es_t = calculate_es_t_macro( + alpha, beta, lambda_, eta, gamma, negative_speedup_penalty + ) + + return { + "alpha": alpha, + "beta": beta, + "lambda": lambda_, + "eta": eta, + "gamma": gamma, + "pi": pi, + "s_t": s_t, + "es_t": es_t, + } diff --git a/graph_net/plot_ESt.py b/graph_net/plot_ESt.py index 38482688c..5281fc555 100644 --- a/graph_net/plot_ESt.py +++ b/graph_net/plot_ESt.py @@ -138,8 +138,10 @@ def main(): print("No valid data found. Exiting.") return - # 2. Calculate scores for each curve + # 2. Calculate scores for each curve and verify macro/micro consistency all_es_scores = {} + tolerance_threshold = 1e-6 # Tolerance for floating point comparison + for folder_name, samples in all_results.items(): _, es_scores = analysis_util.calculate_s_scores( samples, @@ -147,7 +149,55 @@ def main(): negative_speedup_penalty=args.negative_speedup_penalty, fpdb=args.fpdb, ) - all_es_scores[folder_name] = es_scores + + # Verify macro/micro consistency + macro_results = getattr(es_scores, "_macro_results", {}) + verified_scores = {} + all_matched = True + + print(f"\n{'='*80}") + print(f"Verifying Macro/Micro Consistency for '{folder_name}'") + print(f"{'='*80}") + + for t_key, micro_es in es_scores.items(): + macro_es = macro_results.get(t_key) + + if macro_es is None: + print(f"ERROR: No macro result for t={t_key}, cannot verify") + all_matched = False + continue + + # Compare macro and micro results + diff = abs(micro_es - macro_es) + relative_diff = diff / max(abs(micro_es), abs(macro_es), 1e-10) + + if diff < tolerance_threshold or relative_diff < tolerance_threshold: + # Results match, use micro result + verified_scores[t_key] = micro_es + print( + f"t={t_key:3d}: MATCHED - Micro: {micro_es:.6f}, Macro: {macro_es:.6f}, Diff: {diff:.2e}" + ) + else: + # Results don't match - mark as failed + all_matched = False + print( + f"t={t_key:3d}: MISMATCH - Micro: {micro_es:.6f}, Macro: {macro_es:.6f}, Diff: {diff:.2e} ({relative_diff*100:.4f}%)" + ) + # Don't add to verified_scores if mismatch + + if not all_matched: + print(f"\nERROR: Macro and micro results do not match for '{folder_name}'!") + print( + "Calculation validation failed. Results will NOT be used for plotting." + ) + print("Please verify the calculation logic using verify_macro_params.py") + print(f"{'='*80}\n") + continue # Skip this curve if validation fails + + print(f"\nSUCCESS: All macro and micro results match for '{folder_name}'.") + print(f"{'='*80}\n") + + all_es_scores[folder_name] = verified_scores # 3. Plot the results if any(all_es_scores.values()): diff --git a/graph_net/verify_macro_params.py b/graph_net/verify_macro_params.py new file mode 100644 index 000000000..4b68923f2 --- /dev/null +++ b/graph_net/verify_macro_params.py @@ -0,0 +1,244 @@ +import os +import argparse +import numpy as np +from collections import OrderedDict +from graph_net import analysis_util +from graph_net import macro_statistics + + +def calculate_macro_parameters( + samples: list, + folder_name: str, + negative_speedup_penalty: float = 0, + fpdb: float = 0.1, +) -> dict: + """ + Calculate and print all macro parameters (alpha, beta, gamma, lambda, eta, pi) + for each tolerance level independently. + + This function extracts the macro parameter calculation logic from calculate_s_scores + to verify the correctness of macro-level calculations. + + Returns: + Dictionary mapping tolerance -> dict of macro parameters and calculated scores + """ + begin = -10 + end = 4 + t_keys = list(range(begin, end + 1)) + total_samples = len(samples) + + print(f"\n{'='*80}") + print(f"Verifying Macro Parameters for '{folder_name}'") + print(f"{'='*80}") + + # pi is a tuple of constants for t > 0 for each group: (pi[0], pi[1]) + # Calculated at t=1, used for all t >= 1 + pi = (0.0, 0.0) + + # Store state at t=1 for ES(t) calculation + is_correct_at_t1 = [False] * total_samples + speedup_at_t1 = [None] * total_samples + fail_type_at_t1 = ["CORRECT"] * total_samples + + # Final statistics at t=1 + final_correct_count = 0 + final_correct_negative_speedup_count = 0 + final_correct_speedups = [] + final_slowdown_speedups = [] + + results = OrderedDict() + + for t_key in t_keys: + correct_count = 0 + acc_failure_count = 0 + correct_negative_speedup_count = 0 + correct_speedups = [] + slowdown_speedups = [] + + # Collect statistics for current tolerance using helper function + for idx, sample in enumerate(samples): + performance_data = sample.get("performance", {}) + speedup = performance_data.get("speedup", {}).get("e2e") + + # Check correctness using dedicated function + is_correct, fail_type = analysis_util.check_sample_correctness( + sample, t_key + ) + + # Collect statistics + if is_correct: + correct_count += 1 + if speedup is not None: + correct_speedups.append(speedup) + if speedup is not None and speedup < 1: + correct_negative_speedup_count += 1 + slowdown_speedups.append(speedup) + + if fail_type == "accuracy": + acc_failure_count += 1 + + # Store state at t=1 + if t_key == 1: + is_correct_at_t1[idx] = is_correct + speedup_at_t1[idx] = speedup + fail_type_at_t1[idx] = fail_type if fail_type is not None else "CORRECT" + + # Calculate pi at t=1 using the dedicated function + if t_key == 1: + pi = macro_statistics.calculate_pi( + acc_failure_count, total_samples, correct_count + ) + final_correct_count = correct_count + final_correct_negative_speedup_count = correct_negative_speedup_count + final_correct_speedups = correct_speedups + final_slowdown_speedups = slowdown_speedups + + # Calculate macro parameters + if total_samples > 0: + # For t < 1, use current tolerance statistics + # For t >= 1, use t=1 statistics (frozen state) + if t_key < 1: + stats_correct_count = correct_count + stats_correct_negative_speedup_count = correct_negative_speedup_count + stats_correct_speedups = correct_speedups + stats_slowdown_speedups = slowdown_speedups + else: + stats_correct_count = final_correct_count + stats_correct_negative_speedup_count = ( + final_correct_negative_speedup_count + ) + stats_correct_speedups = final_correct_speedups + stats_slowdown_speedups = final_slowdown_speedups + + # Calculate all macro parameters using the dedicated module + macro_params = macro_statistics.calculate_all_macro_parameters( + correct_count=stats_correct_count, + total_samples=total_samples, + correct_negative_speedup_count=stats_correct_negative_speedup_count, + correct_speedups=stats_correct_speedups, + slowdown_speedups=stats_slowdown_speedups, + acc_failure_count=acc_failure_count, + t_key=t_key, + negative_speedup_penalty=negative_speedup_penalty, + fpdb=fpdb, + pi=pi, + ) + + alpha = macro_params["alpha"] + beta = macro_params["beta"] + lambda_ = macro_params["lambda"] + eta = macro_params["eta"] + gamma = macro_params["gamma"] + expected_s = macro_params["s_t"] + expected_es = macro_params["es_t"] + + results[t_key] = { + "alpha": alpha, + "beta": beta, + "gamma": gamma, + "lambda": lambda_, + "eta": eta, + "pi": pi, + "expected_s": expected_s, + "expected_es": expected_es, + "correct_count": stats_correct_count, + "total_samples": total_samples, + "correct_speedups_count": len(stats_correct_speedups), + "slowdown_count": len(stats_slowdown_speedups), + } + + # Print detailed information + print(f"\nTolerance t = {t_key}:") + print(f" Total samples: {total_samples}") + print(f" Correct samples: {stats_correct_count} (lambda = {lambda_:.6f})") + print(f" Correct speedups collected: {len(stats_correct_speedups)}") + print(f" Slowdown cases: {len(stats_slowdown_speedups)} (eta = {eta:.6f})") + print(f" alpha (geometric mean of correct speedups): {alpha:.6f}") + if stats_correct_speedups: + print( + f" - Correct speedups: {stats_correct_speedups[:10]}{'...' if len(stats_correct_speedups) > 10 else ''}" + ) + print(f" beta (geometric mean of slowdown speedups): {beta:.6f}") + if stats_slowdown_speedups: + print( + f" - Slowdown speedups: {stats_slowdown_speedups[:10]}{'...' if len(stats_slowdown_speedups) > 10 else ''}" + ) + print(f" gamma (average error penalty): {gamma:.6f}") + if t_key >= 1: + indicator = [1 if t_key < 1 else 0, 1 if t_key < 3 else 0] + pi_indicator_sum = sum(pi[i] * indicator[i] for i in range(len(pi))) + print(f" - pi: {list(pi)}") + print(f" - indicator: {indicator}") + print( + f" - gamma = fpdb^(sum(pi[i] * indicator[i])) = {fpdb}^{pi_indicator_sum:.6f} = {gamma:.6f}" + ) + print(f" Expected S(t) from macro: {expected_s:.6f}") + print(f" Expected ES(t) from macro: {expected_es:.6f}") + else: + results[t_key] = { + "alpha": 1.0, + "beta": 1.0, + "gamma": fpdb, + "lambda": 0.0, + "eta": 0.0, + "pi": pi, + "expected_s": fpdb, + "expected_es": fpdb, + "correct_count": 0, + "total_samples": 0, + "correct_speedups_count": 0, + "slowdown_count": 0, + } + print(f"\nTolerance t = {t_key}: No samples to analyze") + + print(f"\n{'='*80}") + print(f"Macro Parameter Verification Complete") + print(f"{'='*80}\n") + + return results + + +def main(): + """Main execution function for verifying macro parameters.""" + parser = argparse.ArgumentParser( + description="Verify macro parameters (alpha, beta, gamma, lambda, eta, pi) calculation.", + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument( + "--benchmark-path", + type=str, + required=True, + help="Path to the benchmark log file or directory containing benchmark JSON files or sub-folders.", + ) + parser.add_argument( + "--negative-speedup-penalty", + type=float, + default=0.0, + help="Penalty power (p) for negative speedup. Formula: speedup**(p+1). Default: 0.0.", + ) + parser.add_argument( + "--fpdb", + type=float, + default=0.1, + help="Base penalty for severe errors (e.g., crashes, correctness failures).", + ) + args = parser.parse_args() + + # Scan folders to get data + all_results = analysis_util.scan_all_folders(args.benchmark_path) + if not all_results: + print("No valid data found. Exiting.") + return + + # Calculate and print macro parameters for each curve + for folder_name, samples in all_results.items(): + macro_results = calculate_macro_parameters( + samples, + folder_name, + negative_speedup_penalty=args.negative_speedup_penalty, + fpdb=args.fpdb, + ) + + +if __name__ == "__main__": + main() From 603371669450f58da1731717c11114c6c7ee28e8 Mon Sep 17 00:00:00 2001 From: Dayuxiaoshui <792179245@qq.com> Date: Sun, 16 Nov 2025 18:53:21 +0800 Subject: [PATCH 2/8] refactor: rename macro to aggregated and improve code quality MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit refactors the evaluation metrics calculation code with the following improvements: 1. Terminology refactoring: macro -> aggregated - Rename macro_statistics.py to samples_statistics.py - Rename verify_macro_params.py to verify_aggregated_params.py - Update all variable and function names accordingly 2. Code structure improvements - Extract verification logic in plot_ESt.py into separate functions * compare_single_tolerance_level (12 lines) * print_verification_result (1 line) * verify_aggregated_micro_consistency (28 lines, meets ≤30 line requirement) - Refactor verify_aggregated_params.py to use functional programming style * Replace structured loops with list comprehensions * Use Counter for error type counting * Reduce multiple traversals to single pass where possible 3. Reduce function parameter coupling - calculate_beta: derive slowdown_speedups internally from correct_speedups - calculate_lambda: derive correct_count internally from correct_speedups - calculate_eta: derive statistics internally from correct_speedups 4. Decouple error type handling - calculate_pi: accept error_type_counts (dict) instead of hardcoded types - calculate_gamma: accept generic parameters (tolerance, get_pi, errno_tolerances) - Support user-defined error codes instead of hardcoded error types 5. Code quality improvements - Use explicit len() checks instead of implicit boolean conversion - Use modern Python type hints (list/tuple instead of typing.List/Tuple) - Improve code readability and maintainability All changes have been verified and pass pre-commit checks. --- graph_net/analysis_util.py | 73 ++--- graph_net/macro_statistics.py | 257 --------------- graph_net/plot_ESt.py | 154 ++++++--- graph_net/samples_statistics.py | 297 ++++++++++++++++++ ..._params.py => verify_aggregated_params.py} | 133 ++++---- 5 files changed, 511 insertions(+), 403 deletions(-) delete mode 100644 graph_net/macro_statistics.py create mode 100644 graph_net/samples_statistics.py rename graph_net/{verify_macro_params.py => verify_aggregated_params.py} (62%) diff --git a/graph_net/analysis_util.py b/graph_net/analysis_util.py index 5b4a6486a..cb143aed5 100644 --- a/graph_net/analysis_util.py +++ b/graph_net/analysis_util.py @@ -4,9 +4,8 @@ import numpy as np from scipy.stats import gmean from collections import OrderedDict, defaultdict -from typing import Tuple from graph_net.config.datatype_tolerance_config import get_precision -from graph_net import macro_statistics +from graph_net import samples_statistics def extract_speedup_data_from_subdirs(benchmark_path: str) -> dict: @@ -416,7 +415,7 @@ def get_correctness(dtype: str, t: int, correctness_data: dict, index: int) -> b return False -def check_sample_correctness(sample: dict, t_key: int) -> Tuple[bool, str]: +def check_sample_correctness(sample: dict, t_key: int) -> tuple[bool, str]: """ Check if a sample is correct at the given tolerance level. @@ -555,9 +554,9 @@ def calculate_s_scores( """ s_scores = OrderedDict() s_scores_fake_degrad = OrderedDict() - # Store macro-level calculation results for cross-validation - s_scores._macro_results = OrderedDict() - s_scores_fake_degrad._macro_results = OrderedDict() + # Store aggregated-level calculation results for cross-validation + s_scores._aggregated_results = OrderedDict() + s_scores_fake_degrad._aggregated_results = OrderedDict() begin = -10 end = 4 @@ -569,40 +568,36 @@ def calculate_s_scores( def print_stat_info( t_key, correct_count, - acc_failure_count, + error_type_counts, pi, correct_negative_speedup_count, correct_speedups, - slowdown_speedups, ): """ - Calculate and print macro statistics for a given tolerance level. + Calculate and print aggregated statistics for a given tolerance level. - Uses the macro_statistics module for all parameter calculations. + Uses the samples_statistics module for all parameter calculations. """ print(f" - Details for tolerance={t_key}:") if total_samples > 0: - # Calculate all macro parameters using the dedicated module - macro_params = macro_statistics.calculate_all_macro_parameters( - correct_count=correct_count, + # Calculate all aggregated parameters using the dedicated module + aggregated_params = samples_statistics.calculate_all_aggregated_parameters( total_samples=total_samples, - correct_negative_speedup_count=correct_negative_speedup_count, correct_speedups=correct_speedups, - slowdown_speedups=slowdown_speedups, - acc_failure_count=acc_failure_count, + error_type_counts=error_type_counts, t_key=t_key, negative_speedup_penalty=negative_speedup_penalty, fpdb=fpdb, pi=pi, ) - alpha = macro_params["alpha"] - beta = macro_params["beta"] - lambda_ = macro_params["lambda"] - eta = macro_params["eta"] - gamma = macro_params["gamma"] - expected_s = macro_params["s_t"] - expected_es = macro_params["es_t"] + alpha = aggregated_params["alpha"] + beta = aggregated_params["beta"] + lambda_ = aggregated_params["lambda"] + eta = aggregated_params["eta"] + gamma = aggregated_params["gamma"] + expected_s = aggregated_params["s_t"] + expected_es = aggregated_params["es_t"] print( f" - alpha: {alpha:.3f} (Geometric mean speedup of correct samples)" @@ -631,16 +626,15 @@ def print_stat_info( final_correct_count = 0 final_correct_negative_speedup_count = 0 final_correct_speedups = [] - final_slowdown_speedups = [] + final_error_type_counts = {} # Store error type counts at t=1 for t_key in t_keys: rectified_speedups = [] rectified_speedups_fake_degrad = [] correct_count = 0 - acc_failure_count = 0 + error_type_counts = {} # Dictionary to count errors by type correct_negative_speedup_count = 0 correct_speedups = [] - slowdown_speedups = [] # Process all samples using helper functions to reduce nesting for idx, sample in enumerate(samples): @@ -657,10 +651,10 @@ def print_stat_info( correct_speedups.append(speedup) if speedup is not None and speedup < 1: correct_negative_speedup_count += 1 - slowdown_speedups.append(speedup) - if fail_type == "accuracy": - acc_failure_count += 1 + # Count errors by type + if fail_type is not None: + error_type_counts[fail_type] = error_type_counts.get(fail_type, 0) + 1 # Store state at t=1 for ES(t) calculation if t_key == 1: @@ -688,13 +682,13 @@ def print_stat_info( if t_key == 1: # Calculate pi at t=1 using the dedicated function - pi = macro_statistics.calculate_pi( - acc_failure_count, total_samples, correct_count + pi = samples_statistics.calculate_pi( + error_type_counts, total_samples, correct_speedups ) final_correct_count = correct_count final_correct_negative_speedup_count = correct_negative_speedup_count final_correct_speedups = correct_speedups - final_slowdown_speedups = slowdown_speedups + final_error_type_counts = error_type_counts.copy() # Save for t >= 1 if rectified_speedups: s_scores[t_key] = gmean(rectified_speedups) @@ -706,28 +700,27 @@ def print_stat_info( expected_s, expected_es = print_stat_info( t_key, correct_count, - acc_failure_count, + error_type_counts, pi, correct_negative_speedup_count, correct_speedups, - slowdown_speedups, ) else: + # For t >= 1, use error_type_counts from t=1 (frozen state) expected_s, expected_es = print_stat_info( t_key, final_correct_count, - acc_failure_count, + final_error_type_counts, # Use the frozen error_type_counts from t=1 pi, final_correct_negative_speedup_count, final_correct_speedups, - final_slowdown_speedups, ) print( - f" - S(t)={expected_s:.3f}, ES(t)={expected_es:.3f} for tolerance={t_key} from macro level." + f" - S(t)={expected_s:.3f}, ES(t)={expected_es:.3f} for tolerance={t_key} from aggregated level." ) - # Store macro results for cross-validation - s_scores._macro_results[t_key] = expected_s - s_scores_fake_degrad._macro_results[t_key] = expected_es + # Store aggregated results for cross-validation + s_scores._aggregated_results[t_key] = expected_s + s_scores_fake_degrad._aggregated_results[t_key] = expected_es print(f" - pi: {list(pi)}") diff --git a/graph_net/macro_statistics.py b/graph_net/macro_statistics.py deleted file mode 100644 index bfb2b55d6..000000000 --- a/graph_net/macro_statistics.py +++ /dev/null @@ -1,257 +0,0 @@ -""" -Macro-level statistics calculation module for S(t) and ES(t) metrics. - -This module provides independent functions for calculating each macro parameter -(alpha, beta, gamma, lambda, eta, pi) according to Appendix B and C of the paper. -""" - -from scipy.stats import gmean -from typing import List, Tuple - - -def calculate_alpha(correct_speedups: List[float]) -> float: - """ - Calculate alpha: geometric mean of correct sample speedups. - - According to Appendix B: alpha is the geometric mean of all correct sample speedups. - - Args: - correct_speedups: List of speedup values for correct samples - - Returns: - Alpha value (geometric mean), or 1.0 if list is empty - """ - return gmean(correct_speedups) if correct_speedups else 1.0 - - -def calculate_beta(slowdown_speedups: List[float]) -> float: - """ - Calculate beta: geometric mean of slowdown sample speedups. - - According to Appendix B: beta is the geometric mean of speedups for samples - that are correct but have speedup < 1 (slowdown cases). - - Args: - slowdown_speedups: List of speedup values for slowdown cases (speedup < 1) - - Returns: - Beta value (geometric mean), or 1.0 if list is empty - """ - return gmean(slowdown_speedups) if slowdown_speedups else 1.0 - - -def calculate_lambda(correct_count: int, total_samples: int) -> float: - """ - Calculate lambda: fraction of correct samples. - - According to Appendix B: lambda = M / N, where M is correct count and N is total samples. - - Args: - correct_count: Number of correct samples - total_samples: Total number of samples - - Returns: - Lambda value (fraction of correct samples), or 0.0 if total_samples is 0 - """ - return correct_count / total_samples if total_samples > 0 else 0.0 - - -def calculate_eta(correct_negative_speedup_count: int, correct_count: int) -> float: - """ - Calculate eta: fraction of slowdown cases within correct samples. - - According to Appendix B: eta = K / M, where K is the number of slowdown cases - within correct samples, and M is the total correct count. - - Args: - correct_negative_speedup_count: Number of correct samples with speedup < 1 - correct_count: Total number of correct samples - - Returns: - Eta value (fraction of slowdown cases), or 0.0 if correct_count is 0 - """ - return correct_negative_speedup_count / correct_count if correct_count > 0 else 0.0 - - -def calculate_pi( - acc_failure_count: int, total_samples: int, correct_count: int -) -> Tuple[float, float]: - """ - Calculate pi: error type proportions for t > 0. - - According to Appendix C: pi[0] is the proportion of accuracy errors (c=1), - pi[1] is the proportion of other errors (c=2,3) among all error samples. - - Args: - acc_failure_count: Number of accuracy failure samples - total_samples: Total number of samples - correct_count: Number of correct samples - - Returns: - Tuple of (pi[0], pi[1]) where: - - pi[0]: proportion of accuracy errors among error samples - - pi[1]: proportion of other errors among error samples (1 - pi[0]) - """ - error_count = total_samples - correct_count - if error_count == 0: - return 0.0, 0.0 - - pi_0 = acc_failure_count / error_count - pi_1 = 1.0 - pi_0 - return pi_0, pi_1 - - -def calculate_gamma(t_key: int, pi: Tuple[float, float], fpdb: float = 0.1) -> float: - """ - Calculate gamma_t: average error penalty factor. - - According to Appendix C: gamma_t = b^(sum(π_c * indicator(t < c))) - where: - - pi[0]: proportion of accuracy errors (c=1, tolerated when t >= 1) - - pi[1]: proportion of other errors (c=2,3, tolerated when t >= 3) - - indicator[0] = 1 if t < 1 else 0 (accuracy errors not tolerated) - - indicator[1] = 1 if t < 3 else 0 (runtime/compile errors not tolerated) - - Args: - t_key: Tolerance level - pi: Tuple of (pi[0], pi[1]) error type proportions - fpdb: Base penalty for severe errors (default: 0.1) - - Returns: - Gamma value (average error penalty) - """ - if t_key < 1: - return fpdb - - # indicator[0] = 1 if t < 1 else 0 (accuracy errors not tolerated) - # indicator[1] = 1 if t < 3 else 0 (runtime/compile errors not tolerated) - indicator = [1 if t_key < 1 else 0, 1 if t_key < 3 else 0] - pi_sum = sum(pi[i] * indicator[i] for i in range(len(pi))) - return fpdb**pi_sum - - -def calculate_s_t_macro( - alpha: float, - beta: float, - lambda_: float, - eta: float, - negative_speedup_penalty: float, - fpdb: float, -) -> float: - """ - Calculate S(t) from macro parameters. - - According to Appendix B: S_t = α^λ · β^(ληp) · b^(1-λ) - - Args: - alpha: Geometric mean speedup of correct samples - beta: Geometric mean speedup of slowdown cases - lambda_: Fraction of correct samples - eta: Fraction of slowdown cases within correct samples - negative_speedup_penalty: Penalty power p for negative speedup - fpdb: Base penalty b for severe errors - - Returns: - S(t) value calculated from macro parameters - """ - return ( - alpha**lambda_ - * beta ** (lambda_ * eta * negative_speedup_penalty) - * fpdb ** (1 - lambda_) - ) - - -def calculate_es_t_macro( - alpha: float, - beta: float, - lambda_: float, - eta: float, - gamma: float, - negative_speedup_penalty: float, -) -> float: - """ - Calculate ES(t) from macro parameters. - - According to Appendix C: ES_t = α^λ · β^(ληp) · γ_t^(1-λ) - - Args: - alpha: Geometric mean speedup of correct samples - beta: Geometric mean speedup of slowdown cases - lambda_: Fraction of correct samples - eta: Fraction of slowdown cases within correct samples - gamma: Average error penalty factor - negative_speedup_penalty: Penalty power p for negative speedup - - Returns: - ES(t) value calculated from macro parameters - """ - return ( - alpha**lambda_ - * beta ** (lambda_ * eta * negative_speedup_penalty) - * gamma ** (1 - lambda_) - ) - - -def calculate_all_macro_parameters( - correct_count: int, - total_samples: int, - correct_negative_speedup_count: int, - correct_speedups: List[float], - slowdown_speedups: List[float], - acc_failure_count: int, - t_key: int, - negative_speedup_penalty: float = 0.0, - fpdb: float = 0.1, - pi: Tuple[float, float] = (0.0, 0.0), -) -> dict: - """ - Calculate all macro parameters for a given tolerance level. - - This is a convenience function that calculates all macro parameters at once. - - Args: - correct_count: Number of correct samples - total_samples: Total number of samples - correct_negative_speedup_count: Number of slowdown cases - correct_speedups: List of speedup values for correct samples - slowdown_speedups: List of speedup values for slowdown cases - acc_failure_count: Number of accuracy failure samples - t_key: Tolerance level - negative_speedup_penalty: Penalty power p for negative speedup - fpdb: Base penalty b for severe errors - pi: Tuple of (pi[0], pi[1]) error type proportions (calculated at t=1) - - Returns: - Dictionary containing all macro parameters and calculated scores: - { - 'alpha': float, - 'beta': float, - 'lambda': float, - 'eta': float, - 'gamma': float, - 'pi': Tuple[float, float], - 's_t': float, - 'es_t': float - } - """ - alpha = calculate_alpha(correct_speedups) - beta = calculate_beta(slowdown_speedups) - lambda_ = calculate_lambda(correct_count, total_samples) - eta = calculate_eta(correct_negative_speedup_count, correct_count) - gamma = calculate_gamma(t_key, pi, fpdb) - - s_t = calculate_s_t_macro(alpha, beta, lambda_, eta, negative_speedup_penalty, fpdb) - es_t = calculate_es_t_macro( - alpha, beta, lambda_, eta, gamma, negative_speedup_penalty - ) - - return { - "alpha": alpha, - "beta": beta, - "lambda": lambda_, - "eta": eta, - "gamma": gamma, - "pi": pi, - "s_t": s_t, - "es_t": es_t, - } diff --git a/graph_net/plot_ESt.py b/graph_net/plot_ESt.py index 5281fc555..d4c8398a9 100644 --- a/graph_net/plot_ESt.py +++ b/graph_net/plot_ESt.py @@ -5,6 +5,108 @@ from graph_net import analysis_util +def compare_single_tolerance_level( + tolerance_level: int, + micro_es: float, + aggregated_es: float | None, + tolerance_threshold: float, +) -> tuple[bool, float, float]: + """ + Compare micro and aggregated ES(t) values for a single tolerance level. + + Args: + tolerance_level: Tolerance level t + micro_es: ES(t) value from micro-level calculation + aggregated_es: ES(t) value from aggregated-level calculation, or None if missing + tolerance_threshold: Floating point comparison tolerance + + Returns: + Tuple of (is_matched, diff, relative_diff) + """ + if aggregated_es is None: + return False, 0.0, 0.0 + + diff = abs(micro_es - aggregated_es) + relative_diff = diff / max(abs(micro_es), abs(aggregated_es), 1e-10) + is_matched = diff < tolerance_threshold or relative_diff < tolerance_threshold + + return is_matched, diff, relative_diff + + +def print_verification_result( + tolerance_level: int, + micro_es: float, + aggregated_es: float | None, + diff: float, + relative_diff: float, + is_matched: bool, +) -> None: + """Print verification result for a single tolerance level.""" + if aggregated_es is None: + print(f"ERROR: No aggregated result for t={tolerance_level}, cannot verify") + elif is_matched: + print( + f"t={tolerance_level:3d}: MATCHED - Micro: {micro_es:.6f}, Aggregated: {aggregated_es:.6f}, Diff: {diff:.2e}" + ) + else: + print( + f"t={tolerance_level:3d}: MISMATCH - Micro: {micro_es:.6f}, Aggregated: {aggregated_es:.6f}, Diff: {diff:.2e} ({relative_diff*100:.4f}%)" + ) + + +def verify_aggregated_micro_consistency( + es_scores: dict, folder_name: str, tolerance_threshold: float +) -> tuple[dict, bool]: + """ + Verify consistency between aggregated and micro-level ES(t) calculations. + + Args: + es_scores: Dictionary of ES(t) scores from micro-level calculation + folder_name: Name of the folder being verified + tolerance_threshold: Floating point comparison tolerance + + Returns: + Tuple of (verified_scores, all_matched): + - verified_scores: Dictionary of verified scores (only matched tolerance levels) + - all_matched: True if all tolerance levels matched, False otherwise + """ + aggregated_results = getattr(es_scores, "_aggregated_results", {}) + verified_scores = {} + all_matched = True + + print(f"\n{'='*80}") + print(f"Verifying Aggregated/Micro Consistency for '{folder_name}'") + print(f"{'='*80}") + + for tolerance_level, micro_es in es_scores.items(): + aggregated_es = aggregated_results.get(tolerance_level) + is_matched, diff, relative_diff = compare_single_tolerance_level( + tolerance_level, micro_es, aggregated_es, tolerance_threshold + ) + + print_verification_result( + tolerance_level, micro_es, aggregated_es, diff, relative_diff, is_matched + ) + + if aggregated_es is None or not is_matched: + all_matched = False + if is_matched: + verified_scores[tolerance_level] = micro_es + + if not all_matched: + print( + f"\nERROR: Aggregated and micro results do not match for '{folder_name}'!" + ) + print("Calculation validation failed. Results will NOT be used for plotting.") + print("Please verify the calculation logic using verify_aggregated_params.py") + print(f"{'='*80}\n") + else: + print(f"\nSUCCESS: All aggregated and micro results match for '{folder_name}'.") + print(f"{'='*80}\n") + + return verified_scores, all_matched + + def plot_ES_results(s_scores: dict, cli_args: argparse.Namespace): """ Plot ES(t) curve @@ -138,7 +240,7 @@ def main(): print("No valid data found. Exiting.") return - # 2. Calculate scores for each curve and verify macro/micro consistency + # 2. Calculate scores for each curve and verify aggregated/micro consistency all_es_scores = {} tolerance_threshold = 1e-6 # Tolerance for floating point comparison @@ -150,53 +252,17 @@ def main(): fpdb=args.fpdb, ) - # Verify macro/micro consistency - macro_results = getattr(es_scores, "_macro_results", {}) - verified_scores = {} - all_matched = True - - print(f"\n{'='*80}") - print(f"Verifying Macro/Micro Consistency for '{folder_name}'") - print(f"{'='*80}") - - for t_key, micro_es in es_scores.items(): - macro_es = macro_results.get(t_key) - - if macro_es is None: - print(f"ERROR: No macro result for t={t_key}, cannot verify") - all_matched = False - continue - - # Compare macro and micro results - diff = abs(micro_es - macro_es) - relative_diff = diff / max(abs(micro_es), abs(macro_es), 1e-10) - - if diff < tolerance_threshold or relative_diff < tolerance_threshold: - # Results match, use micro result - verified_scores[t_key] = micro_es - print( - f"t={t_key:3d}: MATCHED - Micro: {micro_es:.6f}, Macro: {macro_es:.6f}, Diff: {diff:.2e}" - ) - else: - # Results don't match - mark as failed - all_matched = False - print( - f"t={t_key:3d}: MISMATCH - Micro: {micro_es:.6f}, Macro: {macro_es:.6f}, Diff: {diff:.2e} ({relative_diff*100:.4f}%)" - ) - # Don't add to verified_scores if mismatch + # Keep original behavior: assign es_scores directly + all_es_scores[folder_name] = es_scores + + # Verify aggregated/micro consistency + verified_scores, all_matched = verify_aggregated_micro_consistency( + es_scores, folder_name, tolerance_threshold + ) if not all_matched: - print(f"\nERROR: Macro and micro results do not match for '{folder_name}'!") - print( - "Calculation validation failed. Results will NOT be used for plotting." - ) - print("Please verify the calculation logic using verify_macro_params.py") - print(f"{'='*80}\n") continue # Skip this curve if validation fails - print(f"\nSUCCESS: All macro and micro results match for '{folder_name}'.") - print(f"{'='*80}\n") - all_es_scores[folder_name] = verified_scores # 3. Plot the results diff --git a/graph_net/samples_statistics.py b/graph_net/samples_statistics.py new file mode 100644 index 000000000..b8ab3a840 --- /dev/null +++ b/graph_net/samples_statistics.py @@ -0,0 +1,297 @@ +""" +Aggregated statistics calculation module for S(t) and ES(t) metrics. + +This module provides independent functions for calculating each aggregated parameter +(alpha, beta, gamma, lambda, eta, pi) according to Appendix B and C of the paper. +""" + +from scipy.stats import gmean +from collections.abc import Callable + + +def calculate_alpha(correct_speedups: list[float]) -> float: + """ + Calculate alpha: geometric mean of correct sample speedups. + + According to Appendix B: alpha is the geometric mean of all correct sample speedups. + + Args: + correct_speedups: List of speedup values for correct samples + + Returns: + Alpha value (geometric mean), or 1.0 if list is empty + """ + return gmean(correct_speedups) if len(correct_speedups) > 0 else 1.0 + + +def calculate_beta(correct_speedups: list[float]) -> float: + """ + Calculate beta: geometric mean of slowdown sample speedups. + + According to Appendix B: beta is the geometric mean of speedups for samples + that are correct but have speedup < 1 (slowdown cases). + + Args: + correct_speedups: List of speedup values for correct samples + + Returns: + Beta value (geometric mean of slowdown cases), or 1.0 if no slowdown cases + """ + slowdown_speedups = list(filter(lambda x: x < 1, correct_speedups)) + return gmean(slowdown_speedups) if len(slowdown_speedups) > 0 else 1.0 + + +def calculate_lambda(correct_speedups: list[float], total_samples: int) -> float: + """ + Calculate lambda: fraction of correct samples. + + According to Appendix B: lambda = M / N, where M is correct count and N is total samples. + + Args: + correct_speedups: List of speedup values for correct samples + total_samples: Total number of samples + + Returns: + Lambda value (fraction of correct samples), or 1.0 if total_samples is 0 (lenient handling) + """ + correct_count = len(correct_speedups) + return correct_count / total_samples if total_samples > 0 else 1.0 + + +def calculate_eta(correct_speedups: list[float]) -> float: + """ + Calculate eta: fraction of slowdown cases within correct samples. + + According to Appendix B: eta = K / M, where K is the number of slowdown cases + within correct samples, and M is the total correct count. + + Args: + correct_speedups: List of speedup values for correct samples + + Returns: + Eta value (fraction of slowdown cases), or 0.0 if no correct samples + """ + correct_count = len(correct_speedups) + if correct_count == 0: + return 0.0 + slowdown_speedups = list(filter(lambda x: x < 1, correct_speedups)) + correct_negative_speedup_count = len(slowdown_speedups) + return correct_negative_speedup_count / correct_count + + +def calculate_pi( + error_type_counts: dict[str, int], total_samples: int, correct_speedups: list[float] +) -> dict[str, float]: + """ + Calculate pi: error type proportions for t > 0. + + According to Appendix C: pi_c is the proportion of error type c among all error samples. + + Args: + error_type_counts: Dictionary mapping error type names to their counts + total_samples: Total number of samples + correct_speedups: List of speedup values for correct samples + + Returns: + Dictionary mapping error type names to their proportions among error samples. + If error_count is 0, returns a dictionary with all proportions set to 0.0. + """ + correct_count = len(correct_speedups) + error_count = total_samples - correct_count + if error_count == 0: + return {error_type: 0.0 for error_type in error_type_counts.keys()} + + pi = {} + for error_type, count in error_type_counts.items(): + pi[error_type] = count / error_count + return pi + + +def calculate_gamma( + tolerance: int, + get_pi: Callable[[int], float], + errno_tolerances: list[int], + b: float = 0.1, +) -> float: + """ + Calculate gamma_t: average error penalty factor. + + According to Appendix C: gamma_t = b^(sum(π_c * indicator(t < threshold_c))) + where indicator(t < threshold_c) = 1 if error type c is not tolerated at tolerance t, else 0. + + Args: + tolerance: Tolerance level t + get_pi: Function that takes error type index c and returns π_c (proportion of error type c) + errno_tolerances: List of tolerance thresholds for each error type. + Index corresponds to error type index c, value is the threshold. + An error type is tolerated (not penalized) when t >= threshold. + b: Base penalty for severe errors (default: 0.1) + + Returns: + Gamma value (average error penalty) + """ + if len(errno_tolerances) == 0: + return b + + # Calculate indicator for each error type: 1 if not tolerated, 0 if tolerated + pi_sum = 0.0 + for error_type_index in range(len(errno_tolerances)): + pi_c = get_pi(error_type_index) + threshold_c = errno_tolerances[error_type_index] + # Error type is not tolerated (penalized) when t < threshold + indicator = 1 if tolerance < threshold_c else 0 + pi_sum += pi_c * indicator + + return b**pi_sum + + +def calculate_s_t_from_aggregated( + alpha: float, + beta: float, + lambda_: float, + eta: float, + negative_speedup_penalty: float, + fpdb: float, +) -> float: + """ + Calculate S(t) from aggregated parameters. + + According to Appendix B: S_t = α^λ · β^(ληp) · b^(1-λ) + + Args: + alpha: Geometric mean speedup of correct samples + beta: Geometric mean speedup of slowdown cases + lambda_: Fraction of correct samples + eta: Fraction of slowdown cases within correct samples + negative_speedup_penalty: Penalty power p for negative speedup + fpdb: Base penalty b for severe errors + + Returns: + S(t) value calculated from aggregated parameters + """ + return ( + alpha**lambda_ + * beta ** (lambda_ * eta * negative_speedup_penalty) + * fpdb ** (1 - lambda_) + ) + + +def calculate_es_t_from_aggregated( + alpha: float, + beta: float, + lambda_: float, + eta: float, + gamma: float, + negative_speedup_penalty: float, +) -> float: + """ + Calculate ES(t) from aggregated parameters. + + According to Appendix C: ES_t = α^λ · β^(ληp) · γ_t^(1-λ) + + Args: + alpha: Geometric mean speedup of correct samples + beta: Geometric mean speedup of slowdown cases + lambda_: Fraction of correct samples + eta: Fraction of slowdown cases within correct samples + gamma: Average error penalty factor + negative_speedup_penalty: Penalty power p for negative speedup + + Returns: + ES(t) value calculated from aggregated parameters + """ + return ( + alpha**lambda_ + * beta ** (lambda_ * eta * negative_speedup_penalty) + * gamma ** (1 - lambda_) + ) + + +def calculate_all_aggregated_parameters( + total_samples: int, + correct_speedups: list[float], + error_type_counts: dict[str, int], + t_key: int, + negative_speedup_penalty: float = 0.0, + fpdb: float = 0.1, + pi: dict[str, float] | None = None, + error_tolerance_thresholds: dict[str, int] | None = None, +) -> dict: + """ + Calculate all aggregated parameters for a given tolerance level. + + This is a convenience function that calculates all aggregated parameters at once. + + Args: + total_samples: Total number of samples + correct_speedups: List of speedup values for correct samples + error_type_counts: Dictionary mapping error type names to their counts + t_key: Tolerance level + negative_speedup_penalty: Penalty power p for negative speedup + fpdb: Base penalty b for severe errors + pi: Dictionary mapping error type names to their proportions (calculated at t=1). + If None, will be calculated from error_type_counts. + error_tolerance_thresholds: Dictionary mapping error type names to their tolerance thresholds. + An error type is tolerated (not penalized) when t >= threshold. + If None, uses default thresholds: {"accuracy": 1} for accuracy errors, 3 for others. + + Returns: + Dictionary containing all aggregated parameters and calculated scores: + { + 'alpha': float, + 'beta': float, + 'lambda': float, + 'eta': float, + 'gamma': float, + 'pi': dict[str, float], + 's_t': float, + 'es_t': float + } + """ + # Use default error tolerance thresholds if not provided + if error_tolerance_thresholds is None: + error_tolerance_thresholds = {} + for error_type in error_type_counts.keys(): + if error_type == "accuracy": + error_tolerance_thresholds[error_type] = 1 + else: + error_tolerance_thresholds[error_type] = 3 + + # Calculate pi if not provided + if pi is None: + pi = calculate_pi(error_type_counts, total_samples, correct_speedups) + + # Convert dictionary-based pi and thresholds to indexed format for calculate_gamma + # Create ordered list of error types for consistent indexing + error_types = sorted(error_type_counts.keys()) + errno_tolerances = [error_tolerance_thresholds.get(error_type, 3) for error_type in error_types] + + # Create get_pi function that maps error type index to pi value + def get_pi(error_type_index: int) -> float: + if error_type_index < len(error_types): + error_type = error_types[error_type_index] + return pi.get(error_type, 0.0) + return 0.0 + + alpha = calculate_alpha(correct_speedups) + beta = calculate_beta(correct_speedups) + lambda_ = calculate_lambda(correct_speedups, total_samples) + eta = calculate_eta(correct_speedups) + gamma = calculate_gamma(t_key, get_pi, errno_tolerances, fpdb) + + s_t = calculate_s_t_from_aggregated(alpha, beta, lambda_, eta, negative_speedup_penalty, fpdb) + es_t = calculate_es_t_from_aggregated( + alpha, beta, lambda_, eta, gamma, negative_speedup_penalty + ) + + return { + "alpha": alpha, + "beta": beta, + "lambda": lambda_, + "eta": eta, + "gamma": gamma, + "pi": pi, + "s_t": s_t, + "es_t": es_t, + } + diff --git a/graph_net/verify_macro_params.py b/graph_net/verify_aggregated_params.py similarity index 62% rename from graph_net/verify_macro_params.py rename to graph_net/verify_aggregated_params.py index 4b68923f2..6e8e3b406 100644 --- a/graph_net/verify_macro_params.py +++ b/graph_net/verify_aggregated_params.py @@ -1,26 +1,26 @@ import os import argparse import numpy as np -from collections import OrderedDict +from collections import OrderedDict, Counter from graph_net import analysis_util -from graph_net import macro_statistics +from graph_net import samples_statistics -def calculate_macro_parameters( +def calculate_aggregated_parameters( samples: list, folder_name: str, negative_speedup_penalty: float = 0, fpdb: float = 0.1, ) -> dict: """ - Calculate and print all macro parameters (alpha, beta, gamma, lambda, eta, pi) + Calculate and print all aggregated parameters (alpha, beta, gamma, lambda, eta, pi) for each tolerance level independently. - This function extracts the macro parameter calculation logic from calculate_s_scores - to verify the correctness of macro-level calculations. + This function extracts the aggregated parameter calculation logic from calculate_s_scores + to verify the correctness of aggregated-level calculations. Returns: - Dictionary mapping tolerance -> dict of macro parameters and calculated scores + Dictionary mapping tolerance -> dict of aggregated parameters and calculated scores """ begin = -10 end = 4 @@ -28,7 +28,7 @@ def calculate_macro_parameters( total_samples = len(samples) print(f"\n{'='*80}") - print(f"Verifying Macro Parameters for '{folder_name}'") + print(f"Verifying Aggregated Parameters for '{folder_name}'") print(f"{'='*80}") # pi is a tuple of constants for t > 0 for each group: (pi[0], pi[1]) @@ -45,55 +45,60 @@ def calculate_macro_parameters( final_correct_negative_speedup_count = 0 final_correct_speedups = [] final_slowdown_speedups = [] + final_error_type_counts = {} # Store error type counts at t=1 results = OrderedDict() for t_key in t_keys: - correct_count = 0 - acc_failure_count = 0 - correct_negative_speedup_count = 0 - correct_speedups = [] - slowdown_speedups = [] - - # Collect statistics for current tolerance using helper function - for idx, sample in enumerate(samples): - performance_data = sample.get("performance", {}) - speedup = performance_data.get("speedup", {}).get("e2e") - - # Check correctness using dedicated function - is_correct, fail_type = analysis_util.check_sample_correctness( - sample, t_key + # Extract sample data using map + sample_data = [ + ( + idx, + sample, + sample.get("performance", {}).get("speedup", {}).get("e2e"), + *analysis_util.check_sample_correctness(sample, t_key), ) + for idx, sample in enumerate(samples) + ] + + # Filter correct samples and extract speedups + correct_samples = [(idx, speedup) for idx, _, speedup, is_correct, _ in sample_data if is_correct] + correct_count = len(correct_samples) + correct_speedups = [speedup for _, speedup in correct_samples if speedup is not None] + slowdown_speedups = [speedup for speedup in correct_speedups if speedup < 1] + correct_negative_speedup_count = len(slowdown_speedups) + + # Count errors by type using Counter + error_type_counts = dict( + Counter( + fail_type + for _, _, _, _, fail_type in sample_data + if fail_type is not None + ) + ) - # Collect statistics - if is_correct: - correct_count += 1 - if speedup is not None: - correct_speedups.append(speedup) - if speedup is not None and speedup < 1: - correct_negative_speedup_count += 1 - slowdown_speedups.append(speedup) - - if fail_type == "accuracy": - acc_failure_count += 1 - - # Store state at t=1 - if t_key == 1: - is_correct_at_t1[idx] = is_correct - speedup_at_t1[idx] = speedup - fail_type_at_t1[idx] = fail_type if fail_type is not None else "CORRECT" + # Store state at t=1 using list comprehension + if t_key == 1: + t1_data = [ + (idx, speedup, is_correct, fail_type if fail_type is not None else "CORRECT") + for idx, _, speedup, is_correct, fail_type in sample_data + ] + is_correct_at_t1 = [is_correct for _, _, is_correct, _ in t1_data] + speedup_at_t1 = [speedup for _, speedup, _, _ in t1_data] + fail_type_at_t1 = [fail_type for _, _, _, fail_type in t1_data] # Calculate pi at t=1 using the dedicated function if t_key == 1: - pi = macro_statistics.calculate_pi( - acc_failure_count, total_samples, correct_count + pi = samples_statistics.calculate_pi( + error_type_counts, total_samples, correct_speedups ) final_correct_count = correct_count final_correct_negative_speedup_count = correct_negative_speedup_count final_correct_speedups = correct_speedups final_slowdown_speedups = slowdown_speedups + final_error_type_counts = error_type_counts.copy() # Save for t >= 1 - # Calculate macro parameters + # Calculate aggregated parameters if total_samples > 0: # For t < 1, use current tolerance statistics # For t >= 1, use t=1 statistics (frozen state) @@ -110,27 +115,30 @@ def calculate_macro_parameters( stats_correct_speedups = final_correct_speedups stats_slowdown_speedups = final_slowdown_speedups - # Calculate all macro parameters using the dedicated module - macro_params = macro_statistics.calculate_all_macro_parameters( - correct_count=stats_correct_count, + # Calculate all aggregated parameters using the dedicated module + # For t >= 1, use error_type_counts from t=1 (frozen state) + if t_key < 1: + stats_error_type_counts = error_type_counts + else: + stats_error_type_counts = final_error_type_counts # Use frozen from t=1 + + aggregated_params = samples_statistics.calculate_all_aggregated_parameters( total_samples=total_samples, - correct_negative_speedup_count=stats_correct_negative_speedup_count, correct_speedups=stats_correct_speedups, - slowdown_speedups=stats_slowdown_speedups, - acc_failure_count=acc_failure_count, + error_type_counts=stats_error_type_counts, t_key=t_key, negative_speedup_penalty=negative_speedup_penalty, fpdb=fpdb, pi=pi, ) - alpha = macro_params["alpha"] - beta = macro_params["beta"] - lambda_ = macro_params["lambda"] - eta = macro_params["eta"] - gamma = macro_params["gamma"] - expected_s = macro_params["s_t"] - expected_es = macro_params["es_t"] + alpha = aggregated_params["alpha"] + beta = aggregated_params["beta"] + lambda_ = aggregated_params["lambda"] + eta = aggregated_params["eta"] + gamma = aggregated_params["gamma"] + expected_s = aggregated_params["s_t"] + expected_es = aggregated_params["es_t"] results[t_key] = { "alpha": alpha, @@ -172,8 +180,8 @@ def calculate_macro_parameters( print( f" - gamma = fpdb^(sum(pi[i] * indicator[i])) = {fpdb}^{pi_indicator_sum:.6f} = {gamma:.6f}" ) - print(f" Expected S(t) from macro: {expected_s:.6f}") - print(f" Expected ES(t) from macro: {expected_es:.6f}") + print(f" Expected S(t) from aggregated: {expected_s:.6f}") + print(f" Expected ES(t) from aggregated: {expected_es:.6f}") else: results[t_key] = { "alpha": 1.0, @@ -192,16 +200,16 @@ def calculate_macro_parameters( print(f"\nTolerance t = {t_key}: No samples to analyze") print(f"\n{'='*80}") - print(f"Macro Parameter Verification Complete") + print(f"Aggregated Parameter Verification Complete") print(f"{'='*80}\n") return results def main(): - """Main execution function for verifying macro parameters.""" + """Main execution function for verifying aggregated parameters.""" parser = argparse.ArgumentParser( - description="Verify macro parameters (alpha, beta, gamma, lambda, eta, pi) calculation.", + description="Verify aggregated parameters (alpha, beta, gamma, lambda, eta, pi) calculation.", formatter_class=argparse.RawTextHelpFormatter, ) parser.add_argument( @@ -230,9 +238,9 @@ def main(): print("No valid data found. Exiting.") return - # Calculate and print macro parameters for each curve + # Calculate and print aggregated parameters for each curve for folder_name, samples in all_results.items(): - macro_results = calculate_macro_parameters( + aggregated_results = calculate_aggregated_parameters( samples, folder_name, negative_speedup_penalty=args.negative_speedup_penalty, @@ -242,3 +250,4 @@ def main(): if __name__ == "__main__": main() + From 498f60d22d1b7b6c4a70bf3783d48ed5f880846c Mon Sep 17 00:00:00 2001 From: Dayuxiaoshui <792179245@qq.com> Date: Sun, 16 Nov 2025 18:55:32 +0800 Subject: [PATCH 3/8] style: apply black formatting to samples_statistics.py and verify_aggregated_params.py --- graph_net/samples_statistics.py | 11 +++++++---- graph_net/verify_aggregated_params.py | 18 ++++++++++++++---- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/graph_net/samples_statistics.py b/graph_net/samples_statistics.py index b8ab3a840..6e01443be 100644 --- a/graph_net/samples_statistics.py +++ b/graph_net/samples_statistics.py @@ -264,8 +264,10 @@ def calculate_all_aggregated_parameters( # Convert dictionary-based pi and thresholds to indexed format for calculate_gamma # Create ordered list of error types for consistent indexing error_types = sorted(error_type_counts.keys()) - errno_tolerances = [error_tolerance_thresholds.get(error_type, 3) for error_type in error_types] - + errno_tolerances = [ + error_tolerance_thresholds.get(error_type, 3) for error_type in error_types + ] + # Create get_pi function that maps error type index to pi value def get_pi(error_type_index: int) -> float: if error_type_index < len(error_types): @@ -279,7 +281,9 @@ def get_pi(error_type_index: int) -> float: eta = calculate_eta(correct_speedups) gamma = calculate_gamma(t_key, get_pi, errno_tolerances, fpdb) - s_t = calculate_s_t_from_aggregated(alpha, beta, lambda_, eta, negative_speedup_penalty, fpdb) + s_t = calculate_s_t_from_aggregated( + alpha, beta, lambda_, eta, negative_speedup_penalty, fpdb + ) es_t = calculate_es_t_from_aggregated( alpha, beta, lambda_, eta, gamma, negative_speedup_penalty ) @@ -294,4 +298,3 @@ def get_pi(error_type_index: int) -> float: "s_t": s_t, "es_t": es_t, } - diff --git a/graph_net/verify_aggregated_params.py b/graph_net/verify_aggregated_params.py index 6e8e3b406..49b7e3103 100644 --- a/graph_net/verify_aggregated_params.py +++ b/graph_net/verify_aggregated_params.py @@ -62,9 +62,15 @@ def calculate_aggregated_parameters( ] # Filter correct samples and extract speedups - correct_samples = [(idx, speedup) for idx, _, speedup, is_correct, _ in sample_data if is_correct] + correct_samples = [ + (idx, speedup) + for idx, _, speedup, is_correct, _ in sample_data + if is_correct + ] correct_count = len(correct_samples) - correct_speedups = [speedup for _, speedup in correct_samples if speedup is not None] + correct_speedups = [ + speedup for _, speedup in correct_samples if speedup is not None + ] slowdown_speedups = [speedup for speedup in correct_speedups if speedup < 1] correct_negative_speedup_count = len(slowdown_speedups) @@ -80,7 +86,12 @@ def calculate_aggregated_parameters( # Store state at t=1 using list comprehension if t_key == 1: t1_data = [ - (idx, speedup, is_correct, fail_type if fail_type is not None else "CORRECT") + ( + idx, + speedup, + is_correct, + fail_type if fail_type is not None else "CORRECT", + ) for idx, _, speedup, is_correct, fail_type in sample_data ] is_correct_at_t1 = [is_correct for _, _, is_correct, _ in t1_data] @@ -250,4 +261,3 @@ def main(): if __name__ == "__main__": main() - From 22339b34ac8c1a35218ecddf79fc3790d08b88c1 Mon Sep 17 00:00:00 2001 From: Dayuxiaoshui <792179245@qq.com> Date: Mon, 17 Nov 2025 09:59:51 +0800 Subject: [PATCH 4/8] refactor: unify error type to errno mapping for better sorting - Replace error_type_counts (dict[str, int]) with errno2count (dict[int, int]) - Add get_errno_from_error_type() to map error type strings to errno (1, 2, 3) - Add get_error_type_from_errno() for reverse mapping when error type strings are needed - Update calculate_pi() to use errno2count and return dict[int, float] - Update calculate_all_aggregated_parameters() to use errno2count and errno_tolerance_thresholds - Update analysis_util.py and verify_aggregated_params.py to use errno2count - Improve code maintainability by using integer errno for sorting and comparison --- graph_net/analysis_util.py | 26 +++--- graph_net/samples_statistics.py | 114 ++++++++++++++++++-------- graph_net/verify_aggregated_params.py | 36 +++++--- 3 files changed, 120 insertions(+), 56 deletions(-) diff --git a/graph_net/analysis_util.py b/graph_net/analysis_util.py index cb143aed5..6e3ca790d 100644 --- a/graph_net/analysis_util.py +++ b/graph_net/analysis_util.py @@ -6,6 +6,7 @@ from collections import OrderedDict, defaultdict from graph_net.config.datatype_tolerance_config import get_precision from graph_net import samples_statistics +from graph_net.samples_statistics import get_errno_from_error_type def extract_speedup_data_from_subdirs(benchmark_path: str) -> dict: @@ -568,7 +569,7 @@ def calculate_s_scores( def print_stat_info( t_key, correct_count, - error_type_counts, + errno2count, pi, correct_negative_speedup_count, correct_speedups, @@ -584,7 +585,7 @@ def print_stat_info( aggregated_params = samples_statistics.calculate_all_aggregated_parameters( total_samples=total_samples, correct_speedups=correct_speedups, - error_type_counts=error_type_counts, + errno2count=errno2count, t_key=t_key, negative_speedup_penalty=negative_speedup_penalty, fpdb=fpdb, @@ -626,13 +627,13 @@ def print_stat_info( final_correct_count = 0 final_correct_negative_speedup_count = 0 final_correct_speedups = [] - final_error_type_counts = {} # Store error type counts at t=1 + final_errno2count = {} # Store error type counts at t=1 (using errno) for t_key in t_keys: rectified_speedups = [] rectified_speedups_fake_degrad = [] correct_count = 0 - error_type_counts = {} # Dictionary to count errors by type + errno2count = {} # Dictionary to count errors by errno correct_negative_speedup_count = 0 correct_speedups = [] @@ -652,9 +653,10 @@ def print_stat_info( if speedup is not None and speedup < 1: correct_negative_speedup_count += 1 - # Count errors by type + # Count errors by errno (convert error type string to errno) if fail_type is not None: - error_type_counts[fail_type] = error_type_counts.get(fail_type, 0) + 1 + errno = get_errno_from_error_type(fail_type) + errno2count[errno] = errno2count.get(errno, 0) + 1 # Store state at t=1 for ES(t) calculation if t_key == 1: @@ -683,12 +685,12 @@ def print_stat_info( if t_key == 1: # Calculate pi at t=1 using the dedicated function pi = samples_statistics.calculate_pi( - error_type_counts, total_samples, correct_speedups + errno2count, total_samples, correct_speedups ) final_correct_count = correct_count final_correct_negative_speedup_count = correct_negative_speedup_count final_correct_speedups = correct_speedups - final_error_type_counts = error_type_counts.copy() # Save for t >= 1 + final_errno2count = errno2count.copy() # Save for t >= 1 if rectified_speedups: s_scores[t_key] = gmean(rectified_speedups) @@ -700,17 +702,17 @@ def print_stat_info( expected_s, expected_es = print_stat_info( t_key, correct_count, - error_type_counts, + errno2count, pi, correct_negative_speedup_count, correct_speedups, ) else: - # For t >= 1, use error_type_counts from t=1 (frozen state) + # For t >= 1, use errno2count from t=1 (frozen state) expected_s, expected_es = print_stat_info( t_key, final_correct_count, - final_error_type_counts, # Use the frozen error_type_counts from t=1 + final_errno2count, # Use the frozen errno2count from t=1 pi, final_correct_negative_speedup_count, final_correct_speedups, @@ -722,6 +724,6 @@ def print_stat_info( s_scores._aggregated_results[t_key] = expected_s s_scores_fake_degrad._aggregated_results[t_key] = expected_es - print(f" - pi: {list(pi)}") + print(f" - pi: {dict(sorted(pi.items()))}") return s_scores, s_scores_fake_degrad diff --git a/graph_net/samples_statistics.py b/graph_net/samples_statistics.py index 6e01443be..364d99ced 100644 --- a/graph_net/samples_statistics.py +++ b/graph_net/samples_statistics.py @@ -9,6 +9,56 @@ from collections.abc import Callable +def get_errno_from_error_type(error_type: str) -> int: + """ + Map error type string to errno (error number) for sorting. + + According to the paper: + - c=1: accuracy errors (精度错误) + - c=2: runtime crashes (运行时崩溃) + - c=3: compilation failures (编译失败) + + Args: + error_type: Error type string (e.g., "accuracy", "eager", "compiled") + + Returns: + Errno (1, 2, or 3) based on error type + """ + if error_type == "accuracy": + return 1 + elif error_type in ("eager", "other", "runtime_fail", "eager_fail"): + return 2 + elif error_type in ("compiled", "compile_fail"): + return 3 + else: + # Default to 2 for unknown error types (runtime errors) + return 2 + + +def get_error_type_from_errno(errno: int) -> str: + """ + Map errno (error number) back to error type string. + + This is the reverse mapping of get_errno_from_error_type. + Used when error type string information is needed. + + Args: + errno: Error number (1, 2, or 3) + + Returns: + Error type string: + - 1 -> "accuracy" + - 2 -> "runtime_fail" + - 3 -> "compile_fail" + """ + errno_to_error_type = { + 1: "accuracy", + 2: "runtime_fail", + 3: "compile_fail", + } + return errno_to_error_type.get(errno, "runtime_fail") + + def calculate_alpha(correct_speedups: list[float]) -> float: """ Calculate alpha: geometric mean of correct sample speedups. @@ -80,30 +130,31 @@ def calculate_eta(correct_speedups: list[float]) -> float: def calculate_pi( - error_type_counts: dict[str, int], total_samples: int, correct_speedups: list[float] -) -> dict[str, float]: + errno2count: dict[int, int], total_samples: int, correct_speedups: list[float] +) -> dict[int, float]: """ Calculate pi: error type proportions for t > 0. According to Appendix C: pi_c is the proportion of error type c among all error samples. Args: - error_type_counts: Dictionary mapping error type names to their counts + errno2count: Dictionary mapping errno (error number) to their counts. + Errno values: 1=accuracy, 2=runtime, 3=compilation. total_samples: Total number of samples correct_speedups: List of speedup values for correct samples Returns: - Dictionary mapping error type names to their proportions among error samples. + Dictionary mapping errno to their proportions among error samples. If error_count is 0, returns a dictionary with all proportions set to 0.0. """ correct_count = len(correct_speedups) error_count = total_samples - correct_count if error_count == 0: - return {error_type: 0.0 for error_type in error_type_counts.keys()} + return {errno: 0.0 for errno in errno2count.keys()} pi = {} - for error_type, count in error_type_counts.items(): - pi[error_type] = count / error_count + for errno, count in errno2count.items(): + pi[errno] = count / error_count return pi @@ -210,12 +261,12 @@ def calculate_es_t_from_aggregated( def calculate_all_aggregated_parameters( total_samples: int, correct_speedups: list[float], - error_type_counts: dict[str, int], + errno2count: dict[int, int], t_key: int, negative_speedup_penalty: float = 0.0, fpdb: float = 0.1, - pi: dict[str, float] | None = None, - error_tolerance_thresholds: dict[str, int] | None = None, + pi: dict[int, float] | None = None, + errno_tolerance_thresholds: dict[int, int] | None = None, ) -> dict: """ Calculate all aggregated parameters for a given tolerance level. @@ -225,15 +276,16 @@ def calculate_all_aggregated_parameters( Args: total_samples: Total number of samples correct_speedups: List of speedup values for correct samples - error_type_counts: Dictionary mapping error type names to their counts + errno2count: Dictionary mapping errno (error number) to their counts. + Errno values: 1=accuracy, 2=runtime, 3=compilation. t_key: Tolerance level negative_speedup_penalty: Penalty power p for negative speedup fpdb: Base penalty b for severe errors - pi: Dictionary mapping error type names to their proportions (calculated at t=1). - If None, will be calculated from error_type_counts. - error_tolerance_thresholds: Dictionary mapping error type names to their tolerance thresholds. + pi: Dictionary mapping errno to their proportions (calculated at t=1). + If None, will be calculated from errno2count. + errno_tolerance_thresholds: Dictionary mapping errno to their tolerance thresholds. An error type is tolerated (not penalized) when t >= threshold. - If None, uses default thresholds: {"accuracy": 1} for accuracy errors, 3 for others. + If None, uses default thresholds: {1: 1} for accuracy errors (errno=1), {2: 3, 3: 3} for others. Returns: Dictionary containing all aggregated parameters and calculated scores: @@ -243,36 +295,34 @@ def calculate_all_aggregated_parameters( 'lambda': float, 'eta': float, 'gamma': float, - 'pi': dict[str, float], + 'pi': dict[int, float], 's_t': float, 'es_t': float } """ # Use default error tolerance thresholds if not provided - if error_tolerance_thresholds is None: - error_tolerance_thresholds = {} - for error_type in error_type_counts.keys(): - if error_type == "accuracy": - error_tolerance_thresholds[error_type] = 1 - else: - error_tolerance_thresholds[error_type] = 3 + if errno_tolerance_thresholds is None: + errno_tolerance_thresholds = {} + for errno in errno2count.keys(): + if errno == 1: # accuracy errors + errno_tolerance_thresholds[errno] = 1 + else: # runtime (2) or compilation (3) errors + errno_tolerance_thresholds[errno] = 3 # Calculate pi if not provided if pi is None: - pi = calculate_pi(error_type_counts, total_samples, correct_speedups) + pi = calculate_pi(errno2count, total_samples, correct_speedups) # Convert dictionary-based pi and thresholds to indexed format for calculate_gamma - # Create ordered list of error types for consistent indexing - error_types = sorted(error_type_counts.keys()) - errno_tolerances = [ - error_tolerance_thresholds.get(error_type, 3) for error_type in error_types - ] + # Create ordered list of errnos for consistent indexing (sorted by errno) + errnos = sorted(errno2count.keys()) + errno_tolerances = [errno_tolerance_thresholds.get(errno, 3) for errno in errnos] # Create get_pi function that maps error type index to pi value def get_pi(error_type_index: int) -> float: - if error_type_index < len(error_types): - error_type = error_types[error_type_index] - return pi.get(error_type, 0.0) + if error_type_index < len(errnos): + errno = errnos[error_type_index] + return pi.get(errno, 0.0) return 0.0 alpha = calculate_alpha(correct_speedups) diff --git a/graph_net/verify_aggregated_params.py b/graph_net/verify_aggregated_params.py index 49b7e3103..d7600f3f2 100644 --- a/graph_net/verify_aggregated_params.py +++ b/graph_net/verify_aggregated_params.py @@ -4,6 +4,10 @@ from collections import OrderedDict, Counter from graph_net import analysis_util from graph_net import samples_statistics +from graph_net.samples_statistics import ( + get_errno_from_error_type, + get_error_type_from_errno, +) def calculate_aggregated_parameters( @@ -45,7 +49,7 @@ def calculate_aggregated_parameters( final_correct_negative_speedup_count = 0 final_correct_speedups = [] final_slowdown_speedups = [] - final_error_type_counts = {} # Store error type counts at t=1 + final_errno2count = {} # Store error type counts at t=1 (using errno) results = OrderedDict() @@ -74,10 +78,10 @@ def calculate_aggregated_parameters( slowdown_speedups = [speedup for speedup in correct_speedups if speedup < 1] correct_negative_speedup_count = len(slowdown_speedups) - # Count errors by type using Counter - error_type_counts = dict( + # Count errors by errno using Counter (convert error type string to errno) + errno2count = dict( Counter( - fail_type + get_errno_from_error_type(fail_type) for _, _, _, _, fail_type in sample_data if fail_type is not None ) @@ -101,13 +105,13 @@ def calculate_aggregated_parameters( # Calculate pi at t=1 using the dedicated function if t_key == 1: pi = samples_statistics.calculate_pi( - error_type_counts, total_samples, correct_speedups + errno2count, total_samples, correct_speedups ) final_correct_count = correct_count final_correct_negative_speedup_count = correct_negative_speedup_count final_correct_speedups = correct_speedups final_slowdown_speedups = slowdown_speedups - final_error_type_counts = error_type_counts.copy() # Save for t >= 1 + final_errno2count = errno2count.copy() # Save for t >= 1 # Calculate aggregated parameters if total_samples > 0: @@ -127,16 +131,16 @@ def calculate_aggregated_parameters( stats_slowdown_speedups = final_slowdown_speedups # Calculate all aggregated parameters using the dedicated module - # For t >= 1, use error_type_counts from t=1 (frozen state) + # For t >= 1, use errno2count from t=1 (frozen state) if t_key < 1: - stats_error_type_counts = error_type_counts + stats_errno2count = errno2count else: - stats_error_type_counts = final_error_type_counts # Use frozen from t=1 + stats_errno2count = final_errno2count # Use frozen from t=1 aggregated_params = samples_statistics.calculate_all_aggregated_parameters( total_samples=total_samples, correct_speedups=stats_correct_speedups, - error_type_counts=stats_error_type_counts, + errno2count=stats_errno2count, t_key=t_key, negative_speedup_penalty=negative_speedup_penalty, fpdb=fpdb, @@ -184,9 +188,17 @@ def calculate_aggregated_parameters( ) print(f" gamma (average error penalty): {gamma:.6f}") if t_key >= 1: + # pi is now dict[int, float], convert to list for display + errnos = sorted(pi.keys()) + pi_list = [pi[errno] for errno in errnos] indicator = [1 if t_key < 1 else 0, 1 if t_key < 3 else 0] - pi_indicator_sum = sum(pi[i] * indicator[i] for i in range(len(pi))) - print(f" - pi: {list(pi)}") + # Calculate pi_indicator_sum using errno-based pi + pi_indicator_sum = sum( + pi.get(errno, 0.0) * indicator[min(i, len(indicator) - 1)] + for i, errno in enumerate(errnos) + ) + print(f" - pi (errno -> proportion): {dict(sorted(pi.items()))}") + print(f" - pi (as list): {pi_list}") print(f" - indicator: {indicator}") print( f" - gamma = fpdb^(sum(pi[i] * indicator[i])) = {fpdb}^{pi_indicator_sum:.6f} = {gamma:.6f}" From a4aa31fd685245002b3f0be4876d723df35ea37f Mon Sep 17 00:00:00 2001 From: Dayuxiaoshui <792179245@qq.com> Date: Mon, 17 Nov 2025 18:31:59 +0800 Subject: [PATCH 5/8] refactor: split tolerance report generation --- graph_net/analysis_util.py | 14 +- graph_net/samples_statistics.py | 127 +++---- graph_net/verify_aggregated_params.py | 470 +++++++++++++++----------- 3 files changed, 348 insertions(+), 263 deletions(-) diff --git a/graph_net/analysis_util.py b/graph_net/analysis_util.py index 6e3ca790d..3ac3a61f3 100644 --- a/graph_net/analysis_util.py +++ b/graph_net/analysis_util.py @@ -582,13 +582,13 @@ def print_stat_info( print(f" - Details for tolerance={t_key}:") if total_samples > 0: # Calculate all aggregated parameters using the dedicated module - aggregated_params = samples_statistics.calculate_all_aggregated_parameters( + aggregated_params = samples_statistics.calculate_es_components_values( total_samples=total_samples, correct_speedups=correct_speedups, errno2count=errno2count, - t_key=t_key, + tolerance=t_key, negative_speedup_penalty=negative_speedup_penalty, - fpdb=fpdb, + b=fpdb, pi=pi, ) @@ -597,8 +597,12 @@ def print_stat_info( lambda_ = aggregated_params["lambda"] eta = aggregated_params["eta"] gamma = aggregated_params["gamma"] - expected_s = aggregated_params["s_t"] - expected_es = aggregated_params["es_t"] + expected_s = samples_statistics.calculate_s_t_from_aggregated( + alpha, beta, lambda_, eta, negative_speedup_penalty, fpdb + ) + expected_es = samples_statistics.calculate_es_t_from_aggregated( + alpha, beta, lambda_, eta, gamma, negative_speedup_penalty + ) print( f" - alpha: {alpha:.3f} (Geometric mean speedup of correct samples)" diff --git a/graph_net/samples_statistics.py b/graph_net/samples_statistics.py index 364d99ced..d76ceb9ec 100644 --- a/graph_net/samples_statistics.py +++ b/graph_net/samples_statistics.py @@ -149,6 +149,10 @@ def calculate_pi( """ correct_count = len(correct_speedups) error_count = total_samples - correct_count + counted_errors = sum(errno2count.values()) + assert ( + error_count == counted_errors + ), f"error_count mismatch: got {error_count}, but errno2count sums to {counted_errors}" if error_count == 0: return {errno: 0.0 for errno in errno2count.keys()} @@ -158,10 +162,36 @@ def calculate_pi( return pi +def resolve_errno_tolerance( + errno2count: dict[int, int], custom_map: dict[int, int] | None +) -> dict[int, int]: + """ + Build a sorted errno -> tolerance map for downstream gamma calculation. + + Args: + errno2count: Observed errno occurrences in the dataset. + custom_map: Optional overrides mapping errno to its minimal tolerated tolerance. + + Returns: + Ordered dict (by errno) mapping each errno seen in errno2count + to the tolerance value where it becomes tolerated. Defaults to: + - errno 1 (accuracy) -> 1 + - errno >=2 (runtime/compile) -> 3 + """ + custom_map = custom_map or {} + + def tolerance_for(errno: int) -> int: + if errno in custom_map: + return custom_map[errno] + return 1 if errno == 1 else 3 + + return {errno: tolerance_for(errno) for errno in sorted(errno2count.keys())} + + def calculate_gamma( tolerance: int, - get_pi: Callable[[int], float], - errno_tolerances: list[int], + pi_value4errno: Callable[[int], float], + errno_as_tolerances: dict[int, int], b: float = 0.1, ) -> float: """ @@ -172,26 +202,24 @@ def calculate_gamma( Args: tolerance: Tolerance level t - get_pi: Function that takes error type index c and returns π_c (proportion of error type c) - errno_tolerances: List of tolerance thresholds for each error type. - Index corresponds to error type index c, value is the threshold. - An error type is tolerated (not penalized) when t >= threshold. + pi_value4errno: Function that takes errno and returns π_c (proportion of error type c). + errno_as_tolerances: Mapping of errno to tolerance thresholds. + An error type is tolerated (not penalized) when t >= threshold for that errno. b: Base penalty for severe errors (default: 0.1) Returns: Gamma value (average error penalty) """ - if len(errno_tolerances) == 0: + if tolerance <= 0: return b - # Calculate indicator for each error type: 1 if not tolerated, 0 if tolerated - pi_sum = 0.0 - for error_type_index in range(len(errno_tolerances)): - pi_c = get_pi(error_type_index) - threshold_c = errno_tolerances[error_type_index] - # Error type is not tolerated (penalized) when t < threshold - indicator = 1 if tolerance < threshold_c else 0 - pi_sum += pi_c * indicator + # Calculate indicator-weighted pi sum for errnos that are not tolerated + pi_sum = sum( + pi_value + for errno, errno_tolerance in errno_as_tolerances.items() + for pi_value in [pi_value4errno(errno)] + if tolerance < errno_tolerance + ) return b**pi_sum @@ -202,7 +230,7 @@ def calculate_s_t_from_aggregated( lambda_: float, eta: float, negative_speedup_penalty: float, - fpdb: float, + b: float, ) -> float: """ Calculate S(t) from aggregated parameters. @@ -215,7 +243,7 @@ def calculate_s_t_from_aggregated( lambda_: Fraction of correct samples eta: Fraction of slowdown cases within correct samples negative_speedup_penalty: Penalty power p for negative speedup - fpdb: Base penalty b for severe errors + b: Base penalty for severe errors or accuracy violation Returns: S(t) value calculated from aggregated parameters @@ -223,7 +251,7 @@ def calculate_s_t_from_aggregated( return ( alpha**lambda_ * beta ** (lambda_ * eta * negative_speedup_penalty) - * fpdb ** (1 - lambda_) + * b ** (1 - lambda_) ) @@ -258,85 +286,60 @@ def calculate_es_t_from_aggregated( ) -def calculate_all_aggregated_parameters( +def calculate_es_components_values( total_samples: int, correct_speedups: list[float], errno2count: dict[int, int], - t_key: int, + tolerance: int, negative_speedup_penalty: float = 0.0, - fpdb: float = 0.1, + b: float = 0.1, pi: dict[int, float] | None = None, - errno_tolerance_thresholds: dict[int, int] | None = None, + errno_as_tolerance: dict[int, int] | None = None, ) -> dict: """ - Calculate all aggregated parameters for a given tolerance level. - - This is a convenience function that calculates all aggregated parameters at once. + Calculate aggregated parameters for a given tolerance level. Args: total_samples: Total number of samples correct_speedups: List of speedup values for correct samples errno2count: Dictionary mapping errno (error number) to their counts. Errno values: 1=accuracy, 2=runtime, 3=compilation. - t_key: Tolerance level + tolerance: Tolerance level negative_speedup_penalty: Penalty power p for negative speedup - fpdb: Base penalty b for severe errors + b: Base penalty for severe errors or accuracy violation pi: Dictionary mapping errno to their proportions (calculated at t=1). If None, will be calculated from errno2count. - errno_tolerance_thresholds: Dictionary mapping errno to their tolerance thresholds. - An error type is tolerated (not penalized) when t >= threshold. - If None, uses default thresholds: {1: 1} for accuracy errors (errno=1), {2: 3, 3: 3} for others. + errno_as_tolerance: Mapping from errno to its minimum tolerated tolerance. + An error type is tolerated (not penalized) when tolerance >= its value. + If None, defaults to {1: 1} for accuracy, {2: 3, 3: 3} for others. Returns: - Dictionary containing all aggregated parameters and calculated scores: + Dictionary containing ES(t) component values: { 'alpha': float, 'beta': float, 'lambda': float, 'eta': float, 'gamma': float, - 'pi': dict[int, float], - 's_t': float, - 'es_t': float + 'pi': dict[int, float] } """ - # Use default error tolerance thresholds if not provided - if errno_tolerance_thresholds is None: - errno_tolerance_thresholds = {} - for errno in errno2count.keys(): - if errno == 1: # accuracy errors - errno_tolerance_thresholds[errno] = 1 - else: # runtime (2) or compilation (3) errors - errno_tolerance_thresholds[errno] = 3 - # Calculate pi if not provided if pi is None: pi = calculate_pi(errno2count, total_samples, correct_speedups) - # Convert dictionary-based pi and thresholds to indexed format for calculate_gamma - # Create ordered list of errnos for consistent indexing (sorted by errno) - errnos = sorted(errno2count.keys()) - errno_tolerances = [errno_tolerance_thresholds.get(errno, 3) for errno in errnos] + # Prepare errno-ordered tolerance mapping for calculate_gamma + errno_as_tolerances = resolve_errno_tolerance(errno2count, errno_as_tolerance) - # Create get_pi function that maps error type index to pi value - def get_pi(error_type_index: int) -> float: - if error_type_index < len(errnos): - errno = errnos[error_type_index] - return pi.get(errno, 0.0) - return 0.0 + # Create pi_value4errno function that maps errno to pi value + def pi_value4errno(errno: int) -> float: + return pi.get(errno, 0.0) alpha = calculate_alpha(correct_speedups) beta = calculate_beta(correct_speedups) lambda_ = calculate_lambda(correct_speedups, total_samples) eta = calculate_eta(correct_speedups) - gamma = calculate_gamma(t_key, get_pi, errno_tolerances, fpdb) - - s_t = calculate_s_t_from_aggregated( - alpha, beta, lambda_, eta, negative_speedup_penalty, fpdb - ) - es_t = calculate_es_t_from_aggregated( - alpha, beta, lambda_, eta, gamma, negative_speedup_penalty - ) + gamma = calculate_gamma(tolerance, pi_value4errno, errno_as_tolerances, b) return { "alpha": alpha, @@ -345,6 +348,4 @@ def get_pi(error_type_index: int) -> float: "eta": eta, "gamma": gamma, "pi": pi, - "s_t": s_t, - "es_t": es_t, } diff --git a/graph_net/verify_aggregated_params.py b/graph_net/verify_aggregated_params.py index d7600f3f2..4b23b82fe 100644 --- a/graph_net/verify_aggregated_params.py +++ b/graph_net/verify_aggregated_params.py @@ -10,217 +10,297 @@ ) -def calculate_aggregated_parameters( +def determine_tolerances(samples: list) -> range: + """Determine tolerance range based on observed errno categories.""" + # Currently errno categories are 1=accuracy, 2=runtime, 3=compile. + # Keep logic data-driven for future extension. + default_errnos = {1, 2, 3} + return range(-10, len(default_errnos) + 2) + + +def extract_statistics_at_tolerance(samples: list, tolerance: int) -> dict: + """Extract statistics for a given tolerance level.""" + sample_data = [ + ( + idx, + sample, + sample.get("performance", {}).get("speedup", {}).get("e2e"), + *analysis_util.check_sample_correctness(sample, tolerance), + ) + for idx, sample in enumerate(samples) + ] + + correct_samples = [ + (idx, speedup) for idx, _, speedup, is_correct, _ in sample_data if is_correct + ] + correct_count = len(correct_samples) + correct_speedups = [ + speedup for _, speedup in correct_samples if speedup is not None + ] + slowdown_speedups = [speedup for speedup in correct_speedups if speedup < 1] + correct_negative_speedup_count = len(slowdown_speedups) + + errno2count = dict( + Counter( + get_errno_from_error_type(fail_type) + for _, _, _, _, fail_type in sample_data + if fail_type is not None + ) + ) + + return { + "correct_count": correct_count, + "correct_speedups": correct_speedups, + "slowdown_speedups": slowdown_speedups, + "correct_negative_speedup_count": correct_negative_speedup_count, + "errno2count": errno2count, + } + + +def freeze_statistics_at_t1( + stats: dict, total_samples: int, frozen_stats: dict +) -> dict: + """Freeze statistics at t=1 and calculate pi.""" + pi = samples_statistics.calculate_pi( + stats["errno2count"], total_samples, stats["correct_speedups"] + ) + frozen_stats.update( + { + "correct_count": stats["correct_count"], + "correct_negative_speedup_count": stats["correct_negative_speedup_count"], + "correct_speedups": stats["correct_speedups"], + "slowdown_speedups": stats["slowdown_speedups"], + "errno2count": stats["errno2count"].copy(), + } + ) + return pi + + +def select_statistics_for_calculation( + tolerance: int, current_stats: dict, frozen_stats: dict +) -> dict: + """Select statistics to use based on tolerance level.""" + if tolerance < 1: + return { + "correct_count": current_stats["correct_count"], + "correct_speedups": current_stats["correct_speedups"], + "slowdown_speedups": current_stats["slowdown_speedups"], + "errno2count": current_stats["errno2count"], + } + else: + return { + "correct_count": frozen_stats["correct_count"], + "correct_speedups": frozen_stats["correct_speedups"], + "slowdown_speedups": frozen_stats["slowdown_speedups"], + "errno2count": frozen_stats["errno2count"], + } + + +def calculate_parameters_for_tolerance( + tolerance: int, + total_samples: int, + stats: dict, + pi: dict, + negative_speedup_penalty: float, + fpdb: float, +) -> dict: + """Calculate ES(t) components and final scores for a tolerance level.""" + aggregated_params = samples_statistics.calculate_es_components_values( + total_samples=total_samples, + correct_speedups=stats["correct_speedups"], + errno2count=stats["errno2count"], + tolerance=tolerance, + negative_speedup_penalty=negative_speedup_penalty, + b=fpdb, + pi=pi, + ) + + alpha = aggregated_params["alpha"] + beta = aggregated_params["beta"] + lambda_ = aggregated_params["lambda"] + eta = aggregated_params["eta"] + gamma = aggregated_params["gamma"] + + expected_s = samples_statistics.calculate_s_t_from_aggregated( + alpha, beta, lambda_, eta, negative_speedup_penalty, fpdb + ) + expected_es = samples_statistics.calculate_es_t_from_aggregated( + alpha, beta, lambda_, eta, gamma, negative_speedup_penalty + ) + + return { + "alpha": alpha, + "beta": beta, + "lambda": lambda_, + "eta": eta, + "gamma": gamma, + "expected_s": expected_s, + "expected_es": expected_es, + } + + +def print_tolerance_details( + tolerance: int, + total_samples: int, + stats: dict, + params: dict, + pi: dict, + fpdb: float, +): + """Print detailed information for a tolerance level.""" + print(f"\nTolerance t = {tolerance}:") + print(f" Total samples: {total_samples}") + print( + f" Correct samples: {stats['correct_count']} (lambda = {params['lambda']:.6f})" + ) + print(f" Correct speedups collected: {len(stats['correct_speedups'])}") + print( + f" Slowdown cases: {len(stats['slowdown_speedups'])} (eta = {params['eta']:.6f})" + ) + print(f" alpha (geometric mean of correct speedups): {params['alpha']:.6f}") + if stats["correct_speedups"]: + print( + f" - Correct speedups: {stats['correct_speedups'][:10]}{'...' if len(stats['correct_speedups']) > 10 else ''}" + ) + print(f" beta (geometric mean of slowdown speedups): {params['beta']:.6f}") + if stats["slowdown_speedups"]: + print( + f" - Slowdown speedups: {stats['slowdown_speedups'][:10]}{'...' if len(stats['slowdown_speedups']) > 10 else ''}" + ) + print(f" gamma (average error penalty): {params['gamma']:.6f}") + if tolerance >= 1: + errnos = sorted(pi.keys()) + pi_list = [pi[errno] for errno in errnos] + indicator = [1 if tolerance < 1 else 0, 1 if tolerance < 3 else 0] + pi_indicator_sum = sum( + pi.get(errno, 0.0) * indicator[min(i, len(indicator) - 1)] + for i, errno in enumerate(errnos) + ) + print(f" - pi (errno -> proportion): {dict(sorted(pi.items()))}") + print(f" - pi (as list): {pi_list}") + print(f" - indicator: {indicator}") + print( + f" - gamma = fpdb^(sum(pi[i] * indicator[i])) = {fpdb}^{pi_indicator_sum:.6f} = {params['gamma']:.6f}" + ) + print(f" Expected S(t) from aggregated: {params['expected_s']:.6f}") + print(f" Expected ES(t) from aggregated: {params['expected_es']:.6f}") + + +class ToleranceReportBuilder: + """Stateful helper for building tolerance reports in order.""" + + def __init__( + self, + samples: list, + total_samples: int, + negative_speedup_penalty: float, + fpdb: float, + ): + self.samples = samples + self.total_samples = total_samples + self.negative_speedup_penalty = negative_speedup_penalty + self.fpdb = fpdb + self.pi: dict[int, float] = {} + self.frozen_stats: dict = { + "correct_count": 0, + "correct_speedups": [], + "slowdown_speedups": [], + "errno2count": {}, + } + + def build_report(self, tolerance: int) -> dict: + current_stats = extract_statistics_at_tolerance(self.samples, tolerance) + + if tolerance == 1: + self.pi = freeze_statistics_at_t1( + current_stats, self.total_samples, self.frozen_stats + ) + + if self.total_samples == 0: + return self._empty_report(tolerance) + + stats_for_calc = select_statistics_for_calculation( + tolerance, current_stats, self.frozen_stats + ) + # For tolerance < 1, pass None to let calculate_es_components_values recalculate pi + # For tolerance >= 1, use frozen pi from t=1 + pi_for_calc = None if tolerance < 1 else self.pi + params = calculate_parameters_for_tolerance( + tolerance, + self.total_samples, + stats_for_calc, + pi_for_calc, + self.negative_speedup_penalty, + self.fpdb, + ) + # Use calculated pi from params for display and return + calculated_pi = params.get("pi", self.pi) + print_tolerance_details( + tolerance, + self.total_samples, + stats_for_calc, + params, + calculated_pi, + self.fpdb, + ) + return { + **params, + "pi": calculated_pi, + "correct_count": stats_for_calc["correct_count"], + "total_samples": self.total_samples, + "correct_speedups_count": len(stats_for_calc["correct_speedups"]), + "slowdown_count": len(stats_for_calc["slowdown_speedups"]), + } + + def _empty_report(self, tolerance: int) -> dict: + print(f"\nTolerance t = {tolerance}: No samples to analyze") + return { + "alpha": 1.0, + "beta": 1.0, + "gamma": self.fpdb, + "lambda": 0.0, + "eta": 0.0, + "pi": self.pi, + "expected_s": self.fpdb, + "expected_es": self.fpdb, + "correct_count": 0, + "total_samples": 0, + "correct_speedups_count": 0, + "slowdown_count": 0, + } + + +def verify_es_components_across_tolerances( samples: list, folder_name: str, negative_speedup_penalty: float = 0, fpdb: float = 0.1, ) -> dict: """ - Calculate and print all aggregated parameters (alpha, beta, gamma, lambda, eta, pi) - for each tolerance level independently. - - This function extracts the aggregated parameter calculation logic from calculate_s_scores - to verify the correctness of aggregated-level calculations. + Verify and print ES component values (alpha, beta, gamma, lambda, eta, pi) for each + tolerance level independently. This logic mirrors `calculate_s_scores` but is split + out for focused validation of aggregated calculations. Returns: Dictionary mapping tolerance -> dict of aggregated parameters and calculated scores """ - begin = -10 - end = 4 - t_keys = list(range(begin, end + 1)) total_samples = len(samples) print(f"\n{'='*80}") print(f"Verifying Aggregated Parameters for '{folder_name}'") print(f"{'='*80}") - # pi is a tuple of constants for t > 0 for each group: (pi[0], pi[1]) - # Calculated at t=1, used for all t >= 1 - pi = (0.0, 0.0) - - # Store state at t=1 for ES(t) calculation - is_correct_at_t1 = [False] * total_samples - speedup_at_t1 = [None] * total_samples - fail_type_at_t1 = ["CORRECT"] * total_samples - - # Final statistics at t=1 - final_correct_count = 0 - final_correct_negative_speedup_count = 0 - final_correct_speedups = [] - final_slowdown_speedups = [] - final_errno2count = {} # Store error type counts at t=1 (using errno) - - results = OrderedDict() - - for t_key in t_keys: - # Extract sample data using map - sample_data = [ - ( - idx, - sample, - sample.get("performance", {}).get("speedup", {}).get("e2e"), - *analysis_util.check_sample_correctness(sample, t_key), - ) - for idx, sample in enumerate(samples) - ] - - # Filter correct samples and extract speedups - correct_samples = [ - (idx, speedup) - for idx, _, speedup, is_correct, _ in sample_data - if is_correct - ] - correct_count = len(correct_samples) - correct_speedups = [ - speedup for _, speedup in correct_samples if speedup is not None - ] - slowdown_speedups = [speedup for speedup in correct_speedups if speedup < 1] - correct_negative_speedup_count = len(slowdown_speedups) - - # Count errors by errno using Counter (convert error type string to errno) - errno2count = dict( - Counter( - get_errno_from_error_type(fail_type) - for _, _, _, _, fail_type in sample_data - if fail_type is not None - ) - ) - - # Store state at t=1 using list comprehension - if t_key == 1: - t1_data = [ - ( - idx, - speedup, - is_correct, - fail_type if fail_type is not None else "CORRECT", - ) - for idx, _, speedup, is_correct, fail_type in sample_data - ] - is_correct_at_t1 = [is_correct for _, _, is_correct, _ in t1_data] - speedup_at_t1 = [speedup for _, speedup, _, _ in t1_data] - fail_type_at_t1 = [fail_type for _, _, _, fail_type in t1_data] - - # Calculate pi at t=1 using the dedicated function - if t_key == 1: - pi = samples_statistics.calculate_pi( - errno2count, total_samples, correct_speedups - ) - final_correct_count = correct_count - final_correct_negative_speedup_count = correct_negative_speedup_count - final_correct_speedups = correct_speedups - final_slowdown_speedups = slowdown_speedups - final_errno2count = errno2count.copy() # Save for t >= 1 - - # Calculate aggregated parameters - if total_samples > 0: - # For t < 1, use current tolerance statistics - # For t >= 1, use t=1 statistics (frozen state) - if t_key < 1: - stats_correct_count = correct_count - stats_correct_negative_speedup_count = correct_negative_speedup_count - stats_correct_speedups = correct_speedups - stats_slowdown_speedups = slowdown_speedups - else: - stats_correct_count = final_correct_count - stats_correct_negative_speedup_count = ( - final_correct_negative_speedup_count - ) - stats_correct_speedups = final_correct_speedups - stats_slowdown_speedups = final_slowdown_speedups - - # Calculate all aggregated parameters using the dedicated module - # For t >= 1, use errno2count from t=1 (frozen state) - if t_key < 1: - stats_errno2count = errno2count - else: - stats_errno2count = final_errno2count # Use frozen from t=1 - - aggregated_params = samples_statistics.calculate_all_aggregated_parameters( - total_samples=total_samples, - correct_speedups=stats_correct_speedups, - errno2count=stats_errno2count, - t_key=t_key, - negative_speedup_penalty=negative_speedup_penalty, - fpdb=fpdb, - pi=pi, - ) + tolerances = determine_tolerances(samples) + builder = ToleranceReportBuilder( + samples=samples, + total_samples=total_samples, + negative_speedup_penalty=negative_speedup_penalty, + fpdb=fpdb, + ) - alpha = aggregated_params["alpha"] - beta = aggregated_params["beta"] - lambda_ = aggregated_params["lambda"] - eta = aggregated_params["eta"] - gamma = aggregated_params["gamma"] - expected_s = aggregated_params["s_t"] - expected_es = aggregated_params["es_t"] - - results[t_key] = { - "alpha": alpha, - "beta": beta, - "gamma": gamma, - "lambda": lambda_, - "eta": eta, - "pi": pi, - "expected_s": expected_s, - "expected_es": expected_es, - "correct_count": stats_correct_count, - "total_samples": total_samples, - "correct_speedups_count": len(stats_correct_speedups), - "slowdown_count": len(stats_slowdown_speedups), - } - - # Print detailed information - print(f"\nTolerance t = {t_key}:") - print(f" Total samples: {total_samples}") - print(f" Correct samples: {stats_correct_count} (lambda = {lambda_:.6f})") - print(f" Correct speedups collected: {len(stats_correct_speedups)}") - print(f" Slowdown cases: {len(stats_slowdown_speedups)} (eta = {eta:.6f})") - print(f" alpha (geometric mean of correct speedups): {alpha:.6f}") - if stats_correct_speedups: - print( - f" - Correct speedups: {stats_correct_speedups[:10]}{'...' if len(stats_correct_speedups) > 10 else ''}" - ) - print(f" beta (geometric mean of slowdown speedups): {beta:.6f}") - if stats_slowdown_speedups: - print( - f" - Slowdown speedups: {stats_slowdown_speedups[:10]}{'...' if len(stats_slowdown_speedups) > 10 else ''}" - ) - print(f" gamma (average error penalty): {gamma:.6f}") - if t_key >= 1: - # pi is now dict[int, float], convert to list for display - errnos = sorted(pi.keys()) - pi_list = [pi[errno] for errno in errnos] - indicator = [1 if t_key < 1 else 0, 1 if t_key < 3 else 0] - # Calculate pi_indicator_sum using errno-based pi - pi_indicator_sum = sum( - pi.get(errno, 0.0) * indicator[min(i, len(indicator) - 1)] - for i, errno in enumerate(errnos) - ) - print(f" - pi (errno -> proportion): {dict(sorted(pi.items()))}") - print(f" - pi (as list): {pi_list}") - print(f" - indicator: {indicator}") - print( - f" - gamma = fpdb^(sum(pi[i] * indicator[i])) = {fpdb}^{pi_indicator_sum:.6f} = {gamma:.6f}" - ) - print(f" Expected S(t) from aggregated: {expected_s:.6f}") - print(f" Expected ES(t) from aggregated: {expected_es:.6f}") - else: - results[t_key] = { - "alpha": 1.0, - "beta": 1.0, - "gamma": fpdb, - "lambda": 0.0, - "eta": 0.0, - "pi": pi, - "expected_s": fpdb, - "expected_es": fpdb, - "correct_count": 0, - "total_samples": 0, - "correct_speedups_count": 0, - "slowdown_count": 0, - } - print(f"\nTolerance t = {t_key}: No samples to analyze") + results = OrderedDict( + (tolerance, builder.build_report(tolerance)) for tolerance in tolerances + ) print(f"\n{'='*80}") print(f"Aggregated Parameter Verification Complete") @@ -263,7 +343,7 @@ def main(): # Calculate and print aggregated parameters for each curve for folder_name, samples in all_results.items(): - aggregated_results = calculate_aggregated_parameters( + aggregated_results = verify_es_components_across_tolerances( samples, folder_name, negative_speedup_penalty=args.negative_speedup_penalty, From 76d2b15d5c7d711ccc934963b35becd0242c330e Mon Sep 17 00:00:00 2001 From: Dayuxiaoshui <792179245@qq.com> Date: Tue, 18 Nov 2025 10:32:00 +0800 Subject: [PATCH 6/8] refactor: improve naming and semantics for ES calculation - Rename verify_es_match_at_tolerance to compare_aggregated_es_and_microscopic_es - Replace tolerance_level with tolerance parameter - Replace tolerance_threshold with atol/rtol to avoid confusion - Rename verify_aggregated_microscopic_consistency to get_verified_aggregated_es_values - Change return type to dict only (remove all_matched) - Rename verified_scores to verified_es_values - Replace micro with microscopic throughout - Rename check_sample_correctness to get_sample_correctness - Rename t1 variables to first_errno_tolerance - Rename es_components to es_constructor_params - Rename calculate_parameters_for_tolerance to calculate_es_constructor_params_for_tolerance - Rename custom_map to errno_tolerance_overrides - Rename errno_as_tolerances to errno2tolerance - Add enable_aggregation_mode command line option --- graph_net/analysis_util.py | 59 ++++++------ graph_net/plot_ESt.py | 130 +++++++++++++++++--------- graph_net/samples_statistics.py | 20 ++-- graph_net/verify_aggregated_params.py | 64 +++++++------ 4 files changed, 164 insertions(+), 109 deletions(-) diff --git a/graph_net/analysis_util.py b/graph_net/analysis_util.py index 3ac3a61f3..99d4a1777 100644 --- a/graph_net/analysis_util.py +++ b/graph_net/analysis_util.py @@ -416,9 +416,9 @@ def get_correctness(dtype: str, t: int, correctness_data: dict, index: int) -> b return False -def check_sample_correctness(sample: dict, t_key: int) -> tuple[bool, str]: +def get_sample_correctness(sample: dict, t_key: int) -> tuple[bool, str]: """ - Check if a sample is correct at the given tolerance level. + Get sample correctness status at the given tolerance level. Args: sample: Sample data dictionary @@ -487,9 +487,9 @@ def calculate_es_rectified_speedup( speedup: float, fail_type: str, t_key: int, - is_correct_at_t1: bool, - speedup_at_t1: float, - fail_type_at_t1: str, + is_correct_at_first_errno_tolerance: bool, + speedup_at_first_errno_tolerance: float, + fail_type_at_first_errno_tolerance: str, negative_speedup_penalty: float, fpdb: float, ) -> float: @@ -500,9 +500,9 @@ def calculate_es_rectified_speedup( speedup: Current speedup value fail_type: Current error type t_key: Current tolerance level - is_correct_at_t1: Whether sample was correct at t=1 - speedup_at_t1: Speedup value at t=1 - fail_type_at_t1: Error type at t=1 + is_correct_at_first_errno_tolerance: Whether sample was correct at first errno tolerance (t=1) + speedup_at_first_errno_tolerance: Speedup value at first errno tolerance (t=1) + fail_type_at_first_errno_tolerance: Error type at first errno tolerance (t=1) negative_speedup_penalty: Penalty power p fpdb: Base penalty for failures @@ -515,13 +515,16 @@ def calculate_es_rectified_speedup( speedup, fail_type, negative_speedup_penalty, fpdb ) - # For t >= 1, use frozen state from t=1 - if not is_correct_at_t1 or speedup_at_t1 is None: - return fake_perf_degrad(t_key, fail_type_at_t1, fpdb) + # For t >= 1, use frozen state from first errno tolerance (t=1) + if ( + not is_correct_at_first_errno_tolerance + or speedup_at_first_errno_tolerance is None + ): + return fake_perf_degrad(t_key, fail_type_at_first_errno_tolerance, fpdb) - if speedup_at_t1 < 1: - return speedup_at_t1 ** (negative_speedup_penalty + 1) - return speedup_at_t1 + if speedup_at_first_errno_tolerance < 1: + return speedup_at_first_errno_tolerance ** (negative_speedup_penalty + 1) + return speedup_at_first_errno_tolerance def fake_perf_degrad(t, error_code, fpdb=0.1): @@ -621,12 +624,12 @@ def print_stat_info( return expected_s, expected_es # pi is a tuple of constants for t > 0 for each group: (pi[0], pi[1]) - # Calculated at t=1, used for all t >= 1 + # Calculated at first errno tolerance (t=1), used for all t >= 1 pi = (0.0, 0.0) - is_correct_at_t1 = [False] * total_samples - speedup_at_t1 = [None] * total_samples - fail_type_at_t1 = ["CORRECT"] * total_samples + is_correct_at_first_errno_tolerance = [False] * total_samples + speedup_at_first_errno_tolerance = [None] * total_samples + fail_type_at_first_errno_tolerance = ["CORRECT"] * total_samples final_correct_count = 0 final_correct_negative_speedup_count = 0 @@ -646,8 +649,8 @@ def print_stat_info( performance_data = sample.get("performance", {}) speedup = performance_data.get("speedup", {}).get("e2e") - # Check correctness using dedicated function - is_correct, fail_type = check_sample_correctness(sample, t_key) + # Get correctness using dedicated function + is_correct, fail_type = get_sample_correctness(sample, t_key) # Collect statistics if is_correct: @@ -662,11 +665,13 @@ def print_stat_info( errno = get_errno_from_error_type(fail_type) errno2count[errno] = errno2count.get(errno, 0) + 1 - # Store state at t=1 for ES(t) calculation + # Store state at first errno tolerance (t=1) for ES(t) calculation if t_key == 1: - is_correct_at_t1[idx] = is_correct - speedup_at_t1[idx] = speedup - fail_type_at_t1[idx] = fail_type if fail_type is not None else "CORRECT" + is_correct_at_first_errno_tolerance[idx] = is_correct + speedup_at_first_errno_tolerance[idx] = speedup + fail_type_at_first_errno_tolerance[idx] = ( + fail_type if fail_type is not None else "CORRECT" + ) # Calculate rectified speedups using dedicated functions regularized_speedup = calculate_rectified_speedup( @@ -678,9 +683,9 @@ def print_stat_info( speedup, fail_type, t_key, - is_correct_at_t1[idx], - speedup_at_t1[idx], - fail_type_at_t1[idx], + is_correct_at_first_errno_tolerance[idx], + speedup_at_first_errno_tolerance[idx], + fail_type_at_first_errno_tolerance[idx], negative_speedup_penalty, fpdb, ) diff --git a/graph_net/plot_ESt.py b/graph_net/plot_ESt.py index d4c8398a9..c6eec05a2 100644 --- a/graph_net/plot_ESt.py +++ b/graph_net/plot_ESt.py @@ -5,20 +5,43 @@ from graph_net import analysis_util -def compare_single_tolerance_level( - tolerance_level: int, - micro_es: float, +def es_result_checker( + es_from_microscopic: float, es_from_macro: float, atol: float, rtol: float +) -> bool: + """ + Check if ES(t) values from microscopic and macro calculations match. + + Args: + es_from_microscopic: ES(t) value from microscopic-level calculation + es_from_macro: ES(t) value from aggregated-level calculation + atol: Absolute tolerance for comparison + rtol: Relative tolerance for comparison + + Returns: + True if values match within tolerance, False otherwise + """ + diff = abs(es_from_microscopic - es_from_macro) + return diff < atol or diff < rtol * max( + abs(es_from_microscopic), abs(es_from_macro), 1e-10 + ) + + +def compare_aggregated_es_and_microscopic_es( + tolerance: int, + microscopic_es: float, aggregated_es: float | None, - tolerance_threshold: float, + atol: float = 1e-3, + rtol: float = 1e-3, ) -> tuple[bool, float, float]: """ - Compare micro and aggregated ES(t) values for a single tolerance level. + Compare ES(t) values from aggregated and microscopic calculations at a tolerance level. Args: - tolerance_level: Tolerance level t - micro_es: ES(t) value from micro-level calculation + tolerance: Tolerance level t + microscopic_es: ES(t) value from microscopic-level calculation aggregated_es: ES(t) value from aggregated-level calculation, or None if missing - tolerance_threshold: Floating point comparison tolerance + atol: Absolute tolerance for comparison + rtol: Relative tolerance for comparison Returns: Tuple of (is_matched, diff, relative_diff) @@ -26,16 +49,16 @@ def compare_single_tolerance_level( if aggregated_es is None: return False, 0.0, 0.0 - diff = abs(micro_es - aggregated_es) - relative_diff = diff / max(abs(micro_es), abs(aggregated_es), 1e-10) - is_matched = diff < tolerance_threshold or relative_diff < tolerance_threshold + diff = abs(microscopic_es - aggregated_es) + relative_diff = diff / max(abs(microscopic_es), abs(aggregated_es), 1e-10) + is_matched = es_result_checker(microscopic_es, aggregated_es, atol, rtol) return is_matched, diff, relative_diff def print_verification_result( - tolerance_level: int, - micro_es: float, + tolerance: int, + microscopic_es: float, aggregated_es: float | None, diff: float, relative_diff: float, @@ -43,68 +66,71 @@ def print_verification_result( ) -> None: """Print verification result for a single tolerance level.""" if aggregated_es is None: - print(f"ERROR: No aggregated result for t={tolerance_level}, cannot verify") + print(f"ERROR: No aggregated result for t={tolerance}, cannot verify") elif is_matched: print( - f"t={tolerance_level:3d}: MATCHED - Micro: {micro_es:.6f}, Aggregated: {aggregated_es:.6f}, Diff: {diff:.2e}" + f"t={tolerance:3d}: MATCHED - Microscopic: {microscopic_es:.6f}, Aggregated: {aggregated_es:.6f}, Diff: {diff:.2e}" ) else: print( - f"t={tolerance_level:3d}: MISMATCH - Micro: {micro_es:.6f}, Aggregated: {aggregated_es:.6f}, Diff: {diff:.2e} ({relative_diff*100:.4f}%)" + f"t={tolerance:3d}: MISMATCH - Microscopic: {microscopic_es:.6f}, Aggregated: {aggregated_es:.6f}, Diff: {diff:.2e} ({relative_diff*100:.4f}%)" ) -def verify_aggregated_micro_consistency( - es_scores: dict, folder_name: str, tolerance_threshold: float -) -> tuple[dict, bool]: +def get_verified_aggregated_es_values(es_scores: dict, folder_name: str) -> dict: """ - Verify consistency between aggregated and micro-level ES(t) calculations. + Get verified ES(t) values by checking consistency between aggregated and microscopic-level calculations. Args: - es_scores: Dictionary of ES(t) scores from micro-level calculation + es_scores: Dictionary of ES(t) scores from microscopic-level calculation folder_name: Name of the folder being verified - tolerance_threshold: Floating point comparison tolerance Returns: - Tuple of (verified_scores, all_matched): - - verified_scores: Dictionary of verified scores (only matched tolerance levels) - - all_matched: True if all tolerance levels matched, False otherwise + Dictionary of verified ES(t) values (only matched tolerance levels). + Returns empty dict if validation fails. """ aggregated_results = getattr(es_scores, "_aggregated_results", {}) - verified_scores = {} + verified_es_values = {} all_matched = True print(f"\n{'='*80}") - print(f"Verifying Aggregated/Micro Consistency for '{folder_name}'") + print(f"Verifying Aggregated/Microscopic Consistency for '{folder_name}'") print(f"{'='*80}") - for tolerance_level, micro_es in es_scores.items(): - aggregated_es = aggregated_results.get(tolerance_level) - is_matched, diff, relative_diff = compare_single_tolerance_level( - tolerance_level, micro_es, aggregated_es, tolerance_threshold + for tolerance, microscopic_es in es_scores.items(): + aggregated_es = aggregated_results.get(tolerance) + is_matched, diff, relative_diff = compare_aggregated_es_and_microscopic_es( + tolerance, microscopic_es, aggregated_es ) print_verification_result( - tolerance_level, micro_es, aggregated_es, diff, relative_diff, is_matched + tolerance, + microscopic_es, + aggregated_es, + diff, + relative_diff, + is_matched, ) if aggregated_es is None or not is_matched: all_matched = False if is_matched: - verified_scores[tolerance_level] = micro_es + verified_es_values[tolerance] = microscopic_es if not all_matched: print( - f"\nERROR: Aggregated and micro results do not match for '{folder_name}'!" + f"\nERROR: Aggregated and microscopic results do not match for '{folder_name}'!" ) print("Calculation validation failed. Results will NOT be used for plotting.") print("Please verify the calculation logic using verify_aggregated_params.py") print(f"{'='*80}\n") + return {} else: - print(f"\nSUCCESS: All aggregated and micro results match for '{folder_name}'.") + print( + f"\nSUCCESS: All aggregated and microscopic results match for '{folder_name}'." + ) print(f"{'='*80}\n") - - return verified_scores, all_matched + return verified_es_values def plot_ES_results(s_scores: dict, cli_args: argparse.Namespace): @@ -232,6 +258,18 @@ def main(): default=0.1, help="Base penalty for severe errors (e.g., crashes, correctness failures).", ) + parser.add_argument( + "--enable-aggregation-mode", + action="store_true", + help="Enable aggregation mode to verify aggregated/microscopic consistency. Default: enabled.", + ) + parser.add_argument( + "--disable-aggregation-mode", + dest="enable_aggregation_mode", + action="store_false", + help="Disable aggregation mode verification.", + ) + parser.set_defaults(enable_aggregation_mode=True) args = parser.parse_args() # 1. Scan folders to get data @@ -240,9 +278,8 @@ def main(): print("No valid data found. Exiting.") return - # 2. Calculate scores for each curve and verify aggregated/micro consistency + # 2. Calculate scores for each curve and verify aggregated/microscopic consistency all_es_scores = {} - tolerance_threshold = 1e-6 # Tolerance for floating point comparison for folder_name, samples in all_results.items(): _, es_scores = analysis_util.calculate_s_scores( @@ -255,15 +292,16 @@ def main(): # Keep original behavior: assign es_scores directly all_es_scores[folder_name] = es_scores - # Verify aggregated/micro consistency - verified_scores, all_matched = verify_aggregated_micro_consistency( - es_scores, folder_name, tolerance_threshold - ) + # Verify aggregated/microscopic consistency if aggregation mode is enabled + if args.enable_aggregation_mode: + verified_es_values = get_verified_aggregated_es_values( + es_scores, folder_name + ) - if not all_matched: - continue # Skip this curve if validation fails + if not verified_es_values: + continue # Skip this curve if validation fails - all_es_scores[folder_name] = verified_scores + all_es_scores[folder_name] = verified_es_values # 3. Plot the results if any(all_es_scores.values()): diff --git a/graph_net/samples_statistics.py b/graph_net/samples_statistics.py index d76ceb9ec..0079a9cd0 100644 --- a/graph_net/samples_statistics.py +++ b/graph_net/samples_statistics.py @@ -163,14 +163,14 @@ def calculate_pi( def resolve_errno_tolerance( - errno2count: dict[int, int], custom_map: dict[int, int] | None + errno2count: dict[int, int], errno_tolerance_overrides: dict[int, int] | None ) -> dict[int, int]: """ Build a sorted errno -> tolerance map for downstream gamma calculation. Args: errno2count: Observed errno occurrences in the dataset. - custom_map: Optional overrides mapping errno to its minimal tolerated tolerance. + errno_tolerance_overrides: Optional overrides mapping errno to its minimal tolerated tolerance. Returns: Ordered dict (by errno) mapping each errno seen in errno2count @@ -178,11 +178,11 @@ def resolve_errno_tolerance( - errno 1 (accuracy) -> 1 - errno >=2 (runtime/compile) -> 3 """ - custom_map = custom_map or {} + errno_tolerance_overrides = errno_tolerance_overrides or {} def tolerance_for(errno: int) -> int: - if errno in custom_map: - return custom_map[errno] + if errno in errno_tolerance_overrides: + return errno_tolerance_overrides[errno] return 1 if errno == 1 else 3 return {errno: tolerance_for(errno) for errno in sorted(errno2count.keys())} @@ -191,7 +191,7 @@ def tolerance_for(errno: int) -> int: def calculate_gamma( tolerance: int, pi_value4errno: Callable[[int], float], - errno_as_tolerances: dict[int, int], + errno2tolerance: dict[int, int], b: float = 0.1, ) -> float: """ @@ -203,7 +203,7 @@ def calculate_gamma( Args: tolerance: Tolerance level t pi_value4errno: Function that takes errno and returns π_c (proportion of error type c). - errno_as_tolerances: Mapping of errno to tolerance thresholds. + errno2tolerance: Mapping of errno to tolerance thresholds. An error type is tolerated (not penalized) when t >= threshold for that errno. b: Base penalty for severe errors (default: 0.1) @@ -216,7 +216,7 @@ def calculate_gamma( # Calculate indicator-weighted pi sum for errnos that are not tolerated pi_sum = sum( pi_value - for errno, errno_tolerance in errno_as_tolerances.items() + for errno, errno_tolerance in errno2tolerance.items() for pi_value in [pi_value4errno(errno)] if tolerance < errno_tolerance ) @@ -329,7 +329,7 @@ def calculate_es_components_values( pi = calculate_pi(errno2count, total_samples, correct_speedups) # Prepare errno-ordered tolerance mapping for calculate_gamma - errno_as_tolerances = resolve_errno_tolerance(errno2count, errno_as_tolerance) + errno2tolerance = resolve_errno_tolerance(errno2count, errno_as_tolerance) # Create pi_value4errno function that maps errno to pi value def pi_value4errno(errno: int) -> float: @@ -339,7 +339,7 @@ def pi_value4errno(errno: int) -> float: beta = calculate_beta(correct_speedups) lambda_ = calculate_lambda(correct_speedups, total_samples) eta = calculate_eta(correct_speedups) - gamma = calculate_gamma(tolerance, pi_value4errno, errno_as_tolerances, b) + gamma = calculate_gamma(tolerance, pi_value4errno, errno2tolerance, b) return { "alpha": alpha, diff --git a/graph_net/verify_aggregated_params.py b/graph_net/verify_aggregated_params.py index 4b23b82fe..ff22bc37a 100644 --- a/graph_net/verify_aggregated_params.py +++ b/graph_net/verify_aggregated_params.py @@ -25,7 +25,7 @@ def extract_statistics_at_tolerance(samples: list, tolerance: int) -> dict: idx, sample, sample.get("performance", {}).get("speedup", {}).get("e2e"), - *analysis_util.check_sample_correctness(sample, tolerance), + *analysis_util.get_sample_correctness(sample, tolerance), ) for idx, sample in enumerate(samples) ] @@ -57,10 +57,13 @@ def extract_statistics_at_tolerance(samples: list, tolerance: int) -> dict: } -def freeze_statistics_at_t1( - stats: dict, total_samples: int, frozen_stats: dict +def _freeze_statistics_at_tolerance( + stats: dict, + total_samples: int, + frozen_stats: dict, + first_errno_tolerance: int, ) -> dict: - """Freeze statistics at t=1 and calculate pi.""" + """Freeze statistics at first_errno_tolerance and calculate pi.""" pi = samples_statistics.calculate_pi( stats["errno2count"], total_samples, stats["correct_speedups"] ) @@ -96,7 +99,7 @@ def select_statistics_for_calculation( } -def calculate_parameters_for_tolerance( +def calculate_es_constructor_params_for_tolerance( tolerance: int, total_samples: int, stats: dict, @@ -104,7 +107,7 @@ def calculate_parameters_for_tolerance( negative_speedup_penalty: float, fpdb: float, ) -> dict: - """Calculate ES(t) components and final scores for a tolerance level.""" + """Calculate ES(t) constructor parameters (alpha, beta, gamma, lambda, eta) and final scores for a tolerance level.""" aggregated_params = samples_statistics.calculate_es_components_values( total_samples=total_samples, correct_speedups=stats["correct_speedups"], @@ -143,7 +146,7 @@ def print_tolerance_details( tolerance: int, total_samples: int, stats: dict, - params: dict, + es_constructor_params: dict, pi: dict, fpdb: float, ): @@ -151,23 +154,27 @@ def print_tolerance_details( print(f"\nTolerance t = {tolerance}:") print(f" Total samples: {total_samples}") print( - f" Correct samples: {stats['correct_count']} (lambda = {params['lambda']:.6f})" + f" Correct samples: {stats['correct_count']} (lambda = {es_constructor_params['lambda']:.6f})" ) print(f" Correct speedups collected: {len(stats['correct_speedups'])}") print( - f" Slowdown cases: {len(stats['slowdown_speedups'])} (eta = {params['eta']:.6f})" + f" Slowdown cases: {len(stats['slowdown_speedups'])} (eta = {es_constructor_params['eta']:.6f})" + ) + print( + f" alpha (geometric mean of correct speedups): {es_constructor_params['alpha']:.6f}" ) - print(f" alpha (geometric mean of correct speedups): {params['alpha']:.6f}") if stats["correct_speedups"]: print( f" - Correct speedups: {stats['correct_speedups'][:10]}{'...' if len(stats['correct_speedups']) > 10 else ''}" ) - print(f" beta (geometric mean of slowdown speedups): {params['beta']:.6f}") + print( + f" beta (geometric mean of slowdown speedups): {es_constructor_params['beta']:.6f}" + ) if stats["slowdown_speedups"]: print( f" - Slowdown speedups: {stats['slowdown_speedups'][:10]}{'...' if len(stats['slowdown_speedups']) > 10 else ''}" ) - print(f" gamma (average error penalty): {params['gamma']:.6f}") + print(f" gamma (average error penalty): {es_constructor_params['gamma']:.6f}") if tolerance >= 1: errnos = sorted(pi.keys()) pi_list = [pi[errno] for errno in errnos] @@ -180,10 +187,12 @@ def print_tolerance_details( print(f" - pi (as list): {pi_list}") print(f" - indicator: {indicator}") print( - f" - gamma = fpdb^(sum(pi[i] * indicator[i])) = {fpdb}^{pi_indicator_sum:.6f} = {params['gamma']:.6f}" + f" - gamma = fpdb^(sum(pi[i] * indicator[i])) = {fpdb}^{pi_indicator_sum:.6f} = {es_constructor_params['gamma']:.6f}" ) - print(f" Expected S(t) from aggregated: {params['expected_s']:.6f}") - print(f" Expected ES(t) from aggregated: {params['expected_es']:.6f}") + print(f" Expected S(t) from aggregated: {es_constructor_params['expected_s']:.6f}") + print( + f" Expected ES(t) from aggregated: {es_constructor_params['expected_es']:.6f}" + ) class ToleranceReportBuilder: @@ -212,8 +221,11 @@ def build_report(self, tolerance: int) -> dict: current_stats = extract_statistics_at_tolerance(self.samples, tolerance) if tolerance == 1: - self.pi = freeze_statistics_at_t1( - current_stats, self.total_samples, self.frozen_stats + self.pi = _freeze_statistics_at_tolerance( + current_stats, + self.total_samples, + self.frozen_stats, + first_errno_tolerance=1, ) if self.total_samples == 0: @@ -225,7 +237,7 @@ def build_report(self, tolerance: int) -> dict: # For tolerance < 1, pass None to let calculate_es_components_values recalculate pi # For tolerance >= 1, use frozen pi from t=1 pi_for_calc = None if tolerance < 1 else self.pi - params = calculate_parameters_for_tolerance( + es_constructor_params = calculate_es_constructor_params_for_tolerance( tolerance, self.total_samples, stats_for_calc, @@ -233,18 +245,18 @@ def build_report(self, tolerance: int) -> dict: self.negative_speedup_penalty, self.fpdb, ) - # Use calculated pi from params for display and return - calculated_pi = params.get("pi", self.pi) + # Use calculated pi from es_constructor_params for display and return + calculated_pi = es_constructor_params.get("pi", self.pi) print_tolerance_details( tolerance, self.total_samples, stats_for_calc, - params, + es_constructor_params, calculated_pi, self.fpdb, ) return { - **params, + **es_constructor_params, "pi": calculated_pi, "correct_count": stats_for_calc["correct_count"], "total_samples": self.total_samples, @@ -270,19 +282,19 @@ def _empty_report(self, tolerance: int) -> dict: } -def verify_es_components_across_tolerances( +def verify_es_constructor_params_across_tolerances( samples: list, folder_name: str, negative_speedup_penalty: float = 0, fpdb: float = 0.1, ) -> dict: """ - Verify and print ES component values (alpha, beta, gamma, lambda, eta, pi) for each + Verify and print ES constructor parameters (alpha, beta, gamma, lambda, eta, pi) for each tolerance level independently. This logic mirrors `calculate_s_scores` but is split out for focused validation of aggregated calculations. Returns: - Dictionary mapping tolerance -> dict of aggregated parameters and calculated scores + Dictionary mapping tolerance -> dict of ES constructor parameters and calculated scores """ total_samples = len(samples) @@ -343,7 +355,7 @@ def main(): # Calculate and print aggregated parameters for each curve for folder_name, samples in all_results.items(): - aggregated_results = verify_es_components_across_tolerances( + aggregated_results = verify_es_constructor_params_across_tolerances( samples, folder_name, negative_speedup_penalty=args.negative_speedup_penalty, From 0e48433794b9fc6884130b290145a31559455d05 Mon Sep 17 00:00:00 2001 From: Dayuxiaoshui <792179245@qq.com> Date: Tue, 18 Nov 2025 11:17:37 +0800 Subject: [PATCH 7/8] feat: add aggregated ES(t) plotting and verification - Modified plot_ES_results to return fig, ax, all_x_coords for external plotting - Added manual plotting of aggregated ES(t) curves in main function - Both microscopic and aggregated curves are plotted on the same graph - Aggregated curves use dashed lines with square markers for distinction - All verification checks pass with floating-point precision differences (1.39e-17) --- graph_net/analysis_util.py | 340 ++++++++++++-------------- graph_net/plot_ESt.py | 184 +++++++++++--- graph_net/verify_aggregated_params.py | 2 +- 3 files changed, 303 insertions(+), 223 deletions(-) diff --git a/graph_net/analysis_util.py b/graph_net/analysis_util.py index 99d4a1777..3bf3bf8ab 100644 --- a/graph_net/analysis_util.py +++ b/graph_net/analysis_util.py @@ -5,8 +5,6 @@ from scipy.stats import gmean from collections import OrderedDict, defaultdict from graph_net.config.datatype_tolerance_config import get_precision -from graph_net import samples_statistics -from graph_net.samples_statistics import get_errno_from_error_type def extract_speedup_data_from_subdirs(benchmark_path: str) -> dict: @@ -166,6 +164,7 @@ def parse_logs_to_data(log_file: str) -> list: List of data dictionaries, each containing configuration, correctness, performance, and result information for a single model-compiler run. """ + try: with open(log_file, "r", encoding="utf-8") as f: lines = f.readlines() @@ -332,10 +331,13 @@ def scan_all_folders(benchmark_path: str) -> dict: """ Unified entry point that supports log files and directories: - If benchmark_path is a log file (.log or .txt) → parse it directly and return data as a single curve. + - If benchmark_path is a directory → scan for .log and .txt files in the directory, each log file becomes a curve. + Returns dict[curve_name] -> list_of_samples """ + # Handle single log file if os.path.isfile(benchmark_path): print(f"Detected log file: '{benchmark_path}'") @@ -416,117 +418,6 @@ def get_correctness(dtype: str, t: int, correctness_data: dict, index: int) -> b return False -def get_sample_correctness(sample: dict, t_key: int) -> tuple[bool, str]: - """ - Get sample correctness status at the given tolerance level. - - Args: - sample: Sample data dictionary - t_key: Tolerance level - - Returns: - Tuple of (is_correct, fail_type) - - is_correct: True if sample is correct at this tolerance - - fail_type: Error type if not correct, None if correct - """ - performance_data = sample.get("performance", {}) - fail_type = performance_data.get("failure") - - # If there's already a failure type, return it - if fail_type is not None: - return False, fail_type - - # Check correctness based on datatype and tolerance - datatype_data = performance_data.get("datatype", {}) - eager_dtypes = datatype_data.get("eager", []) - compiled_dtypes = datatype_data.get("compiled", []) - - # Check if datatypes match and are valid - if not (eager_dtypes and eager_dtypes == compiled_dtypes and len(eager_dtypes) > 0): - return False, "accuracy" - - correctness_data = sample.get("correctness", {}) - output_count = len(correctness_data.get("[equal]", [])) - - if len(eager_dtypes) != output_count: - return False, "accuracy" - - # Check all outputs for correctness - is_correct = all( - get_correctness(eager_dtypes[i], t_key, correctness_data, i) - for i in range(output_count) - ) - - return is_correct, None if is_correct else "accuracy" - - -def calculate_rectified_speedup( - speedup: float, fail_type: str, negative_speedup_penalty: float, fpdb: float -) -> float: - """ - Calculate rectified speedup for S(t) calculation. - - Args: - speedup: Original speedup value - fail_type: Error type or None if correct - negative_speedup_penalty: Penalty power p for negative speedup - fpdb: Base penalty for failures - - Returns: - Rectified speedup value - """ - if fail_type is not None or speedup is None: - return fpdb - - if speedup < 1: - return speedup ** (negative_speedup_penalty + 1) - return speedup - - -def calculate_es_rectified_speedup( - speedup: float, - fail_type: str, - t_key: int, - is_correct_at_first_errno_tolerance: bool, - speedup_at_first_errno_tolerance: float, - fail_type_at_first_errno_tolerance: str, - negative_speedup_penalty: float, - fpdb: float, -) -> float: - """ - Calculate rectified speedup for ES(t) calculation. - - Args: - speedup: Current speedup value - fail_type: Current error type - t_key: Current tolerance level - is_correct_at_first_errno_tolerance: Whether sample was correct at first errno tolerance (t=1) - speedup_at_first_errno_tolerance: Speedup value at first errno tolerance (t=1) - fail_type_at_first_errno_tolerance: Error type at first errno tolerance (t=1) - negative_speedup_penalty: Penalty power p - fpdb: Base penalty for failures - - Returns: - Error-aware rectified speedup value - """ - if t_key < 1: - # For t < 1, ES(t) = S(t) - return calculate_rectified_speedup( - speedup, fail_type, negative_speedup_penalty, fpdb - ) - - # For t >= 1, use frozen state from first errno tolerance (t=1) - if ( - not is_correct_at_first_errno_tolerance - or speedup_at_first_errno_tolerance is None - ): - return fake_perf_degrad(t_key, fail_type_at_first_errno_tolerance, fpdb) - - if speedup_at_first_errno_tolerance < 1: - return speedup_at_first_errno_tolerance ** (negative_speedup_penalty + 1) - return speedup_at_first_errno_tolerance - - def fake_perf_degrad(t, error_code, fpdb=0.1): """ Calculate fake performance degradation based on tolerance t and error code. @@ -558,9 +449,6 @@ def calculate_s_scores( """ s_scores = OrderedDict() s_scores_fake_degrad = OrderedDict() - # Store aggregated-level calculation results for cross-validation - s_scores._aggregated_results = OrderedDict() - s_scores_fake_degrad._aggregated_results = OrderedDict() begin = -10 end = 4 @@ -572,39 +460,38 @@ def calculate_s_scores( def print_stat_info( t_key, correct_count, - errno2count, + acc_failure_count, pi, correct_negative_speedup_count, correct_speedups, + slowdown_speedups, ): - """ - Calculate and print aggregated statistics for a given tolerance level. - - Uses the samples_statistics module for all parameter calculations. - """ print(f" - Details for tolerance={t_key}:") if total_samples > 0: - # Calculate all aggregated parameters using the dedicated module - aggregated_params = samples_statistics.calculate_es_components_values( - total_samples=total_samples, - correct_speedups=correct_speedups, - errno2count=errno2count, - tolerance=t_key, - negative_speedup_penalty=negative_speedup_penalty, - b=fpdb, - pi=pi, + alpha = gmean(correct_speedups) if correct_speedups else 1 + beta = gmean(slowdown_speedups) if slowdown_speedups else 1 + lambda_ = correct_count / total_samples if total_samples > 0 else 0 + eta = ( + correct_negative_speedup_count / correct_count + if correct_count > 0 + else 0 + ) + indicator = [1 if t_key < 1 else 0, 1 if t_key < 3 else 0] + gamma = ( + fpdb ** sum(pi[i] * indicator[i] for i in range(len(pi))) + if t_key >= 1 + else fpdb ) - alpha = aggregated_params["alpha"] - beta = aggregated_params["beta"] - lambda_ = aggregated_params["lambda"] - eta = aggregated_params["eta"] - gamma = aggregated_params["gamma"] - expected_s = samples_statistics.calculate_s_t_from_aggregated( - alpha, beta, lambda_, eta, negative_speedup_penalty, fpdb + expected_s = ( + alpha**lambda_ + * beta ** (lambda_ * eta * negative_speedup_penalty) + * fpdb ** (1 - lambda_) ) - expected_es = samples_statistics.calculate_es_t_from_aggregated( - alpha, beta, lambda_, eta, gamma, negative_speedup_penalty + expected_es = ( + alpha**lambda_ + * beta ** (lambda_ * eta * negative_speedup_penalty) + * gamma ** (1 - lambda_) ) print( @@ -618,39 +505,55 @@ def print_stat_info( ) else: print(" - No samples to analyze.") - expected_s = fpdb - expected_es = fpdb return expected_s, expected_es - # pi is a tuple of constants for t > 0 for each group: (pi[0], pi[1]) - # Calculated at first errno tolerance (t=1), used for all t >= 1 - pi = (0.0, 0.0) + # pi is a list of constants for t > 0 for each group + pi = [0, 0] - is_correct_at_first_errno_tolerance = [False] * total_samples - speedup_at_first_errno_tolerance = [None] * total_samples - fail_type_at_first_errno_tolerance = ["CORRECT"] * total_samples + is_correct_at_t1 = [False] * total_samples + speedup_at_t1 = [None] * total_samples + fail_type_at_t1 = ["CORRECT"] * total_samples final_correct_count = 0 final_correct_negative_speedup_count = 0 final_correct_speedups = [] - final_errno2count = {} # Store error type counts at t=1 (using errno) + final_slowdown_speedups = [] for t_key in t_keys: rectified_speedups = [] rectified_speedups_fake_degrad = [] correct_count = 0 - errno2count = {} # Dictionary to count errors by errno + acc_failure_count = 0 correct_negative_speedup_count = 0 correct_speedups = [] + slowdown_speedups = [] - # Process all samples using helper functions to reduce nesting for idx, sample in enumerate(samples): performance_data = sample.get("performance", {}) + fail_type = performance_data.get("failure") speedup = performance_data.get("speedup", {}).get("e2e") - # Get correctness using dedicated function - is_correct, fail_type = get_sample_correctness(sample, t_key) + # Determine the true state of the current sample (for statistics and S curve) + is_correct = False + if fail_type is None: + datatype_data = performance_data.get("datatype", {}) + eager_dtypes = datatype_data.get("eager", []) + compiled_dtypes = datatype_data.get("compiled", []) + if ( + eager_dtypes + and eager_dtypes == compiled_dtypes + and len(eager_dtypes) > 0 + ): + correctness_data = sample.get("correctness", {}) + output_count = len(correctness_data.get("[equal]", [])) + if len(eager_dtypes) == output_count: + is_correct = all( + get_correctness(eager_dtypes[i], t_key, correctness_data, i) + for i in range(output_count) + ) + if not is_correct: + fail_type = "accuracy" # Collect statistics if is_correct: @@ -659,47 +562,62 @@ def print_stat_info( correct_speedups.append(speedup) if speedup is not None and speedup < 1: correct_negative_speedup_count += 1 + slowdown_speedups.append(speedup) - # Count errors by errno (convert error type string to errno) - if fail_type is not None: - errno = get_errno_from_error_type(fail_type) - errno2count[errno] = errno2count.get(errno, 0) + 1 + if fail_type == "accuracy": + acc_failure_count += 1 - # Store state at first errno tolerance (t=1) for ES(t) calculation if t_key == 1: - is_correct_at_first_errno_tolerance[idx] = is_correct - speedup_at_first_errno_tolerance[idx] = speedup - fail_type_at_first_errno_tolerance[idx] = ( - fail_type if fail_type is not None else "CORRECT" - ) + is_correct_at_t1[idx] = is_correct + speedup_at_t1[idx] = speedup + fail_type_at_t1[idx] = fail_type if fail_type is not None else "CORRECT" - # Calculate rectified speedups using dedicated functions - regularized_speedup = calculate_rectified_speedup( - speedup, fail_type, negative_speedup_penalty, fpdb - ) + # S(t) calculation + if fail_type is not None or speedup is None: + regularized_speedup = fpdb + else: + regularized_speedup = ( + speedup ** (negative_speedup_penalty + 1) + if speedup < 1 + else speedup + ) rectified_speedups.append(regularized_speedup) - rec_speedup_fake_degrad = calculate_es_rectified_speedup( - speedup, - fail_type, - t_key, - is_correct_at_first_errno_tolerance[idx], - speedup_at_first_errno_tolerance[idx], - fail_type_at_first_errno_tolerance[idx], - negative_speedup_penalty, - fpdb, - ) + # ES(t) calculation: based on state change + if t_key < 1: + if fail_type is not None or speedup is None: + rec_speedup_fake_degrad = fpdb + else: + rec_speedup_fake_degrad = ( + speedup ** (negative_speedup_penalty + 1) + if speedup < 1 + else speedup + ) + else: + if not is_correct_at_t1[idx] or speedup_at_t1[idx] is None: + fail_type_frozen = fail_type_at_t1[idx] + rec_speedup_fake_degrad = fake_perf_degrad( + t_key, fail_type_frozen, fpdb + ) + else: + rec_speedup_fake_degrad = ( + speedup_at_t1[idx] ** (negative_speedup_penalty + 1) + if speedup_at_t1[idx] < 1 + else speedup_at_t1[idx] + ) rectified_speedups_fake_degrad.append(rec_speedup_fake_degrad) if t_key == 1: - # Calculate pi at t=1 using the dedicated function - pi = samples_statistics.calculate_pi( - errno2count, total_samples, correct_speedups - ) + if total_samples == correct_count: + pi[0] = 0 + pi[1] = 0 + else: + pi[0] = acc_failure_count / (total_samples - correct_count) + pi[1] = 1 - pi[0] final_correct_count = correct_count final_correct_negative_speedup_count = correct_negative_speedup_count final_correct_speedups = correct_speedups - final_errno2count = errno2count.copy() # Save for t >= 1 + final_slowdown_speedups = slowdown_speedups if rectified_speedups: s_scores[t_key] = gmean(rectified_speedups) @@ -711,28 +629,70 @@ def print_stat_info( expected_s, expected_es = print_stat_info( t_key, correct_count, - errno2count, + acc_failure_count, pi, correct_negative_speedup_count, correct_speedups, + slowdown_speedups, ) else: - # For t >= 1, use errno2count from t=1 (frozen state) expected_s, expected_es = print_stat_info( t_key, final_correct_count, - final_errno2count, # Use the frozen errno2count from t=1 + acc_failure_count, pi, final_correct_negative_speedup_count, final_correct_speedups, + final_slowdown_speedups, ) print( - f" - S(t)={expected_s:.3f}, ES(t)={expected_es:.3f} for tolerance={t_key} from aggregated level." + f" - S(t)={expected_s:.3f}, ES(t)={expected_es:.3f} for tolerance={t_key} from macro level." ) - # Store aggregated results for cross-validation - s_scores._aggregated_results[t_key] = expected_s - s_scores_fake_degrad._aggregated_results[t_key] = expected_es - print(f" - pi: {dict(sorted(pi.items()))}") + print(f" - pi: {pi}") return s_scores, s_scores_fake_degrad + + +def check_sample_correctness(sample: dict, t_key: int) -> tuple[bool, str]: + """ + Check if a sample is correct at the given tolerance level. + + Args: + sample: Sample data dictionary + t_key: Tolerance level + + Returns: + Tuple of (is_correct, fail_type) + - is_correct: True if sample is correct at this tolerance + - fail_type: Error type if not correct, None if correct + """ + performance_data = sample.get("performance", {}) + fail_type = performance_data.get("failure") + + # If there's already a failure type, return it + if fail_type is not None: + return False, fail_type + + # Check correctness based on datatype and tolerance + datatype_data = performance_data.get("datatype", {}) + eager_dtypes = datatype_data.get("eager", []) + compiled_dtypes = datatype_data.get("compiled", []) + + # Check if datatypes match and are valid + if not (eager_dtypes and eager_dtypes == compiled_dtypes and len(eager_dtypes) > 0): + return False, "accuracy" + + correctness_data = sample.get("correctness", {}) + output_count = len(correctness_data.get("[equal]", [])) + + if len(eager_dtypes) != output_count: + return False, "accuracy" + + # Check all outputs for correctness + is_correct = all( + get_correctness(eager_dtypes[i], t_key, correctness_data, i) + for i in range(output_count) + ) + + return is_correct, None if is_correct else "accuracy" diff --git a/graph_net/plot_ESt.py b/graph_net/plot_ESt.py index c6eec05a2..15e4f5917 100644 --- a/graph_net/plot_ESt.py +++ b/graph_net/plot_ESt.py @@ -3,6 +3,24 @@ import numpy as np import matplotlib.pyplot as plt from graph_net import analysis_util +from graph_net import verify_aggregated_params + + +class ESScoresWrapper: + """Wrapper for es_scores dict to allow attribute assignment.""" + + def __init__(self, es_scores_dict): + self._dict = es_scores_dict + self._aggregated_results = {} + + def items(self): + return self._dict.items() + + def __getitem__(self, key): + return self._dict[key] + + def __setitem__(self, key, value): + self._dict[key] = value def es_result_checker( @@ -20,10 +38,7 @@ def es_result_checker( Returns: True if values match within tolerance, False otherwise """ - diff = abs(es_from_microscopic - es_from_macro) - return diff < atol or diff < rtol * max( - abs(es_from_microscopic), abs(es_from_macro), 1e-10 - ) + return np.allclose(es_from_microscopic, es_from_macro, rtol=rtol, atol=atol) def compare_aggregated_es_and_microscopic_es( @@ -87,11 +102,13 @@ def get_verified_aggregated_es_values(es_scores: dict, folder_name: str) -> dict Returns: Dictionary of verified ES(t) values (only matched tolerance levels). - Returns empty dict if validation fails. + + Raises: + AssertionError: If aggregated and microscopic results do not match (fail-fast). """ aggregated_results = getattr(es_scores, "_aggregated_results", {}) verified_es_values = {} - all_matched = True + mismatches = [] print(f"\n{'='*80}") print(f"Verifying Aggregated/Microscopic Consistency for '{folder_name}'") @@ -112,25 +129,37 @@ def get_verified_aggregated_es_values(es_scores: dict, folder_name: str) -> dict is_matched, ) - if aggregated_es is None or not is_matched: - all_matched = False - if is_matched: + if aggregated_es is None: + mismatches.append( + f"t={tolerance}: Missing aggregated result (microscopic={microscopic_es:.6f})" + ) + elif not is_matched: + mismatches.append( + f"t={tolerance}: Mismatch - Microscopic={microscopic_es:.6f}, " + f"Aggregated={aggregated_es:.6f}, Diff={diff:.2e} ({relative_diff*100:.4f}%)" + ) + else: verified_es_values[tolerance] = microscopic_es - if not all_matched: - print( - f"\nERROR: Aggregated and microscopic results do not match for '{folder_name}'!" - ) - print("Calculation validation failed. Results will NOT be used for plotting.") - print("Please verify the calculation logic using verify_aggregated_params.py") - print(f"{'='*80}\n") - return {} - else: - print( - f"\nSUCCESS: All aggregated and microscopic results match for '{folder_name}'." + if mismatches: + error_msg = ( + f"\n{'='*80}\n" + f"ERROR: Aggregated and microscopic results do not match for '{folder_name}'!\n" + f"{'='*80}\n" + f"Mismatches:\n" + + "\n".join(f" - {mismatch}" for mismatch in mismatches) + + f"\n\nCalculation validation failed. Please verify the calculation logic " + f"using verify_aggregated_params.py\n" + f"{'='*80}\n" ) - print(f"{'='*80}\n") - return verified_es_values + print(error_msg) + raise AssertionError(error_msg) + + print( + f"\nSUCCESS: All aggregated and microscopic results match for '{folder_name}'." + ) + print(f"{'='*80}\n") + return verified_es_values def plot_ES_results(s_scores: dict, cli_args: argparse.Namespace): @@ -221,10 +250,7 @@ def plot_ES_results(s_scores: dict, cli_args: argparse.Namespace): ax.xaxis.grid(True, which="major", lw=0.7, ls=":", color="grey", alpha=0.5) ax.yaxis.grid(True, which="major", lw=0.7, ls=":", color="grey", alpha=0.5) - ax.legend(fontsize=16, loc="best") - output_file = os.path.join(cli_args.output_dir, "ES_result.png") - plt.savefig(output_file, dpi=300, bbox_inches="tight") - print(f"\nComparison plot saved to {output_file}") + return fig, ax, all_x_coords def main(): @@ -280,6 +306,7 @@ def main(): # 2. Calculate scores for each curve and verify aggregated/microscopic consistency all_es_scores = {} + all_aggregated_results = {} for folder_name, samples in all_results.items(): _, es_scores = analysis_util.calculate_s_scores( @@ -294,19 +321,112 @@ def main(): # Verify aggregated/microscopic consistency if aggregation mode is enabled if args.enable_aggregation_mode: + # Calculate aggregated results and attach to es_scores + aggregated_results = ( + verify_aggregated_params.verify_es_constructor_params_across_tolerances( + samples, + folder_name, + negative_speedup_penalty=args.negative_speedup_penalty, + fpdb=args.fpdb, + ) + ) + # Store aggregated results for plotting + all_aggregated_results[folder_name] = aggregated_results + + # Extract expected_es values and attach as _aggregated_results + # Wrap es_scores to allow attribute assignment + es_scores_wrapper = ESScoresWrapper(es_scores) + es_scores_wrapper._aggregated_results = { + tolerance: result["expected_es"] + for tolerance, result in aggregated_results.items() + } + + # Fail-fast: raise AssertionError if validation fails verified_es_values = get_verified_aggregated_es_values( - es_scores, folder_name + es_scores_wrapper, folder_name ) - - if not verified_es_values: - continue # Skip this curve if validation fails - all_es_scores[folder_name] = verified_es_values # 3. Plot the results if any(all_es_scores.values()): os.makedirs(args.output_dir, exist_ok=True) - plot_ES_results(all_es_scores, args) + fig, ax, all_x_coords = plot_ES_results(all_es_scores, args) + + # Manually add aggregated curves if available + if args.enable_aggregation_mode and all_aggregated_results: + prop_cycle = plt.rcParams["axes.prop_cycle"] + colors = prop_cycle.by_key()["color"] + + for idx, (folder_name, aggregated_results) in enumerate( + all_aggregated_results.items() + ): + if folder_name not in all_es_scores: + continue + + color = colors[idx % len(colors)] + agg_plot_points = [] + for tolerance, result in aggregated_results.items(): + if isinstance(result, dict) and "expected_es" in result: + agg_plot_points.append( + {"x": tolerance, "y": result["expected_es"]} + ) + + if agg_plot_points: + agg_plot_points.sort(key=lambda p: p["x"]) + agg_x_vals = np.array([p["x"] for p in agg_plot_points]) + agg_y_vals = np.array([p["y"] for p in agg_plot_points]) + + agg_zero_index = ( + np.where(agg_x_vals == 0)[0][0] if 0 in agg_x_vals else None + ) + + if agg_zero_index is not None: + ax.plot( + agg_x_vals[: agg_zero_index + 1], + agg_y_vals[: agg_zero_index + 1], + "s--", + color=color, + label=f"{folder_name} (aggregated)", + linewidth=2, + markersize=6, + alpha=0.7, + ) + ax.plot( + agg_x_vals[agg_zero_index:], + agg_y_vals[agg_zero_index:], + "s--", + color=color, + linewidth=2, + markersize=6, + drawstyle="steps-post", + alpha=0.7, + ) + else: + ax.plot( + agg_x_vals, + agg_y_vals, + "s--", + color=color, + label=f"{folder_name} (aggregated)", + linewidth=2, + markersize=6, + alpha=0.7, + ) + + # Update x-axis range if needed + if all_x_coords: + for folder_name, aggregated_results in all_aggregated_results.items(): + for tolerance in aggregated_results.keys(): + all_x_coords.append(tolerance) + x_min = int(np.floor(min(all_x_coords))) + x_max = int(np.ceil(max(all_x_coords))) + ax.set_xticks(np.arange(x_min, x_max + 1)) + + ax.legend(fontsize=16, loc="best") + + output_file = os.path.join(args.output_dir, "ES_result.png") + plt.savefig(output_file, dpi=300, bbox_inches="tight") + print(f"\nComparison plot saved to {output_file}") else: print("No ES(t) scores were calculated. Skipping plot generation.") diff --git a/graph_net/verify_aggregated_params.py b/graph_net/verify_aggregated_params.py index ff22bc37a..27be82d60 100644 --- a/graph_net/verify_aggregated_params.py +++ b/graph_net/verify_aggregated_params.py @@ -25,7 +25,7 @@ def extract_statistics_at_tolerance(samples: list, tolerance: int) -> dict: idx, sample, sample.get("performance", {}).get("speedup", {}).get("e2e"), - *analysis_util.get_sample_correctness(sample, tolerance), + *analysis_util.check_sample_correctness(sample, tolerance), ) for idx, sample in enumerate(samples) ] From 8654af54481a9f0320a6715505734d346969cc07 Mon Sep 17 00:00:00 2001 From: Dayuxiaoshui <792179245@qq.com> Date: Tue, 18 Nov 2025 11:20:35 +0800 Subject: [PATCH 8/8] fix: move ax.legend outside aggregation condition block - Move ax.legend() outside the aggregation mode condition block - Ensure legend is always displayed regardless of aggregation mode - Fix issue where legend was missing when aggregation mode is disabled --- graph_net/plot_ESt.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/graph_net/plot_ESt.py b/graph_net/plot_ESt.py index 15e4f5917..f786e3619 100644 --- a/graph_net/plot_ESt.py +++ b/graph_net/plot_ESt.py @@ -422,7 +422,8 @@ def main(): x_max = int(np.ceil(max(all_x_coords))) ax.set_xticks(np.arange(x_min, x_max + 1)) - ax.legend(fontsize=16, loc="best") + # Always show legend (whether aggregated curves are added or not) + ax.legend(fontsize=16, loc="best") output_file = os.path.join(args.output_dir, "ES_result.png") plt.savefig(output_file, dpi=300, bbox_inches="tight")