Skip to content

Commit

Permalink
Remove visible_columns from DataPanel (#125)
Browse files Browse the repository at this point in the history
DataPanels no longer rely on `visible_columns` to create views. This PR removes `visible_columns` entirely.

Other changes:
- Improve code coverage 
  - Reactivate provenance tests
  - DataPanel batch tests
  - Concat tests
  - Merge tests
- Remove Identifiers, Splits and Info from DataPanel and `AbstractColumn`
  • Loading branch information
seyuboglu committed Aug 13, 2021
1 parent 647c6d3 commit 0440b6d
Show file tree
Hide file tree
Showing 27 changed files with 1,020 additions and 1,208 deletions.
18 changes: 8 additions & 10 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,9 @@ Note: in order to pass the full test suite (step 5), you'll need to install all

### Tests

An extensive test suite is included to test the library behavior and several examples.
An extensive test suite is included to test the library behavior.
Library tests can be found in the
[tests folder](https://github.com/robustness-gym/robustness-gym/tree/main/tests).
[tests folder](https://github.com/robustness-gym/meerkat/tree/main/tests).

From the root of the
repository, here's how to run tests with `pytest` for the library:
Expand All @@ -200,15 +200,13 @@ $ make test
You can specify a smaller set of tests in order to test only the feature
you're working on.

Meerkat uses `pytest` as a test runner only. It doesn't use any
`pytest`-specific features in the test suite itself.

This means `unittest` is fully supported. Here's how to run tests with
`unittest`:

```bash
$ python -m unittest discover -s tests -t . -v
Per the checklist above, all PRs should include high-coverage tests.
To produce a code coverage report, run the following `pytest`
```
pytest --cov-report term-missing,html --cov=meerkat .
```
This will populate a directory `htmlcov` with an HTML report.
Open `htmlcov/index.html` in a browser to view the report.


### Style guide
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
![GitHub](https://img.shields.io/github/license/robustness-gym/meerkat)
[![Documentation Status](https://readthedocs.org/projects/meerkat/badge/?version=latest)](https://meerkat.readthedocs.io/en/latest/?badge=latest)
[![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)](https://github.com/pre-commit/pre-commit)
<!---[![codecov](https://codecov.io/gh/robustness-gym/meerkat/branch/main/graph/badge.svg?token=MOLQYUSYQU)](https://codecov.io/gh/robustness-gym/meerkat)--->
[![codecov](https://codecov.io/gh/robustness-gym/meerkat/branch/main/graph/badge.svg?token=MOLQYUSYQU)](https://codecov.io/gh/robustness-gym/meerkat)

Meerkat provides fast and flexible data structures for working with complex machine learning datasets.

Expand Down
57 changes: 43 additions & 14 deletions meerkat/block/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
from collections import defaultdict
from collections.abc import MutableMapping
from typing import Dict, Sequence, Union
from typing import Dict, Mapping, Sequence, Union

import pandas as pd
import yaml
Expand All @@ -18,16 +18,20 @@ class BlockManager(MutableMapping):
"""This manager manages all blocks."""

def __init__(self) -> None:
self._columns: Dict[str, AbstractColumn] = {}
self._columns: Dict[str, AbstractColumn] = {} # ordered as of 3.7
self._column_to_block_id: Dict[str, int] = {}
self._block_refs: Dict[int, BlockRef] = {}

def update(self, block_ref: BlockRef):
"""data (): a single blockable object, potentially contains multiple
columns."""
for name in block_ref:
if name in self:
self.remove(name)

# although we can't have the same column living in multiple managers
# we don't view here because it can lead to multiple calls to clone
self._columns.update({name: column for name, column in block_ref.items()})
self._columns.update(block_ref)

block_id = id(block_ref.block)
# check if there already is a block_ref in the manager for this block
Expand Down Expand Up @@ -66,9 +70,15 @@ def apply(self, method_name: str = "_get", *args, **kwargs) -> BlockManager:

results[name] = result

if isinstance(results, BlockManager):
results.reorder(self.keys())
return results

def consolidate(self):
column_order = list(
self._columns.keys()
) # need to maintain order after consolidate

block_ref_groups = defaultdict(list)
for block_ref in self._block_refs.values():
block_ref_groups[block_ref.block.signature].append(block_ref)
Expand All @@ -78,15 +88,13 @@ def consolidate(self):
# if there is only one block ref in the group, do not consolidate
continue

# remove old block_refs
for old_ref in block_refs:
self._block_refs.pop(id(old_ref.block))

# consolidate group
block_class = block_refs[0].block.__class__
block_ref = block_class.consolidate(block_refs)
self.update(block_ref)

self.reorder(column_order)

def remove(self, name):
if name not in self._columns:
raise ValueError(f"Remove failed: no column '{name}' in BlockManager.")
Expand All @@ -103,6 +111,11 @@ def remove(self, name):

self._column_to_block_id.pop(name)

def reorder(self, order: Sequence[str]):
if set(order) != set(self._columns):
raise ValueError("Must include all columns when reordering a BlockManager.")
self._columns = {name: self._columns[name] for name in order}

def __getitem__(
self, index: Union[str, Sequence[str]]
) -> Union[AbstractColumn, BlockManager]:
Expand All @@ -128,6 +141,7 @@ def __getitem__(
for block_id, names in block_id_to_names.items():
block_ref = self._block_refs[block_id]
mgr.update(block_ref[names])
mgr.reorder(order=index)
return mgr
else:
raise ValueError(
Expand All @@ -148,6 +162,14 @@ def __delitem__(self, key):
def __len__(self):
return len(self._columns)

@property
def nrows(self):
return 0 if len(self) == 0 else len(next(iter(self._columns.values())))

@property
def ncols(self):
return len(self)

def __contains__(self, value):
return value in self._columns

Expand All @@ -160,8 +182,11 @@ def get_block_ref(self, name: str):
def add_column(self, col: AbstractColumn, name: str):
"""Convert data to a meerkat column using the appropriate Column
type."""
if name in self._columns:
self.remove(name)
if len(self) > 0 and len(col) != self.nrows:
raise ValueError(
f"Cannot add column '{name}' with length {len(col)} to `BlockManager` "
f" with length {self.nrows} columns."
)

if not col.is_blockable():
col = col.view()
Expand All @@ -172,15 +197,19 @@ def add_column(self, col: AbstractColumn, name: str):
self.update(BlockRef(columns={name: col}, block=col._block))

@classmethod
def from_dict(cls, data):
def from_dict(cls, data: Mapping[str, object]):
mgr = cls()
for name, data in data.items():
col = AbstractColumn.from_data(data)
mgr.add_column(col=col, name=name)
return mgr

def write(self, path: str):
meta = {"dtype": BlockManager, "columns": {}}
meta = {
"dtype": BlockManager,
"columns": {},
"_column_order": list(self.keys()),
}

# prepare directories
os.makedirs(path, exist_ok=True)
Expand Down Expand Up @@ -264,7 +293,7 @@ def read(
mgr.add_column(
col_meta["dtype"].read(path=column_dir, _meta=col_meta), name
)

mgr.reorder(meta["_column_order"])
return mgr

def _repr_pandas_(self):
Expand All @@ -279,13 +308,13 @@ def _repr_pandas_(self):

def view(self):
mgr = BlockManager()
for name, col in self._columns.items():
for name, col in self.items():
mgr.add_column(col.view(), name)
return mgr

def copy(self):
mgr = BlockManager()
for name, col in self._columns.items():
for name, col in self.items():
mgr.add_column(col.copy(), name)
return mgr

Expand Down
9 changes: 9 additions & 0 deletions meerkat/block/numpy_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Hashable, Mapping, Sequence, Tuple, Union

import numpy as np
import torch

from meerkat.block.ref import BlockRef
from meerkat.errors import ConsolidationError
Expand Down Expand Up @@ -109,9 +110,17 @@ def _consolidate(
}
return BlockRef(block=block, columns=new_columns)

@staticmethod
def _convert_index(index):
if torch.is_tensor(index):
# need to convert to numpy for boolean indexing
return index.numpy()
return index

def _get(
self, index, block_ref: BlockRef, materialize: bool = True
) -> Union[BlockRef, dict]:
index = self._convert_index(index)
# TODO: check if they're trying to index more than just the row dimension
data = self.data[index]
if isinstance(index, int):
Expand Down
4 changes: 4 additions & 0 deletions meerkat/block/pandas_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Hashable, Mapping, Sequence, Tuple, Union

import pandas as pd
import torch

from meerkat.block.ref import BlockRef
from meerkat.columns.numpy_column import NumpyArrayColumn
Expand Down Expand Up @@ -79,6 +80,9 @@ def _consolidate(

@staticmethod
def _convert_index(index):
if torch.is_tensor(index):
# need to convert to numpy for boolean indexing
return index.numpy()
if isinstance(index, NumpyArrayColumn):
return index.data
if isinstance(index, TensorColumn):
Expand Down
11 changes: 6 additions & 5 deletions meerkat/columns/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@
from meerkat.mixins.blockable import BlockableMixin
from meerkat.mixins.cloneable import CloneableMixin
from meerkat.mixins.collate import CollateMixin
from meerkat.mixins.identifier import IdentifierMixin
from meerkat.mixins.inspect_fn import FunctionInspectorMixin
from meerkat.mixins.io import ColumnIOMixin
from meerkat.mixins.lambdable import LambdaMixin
from meerkat.mixins.mapping import MappableMixin
from meerkat.mixins.materialize import MaterializationMixin
from meerkat.provenance import ProvenanceMixin, capture_provenance
from meerkat.tools.identifier import Identifier
from meerkat.tools.utils import convert_to_batch_column_fn

logger = logging.getLogger(__name__)
Expand All @@ -32,7 +30,6 @@ class AbstractColumn(
CollateMixin,
ColumnIOMixin,
FunctionInspectorMixin,
IdentifierMixin,
LambdaMixin,
MappableMixin,
MaterializationMixin,
Expand All @@ -46,7 +43,6 @@ class AbstractColumn(
def __init__(
self,
data: Sequence = None,
identifier: Identifier = None,
collate_fn: Callable = None,
*args,
**kwargs,
Expand All @@ -55,7 +51,6 @@ def __init__(
self._set_data(data)

super(AbstractColumn, self).__init__(
identifier=identifier,
collate_fn=collate_fn,
*args,
**kwargs,
Expand Down Expand Up @@ -180,6 +175,12 @@ def _translate_index(self, index):
if not self._is_batch_index(index):
return index

if isinstance(index, pd.Series):
index = index.values

if torch.is_tensor(index):
index = index.numpy()

# `index` should return a batch
if isinstance(index, slice):
# int or slice index => standard list slicing
Expand Down
1 change: 1 addition & 0 deletions meerkat/columns/numpy_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def _set_batch(self, indices, values):
self._data[indices] = values

def _get(self, index, materialize: bool = True):
index = NumpyBlock._convert_index(index)
data = self._data[index]
if self._is_batch_index(index):
# only create a numpy array column
Expand Down

0 comments on commit 0440b6d

Please sign in to comment.