Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 37 additions & 29 deletions src/midst_toolkit/models/clavaddpm/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import os
import pickle
from collections import OrderedDict, defaultdict
from collections import defaultdict
from logging import INFO, WARNING
from pathlib import Path
from typing import Any
Expand Down Expand Up @@ -233,7 +233,7 @@ def _pair_clustering(

cluster_labels = _get_cluster_labels(cluster_data, clustering_method, num_clusters)

child_group_data = _get_group_data(sorted_child_data, foreign_key_index)
child_group_data = group_data_by_id(sorted_child_data, foreign_key_index, sort_by_column_value=True)
child_group_lengths = np.array([len(group) for group in child_group_data], dtype=int)

if clustering_method == ClusteringMethod.VARIATIONAL:
Expand Down Expand Up @@ -313,12 +313,12 @@ def _merge_parent_data_with_child_data(
child_data: Numpy array of the child data. Should be sorted by the foreign key.
parent_data: Numpy array of the parent data. Should be sorted by the parent primary key.
parent_primary_key_index: Index of the parent primary key.
foreign_key_index: Index of the foreign key to the child data.
foreign_key_index: Index of the foreign key in the child data.

Returns:
Numpy array of the parent data merged for each group of the child group data.
"""
child_group_data_dict = _group_data_by_group_id(child_data, foreign_key_index)
child_group_data_dict = group_data_by_group_id_as_dict(child_data, foreign_key_index)

group_lengths = []
unique_group_ids = parent_data[:, parent_primary_key_index]
Expand Down Expand Up @@ -669,47 +669,55 @@ def _get_categorical_and_numerical_columns(
return numerical_columns, categorical_columns


def _group_data_by_group_id(
np_data: np.ndarray,
group_id_index: int,
def group_data_by_group_id_as_dict(
data_to_be_grouped: np.ndarray, column_index_to_group_by: int
) -> dict[int, list[np.ndarray]]:
"""
Collects the data in each group by group id and returns it as a dictionary.
Group rows in a numpy array by their values in the column specified by ``column_index_to_group_by`` into a
dictionary. Returns a dict where keys are values from the column to group by and values are lists of
corresponding rows (groups).

Args:
np_data: Numpy array of the data.
group_id_index: The index of the data that contains the group id.
data_to_be_grouped: Numpy array of the data to be grouped.
column_index_to_group_by: Column index by which the data should be grouped.

Returns:
Dictionary of group data by group id.
Dictionary of group data where the keys are values from the column to group by and the values
are a list of full ROWS from the ``data_to_be_grouped`` where the specified column value is shared.
"""
group_data_by_group_id = OrderedDict[int, list[np.ndarray]]()

for i in range(len(np_data)):
group_id = _parse_numpy_number_as_int(np_data[i, group_id_index])

if group_id not in group_data_by_group_id:
group_data_by_group_id[group_id] = []
grouped_data_dict: defaultdict[int, list[np.ndarray]] = defaultdict(list)
num_rows = len(data_to_be_grouped)
for row in range(num_rows):
row_id = _parse_numpy_number_as_int(data_to_be_grouped[row, column_index_to_group_by])
grouped_data_dict[row_id].append(data_to_be_grouped[row])

group_data_by_group_id[group_id].append(np_data[i])
return grouped_data_dict

return group_data_by_group_id


def _get_group_data(np_data: np.ndarray, group_id_index: int) -> np.ndarray:
def group_data_by_id(
data_to_be_grouped: np.ndarray, column_index_to_group_by: int, sort_by_column_value: bool = False
) -> np.ndarray:
"""
Collects the data in each group by group id and returns it as a numpy array.
Group rows in a numpy array that share values in the column specified by ``column_index_to_group_by``.
Returns an array of arrays where each sub-array contains full rows sharing identical values in the grouping column.

Args:
np_data: Numpy array of the data.
group_id_index: The index of the data that contains the group id.
data_to_be_grouped: Numpy array of the data to be grouped.
column_index_to_group_by: Column index by which the data should be grouped.
sort_by_column_value: Whether or not the returned groups are sorted by the values in the column the index
``column_index_to_group_by``. Defaults to False.

Returns:
Numpy array of the data ordered by group id.
Numpy array of the data grouped by values in the column with index ``column_index_to_group_by``. The returned
array has dtype=object since groups may have different lengths.
"""
group_data_by_group_id = _group_data_by_group_id(np_data, group_id_index)
group_data_list = [np.array(group_data) for group_data in group_data_by_group_id.values()]
return np.array(group_data_list, dtype=object)
grouped_data_by_group_id = group_data_by_group_id_as_dict(data_to_be_grouped, column_index_to_group_by)
if sort_by_column_value:
grouped_data = [(key, np.array(group_data)) for key, group_data in grouped_data_by_group_id.items()]
grouped_data_list = [data for _, data in sorted(grouped_data)]
else:
grouped_data_list = [np.array(group_data) for group_data in grouped_data_by_group_id.values()]
return np.array(grouped_data_list, dtype=object)


def _parse_numpy_number_as_int(number: np.number) -> int:
Expand Down
96 changes: 96 additions & 0 deletions tests/unit/models/clavaddpm/test_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
_min_max_normalize_sklearn,
_quantile_normalize_sklearn,
get_normalized_numerical_columns,
group_data_by_group_id_as_dict,
group_data_by_id,
)
from midst_toolkit.models.clavaddpm.enumerations import DataAndKeyNormalizationType

Expand Down Expand Up @@ -81,3 +83,97 @@ def test_get_normalized_numerical_columns() -> None:
)

unset_all_random_seeds()


def test_group_data_by_id() -> None:
set_all_random_seeds(42)
data_array_with_one_foreign_keys = np.hstack(
(np.random.randn(10, 3), np.random.randint(0, 3, (10, 1)).astype(float), np.random.randn(10, 1))
)
data_array_with_foreign_key_in_front = np.hstack(
(np.random.randint(0, 2, (10, 1)).astype(float), np.random.randn(10, 3), np.random.randn(10, 1))
)

grouped_data = group_data_by_id(data_array_with_one_foreign_keys, 3)
assert len(grouped_data) == 3
assert len(grouped_data[0]) == 4
assert len(grouped_data[1]) == 2
assert len(grouped_data[2]) == 4
assert np.allclose(
grouped_data[0],
np.array(
[
[0.49671415, -0.1382643, 0.64768854, 2.0, 2.77831304],
[1.52302986, -0.23415337, -0.23413696, 2.0, 1.19363972],
[0.54256004, -0.46341769, -0.46572975, 2.0, 0.88176104],
[0.24196227, -1.91328024, -1.72491783, 2.0, -1.00908534],
]
),
atol=1e-6,
)
assert np.allclose(
grouped_data[1],
np.array(
[
[1.57921282, 0.76743473, -0.46947439, 0.0, 0.21863832],
[-0.90802408, -1.4123037, 1.46564877, 0.0, 0.77370042],
],
),
atol=1e-6,
)

grouped_data = group_data_by_id(data_array_with_foreign_key_in_front, 0, sort_by_column_value=True)
# Because the first column is non-unique, we get proper groups.
assert len(grouped_data) == 2
assert len(grouped_data[0]) == 9
assert len(grouped_data[1]) == 1
assert np.allclose(
grouped_data[1],
np.array([[1.0, -0.676922, 0.61167629, 1.03099952, 1.47789404]]),
atol=1e-6,
)
assert np.allclose(
grouped_data[0],
np.array(
[
[0.0, 0.93128012, -0.83921752, -0.30921238, -0.51827022],
[0.0, 0.33126343, 0.97554513, -0.47917424, -0.8084936],
[0.0, -0.18565898, -1.10633497, -1.19620662, -0.50175704],
[0.0, 0.81252582, 1.35624003, -0.07201012, 0.91540212],
[0.0, 1.0035329, 0.36163603, -0.64511975, 0.32875111],
[0.0, 0.36139561, 1.53803657, -0.03582604, -0.5297602],
[0.0, 1.56464366, -2.6197451, 0.8219025, 0.51326743],
[0.0, 0.08704707, -0.29900735, 0.09176078, 0.09707755],
[0.0, -1.98756891, -0.21967189, 0.35711257, 0.96864499],
]
),
atol=1e-6,
)
unset_all_random_seeds()


def test_group_data_by_group_id_as_dict() -> None:
set_all_random_seeds(42)
data_array_with_one_foreign_keys = np.hstack(
(np.random.randn(10, 3), np.random.randint(0, 3, (10, 1)).astype(float), np.random.randn(10, 1))
)
data_array_with_foreign_key_in_front = np.hstack(
(np.random.randint(0, 2, (10, 1)).astype(float), np.random.randn(10, 3), np.random.randn(10, 1))
)

grouped_data = group_data_by_group_id_as_dict(data_array_with_one_foreign_keys, 3)
assert len(grouped_data) == 3
assert len(grouped_data[2]) == 4
assert len(grouped_data[0]) == 2
assert np.allclose(grouped_data[0][0], np.array([1.57921282, 0.76743473, -0.46947439, 0.0, 0.21863832]), atol=1e-6)
assert np.allclose(grouped_data[0][1], np.array([-0.90802408, -1.4123037, 1.46564877, 0.0, 0.77370042]), atol=1e-6)
assert np.allclose(
grouped_data[2][1], np.array([1.52302986, -0.23415337, -0.23413696, 2.0, 1.19363972]), atol=1e-6
)

grouped_data = group_data_by_group_id_as_dict(data_array_with_foreign_key_in_front, 0)
# Because the first column is non-unique, we get proper groups.
assert len(grouped_data) == 2
assert len(grouped_data[0]) == 9
assert len(grouped_data[1]) == 1
unset_all_random_seeds()