diff --git a/src/safeds/data/tabular/containers/_row.py b/src/safeds/data/tabular/containers/_row.py index a0049471b..f6fd2f3f0 100644 --- a/src/safeds/data/tabular/containers/_row.py +++ b/src/safeds/data/tabular/containers/_row.py @@ -1,29 +1,26 @@ from __future__ import annotations -from hashlib import md5 from typing import TYPE_CHECKING, Any -import pandas as pd -from IPython.core.display_functions import DisplayHandle, display -from pandas.core.util.hashing import hash_pandas_object +import polars as pl from safeds.data.tabular.exceptions import UnknownColumnNameError from safeds.data.tabular.typing import ColumnType, Schema if TYPE_CHECKING: - from collections.abc import Iterable, Iterator + from collections.abc import Iterator class Row: """ A row is a collection of values, where each value is associated with a column name. - Parameters - ---------- - data : Iterable - The data. - schema : Schema - The schema of the row. + To create a row manually, use the static method [from_dict][safeds.data.tabular.containers._row.Row.from_dict]. + + Examples + -------- + >>> from safeds.data.tabular.containers import Row + >>> row = Row.from_dict({"a": 1, "b": 2}) """ # ------------------------------------------------------------------------------------------------------------------ @@ -44,60 +41,111 @@ def from_dict(data: dict[str, Any]) -> Row: ------- row : Row The generated row. + + Examples + -------- + >>> from safeds.data.tabular.containers import Row + >>> row = Row.from_dict({"a": 1, "b": 2}) """ - row_frame = pd.DataFrame([data.values()], columns=list(data.keys())) - # noinspection PyProtectedMember - return Row(data.values(), Schema._from_pandas_dataframe(row_frame)) + return Row(pl.DataFrame(data)) # ------------------------------------------------------------------------------------------------------------------ # Dunder methods # ------------------------------------------------------------------------------------------------------------------ - def __init__(self, data: Iterable, schema: Schema | None = None): - self._data: pd.Series = data if isinstance(data, pd.Series) else pd.Series(data) - self._data = self._data.reset_index(drop=True) + def __init__(self, data: pl.DataFrame, schema: Schema | None = None): + """ + Initialize a row from a `polars.DataFrame`. + + **Do not use this method directly.** It is not part of the public interface and may change in the future + without a major version bump. Use the static method + [from_dict][safeds.data.tabular.containers._row.Row.from_dict] instead. + + Parameters + ---------- + data : polars.DataFrame + The data. + schema : Schema | None + The schema. If None, the schema is inferred from the data. + """ + self._data: pl.DataFrame = data self._schema: Schema if schema is not None: self._schema = schema else: - column_names = [f"column_{i}" for i in range(len(self._data))] - dataframe = self._data.to_frame().T - dataframe.columns = column_names # noinspection PyProtectedMember - self._schema = Schema._from_pandas_dataframe(dataframe) + self._schema = Schema._from_polars_dataframe(self._data) def __eq__(self, other: Any) -> bool: if not isinstance(other, Row): return NotImplemented if self is other: return True - return self._schema == other._schema and self._data.equals(other._data) + return self._schema == other._schema and self._data.frame_equal(other._data) def __getitem__(self, column_name: str) -> Any: - return self.get_value(column_name) + """ + Return the value of a specified column. - def __hash__(self) -> int: - data_hash_string = md5(hash_pandas_object(self._data, index=True).values).hexdigest() - column_names_frozenset = frozenset(self.get_column_names()) + Parameters + ---------- + column_name : str + The column name. + + Returns + ------- + value : Any + The value of the column. - return hash((data_hash_string, column_names_frozenset)) + Raises + ------ + UnknownColumnNameError + If the row does not contain the specified column. + + Examples + -------- + >>> from safeds.data.tabular.containers import Row + >>> row = Row.from_dict({"a": 1, "b": 2}) + >>> row["a"] + 1 + """ + return self.get_value(column_name) def __iter__(self) -> Iterator[Any]: return iter(self.get_column_names()) def __len__(self) -> int: - return len(self._data) + """ + Return the number of columns in this row. + + Returns + ------- + count : int + The number of columns. + + Examples + -------- + >>> from safeds.data.tabular.containers import Row + >>> row = Row.from_dict({"a": 1, "b": 2}) + >>> len(row) + 2 + """ + return self._data.shape[1] def __repr__(self) -> str: - tmp = self._data.to_frame().T - tmp.columns = self.get_column_names() - return tmp.__repr__() + return f"Row({str(self)})" def __str__(self) -> str: - tmp = self._data.to_frame().T - tmp.columns = self.get_column_names() - return tmp.__str__() + match len(self): + case 0: + return "{}" + case 1: + return str(self.to_dict()) + case _: + lines = (f" {name!r}: {value!r}" for name, value in self.to_dict().items()) + joined = ",\n".join(lines) + return f"{{\n{joined}\n}}" # ------------------------------------------------------------------------------------------------------------------ # Properties @@ -112,6 +160,12 @@ def schema(self) -> Schema: ------- schema : Schema The schema. + + Examples + -------- + >>> from safeds.data.tabular.containers import Row + >>> row = Row.from_dict({"a": 1, "b": 2}) + >>> schema = row.schema """ return self._schema @@ -130,20 +184,30 @@ def get_value(self, column_name: str) -> Any: Returns ------- - value : + value : Any The value of the column. + + Raises + ------ + UnknownColumnNameError + If the row does not contain the specified column. + + Examples + -------- + >>> from safeds.data.tabular.containers import Row + >>> row = Row.from_dict({"a": 1, "b": 2}) + >>> row.get_value("a") + 1 """ - if not self._schema.has_column(column_name): + if not self.has_column(column_name): raise UnknownColumnNameError([column_name]) - # noinspection PyProtectedMember - return self._data[self._schema._get_column_index(column_name)] + + return self._data[0, column_name] def has_column(self, column_name: str) -> bool: """ Return whether the row contains a given column. - Alias for self.schema.hasColumn(column_name: str) -> bool. - Parameters ---------- column_name : str @@ -151,29 +215,42 @@ def has_column(self, column_name: str) -> bool: Returns ------- - contains : bool + has_column : bool True, if row contains the column. + + Examples + -------- + >>> from safeds.data.tabular.containers import Row + >>> row = Row.from_dict({"a": 1, "b": 2}) + >>> row.has_column("a") + True + + >>> row.has_column("c") + False """ return self._schema.has_column(column_name) def get_column_names(self) -> list[str]: """ - Return a list of all column names saved in this schema. - - Alias for self.schema.get_column_names() -> list[str]. + Return a list of all column names in the row. Returns ------- column_names : list[str] The column names. + + Examples + -------- + >>> from safeds.data.tabular.containers import Row + >>> row = Row.from_dict({"a": 1, "b": 2}) + >>> row.get_column_names() + ['a', 'b'] """ return self._schema.get_column_names() def get_type_of_column(self, column_name: str) -> ColumnType: """ - Return the type of a specified column. - - Alias for self.schema.get_type_of_column(column_name: str) -> ColumnType. + Return the type of the specified column. Parameters ---------- @@ -187,8 +264,15 @@ def get_type_of_column(self, column_name: str) -> ColumnType: Raises ------ - ColumnNameError - If the specified target column name does not exist. + UnknownColumnNameError + If the row does not contain the specified column. + + Examples + -------- + >>> from safeds.data.tabular.containers import Row + >>> row = Row.from_dict({"a": 1, "b": 2}) + >>> row.get_type_of_column("a") + Integer """ return self._schema.get_type_of_column(column_name) @@ -204,8 +288,15 @@ def count(self) -> int: ------- count : int The number of columns. + + Examples + -------- + >>> from safeds.data.tabular.containers import Row + >>> row = Row.from_dict({"a": 1, "b": 2}) + >>> row.count() + 2 """ - return len(self._data) + return self._data.shape[1] # ------------------------------------------------------------------------------------------------------------------ # Conversion @@ -219,6 +310,13 @@ def to_dict(self) -> dict[str, Any]: ------- data : dict[str, Any] Dictionary representation of the row. + + Examples + -------- + >>> from safeds.data.tabular.containers import Row + >>> row = Row.from_dict({"a": 1, "b": 2}) + >>> row.to_dict() + {'a': 1, 'b': 2} """ return {column_name: self.get_value(column_name) for column_name in self.get_column_names()} @@ -226,17 +324,14 @@ def to_dict(self) -> dict[str, Any]: # IPython integration # ------------------------------------------------------------------------------------------------------------------ - def _ipython_display_(self) -> DisplayHandle: + def _repr_html_(self) -> str: """ - Return a display object for the column to be used in Jupyter Notebooks. + Return an HTML representation of the row. Returns ------- - output : DisplayHandle - Output object. + output : str + The generated HTML. """ - tmp = self._data.to_frame().T - tmp.columns = self.get_column_names() - - with pd.option_context("display.max_rows", tmp.shape[0], "display.max_columns", tmp.shape[1]): - return display(tmp) + # noinspection PyProtectedMember + return self._data._repr_html_() diff --git a/src/safeds/data/tabular/containers/_table.py b/src/safeds/data/tabular/containers/_table.py index f9a4a115a..0f6d52263 100644 --- a/src/safeds/data/tabular/containers/_table.py +++ b/src/safeds/data/tabular/containers/_table.py @@ -8,9 +8,10 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd +import polars as pl import seaborn as sns from IPython.core.display_functions import DisplayHandle, display -from pandas import DataFrame, Series +from pandas import DataFrame from scipy import stats from safeds.data.image.containers import Image @@ -203,14 +204,14 @@ def from_rows(rows: list[Row]) -> Table: raise MissingDataError("This function requires at least one row.") schema_compare: Schema = rows[0]._schema - row_array: list[Series] = [] + row_array: list[pd.DataFrame] = [] for row in rows: if schema_compare != row._schema: raise SchemaMismatchError - row_array.append(row._data) + row_array.append(row._data.to_pandas()) - dataframe: DataFrame = pd.DataFrame(row_array) + dataframe: DataFrame = pd.concat(row_array, ignore_index=True) dataframe.columns = schema_compare.get_column_names() return Table(dataframe) @@ -387,7 +388,8 @@ def get_row(self, index: int) -> Row: """ if len(self._data.index) - 1 < index or index < 0: raise IndexOutOfBoundsError(index) - return Row(self._data.iloc[[index]].squeeze(), self._schema) + + return Row(pl.DataFrame(self._data.iloc[[index]]), self._schema) # ------------------------------------------------------------------------------------------------------------------ # Information @@ -543,8 +545,7 @@ def add_row(self, row: Row) -> Table: if self._schema != row.schema: raise SchemaMismatchError - row_frame = row._data.to_frame().T - row_frame.columns = self.get_column_names() + row_frame = row._data.to_pandas() new_df = pd.concat([self._data, row_frame]).infer_objects() new_df.columns = self.get_column_names() @@ -571,7 +572,7 @@ def add_rows(self, rows: list[Row] | Table) -> Table: if self._schema != row.schema: raise SchemaMismatchError - row_frames = [row._data.to_frame().T for row in rows] + row_frames = [row._data.to_pandas() for row in rows] for row_frame in row_frames: row_frame.columns = self.get_column_names() @@ -1186,7 +1187,10 @@ def to_rows(self) -> list[Row]: rows : list[Row] List of rows. """ - return [Row(series_row, self._schema) for (_, series_row) in self._data.iterrows()] + return [ + Row(pl.DataFrame([list(series_row)], schema=self._schema.get_column_names()), self._schema) + for (_, series_row) in self._data.iterrows() + ] # ------------------------------------------------------------------------------------------------------------------ # IPython integration diff --git a/tests/safeds/data/tabular/containers/_table/test_add_row.py b/tests/safeds/data/tabular/containers/_table/test_add_row.py index 0832949c7..91ec0280b 100644 --- a/tests/safeds/data/tabular/containers/_table/test_add_row.py +++ b/tests/safeds/data/tabular/containers/_table/test_add_row.py @@ -1,12 +1,11 @@ from _pytest.python_api import raises from safeds.data.tabular.containers import Row, Table from safeds.data.tabular.exceptions import SchemaMismatchError -from safeds.data.tabular.typing import Integer, Schema, String def test_add_row_valid() -> None: table1 = Table.from_dict({"col1": [1, 2, 1], "col2": [1, 2, 4]}) - row = Row([5, 6], table1.schema) + row = Row.from_dict({"col1": 5, "col2": 6}) table1 = table1.add_row(row) assert table1.count_rows() == 4 assert table1.get_row(3) == row @@ -15,9 +14,6 @@ def test_add_row_valid() -> None: def test_add_row_invalid() -> None: table1 = Table.from_dict({"col1": [1, 2, 1], "col2": [1, 2, 4]}) - row = Row( - [5, "Hallo"], - Schema({"col1": Integer(), "col2": String()}), - ) + row = Row.from_dict({"col1": 5, "col2": "Hallo"}) with raises(SchemaMismatchError): table1 = table1.add_row(row) diff --git a/tests/safeds/data/tabular/containers/_table/test_add_rows.py b/tests/safeds/data/tabular/containers/_table/test_add_rows.py index d4687f154..62d87bbdb 100644 --- a/tests/safeds/data/tabular/containers/_table/test_add_rows.py +++ b/tests/safeds/data/tabular/containers/_table/test_add_rows.py @@ -1,10 +1,27 @@ +import polars as pl from safeds.data.tabular.containers import Row, Table def test_add_rows_valid() -> None: table1 = Table.from_dict({"col1": ["a", "b", "c"], "col2": [1, 2, 4]}) - row1 = Row(["d", 6], table1.schema) - row2 = Row(["e", 8], table1.schema) + row1 = Row( + pl.DataFrame( + { + "col1": "d", + "col2": 6, + }, + ), + table1.schema, + ) + row2 = Row( + pl.DataFrame( + { + "col1": "e", + "col2": 8, + }, + ), + table1.schema, + ) table1 = table1.add_rows([row1, row2]) assert table1.count_rows() == 5 assert table1.get_row(3) == row1 @@ -15,8 +32,24 @@ def test_add_rows_valid() -> None: def test_add_rows_table_valid() -> None: table1 = Table.from_dict({"col1": [1, 2, 1], "col2": [1, 2, 4]}) - row1 = Row([5, 6], table1.schema) - row2 = Row([7, 8], table1.schema) + row1 = Row( + pl.DataFrame( + { + "col1": 5, + "col2": 6, + }, + ), + table1.schema, + ) + row2 = Row( + pl.DataFrame( + { + "col1": 7, + "col2": 8, + }, + ), + table1.schema, + ) table2 = Table.from_rows([row1, row2]) table1 = table1.add_rows(table2) assert table1.count_rows() == 5 diff --git a/tests/safeds/data/tabular/containers/_table/test_to_rows.py b/tests/safeds/data/tabular/containers/_table/test_to_rows.py index 22de2a837..12b7e1a23 100644 --- a/tests/safeds/data/tabular/containers/_table/test_to_rows.py +++ b/tests/safeds/data/tabular/containers/_table/test_to_rows.py @@ -1,3 +1,4 @@ +import polars as pl from safeds.data.tabular.containers import Row, Table from safeds.data.tabular.typing import Integer, Schema, String @@ -19,9 +20,9 @@ def test_to_rows() -> None: }, ) rows_expected = [ - Row([1, 4, "d"], expected_schema), - Row([2, 5, "e"], expected_schema), - Row([3, 6, "f"], expected_schema), + Row(pl.DataFrame({"A": 1, "B": 4, "D": "d"}), expected_schema), + Row(pl.DataFrame({"A": 2, "B": 5, "D": "e"}), expected_schema), + Row(pl.DataFrame({"A": 3, "B": 6, "D": "f"}), expected_schema), ] rows_is = table.to_rows() diff --git a/tests/safeds/data/tabular/containers/test_row.py b/tests/safeds/data/tabular/containers/test_row.py index 2a71c1c28..4c62e0988 100644 --- a/tests/safeds/data/tabular/containers/test_row.py +++ b/tests/safeds/data/tabular/containers/test_row.py @@ -1,5 +1,6 @@ from typing import Any +import polars as pl import pytest from safeds.data.tabular.containers import Row, Table from safeds.data.tabular.exceptions import UnknownColumnNameError @@ -12,16 +13,20 @@ class TestFromDict: [ ( {}, - Row([]), + Row(pl.DataFrame()), ), ( { "a": 1, "b": 2, }, - Row([1, 2], schema=Schema({"a": Integer(), "b": Integer()})), + Row(pl.DataFrame({"a": 1, "b": 2})), ), ], + ids=[ + "empty", + "non-empty", + ], ) def test_should_create_row_from_dict(self, data: dict[str, Any], expected: Row) -> None: assert Row.from_dict(data) == expected @@ -31,13 +36,24 @@ class TestInit: @pytest.mark.parametrize( ("row", "expected"), [ - (Row([], Schema({})), Schema({})), - (Row([0], Schema({"col1": Integer()})), Schema({"col1": Integer()})), ( - Row([0, "a"], Schema({"col1": Integer(), "col2": String()})), + Row(pl.DataFrame(), Schema({})), + Schema({}), + ), + ( + Row(pl.DataFrame({"col1": 0}), Schema({"col1": Integer()})), + Schema({"col1": Integer()}), + ), + ( + Row(pl.DataFrame({"col1": 0, "col2": "a"}), Schema({"col1": Integer(), "col2": String()})), Schema({"col1": Integer(), "col2": String()}), ), ], + ids=[ + "empty", + "one column", + "two columns", + ], ) def test_should_use_the_schema_if_passed(self, row: Row, expected: Schema) -> None: assert row._schema == expected @@ -45,8 +61,12 @@ def test_should_use_the_schema_if_passed(self, row: Row, expected: Schema) -> No @pytest.mark.parametrize( ("row", "expected"), [ - (Row([]), Schema({})), - (Row([0]), Schema({"column_0": Integer()})), + (Row(pl.DataFrame()), Schema({})), + (Row(pl.DataFrame({"col1": 0})), Schema({"col1": Integer()})), + ], + ids=[ + "empty", + "one column", ], ) def test_should_infer_the_schema_if_not_passed(self, row: Row, expected: Schema) -> None: @@ -63,16 +83,41 @@ class TestEq: (Row.from_dict({"col1": 0}), Row.from_dict({"col2": 0}), False), (Row.from_dict({"col1": 0}), Row.from_dict({"col1": "a"}), False), ], + ids=[ + "empty rows", + "equal rows", + "different values", + "different columns", + "different types", + ], ) def test_should_return_whether_two_rows_are_equal(self, row1: Row, row2: Row, expected: bool) -> None: assert (row1.__eq__(row2)) == expected + @pytest.mark.parametrize( + "row", + [ + Row.from_dict({}), + Row.from_dict({"col1": 0}), + ], + ids=[ + "empty", + "non-empty", + ], + ) + def test_should_return_true_if_objects_are_identical(self, row: Row) -> None: + assert (row.__eq__(row)) is True + @pytest.mark.parametrize( ("row", "other"), [ (Row.from_dict({"col1": 0}), None), (Row.from_dict({"col1": 0}), Table([])), ], + ids=[ + "Row vs. None", + "Row vs. Table", + ], ) def test_should_return_not_implemented_if_other_is_not_row(self, row: Row, other: Any) -> None: assert (row.__eq__(other)) is NotImplemented @@ -85,6 +130,10 @@ class TestGetitem: (Row.from_dict({"col1": 0}), "col1", 0), (Row.from_dict({"col1": 0, "col2": "a"}), "col2", "a"), ], + ids=[ + "one column", + "two columns", + ], ) def test_should_return_the_value_in_the_column(self, row: Row, column_name: str, expected: Any) -> None: assert row[column_name] == expected @@ -95,6 +144,10 @@ def test_should_return_the_value_in_the_column(self, row: Row, column_name: str, (Row.from_dict({}), "col1"), (Row.from_dict({"col1": 0}), "col2"), ], + ids=[ + "empty row", + "column does not exist", + ], ) def test_should_raise_if_column_does_not_exist(self, row: Row, column_name: str) -> None: with pytest.raises(UnknownColumnNameError): @@ -102,51 +155,72 @@ def test_should_raise_if_column_does_not_exist(self, row: Row, column_name: str) row[column_name] -class TestHash: +class TestIter: @pytest.mark.parametrize( - ("row1", "row2"), + ("row", "expected"), [ - (Row.from_dict({}), Row.from_dict({})), - (Row.from_dict({"col1": 0}), Row.from_dict({"col1": 0})), + (Row.from_dict({}), []), + (Row.from_dict({"col1": 0}), ["col1"]), + ], + ids=[ + "empty", + "non-empty", ], ) - def test_should_return_same_hash_for_equal_rows(self, row1: Row, row2: Row) -> None: - assert hash(row1) == hash(row2) + def test_should_return_an_iterator_for_the_column_names(self, row: Row, expected: list[str]) -> None: + assert list(row) == expected + +class TestLen: @pytest.mark.parametrize( - ("row1", "row2"), + ("row", "expected"), [ - (Row.from_dict({"col1": 0}), Row.from_dict({"col1": 1})), - (Row.from_dict({"col1": 0}), Row.from_dict({"col2": 0})), - (Row.from_dict({"col1": 0}), Row.from_dict({"col1": "a"})), + (Row.from_dict({}), 0), + (Row.from_dict({"col1": 0, "col2": "a"}), 2), + ], + ids=[ + "empty", + "non-empty", ], ) - def test_should_return_different_hash_for_unequal_rows(self, row1: Row, row2: Row) -> None: - assert hash(row1) != hash(row2) + def test_should_return_the_number_of_columns(self, row: Row, expected: int) -> None: + assert len(row) == expected -class TestIter: +class TestStr: @pytest.mark.parametrize( ("row", "expected"), [ - (Row.from_dict({}), []), - (Row.from_dict({"col1": 0}), ["col1"]), + (Row.from_dict({}), "{}"), + (Row.from_dict({"col1": 0}), "{'col1': 0}"), + (Row.from_dict({"col1": 0, "col2": "a"}), "{\n 'col1': 0,\n 'col2': 'a'\n}"), + ], + ids=[ + "empty", + "single column", + "multiple columns", ], ) - def test_should_return_an_iterator_for_the_column_names(self, row: Row, expected: list[str]) -> None: - assert list(row) == expected + def test_should_return_a_string_representation(self, row: Row, expected: str) -> None: + assert str(row) == expected -class TestLen: +class TestRepr: @pytest.mark.parametrize( ("row", "expected"), [ - (Row.from_dict({}), 0), - (Row.from_dict({"col1": 0, "col2": "a"}), 2), + (Row.from_dict({}), "Row({})"), + (Row.from_dict({"col1": 0}), "Row({'col1': 0})"), + (Row.from_dict({"col1": 0, "col2": "a"}), "Row({\n 'col1': 0,\n 'col2': 'a'\n})"), + ], + ids=[ + "empty", + "single column", + "multiple columns", ], ) - def test_should_return_the_number_of_columns(self, row: Row, expected: int) -> None: - assert len(row) == expected + def test_should_return_a_string_representation(self, row: Row, expected: str) -> None: + assert repr(row) == expected class TestGetValue: @@ -156,6 +230,10 @@ class TestGetValue: (Row.from_dict({"col1": 0}), "col1", 0), (Row.from_dict({"col1": 0, "col2": "a"}), "col2", "a"), ], + ids=[ + "one column", + "two columns", + ], ) def test_should_return_the_value_in_the_column(self, row: Row, column_name: str, expected: Any) -> None: assert row.get_value(column_name) == expected @@ -166,6 +244,10 @@ def test_should_return_the_value_in_the_column(self, row: Row, column_name: str, (Row.from_dict({}), "col1"), (Row.from_dict({"col1": 0}), "col2"), ], + ids=[ + "empty row", + "column does not exist", + ], ) def test_should_raise_if_column_does_not_exist(self, row: Row, column_name: str) -> None: with pytest.raises(UnknownColumnNameError): @@ -180,6 +262,11 @@ class TestHasColumn: (Row.from_dict({"col1": 0}), "col1", True), (Row.from_dict({"col1": 0}), "col2", False), ], + ids=[ + "empty row", + "column exists", + "column does not exist", + ], ) def test_should_return_whether_the_row_has_the_column(self, row: Row, column_name: str, expected: bool) -> None: assert row.has_column(column_name) == expected @@ -192,6 +279,10 @@ class TestGetColumnNames: (Row.from_dict({}), []), (Row.from_dict({"col1": 0}), ["col1"]), ], + ids=[ + "empty", + "non-empty", + ], ) def test_should_return_the_column_names(self, row: Row, expected: list[str]) -> None: assert row.get_column_names() == expected @@ -204,6 +295,10 @@ class TestGetTypeOfColumn: (Row.from_dict({"col1": 0}), "col1", Integer()), (Row.from_dict({"col1": 0, "col2": "a"}), "col2", String()), ], + ids=[ + "one column", + "two columns", + ], ) def test_should_return_the_type_of_the_column(self, row: Row, column_name: str, expected: ColumnType) -> None: assert row.get_type_of_column(column_name) == expected @@ -214,6 +309,10 @@ def test_should_return_the_type_of_the_column(self, row: Row, column_name: str, (Row.from_dict({}), "col1"), (Row.from_dict({"col1": 0}), "col2"), ], + ids=[ + "empty row", + "column does not exist", + ], ) def test_should_raise_if_column_does_not_exist(self, row: Row, column_name: str) -> None: with pytest.raises(UnknownColumnNameError): @@ -227,6 +326,10 @@ class TestCount: (Row.from_dict({}), 0), (Row.from_dict({"col1": 0, "col2": "a"}), 2), ], + ids=[ + "empty", + "non-empty", + ], ) def test_should_return_the_number_of_columns(self, row: Row, expected: int) -> None: assert row.count() == expected @@ -237,17 +340,37 @@ class TestToDict: ("row", "expected"), [ ( - Row([]), + Row(pl.DataFrame({})), {}, ), ( - Row([1, 2], schema=Schema({"a": Integer(), "b": Integer()})), + Row(pl.DataFrame({"a": 1, "b": 2})), { "a": 1, "b": 2, }, ), ], + ids=[ + "empty", + "non-empty", + ], ) def test_should_return_dict_for_table(self, row: Row, expected: dict[str, Any]) -> None: assert row.to_dict() == expected + + +class TestReprHtml: + @pytest.mark.parametrize( + "row", + [ + Row(pl.DataFrame({})), + Row(pl.DataFrame({"a": 1, "b": 2})), + ], + ids=[ + "empty", + "non-empty", + ], + ) + def test_should_contain_table_element(self, row: Row) -> None: + assert "