diff --git a/src/abm_shape_collection/calculate_shape_stats.py b/src/abm_shape_collection/calculate_shape_stats.py index 90b476e..ecb5266 100644 --- a/src/abm_shape_collection/calculate_shape_stats.py +++ b/src/abm_shape_collection/calculate_shape_stats.py @@ -11,8 +11,9 @@ def calculate_shape_stats( ) -> pd.DataFrame: all_stats = [] - data_transform = pca.transform(data.filter(like="shcoeffs").values) - ref_data_transform = pca.transform(ref_data.filter(like="shcoeffs").values) + columns = ref_data.filter(like="shcoeffs").columns + data_transform = pca.transform(data[columns].values) + ref_data_transform = pca.transform(ref_data[columns].values) for component in range(components): ks_stats = get_ks_statistic(data_transform[:, component], ref_data_transform[:, component]) @@ -20,7 +21,7 @@ def calculate_shape_stats( all_stats.append(ks_stats) for tick, tick_data in data.groupby("TICK"): - tick_data_transform = pca.transform(tick_data.filter(like="shcoeffs").values) + tick_data_transform = pca.transform(tick_data[columns].values) for component in range(components): tick_ks_stats = get_ks_statistic(