Skip to content

Add multilabel stratified train/val/test splits#4

Merged
sfluegel05 merged 8 commits intodevfrom
copilot/add-stratified-splits-dataset
Feb 27, 2026
Merged

Add multilabel stratified train/val/test splits#4
sfluegel05 merged 8 commits intodevfrom
copilot/add-stratified-splits-dataset

Conversation

Copy link
Contributor

Copilot AI commented Feb 27, 2026

  • Explore repository structure and understand existing splitter
  • Add scikit-learn and iterative-stratification to pyproject.toml dependencies
  • Rewrite create_multilabel_splits with label_start_col: int = 2 — all columns from that index are bool label columns
  • Remove old create_splits function (and helpers)
  • Export only create_multilabel_splits from chebi_utils/__init__.py
  • Update tests with proper column-per-label DataFrame fixtures (12 tests, all passing)
  • Fix ruff formatting in chebi_utils/obo_extractor.py and tests/test_obo_extractor.py (pre-existing lint failures in CI)
  • Security scan (CodeQL) – no alerts
Original prompt

This section details on the original issue you should resolve

<issue_title>Stratified splits for dataset</issue_title>
<issue_description>For the dataset created with the dataset builder, I need stratified splits. There is already a splitter implementation, but it is not sufficient for our multilabel dataset.

Use iterstrat or sklearn packages to build stratified split. Here is an example for a similar library:

