In [1]:
import os

os.chdir("../../src")

In [2]:
import utils.data as data_utils
import utils.data.extraction, utils.data.splits, utils.data.dataset, utils.data.dataloader
import utils.dev.notebook as dev
from utils.paths import RAW_DATA_PATH, PREPROCESSED_DATA_PATH

#### Extracting relevant class variables from `.obs`

In [3]:
extraction_cfg = dev.dict_to_namespace(
    {
        "path": PREPROCESSED_DATA_PATH
        / "GSE194122_openproblems_neurips2021_cite_BMMC_processed_processed.h5",
        "obs": {
            "columns": [
                {"name": "cell_type", "as_codes": True},
                {"name": "batch", "as_codes": True},
            ]
        },
    }
)
extraction_cfg
obs = data_utils.extraction.get_dataset_obs(extraction_cfg)
display(obs.head(6))
display(obs.dtypes)

Unnamed: 0,cell_type,batch
GCATTAGCATAAGCGG-1-s1d1,27,0
TACAGGTGTTAGAGTA-1-s1d1,15,0
AGGATCTAGGTCTACT-1-s1d1,27,0
GTAGAAAGTGACACAG-1-s1d1,19,0
TCCGAAAAGGATCATA-1-s1d1,35,0
CTCCCAATCCATTGGA-1-s1d1,29,0


cell_type    int8
batch        int8
dtype: object

#### Creating naive split mixing batches etc.

In [4]:
fraction_split_cfg = dev.dict_to_namespace({"val_fraction": 0.2})
split = data_utils.splits.naive_mixing_fraction_split(
    max_idx=len(obs), cfg=fraction_split_cfg
)
split

Split(train_indices=array([60195,  1089, 30812, ..., 16049, 18228, 22661]), val_indices=array([30112, 23109,  4035, ..., 78416, 50356, 85997]))

In [5]:
train_indxs, val_indxs = split

In [6]:
naive_kfold_split_cfg = dev.dict_to_namespace({"n_splits": 4, "random_state": 2})
kf_split = data_utils.splits.naive_mixing_k_fold_split(
    max_idx=len(obs), cfg=naive_kfold_split_cfg
)
display(kf_split)
display(next(iter(kf_split)))

<generator object naive_mixing_k_fold_split at 0x7fc23b01cb30>

Split(train_indices=array([    1,     2,     3, ..., 90258, 90259, 90260]), val_indices=array([    0,     4,     9, ..., 90254, 90255, 90256]))

#### Composite split separating samples from different batches etc.

In [7]:
composite_kfold_split_cfg = dev.dict_to_namespace(
    {
        "grid_variables": ["cell_type", "batch"],
        "n_splits": 2,
        "random_state": 0,
    }
)
kfsplit = data_utils.splits.composite_k_fold_split(
    df=obs, cfg=composite_kfold_split_cfg
)

In [8]:
train_idxs, val_indxs = next(iter(kfsplit))

In [9]:
train_dataset_cfg = dev.dict_to_namespace(
    {
        "path": RAW_DATA_PATH
        / "GSE194122_openproblems_neurips2021_cite_BMMC_processed.h5ad",
        "rowsize": 14087,
        "obs": {
            "columns": [
                {
                    "org_name": "cell_type",
                    "new_name": "cell_type",
                    "remap_categories": False,
                },
                {"org_name": "batch", "new_name": "batch", "remap_categories": False},
                {"org_name": "Site", "new_name": "site", "remap_categories": False},
            ]
        },
    }
)
train_dataset = data_utils.dataset.hdf5SparseDataset(
    dataset_idxs=train_idxs, cfg=train_dataset_cfg
)

In [10]:
next(iter(train_dataset))

{'data': tensor([[0.6763, 0.0000, 0.0000,  ..., 1.5909, 1.6189, 1.6097]]),
 'cell_type': tensor([15]),
 'batch': tensor([0]),
 'site': tensor([0])}

In [11]:
train_dataset._dataset_idxs

