Skip to content

Commit

Permalink
Adding shape to DataTable and DataColumn (#358)
Browse files Browse the repository at this point in the history
* adding shape and testing

* release notes updated

* isinstance, api ref

* better/more testing

* linting
  • Loading branch information
ctduffy committed Nov 5, 2020
1 parent 4be6b85 commit e4435e9
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/source/api_reference.rst
Expand Up @@ -10,6 +10,7 @@ DataTable
:toctree: generated/

DataTable
DataTable.shape
DataTable.add_semantic_tags
DataTable.remove_semantic_tags
DataTable.reset_semantic_tags
Expand All @@ -35,6 +36,7 @@ DataColumn
:toctree: generated/

DataColumn
DataColumn.shape
DataColumn.add_semantic_tags
DataColumn.remove_semantic_tags
DataColumn.reset_semantic_tags
Expand Down
1 change: 1 addition & 0 deletions docs/source/release_notes.rst
Expand Up @@ -8,6 +8,7 @@ Release Notes
* Add ``__eq__`` to DataTable and DataColumn and update LogicalType equality (:pr:`318`)
* Add ``value_counts()`` method to DataTable (:pr:`342`)
* Support serialization and deserialization of DataTables via csv, pickle, or parquet (:pr:`293`)
* Add ``shape`` property to DataTable and DataColumn (:pr:`358`)
* Fixes
* Catch non numeric time index at validation (:pr:`332`)
* Changes
Expand Down
6 changes: 6 additions & 0 deletions woodwork/data_column.py
Expand Up @@ -280,6 +280,12 @@ def to_series(self):
"""
return self._series

@property
def shape(self):
"""Returns a tuple representing the dimensionality of the DataTable. If Dask DataFrame, returns
a Dask `Delayed` object for the number of rows."""
return self._series.shape

@property
def logical_type(self):
"""The logical type for the column"""
Expand Down
6 changes: 6 additions & 0 deletions woodwork/data_table.py
Expand Up @@ -205,6 +205,12 @@ def semantic_tags(self):
"""A dictionary containing semantic tags for each column"""
return {dc.name: dc.semantic_tags for dc in self.columns.values()}

@property
def shape(self):
"""Returns a tuple representing the dimensionality of the DataTable. If Dask DataFrame, returns
a Dask `Delayed` object for the number of rows."""
return self._dataframe.shape

@property
def index(self):
"""The index column for the table"""
Expand Down
11 changes: 11 additions & 0 deletions woodwork/tests/data_column/test_data_column.py
Expand Up @@ -436,6 +436,17 @@ def test_to_series(sample_series):
pd.testing.assert_series_equal(to_pandas(series), to_pandas(data_col._series))


def test_shape(categorical_df):
col = DataColumn(categorical_df['ints'])
assert col.shape == (9,)
assert col.shape == col.to_series().shape


def test_shape_dask(categorical_dd):
col = DataColumn(categorical_dd['ints'])
assert col.to_series().compute().shape == col.shape[0].compute()


def test_dtype_update_on_init(datetime_series):
dc = DataColumn(datetime_series,
logical_type='DateTime')
Expand Down
22 changes: 22 additions & 0 deletions woodwork/tests/data_table/test_datatable.py
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import pandas as pd
import pytest
from dask.delayed import Delayed

import woodwork as ww
from woodwork import DataColumn, DataTable
Expand Down Expand Up @@ -1557,6 +1558,27 @@ def test_datatable_clear_time_index(sample_df):
assert all(['time_index' not in col.semantic_tags for col in dt.columns.values()])


def test_shape(categorical_df, categorical_log_types):
dt = ww.DataTable(categorical_df, logical_types=categorical_log_types)
assert dt.shape == (9, 5)
assert dt.shape == dt.to_dataframe().shape

dt.pop('ints')
assert dt.shape == (9, 4)


def test_shape_dask(categorical_dd, categorical_log_types):
dt = ww.DataTable(categorical_dd, logical_types=categorical_log_types)
assert isinstance(dt.shape[0], Delayed)
assert dt.shape[1] == 5

dt.pop('bools')
assert isinstance(dt.shape[0], Delayed)
assert dt.shape[1] == 4

assert (dt.shape[0].compute(), dt.shape[1]) == (len(dt.to_dataframe()), len(dt.columns))


def test_select_invalid_inputs(sample_df):
dt = DataTable(sample_df, time_index='signup_date', index='id', name='dt_name')
dt = dt.set_logical_types({
Expand Down

0 comments on commit e4435e9

Please sign in to comment.