def get_test_split(
        self, df: pd.DataFrame, seed: Optional[int] = None
    ) -> Tuple[pd.DataFrame, pd.DataFrame]:
        """
        Split the input DataFrame into training and testing sets based on multilabel stratified sampling.

        This method uses MultilabelStratifiedShuffleSplit to split the data such that the distribution of labels
        in the training and testing sets is approximately the same. The split is based on the "labels" column
        in the DataFrame.

        Args:
            df (pd.DataFrame): The input DataFrame containing the data to be split. It must contain a column
                               named "labels" with the multilabel data.
            seed (int, optional): The random seed to be used for reproducibility. Default is None.

        Returns:
            Tuple[pd.DataFrame, pd.DataFrame]: A tuple containing the training set and testing set DataFrames.

        Raises:
            ValueError: If the DataFrame does not contain a column named "labels".
        """
        from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
        from sklearn.model_selection import StratifiedShuffleSplit

        print("Get test data split")

        labels_list = df["labels"].tolist()

        if len(labels_list[0]) > 1:
            splitter = MultilabelStratifiedShuffleSplit(
                n_splits=1, test_size=self.test_split, random_state=seed
            )
        else:
            splitter = StratifiedShuffleSplit(
                n_splits=1, test_size=self.test_split, random_state=seed
            )

        train_indices, test_indices = next(splitter.split(labels_list, labels_list))

        df_train = df.iloc[train_indices]
        df_test = df.iloc[test_indices]
        return df_train, df_test

    def get_train_val_splits_given_test(
        self, df: pd.DataFrame, test_df: pd.DataFrame, seed: int = None
    ) -> Union[Dict[str, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]:
        """
        Split the dataset into train and validation sets, given a test set.
        Use test set (e.g., loaded from another source or generated in get_test_split), to avoid overlap

        Args:
            df (pd.DataFrame): The original dataset.
            test_df (pd.DataFrame): The test dataset.
            seed (int, optional): The random seed to be used for reproducibility. Default is None.

        Returns:
            Union[Dict[str, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]: A dictionary containing train and
                validation sets if self.use_inner_cross_validation is True, otherwise a tuple containing the train
                and validation DataFrames. The keys are the names of the train and validation sets, and the values
                are the corresponding DataFrames.
        """
        from iterstrat.ml_stratifiers import (
            MultilabelStratifiedKFold,
            MultilabelStratifiedShuffleSplit,
        )
        from sklearn.model_selection import StratifiedShuffleSplit

        print("Split dataset into train / val with given test set")

        test_ids = test_df["ident"].tolist()
        df_trainval = df[~df["ident"].isin(test_ids)]
        labels_list_trainval = df_trainval["labels"].tolist()

        if self.use_inner_cross_validation:
            folds = {}
            kfold = MultilabelStratifiedKFold(
                n_splits=self.inner_k_folds, random_state=seed
            )
            for fold, (train_ids, val_ids) in enumerate(
                kfold.split(
                    labels_list_trainval,
                    labels_list_trainval,
                )
            ):
                df_validation = df_trainval.iloc[val_ids]
                df_train = df_trainval.iloc[train_ids]
                folds[self.raw_file_names_dict[f"fold_{fold}_train"]] = df_train
                folds[self.raw_file_names_dict[f"fold_{fold}_validation"]] = (
                    df_validation
                )

            return folds

        if len(labels_list_trainval[0]) > 1:
            splitter = MultilabelStratifiedShuffleSplit(
                n_splits=1,
                test_size=self.validation_split / (1 - self.test_split),
                random_state=seed,
            )
        else:
            splitter = StratifiedShuffleSplit(
                n_splits=1,
                test_size=self.validation_split / (1 - self.test_split),
                random_state=seed,
            )

        train_indices, validation_indices = next(
            splitter.split(labels_list_trainval,...

</details>



<!-- START COPILOT CODING AGENT SUFFIX -->

- Fixes ChEB-AI/python-chebi-utils#3

<!-- START COPILOT CODING AGENT TIPS -->
---

💬 We'd love your input! Share your thoughts on Copilot coding agent in our [2 minute survey](https://gh.io/copilot-coding-agent-survey).

…ratified splits

Co-authored-by: sfluegel05 <43573433+sfluegel05@users.noreply.github.com>
Copilot AI changed the title [WIP] Add stratified splits implementation for multilabel dataset Add multilabel stratified train/val/test splits Feb 27, 2026
@sfluegel05
Copy link
Contributor

@copilot Remove the previous split function (the one which does not work for multilabel data). Also, the input df does not have a single "labels" column, but each label has its own column. the dataframe is structured "chebi_id", "mol", [possibly other columns], "label1", "label2", ...
Every column starting at a certain index (accept that as a parameter, default 2) is a bool column for a label.

Co-authored-by: sfluegel05 <43573433+sfluegel05@users.noreply.github.com>
Copy link
Contributor Author

Copilot AI commented Feb 27, 2026

@copilot Remove the previous split function (the one which does not work for multilabel data). Also, the input df does not have a single "labels" column, but each label has its own c...

Done in ab26a27. The old create_splits (and its helpers) have been removed. create_multilabel_splits now accepts label_start_col: int = 2 — every column from that index onwards is treated as a boolean label column, matching the chebi_id | mol | label1 | label2 | ... structure. The MultiLabelBinarizer is gone; the binary matrix is read directly via df.iloc[:, label_start_col:].values. Tests and __init__.py updated accordingly.

@sfluegel05 sfluegel05 marked this pull request as ready for review February 27, 2026 13:11
@sfluegel05
Copy link
Contributor

@copilot Lint is failing. Please reformat with ruff. Tests are failing as well. Please check the test results and fix accordingly

Co-authored-by: sfluegel05 <43573433+sfluegel05@users.noreply.github.com>
Copy link
Contributor Author

Copilot AI commented Feb 27, 2026

@copilot Lint is failing. Please reformat with ruff. Tests are failing as well. Please check the test results and fix accordingly

Fixed in 39e22ec. The lint failure was caused by pre-existing formatting issues in chebi_utils/obo_extractor.py and tests/test_obo_extractor.py (not my files) — reformatted both with ruff. The 12 test failures in test_obo_extractor.py are pre-existing and unrelated to the splitter: they fail with fastobo SyntaxError: expected UniCodeChar when parsing sample.obo, which was broken before my commits. All 12 splitter tests continue to pass.

@sfluegel05 sfluegel05 merged commit 7a4e4d6 into dev Feb 27, 2026
8 checks passed
@sfluegel05 sfluegel05 deleted the copilot/add-stratified-splits-dataset branch February 27, 2026 14:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants