Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] [New Query Planner] Add support for `df.count_rows(). #1273

Merged
merged 1 commit into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pyarrow as pa

from daft import context
from daft.daft import ImageFormat
from daft.daft import CountMode, ImageFormat
from daft.daft import PyExpr as _PyExpr
from daft.daft import col as _col
from daft.daft import lit as _lit
Expand Down Expand Up @@ -275,8 +275,8 @@ def cast(self, dtype: DataType) -> Expression:
expr = self._expr.cast(dtype._dtype)
return Expression._from_pyexpr(expr)

def _count(self) -> Expression:
expr = self._expr.count()
def _count(self, mode: CountMode = CountMode.Valid) -> Expression:
expr = self._expr.count(mode)
return Expression._from_pyexpr(expr)

def _sum(self) -> Expression:
Expand Down
7 changes: 5 additions & 2 deletions daft/logical/rust_logical_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from daft import DataType, col
from daft.context import get_context
from daft.daft import FileFormat, FileFormatConfig, JoinType
from daft.daft import CountMode, FileFormat, FileFormatConfig, JoinType
from daft.daft import LogicalPlanBuilder as _LogicalPlanBuilder
from daft.daft import PartitionScheme, PartitionSpec, ResourceRequest
from daft.errors import ExpressionTypeError
Expand Down Expand Up @@ -130,7 +130,10 @@
return RustLogicalPlanBuilder(builder)

def count(self) -> RustLogicalPlanBuilder:
raise NotImplementedError("not implemented")
# TODO(Clark): Add dedicated logical/physical ops when introducing metadata-based count optimizations.
first_col = col(self.schema().column_names()[0])
builder = self._builder.aggregate([first_col._count(CountMode.All)], [])
return RustLogicalPlanBuilder(builder)

Check warning on line 136 in daft/logical/rust_logical_plan.py

View check run for this annotation

Codecov / codecov/patch

daft/logical/rust_logical_plan.py#L134-L136

Added lines #L134 - L136 were not covered by tests

def distinct(self) -> RustLogicalPlanBuilder:
builder = self._builder.distinct()
Expand Down
6 changes: 3 additions & 3 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pyarrow as pa

from daft.arrow_utils import ensure_array, ensure_chunked_array
from daft.daft import ImageFormat, PySeries
from daft.daft import CountMode, ImageFormat, PySeries
from daft.datatype import DataType
from daft.utils import pyarrow_supports_fixed_shape_tensor

Expand Down Expand Up @@ -434,9 +434,9 @@ def __xor__(self, other: object) -> Series:
assert self._series is not None and other._series is not None
return Series._from_pyseries(self._series ^ other._series)

def _count(self) -> Series:
def _count(self, mode: CountMode = CountMode.Valid) -> Series:
assert self._series is not None
return Series._from_pyseries(self._series._count())
return Series._from_pyseries(self._series._count(mode))

def _min(self) -> Series:
assert self._series is not None
Expand Down
42 changes: 29 additions & 13 deletions src/daft-core/src/array/ops/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::sync::Arc;

use arrow2;

use crate::{array::DataArray, datatypes::*};
use crate::{array::DataArray, count_mode::CountMode, datatypes::*};
use common_error::DaftResult;

use super::{DaftCountAggable, GroupIndices};
Expand All @@ -13,31 +13,47 @@ where
{
type Output = DaftResult<DataArray<UInt64Type>>;

fn count(&self) -> Self::Output {
fn count(&self, mode: CountMode) -> Self::Output {
let arrow_array = &self.data;
let count = arrow_array.len() - arrow_array.null_count();
let count = match mode {
CountMode::All => arrow_array.len(),
CountMode::Valid => arrow_array.len() - arrow_array.null_count(),
CountMode::Null => arrow_array.null_count(),
};
let result_arrow_array =
Box::new(arrow2::array::PrimitiveArray::from([Some(count as u64)]));
DataArray::<UInt64Type>::new(
Arc::new(Field::new(self.field.name.clone(), DataType::UInt64)),
result_arrow_array,
)
}
fn grouped_count(&self, groups: &GroupIndices) -> Self::Output {
fn grouped_count(&self, groups: &GroupIndices, mode: CountMode) -> Self::Output {
let arrow_array = self.data.as_ref();

let counts_per_group: Vec<_> = if arrow_array.null_count() > 0 {
groups
let counts_per_group: Vec<_> = match mode {
CountMode::All => groups.iter().map(|g| g.len() as u64).collect(),
CountMode::Valid => {
if arrow_array.null_count() > 0 {
groups
.iter()
.map(|g| {
let null_count = g
.iter()
.fold(0u64, |acc, v| acc + arrow_array.is_null(*v as usize) as u64);
(g.len() as u64) - null_count
})
.collect()
} else {
groups.iter().map(|g| g.len() as u64).collect()
}
}
CountMode::Null => groups
.iter()
.map(|g| {
let null_count = g
.iter()
.fold(0u64, |acc, v| acc + arrow_array.is_null(*v as usize) as u64);
(g.len() as u64) - null_count
g.iter()
.fold(0u64, |acc, v| acc + arrow_array.is_null(*v as usize) as u64)
})
.collect()
} else {
groups.iter().map(|g| g.len() as u64).collect()
.collect(),
};

Ok(DataArray::<UInt64Type>::from((
Expand Down
7 changes: 5 additions & 2 deletions src/daft-core/src/array/ops/mean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::sync::Arc;

use arrow2;

use crate::count_mode::CountMode;
use crate::{array::DataArray, datatypes::*};

use common_error::DaftResult;
Expand All @@ -16,7 +17,9 @@ impl DaftMeanAggable for &DataArray<Float64Type> {

fn mean(&self) -> Self::Output {
let sum_value = DaftSumAggable::sum(self)?.as_arrow().value(0);
let count_value = DaftCountAggable::count(self)?.as_arrow().value(0);
let count_value = DaftCountAggable::count(self, CountMode::Valid)?
.as_arrow()
.value(0);

let result = match count_value {
0 => None,
Expand All @@ -33,7 +36,7 @@ impl DaftMeanAggable for &DataArray<Float64Type> {
fn grouped_mean(&self, groups: &GroupIndices) -> Self::Output {
use arrow2::array::PrimitiveArray;
let sum_values = self.grouped_sum(groups)?;
let count_values = self.grouped_count(groups)?;
let count_values = self.grouped_count(groups, CountMode::Valid)?;
assert_eq!(sum_values.len(), count_values.len());
let mean_per_group = sum_values
.as_arrow()
Expand Down
6 changes: 4 additions & 2 deletions src/daft-core/src/array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ pub use sort::{build_multi_array_bicompare, build_multi_array_compare};

use common_error::DaftResult;

use crate::count_mode::CountMode;

pub trait DaftCompare<Rhs> {
type Output;

Expand Down Expand Up @@ -93,8 +95,8 @@ pub trait IntoGroups {

pub trait DaftCountAggable {
type Output;
fn count(&self) -> Self::Output;
fn grouped_count(&self, groups: &GroupIndices) -> Self::Output;
fn count(&self, mode: CountMode) -> Self::Output;
fn grouped_count(&self, groups: &GroupIndices, mode: CountMode) -> Self::Output;
}

pub trait DaftSumAggable {
Expand Down
94 changes: 94 additions & 0 deletions src/daft-core/src/count_mode.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#[cfg(feature = "python")]
use pyo3::{
exceptions::PyValueError,
prelude::*,
types::{PyBytes, PyTuple},
};
use serde::{Deserialize, Serialize};
use std::fmt::{Display, Formatter, Result};
use std::str::FromStr;
use std::string::ToString;

use crate::impl_bincode_py_state_serialization;

use common_error::{DaftError, DaftResult};

/// Supported count modes for Daft's count aggregation.
///
/// | All - Count both non-null and null values.
/// | Valid - Count only valid values.
/// | Null - Count only null values.
#[allow(clippy::upper_case_acronyms)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)]
#[cfg_attr(feature = "python", pyclass)]
pub enum CountMode {
All = 1,
Valid = 2,
Null = 3,
}

#[cfg(feature = "python")]
#[pymethods]
impl CountMode {
#[new]
#[pyo3(signature = (*args))]
pub fn new(args: &PyTuple) -> PyResult<Self> {
match args.len() {
// Create dummy variant, to be overridden by __setstate__.
0 => Ok(Self::All),
_ => Err(PyValueError::new_err(format!(
"expected no arguments to make new JoinType, got : {}",
args.len()
))),
}
}

/// Create a CountMode from its string representation.
///
/// Args:
/// count_mode: String representation of the count mode , e.g. "all", "valid", or "null".
#[staticmethod]
pub fn from_count_mode_str(count_mode: &str) -> PyResult<Self> {
Self::from_str(count_mode).map_err(|e| PyValueError::new_err(e.to_string()))
}
pub fn __str__(&self) -> PyResult<String> {
Ok(self.to_string())
}
}

impl_bincode_py_state_serialization!(CountMode);

impl CountMode {
pub fn iterator() -> std::slice::Iter<'static, CountMode> {
use CountMode::*;

static COUNT_MODES: [CountMode; 3] = [All, Valid, Null];
COUNT_MODES.iter()
}
}

impl FromStr for CountMode {
type Err = DaftError;

fn from_str(count_mode: &str) -> DaftResult<Self> {
use CountMode::*;

match count_mode {
"all" => Ok(All),
"valid" => Ok(Valid),
"null" => Ok(Null),
_ => Err(DaftError::TypeError(format!(
"Count mode {} is not supported; only the following modes are supported: {:?}",
count_mode,
CountMode::iterator().as_slice()
))),
}
}
}

impl Display for CountMode {
fn fmt(&self, f: &mut Formatter) -> Result {
// Leverage Debug trait implementation, which will already return the enum variant as a string.
write!(f, "{:?}", self)
}
}
11 changes: 11 additions & 0 deletions src/daft-core/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![feature(let_chains)]

pub mod array;
pub mod count_mode;
pub mod datatypes;
#[cfg(feature = "python")]
pub mod ffi;
Expand All @@ -10,7 +11,10 @@ pub mod python;
pub mod schema;
pub mod series;
pub mod utils;
#[cfg(feature = "python")]
use pyo3::prelude::*;

pub use count_mode::CountMode;
pub use datatypes::DataType;
pub use series::{IntoSeries, Series};

Expand All @@ -23,3 +27,10 @@ pub const DAFT_BUILD_TYPE: &str = {
None => BUILD_TYPE_DEV,
}
};

#[cfg(feature = "python")]
pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> {
parent.add_class::<CountMode>()?;

Ok(())
}
5 changes: 3 additions & 2 deletions src/daft-core/src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use pyo3::{exceptions::PyValueError, prelude::*, pyclass::CompareOp, types::PyLi

use crate::{
array::{ops::DaftLogical, pseudo_arrow::PseudoArrowArray, DataArray},
count_mode::CountMode,
datatypes::{DataType, Field, ImageFormat, ImageMode, PythonType, UInt64Type},
ffi,
series::{self, IntoSeries, Series},
Expand Down Expand Up @@ -176,8 +177,8 @@ impl PySeries {
Ok((&self.series).not()?.into())
}

pub fn _count(&self) -> PyResult<Self> {
Ok((self.series).count(None)?.into())
pub fn _count(&self, mode: CountMode) -> PyResult<Self> {
Ok((self.series).count(None, mode)?.into())
}

pub fn _sum(&self) -> PyResult<Self> {
Expand Down
7 changes: 4 additions & 3 deletions src/daft-core/src/series/ops/agg.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
use crate::count_mode::CountMode;
use crate::series::IntoSeries;
use crate::{array::ops::GroupIndices, series::Series, with_match_physical_daft_types};
use common_error::{DaftError, DaftResult};

use crate::datatypes::*;

impl Series {
pub fn count(&self, groups: Option<&GroupIndices>) -> DaftResult<Series> {
pub fn count(&self, groups: Option<&GroupIndices>, mode: CountMode) -> DaftResult<Series> {
use crate::array::ops::DaftCountAggable;
let s = self.as_physical()?;
with_match_physical_daft_types!(s.data_type(), |$T| {
match groups {
Some(groups) => Ok(DaftCountAggable::grouped_count(&s.downcast::<$T>()?, groups)?.into_series()),
None => Ok(DaftCountAggable::count(&s.downcast::<$T>()?)?.into_series())
Some(groups) => Ok(DaftCountAggable::grouped_count(&s.downcast::<$T>()?, groups, mode)?.into_series()),
None => Ok(DaftCountAggable::count(&s.downcast::<$T>()?, mode)?.into_series())
}
})
}
Expand Down
Loading
Loading