Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 47 additions & 5 deletions synthpop/processor/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(self, metadata, enforce_rounding=True, enforce_min_max_values=True,
self.encoders = {} # Stores encoders for categorical columns
self.scalers = {} # Stores scalers for numerical columns
self.original_columns = None # To restore column order
self._original_dtypes = None # Store original dtypes

def preprocess(self, data: pd.DataFrame) -> pd.DataFrame:
"""Transform the raw data into numerical space."""
Expand All @@ -43,6 +44,7 @@ def preprocess(self, data: pd.DataFrame) -> pd.DataFrame:

self.validate(data)
self.original_columns = data.columns # Store original column order
self._original_dtypes = data.dtypes # Store original dtypes
processed_data = self._preprocess(data)

return processed_data
Expand Down Expand Up @@ -88,6 +90,12 @@ def postprocess(self, synthetic_data: pd.DataFrame) -> pd.DataFrame:
elif dtype == "numerical" and col in self.scalers:
scaler = self.scalers[col]
synthetic_data[col] = scaler.inverse_transform(synthetic_data[[col]])

# Restore original dtype for numerical columns
if self._original_dtypes is not None:
original_dtype = self._original_dtypes[col]
if np.issubdtype(original_dtype, np.integer):
synthetic_data[col] = synthetic_data[col].round().astype(original_dtype)

elif dtype == "boolean":
synthetic_data[col] = synthetic_data[col].round().astype(bool)
Expand Down Expand Up @@ -123,13 +131,47 @@ def _encode_categorical(self, series: pd.Series, encoder):
encoded_df = pd.DataFrame(encoded_array, columns=encoder.get_feature_names_out([series.name]))
return encoded_df

def _decode_categorical(self, series: pd.Series, encoder):
"""Decode categorical columns."""
def _decode_categorical(self, encoded: pd.Series or pd.DataFrame, encoder):
"""
Decode categorical columns, snapping any out‐of‐range codes back to the nearest
valid category (or to NaN), so novel copula values won't blow up.
"""
# LABEL ENCODER CASE
if isinstance(encoder, LabelEncoder):
return encoder.inverse_transform(series.astype(int))
# Pull out the raw numeric codes (may be floats from copula)
codes = np.rint(encoded.astype(float)).astype(int)
max_idx = len(encoder.classes_) - 1

# Any code outside [0, max_idx] → -1 sentinel
safe_codes = np.where((codes >= 0) & (codes <= max_idx), codes, -1)

# Map valid codes back to labels, sentinel→NaN
decoded = [
encoder.classes_[c] if c >= 0 else np.nan
for c in safe_codes
]
return pd.Series(decoded, index=getattr(encoded, "index", None))

# ONE-HOT ENCODER CASE
elif isinstance(encoder, OneHotEncoder):
category_index = np.argmax(series.values, axis=1)
return encoder.categories_[0][category_index]
# Ensure a 2D array of one-hot "scores"
arr = encoded.values if isinstance(encoded, pd.DataFrame) else np.asarray(encoded)
if arr.ndim == 1:
# If someone passed a flat Series, assume the first category axis:
n_cat = len(encoder.categories_[0])
arr = arr.reshape(-1, n_cat)

# Argmax and clip into [0, n_cat-1]
idx = np.argmax(arr, axis=1)
max_idx = len(encoder.categories_[0]) - 1
idx = np.clip(idx, 0, max_idx)

# Look up the category labels
cats = encoder.categories_[0]
return pd.Series(cats[idx], index=getattr(encoded, "index", None))

else:
raise TypeError(f"Unsupported encoder type: {type(encoder)}")

def _handle_missing_values(self, series: pd.Series):
"""Handle missing values based on column type."""
Expand Down