Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/cumSum.kt
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,19 @@ import kotlin.reflect.KProperty
* from the first cell to the last cell.
*
* __NOTE:__ If the column contains nullable values and [skipNA\] is set to `true`,
* null and NaN values are skipped when computing the cumulative sum.
* When false, all values after the first NA will be NaN (for Double and Float columns)
* or null (for integer columns).
* `null` and `NaN` values are skipped when computing the cumulative sum.
* When `false`, all values after the first `NA` will be `NaN` (for [Double] and [Float] columns)
* or `null` (for other columns).
*
* {@get [CumSumDocs.CUMSUM_PARAM] @param [columns\]
* The names of the columns to apply cumSum operation.}
* `cumSum` only works on columns that contain solely primitive numbers.
*
* @param [skipNA\] Whether to skip null and NaN values (default: `true`).
* Similar to [sum][sum], [Byte][Byte]- and [Short][Short]-columns are converted to [Int][Int].
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sum-sum! Bye-Bye-Bye!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, unfortunately, there are KDoc rendering issues if you don't alias the types here

*
* {@get [CumSumDocs.CUMSUM_PARAM] @param [columns\] The selection of the columns to apply the `cumSum` operation to.
* If not provided, `cumSum` will be applied to all primitive columns [at any depth][ColumnsSelectionDsl.colsAtAnyDepth].
* }
*
* @param [skipNA\] Whether to skip `null` and `NaN` values (default: `true`).
* @return A new {@get [CumSumDocs.DATA_TYPE]} of the same type with the cumulative sums.
*
* {@get [CumSumDocs.CUMSUM_PARAM] @see [Selecting Columns][SelectSelectingOptions].}
Expand All @@ -41,8 +45,11 @@ import kotlin.reflect.KProperty
@ExcludeFromSources
@Suppress("ClassName")
private interface CumSumDocs {

// Can be emptied to disable information about selecting columns
interface CUMSUM_PARAM

// Either [DataColumn] or [DataFrame]
interface DATA_TYPE
}

Expand Down Expand Up @@ -157,10 +164,11 @@ public fun <T> DataFrame<T>.cumSum(
* {@set [CumSumDocs.DATA_TYPE] [DataFrame]}
* {@set [CumSumDocs.CUMSUM_PARAM]}
*/
@Refine
@Interpretable("DataFrameCumSum0")
public fun <T> DataFrame<T>.cumSum(skipNA: Boolean = defaultCumSumSkipNA): DataFrame<T> =
cumSum(skipNA) {
// TODO keep at any depth?
colsAtAnyDepth().filter { it.isNumber() }.cast()
colsAtAnyDepth().filter { it.isPrimitiveOrMixedNumber() }.cast()
}

// endregion
Expand Down Expand Up @@ -212,10 +220,11 @@ public fun <T, G> GroupBy<T, G>.cumSum(
* {@set [CumSumDocs.DATA_TYPE] [GroupBy]}
* {@set [CumSumDocs.CUMSUM_PARAM]}
*/
@Refine
@Interpretable("GroupByCumSum0")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are marking the functions for Compiler Plugin, should we add a test for this trivial case?
@koperagen

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

testing what? and how?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, usually i simply annotate functions and later adjust if anything is wrong (one time i forgot Refine annotation)

public fun <T, G> GroupBy<T, G>.cumSum(skipNA: Boolean = defaultCumSumSkipNA): GroupBy<T, G> =
cumSum(skipNA) {
// TODO keep at any depth?
colsAtAnyDepth().filter { it.isNumber() }.cast()
colsAtAnyDepth().filter { it.isPrimitiveOrMixedNumber() }.cast()
}

// endregion
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ internal fun DataColumn<Long?>.cumSumImpl(skipNA: Boolean): DataColumn<Long?> {
* T : Number(?) -> T(?)
*/
public val cumSumTypeConversion: CalculateReturnType = { type, _ ->
when (val type = type.withNullability(false)) {
when (type.withNullability(false)) {
// type changes to Int, carrying nullability
typeOf<Short>(), typeOf<Byte>() -> typeOf<Int>().withNullability(type.isMarkedNullable)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.jetbrains.kotlinx.dataframe.statistics

import io.kotest.assertions.throwables.shouldThrow
import io.kotest.matchers.shouldBe
import org.jetbrains.kotlinx.dataframe.DataColumn
import org.jetbrains.kotlinx.dataframe.api.columnOf
Expand All @@ -8,7 +9,10 @@ import org.jetbrains.kotlinx.dataframe.api.cumSum
import org.jetbrains.kotlinx.dataframe.api.dataFrameOf
import org.jetbrains.kotlinx.dataframe.api.groupBy
import org.jetbrains.kotlinx.dataframe.api.map
import org.jetbrains.kotlinx.dataframe.impl.nullableNothingType
import org.jetbrains.kotlinx.dataframe.math.cumSumTypeConversion
import org.junit.Test
import kotlin.reflect.typeOf

@Suppress("ktlint:standard:argument-list-wrapping")
class CumsumTests {
Expand Down Expand Up @@ -92,4 +96,39 @@ class CumsumTests {
"c", 4,
)
}

@Test
fun `df cumSum default`() {
val df = dataFrameOf(
"doubles" to columnOf(1.0, 2.0, null),
"shorts" to columnOf(1.toShort(), 2.toShort(), null),
"bigInts" to columnOf(1.toBigInteger(), 2.toBigInteger(), null),
"mixed" to columnOf<Number?>(1.0, 2, null),
)

val res = df.cumSum()

// works for Doubles, turns nulls into NaNs
res["doubles"].values() shouldBe columnOf(1.0, 3.0, Double.NaN).values()
// works for Shorts, turns into Ints, skips nulls
res["shorts"].values() shouldBe columnOf(1, 3, null).values()
// does not work for big numbers, keeps them as is
res["bigInts"].values() shouldBe columnOf(1.toBigInteger(), 2.toBigInteger(), null).values()
// works for mixed columns of primitives, number-unifies them; in this case to Doubles
res["mixed"].values() shouldBe columnOf(1.0, 3.0, Double.NaN).values()
}

@Test
fun `cumSumTypeConversion tests`() {
cumSumTypeConversion(typeOf<Int>(), false) shouldBe typeOf<Int>()
cumSumTypeConversion(typeOf<Long?>(), false) shouldBe typeOf<Long?>()
cumSumTypeConversion(typeOf<Short?>(), false) shouldBe typeOf<Int?>()
cumSumTypeConversion(typeOf<Byte>(), false) shouldBe typeOf<Int>()
cumSumTypeConversion(typeOf<Float?>(), false) shouldBe typeOf<Float>()
cumSumTypeConversion(typeOf<Double?>(), false) shouldBe typeOf<Double>()
cumSumTypeConversion(typeOf<Double>(), false) shouldBe typeOf<Double>()
cumSumTypeConversion(nullableNothingType, false) shouldBe nullableNothingType

shouldThrow<IllegalStateException> { cumSumTypeConversion(typeOf<String>(), false) }
}
}