diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index d4397f593570c..c49527ecb7dd6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2386,7 +2386,7 @@ class AstBuilder extends DataTypeAstBuilder */ private def withSample(ctx: SampleContext, query: LogicalPlan): LogicalPlan = withOrigin(ctx) { // Create a sampled plan if we need one. - def sample(fraction: Double, seed: Long): Sample = { + def sample(fraction: Double, seed: Option[Long]): Sample = { // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling // function takes X PERCENT as the input and the range of X is [0, 100], we need to // adjust the fraction. @@ -2401,11 +2401,7 @@ class AstBuilder extends DataTypeAstBuilder throw QueryParsingErrors.emptyInputForTableSampleError(ctx) } - val seed = if (ctx.seed != null) { - ctx.seed.getText.toLong - } else { - (math.random() * 1000).toLong - } + val seed: Option[Long] = Option(ctx.seed).map(_.getText.toLong) ctx.sampleMethod() match { case ctx: SampleByRowsContext => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala index f7740cf558c1f..9341dee19c742 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala @@ -141,7 +141,7 @@ object NormalizePlan extends PredicateHelper { child ) case sample: Sample => - sample.copy(seed = 0L) + sample.copy(seed = sample.seed.map(_ => 0L)) case Join(left, right, joinType, condition, hint) if condition.isDefined => val newJoinType = joinType match { case ExistenceJoin(a: Attribute) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index c18b7fcecc484..8e9f264698caf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -1912,6 +1912,22 @@ object SubqueryAlias { } } +object Sample { + /** + * Convenience constructor that wraps a concrete seed in [[Some]]. + * Use the case-class constructor directly with [[None]] when no seed + * was specified and a random seed should be generated at execution time. + */ + def apply( + lowerBound: Double, + upperBound: Double, + withReplacement: Boolean, + seed: Long, + child: LogicalPlan): Sample = { + new Sample(lowerBound, upperBound, withReplacement, Some(seed), child) + } +} + /** * Sample the dataset. * @@ -1919,14 +1935,16 @@ object SubqueryAlias { * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled * will be ub - lb. * @param withReplacement Whether to sample with replacement. - * @param seed the random seed + * @param seed the random seed. `Some(seed)` when the user explicitly specified a seed + * (SQL `REPEATABLE` clause or programmatic API), `None` when no seed was + * specified and a random seed should be generated at execution time. * @param child the LogicalPlan */ case class Sample( lowerBound: Double, upperBound: Double, withReplacement: Boolean, - seed: Long, + seed: Option[Long], child: LogicalPlan) extends UnaryNode { val eps = RandomSampler.roundingEpsilon diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index b659e8273d124..edaa7aee5cabb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -831,9 +831,14 @@ class PlanParserSuite extends AnalysisTest { assertEqual(s"$sql tablesample(100 rows)", table("t").limit(100).select(star())) assertEqual(s"$sql tablesample(43 percent) as x", - Sample(0, .43d, withReplacement = false, 10L, table("t").as("x")).select(star())) + Sample(0, .43d, withReplacement = false, None, table("t").as("x")).select(star())) assertEqual(s"$sql tablesample(bucket 4 out of 10) as x", - Sample(0, .4d, withReplacement = false, 10L, table("t").as("x")).select(star())) + Sample(0, .4d, withReplacement = false, None, table("t").as("x")).select(star())) + // REPEATABLE clause produces Some(seed) + assertEqual(s"$sql tablesample(43 percent) repeatable (10) as x", + Sample(0, .43d, withReplacement = false, 10L, table("t").as("x")).select(star())) + assertEqual(s"$sql tablesample(bucket 4 out of 10) repeatable (99) as x", + Sample(0, .4d, withReplacement = false, 99L, table("t").as("x")).select(star())) val sql1 = s"$sql tablesample(bucket 4 out of 10 on x) as x" val fragment1 = "tablesample(bucket 4 out of 10 on x)" diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 37bcf995ee16d..9b6f09c983fc5 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -454,7 +454,7 @@ class SparkConnectPlanner( rel.getLowerBound, rel.getUpperBound, rel.getWithReplacement, - if (rel.hasSeed) rel.getSeed else Utils.random.nextLong, + if (rel.hasSeed) Some(rel.getSeed) else None, plan) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala index 683b66861a812..2070873f96579 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala @@ -1943,11 +1943,18 @@ class Dataset[T] private[sql]( override def sample(fraction: Double, seed: Long): Dataset[T] = super.sample(fraction, seed) /** @inheritdoc */ - override def sample(fraction: Double): Dataset[T] = super.sample(fraction) + override def sample(fraction: Double): Dataset[T] = { + withSameTypedPlan { + Sample(0.0, fraction, withReplacement = false, None, logicalPlan) + } + } /** @inheritdoc */ - override def sample(withReplacement: Boolean, fraction: Double): Dataset[T] = - super.sample(withReplacement, fraction) + override def sample(withReplacement: Boolean, fraction: Double): Dataset[T] = { + withSameTypedPlan { + Sample(0.0, fraction, withReplacement, None, logicalPlan) + } + } /** @inheritdoc */ override def dropDuplicates(colNames: Array[String]): Dataset[T] = super.dropDuplicates(colNames) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 5e73cdd6da8f4..aaafeccdfc6fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -352,8 +352,11 @@ case class SampleExec( lowerBound: Double, upperBound: Double, withReplacement: Boolean, - seed: Long, + seed: Option[Long], child: SparkPlan) extends UnaryExecNode with CodegenSupport { + + val resolvedSeed: Long = seed.getOrElse((math.random() * 1000).toLong) + override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning @@ -369,9 +372,9 @@ case class SampleExec( child.execute(), new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false), preservesPartitioning = true, - seed) + resolvedSeed) } else { - child.execute().randomSampleWithRange(lowerBound, upperBound, seed) + child.execute().randomSampleWithRange(lowerBound, upperBound, resolvedSeed) } } @@ -405,7 +408,7 @@ case class SampleExec( s""" | private void $initSampler() { | $v = new $samplerClass($upperBound - $lowerBound, false); - | java.util.Random random = new java.util.Random(${seed}L); + | java.util.Random random = new java.util.Random(${resolvedSeed}L); | long randomSeed = random.nextLong(); | int loopCount = 0; | while (loopCount < partitionIndex) { @@ -431,7 +434,7 @@ case class SampleExec( val sampler = ctx.addMutableState(s"$samplerClass", "sampler", v => s""" | $v = new $samplerClass($lowerBound, $upperBound, false); - | $v.setSeed(${seed}L + partitionIndex); + | $v.setSeed(${resolvedSeed}L + partitionIndex); """.stripMargin.trim) s""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 4a4ccab47cad0..f7eaafa4b63b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -818,7 +818,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { sample.lowerBound, sample.upperBound, sample.withReplacement, - sample.seed) + sample.seed.getOrElse((math.random() * 1000).toLong)) val pushed = PushDownUtils.pushTableSample(sHolder.builder, tableSample) if (pushed) { sHolder.pushedSample = Some(tableSample)