array([    2,     4,     7, ..., 90254, 90257, 90258])

In [12]:
train_dataloader_cfg = dev.dict_to_namespace(
    {"dataloader": {"batch_size": 16, "num_workers": 1}}
)
train_dataloader = data_utils.dataloader.get_hdf5SparseDataloader(
    train_dataloader_cfg, train_dataset
)

In [13]:
batch = next(iter(train_dataloader))

In [14]:
batch

{'data': tensor([[0.0000, 0.0000, 0.0000,  ..., 0.7793, 3.2602, 0.5459],
         [0.0000, 0.0000, 0.0000,  ..., 0.3247, 0.5694, 0.5694],
         [0.0000, 0.0000, 0.0000,  ..., 0.6212, 0.4983, 0.4983],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.5289, 0.9049, 0.9947],
         [0.0000, 0.0000, 0.0000,  ..., 0.2788, 0.8268, 0.9164],
         [0.0000, 0.0000, 0.0000,  ..., 0.8869, 1.1819, 1.2177]]),
 'cell_type': tensor([35, 44, 43,  3, 15, 15,  3,  5,  3, 18, 15, 15, 27, 10, 38, 12]),
 'batch': tensor([ 2,  3,  4,  5,  6,  6,  6,  6,  6,  7,  8,  8,  9, 10, 11, 11]),
 'site': tensor([0, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3])}

In [16]:
from tqdm import tqdm

for _ in tqdm(train_dataloader):
    pass

100%|██████████| 2821/2821 [01:58<00:00, 23.79it/s]


#### Composite parameterised split - just another way of parameterised splittic without randomness

In [17]:
obs

Unnamed: 0,cell_type,batch,index
GCATTAGCATAAGCGG-1-s1d1,27,0,0
TACAGGTGTTAGAGTA-1-s1d1,15,0,1
AGGATCTAGGTCTACT-1-s1d1,27,0,2
GTAGAAAGTGACACAG-1-s1d1,19,0,3
TCCGAAAAGGATCATA-1-s1d1,35,0,4
...,...,...,...
GAATCACCACGGAAGT-1-s4d9,22,11,90256
GCTGGGTGTACGGATG-1-s4d9,13,11,90257
TCGAAGTGTGACAGGT-1-s4d9,37,11,90258
GCAGGCTGTTGCATAC-1-s4d9,5,11,90259


In [18]:
subset_param_composite_split_cfg = dev.dict_to_namespace(
    {
        "val_filter_values": [
            {"name": "batch", "filter_values": [0, 2]},
            {"name": "cell_type", "filter_values": [6, 11, 12]},
        ]
    }
)

split = data_utils.splits.subset_parameterised_composite_split(
    df=obs, cfg=subset_param_composite_split_cfg
)
split

Split(train_indices=array([    0,     1,     2, ..., 90258, 90259, 90260]), val_indices=array([   59,    63,    70,    84,    98,   166,   241,   320,   326,
         356,   404,   408,   420,   447,   541,   547,   580,   600,
         661,   735,   759,   796,   888,   920,   936,   941,   943,
         955,   965,  1017,  1035,  1167,  1203,  1206,  1231,  1282,
        1283,  1291,  1384,  1429,  1439,  1542,  1568,  1593,  1595,
        1659,  1689,  1710,  1734,  1784,  1799,  1822,  1845,  1914,
        1915,  1964,  2002,  2037,  2043,  2081,  2213,  2280,  2300,
        2335,  2337,  2369,  2385,  2410,  2433,  2562,  2639,  2644,
        2722,  2782,  2837,  2853,  2861,  2866,  3025,  3032,  3036,
        3039,  3057,  3073,  3084,  3217,  3220,  3234,  3295,  3317,
        3340,  3362,  3369,  3375,  3384,  3393,  3422,  3453,  3486,
        3514,  3578,  3614,  3623,  3635,  3668,  3669,  3708,  3768,
        3775,  3871,  3872,  3884,  3923,  3950,  3977,  3980,  4007,
  