From 50acecc3d12778f1f30ca636b6e83163f1fc775a Mon Sep 17 00:00:00 2001 From: Yogesh Garg Date: Fri, 2 Mar 2018 16:00:40 -0800 Subject: [PATCH 1/3] add test case for JavaWrapper that displays memory leak for JavaWrapper but not JavaParams --- python/pyspark/ml/tests.py | 39 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 116885969345c..6dee6938d8916 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -173,6 +173,45 @@ class MockModel(MockTransformer, Model, HasFake): pass +class JavaWrapperMemoryTests(SparkSessionTestCase): + + def test_java_object_gets_detached(self): + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], []))], + ["label", "weight", "features"]) + lr = LinearRegression(maxIter=1, regParam=0.0, solver="normal", weightCol="weight", + fitIntercept=False) + + model = lr.fit(df) + summary = model.summary + + self.assertIsInstance(model, JavaWrapper) + self.assertIsInstance(summary, JavaWrapper) + self.assertIsInstance(model, JavaParams) + self.assertNotIsInstance(summary, JavaParams) + + error_no_object = 'Target Object ID does not exist for this gateway' + + self.assertIn("LinearRegression_", model._java_obj.toString()) + self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString()) + + model.__del__() + + with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): + model._java_obj.toString() + self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString()) + + try: + summary.__del__() + except: + pass + + with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): + model._java_obj.toString() + with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): + summary._java_obj.toString() + + class ParamTypeConversionTests(PySparkTestCase): """ Test that param type conversion happens. From d36c1a10cd318d9ddeb2717737248c974a2349f1 Mon Sep 17 00:00:00 2001 From: Yogesh Garg Date: Fri, 2 Mar 2018 16:01:19 -0800 Subject: [PATCH 2/3] send the delete method from JavaParams to JavaWrapper --- python/pyspark/ml/wrapper.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 0f846fbc5b5ef..34ab627fa2956 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -36,6 +36,10 @@ def __init__(self, java_obj=None): super(JavaWrapper, self).__init__() self._java_obj = java_obj + def __del__(self): + if SparkContext._active_spark_context and not self._java_obj is None: + SparkContext._active_spark_context._gateway.detach(self._java_obj) + @classmethod def _create_from_java_class(cls, java_class, *args): """ @@ -100,10 +104,6 @@ class JavaParams(JavaWrapper, Params): __metaclass__ = ABCMeta - def __del__(self): - if SparkContext._active_spark_context: - SparkContext._active_spark_context._gateway.detach(self._java_obj) - def _make_java_param_pair(self, param, value): """ Makes a Java param pair. From 07e18299c83d0874f7b5aaaa301d0eb80746ab01 Mon Sep 17 00:00:00 2001 From: Yogesh Garg Date: Fri, 2 Mar 2018 18:41:26 -0800 Subject: [PATCH 3/3] fix style --- python/pyspark/ml/wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 34ab627fa2956..5061f6434794a 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -37,7 +37,7 @@ def __init__(self, java_obj=None): self._java_obj = java_obj def __del__(self): - if SparkContext._active_spark_context and not self._java_obj is None: + if SparkContext._active_spark_context and self._java_obj is not None: SparkContext._active_spark_context._gateway.detach(self._java_obj) @classmethod