From 70e6d01ac1201b80a17529710ad730a9f73812e8 Mon Sep 17 00:00:00 2001 From: Emmanuel Jordy Menvouta <56538317+emmanueljordy@users.noreply.github.com> Date: Thu, 22 May 2025 20:45:33 +0200 Subject: [PATCH] Update data_processor.py Modify post_process function to deal with out of bounds generated values. --- synthpop/processor/data_processor.py | 52 +++++++++++++++++++++++++--- 1 file changed, 47 insertions(+), 5 deletions(-) diff --git a/synthpop/processor/data_processor.py b/synthpop/processor/data_processor.py index ee857ef..2116cb6 100644 --- a/synthpop/processor/data_processor.py +++ b/synthpop/processor/data_processor.py @@ -32,6 +32,7 @@ def __init__(self, metadata, enforce_rounding=True, enforce_min_max_values=True, self.encoders = {} # Stores encoders for categorical columns self.scalers = {} # Stores scalers for numerical columns self.original_columns = None # To restore column order + self._original_dtypes = None # Store original dtypes def preprocess(self, data: pd.DataFrame) -> pd.DataFrame: """Transform the raw data into numerical space.""" @@ -43,6 +44,7 @@ def preprocess(self, data: pd.DataFrame) -> pd.DataFrame: self.validate(data) self.original_columns = data.columns # Store original column order + self._original_dtypes = data.dtypes # Store original dtypes processed_data = self._preprocess(data) return processed_data @@ -88,6 +90,12 @@ def postprocess(self, synthetic_data: pd.DataFrame) -> pd.DataFrame: elif dtype == "numerical" and col in self.scalers: scaler = self.scalers[col] synthetic_data[col] = scaler.inverse_transform(synthetic_data[[col]]) + + # Restore original dtype for numerical columns + if self._original_dtypes is not None: + original_dtype = self._original_dtypes[col] + if np.issubdtype(original_dtype, np.integer): + synthetic_data[col] = synthetic_data[col].round().astype(original_dtype) elif dtype == "boolean": synthetic_data[col] = synthetic_data[col].round().astype(bool) @@ -123,13 +131,47 @@ def _encode_categorical(self, series: pd.Series, encoder): encoded_df = pd.DataFrame(encoded_array, columns=encoder.get_feature_names_out([series.name])) return encoded_df - def _decode_categorical(self, series: pd.Series, encoder): - """Decode categorical columns.""" + def _decode_categorical(self, encoded: pd.Series or pd.DataFrame, encoder): + """ + Decode categorical columns, snapping any out‐of‐range codes back to the nearest + valid category (or to NaN), so novel copula values won't blow up. + """ + # LABEL ENCODER CASE if isinstance(encoder, LabelEncoder): - return encoder.inverse_transform(series.astype(int)) + # Pull out the raw numeric codes (may be floats from copula) + codes = np.rint(encoded.astype(float)).astype(int) + max_idx = len(encoder.classes_) - 1 + + # Any code outside [0, max_idx] → -1 sentinel + safe_codes = np.where((codes >= 0) & (codes <= max_idx), codes, -1) + + # Map valid codes back to labels, sentinel→NaN + decoded = [ + encoder.classes_[c] if c >= 0 else np.nan + for c in safe_codes + ] + return pd.Series(decoded, index=getattr(encoded, "index", None)) + + # ONE-HOT ENCODER CASE elif isinstance(encoder, OneHotEncoder): - category_index = np.argmax(series.values, axis=1) - return encoder.categories_[0][category_index] + # Ensure a 2D array of one-hot "scores" + arr = encoded.values if isinstance(encoded, pd.DataFrame) else np.asarray(encoded) + if arr.ndim == 1: + # If someone passed a flat Series, assume the first category axis: + n_cat = len(encoder.categories_[0]) + arr = arr.reshape(-1, n_cat) + + # Argmax and clip into [0, n_cat-1] + idx = np.argmax(arr, axis=1) + max_idx = len(encoder.categories_[0]) - 1 + idx = np.clip(idx, 0, max_idx) + + # Look up the category labels + cats = encoder.categories_[0] + return pd.Series(cats[idx], index=getattr(encoded, "index", None)) + + else: + raise TypeError(f"Unsupported encoder type: {type(encoder)}") def _handle_missing_values(self, series: pd.Series): """Handle missing values based on column type."""