Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve intersectional slicing #606

Merged
merged 1 commit into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 55 additions & 2 deletions cyclops/data/slicer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Functions and classes for creating subsets of Hugging Face datasets."""

import copy
import datetime
import itertools
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
Expand Down Expand Up @@ -60,10 +62,16 @@ class SliceSpec:
- `keep_nulls`: A boolean flag indicating whether to keep rows where the value
is null. If used in conjunction with `negate`, the slice selects rows where
the value is not null. Can be used on its own. Defaults to False.
validate : bool, default=True
Whether to validate the column names in the slice specifications.
intersections : List[Tuple[int]], int, optional, default=None
An indication of slices to intersect. If a list of tuples is provided, the
tuples should contain the indices of the slices to intersect. If an integer is
provided, it will be passed as the argument `r` in `itertools.combinations`,
and all combinations of `r` slices will be intersected. The intersections are
created _before_ the slices are registered.
include_overall : bool, default=True
Whether to include an `overall` slice that selects all examples.
validate : bool, default=True
Whether to validate the column names in the slice specifications.
column_names : List[str], optional, default=None
List of column names in the dataset. If provided and `validate` is True, it is
used to validate the column names in the slice specifications.
Expand Down Expand Up @@ -137,6 +145,22 @@ class SliceSpec:
feature_1:value_1&feature_2:[2020-01-01 - inf]&feature_3:year=['2000', '2010', '2020']
overall

# a different way to create intersections/compound slices
>>> slice_spec = SliceSpec(
... spec_list=[
... {"feature_1": {"keep_nulls": False}},
... {"feature_2": {"keep_nulls": False}},
... ],
... include_overall=False,
... intersections=2,
... )
>>> for slice_name, slice_func in slice_spec.slices():
... print(slice_name)
... # do something with slice_func here (e.g. dataset.filter(slice_func))
feature_1:non_null
feature_2:non_null
feature_1:non_null&feature_2:non_null

""" # noqa: W505

spec_list: List[Dict[str, Dict[str, Any]]] = field(
Expand All @@ -146,6 +170,7 @@ class SliceSpec:
hash=True,
compare=True,
)
intersections: Optional[Union[List[Tuple[int, ...]], int]] = None
validate: bool = True
include_overall: bool = True
column_names: Optional[List[str]] = None
Expand All @@ -160,6 +185,9 @@ class SliceSpec:

def __post_init__(self) -> None:
"""Create and register slice functions out of the slice specifications."""
self.spec_list = copy.deepcopy(self.spec_list)
if self.intersections is not None:
self._create_intersections()
for slice_spec in self.spec_list:
self._parse_and_register_slice_specs(slice_spec)

Expand Down Expand Up @@ -192,6 +220,31 @@ def slices(self) -> Generator[Tuple[str, Callable[..., Any]], None, None]:
for registration_key, slice_function in self._registry.items():
yield registration_key, slice_function

def _create_intersections(self) -> None:
"""Create intersections of slices."""
intersect_list = []
if isinstance(self.intersections, list) and isinstance(
self.intersections[0], tuple
):
for intersection in self.intersections:
intersect_dict = {}
for index in set(intersection):
intersect_dict.update(self.spec_list[index])
intersect_list.append(intersect_dict)
elif isinstance(self.intersections, int):
combinations = itertools.combinations(self.spec_list, self.intersections)
for combination in combinations:
intersect_dict = {}
for slice_ in combination:
intersect_dict.update(slice_)
intersect_list.append(intersect_dict)
else:
raise ValueError(
"Expected `intersections` to be a list of tuples or an integer. "
f"Got {self.intersections} instead.",
)
self.spec_list.extend(intersect_list)

def _parse_and_register_slice_specs(
self,
slice_spec: Dict[str, Dict[str, Any]],
Expand Down
29 changes: 29 additions & 0 deletions tests/cyclops/data/test_slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,3 +526,32 @@ def test_slice_spec():
month,
),
)


def test_create_intersection():
"""Test the creation of slice intersections."""
spec_list = [
{"feature_1": {"value": "value_1"}},
{"feature_2": {"min_value": "2020-01-01", "keep_nulls": False}},
{"feature_3": {"year": ["2000", "2010", "2020"]}},
]

slice_spec = SliceSpec(spec_list)
assert slice_spec.spec_list == spec_list

intersect_list = [
{"feature_1": {"value": "value_1"}},
{"feature_2": {"min_value": "2020-01-01", "keep_nulls": False}},
{"feature_3": {"year": ["2000", "2010", "2020"]}},
{
"feature_1": {"value": "value_1"},
"feature_2": {"min_value": "2020-01-01", "keep_nulls": False},
"feature_3": {"year": ["2000", "2010", "2020"]},
},
]

int_slice_spec1 = SliceSpec(spec_list, intersections=[(0, 1, 2)])
assert int_slice_spec1.spec_list == intersect_list

int_slice_spec2 = SliceSpec(spec_list, intersections=3)
assert int_slice_spec2.spec_list == intersect_list
Loading