In [36]:
import torch
from torch.utils.data import SequentialSampler

import numpy as np
from anndata import AnnData
from anndata._torch import AnnDataSet, AnnDataLoader, split_into_datasets

In [37]:
n_obs = 101

obsm = {'X_repr': np.random.randn(n_obs, 5)}
layers = {'Lay': np.random.randn(n_obs, 10)}
obs = {'label': np.random.binomial(1, 0.5, n_obs)}

adata = AnnData(X=np.random.randn(101, 10), obsm=obsm, layers=layers, obs=obs)

### AnnData to Dataset

In [38]:
dataset = AnnDataSet(adata, label_key='label')

In [39]:
# or
dataset = AnnDataSet(adata, obsm='X_repr', label_key='label')

In [40]:
# or
dataset = AnnDataSet(adata, layer='Lay', label_key='label')

In [41]:
# you can pass any argument of torch.utils.data.DataLoader
# uses custom default sampler if batch_size > 1 to get proper batches instead of concatenated separate indices
dloader = AnnDataLoader(dataset, batch_size=32, shuffle=True)

### AnnDataLoader directly from AnnData

In [42]:
# you can pass any argument of torch.utils.data.DataLoader and AnnDataSet
dloader = AnnDataLoader(adata, batch_size=32, shuffle=True, layer='Lay', label_key='label')

In [43]:
for batch in dloader:
    pass

In [44]:
# set different sampler, for example
dloader = AnnDataLoader(dataset, batch_size=32, sampler=SequentialSampler(dataset))
# or
dloader = AnnDataLoader(adata, batch_size=32, layer='Lay', label_key='label', sampler=SequentialSampler(dataset))

### Split adata into datasets

In [45]:
adata.obs['split'] = 'train'
adata.obs.loc[-40:, 'split'] = 'test'

In [48]:
# you can pass any argument of AnnDataSet
datasets = split_into_datasets(adata, obsm='X_repr', label_key='label')

In [53]:
datasets

{'train': <anndata._torch.interface.AnnDataSet at 0x57020b8>,
 'test': <anndata._torch.interface.AnnDataSet at 0x5577470>}

In [None]:
train_loader = AnnDataLoader(datasets['tarin'], batch_size=32, shuffle=True)