-
Notifications
You must be signed in to change notification settings - Fork 63
/
row_comparer.py
45 lines (38 loc) · 1.47 KB
/
row_comparer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from pyspark.sql import Row
from chispa.number_helpers import nan_safe_equality, nan_safe_approx_equality
import math
def are_rows_equal(r1: Row, r2: Row) -> bool:
return r1 == r2
def are_rows_equal_enhanced(r1: Row, r2: Row, allow_nan_equality: bool) -> bool:
if r1 is None and r2 is None:
return True
if (r1 is None and r2 is not None) or (r2 is None and r1 is not None):
return False
d1 = r1.asDict()
d2 = r2.asDict()
if allow_nan_equality:
for key in d1.keys() & d2.keys():
if not(nan_safe_equality(d1[key], d2[key])):
return False
return True
else:
return r1 == r2
def are_rows_approx_equal(r1: Row, r2: Row, precision: float, allow_nan_equality=False) -> bool:
if r1 is None and r2 is None:
return True
if (r1 is None and r2 is not None) or (r2 is None and r1 is not None):
return False
d1 = r1.asDict()
d2 = r2.asDict()
allEqual = True
for key in d1.keys() & d2.keys():
if isinstance(d1[key], float) and isinstance(d2[key], float):
if allow_nan_equality and not(nan_safe_approx_equality(d1[key], d2[key], precision)):
allEqual = False
elif not(allow_nan_equality) and math.isnan(abs(d1[key] - d2[key])):
allEqual = False
elif abs(d1[key] - d2[key]) > precision:
allEqual = False
elif d1[key] != d2[key]:
allEqual = False
return allEqual