Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 19 additions & 30 deletions graph_net/analysis_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import sys
from scipy.stats import gmean
from graph_net.config.datatype_tolerance_config import get_precision
from graph_net.positive_tolerance_interpretation import PositiveToleranceInterpretation
from graph_net.verify_aggregated_params import determine_tolerances


def detect_sample_status(log_text: str) -> str:
Expand Down Expand Up @@ -293,38 +295,24 @@ def get_correctness(dtype: str, t: int, correctness_data: dict, index: int) -> b
return False


def fake_perf_degrad(tolerance, error_code, type="default") -> str:
def fake_perf_degrad(
tolerance,
error_code,
positive_tolerance_interpretation: PositiveToleranceInterpretation,
) -> str:
"""
Judge current correctness based on tolerance t and status.
Refactored to delegate logic to PositiveToleranceInterpretation classes.
"""
if type == "default":
if tolerance >= 3:
return "correct"
elif error_code == "accuracy" and tolerance >= 1:
return "correct"
else:
return error_code
elif type == "extended":
if (
error_code == "compile_fail" or error_code == "runtime_fail"
) and tolerance >= 4:
return "correct"
elif error_code == "eager_fail" and tolerance >= 3:
return "correct"
elif (
error_code == "shape_mismatch" or error_code == "type_mismatch"
) and tolerance >= 2:
return "correct"
elif error_code == "accuracy" and tolerance >= 1:
return "correct"
else:
return error_code
else:
raise NotImplementedError
if positive_tolerance_interpretation.is_error_tolerated(tolerance, error_code):
return "correct"

return error_code


