Skip to content

Commit

Permalink
Fix error in gene_random_walk (#375)
Browse files Browse the repository at this point in the history
train_X_ori and val_X_ori didn't get standardized, fix in this PR
  • Loading branch information
WenjieDu committed Apr 30, 2024
1 parent f332a0c commit 6446149
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions pypots/data/generating.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,6 @@ def gene_random_walk(
# create random missing values
train_X_ori = train_X
train_X = mcar(train_X, missing_rate)
val_X_ori = val_X
val_X = mcar(val_X, missing_rate)
# test set is left to mask after normalization

train_X = train_X.reshape(-1, n_features)
Expand Down Expand Up @@ -305,18 +303,21 @@ def gene_random_walk(

if missing_rate > 0:
# mask values in the test set as ground truth
test_X_ori = test_X
test_X = mcar(test_X, missing_rate)

data["train_X"] = train_X
train_X_ori = scaler.transform(train_X_ori.reshape(-1, n_features)).reshape(
-1, n_steps, n_features
)
data["train_X_ori"] = train_X_ori

val_X_ori = val_X
val_X = mcar(val_X, missing_rate)
data["val_X"] = val_X
data["val_X_ori"] = val_X_ori

# test_X is for model input
test_X_ori = test_X
test_X = mcar(test_X, missing_rate)
data["test_X"] = test_X
data["test_X_ori"] = test_X_ori
data["test_X_indicating_mask"] = ~np.isnan(test_X_ori) ^ ~np.isnan(test_X)
data["test_X_ori"] = np.nan_to_num(test_X_ori) # fill NaNs for later error calc
data["test_X_indicating_mask"] = np.isnan(test_X_ori) ^ np.isnan(test_X)

return data

Expand Down Expand Up @@ -421,7 +422,7 @@ def gene_physionet2012(artificially_missing_rate: float = 0.1):
# test_X is for model input
data["test_X"] = test_X
# test_X_ori is for error calc, not for model input, hence mustn't have NaNs
data["test_X_ori"] = np.nan_to_num(test_X_ori)
data["test_X_indicating_mask"] = ~np.isnan(test_X_ori) ^ ~np.isnan(test_X)
data["test_X_ori"] = np.nan_to_num(test_X_ori) # fill NaNs for later error calc
data["test_X_indicating_mask"] = np.isnan(test_X_ori) ^ np.isnan(test_X)

return data

0 comments on commit 6446149

Please sign in to comment.