Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2386,7 +2386,7 @@ class AstBuilder extends DataTypeAstBuilder
*/
private def withSample(ctx: SampleContext, query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
// Create a sampled plan if we need one.
def sample(fraction: Double, seed: Long): Sample = {
def sample(fraction: Double, seed: Option[Long]): Sample = {
// The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling
// function takes X PERCENT as the input and the range of X is [0, 100], we need to
// adjust the fraction.
Expand All @@ -2401,11 +2401,7 @@ class AstBuilder extends DataTypeAstBuilder
throw QueryParsingErrors.emptyInputForTableSampleError(ctx)
}

val seed = if (ctx.seed != null) {
ctx.seed.getText.toLong
} else {
(math.random() * 1000).toLong
}
val seed: Option[Long] = Option(ctx.seed).map(_.getText.toLong)

ctx.sampleMethod() match {
case ctx: SampleByRowsContext =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ object NormalizePlan extends PredicateHelper {
child
)
case sample: Sample =>
sample.copy(seed = 0L)
sample.copy(seed = sample.seed.map(_ => 0L))
case Join(left, right, joinType, condition, hint) if condition.isDefined =>
val newJoinType = joinType match {
case ExistenceJoin(a: Attribute) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1912,21 +1912,39 @@ object SubqueryAlias {
}
}

object Sample {
/**
* Convenience constructor that wraps a concrete seed in [[Some]].
* Use the case-class constructor directly with [[None]] when no seed
* was specified and a random seed should be generated at execution time.
*/
def apply(
lowerBound: Double,
upperBound: Double,
withReplacement: Boolean,
seed: Long,
child: LogicalPlan): Sample = {
new Sample(lowerBound, upperBound, withReplacement, Some(seed), child)
}
}

/**
* Sample the dataset.
*
* @param lowerBound Lower-bound of the sampling probability (usually 0.0)
* @param upperBound Upper-bound of the sampling probability. The expected fraction sampled
* will be ub - lb.
* @param withReplacement Whether to sample with replacement.
* @param seed the random seed
* @param seed the random seed. `Some(seed)` when the user explicitly specified a seed
* (SQL `REPEATABLE` clause or programmatic API), `None` when no seed was
* specified and a random seed should be generated at execution time.
* @param child the LogicalPlan
*/
case class Sample(
lowerBound: Double,
upperBound: Double,
withReplacement: Boolean,
seed: Long,
seed: Option[Long],
child: LogicalPlan) extends UnaryNode {

val eps = RandomSampler.roundingEpsilon
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -831,9 +831,14 @@ class PlanParserSuite extends AnalysisTest {
assertEqual(s"$sql tablesample(100 rows)",
table("t").limit(100).select(star()))
assertEqual(s"$sql tablesample(43 percent) as x",
Sample(0, .43d, withReplacement = false, 10L, table("t").as("x")).select(star()))
Sample(0, .43d, withReplacement = false, None, table("t").as("x")).select(star()))
assertEqual(s"$sql tablesample(bucket 4 out of 10) as x",
Sample(0, .4d, withReplacement = false, 10L, table("t").as("x")).select(star()))
Sample(0, .4d, withReplacement = false, None, table("t").as("x")).select(star()))
// REPEATABLE clause produces Some(seed)
assertEqual(s"$sql tablesample(43 percent) repeatable (10) as x",
Sample(0, .43d, withReplacement = false, 10L, table("t").as("x")).select(star()))
assertEqual(s"$sql tablesample(bucket 4 out of 10) repeatable (99) as x",
Sample(0, .4d, withReplacement = false, 99L, table("t").as("x")).select(star()))

val sql1 = s"$sql tablesample(bucket 4 out of 10 on x) as x"
val fragment1 = "tablesample(bucket 4 out of 10 on x)"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ class SparkConnectPlanner(
rel.getLowerBound,
rel.getUpperBound,
rel.getWithReplacement,
if (rel.hasSeed) rel.getSeed else Utils.random.nextLong,
if (rel.hasSeed) Some(rel.getSeed) else None,
plan)
}

Expand Down
13 changes: 10 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1943,11 +1943,18 @@ class Dataset[T] private[sql](
override def sample(fraction: Double, seed: Long): Dataset[T] = super.sample(fraction, seed)

/** @inheritdoc */
override def sample(fraction: Double): Dataset[T] = super.sample(fraction)
override def sample(fraction: Double): Dataset[T] = {
withSameTypedPlan {
Sample(0.0, fraction, withReplacement = false, None, logicalPlan)
}
}

/** @inheritdoc */
override def sample(withReplacement: Boolean, fraction: Double): Dataset[T] =
super.sample(withReplacement, fraction)
override def sample(withReplacement: Boolean, fraction: Double): Dataset[T] = {
withSameTypedPlan {
Sample(0.0, fraction, withReplacement, None, logicalPlan)
}
}

/** @inheritdoc */
override def dropDuplicates(colNames: Array[String]): Dataset[T] = super.dropDuplicates(colNames)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,11 @@ case class SampleExec(
lowerBound: Double,
upperBound: Double,
withReplacement: Boolean,
seed: Long,
seed: Option[Long],
child: SparkPlan) extends UnaryExecNode with CodegenSupport {

val resolvedSeed: Long = seed.getOrElse((math.random() * 1000).toLong)

override def output: Seq[Attribute] = child.output

override def outputPartitioning: Partitioning = child.outputPartitioning
Expand All @@ -369,9 +372,9 @@ case class SampleExec(
child.execute(),
new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false),
preservesPartitioning = true,
seed)
resolvedSeed)
} else {
child.execute().randomSampleWithRange(lowerBound, upperBound, seed)
child.execute().randomSampleWithRange(lowerBound, upperBound, resolvedSeed)
}
}

Expand Down Expand Up @@ -405,7 +408,7 @@ case class SampleExec(
s"""
| private void $initSampler() {
| $v = new $samplerClass<UnsafeRow>($upperBound - $lowerBound, false);
| java.util.Random random = new java.util.Random(${seed}L);
| java.util.Random random = new java.util.Random(${resolvedSeed}L);
| long randomSeed = random.nextLong();
| int loopCount = 0;
| while (loopCount < partitionIndex) {
Expand All @@ -431,7 +434,7 @@ case class SampleExec(
val sampler = ctx.addMutableState(s"$samplerClass<UnsafeRow>", "sampler",
v => s"""
| $v = new $samplerClass<UnsafeRow>($lowerBound, $upperBound, false);
| $v.setSeed(${seed}L + partitionIndex);
| $v.setSeed(${resolvedSeed}L + partitionIndex);
""".stripMargin.trim)

s"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
sample.lowerBound,
sample.upperBound,
sample.withReplacement,
sample.seed)
sample.seed.getOrElse((math.random() * 1000).toLong))
val pushed = PushDownUtils.pushTableSample(sHolder.builder, tableSample)
if (pushed) {
sHolder.pushedSample = Some(tableSample)
Expand Down