diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 7c6a50acc..8bcbfefdd 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -186,9 +186,9 @@ Note: in order to pass the full test suite (step 5), you'll need to install all ### Tests -An extensive test suite is included to test the library behavior and several examples. +An extensive test suite is included to test the library behavior. Library tests can be found in the -[tests folder](https://github.com/robustness-gym/robustness-gym/tree/main/tests). +[tests folder](https://github.com/robustness-gym/meerkat/tree/main/tests). From the root of the repository, here's how to run tests with `pytest` for the library: @@ -200,15 +200,13 @@ $ make test You can specify a smaller set of tests in order to test only the feature you're working on. -Meerkat uses `pytest` as a test runner only. It doesn't use any -`pytest`-specific features in the test suite itself. - -This means `unittest` is fully supported. Here's how to run tests with -`unittest`: - -```bash -$ python -m unittest discover -s tests -t . -v +Per the checklist above, all PRs should include high-coverage tests. +To produce a code coverage report, run the following `pytest` +``` +pytest --cov-report term-missing,html --cov=meerkat . ``` +This will populate a directory `htmlcov` with an HTML report. +Open `htmlcov/index.html` in a browser to view the report. ### Style guide diff --git a/README.md b/README.md index 864106ee9..3a5aeb707 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ ![GitHub](https://img.shields.io/github/license/robustness-gym/meerkat) [![Documentation Status](https://readthedocs.org/projects/meerkat/badge/?version=latest)](https://meerkat.readthedocs.io/en/latest/?badge=latest) [![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)](https://github.com/pre-commit/pre-commit) - +[![codecov](https://codecov.io/gh/robustness-gym/meerkat/branch/main/graph/badge.svg?token=MOLQYUSYQU)](https://codecov.io/gh/robustness-gym/meerkat) Meerkat provides fast and flexible data structures for working with complex machine learning datasets. diff --git a/meerkat/block/manager.py b/meerkat/block/manager.py index 8ca509375..38e54ba6e 100644 --- a/meerkat/block/manager.py +++ b/meerkat/block/manager.py @@ -3,7 +3,7 @@ import os from collections import defaultdict from collections.abc import MutableMapping -from typing import Dict, Sequence, Union +from typing import Dict, Mapping, Sequence, Union import pandas as pd import yaml @@ -18,16 +18,20 @@ class BlockManager(MutableMapping): """This manager manages all blocks.""" def __init__(self) -> None: - self._columns: Dict[str, AbstractColumn] = {} + self._columns: Dict[str, AbstractColumn] = {} # ordered as of 3.7 self._column_to_block_id: Dict[str, int] = {} self._block_refs: Dict[int, BlockRef] = {} def update(self, block_ref: BlockRef): """data (): a single blockable object, potentially contains multiple columns.""" + for name in block_ref: + if name in self: + self.remove(name) + # although we can't have the same column living in multiple managers # we don't view here because it can lead to multiple calls to clone - self._columns.update({name: column for name, column in block_ref.items()}) + self._columns.update(block_ref) block_id = id(block_ref.block) # check if there already is a block_ref in the manager for this block @@ -66,9 +70,15 @@ def apply(self, method_name: str = "_get", *args, **kwargs) -> BlockManager: results[name] = result + if isinstance(results, BlockManager): + results.reorder(self.keys()) return results def consolidate(self): + column_order = list( + self._columns.keys() + ) # need to maintain order after consolidate + block_ref_groups = defaultdict(list) for block_ref in self._block_refs.values(): block_ref_groups[block_ref.block.signature].append(block_ref) @@ -78,15 +88,13 @@ def consolidate(self): # if there is only one block ref in the group, do not consolidate continue - # remove old block_refs - for old_ref in block_refs: - self._block_refs.pop(id(old_ref.block)) - # consolidate group block_class = block_refs[0].block.__class__ block_ref = block_class.consolidate(block_refs) self.update(block_ref) + self.reorder(column_order) + def remove(self, name): if name not in self._columns: raise ValueError(f"Remove failed: no column '{name}' in BlockManager.") @@ -103,6 +111,11 @@ def remove(self, name): self._column_to_block_id.pop(name) + def reorder(self, order: Sequence[str]): + if set(order) != set(self._columns): + raise ValueError("Must include all columns when reordering a BlockManager.") + self._columns = {name: self._columns[name] for name in order} + def __getitem__( self, index: Union[str, Sequence[str]] ) -> Union[AbstractColumn, BlockManager]: @@ -128,6 +141,7 @@ def __getitem__( for block_id, names in block_id_to_names.items(): block_ref = self._block_refs[block_id] mgr.update(block_ref[names]) + mgr.reorder(order=index) return mgr else: raise ValueError( @@ -148,6 +162,14 @@ def __delitem__(self, key): def __len__(self): return len(self._columns) + @property + def nrows(self): + return 0 if len(self) == 0 else len(next(iter(self._columns.values()))) + + @property + def ncols(self): + return len(self) + def __contains__(self, value): return value in self._columns @@ -160,8 +182,11 @@ def get_block_ref(self, name: str): def add_column(self, col: AbstractColumn, name: str): """Convert data to a meerkat column using the appropriate Column type.""" - if name in self._columns: - self.remove(name) + if len(self) > 0 and len(col) != self.nrows: + raise ValueError( + f"Cannot add column '{name}' with length {len(col)} to `BlockManager` " + f" with length {self.nrows} columns." + ) if not col.is_blockable(): col = col.view() @@ -172,7 +197,7 @@ def add_column(self, col: AbstractColumn, name: str): self.update(BlockRef(columns={name: col}, block=col._block)) @classmethod - def from_dict(cls, data): + def from_dict(cls, data: Mapping[str, object]): mgr = cls() for name, data in data.items(): col = AbstractColumn.from_data(data) @@ -180,7 +205,11 @@ def from_dict(cls, data): return mgr def write(self, path: str): - meta = {"dtype": BlockManager, "columns": {}} + meta = { + "dtype": BlockManager, + "columns": {}, + "_column_order": list(self.keys()), + } # prepare directories os.makedirs(path, exist_ok=True) @@ -264,7 +293,7 @@ def read( mgr.add_column( col_meta["dtype"].read(path=column_dir, _meta=col_meta), name ) - + mgr.reorder(meta["_column_order"]) return mgr def _repr_pandas_(self): @@ -279,13 +308,13 @@ def _repr_pandas_(self): def view(self): mgr = BlockManager() - for name, col in self._columns.items(): + for name, col in self.items(): mgr.add_column(col.view(), name) return mgr def copy(self): mgr = BlockManager() - for name, col in self._columns.items(): + for name, col in self.items(): mgr.add_column(col.copy(), name) return mgr diff --git a/meerkat/block/numpy_block.py b/meerkat/block/numpy_block.py index 74ed9ee1a..33ba5e1a5 100644 --- a/meerkat/block/numpy_block.py +++ b/meerkat/block/numpy_block.py @@ -5,6 +5,7 @@ from typing import Hashable, Mapping, Sequence, Tuple, Union import numpy as np +import torch from meerkat.block.ref import BlockRef from meerkat.errors import ConsolidationError @@ -109,9 +110,17 @@ def _consolidate( } return BlockRef(block=block, columns=new_columns) + @staticmethod + def _convert_index(index): + if torch.is_tensor(index): + # need to convert to numpy for boolean indexing + return index.numpy() + return index + def _get( self, index, block_ref: BlockRef, materialize: bool = True ) -> Union[BlockRef, dict]: + index = self._convert_index(index) # TODO: check if they're trying to index more than just the row dimension data = self.data[index] if isinstance(index, int): diff --git a/meerkat/block/pandas_block.py b/meerkat/block/pandas_block.py index c0feb9fbc..87faf2b4e 100644 --- a/meerkat/block/pandas_block.py +++ b/meerkat/block/pandas_block.py @@ -5,6 +5,7 @@ from typing import Hashable, Mapping, Sequence, Tuple, Union import pandas as pd +import torch from meerkat.block.ref import BlockRef from meerkat.columns.numpy_column import NumpyArrayColumn @@ -79,6 +80,9 @@ def _consolidate( @staticmethod def _convert_index(index): + if torch.is_tensor(index): + # need to convert to numpy for boolean indexing + return index.numpy() if isinstance(index, NumpyArrayColumn): return index.data if isinstance(index, TensorColumn): diff --git a/meerkat/columns/abstract.py b/meerkat/columns/abstract.py index ec34749aa..fc46e5e53 100644 --- a/meerkat/columns/abstract.py +++ b/meerkat/columns/abstract.py @@ -13,14 +13,12 @@ from meerkat.mixins.blockable import BlockableMixin from meerkat.mixins.cloneable import CloneableMixin from meerkat.mixins.collate import CollateMixin -from meerkat.mixins.identifier import IdentifierMixin from meerkat.mixins.inspect_fn import FunctionInspectorMixin from meerkat.mixins.io import ColumnIOMixin from meerkat.mixins.lambdable import LambdaMixin from meerkat.mixins.mapping import MappableMixin from meerkat.mixins.materialize import MaterializationMixin from meerkat.provenance import ProvenanceMixin, capture_provenance -from meerkat.tools.identifier import Identifier from meerkat.tools.utils import convert_to_batch_column_fn logger = logging.getLogger(__name__) @@ -32,7 +30,6 @@ class AbstractColumn( CollateMixin, ColumnIOMixin, FunctionInspectorMixin, - IdentifierMixin, LambdaMixin, MappableMixin, MaterializationMixin, @@ -46,7 +43,6 @@ class AbstractColumn( def __init__( self, data: Sequence = None, - identifier: Identifier = None, collate_fn: Callable = None, *args, **kwargs, @@ -55,7 +51,6 @@ def __init__( self._set_data(data) super(AbstractColumn, self).__init__( - identifier=identifier, collate_fn=collate_fn, *args, **kwargs, @@ -180,6 +175,12 @@ def _translate_index(self, index): if not self._is_batch_index(index): return index + if isinstance(index, pd.Series): + index = index.values + + if torch.is_tensor(index): + index = index.numpy() + # `index` should return a batch if isinstance(index, slice): # int or slice index => standard list slicing diff --git a/meerkat/columns/numpy_column.py b/meerkat/columns/numpy_column.py index 9580be7c2..bb4fa48e7 100644 --- a/meerkat/columns/numpy_column.py +++ b/meerkat/columns/numpy_column.py @@ -113,6 +113,7 @@ def _set_batch(self, indices, values): self._data[indices] = values def _get(self, index, materialize: bool = True): + index = NumpyBlock._convert_index(index) data = self._data[index] if self._is_batch_index(index): # only create a numpy array column diff --git a/meerkat/datapanel.py b/meerkat/datapanel.py index b6fe3aa96..5e40070b8 100644 --- a/meerkat/datapanel.py +++ b/meerkat/datapanel.py @@ -4,8 +4,6 @@ import logging import os import pathlib -from contextlib import contextmanager -from copy import copy from typing import ( Callable, Dict, @@ -26,8 +24,6 @@ import torch import ujson as json import yaml -from datasets import DatasetInfo, NamedSplit -from datasets.arrow_dataset import DatasetInfoMixin from jsonlines import jsonlines import meerkat @@ -40,7 +36,6 @@ from meerkat.mixins.mapping import MappableMixin from meerkat.mixins.materialize import MaterializationMixin from meerkat.provenance import ProvenanceMixin, capture_provenance -from meerkat.tools.identifier import Identifier from meerkat.tools.utils import convert_to_batch_fn logger = logging.getLogger(__name__) @@ -57,7 +52,6 @@ class DataPanel( MappableMixin, MaterializationMixin, ProvenanceMixin, - DatasetInfoMixin, # this should be the last in order of mixins ): """Meerkat DataPanel class.""" @@ -70,102 +64,27 @@ class DataPanel( def __init__( self, data: Union[dict, list, datasets.Dataset] = None, - identifier: Identifier = None, - column_names: List[str] = None, - info: DatasetInfo = None, - split: Optional[NamedSplit] = None, *args, **kwargs, ): super(DataPanel, self).__init__( - info=info, - split=split, *args, **kwargs, ) - - # TODO(karan, sabri): copy columns when they're passed in and prevent users - # from setting visible_rows inside columns that belong to a datapanel logger.debug("Creating DataPanel.") - if data is not None: - assert column_names is None, "Don't pass in column_names." - # The data is passed in - - # `data` is a dictionary - if isinstance(data, dict) and len(data): - data = self._create_columns(data) - self._assert_columns_all_equal_length(data) - self.data = data - - # `data` is a list - elif isinstance(data, list) and len(data): - # Transpose the list of dicts to a dict of lists i.e. a batch - data = tz.merge_with(list, *data) - # Assert all columns are the same length - data = self._create_columns(data) - self._assert_columns_all_equal_length(data) - self.data = data - - # `data` is a datasets.Dataset - elif isinstance(data, datasets.Dataset): - self.data = self._create_columns(data[:]) - info, split = data.info, data.split - elif isinstance(data, BlockManager): - self.data = data - - else: - if column_names: - # Use column_names to setup the manager - self._check_columns_unique(column_names) - self.data = {k: [] for k in column_names} - else: - self.data = {} - - # Setup the DatasetInfo - info = info.copy() if info is not None else DatasetInfo() - DatasetInfoMixin.__init__(self, info=info, split=split) - - # Create attributes for all columns and visible columns - self._visible_columns = None - - # Create an identifier - # TODO(Sabri): make _autobuild_identifier more informative - self._identifier = Identifier( - self._autobuild_identifier() if not identifier else identifier - ) - - # Create logging directory - self._create_logdir() - - self._initialize_state() + self.data = data # TODO(Sabri): fix add_index for new datset # Add an index to the dataset if not self.has_index: self._add_index() - @classmethod - def _create_columns(cls, name_to_data: Dict[str, AbstractColumn.Columnable]): - new_data = {} - for column_name, data in name_to_data.items(): - new_data[column_name] = AbstractColumn.from_data(data=data) - - return new_data - def _repr_pandas_(self): - return self.data._repr_pandas_()[self.column_names].rename( + return self.data._repr_pandas_()[self.columns].rename( columns={k: f"{k} ({v.__class__.__name__})" for k, v in self.items()} ) - def old_repr(self): - return pd.DataFrame( - { - f"{k} ({v.__class__.__name__})": v._repr_pandas_() - for k, v in self.items() - } - ) - def _repr_html_(self): return self._repr_pandas_()._repr_html_() @@ -173,16 +92,15 @@ def streamlit(self): return self._repr_pandas_() def __repr__(self): - return f"{self.__class__.__name__}" f"(num_rows: {self.num_rows})" + return ( + f"{self.__class__.__name__}" f"(nrows: {self.nrows}, ncols: {self.ncols})" + ) def __len__(self): - # If only a subset of rows are visible - if len(self.visible_columns) == 0: - return 0 - return len(self[self.visible_columns[0]]) + return self.nrows def __contains__(self, item): - return item in self.visible_columns + return item in self.columns @property def data(self) -> BlockManager: @@ -197,168 +115,61 @@ def _set_data(self, value: Union[BlockManager, Mapping] = None): self._data = value elif isinstance(value, Mapping): self._data = BlockManager.from_dict(value) + elif isinstance(value, Sequence): + if not isinstance(value[0], Mapping): + raise ValueError( + "Cannot set DataPanel `data` to a Sequence containing object of " + f" type {type(value[0])}. Must be a Sequence of Mapping." + ) + self._data = BlockManager.from_dict(tz.merge_with(list, *value)) elif value is None: self._data = BlockManager() else: raise ValueError( - f"Cannot set DataPanel data to object of type {type(value)}" + f"Cannot set DataPanel `data` to object of type {type(value)}." ) @data.setter def data(self, value): self._set_data(value) - def full_length(self): - # If there are columns, full_length of any column, since they must be same size - if self.column_names: - return self.data[self.column_names[0]].full_length() - return 0 - - @property - def column_names(self): - """Column names in the dataset.""" - return self.visible_columns - @property def columns(self): - """Column names in the dataset.""" - return self.visible_columns - - @property - def num_rows(self): - """Number of rows in the dataset.""" - return len(self) - - @property - def shape(self): - """Shape of the dataset (num_rows, num_columns).""" - return self.num_rows, len(self.columns) - - @classmethod - def _assert_columns_all_equal_length(cls, batch: Batch): - """Check that all columns have the same length so that the data is - tabular.""" - assert cls._columns_all_equal_length( - batch - ), "All columns must have equal length." - - @classmethod - def _columns_all_equal_length(cls, batch: Batch): - """Check that all columns have the same length so that the data is - tabular.""" - if len(set([len(v) for k, v in batch.items()])) == 1: - return True - return False - - def _check_columns_exist(self, columns: List[str]): - """Check that every column in `columns` exists.""" - for col in columns: - assert col in self.all_columns, f"{col} is not a valid column." - - def _check_columns_unique(self, columns: List[str]): - """Checks that all columns are unique.""" - assert len(columns) == len(set(columns)) - - def _initialize_state(self): - """Dataset state initialization.""" - # Show all columns by default - self.visible_columns = copy(self.all_columns) - - # Set the features - self._set_features() - - @property - def all_columns(self): + """Column names in the DataPanel.""" return list(self.data.keys()) @property - def visible_columns(self): - if self._visible_columns is None: - return self.all_columns - return self._visible_columns - - @visible_columns.setter - def visible_columns(self, columns: Optional[Sequence[str]] = None): - if columns is None: - # do nothing, keep old visible columns - return - for c in columns: - if c not in self.all_columns: - raise ValueError(f"Trying to set nonexistant column {c} to visible.") - - self._visible_columns = copy(columns) - if "index" not in self._visible_columns and "index" in self.all_columns: - self._visible_columns.append("index") - - @contextmanager - def format(self, columns: List[str] = None): - """Context where only `columns` will be visible.""" - # Get the current format - current_format = self.get_format() - - if columns: - # View only `columns` - self.set_format(columns) - else: - # Use all columns - self.set_format(self.column_names) - try: - yield - finally: - # Reset the format back - self.set_format(current_format) - - def get_format(self) -> List[str]: - """Get the dataset format.""" - return self.visible_columns - - def set_format(self, columns: List[str]): - """Set the dataset format. - - Only `columns` are visible after set_format is invoked. - """ - # Check that the columns exist - self._check_columns_exist(columns) - # Set visible columns - self.visible_columns = columns - - def reset_format(self): - """Reset the dataset format. + def nrows(self): + """Number of rows in the DataPanel.""" + if self.ncols == 0: + return 0 + return self.data.nrows - All columns are visible. - """ - # All columns are visible - self.visible_columns = self.all_columns + @property + def ncols(self): + """Number of rows in the DataPanel.""" + return self.data.ncols @property - def identifier(self): - """Identifier.""" - return self._identifier - - def _set_features(self): - """Set the features of the dataset.""" - with self.format(): - self.info.features = None # Features.from_arrow_schema( - # pa.Table.from_pydict( - # self[:1], - # ).schema - # ) + def shape(self): + """Shape of the DataPanel (num_rows, num_columns).""" + return self.nrows, self.ncols def add_column( self, name: str, data: AbstractColumn.Columnable, overwrite=False ) -> None: - """Add a column to the dataset.""" + """Add a column to the DataPanel.""" assert isinstance( name, str ), f"Column name must of type `str`, not `{type(name)}`." - assert (name not in self.all_columns) or overwrite, ( + assert (name not in self.columns) or overwrite, ( f"Column with name `{name}` already exists, " f"set `overwrite=True` to overwrite." ) - if name in self.all_columns: + if name in self.columns: self.remove_column(name) column = AbstractColumn.from_data(data) @@ -371,24 +182,14 @@ def add_column( # Add the column self.data[name] = column - if self._visible_columns is not None: - self.visible_columns = self.visible_columns + [name] - - # Set features - self._set_features() - logger.info(f"Added column `{name}` with length `{len(column)}`.") def remove_column(self, column: str) -> None: """Remove a column from the dataset.""" - assert column in self.all_columns, f"Column `{column}` does not exist." + assert column in self.columns, f"Column `{column}` does not exist." # Remove the column del self.data[column] - self.visible_columns = [col for col in self.visible_columns if col != column] - - # Set features - self._set_features() logger.info(f"Removed column `{column}`.") @@ -405,34 +206,9 @@ def append( `example_or_batch` must have the same columns as the dataset (regardless of what columns are visible). """ - if axis == 0 or axis == "rows": - # append new rows - return meerkat.concat([self, dp], axis="rows") - elif axis == 1 or axis == "columns": - # append new columns - if len(dp) != len(self): - raise ValueError( - "Can only append DataPanels along axis 1 (columns) if they have the" - f"same length. {len(self)} != {len(dp)}" - ) - - shared = set(dp.visible_columns).intersection(set(self.visible_columns)) - if not overwrite and shared: - if suffixes is None: - raise ValueError() - left_suf, right_suf = suffixes - data = { - **{k + left_suf if k in shared else k: v for k, v in self.items()}, - **{k + right_suf if k in shared else k: v for k, v in dp.items()}, - } - else: - data = {**dict(self.items()), **dict(dp.items())} - - col = self._clone(data=data) - col._visible_columns = None - return col - else: - raise ValueError("DataPanel `axis` must be either 0 or 1.") + return meerkat.concat( + [self, dp], axis=axis, suffixes=suffixes, overwrite=overwrite + ) def _add_index(self): """Add an index to the dataset.""" @@ -446,48 +222,18 @@ def tail(self, n: int = 5) -> DataPanel: """Get the last `n` examples of the DataPanel.""" return self.lz[-n:] - def _create_logdir(self): - """Create and assign a directory for logging this dataset's files.""" - if self.identifier.name == "RGDataset": - # TODO(karan): handle temporarily constructed datasets differently - self.logdir /= str(self.identifier) - self.logdir.mkdir(parents=True, exist_ok=True) - else: - self.logdir /= str(self.identifier) - self.logdir.mkdir(parents=True, exist_ok=True) - - def _autobuild_identifier(self) -> Identifier: - """Automatically build an identifier for the dataset using available - information.""" - # Look for a name, otherwise assign a default - _name = self.info.builder_name if self.info.builder_name else "RGDataset" - - # Check for split, version information - split = str(self.split) if self.split else None - version = str(self.version) if self.version else None - - # Add all available information to kwargs dict - kwargs = {} - if split: - kwargs["split"] = split - if version: - kwargs["version"] = version - - # Create identifier - return Identifier(_name=_name, **kwargs) - def _get(self, index, materialize: bool = False): if isinstance(index, str): # str index => column selection (AbstractColumn) - if index in self.column_names: + if index in self.columns: return self.data[index] - raise AttributeError(f"Column {index} does not exist.") + raise KeyError(f"Column `{index}` does not exist.") elif isinstance(index, int): # int index => single row (dict) return { k: self.data[k]._get(index, materialize=materialize) - for k in self.visible_columns + for k in self.columns } # cases where `index` returns a datapanel @@ -530,14 +276,13 @@ def _get(self, index, materialize: bool = False): raise TypeError("Invalid index type: {}".format(type(index))) if index_type == "column": - if not set(index).issubset(self.visible_columns): - missing_cols = set(index) - set(self.visible_columns) - raise ValueError(f"DataPanel does not have columns {missing_cols}") + if not set(index).issubset(self.columns): + missing_cols = set(index) - set(self.columns) + raise KeyError(f"DataPanel does not have columns {missing_cols}") dp = self._clone(data=self.data[index]) - dp.visible_columns = index return dp - elif index_type == "row": + elif index_type == "row": # pragma: no cover return self._clone( data=self.data.apply("_get", index=index, materialize=materialize) ) @@ -546,19 +291,14 @@ def _get(self, index, materialize: bool = False): def __getitem__(self, index): return self._get(index, materialize=True) - def get(self, column, value=None): - if column in self: - return self[column] - return value - def __setitem__(self, index, value): self.add_column(name=index, data=value, overwrite=True) @property def has_index(self) -> bool: """Check if the dataset has an index column.""" - if self.column_names: - return "index" in self.column_names + if self.columns: + return "index" in self.columns # Just return True if the dataset is empty return True @@ -590,25 +330,11 @@ def from_huggingface(cls, *args, **kwargs): else: return cls(dataset) - @classmethod - @capture_provenance() - def from_columns( - cls, - columns: Dict[str, AbstractColumn], - identifier: Identifier = None, - ) -> DataPanel: - """Create a Dataset from a dict of columns.""" - return cls( - columns, - identifier=identifier, - ) - @classmethod @capture_provenance() def from_jsonl( cls, json_path: str, - identifier: Identifier = None, ) -> DataPanel: """Load a dataset from a .jsonl file on disk, where each line of the json file consists of a single example.""" @@ -622,28 +348,23 @@ def from_jsonl( data[k].append(line[k]) return cls( - data, - identifier=identifier - if identifier - else Identifier("Jsonl", jsonl=json_path), + data=data, ) @classmethod - # @capture_provenance() + @capture_provenance() def from_batch( cls, batch: Batch, - identifier: Identifier = None, ) -> DataPanel: """Convert a batch to a Dataset.""" - return cls(batch, identifier=identifier) + return cls(batch) @classmethod @capture_provenance() def from_batches( cls, batches: Sequence[Batch], - identifier: Identifier = None, ) -> DataPanel: """Convert a list of batches to a dataset.""" @@ -652,7 +373,6 @@ def from_batches( tz.compose(list, tz.concat), *batches, ), - identifier=identifier, ) @classmethod @@ -660,7 +380,6 @@ def from_batches( def from_dict( cls, d: Dict, - identifier: Identifier = None, ) -> DataPanel: """Convert a dictionary to a dataset. @@ -668,7 +387,6 @@ def from_dict( """ return cls.from_batch( batch=d, - identifier=identifier, ) @classmethod @@ -676,14 +394,12 @@ def from_dict( def from_pandas( cls, df: pd.DataFrame, - identifier: Identifier = None, ): """Create a Dataset from a pandas DataFrame.""" # column names must be str in meerkat df = df.rename(mapper=str, axis="columns") return cls.from_batch( df.to_dict("series"), - identifier=identifier, ) @classmethod @@ -707,14 +423,10 @@ def from_csv(cls, filepath: str, *args, **kwargs): def from_feather( cls, path: str, - identifier: Identifier = None, ): """Create a Dataset from a feather file.""" return cls.from_batch( pd.read_feather(path).to_dict("list"), - identifier=Identifier("Feather", path=path) - if not identifier - else identifier, ) @capture_provenance() @@ -777,8 +489,10 @@ def batch( batches of data """ cell_columns, batch_columns = [], [] + from meerkat.columns.lambda_column import LambdaColumn + for name, column in self.items(): - if isinstance(column, CellColumn): + if isinstance(column, (CellColumn, LambdaColumn)) and materialize: cell_columns.append(name) else: batch_columns.append(name) @@ -842,11 +556,6 @@ def update( # TODO(karan): make this fn go faster # most of the time is spent on the merge, speed it up further - # Return if the function is None - if function is None: - logger.info("`function` None, returning None.") - return self - # Return if `self` has no examples if not len(self): logger.info("Dataset empty, returning None.") @@ -903,6 +612,7 @@ def update( return new_dp + @capture_provenance() def map( self, function: Optional[Callable] = None, @@ -918,7 +628,7 @@ def map( pbar: bool = False, **kwargs, ) -> Optional[Union[Dict, List, AbstractColumn]]: - input_columns = self.visible_columns if input_columns is None else input_columns + input_columns = self.columns if input_columns is None else input_columns dp = self[input_columns] return super(DataPanel, dp).map( function=function, @@ -950,11 +660,6 @@ def filter( ) -> Optional[DataPanel]: """Filter operation on the DataPanel.""" - # Just return if the function is None - if function is None: - logger.info("`function` None, returning None.") - return None - # Return if `self` has no examples if not len(self): logger.info("DataPanel empty, returning None.") @@ -1018,14 +723,14 @@ def merge( ) def items(self): - for name in self.visible_columns: + for name in self.columns: yield name, self.data[name] def keys(self): - return self.visible_columns + return self.columns def values(self): - for name in self.visible_columns: + for name in self.columns: yield self.data[name] @classmethod @@ -1095,12 +800,7 @@ def write( @classmethod def _state_keys(cls) -> set: """List of attributes that describe the state of the object.""" - return { - "_identifier", - "_visible_columns", - "_info", - "_split", - } + return {} def _view_data(self) -> object: return self.data.view() diff --git a/meerkat/mixins/cloneable.py b/meerkat/mixins/cloneable.py index 600f82783..bb0eeeb46 100644 --- a/meerkat/mixins/cloneable.py +++ b/meerkat/mixins/cloneable.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from meerkat.mixins.blockable import BlockableMixin +from meerkat.provenance import ProvenanceMixin @dataclass @@ -44,6 +45,11 @@ def _clone(self, data: object = None): obj = self.__class__.__new__(self.__class__) obj._set_state(state) obj._set_data(data) + + if isinstance(self, ProvenanceMixin): + # need to create a node for the object + obj._init_node() + return obj def _copy_data(self) -> object: diff --git a/meerkat/mixins/identifier.py b/meerkat/mixins/identifier.py deleted file mode 100644 index d4dda3ba8..000000000 --- a/meerkat/mixins/identifier.py +++ /dev/null @@ -1,19 +0,0 @@ -from meerkat.tools.identifier import Identifier - - -class IdentifierMixin: - def __init__(self, identifier: Identifier, *args, **kwargs): - super(IdentifierMixin, self).__init__(*args, **kwargs) - - # Identifier for the object - self._identifier = ( - Identifier(self.__class__.__name__) if not identifier else identifier - ) - - @property - def identifier(self): - return self._identifier - - @property - def id(self): - return self.identifier diff --git a/meerkat/mixins/mapping.py b/meerkat/mixins/mapping.py index d10bcb05c..1f7a05694 100644 --- a/meerkat/mixins/mapping.py +++ b/meerkat/mixins/mapping.py @@ -33,15 +33,6 @@ def map( """Map a function over the elements of the column.""" - # Just return if the function is None - if function is None: - logger.info("`function` None, returning None.") - return None - - # Ensure that num_workers is not None - if num_workers is None: - num_workers = 0 - # Return if `self` has no examples if not len(self): logger.info("Dataset empty, returning None.") diff --git a/meerkat/ops/concat.py b/meerkat/ops/concat.py index 010e99040..1d18a7fa7 100644 --- a/meerkat/ops/concat.py +++ b/meerkat/ops/concat.py @@ -1,4 +1,5 @@ -from typing import Sequence, Union +from itertools import combinations +from typing import Sequence, Tuple, Union import cytoolz as tz @@ -12,6 +13,8 @@ def concat( objs: Union[Sequence[DataPanel], Sequence[AbstractColumn]], axis: Union[str, int] = "rows", + suffixes: Tuple[str] = None, + overwrite: bool = False, ) -> Union[DataPanel, AbstractColumn]: """Concatenate a sequence of columns or a sequence of `DataPanel`s. If sequence is empty, returns an empty `DataPanel`. @@ -41,8 +44,8 @@ def concat( if isinstance(objs[0], DataPanel): if axis == 0 or axis == "rows": # append new rows - columns = set(objs[0].visible_columns) - if not all([set(dp.visible_columns) == columns for dp in objs]): + columns = set(objs[0].columns) + if not all([set(dp.columns) == columns for dp in objs]): raise ConcatError( "Can only concatenate DataPanels along axis 0 (rows) if they have " " the same set of columns names." @@ -59,15 +62,27 @@ def concat( "have the same length." ) - columns = list(tz.concat((dp.visible_columns for dp in objs))) - if not tz.isdistinct(columns): - raise ConcatError( - "Can only concatenate DataPanels along axis 1 (columns) if they " - "have distinct column names." + # get all column names that appear in more than one DataPanel + shared = set() + for dp1, dp2 in combinations(objs, 2): + shared |= set(dp1.columns) & set(dp2.columns) + + # TODO (sabri): I removed the index column for now to address + # https://github.com/robustness-gym/meerkat/issues/65, but when we refactor + # index with https://github.com/robustness-gym/meerkat/issues/117 we should + # take this out + shared -= {"index"} + if shared and not overwrite: + if suffixes is None: + raise ConcatError("Must ") + data = tz.merge( + {k + suffixes[idx] if k in shared else k: v for k, v in dp.items()} + for idx, dp in enumerate(objs) ) + else: + data = tz.merge(dict(dp.items()) for dp in objs) - data = tz.merge(*(dict(dp.items()) for dp in objs)) - return objs[0].from_batch(data) + return objs[0]._clone(data=data) else: raise ConcatError(f"Invalid axis `{axis}` passed to concat.") elif isinstance(objs[0], AbstractColumn): diff --git a/meerkat/ops/merge.py b/meerkat/ops/merge.py index a9d6f81b0..a5d981292 100644 --- a/meerkat/ops/merge.py +++ b/meerkat/ops/merge.py @@ -84,9 +84,7 @@ def _cols_to_construct(dp: DataPanel): # add columns in both `left_on` and `right_on`, casting to the column type in left for name, column in merged_df.iteritems(): merged_dp.add_column(name, left[name]._clone(data=column.values)) - merged_dp.visible_columns = ( - merged_dp.visible_columns[-1:] + merged_dp.visible_columns[:-1] - ) + merged_dp.data.reorder(merged_dp.columns[-1:] + merged_dp.columns[:-1]) if ( not keep_indexes and ("index" + suffixes[0]) in merged_dp diff --git a/meerkat/pipelines/entitydatapanel.py b/meerkat/pipelines/entitydatapanel.py index e0e1ed7bc..4972eb8fb 100644 --- a/meerkat/pipelines/entitydatapanel.py +++ b/meerkat/pipelines/entitydatapanel.py @@ -10,7 +10,6 @@ from meerkat import DataPanel, ListColumn, NumpyArrayColumn, TensorColumn from meerkat.nn import EmbeddingColumn -from meerkat.tools.identifier import Identifier logger = logging.getLogger(__name__) @@ -19,8 +18,7 @@ class EntityDataPanel(DataPanel): def __init__( self, data: Union[dict, list, datasets.Dataset] = None, - identifier: Identifier = None, - column_names: List[str] = None, + columns: List[str] = None, embedding_columns: List[str] = None, index_column: str = None, **kwargs, @@ -34,20 +32,15 @@ def __init__( operations such as nearest neighbor search. Args: - identifier: identifier - column_names: all column names + columns: all column names embedding_columns: embedding columns in all columns index_column: index column """ super().__init__( data=data, - identifier=identifier, - column_names=column_names, - info=None, - split=None, **kwargs, ) - if len(self.column_names) > 0: + if len(self.columns) > 0: self._embedding_columns = embedding_columns if embedding_columns else [] self._check_columns_unique(self._embedding_columns) @@ -67,6 +60,15 @@ def __init__( self._index_column = None self._index_to_rowid = {} + def _check_columns_exist(self, columns: List[str]): + """Check that every column in `columns` exists.""" + for col in columns: + assert col in self.columns, f"{col} is not a valid column." + + def _check_columns_unique(self, columns: List[str]): + """Checks that all columns are unique.""" + assert len(columns) == len(set(columns)) + @classmethod def from_datapanel( cls, @@ -96,7 +98,7 @@ def to_datapanel(self, klass: type = None): klass = DataPanel elif not issubclass(klass, DataPanel): raise ValueError("`klass` must be a subclass of DataPanel") - return klass.from_batch({k: self[k] for k in self.visible_columns}) + return klass.from_batch({k: self[k] for k in self.columns}) @property def index(self): @@ -106,7 +108,7 @@ def index(self): @property def embedding_columns(self): """Returns _visible_ embedding columns.""" - return [e for e in self._embedding_columns if e in self.visible_columns] + return [e for e in self._embedding_columns if e in self.columns] @property def index_column(self): @@ -214,7 +216,7 @@ def append( ) # Save the new index column for saving EntityDataPanel new_index_column = self.index_column - if self.index_column in dp.column_names and not overwrite: + if self.index_column in dp.columns and not overwrite: new_index_column += suffixes[0] ret = super(EntityDataPanel, self).append(dp, axis, suffixes, overwrite) ret._embedding_columns = new_embedding_cols @@ -247,7 +249,7 @@ def merge( # the joining column, then the index column will change to # have a suffix new_index_column = self.index_column - if self.index_column in right.column_names: + if self.index_column in right.columns: # Column will stay the same if it's the joining # column of both left and right if not ( @@ -372,8 +374,6 @@ def most_similar( def _state_keys(cls) -> set: """List of attributes that describe the state of the object.""" return { - "_visible_columns", - "_identifier", "_embedding_columns", "_index_column", } diff --git a/meerkat/provenance.py b/meerkat/provenance.py index a573a8771..47298daf5 100644 --- a/meerkat/provenance.py +++ b/meerkat/provenance.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings import weakref from copy import copy from functools import wraps @@ -7,6 +8,7 @@ from typing import Any, Dict, List, Mapping, Sequence, Tuple, Union import meerkat as mk +from meerkat.errors import ExperimentalWarning _provenance_enabled = False @@ -47,6 +49,9 @@ def is_provenance_enabled(): class ProvenanceMixin: def __init__(self, *args, **kwargs): super(ProvenanceMixin, self).__init__(*args, **kwargs) + self._init_node() + + def _init_node(self): self._node = ProvenanceObjNode(self) @property @@ -246,18 +251,25 @@ def visualize_provenance( show_columns: bool = False, last_parent_only: bool = False, ): - try: + + warnings.warn( # pragma: no cover + ExperimentalWarning( + "The function `meerkat.provenance.visualize_provenance` is experimental and" + " has limited test coverage. Proceed with caution." + ) + ) + try: # pragma: no cover import cyjupyter - except ImportError: + except ImportError: # pragma: no cover raise ImportError( "`visualize_provenance` requires the `cyjupyter` dependency." "See https://github.com/cytoscape/cytoscape-jupyter-widget" ) - nodes, edges = obj.get_provenance( + nodes, edges = obj.get_provenance( # pragma: no cover include_columns=show_columns, last_parent_only=last_parent_only ) - cy_nodes = [ + cy_nodes = [ # pragma: no cover { "data": { "id": id(node), @@ -267,7 +279,7 @@ def visualize_provenance( } for node in nodes ] - cy_edges = [ + cy_edges = [ # pragma: no cover { "data": { "source": id(edge[0]), @@ -278,9 +290,9 @@ def visualize_provenance( for edge in edges ] - cy_data = {"elements": {"nodes": cy_nodes, "edges": cy_edges}} + cy_data = {"elements": {"nodes": cy_nodes, "edges": cy_edges}} # pragma: no cover - style = [ + style = [ # pragma: no cover { "selector": "node", "css": { @@ -318,6 +330,6 @@ def visualize_provenance( }, }, ] - return cyjupyter.Cytoscape( + return cyjupyter.Cytoscape( # pragma: no cover data=cy_data, visual_style=style, layout_name="breadthfirst" ) diff --git a/meerkat/tools/identifier.py b/meerkat/tools/identifier.py deleted file mode 100644 index 3123a62ea..000000000 --- a/meerkat/tools/identifier.py +++ /dev/null @@ -1,153 +0,0 @@ -"""Identifiers for objects in Meerkat.""" -from __future__ import annotations - -import ast -import json -from typing import Any, Callable, List, Union - - -class Identifier: - """Class for creating identifiers for objects in Robustness Gym.""" - - def __init__( - self, - _name: str, - _index: Union[str, int] = None, - **kwargs, - ): - - self._name = _name - self._index = str(_index) if _index is not None else None - self._parameters = kwargs - - # Add the parameter - for param, value in self.parameters.items(): - self.add_parameter(param, value) - - @property - def name(self): - """Base name.""" - return self._name - - @property - def index(self): - """Index associated with the identifier.""" - return self._index - - @property - def parameters(self): - """Additional parameters contained in the identifier.""" - return self._parameters - - @classmethod - def range(cls, n: int, _name: str, **kwargs) -> List[Identifier]: - """Create a list of identifiers, with index varying from 1 to `n`.""" - - if n > 1: - return [cls(_name=_name, _index=i, **kwargs) for i in range(1, n + 1)] - return [cls(_name=_name, **kwargs)] - - def __call__(self, **kwargs): - """Call the identifier with additional parameters to return a new - identifier.""" - ident = Identifier.loads(self.dumps()) - for parameter, value in kwargs.items(): - ident.add_parameter(parameter, value) - return ident - - def __repr__(self): - params = ", ".join([f"{k}={v}" for k, v in self.parameters.items()]) - if self.index is not None: - return ( - f"{self.name}-{self.index}({params})" - if len(params) > 0 - else f"{self.name}-{self.index}" - ) - return f"{self.name}({params})" if len(params) > 0 else f"{self.name}" - - def __hash__(self): - # return persistent_hash(str(self)) - return hash(str(self)) - - def __eq__(self, other: Union[Identifier, str]): - return str(self) == str(other) - - def dumps(self): - """Dump the identifier to JSON.""" - return json.dumps(self.__dict__) - - @staticmethod - def _parse_args(s: str): - """https://stackoverflow.com/questions/49723047/parsing-a-string-as-a- - python-argument-list.""" - args = "f({})".format(s) - tree = ast.parse(args) - funccall = tree.body[0].value - # return {arg.arg: ast.literal_eval(arg.value) for arg in funccall.keywords} - params = {} - for arg in funccall.keywords: - try: - params[arg.arg] = ast.literal_eval(arg.value) - except ValueError: - params[arg.arg] = arg.value.id - return params - - @classmethod - def parse(cls, s: str) -> Identifier: - """Parse in an identifier from string.""" - # Parse out the various components - if "(" in s: - name_index, params = s.split("(") - params = params.split(")")[0] - else: - name_index = s - params = None - - # Create the name and index - if "-" in name_index: - name, index = name_index.split("-")[:-1], name_index.split("-")[-1] - name = "-".join(name) - if index.isnumeric(): - index = int(index) - else: - name = "-".join([name, index]) - index = None - else: - name = name_index - index = None - - # Parse out the params - if params is not None: - params = cls._parse_args(params) - else: - params = {} - - return cls(_name=name, _index=index, **params) - - def without(self, *params) -> Identifier: - """Returns an identifier without `params`.""" - return Identifier( - self.name, - self.index, - **{k: v for k, v in self.parameters.items() if k not in set(params)}, - ) - - @classmethod - def loads(cls, s: str): - """Load the identifier from JSON.""" - identifier = Identifier(_name="") - identifier.__dict__ = json.loads(s) - return identifier - - def add_parameter(self, parameter: str, value: Any) -> None: - """Add a parameter to the identifier.""" - if isinstance(value, Callable): - self.parameters[parameter] = ".".join( - [str(value.__module__), str(value.__name__)] - ) - else: - self.parameters[parameter] = value - - -# Assign Id as an alias for the Identifier class -Id = Identifier diff --git a/tests/meerkat/block/test_manager.py b/tests/meerkat/block/test_manager.py index ba59bde96..25bb5af26 100644 --- a/tests/meerkat/block/test_manager.py +++ b/tests/meerkat/block/test_manager.py @@ -79,6 +79,23 @@ def test_consolidate_multiple_types(): assert len(mgr._block_refs) == 3 +def test_consolidate_preserves_order(): + mgr = BlockManager() + + col1 = mk.NumpyArrayColumn(data=np.arange(10)) + mgr.add_column(col1, "col1") + col2 = mk.NumpyArrayColumn(np.arange(10) * 2) + mgr.add_column(col2, "col2") + col3 = mk.PandasSeriesColumn(np.arange(10) * 3) + mgr.add_column(col3, "col3") + + order = ["col2", "col3", "col1"] + mgr.reorder(order) + assert list(mgr.keys()) == order + mgr.consolidate() + assert list(mgr.keys()) == order + + @pytest.mark.parametrize( "num_blocks, consolidated", product([1, 2, 3], [True, False]), diff --git a/tests/meerkat/columns/abstract.py b/tests/meerkat/columns/abstract.py index aad1bbc99..18df82e0f 100644 --- a/tests/meerkat/columns/abstract.py +++ b/tests/meerkat/columns/abstract.py @@ -8,6 +8,7 @@ import pytest from meerkat.datapanel import DataPanel +from meerkat.ops.concat import concat @pytest.fixture @@ -207,6 +208,15 @@ def func(x): assert result.is_equal(filter_spec["expected_result"]) + def test_concat(self, testbed: AbstractColumnTestBed, n: int = 2): + col = testbed.col + out = concat([col] * n) + + assert len(out) == len(col) * n + assert isinstance(out, type(col)) + for i in range(n): + assert out.lz[i * len(col) : (i + 1) * len(col)].is_equal(col) + def test_copy(self, testbed: AbstractColumnTestBed): col, _ = testbed.col, testbed.data col_copy = col.copy() diff --git a/tests/meerkat/columns/test_image_column.py b/tests/meerkat/columns/test_image_column.py index 39025eaf6..9f048ab22 100644 --- a/tests/meerkat/columns/test_image_column.py +++ b/tests/meerkat/columns/test_image_column.py @@ -253,6 +253,10 @@ def test_map_return_single_w_kwarg( ): return super().test_map_return_single_w_kwarg(testbed, batched, materialize) + @ImageColumnTestBed.parametrize(params={"n": [1, 2, 3]}) + def test_concat(self, testbed: AbstractColumnTestBed, n: int): + return super().test_concat(testbed, n=n) + @ImageColumnTestBed.parametrize() def test_copy(self, testbed: AbstractColumnTestBed): return super().test_copy(testbed) diff --git a/tests/meerkat/columns/test_numpy_column.py b/tests/meerkat/columns/test_numpy_column.py index f41750edf..06dd5a219 100644 --- a/tests/meerkat/columns/test_numpy_column.py +++ b/tests/meerkat/columns/test_numpy_column.py @@ -121,6 +121,10 @@ def test_map_return_single_w_kwarg( testbed, batched, materialize=True ) + @NumpyArrayColumnTestBed.parametrize(params={"n": [1, 2, 3]}) + def test_concat(self, testbed: AbstractColumnTestBed, n: int): + return super().test_concat(testbed, n=n) + @NumpyArrayColumnTestBed.parametrize() def test_copy(self, testbed: AbstractColumnTestBed): return super().test_copy(testbed) diff --git a/tests/meerkat/columns/test_pandas_column.py b/tests/meerkat/columns/test_pandas_column.py index 05ba41ae7..05111c042 100644 --- a/tests/meerkat/columns/test_pandas_column.py +++ b/tests/meerkat/columns/test_pandas_column.py @@ -150,6 +150,10 @@ def test_map_return_single_w_kwarg( testbed, batched, materialize=True ) + @PandasSeriesColumnTestBed.parametrize(params={"n": [1, 2, 3]}) + def test_concat(self, testbed: AbstractColumnTestBed, n: int): + return super().test_concat(testbed, n=n) + @PandasSeriesColumnTestBed.parametrize() def test_copy(self, testbed: AbstractColumnTestBed): return super().test_copy(testbed) diff --git a/tests/meerkat/columns/test_tensor_column.py b/tests/meerkat/columns/test_tensor_column.py index ab3bcddc4..ee0d83d52 100644 --- a/tests/meerkat/columns/test_tensor_column.py +++ b/tests/meerkat/columns/test_tensor_column.py @@ -120,6 +120,10 @@ def test_map_return_single_w_kwarg( testbed, batched, materialize=True ) + @TensorColumnTestBed.parametrize(params={"n": [1, 2, 3]}) + def test_concat(self, testbed: AbstractColumnTestBed, n: int): + return super().test_concat(testbed, n=n) + @TensorColumnTestBed.parametrize() def test_copy(self, testbed: AbstractColumnTestBed): return super().test_copy(testbed) diff --git a/tests/meerkat/ops/test_concat.py b/tests/meerkat/ops/test_concat.py index 4494b715c..525c0c73b 100644 --- a/tests/meerkat/ops/test_concat.py +++ b/tests/meerkat/ops/test_concat.py @@ -5,32 +5,13 @@ import pytest from meerkat import concat -from meerkat.columns.image_column import ImageColumn from meerkat.columns.list_column import ListColumn from meerkat.columns.numpy_column import NumpyArrayColumn -from meerkat.columns.pandas_column import PandasSeriesColumn -from meerkat.columns.tensor_column import TensorColumn from meerkat.datapanel import DataPanel from meerkat.errors import ConcatError from meerkat.nn.prediction_column import ClassificationOutputColumn -from ...testbeds import MockColumn, MockDatapanel, MockImageColumn - - -@pytest.mark.parametrize( - "col_type,n", - product( - [ListColumn, NumpyArrayColumn, TensorColumn, PandasSeriesColumn], - [1, 2, 3], - ), -) -def test_column_concat(col_type, n): - mock_col = MockColumn(col_type=col_type) - out = concat([mock_col.col] * n) - - assert len(out) == len(mock_col.visible_rows) * n - assert isinstance(out, col_type) - assert list(out.data) == list(np.concatenate([mock_col.visible_rows] * n)) +from ...testbeds import MockDatapanel @pytest.mark.parametrize( @@ -48,7 +29,7 @@ def test_datapanel_row_concat(use_visible_columns, n): assert len(out) == len(mock_dp.visible_rows) * n assert isinstance(out, DataPanel) - assert set(out.visible_columns) == set(mock_dp.visible_columns) + assert set(out.columns) == set(mock_dp.dp.columns) assert (out["a"].data == np.concatenate([mock_dp.visible_rows] * n)).all() assert out["b"].data == list(np.concatenate([mock_dp.visible_rows] * n)) @@ -64,19 +45,18 @@ def test_datapanel_column_concat(): assert len(out) == len(mock_dp.visible_rows) assert isinstance(out, DataPanel) - assert set(out.visible_columns) == {"a", "b", "index"} + assert set(out.columns) == {"a", "b"} assert list(out["a"].data) == out["b"].data -def test_image_column(tmpdir): - - mock = MockImageColumn(length=16, tmpdir=tmpdir) - - out = concat([mock.col.lz[:5], mock.col.lz[5:10]]) +def test_concat_same_columns(): + a = DataPanel.from_batch({"a": [1, 2, 3]}) + b = DataPanel.from_batch({"a": [2, 3, 4]}) - assert len(out) == 10 - assert isinstance(out, ImageColumn) - assert [cell for cell in out.data] == mock.image_paths[:10] + out = concat([a, b], axis="columns", suffixes=["_a", "_b"]) + assert out.columns == ["a_a", "index", "a_b"] + assert list(out["a_a"].data) == [1, 2, 3] + assert list(out["a_b"].data) == [2, 3, 4] def test_concat_different_type(): @@ -115,14 +95,6 @@ def test_concat_different_lengths(): concat([a, b], axis="columns") -def test_concat_same_columns(): - a = DataPanel.from_batch({"a": [1, 2, 3]}) - b = DataPanel.from_batch({"a": [1, 2, 3]}) - - with pytest.raises(ConcatError): - concat([a, b], axis="columns") - - def test_concat_maintains_subclass(): col = ClassificationOutputColumn(logits=[0, 1, 0, 1], num_classes=2) out = concat([col, col]) diff --git a/tests/meerkat/ops/test_merge.py b/tests/meerkat/ops/test_merge.py index 6a18207c5..312614c9a 100644 --- a/tests/meerkat/ops/test_merge.py +++ b/tests/meerkat/ops/test_merge.py @@ -1,10 +1,12 @@ """Unittests for Datasets.""" -from itertools import product +import os +from typing import Dict import numpy as np import pytest import torch +from meerkat.columns.abstract import AbstractColumn from meerkat.columns.image_column import ImageCellColumn, ImageColumn from meerkat.columns.list_column import ListColumn from meerkat.columns.numpy_column import NumpyArrayColumn @@ -13,148 +15,159 @@ from meerkat.errors import MergeError from ...testbeds import MockImageColumn +from ..test_datapanel import DataPanelTestBed -def get_dps( - length1: int, - length2: int, - use_visible_columns: bool = False, - include_image_column: bool = False, - tmpdir: str = None, -): - shuffle1 = np.arange(length1) - batch1 = { - "a": np.arange(length1)[shuffle1], - "b": list(np.arange(length1)[shuffle1]), - "c": [[i] for i in np.arange(length1)[shuffle1]], - "d": (torch.arange(length1) % 3)[shuffle1], - "e": [f"1_{i}" for i in np.arange(length1)[shuffle1]], +class MergeTestBed(DataPanelTestBed): + DEFAULT_CONFIG = { + "lengths": [ + {"left": 12, "right": 16}, + {"left": 16, "right": 16}, + {"left": 16, "right": 12}, + ], + "consolidated": [True, False], } - np.random.seed(1) - shuffle2 = np.random.permutation(np.arange(length2)) - batch2 = { - "a": np.arange(length2)[shuffle2], - "b": list(np.arange(length2)[shuffle2]), - "e": [f"1_{i}" for i in np.arange(length1)[shuffle2]], - "f": (np.arange(length2) % 2)[shuffle2], - } + def __init__( + self, + column_configs: Dict[str, AbstractColumn], + simple: bool = False, + lengths: int = 16, + consolidated: int = 16, + tmpdir: str = None, + ): + self.side_to_dp = {} + if simple: + # TODO (Sabri): do away with the simple testbed, and replace with the full + # one after updating support for missing values + # https://github.com/robustness-gym/meerkat/issues/123 + np.random.seed(1) + self.side_to_dp["left"] = DataPanel.from_batch( + { + "key": np.arange(lengths["left"]), + "b": list(np.arange(lengths["left"])), + "c": [[i] for i in np.arange(lengths["left"])], + "d": (torch.arange(lengths["left"]) % 3), + "e": [f"1_{i}" for i in np.arange(lengths["left"])], + } + ).lz[np.random.permutation(np.arange(lengths["left"]))] + + self.side_to_dp["right"] = DataPanel.from_batch( + { + "key": np.arange(lengths["right"]), + "b": list(np.arange(lengths["right"])), + "e": [f"1_{i}" for i in np.arange(lengths["right"])], + "f": (np.arange(lengths["right"]) % 2), + } + ) + else: + for side in ["left", "right"]: + side_tmpdir = os.path.join(tmpdir, side) + os.makedirs(side_tmpdir) + column_testbeds = self._build_column_testbeds( + column_configs, length=lengths[side], tmpdir=side_tmpdir + ) + columns = { + name: testbed.col for name, testbed in column_testbeds.items() + } + dp = DataPanel.from_batch(columns) + + dp["key"] = np.arange(len(dp)) + + if consolidated: + dp.consolidate() + + if side == "left": + np.random.seed(1) + dp = dp.lz[np.random.permutation(np.arange(len(dp)))] + self.side_to_dp[side] = dp + + +@pytest.fixture +def testbed(request, tmpdir): + config = request.param + return MergeTestBed(**config, tmpdir=tmpdir) + + +class TestMerge: + @MergeTestBed.parametrize(params={"sort": [True, False]}) + def test_merge_inner(self, testbed: MergeTestBed, sort): + dp1, dp2 = ( + testbed.side_to_dp["left"], + testbed.side_to_dp["right"], + ) + + out = dp1.merge( + dp2, + on="key", + how="inner", + keep_indexes=False, + suffixes=("_1", "_2"), + sort=sort, + ) + + assert isinstance(out, DataPanel) + assert len(out) == min(len(dp1), len(dp2)) + + # # check sorted + if sort: + assert np.all(np.diff(out["key"]) >= 0) + + # assert set(out.columns) == set(expected_columns) + for name in dp1.columns: + if name in ["key", "index"]: + continue + + if isinstance(out[f"{name}_1"], ImageColumn): + assert out[f"{name}_1"].__class__ == out[f"{name}_2"].__class__ + assert out[f"{name}_1"].data.is_equal( + out[f"{name}_2"].data.str.replace("right", "left") + ) + else: + assert out[f"{name}_1"].is_equal(out[f"{name}_2"]) + + @MergeTestBed.parametrize(config={"simple": [True]}, params={"sort": [True, False]}) + def test_merge_outer(self, testbed, sort): + dp1, dp2 = ( + testbed.side_to_dp["left"], + testbed.side_to_dp["right"], + ) + out = dp1.merge( + dp2, + on="key", + how="outer", + keep_indexes=False, + suffixes=("_1", "_2"), + sort=sort, + ) + + a1 = set(dp1["key"]) + a2 = set(dp2["key"]) + + assert isinstance(out, DataPanel) + assert len(out) == len(a1 | a2) + + # check columns + expected_columns = ["key", "b_1", "b_2", "c", "d", "e_1", "e_2", "f", "index"] + assert set(out.columns) == set(expected_columns) + + # check sorted + if sort: + assert np.all(np.diff(out["key"]) >= 0) + + # check for `None` at unmatched rows + mask_both = np.where([val in (a1 & a2) for val in out["key"]])[0] + mask_1 = np.where([val in (a1 - a2) for val in out["key"]])[0] + mask_2 = np.where([val in (a2 - a1) for val in out["key"]])[0] + # check for equality at matched rows + assert list(out.lz[mask_both]["b_1"]) == list(out.lz[mask_both]["b_2"]) + # check for `values` at unmatched rows + assert set(out.lz[mask_1]["b_1"]) == a1 - a2 + assert set(out.lz[mask_2]["b_2"]) == a2 - a1 + # check for `None` at unmatched rows + assert list(out.lz[mask_1]["b_2"]) == [None] * len(mask_1) + assert list(out.lz[mask_2]["b_1"]) == [None] * len(mask_2) - if include_image_column: - img_col = MockImageColumn(length=length1, tmpdir=tmpdir).col - batch1["img"] = img_col - img_col = MockImageColumn(length=length2, tmpdir=tmpdir).col - batch2["img"] = img_col.lz[shuffle2] - - visible_columns = ["a", "b", "c"] if use_visible_columns else None - - dps = [] - for batch, shuffle in [(batch1, shuffle1), (batch2, shuffle2)]: - dp = DataPanel.from_batch(batch) - - if use_visible_columns: - dp.visible_columns = [c for c in visible_columns if c in dp.all_columns] - - dps.append(dp) - return dps[0], dps[1], visible_columns, shuffle1, shuffle2 - - -def test_no_on(): - length = 16 - # check dictionary not hashable - dp1 = DataPanel.from_batch( - { - "a": ListColumn([{"a": 1}] * length), - "b": list(np.arange(length)), - } - ) - dp2 = dp1.copy() - with pytest.raises(MergeError): - dp1.merge(dp2) - - -@pytest.mark.parametrize( - "use_visible_columns,diff_length,sort", - product([True, False], [True, False], [True, False]), -) -def test_merge_inner(use_visible_columns, diff_length, sort): - length1 = 16 - length2 = 12 if diff_length else 16 - dp1, dp2, visible_columns, shuffle1, shuffle2 = get_dps( - length1=length1, - length2=length2, - use_visible_columns=use_visible_columns, - ) - out = dp1.merge( - dp2, on="a", how="inner", keep_indexes=False, suffixes=("_1", "_2"), sort=sort - ) - - assert isinstance(out, DataPanel) - expected_length = min(length1, length2) - assert len(out) == expected_length - - # check columns - if use_visible_columns: - expected_columns = ["a", "b_1", "b_2", "c"] - else: - expected_columns = ["a", "b_1", "b_2", "c", "d", "e_1", "e_2", "f"] - - # check sorted - if sort: - assert np.all(np.diff(out["a"]) >= 0) - - assert set(out.columns) == set(expected_columns) - - assert (out["b_1"] == out["b_2"]).all() - if not use_visible_columns: - assert list(out["e_1"]) == list(out["e_2"]) - - -@pytest.mark.parametrize( - "use_visible_columns,sort", - product([True, False], [True, False]), -) -def test_merge_outer(use_visible_columns, sort): - dp1, dp2, visible_columns, shuffle1, shuffle2 = get_dps( - length1=16, - length2=12, - use_visible_columns=use_visible_columns, - ) - out = dp1.merge( - dp2, on="a", how="outer", keep_indexes=False, suffixes=("_1", "_2"), sort=sort - ) - - a1 = set(shuffle1) - a2 = set(shuffle2) - - assert isinstance(out, DataPanel) - assert len(out) == len(a1 | a2) - - # check columns - if use_visible_columns: - expected_columns = ["a", "b_1", "b_2", "c"] - else: - expected_columns = ["a", "b_1", "b_2", "c", "d", "e_1", "e_2", "f"] - assert set(out.columns) == set(expected_columns) - - # check sorted - if sort: - assert np.all(np.diff(out["a"]) >= 0) - - # check for `None` at unmatched rows - mask_both = np.where([val in (a1 & a2) for val in out["a"]])[0] - mask_1 = np.where([val in (a1 - a2) for val in out["a"]])[0] - mask_2 = np.where([val in (a2 - a1) for val in out["a"]])[0] - # check for equality at matched rows - assert list(out.lz[mask_both]["b_1"]) == list(out.lz[mask_both]["b_2"]) - # check for `values` at unmatched rows - assert set(out.lz[mask_1]["b_1"]) == a1 - a2 - assert set(out.lz[mask_2]["b_2"]) == a2 - a1 - # check for `None` at unmatched rows - assert list(out.lz[mask_1]["b_2"]) == [None] * len(mask_1) - assert list(out.lz[mask_2]["b_1"]) == [None] * len(mask_2) - - if not use_visible_columns: # check for `values` at unmatched rows assert set(out.lz[mask_1]["e_1"]) == set([f"1_{i}" for i in a1 - a2]) assert set(out.lz[mask_2]["e_2"]) == set([f"1_{i}" for i in a2 - a1]) @@ -162,254 +175,256 @@ def test_merge_outer(use_visible_columns, sort): assert list(out.lz[mask_1]["e_2"]) == [None] * len(mask_1) assert list(out.lz[mask_2]["e_1"]) == [None] * len(mask_2) + @MergeTestBed.parametrize(config={"simple": [True]}, params={"sort": [True, False]}) + def test_merge_left(self, testbed, sort): + dp1, dp2 = ( + testbed.side_to_dp["left"], + testbed.side_to_dp["right"], + ) + out = dp1.merge( + dp2, + on="key", + how="left", + keep_indexes=False, + suffixes=("_1", "_2"), + sort=sort, + ) + + a1 = set(dp1["key"]) + a2 = set(dp2["key"]) + + assert isinstance(out, DataPanel) + assert len(out) == len(a1) + + # check columns + expected_columns = ["key", "b_1", "b_2", "c", "d", "e_1", "e_2", "index", "f"] + assert set(out.columns) == set(expected_columns) + + # check sorted + if sort: + assert np.all(np.diff(out["key"]) >= 0) + + # check for `None` at unmatched rows + mask_both = np.where([val in (a1 & a2) for val in out["key"]])[0] + mask_1 = np.where([val in (a1 - a2) for val in out["key"]])[0] + + # check for equality at matched rows + assert list(out.lz[mask_both]["b_1"]) == list(out.lz[mask_both]["b_2"]) + # check for `values` at unmatched rows + assert set(out.lz[mask_1]["b_1"]) == a1 - a2 + # check for `None` at unmatched rows + assert list(out.lz[mask_1]["b_2"]) == [None] * len(mask_1) -@pytest.mark.parametrize( - "use_visible_columns,sort", - product([True, False], [True, False]), -) -def test_merge_left(use_visible_columns, sort): - dp1, dp2, visible_columns, shuffle1, shuffle2 = get_dps( - length1=16, - length2=12, - use_visible_columns=use_visible_columns, - ) - out = dp1.merge( - dp2, on="a", how="left", keep_indexes=False, suffixes=("_1", "_2"), sort=sort - ) - - a1 = set(shuffle1) - a2 = set(shuffle2) - - assert isinstance(out, DataPanel) - assert len(out) == len(a1) - - # check columns - if use_visible_columns: - expected_columns = ["a", "b_1", "b_2", "c"] - else: - expected_columns = ["a", "b_1", "b_2", "c", "d", "e_1", "e_2", "f"] - assert set(out.columns) == set(expected_columns) - - # check sorted - if sort: - assert np.all(np.diff(out["a"]) >= 0) - - # check for `None` at unmatched rows - mask_both = np.where([val in (a1 & a2) for val in out["a"]])[0] - mask_1 = np.where([val in (a1 - a2) for val in out["a"]])[0] - - # check for equality at matched rows - assert list(out.lz[mask_both]["b_1"]) == list(out.lz[mask_both]["b_2"]) - # check for `values` at unmatched rows - assert set(out.lz[mask_1]["b_1"]) == a1 - a2 - # check for `None` at unmatched rows - assert list(out.lz[mask_1]["b_2"]) == [None] * len(mask_1) - - if not use_visible_columns: # check for `values` at unmatched rows assert set(out.lz[mask_1]["e_1"]) == set([f"1_{i}" for i in a1 - a2]) # check for equality at matched rows assert list(out.lz[mask_1]["e_2"]) == [None] * len(mask_1) + @MergeTestBed.parametrize(config={"simple": [True]}, params={"sort": [True, False]}) + def test_merge_right(self, testbed, sort): + dp1, dp2 = ( + testbed.side_to_dp["left"], + testbed.side_to_dp["right"], + ) + out = dp1.merge( + dp2, + on="key", + how="right", + keep_indexes=False, + suffixes=("_1", "_2"), + sort=sort, + ) + + a1 = set(dp1["key"]) + a2 = set(dp2["key"]) + + assert isinstance(out, DataPanel) + assert len(out) == len(a2) + + # check columns + expected_columns = ["key", "b_1", "b_2", "c", "d", "e_1", "e_2", "f", "index"] + assert set(out.columns) == set(expected_columns) + + # check sorted + if sort: + assert np.all(np.diff(out["key"]) >= 0) + + # check for `None` at unmatched rows + mask_both = np.where([val in (a1 & a2) for val in out["key"]])[0] + mask_2 = np.where([val in (a2 - a1) for val in out["key"]])[0] + # check for equality at matched rows + assert list(out.lz[mask_both]["b_1"]) == list(out.lz[mask_both]["b_2"]) + # check for `values` at unmatched rows + assert set(out.lz[mask_2]["b_2"]) == a2 - a1 + # check for `None` at unmatched rows + assert list(out.lz[mask_2]["b_1"]) == [None] * len(mask_2) -@pytest.mark.parametrize( - "use_visible_columns,sort", - product([True, False], [True, False]), -) -def test_merge_right(use_visible_columns, sort): - dp1, dp2, visible_columns, shuffle1, shuffle2 = get_dps( - length1=16, - length2=12, - use_visible_columns=use_visible_columns, - ) - out = dp1.merge( - dp2, on="a", how="right", keep_indexes=False, suffixes=("_1", "_2"), sort=sort - ) - - a1 = set(shuffle1) - a2 = set(shuffle2) - - assert isinstance(out, DataPanel) - assert len(out) == len(a2) - - # check columns - if use_visible_columns: - expected_columns = ["a", "b_1", "b_2", "c"] - else: - expected_columns = ["a", "b_1", "b_2", "c", "d", "e_1", "e_2", "f"] - assert set(out.columns) == set(expected_columns) - - # check sorted - if sort: - assert np.all(np.diff(out["a"]) >= 0) - - # check for `None` at unmatched rows - mask_both = np.where([val in (a1 & a2) for val in out["a"]])[0] - mask_2 = np.where([val in (a2 - a1) for val in out["a"]])[0] - # check for equality at matched rows - assert list(out.lz[mask_both]["b_1"]) == list(out.lz[mask_both]["b_2"]) - # check for `values` at unmatched rows - assert set(out.lz[mask_2]["b_2"]) == a2 - a1 - # check for `None` at unmatched rows - assert list(out.lz[mask_2]["b_1"]) == [None] * len(mask_2) - - if not use_visible_columns: # check for `values` at unmatched rows assert set(out.lz[mask_2]["e_2"]) == set([f"1_{i}" for i in a2 - a1]) # check for equality at matched rows assert list(out.lz[mask_2]["e_1"]) == [None] * len(mask_2) - -def test_merge_output_column_types(): - dp1 = DataPanel.from_batch({"a": np.arange(3), "b": ListColumn(["1", "2", "3"])}) - dp2 = dp1.copy() - - out = dp1.merge(dp2, on="b", how="inner") - assert isinstance(out["b"], ListColumn) - - -def test_image_merge(tmpdir): - length = 16 - img_col_test_bed = MockImageColumn(length=length, tmpdir=tmpdir) - dp1 = DataPanel.from_batch( - { - "a": np.arange(length), - "img": img_col_test_bed.col, - } - ) - rows = np.arange(4, 8) - dp2 = DataPanel.from_batch( - { - "a": rows, - } - ) - - out = dp1.merge(dp2, on="a", how="inner") - assert isinstance(out["img"], ImageColumn) - assert [str(fp) for fp in out["img"].data] == [ - img_col_test_bed.image_paths[row] for row in rows - ] - - -def test_cell_merge(tmpdir): - length = 16 - img_col_test_bed = MockImageColumn( - length=length, tmpdir=tmpdir, use_cell_column=True - ) - dp1 = DataPanel.from_batch( - { - "a": np.arange(length), - "img": img_col_test_bed.col, - } - ) - rows = np.arange(4, 8) - dp2 = DataPanel.from_batch( - { - "a": rows, - } - ) - - out = dp1.merge(dp2, on="a", how="inner") - assert isinstance(out["img"], ImageCellColumn) - assert [str(cell.filepath) for cell in out["img"].data] == [ - img_col_test_bed.image_paths[row] for row in rows - ] - - -def test_cell_merge_names(tmpdir): - length = 16 - img_col_test_bed = MockImageColumn( - length=length, tmpdir=tmpdir, use_cell_column=True - ) - dp1 = DataPanel.from_batch( - { - "dicom_id": np.arange(length), - "dicom": img_col_test_bed.col, - } - ) - rows = np.arange(4, 8) - dp2 = DataPanel.from_batch( - { - "dicom_id": rows, - } - ) - - out = dp1.merge(dp2, on="dicom_id", how="inner") - assert isinstance(out["dicom"], ImageCellColumn) - assert [str(cell.filepath) for cell in out["dicom"].data] == [ - img_col_test_bed.image_paths[row] for row in rows - ] - - -def test_check_merge_columns(): - length = 16 - # check dictionary not hashable - dp1 = DataPanel.from_batch( - { - "a": ListColumn([{"a": 1}] * length), - "b": list(np.arange(length)), - } - ) - dp2 = dp1.copy() - with pytest.raises(MergeError): - dp1.merge(dp2, on=["a"]) - - # check multi-on - with pytest.raises(MergeError): - dp1.merge(dp2, on=["a", "b"]) - - # check multi-dimensional numpy array - dp1 = DataPanel.from_batch( - { - "a": NumpyArrayColumn(np.stack([np.arange(5)] * length)), - "b": list(np.arange(length)), - } - ) - dp2 = dp1.copy() - with pytest.raises(MergeError): - dp1.merge(dp2, on="a") - - # check multi-dimensional numpy array - dp1 = DataPanel.from_batch( - { - "a": TensorColumn(torch.stack([torch.arange(5)] * length)), - "b": list(np.arange(length)), - } - ) - dp2 = dp1.copy() - with pytest.raises(MergeError): - dp1.merge(dp2, on="a") - - # checks that **all** cells are hashable (not just the first) - dp1 = DataPanel.from_batch( - { - "a": ListColumn(["hello"] + [{"a": 1}] * (length - 1)), - "b": list(np.arange(length)), - } - ) - dp2 = dp1.copy() - with pytest.raises(MergeError): - dp1.merge(dp2, on="a") - - # checks if Cells in cell columns are NOT hashable - dp1 = DataPanel.from_batch( - { - "a": ImageCellColumn.from_filepaths(["a"] * length), - "b": list(np.arange(length)), - } - ) - dp2 = dp1.copy() - with pytest.raises(MergeError): - dp1.merge(dp2, on="a") - - # checks that having a column called __right_indices__ raises a merge error - dp1 = DataPanel.from_batch( - { - "a": ListColumn(["hello"] + [{"a": 1}] * (length - 1)), - "b": list(np.arange(length)), - "__right_indices__": list(np.arange(length)), - } - ) - dp2 = dp1.copy() - with pytest.raises(MergeError): - dp1.merge(dp2, on="__right_indices__") + def test_merge_output_column_types(self): + dp1 = DataPanel.from_batch( + {"a": np.arange(3), "b": ListColumn(["1", "2", "3"])} + ) + dp2 = dp1.copy() + + out = dp1.merge(dp2, on="b", how="inner") + assert isinstance(out["b"], ListColumn) + + def test_image_merge(self, tmpdir): + length = 16 + img_col_test_bed = MockImageColumn(length=length, tmpdir=tmpdir) + dp1 = DataPanel.from_batch( + { + "a": np.arange(length), + "img": img_col_test_bed.col, + } + ) + rows = np.arange(4, 8) + dp2 = DataPanel.from_batch( + { + "a": rows, + } + ) + + out = dp1.merge(dp2, on="a", how="inner") + assert isinstance(out["img"], ImageColumn) + assert [str(fp) for fp in out["img"].data] == [ + img_col_test_bed.image_paths[row] for row in rows + ] + + def test_cell_merge(self, tmpdir): + length = 16 + img_col_test_bed = MockImageColumn( + length=length, tmpdir=tmpdir, use_cell_column=True + ) + dp1 = DataPanel.from_batch( + { + "a": np.arange(length), + "img": img_col_test_bed.col, + } + ) + rows = np.arange(4, 8) + dp2 = DataPanel.from_batch( + { + "a": rows, + } + ) + + out = dp1.merge(dp2, on="a", how="inner") + assert isinstance(out["img"], ImageCellColumn) + assert [str(cell.filepath) for cell in out["img"].data] == [ + img_col_test_bed.image_paths[row] for row in rows + ] + + def test_cell_merge_names(self, tmpdir): + length = 16 + img_col_test_bed = MockImageColumn( + length=length, tmpdir=tmpdir, use_cell_column=True + ) + dp1 = DataPanel.from_batch( + { + "dicom_id": np.arange(length), + "dicom": img_col_test_bed.col, + } + ) + rows = np.arange(4, 8) + dp2 = DataPanel.from_batch( + { + "dicom_id": rows, + } + ) + + out = dp1.merge(dp2, on="dicom_id", how="inner") + assert isinstance(out["dicom"], ImageCellColumn) + assert [str(cell.filepath) for cell in out["dicom"].data] == [ + img_col_test_bed.image_paths[row] for row in rows + ] + + def test_no_on(self): + length = 16 + # check dictionary not hashable + dp1 = DataPanel.from_batch( + { + "a": ListColumn([{"a": 1}] * length), + "b": list(np.arange(length)), + } + ) + dp2 = dp1.copy() + with pytest.raises(MergeError): + dp1.merge(dp2) + + def test_check_merge_columns(self): + length = 16 + # check dictionary not hashable + dp1 = DataPanel.from_batch( + { + "a": ListColumn([{"a": 1}] * length), + "b": list(np.arange(length)), + } + ) + dp2 = dp1.copy() + with pytest.raises(MergeError): + dp1.merge(dp2, on=["a"]) + + # check multi-on + with pytest.raises(MergeError): + dp1.merge(dp2, on=["a", "b"]) + + # check multi-dimensional numpy array + dp1 = DataPanel.from_batch( + { + "a": NumpyArrayColumn(np.stack([np.arange(5)] * length)), + "b": list(np.arange(length)), + } + ) + dp2 = dp1.copy() + with pytest.raises(MergeError): + dp1.merge(dp2, on="a") + + # check multi-dimensional numpy array + dp1 = DataPanel.from_batch( + { + "a": TensorColumn(torch.stack([torch.arange(5)] * length)), + "b": list(np.arange(length)), + } + ) + dp2 = dp1.copy() + with pytest.raises(MergeError): + dp1.merge(dp2, on="a") + + # checks that **all** cells are hashable (not just the first) + dp1 = DataPanel.from_batch( + { + "a": ListColumn(["hello"] + [{"a": 1}] * (length - 1)), + "b": list(np.arange(length)), + } + ) + dp2 = dp1.copy() + with pytest.raises(MergeError): + dp1.merge(dp2, on="a") + + # checks if Cells in cell columns are NOT hashable + dp1 = DataPanel.from_batch( + { + "a": ImageCellColumn.from_filepaths(["a"] * length), + "b": list(np.arange(length)), + } + ) + dp2 = dp1.copy() + with pytest.raises(MergeError): + dp1.merge(dp2, on="a") + + # checks that having a column called __right_indices__ raises a merge error + dp1 = DataPanel.from_batch( + { + "a": ListColumn(["hello"] + [{"a": 1}] * (length - 1)), + "b": list(np.arange(length)), + "__right_indices__": list(np.arange(length)), + } + ) + dp2 = dp1.copy() + with pytest.raises(MergeError): + dp1.merge(dp2, on="__right_indices__") diff --git a/tests/meerkat/pipelines/test_entitydatapanel.py b/tests/meerkat/pipelines/test_entitydatapanel.py index ab0b2a032..ea92785e6 100644 --- a/tests/meerkat/pipelines/test_entitydatapanel.py +++ b/tests/meerkat/pipelines/test_entitydatapanel.py @@ -118,7 +118,7 @@ def test_find_similar(k): sim = ent.most_similar("x", k) assert isinstance(sim, EntityDataPanel) assert len(sim) == k - assert sim.column_names == ["a", "b", "c", "d", "e", "f", "g", "index"] + assert sim.columns == ["a", "b", "c", "d", "e", "f", "g", "index"] @pytest.mark.parametrize( @@ -150,7 +150,7 @@ def test_find_similar_multiple_columns(k): ) assert isinstance(sim, EntityDataPanel) assert len(sim) == k - assert sim.column_names == ["a", "b", "c", "d", "e", "f", "g", "index", "embs2"] + assert sim.columns == ["a", "b", "c", "d", "e", "f", "g", "index", "embs2"] def test_convert_entities_to_ids(): @@ -202,7 +202,7 @@ def test_append_entities(): "g": data["g"], "h": [3, 4, 5], } - assert ent3.column_names == ["a", "b", "c", "d", "e", "f", "g", "index", "h"] + assert ent3.columns == ["a", "b", "c", "d", "e", "f", "index", "g", "h"] assert ent3["h"].tolist() == gold_data["h"] assert ent3["c"]._data == gold_data["c"] assert ent3._index_column == "c" @@ -273,8 +273,8 @@ def test_merge_entities(): "g_y": data["g"], "h": [3, 4, 5], } - # import pdb; pdb.set_trace() - assert set(ent3.all_columns) == set( + + assert set(ent3.columns) == set( [ "a", "b", @@ -316,20 +316,22 @@ def test_merge_entities(): "g_y": data["g"], "h": ["x", "y", "z"], } - # import pdb; pdb.set_trace() - assert ent3.column_names == [ - "c_x", - "h", - "g_x", - "a", - "b", - "c_y", - "d", - "e", - "f", - "g_y", - "index", - ] + + assert set(ent3.columns) == set( + [ + "c_x", + "h", + "index", + "g_x", + "a", + "b", + "c_y", + "d", + "e", + "f", + "g_y", + ] + ) assert ent3._index_column == "c_x" assert ent3._embedding_columns == ["g_x", "g_y"] for c in ["a", "b", "c_x", "c_y", "d", "e", "f", "g_x", "g_y", "h"]: @@ -359,20 +361,22 @@ def test_merge_entities(): "g_y": data["g"], "h": [3, 4, 5], } - # import pdb; pdb.set_trace() - assert ent3.column_names == [ - "a", - "b", - "c", - "d_x", - "e", - "f", - "g_x", - "d_y", - "h", - "g_y", - "index", - ] + + assert set(ent3.columns) == set( + [ + "a", + "b", + "c", + "d_x", + "e", + "f", + "g_x", + "d_y", + "h", + "g_y", + "index", + ] + ) assert ent3._index_column == "c" assert ent3._embedding_columns == ["g_x", "g_y"] for c in ["a", "b", "c", "d_x", "d_y", "e", "f", "g_x", "g_y", "h"]: @@ -406,19 +410,21 @@ def test_merge_entities(): ), "c_y": [3, 5], } - assert ent3.column_names == [ - "a", - "b", - "c_x", - "d_x", - "e", - "f", - "g", - "d_y", - "c_y", - "i", - "index", - ] + assert set(ent3.columns) == set( + [ + "a", + "b", + "c_x", + "d_x", + "e", + "f", + "g", + "d_y", + "c_y", + "i", + "index", + ] + ) assert ent3._index_column == "c_x" assert ent3._embedding_columns == ["g", "i"] for c in ["a", "b", "c_x", "c_y", "d_x", "d_y", "e", "f", "g", "i"]: diff --git a/tests/meerkat/test_datapanel.py b/tests/meerkat/test_datapanel.py index d47063db6..f87dc5222 100644 --- a/tests/meerkat/test_datapanel.py +++ b/tests/meerkat/test_datapanel.py @@ -12,6 +12,7 @@ import ujson as json from meerkat import NumpyArrayColumn +from meerkat.block.manager import BlockManager from meerkat.columns.abstract import AbstractColumn from meerkat.columns.list_column import ListColumn from meerkat.columns.pandas_column import PandasSeriesColumn @@ -21,6 +22,7 @@ from .columns.test_image_column import ImageColumnTestBed from .columns.test_numpy_column import NumpyArrayColumnTestBed from .columns.test_pandas_column import PandasSeriesColumnTestBed +from .columns.test_tensor_column import TensorColumnTestBed class DataPanelTestBed: @@ -32,6 +34,7 @@ class DataPanelTestBed: DEFAULT_COLUMN_CONFIGS = { "np": {"testbed_class": NumpyArrayColumnTestBed, "n": 2}, "pd": {"testbed_class": PandasSeriesColumnTestBed, "n": 2}, + "torch": {"testbed_class": TensorColumnTestBed, "n": 2}, "img": {"testbed_class": ImageColumnTestBed, "n": 2}, } @@ -42,16 +45,30 @@ def __init__( length: int = 16, tmpdir: str = None, ): - self.column_testbeds = {} + self.column_testbeds = self._build_column_testbeds( + column_configs, length=length, tmpdir=tmpdir + ) + + self.columns = { + name: testbed.col for name, testbed in self.column_testbeds.items() + } + self.dp = DataPanel.from_batch(self.columns) + if consolidated: + self.dp.consolidate() + + def _build_column_testbeds( + self, column_configs: Dict[str, AbstractColumn], length: int, tmpdir: str + ): def _get_tmpdir(name): path = os.path.join(tmpdir, name) os.makedirs(path) return path + column_testbeds = {} for name, config in column_configs.items(): params = config["testbed_class"].get_params(**config.get("kwargs", {})) - self.column_testbeds.update( + column_testbeds.update( { f"{name}_{col_id}_{idx}": config["testbed_class"]( **col_config[1], @@ -63,14 +80,7 @@ def _get_tmpdir(name): for col_config, col_id in zip(params["argvalues"], params["ids"]) } ) - - self.columns = { - name: testbed.col for name, testbed in self.column_testbeds.items() - } - self.dp = DataPanel.from_batch(self.columns) - - if consolidated: - self.dp.consolidate() + return column_testbeds @classmethod def get_params( @@ -112,9 +122,17 @@ def get_params( @classmethod @wraps(pytest.mark.parametrize) - def parametrize(cls, config: dict = None, params: dict = None): + def parametrize( + cls, + config: dict = None, + column_configs: Sequence[Dict] = None, + params: dict = None, + ): return pytest.mark.parametrize( - **cls.get_params(config=config, params=params), indirect=["testbed"] + **cls.get_params( + config=config, params=params, column_configs=column_configs + ), + indirect=["testbed"], ) @@ -177,8 +195,8 @@ def test_row_index_single(self, testbed): params={ "index_type": [ np.array, - # pd.Series, - # torch.Tensor, + pd.Series, + torch.Tensor, NumpyArrayColumn, PandasSeriesColumn, TensorColumn, @@ -189,16 +207,32 @@ def test_row_index_multiple(self, testbed, index_type): dp = testbed.dp rows = np.arange(len(dp)) + def convert_to_index_type(index, dtype): + index = index_type(index) + if index_type == torch.Tensor: + return index.to(dtype) + return index + # slice index => multiple row selection (DataPanel) # tuple or list index => multiple row selection (DataPanel) # np.array indeex => multiple row selection (DataPanel) for rows, indices in ( (dp[1:3], rows[1:3]), (dp[[0, 2]], rows[[0, 2]]), - (dp[index_type(np.array((0,)))], rows[np.array((0,))]), - (dp[index_type(np.array((1, 1)))], rows[np.array((1, 1))]), ( - dp[index_type(np.array((True, False) * (len(dp) // 2)))], + dp[convert_to_index_type(np.array((0,)), dtype=int)], + rows[np.array((0,))], + ), + ( + dp[convert_to_index_type(np.array((1, 1)), dtype=int)], + rows[np.array((1, 1))], + ), + ( + dp[ + convert_to_index_type( + np.array((True, False) * (len(dp) // 2)), dtype=bool + ) + ], rows[np.array((True, False) * (len(dp) // 2))], ), ): @@ -240,8 +274,8 @@ def test_row_lz_index_single(self, testbed): params={ "index_type": [ np.array, - # pd.Series, - # torch.Tensor, + pd.Series, + torch.Tensor, NumpyArrayColumn, PandasSeriesColumn, TensorColumn, @@ -252,16 +286,32 @@ def test_row_lz_index_multiple(self, testbed, index_type): dp = testbed.dp rows = np.arange(len(dp)) + def convert_to_index_type(index, dtype): + index = index_type(index) + if index_type == torch.Tensor: + return index.to(dtype) + return index + # slice index => multiple row selection (DataPanel) # tuple or list index => multiple row selection (DataPanel) # np.array indeex => multiple row selection (DataPanel) for rows, indices in ( (dp.lz[1:3], rows[1:3]), (dp.lz[[0, 2]], rows[[0, 2]]), - (dp.lz[index_type(np.array((0,)))], rows[np.array((0,))]), - (dp.lz[index_type(np.array((1, 1)))], rows[np.array((1, 1))]), ( - dp.lz[index_type(np.array((True, False) * (len(dp) // 2)))], + dp.lz[convert_to_index_type(np.array((0,)), dtype=int)], + rows[np.array((0,))], + ), + ( + dp.lz[convert_to_index_type(np.array((1, 1)), dtype=int)], + rows[np.array((1, 1))], + ), + ( + dp.lz[ + convert_to_index_type( + np.array((True, False) * (len(dp) // 2)), dtype=bool + ) + ], rows[np.array((True, False) * (len(dp) // 2))], ), ): @@ -281,6 +331,42 @@ def test_row_lz_index_multiple(self, testbed, index_type): if value.__class__ == dp[key].__class__: assert dp[key]._clone(data=data).is_equal(value) + @DataPanelTestBed.parametrize() + def test_invalid_indices(self, testbed): + dp = testbed.dp + index = ["nonexistent_column"] + missing_cols = set(index) - set(dp.columns) + with pytest.raises( + KeyError, match=f"DataPanel does not have columns {missing_cols}" + ): + dp[index] + + dp = testbed.dp + index = "nonexistent_column" + with pytest.raises(KeyError, match=f"Column `{index}` does not exist."): + dp[index] + + dp = testbed.dp + index = np.zeros((len(dp), 10)) + with pytest.raises( + ValueError, match="Index must have 1 axis, not {}".format(len(index.shape)) + ): + dp[index] + + dp = testbed.dp + index = torch.zeros((len(dp), 10)) + with pytest.raises( + ValueError, match="Index must have 1 axis, not {}".format(len(index.shape)) + ): + dp[index] + + dp = testbed.dp + index = {"a": 1} + with pytest.raises( + TypeError, match="Invalid index type: {}".format(type(index)) + ): + dp[index] + @DataPanelTestBed.parametrize() def test_col_indexing_view_copy_semantics(self, testbed): dp = testbed.dp @@ -397,6 +483,18 @@ def func(x): for key, map_spec in map_specs.items(): assert result[key].is_equal(map_spec["expected_result"]) + @DataPanelTestBed.parametrize( + column_configs={"img": {"testbed_class": ImageColumnTestBed, "n": 2}}, + params={"batched": [True, False], "materialize": [True, False]}, + ) + def test_map_return_multiple_img_only( + self, testbed: DataPanelTestBed, batched: bool, materialize: bool + ): + testbed.dp.remove_column("index") + self.test_map_return_multiple( + testbed=testbed, batched=batched, materialize=materialize + ) + @DataPanelTestBed.parametrize( params={ "batched": [True, False], @@ -601,7 +699,7 @@ def test_append_columns(self): assert len(out) == len(dp) * 2 assert isinstance(out, DataPanel) - assert set(out.visible_columns) == set(dp.visible_columns) + assert set(out.columns) == set(dp.columns) assert (out["a"].data == np.concatenate([np.arange(length)] * 2)).all() assert out["b"].data == list(np.concatenate([np.arange(length)] * 2)) @@ -612,7 +710,7 @@ def test_tail(self, testbed): new_dp = dp.tail(n=2) assert isinstance(new_dp, DataPanel) - assert new_dp.visible_columns == dp.visible_columns + assert new_dp.columns == dp.columns assert len(new_dp) == 2 @DataPanelTestBed.parametrize() @@ -622,7 +720,7 @@ def test_head(self, testbed): new_dp = dp.head(n=2) assert isinstance(new_dp, DataPanel) - assert new_dp.visible_columns == dp.visible_columns + assert new_dp.columns == dp.columns assert len(new_dp) == 2 class DataPanelSubclass(DataPanel): @@ -657,7 +755,7 @@ def test_from_csv(self): pd.DataFrame(data).to_csv(temp_f.name) dp_new = DataPanel.from_csv(temp_f.name) - assert dp_new.column_names == ["Unnamed: 0", "a", "b", "c", "index"] + assert dp_new.columns == ["Unnamed: 0", "a", "b", "c", "index"] # Skip index column for k in data: if isinstance(dp_new[k], PandasSeriesColumn): @@ -680,7 +778,7 @@ def test_from_jsonl(self): out_f.write(json.dumps(to_write) + "\n") dp_new = DataPanel.from_jsonl(temp_f.name) - assert dp_new.column_names == ["a", "b", "c", "index"] + assert dp_new.columns == ["a", "b", "c", "index"] # Skip index column for k in data: if isinstance(dp_new[k], NumpyArrayColumn): @@ -702,7 +800,7 @@ def test_from_batch(self): "f": np.ones(3), }, ) - assert set(datapanel.column_names) == {"a", "b", "c", "d", "e", "f", "index"} + assert set(datapanel.columns) == {"a", "b", "c", "d", "e", "f", "index"} assert len(datapanel) == 3 def test_to_pandas(self): @@ -724,7 +822,7 @@ def test_to_pandas(self): df = dp.to_pandas() assert isinstance(df, pd.DataFrame) - assert list(df.columns) == dp.visible_columns + assert list(df.columns) == dp.columns assert len(df) == len(dp) assert (df["a"].values == dp["a"].data).all() @@ -734,3 +832,92 @@ def test_to_pandas(self): assert (df["d"].values == dp["d"].numpy()).all() assert (df["e"].values == dp["e"].values).all() + + def test_constructo(self): + length = 16 + + # from dictionary + data = { + "a": np.arange(length), + "b": ListColumn(np.arange(length)), + } + dp = DataPanel(data=data) + assert len(dp) == length + assert dp["a"].is_equal(NumpyArrayColumn(np.arange(length))) + + # from BlockManager + mgr = BlockManager.from_dict(data) + dp = DataPanel(data=mgr) + assert len(dp) == length + assert dp["a"].is_equal(NumpyArrayColumn(np.arange(length))) + assert dp.columns == ["a", "b", "index"] + + # from list of dictionaries + data = [{"a": idx, "b": str(idx)} for idx in range(length)] + dp = DataPanel(data=data) + assert len(dp) == length + assert dp["a"].is_equal(NumpyArrayColumn(np.arange(length))) + assert dp.columns == ["a", "b", "index"] + + # from nothing + dp = DataPanel() + assert len(dp) == 0 + + def test_constructor_w_invalid_data(self): + with pytest.raises( + ValueError, + match=f"Cannot set DataPanel `data` to object of type {type(5)}.", + ): + DataPanel(data=5) + + def test_constructor_w_invalid_sequence(self): + data = list(range(4)) + with pytest.raises( + ValueError, + match="Cannot set DataPanel `data` to a Sequence containing object of " + f" type {type(data[0])}. Must be a Sequence of Mapping.", + ): + DataPanel(data=data) + + def test_constructor_w_unequal_lengths(self): + length = 16 + data = { + "a": np.arange(length), + "b": ListColumn(np.arange(length - 1)), + } + with pytest.raises( + ValueError, + match=( + f"Cannot add column 'b' with length {length - 1} to `BlockManager` " + f" with length {length} columns." + ), + ): + DataPanel(data=data) + + def test_shape(self): + length = 16 + data = { + "a": np.arange(length), + "b": ListColumn(np.arange(length)), + } + dp = DataPanel(data) + assert dp.shape == (16, 3) + + @DataPanelTestBed.parametrize() + def test_streamlit(self, testbed): + testbed.dp.streamlit() + + @DataPanelTestBed.parametrize() + def test_str(self, testbed): + result = str(testbed.dp) + assert isinstance(result, str) + + @DataPanelTestBed.parametrize() + def test_repr(self, testbed): + result = repr(testbed.dp) + assert isinstance(result, str) + + @DataPanelTestBed.parametrize() + def test_repr_pandas(self, testbed): + df = testbed.dp._repr_pandas_() + assert isinstance(df, pd.DataFrame) diff --git a/tests/meerkat/test_provenance.py b/tests/meerkat/test_provenance.py index ea46ad8b5..0ff2ba263 100644 --- a/tests/meerkat/test_provenance.py +++ b/tests/meerkat/test_provenance.py @@ -5,8 +5,9 @@ import meerkat as mk from meerkat.datapanel import DataPanel -from meerkat.provenance import ( # ProvenanceOpNode, +from meerkat.provenance import ( ProvenanceObjNode, + ProvenanceOpNode, capture_provenance, provenance, ) @@ -24,46 +25,22 @@ def test_obj_del(): assert node.ref() is None -# def test_from_batch(): -# -# with provenance(): -# dp = DataPanel.from_batch( -# { -# "x": np.arange(4), -# } -# ) -# assert isinstance(dp.node.last_parent[0], ProvenanceOpNode) -# assert dp.node.last_parent[0].name == "DataPanel.from_batch" -# assert dp.node.last_parent[0] - - -def test_from_batch_no_provenance(): - - with provenance(enabled=False): - dp = DataPanel.from_batch( +def test_map(): + with provenance(): + dp1 = DataPanel.from_batch( { "x": np.arange(4), } ) - assert dp.node.last_parent is None - - -# def test_map(): -# with provenance(): -# dp1 = DataPanel.from_batch( -# { -# "x": np.arange(4), -# } -# ) -# -# dp2 = dp1.map(lambda x: {"z": x["x"] + 1}, batched=True, batch_size=2) -# -# assert isinstance(dp2.node.last_parent[0], ProvenanceOpNode) -# assert dp2.node.last_parent[1] == tuple() -# assert dp2.node.last_parent[0].name == "MappableMixin.map" -# assert isinstance(dp1.node.last_parent[0], ProvenanceOpNode) -# assert dp1.node.children[-1][1] == ("self",) -# assert dp1.node.children[-1][0].name == "MappableMixin.map" + + dp2 = dp1.map(lambda x: {"z": x["x"] + 1}, is_batched_fn=True, batch_size=2) + + assert isinstance(dp2.node.last_parent[0], ProvenanceOpNode) + assert dp2.node.last_parent[1] == tuple() + assert dp2.node.last_parent[0].name == "DataPanel.map" + assert isinstance(dp1.node.last_parent[0], ProvenanceOpNode) + assert dp1.node.children[-1][1] == ("self",) + assert dp1.node.children[-1][0].name == "DataPanel.map" @capture_provenance(capture_args=["x"]) @@ -72,80 +49,100 @@ def custom_fn(dp1, dp2, x): return {"dp": dp3, "x": x}, dp2 -# def test_custom_fn(): -# -# with provenance(): -# dp1 = DataPanel.from_batch( -# { -# "x": np.arange(4), -# } -# ) -# dp2 = DataPanel.from_batch( -# { -# "y": np.arange(4), -# } -# ) -# d, _ = custom_fn(dp1, dp2, x="abc") -# dp3 = d["dp"] -# -# assert isinstance(dp3.node.last_parent[0], ProvenanceOpNode) -# assert dp3.node.last_parent[0].name == "custom_fn" -# assert dp3.node.last_parent[1] == (0, "dp") -# -# assert isinstance(dp1.node.last_parent[0], ProvenanceOpNode) -# assert dp1.node.children[-1][0].name == "custom_fn" -# assert dp1.node.children[-1][1] == ("dp1",) -# -# assert isinstance(dp2.node.last_parent[0], ProvenanceOpNode) -# assert dp2.node.children[-1][0].name == "custom_fn" -# assert dp2.node.children[-1][1] == ("dp2",) -# -# custom_op = dp3.node.last_parent[0] -# # test that dp2, which was passed in to `custom_fn`, is not in children -# assert custom_op.children == [ -# (dp3.node, (0, "dp")), -# *[(dp3[key].node, (0, "dp", key)) for key in dp3.keys()], -# ] -# -# custom_op.captured_args["x"] == "abc" - - -# def test_get_provenance(): -# -# with provenance(): -# dp1 = DataPanel.from_batch( -# { -# "x": np.arange(4), -# } -# ) -# dp2 = DataPanel.from_batch( -# { -# "y": np.arange(4), -# } -# ) -# d, _ = custom_fn(dp1, dp2, x="abc") -# dp3 = d["dp"] -# -# nodes, edges = dp3.get_provenance(include_columns=False, last_parent_only=False) -# assert len(nodes) == 8 -# assert sum([isinstance(node, ProvenanceObjNode) for node in nodes]) == 3 -# assert sum([isinstance(node, ProvenanceOpNode) for node in nodes]) == 5 -# assert len(edges) == 9 -# -# nodes, edges = dp3.get_provenance(include_columns=False, last_parent_only=True) -# assert len(nodes) == 6 -# assert sum([isinstance(node, ProvenanceObjNode) for node in nodes]) == 3 -# assert sum([isinstance(node, ProvenanceOpNode) for node in nodes]) == 3 -# assert len(edges) == 5 -# -# nodes, edges = dp3.get_provenance(include_columns=True, last_parent_only=True) -# assert len(nodes) == 10 -# assert sum([isinstance(node, ProvenanceObjNode) for node in nodes]) == 7 -# assert sum([isinstance(node, ProvenanceOpNode) for node in nodes]) == 3 -# assert len(edges) == 13 -# -# nodes, edges = dp3.get_provenance(include_columns=True, last_parent_only=False) -# assert len(nodes) == 16 -# assert sum([isinstance(node, ProvenanceObjNode) for node in nodes]) == 7 -# assert sum([isinstance(node, ProvenanceOpNode) for node in nodes]) == 9 -# assert len(edges) == 28 +def test_custom_fn(): + + with provenance(): + dp1 = DataPanel.from_batch( + { + "x": np.arange(4), + } + ) + dp2 = DataPanel.from_batch( + { + "y": np.arange(4), + } + ) + d, _ = custom_fn(dp1, dp2, x="abc") + dp3 = d["dp"] + + assert isinstance(dp3.node.last_parent[0], ProvenanceOpNode) + assert dp3.node.last_parent[0].name == "custom_fn" + assert dp3.node.last_parent[1] == (0, "dp") + + assert isinstance(dp1.node.last_parent[0], ProvenanceOpNode) + assert dp1.node.children[-1][0].name == "custom_fn" + assert dp1.node.children[-1][1] == ("dp1",) + + assert isinstance(dp2.node.last_parent[0], ProvenanceOpNode) + assert dp2.node.children[-1][0].name == "custom_fn" + assert dp2.node.children[-1][1] == ("dp2",) + + custom_op = dp3.node.last_parent[0] + # test that dp2, which was passed in to `custom_fn`, is not in children + assert custom_op.children == [ + (dp3.node, (0, "dp")), + *[(dp3[key].node, (0, "dp", key)) for key in dp3.keys()], + ] + + custom_op.captured_args["x"] == "abc" + + +def test_get_provenance(): + + with provenance(): + dp1 = DataPanel.from_batch( + { + "x": np.arange(4), + } + ) + dp2 = DataPanel.from_batch( + { + "y": np.arange(4), + } + ) + d, _ = custom_fn(dp1, dp2, x="abc") + dp3 = d["dp"] + + nodes, edges = dp3.get_provenance(include_columns=False, last_parent_only=False) + assert len(nodes) == 7 + assert sum([isinstance(node, ProvenanceObjNode) for node in nodes]) == 3 + assert sum([isinstance(node, ProvenanceOpNode) for node in nodes]) == 4 + assert len(edges) == 8 + + nodes, edges = dp3.get_provenance(include_columns=False, last_parent_only=True) + assert len(nodes) == 6 + assert sum([isinstance(node, ProvenanceObjNode) for node in nodes]) == 3 + assert sum([isinstance(node, ProvenanceOpNode) for node in nodes]) == 3 + assert len(edges) == 5 + + nodes, edges = dp3.get_provenance(include_columns=True, last_parent_only=True) + assert len(nodes) == 10 + assert sum([isinstance(node, ProvenanceObjNode) for node in nodes]) == 7 + assert sum([isinstance(node, ProvenanceOpNode) for node in nodes]) == 3 + assert len(edges) == 13 + + nodes, edges = dp3.get_provenance(include_columns=True, last_parent_only=False) + assert len(nodes) == 11 + assert sum([isinstance(node, ProvenanceObjNode) for node in nodes]) == 7 + assert sum([isinstance(node, ProvenanceOpNode) for node in nodes]) == 4 + assert len(edges) == 20 + + +def test_repr(): + with provenance(): + dp1 = DataPanel.from_batch( + { + "x": np.arange(4), + } + ) + dp2 = DataPanel.from_batch( + { + "y": np.arange(4), + } + ) + d, _ = custom_fn(dp1, dp2, x="abc") + dp3 = d["dp"] + + assert repr(dp1) == "DataPanel(nrows: 4, ncols: 2)" + assert repr(dp2) == "DataPanel(nrows: 4, ncols: 2)" + assert repr(dp3) == "DataPanel(nrows: 4, ncols: 3)"