Skip to content

Commit

Permalink
[Dataset] Fixed split code
Browse files Browse the repository at this point in the history
  • Loading branch information
YanSte committed Aug 30, 2023
1 parent 8dc3723 commit f050db2
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/skit/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,17 +245,18 @@ def stratifiedTrainValidSplit(df, x_feature_columns, y_target_columns, num_split
valid_df : DataFrame
The validation set DataFrame.
"""

# Initialize StratifiedKFold
# ----
stratifiedKFold = StratifiedKFold(n_splits=num_splits, random_state=seed, shuffle=shuffle)

# Add a new column for the fold
# ----
df["Fold"] = "train"

# Prepare the features and labels
X = df[x_feature_columns]
y = df[y_target_columns]

# Add a new column for the fold
df["Fold"] = "train"

# Perform the split
for fold_no, (train, valid) in enumerate(stratifiedKFold.split(X, y), start=1):
if fold_no == selected_fold:
Expand All @@ -264,6 +265,7 @@ def stratifiedTrainValidSplit(df, x_feature_columns, y_target_columns, num_split
# Separate into train and valid DataFrames and reset index
train_df = df[df.Fold == "train"].reset_index(drop=True)
valid_df = df[df.Fold == "valid"].reset_index(drop=True)

df.drop(columns=['Fold'], inplace=True)

return train_df, valid_df
Expand Down

0 comments on commit f050db2

Please sign in to comment.