Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion src/tether/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ class ValidationResult:
num_elements: int
threshold: float
details: str
cosine_sim: float = float("nan")

def to_dict(self) -> dict:
return {
"passed": self.passed,
"max_abs_diff": round(self.max_abs_diff, 6),
"mean_abs_diff": round(self.mean_abs_diff, 6),
"max_rel_diff": round(self.max_rel_diff, 6),
"cosine_sim": round(self.cosine_sim, 6),
"num_elements": self.num_elements,
"threshold": self.threshold,
"details": self.details,
Expand Down Expand Up @@ -63,14 +65,25 @@ def validate_outputs(
max_abs = float(abs_diff.max())
mean_abs = float(abs_diff.mean())

reference_flat = reference.ravel().astype(np.float64, copy=False)
candidate_flat = candidate.ravel().astype(np.float64, copy=False)
reference_norm = float(np.linalg.norm(reference_flat))
candidate_norm = float(np.linalg.norm(candidate_flat))
if reference_norm == 0.0 or candidate_norm == 0.0:
cosine_sim = 0.0
else:
cosine_sim = float(
np.dot(reference_flat, candidate_flat) / (reference_norm * candidate_norm)
)

denom = np.maximum(np.abs(reference), 1e-8)
rel_diff = abs_diff / denom
max_rel = float(rel_diff.max())

passed = max_abs < threshold
details = (
f"{name}: max_abs={max_abs:.6f}, mean_abs={mean_abs:.6f}, "
f"max_rel={max_rel:.4f}, threshold={threshold}"
f"max_rel={max_rel:.4f}, cosine_sim={cosine_sim:.6f}, threshold={threshold}"
)
if passed:
logger.info("PASS %s", details)
Expand All @@ -85,6 +98,7 @@ def validate_outputs(
num_elements=int(reference.size),
threshold=threshold,
details=details,
cosine_sim=cosine_sim,
)


Expand Down
39 changes: 39 additions & 0 deletions tests/test_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def test_identical_outputs_pass(self):
result = validate_outputs(ref, ref, threshold=0.01)
assert result.passed
assert result.max_abs_diff == 0.0
assert result.cosine_sim == pytest.approx(1.0)

def test_small_diff_passes(self):
ref = np.array([1.0, 2.0, 3.0])
Expand All @@ -27,14 +28,52 @@ def test_large_diff_fails(self):
assert not result.passed
assert result.max_abs_diff == pytest.approx(0.5)

def test_opposite_direction_reports_negative_cosine(self):
ref = np.array([1.0, 0.0])
candidate = np.array([-1.0, 0.0])
result = validate_outputs(ref, candidate, threshold=0.01)
assert not result.passed
assert result.cosine_sim == pytest.approx(-1.0)

def test_orthogonal_vectors_report_zero_cosine(self):
ref = np.array([1.0, 0.0])
candidate = np.array([0.0, 1.0])
result = validate_outputs(ref, candidate, threshold=2.0)
assert result.cosine_sim == pytest.approx(0.0)

def test_identical_tiny_outputs_report_unit_cosine(self):
ref = np.array([1e-6, 0.0])
result = validate_outputs(ref, ref, threshold=0.01)
assert result.cosine_sim == pytest.approx(1.0)

def test_zero_norm_outputs_report_zero_cosine(self):
ref = np.array([0.0, 0.0])
result = validate_outputs(ref, ref, threshold=0.01)
assert result.passed
assert result.cosine_sim == 0.0

def test_shape_mismatch_fails(self):
ref = np.array([1.0, 2.0])
candidate = np.array([1.0, 2.0, 3.0])
result = validate_outputs(ref, candidate)
assert not result.passed
assert np.isnan(result.cosine_sim)

def test_torch_tensor_input(self):
ref = torch.tensor([1.0, 2.0, 3.0])
candidate = torch.tensor([1.001, 2.001, 3.001])
result = validate_outputs(ref, candidate, threshold=0.01)
assert result.passed

def test_to_dict_includes_rounded_cosine(self):
result = ValidationResult(
passed=True,
max_abs_diff=0.0,
mean_abs_diff=0.0,
max_rel_diff=0.0,
num_elements=3,
threshold=0.01,
details="output",
cosine_sim=0.123456789,
)
assert result.to_dict()["cosine_sim"] == 0.123457