diff --git a/python/pyspark/tests/test_statcounter.py b/python/pyspark/tests/test_statcounter.py index 9651871e113a8..b10fe7cd911c4 100644 --- a/python/pyspark/tests/test_statcounter.py +++ b/python/pyspark/tests/test_statcounter.py @@ -16,6 +16,7 @@ # from pyspark.statcounter import StatCounter from pyspark.testing.utils import ReusedPySparkTestCase +import math class StatCounterTests(ReusedPySparkTestCase): @@ -76,6 +77,31 @@ def test_merge_stats(self): self.assertEqual(stats.sum(), 20.0) self.assertAlmostEqual(stats.variance(), 1.25) self.assertAlmostEqual(stats.sampleVariance(), 1.4285714285714286) + execution_statements = [ + StatCounter([1.0, 2.0]).mergeStats(StatCounter(range(1, 301))), + StatCounter(range(1, 301)).mergeStats(StatCounter([1.0, 2.0])), + ] + for stats in execution_statements: + self.assertEqual(stats.count(), 302) + self.assertEqual(stats.max(), 300.0) + self.assertEqual(stats.min(), 1.0) + self.assertAlmostEqual(stats.mean(), 149.51324503311) + self.assertAlmostEqual(stats.variance(), 7596.302804701549) + self.assertAlmostEqual(stats.sampleVariance(), 7621.539691095905) + + def test_variance_when_size_zero(self): + # SPARK-38854: Test case to improve test coverage when + # StatCounter argument is empty list or None + arguments = [[], None] + + for arg in arguments: + stats = StatCounter(arg) + self.assertTrue(math.isnan(stats.variance())) + self.assertTrue(math.isnan(stats.sampleVariance())) + self.assertEqual(stats.count(), 0) + self.assertTrue(math.isinf(stats.max())) + self.assertTrue(math.isinf(stats.min())) + self.assertEqual(stats.mean(), 0.0) def test_merge_stats_with_self(self): stats = StatCounter([1.0, 2.0, 3.0, 4.0])