def calculate_scores(
samples: list,
positive_tolerance_interpretation: PositiveToleranceInterpretation,
p: float = 0,
b: float = 0.1,
type: str = "ESt",
Expand All @@ -339,7 +327,10 @@ def calculate_scores(

scores = {}

for tolerance in range(-10, 5):
strategy = positive_tolerance_interpretation
tolerances = determine_tolerances(samples, positive_tolerance_interpretation)

for tolerance in tolerances:
rectified_speedups = []
rectified_speedups_fake_degrad = []

Expand Down Expand Up @@ -373,12 +364,10 @@ def calculate_scores(
)
else:
if not is_correct_at_t1[idx]:
current_correctness = fake_perf_degrad(
is_tolerated = strategy.is_error_tolerated(
tolerance, fail_type_at_t1[idx]
)
rec_speedup_fake_degrad = (
1 if current_correctness == "correct" else b
)
rec_speedup_fake_degrad = 1 if is_tolerated else b
else:
rec_speedup_fake_degrad = (
speedup_at_t1[idx] ** (p + 1)
Expand Down
77 changes: 77 additions & 0 deletions graph_net/default_positive_tolerance_interpretation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from enum import IntEnum

from graph_net.positive_tolerance_interpretation import PositiveToleranceInterpretation


class DefaultErrorEnum(IntEnum):
"""
Values correspond to the minimum tolerance level required.
"""

kAccuracyViolation = 1 # Accuracy
kRuntimeFailure = 2 # Includes Runtime, NaN, Inf, TypeMismatch, etc.
kCompilationFailed = 3 # Compile Failure

@classmethod
def get_error_enum(cls, base_error_type: str) -> "DefaultErrorEnum":
if not base_error_type:
return cls.kRuntimeFailure

etype = base_error_type.lower()

if "accuracy" in etype:
return cls.kAccuracyViolation

if "compile_fail" in etype:
return cls.kCompilationFailed

return cls.kRuntimeFailure


class DefaultPositiveToleranceInterpretation(PositiveToleranceInterpretation):
"""
Legacy interpretation:
- t=1: Accuracy errors tolerated.
- t=3: Runtime/Compilation errors tolerated.
"""

def __init__(self, *argc, **kwargs):
super().__init__(*argc, **kwargs)

def type_name(self) -> str:
return "default"

def get_errno(self, error_type: str) -> int:
return DefaultErrorEnum.get_error_enum(error_type).value

def get_error_type(self, errno: int) -> str:
mapping = {1: "accuracy", 2: "runtime_fail", 3: "compile_fail"}
return mapping.get(errno, "unknown_error")

def get_tolerance_mapping(self) -> dict[int, int]:
return {
DefaultErrorEnum.kAccuracyViolation.value: 1,
DefaultErrorEnum.kRuntimeFailure.value: 3,
DefaultErrorEnum.kCompilationFailed.value: 3,
}

def is_error_tolerated(self, tolerance: int, base_error_code: str) -> bool:
if base_error_code == "correct":
return True
if base_error_code in ["eager_fail", "reference_fail"]:
return False

error_enum = DefaultErrorEnum.get_error_enum(base_error_code)
mapping = self.get_tolerance_mapping()
required_threshold = mapping.get(error_enum.value, 999)

return tolerance >= required_threshold

def num_errno_enum_values(self) -> int:
"""
Default mode defines 3 levels of errors:
1: Accuracy
2: Runtime (Generic)
3: Compilation
"""
return len(DefaultErrorEnum)
87 changes: 87 additions & 0 deletions graph_net/mismatch_extended_positive_tolerance_interpretation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from enum import IntEnum

from graph_net.positive_tolerance_interpretation import PositiveToleranceInterpretation


class MismatchExtendedErrorEnum(IntEnum):
"""
Values correspond to the minimum tolerance level required.
"""

kAccuracyViolation = 1
kValueTypeOrMetaMismatch = 2
kExecutionFailed = 3
kCompilationFailed = 4

@classmethod
def get_error_enum(cls, base_error_type: str) -> "MismatchExtendedErrorEnum":
if not base_error_type:
return cls.kExecutionFailed

etype = base_error_type.lower()
if "accuracy" in etype:
return cls.kAccuracyViolation
if any(x in etype for x in ["nan", "inf", "type_mismatch", "shape_mismatch"]):
return cls.kValueTypeOrMetaMismatch
if "compile_fail" in etype:
return cls.kCompilationFailed

return cls.kExecutionFailed


class MismatchExtendedPositiveToleranceInterpretation(PositiveToleranceInterpretation):
"""
Extended interpretation (ESt):
- t=1: Accuracy
- t=2: NaN/Inf/Type/Shape
- t=3: Runtime
- t=4: Compile
"""

def __init__(self, *argc, **kwargs):
super().__init__(*argc, **kwargs)

def type_name(self) -> str:
return "mismatch_extended"

def get_errno(self, error_type: str) -> int:
return MismatchExtendedErrorEnum.get_error_enum(error_type).value

def get_error_type(self, errno: int) -> str:
mapping = {
1: "accuracy",
2: "type/shape_mismatch",
3: "runtime_fail",
4: "compile_fail",
}
return mapping.get(errno, "unknown_error")

def get_tolerance_mapping(self) -> dict[int, int]:
return {
MismatchExtendedErrorEnum.kAccuracyViolation.value: 1,
MismatchExtendedErrorEnum.kValueTypeOrMetaMismatch.value: 2,
MismatchExtendedErrorEnum.kExecutionFailed.value: 3,
MismatchExtendedErrorEnum.kCompilationFailed.value: 4,
}

def is_error_tolerated(self, tolerance: int, base_error_code: str) -> bool:
if base_error_code == "correct":
return True
if base_error_code in ["eager_fail", "reference_fail"]:
return False

error_enum = MismatchExtendedErrorEnum.get_error_enum(base_error_code)
mapping = self.get_tolerance_mapping()
required_threshold = mapping.get(error_enum.value, 999)

return tolerance >= required_threshold

def num_errno_enum_values(self) -> int:
"""
Extended mode defines 4 levels of errors:
1: Accuracy
2: Type/Shape/NaN
3: Runtime
4: Compilation
"""
return len(MismatchExtendedErrorEnum)
16 changes: 16 additions & 0 deletions graph_net/plot_ESt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
import matplotlib.pyplot as plt
from graph_net import analysis_util
from graph_net import verify_aggregated_params
from graph_net.positive_tolerance_interpretation_manager import (
get_supported_positive_tolerance_interpretation_types,
get_positive_tolerance_interpretation,
)


class ESScoresWrapper:
Expand Down Expand Up @@ -262,6 +266,9 @@ def main(args):
# 2. Calculate scores for each curve and verify aggregated/microscopic consistency
all_es_scores = {}
all_aggregated_results = {}
positive_tolerance_interpretation = get_positive_tolerance_interpretation(
args.interpretation_type
)

for folder_name, samples in all_results.items():
print(f"\nCalculating ESt scores for '{folder_name}'...")
Expand All @@ -271,6 +278,7 @@ def main(args):
p=args.negative_speedup_penalty,
b=args.fpdb,
type="ESt",
positive_tolerance_interpretation=positive_tolerance_interpretation,
)

# Keep original behavior: assign es_scores directly
Expand All @@ -285,6 +293,7 @@ def main(args):
folder_name,
negative_speedup_penalty=args.negative_speedup_penalty,
fpdb=args.fpdb,
positive_tolerance_interpretation=positive_tolerance_interpretation,
)
)
# Store aggregated results for plotting
Expand Down Expand Up @@ -429,6 +438,13 @@ def main(args):
action="store_false",
help="Disable aggregation mode verification.",
)
parser.add_argument(
"--positive-tolerance-interpretation",
dest="interpretation_type",
choices=get_supported_positive_tolerance_interpretation_types(),
default="default",
help="Select how positive tolerance values are interpreted into error types.",
)
parser.set_defaults(enable_aggregation_mode=True)
args = parser.parse_args()
main(args)
53 changes: 53 additions & 0 deletions graph_net/positive_tolerance_interpretation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from abc import ABC, abstractmethod


class PositiveToleranceInterpretation(ABC):
"""
Abstract base class defining how positive tolerance values (t > 0)
are interpreted and mapped to specific error types.
"""

def __init__(self, *argc, **kwargs):
pass

@abstractmethod
def type_name(self) -> str:
"""Return the unique string identifier for this interpretation strategy."""
raise NotImplementedError

@abstractmethod
def get_errno(self, error_type: str) -> int:
"""Map a raw error type string to an internal error number (errno)."""
raise NotImplementedError

@abstractmethod
def get_error_type(self, errno: int) -> str:
"""Map an internal error number (errno) back to a representative string."""
raise NotImplementedError

@abstractmethod
def get_tolerance_mapping(self) -> dict[int, int]:
"""
Return the mapping of errno.
Used for statistical calculations (Gamma/Pi).
"""
raise NotImplementedError

@abstractmethod
def is_error_tolerated(self, tolerance: int, base_error_code: str) -> bool:
"""
Determine if a specific error is considered 'correct' under the given tolerance.
Replaces the old 'fake_perf_degrad' logic.
"""
raise NotImplementedError

@abstractmethod
def num_errno_enum_values(self) -> int:
"""
Return the number of defined error categories (or the maximum errno).

Example:
- Default: returns 3 (Accuracy, Runtime, Compile)
- MismatchExtended: returns 4 (Accuracy, Data, Runtime, Compile)
"""
raise NotImplementedError
Loading
Loading