In [None]:
#| default_exp cohort_selector

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import re
from typing import List, Any, Dict, Union
import warnings

import numpy as np
import pandas as pd

In [None]:
#| export
from pheno_utils.config import (
    DATASETS_PATH, 
    COHORT, 
    ERROR_ACTION
)
from pheno_utils.meta_loader import MetaLoader

In [None]:
#| export

class CohortSelector:
    """
    Class for selecting a subset of a cohort's data based on a query.

    Args:

        base_path (str, optional): Base path of the datasets. Defaults to DATASETS_PATH.
        cohort (str, optional): Name of the cohort. Defaults to COHORT.
        errors (str, optional): Error action. Defaults to ERROR_ACTION.
        **kwargs: Additional keyword arguments.

    Attributes:

        cohort (str): Name of the cohort.
        base_path (str): Base path of the datasets.
        errors (str): Error action.
        kwargs: Additional keyword arguments.
        ml (MetaLoader): MetaLoader object for loading metadata and data.

    """

    def __init__(
        self,
        base_path: str = DATASETS_PATH,
        cohort: str = COHORT,
        errors: str = ERROR_ACTION,
        **kwargs,
    ) -> None:
        """
        Initialize CohortSelector object.

        Args:

            base_path (str, optional): Base path of the datasets. Defaults to DATASETS_PATH.
            cohort (str, optional): Name of the cohort. Defaults to COHORT.
            errors (str, optional): Error action. Defaults to ERROR_ACTION.
            **kwargs: Additional keyword arguments.

        """
        self.cohort = cohort
        self.base_path = base_path
        self.errors = errors
        self.kwargs = kwargs

        self.ml = MetaLoader(
            base_path=self.base_path, cohort=self.cohort,
            flexible_field_search=False, errors=self.errors,
            **self.kwargs)

    def select(self, query: str) -> pd.DataFrame:
        """
        Select a subset of the cohort's data based on the given query.

        Args:

            query (str): Query string to filter the data.

        Returns:

            pd.DataFrame: Filtered DataFrame based on the query.

        Raises:

            ValueError: If no column names are found in the query.
            ValueError: If column names in the query do not match the column names in the metadata.

        """
        column_names = re.findall(r'([a-zA-Z][a-zA-Z0-9_/]*)\b', query)
        if not column_names:
            raise ValueError('No column names found in query')

        test_cols = self.ml.get(column_names)
        missing_cols = [col for col in column_names
                        if col not in test_cols.columns and
                        col not in test_cols.columns.str.split('/').str[1]]
        if len(missing_cols):
            raise ValueError(f'Column names {missing_cols} in query do not match column names in metadata')

        df = self.ml.load(column_names)

        try:
            return df.query(re.sub(r'\w+/', '', query))
        except:
            print(re.sub(r'\w+(/)', '__', query))
            return df.query(re.sub(r'(\w+)/', r'\1__', query))


In [None]:
cs = CohortSelector()

You may use the `CohortSelector` to select participants based on any fields. The selector will return a DataFrame with the selected sub-cohort, along with the fields that were used in the query.

For example, the following query selects participants who have moderate obstructive sleep apnea (AHI > 15) based on recordings of at least 4 hours of sleep.

In [None]:
cs.select('15 < ahi < 20 & total_sleep_time > 4*3600')

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,ahi,total_sleep_time
participant_id,cohort,research_stage,array_index,Unnamed: 4_level_1,Unnamed: 5_level_1
9,10k,02_00_visit,2,18.39,23748.0
15,10k,00_00_visit,0,19.54,19230.0
30,10k,00_00_visit,0,18.52,25111.0
42,10k,00_00_visit,1,17.84,24966.0
49,10k,02_00_visit,0,17.58,21980.0
...,...,...,...,...,...
902,10k,02_00_visit,2,19.78,24162.0
914,10k,02_00_visit,2,17.54,26479.0
936,10k,00_00_visit,1,17.43,20865.0
941,10k,00_00_visit,0,17.82,17606.0


You may also use the selector to filter on dates. Here we filter on dates of image collection in the fundus imaging dataset.

In [None]:
cs.select('fundus/collection_date > "2022-01-01"')

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,collection_date
participant_id,cohort,research_stage,array_index,Unnamed: 4_level_1
0,10k,00_00_visit,0,2022-11-16
1,10k,00_00_visit,0,2022-06-30
3,10k,00_00_visit,0,2022-04-26


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()