diff --git a/python/pyspark/testing/mlutils.py b/python/pyspark/testing/mlutils.py index 8c1c7c3b1242e..e26a4cc83ee52 100644 --- a/python/pyspark/testing/mlutils.py +++ b/python/pyspark/testing/mlutils.py @@ -25,8 +25,7 @@ from pyspark.ml.classification import Classifier, ClassificationModel from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable from pyspark.ml.wrapper import _java2py -from pyspark.sql import SparkSession -from pyspark.sql.classic.dataframe import DataFrame +from pyspark.sql import DataFrame, SparkSession from pyspark.sql.types import DoubleType from pyspark.testing.utils import ReusedPySparkTestCase as PySparkTestCase @@ -100,6 +99,11 @@ def tearDownClass(cls): class MockDataset(DataFrame): + def __new__(cls, *args, **kwargs): + # DataFrame by default creates classic DataFrame, we need this to + # overwrite the default behavior. + return object.__new__(cls) + def __init__(self): self.index = 0