---
title: Utils for Polars
---

In [None]:
#| default_exp polars

In [None]:
#| export
import polars as pl
from typing import Any, Collection

## IO

In [None]:
#| export
def convert_to_pd_dataframe(
    df: pl.DataFrame | pl.LazyFrame, # original DataFrame or LazyFrame
):
    """
    Convert a Polars DataFrame or LazyFrame into a pandas-like DataFrame.
    """
    if isinstance(df, pl.LazyFrame):
        df = df.collect()
    elif not isinstance(df, pl.DataFrame):
        raise TypeError("Input must be a Polars DataFrame or LazyFrame")

    data = df.to_pandas(use_pyarrow_extension_array=True)

    return data


## Functions

In [None]:
#| export
def sort(df: pl.DataFrame, col="time"):
    if df.get_column(col).is_sorted():
        return df.set_sorted(col)
    else:
        return df.sort(col)


In [None]:
#| export
def _expand_selectors(items: Any, *more_items: Any) -> list[Any]:
    """
    See `_expand_selectors` in `polars`.
    """
    expanded: list[Any] = []
    for item in (
        *(
            items
            if isinstance(items, Collection) and not isinstance(items, str)
            else [items]
        ),
        *more_items,
    ):
        expanded.append(item)
    return expanded

In [None]:
#| export

def pl_norm(columns, *more_columns) -> pl.Expr:
    """
    Computes the square root of the sum of squares for the given columns.

    Args:
    *columns (str): Names of the columns.

    Returns:
    pl.Expr: Expression representing the square root of the sum of squares.
    """
    all_columns = _expand_selectors(columns, *more_columns)
    squares = [pl.col(column).pow(2) for column in all_columns]

    return sum(squares).sqrt()

In [None]:
# | export
def decompose_vector(
    df: pl.DataFrame, vector_col, name=None, suffixes: list = ["_x", "_y", "_z"]
):
    """
    Decompose a vector column in a DataFrame into separate columns for each component with custom suffixes.

    Parameters:
    - df (pl.DataFrame): The input DataFrame.
    - vector_col (str): The name of the vector column to decompose.
    - name (str, optional): Base name for the decomposed columns. If None, uses `vector_col` as the base name.
    - suffixes (list, optional): A list of suffixes to use for the decomposed columns.
      If None or not enough suffixes are provided, defaults to '_0', '_1', etc.

    Returns:
    - pl.DataFrame: A DataFrame with the original vector column decomposed into separate columns.
    """

    if name is None:
        name = vector_col

    # Determine the maximum length of vectors in the column to handle dynamic vector lengths
    max_length = df.select(pl.col(vector_col).list.len()).max()[0, 0]

    if suffixes is None or len(suffixes) < max_length:
        if suffixes is None:
            suffixes = []
        # Extend or create the list of suffixes with default values
        suffixes.extend([f"_{i}" for i in range(len(suffixes), max_length)])

    # Create column expressions for each element in the vector
    column_expressions = [
        pl.col(vector_col).list.get(i).alias(name).name.suffix(suffixes[i])
        for i in range(max_length)
    ]

    return df.with_columns(column_expressions)

## Fast filter for a list of predicates

[Use a list of filters within polars - Stack Overflow](https://stackoverflow.com/questions/74993391/use-a-list-of-filters-within-polars)

In [None]:
#| export
def filter_series_by_ranges_i(data: pl.Series, starts: list, stops: list):
    starts_index = data.search_sorted(starts)
    ends_index = data.search_sorted(stops, side="right")

    return pl.concat(
        pl.arange(*range, eager=True) for range in zip(starts_index, ends_index)
    ).unique()


def filter_df_by_ranges(data: pl.DataFrame, starts: list, stops: list, col="time"):
    """
    Filter a DataFrame from ranges
    """

    indices_unique = filter_series_by_ranges_i(data[col], starts, stops)
    return data[indices_unique]

In [None]:
#| hide
def _filter_series_by_ranges_i(data: pl.Series, starts: list, stops: list):
    return pl.concat(data.is_between(*range) for range in zip(starts, stops))


def _filter_df_by_intervals(data: pl.DataFrame, starts: list, stops: list, col="time"):
    """
    Filter a DataFrame based on intervals defined by start and stop times.

    Parameters:
    - data (pl.DataFrame): The DataFrame to be filtered.
    - starts (list): A list of start times for the intervals.
    - stops (list): A list of stop times for the intervals.

    Returns:
    - pl.DataFrame: The filtered DataFrame containing rows within the specified intervals.
    """
    predicates = pl.any_horizontal(
        pl.col(col).is_between(*range) for range in zip(starts, stops)
    )

    return data.filter(predicates)

In [None]:
def sample_data():
    return pl.DataFrame({
        "time": pl.arange(10),
    })

def test_filter_df_by_intervals(sample_data):
    filtered_data = filter_df_by_ranges(sample_data, [1, 5], [3, 7])
    assert len(filtered_data) == 6
    assert filtered_data["time"].min() == 1
    assert filtered_data["time"].max() == 7

def test_filter_df_by_intervals_no_match(sample_data):
    filtered_data = filter_df_by_ranges(sample_data, [100, 200], [300, 400])
    assert len(filtered_data) == 0

def test_filter_df_by_intervals_edge_case(sample_data):
    filtered_data = filter_df_by_ranges(sample_data, [1, 1], [1, 1])
    assert len(filtered_data) == 1

In [None]:
_sample_data = sample_data()
test_filter_df_by_intervals(_sample_data)
test_filter_df_by_intervals_no_match(_sample_data)
test_filter_df_by_intervals_edge_case(_sample_data)

In [None]:
n = 1000000
data = pl.DataFrame({
    "time": np.arange(n),
})

starts = list(range(0, n-200, 100))
stops = list(range(100, n-100, 100))

%time filter_df_by_ranges(data, starts, stops)

CPU times: user 38.1 ms, sys: 8.89 ms, total: 47 ms
Wall time: 45.6 ms


time
i64
0
1
2
3
4
…
999796
999797
999798
999799
