Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions python/pyspark/pandas/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,8 @@ class DataFrame(Frame, Generic[T]):
1, when `data` is a distributed dataset (Internal DataFrame/Spark DataFrame/
pandas-on-Spark DataFrame/pandas-on-Spark Series), it will first parallize
the `index` if necessary, and then try to combine the `data` and `index`;
Note that in this case `compute.ops_on_diff_frames` should be turned on;
Note that if `data` and `index` doesn't have the same anchor, then
`compute.ops_on_diff_frames` should be turned on;
2, when `data` is a local dataset (Pandas DataFrame/numpy ndarray/list/etc),
it will first collect the `index` to driver if necessary, and then apply
the `Pandas.DataFrame(...)` creation internally;
Expand Down Expand Up @@ -527,13 +528,13 @@ def __init__( # type: ignore[no-untyped-def]
assert dtype is None
assert not copy
if index is None:
internal = data._internal.resolved_copy
internal = data._internal
elif isinstance(data, ps.Series):
assert columns is None
assert dtype is None
assert not copy
if index is None:
internal = data.to_frame()._internal.resolved_copy
internal = data.to_frame()._internal
else:
from pyspark.pandas.indexes.base import Index

Expand All @@ -558,17 +559,27 @@ def __init__( # type: ignore[no-untyped-def]
index_ps = ps.Index(index)
index_df = index_ps.to_frame()

# drop un-matched rows in `data`
# note that `combine_frames` can not work with a MultiIndex for now
combined = combine_frames(data_df, index_df, how="right")
combined_labels = combined._internal.column_labels
index_labels = [label for label in combined_labels if label[0] == "that"]
combined = combined.set_index(index_labels)

combined._internal._column_labels = data_df._internal.column_labels
combined._internal._column_label_names = data_df._internal._column_label_names
combined._internal._index_names = index_df._internal.column_labels
combined.index.name = index_ps.name
if same_anchor(data_df, index_df):
data_labels = data_df._internal.column_labels
data_pssers = [data_df._psser_for(label) for label in data_labels]
index_labels = index_df._internal.column_labels
index_pssers = [index_df._psser_for(label) for label in index_labels]
internal = data_df._internal.with_new_columns(data_pssers + index_pssers)

combined = ps.DataFrame(internal).set_index(index_labels)
combined.index.name = index_ps.name
else:
# drop un-matched rows in `data`
# note that `combine_frames` can not work with a MultiIndex for now
combined = combine_frames(data_df, index_df, how="right")
combined_labels = combined._internal.column_labels
index_labels = [label for label in combined_labels if label[0] == "that"]
combined = combined.set_index(index_labels)

combined._internal._column_labels = data_df._internal.column_labels
combined._internal._column_label_names = data_df._internal._column_label_names
combined._internal._index_names = index_df._internal.column_labels
combined.index.name = index_ps.name

internal = combined._internal

Expand Down
148 changes: 147 additions & 1 deletion python/pyspark/pandas/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
#
import decimal
from datetime import datetime
from datetime import datetime, timedelta
from distutils.version import LooseVersion
import inspect
import sys
Expand Down Expand Up @@ -343,6 +343,152 @@ def test_creation_index(self):
data=ps.DataFrame([1, 2]), index=ps.MultiIndex.from_tuples([(1, 3), (2, 4)])
)

def test_creation_index_same_anchor(self):
pdf = pd.DataFrame(
{
"a": [1, 2, None, 4],
"b": [1, None, None, 4],
"c": [1, 2, None, None],
"d": [None, 2, None, 4],
}
)
psdf = ps.from_pandas(pdf)

self.assert_eq(
ps.DataFrame(data=psdf, index=psdf.index),
pd.DataFrame(data=pdf, index=pdf.index),
)
self.assert_eq(
ps.DataFrame(data=psdf + 1, index=psdf.index),
pd.DataFrame(data=pdf + 1, index=pdf.index),
)
self.assert_eq(
ps.DataFrame(data=psdf[["a", "c"]] * 2, index=psdf.index),
pd.DataFrame(data=pdf[["a", "c"]] * 2, index=pdf.index),
)

# test String Index
pdf = pd.DataFrame(
data={"s": ["Hello", "World", "Databricks"], "x": [2002, 2003, 2004], "y": [4, 5, 6]}
)
pdf = pdf.set_index("s")
pdf.index.name = None
psdf = ps.from_pandas(pdf)

self.assert_eq(
ps.DataFrame(data=psdf, index=psdf.index),
pd.DataFrame(data=pdf, index=pdf.index),
)
self.assert_eq(
ps.DataFrame(data=psdf + 1, index=psdf.index),
pd.DataFrame(data=pdf + 1, index=pdf.index),
)
self.assert_eq(
ps.DataFrame(data=psdf[["y"]] * 2, index=psdf.index),
pd.DataFrame(data=pdf[["y"]] * 2, index=pdf.index),
)

# test DatetimeIndex
pdf = pd.DataFrame(
data={
"t": [
datetime(2022, 9, 1, 0, 0, 0, 0),
datetime(2022, 9, 2, 0, 0, 0, 0),
datetime(2022, 9, 3, 0, 0, 0, 0),
],
"x": [2002, 2003, 2004],
"y": [4, 5, 6],
}
)
pdf = pdf.set_index("t")
pdf.index.name = None
psdf = ps.from_pandas(pdf)

self.assert_eq(
ps.DataFrame(data=psdf, index=psdf.index),
pd.DataFrame(data=pdf, index=pdf.index),
)
self.assert_eq(
ps.DataFrame(data=psdf + 1, index=psdf.index),
pd.DataFrame(data=pdf + 1, index=pdf.index),
)
self.assert_eq(
ps.DataFrame(data=psdf[["y"]] * 2, index=psdf.index),
pd.DataFrame(data=pdf[["y"]] * 2, index=pdf.index),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we test it with the MultiIndex and other index types such as CategoricalIndex and TimedeltaIndex ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me add them

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool!


# test TimedeltaIndex
pdf = pd.DataFrame(
data={
"t": [
timedelta(1),
timedelta(3),
timedelta(5),
],
"x": [2002, 2003, 2004],
"y": [4, 5, 6],
}
)
pdf = pdf.set_index("t")
pdf.index.name = None
psdf = ps.from_pandas(pdf)

self.assert_eq(
ps.DataFrame(data=psdf, index=psdf.index),
pd.DataFrame(data=pdf, index=pdf.index),
)
self.assert_eq(
ps.DataFrame(data=psdf + 1, index=psdf.index),
pd.DataFrame(data=pdf + 1, index=pdf.index),
)
self.assert_eq(
ps.DataFrame(data=psdf[["y"]] * 2, index=psdf.index),
pd.DataFrame(data=pdf[["y"]] * 2, index=pdf.index),
)

# test CategoricalIndex
pdf = pd.DataFrame(
data={
"z": [-1, -2, -3, -4],
"x": [2002, 2003, 2004, 2005],
"y": [4, 5, 6, 7],
},
index=pd.CategoricalIndex(["a", "c", "b", "a"], categories=["a", "b", "c"]),
)
psdf = ps.from_pandas(pdf)

self.assert_eq(
ps.DataFrame(data=psdf, index=psdf.index),
pd.DataFrame(data=pdf, index=pdf.index),
)
self.assert_eq(
ps.DataFrame(data=psdf + 1, index=psdf.index),
pd.DataFrame(data=pdf + 1, index=pdf.index),
)
self.assert_eq(
ps.DataFrame(data=psdf[["y"]] * 2, index=psdf.index),
pd.DataFrame(data=pdf[["y"]] * 2, index=pdf.index),
)

# test distributed data with ps.MultiIndex
pdf = pd.DataFrame(
data={
"z": [-1, -2, -3, -4],
"x": [2002, 2003, 2004, 2005],
"y": [4, 5, 6, 7],
},
index=pd.MultiIndex.from_tuples([("a", "x"), ("b", "y"), ("c", "z"), ("a", "x")]),
)
psdf = ps.from_pandas(pdf)

err_msg = "Cannot combine a Distributed Dataset with a MultiIndex"
with self.assertRaisesRegex(ValueError, err_msg):
# test ps.DataFrame with ps.MultiIndex
ps.DataFrame(data=psdf, index=psdf.index)
with self.assertRaisesRegex(ValueError, err_msg):
# test ps.DataFrame with pd.MultiIndex
ps.DataFrame(data=psdf, index=pdf.index)

def _check_extension(self, psdf, pdf):
if LooseVersion("1.1") <= LooseVersion(pd.__version__) < LooseVersion("1.2.2"):
self.assert_eq(psdf, pdf, check_exact=False)
Expand Down