diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/cumSum.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/cumSum.kt index b32409c692..cede3a6932 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/cumSum.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/cumSum.kt @@ -24,15 +24,19 @@ import kotlin.reflect.KProperty * from the first cell to the last cell. * * __NOTE:__ If the column contains nullable values and [skipNA\] is set to `true`, - * null and NaN values are skipped when computing the cumulative sum. - * When false, all values after the first NA will be NaN (for Double and Float columns) - * or null (for integer columns). + * `null` and `NaN` values are skipped when computing the cumulative sum. + * When `false`, all values after the first `NA` will be `NaN` (for [Double] and [Float] columns) + * or `null` (for other columns). * - * {@get [CumSumDocs.CUMSUM_PARAM] @param [columns\] - * The names of the columns to apply cumSum operation.} + * `cumSum` only works on columns that contain solely primitive numbers. * - * @param [skipNA\] Whether to skip null and NaN values (default: `true`). + * Similar to [sum][sum], [Byte][Byte]- and [Short][Short]-columns are converted to [Int][Int]. * + * {@get [CumSumDocs.CUMSUM_PARAM] @param [columns\] The selection of the columns to apply the `cumSum` operation to. + * If not provided, `cumSum` will be applied to all primitive columns [at any depth][ColumnsSelectionDsl.colsAtAnyDepth]. + * } + * + * @param [skipNA\] Whether to skip `null` and `NaN` values (default: `true`). * @return A new {@get [CumSumDocs.DATA_TYPE]} of the same type with the cumulative sums. * * {@get [CumSumDocs.CUMSUM_PARAM] @see [Selecting Columns][SelectSelectingOptions].} @@ -41,8 +45,11 @@ import kotlin.reflect.KProperty @ExcludeFromSources @Suppress("ClassName") private interface CumSumDocs { + + // Can be emptied to disable information about selecting columns interface CUMSUM_PARAM + // Either [DataColumn] or [DataFrame] interface DATA_TYPE } @@ -157,10 +164,11 @@ public fun DataFrame.cumSum( * {@set [CumSumDocs.DATA_TYPE] [DataFrame]} * {@set [CumSumDocs.CUMSUM_PARAM]} */ +@Refine +@Interpretable("DataFrameCumSum0") public fun DataFrame.cumSum(skipNA: Boolean = defaultCumSumSkipNA): DataFrame = cumSum(skipNA) { - // TODO keep at any depth? - colsAtAnyDepth().filter { it.isNumber() }.cast() + colsAtAnyDepth().filter { it.isPrimitiveOrMixedNumber() }.cast() } // endregion @@ -212,10 +220,11 @@ public fun GroupBy.cumSum( * {@set [CumSumDocs.DATA_TYPE] [GroupBy]} * {@set [CumSumDocs.CUMSUM_PARAM]} */ +@Refine +@Interpretable("GroupByCumSum0") public fun GroupBy.cumSum(skipNA: Boolean = defaultCumSumSkipNA): GroupBy = cumSum(skipNA) { - // TODO keep at any depth? - colsAtAnyDepth().filter { it.isNumber() }.cast() + colsAtAnyDepth().filter { it.isPrimitiveOrMixedNumber() }.cast() } // endregion diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/cumsum.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/cumsum.kt index 6f73220b66..2e7c6b505b 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/cumsum.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/cumsum.kt @@ -292,7 +292,7 @@ internal fun DataColumn.cumSumImpl(skipNA: Boolean): DataColumn { * T : Number(?) -> T(?) */ public val cumSumTypeConversion: CalculateReturnType = { type, _ -> - when (val type = type.withNullability(false)) { + when (type.withNullability(false)) { // type changes to Int, carrying nullability typeOf(), typeOf() -> typeOf().withNullability(type.isMarkedNullable) diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/cumsum.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/cumsum.kt index 05fc9438ce..04a44bbe10 100644 --- a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/cumsum.kt +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/cumsum.kt @@ -1,5 +1,6 @@ package org.jetbrains.kotlinx.dataframe.statistics +import io.kotest.assertions.throwables.shouldThrow import io.kotest.matchers.shouldBe import org.jetbrains.kotlinx.dataframe.DataColumn import org.jetbrains.kotlinx.dataframe.api.columnOf @@ -8,7 +9,10 @@ import org.jetbrains.kotlinx.dataframe.api.cumSum import org.jetbrains.kotlinx.dataframe.api.dataFrameOf import org.jetbrains.kotlinx.dataframe.api.groupBy import org.jetbrains.kotlinx.dataframe.api.map +import org.jetbrains.kotlinx.dataframe.impl.nullableNothingType +import org.jetbrains.kotlinx.dataframe.math.cumSumTypeConversion import org.junit.Test +import kotlin.reflect.typeOf @Suppress("ktlint:standard:argument-list-wrapping") class CumsumTests { @@ -92,4 +96,39 @@ class CumsumTests { "c", 4, ) } + + @Test + fun `df cumSum default`() { + val df = dataFrameOf( + "doubles" to columnOf(1.0, 2.0, null), + "shorts" to columnOf(1.toShort(), 2.toShort(), null), + "bigInts" to columnOf(1.toBigInteger(), 2.toBigInteger(), null), + "mixed" to columnOf(1.0, 2, null), + ) + + val res = df.cumSum() + + // works for Doubles, turns nulls into NaNs + res["doubles"].values() shouldBe columnOf(1.0, 3.0, Double.NaN).values() + // works for Shorts, turns into Ints, skips nulls + res["shorts"].values() shouldBe columnOf(1, 3, null).values() + // does not work for big numbers, keeps them as is + res["bigInts"].values() shouldBe columnOf(1.toBigInteger(), 2.toBigInteger(), null).values() + // works for mixed columns of primitives, number-unifies them; in this case to Doubles + res["mixed"].values() shouldBe columnOf(1.0, 3.0, Double.NaN).values() + } + + @Test + fun `cumSumTypeConversion tests`() { + cumSumTypeConversion(typeOf(), false) shouldBe typeOf() + cumSumTypeConversion(typeOf(), false) shouldBe typeOf() + cumSumTypeConversion(typeOf(), false) shouldBe typeOf() + cumSumTypeConversion(typeOf(), false) shouldBe typeOf() + cumSumTypeConversion(typeOf(), false) shouldBe typeOf() + cumSumTypeConversion(typeOf(), false) shouldBe typeOf() + cumSumTypeConversion(typeOf(), false) shouldBe typeOf() + cumSumTypeConversion(nullableNothingType, false) shouldBe nullableNothingType + + shouldThrow { cumSumTypeConversion(typeOf(), false) } + } }