In [1]:
import gc
from os.path import exists
from shutil import rmtree
from sys import stderr
import time
from tqdm import tqdm

def err(*args):
    print(*args, file=stderr)

import cellxgene_census
import cellxgene_census.experimental.ml as census_ml
import somacore as soma
from somacore import AxisQuery, ExperimentAxisQuery, Measurement
from tiledbsoma import Experiment

In [2]:
census = cellxgene_census.open_soma()

#reference = ln.Collection.filter(uid="1gsdckxvOvIjQgeDVS1F").one().reference
reference = '283d65eb-dd53-496d-adb7-7570c7caa443'
query_collection_id = f"collection_id == '{reference}'"
datasets = (
    census["census_info"]["datasets"]
    .read(column_names=["dataset_id"], value_filter=query_collection_id)
    .concat().to_pandas()
)["dataset_id"].tolist()

datasets[:10]

The "stable" release is currently 2023-12-15. Specify 'census_version="2023-12-15"' in future calls to open_soma() to ensure data consistency.


['8e10f1c4-8e98-41e5-b65f-8cd89a887122',
 'b165f033-9dec-468a-9248-802fc6902a74',
 'ff7d15fa-f4b6-4a0e-992e-fd0c9d088ded',
 'fe1a73ab-a203-45fd-84e9-0f7fd19efcbd',
 'fbf173f9-f809-4d84-9b65-ae205d35b523',
 'fa554686-fc07-44dd-b2de-b726d82d26ec',
 'f9034091-2e8f-4ac6-9874-e7b7eb566824',
 'f8dda921-5fb4-4c94-a654-c6fc346bfd6d',
 'f7d003d4-40d5-4de8-858c-a9a8b48fcc67',
 'f6d9f2ad-5ec7-4d53-b7f0-ceb0e7bcd181']

In [3]:
len(datasets)

138

In [4]:
def subset_census(query: ExperimentAxisQuery, output_base_dir: str) -> None:
    """
    Subset the census cube to the given query, returning a new cube.
    """
    with Experiment.create(uri=output_base_dir) as exp_subset:
        x_data = query.X(layer_name="raw").tables().concat()

        obs_data = query.obs().concat()
        # remove obs rows with no X data
        x_soma_dim_0_unique = pa.Table.from_arrays([x_data["soma_dim_0"].unique()], names=["soma_dim_0"])
        obs_data = obs_data.join(x_soma_dim_0_unique, keys="soma_joinid", right_keys="soma_dim_0", join_type="inner")
        obs = soma.DataFrame.create(os.path.join(output_base_dir, "obs"), schema=obs_data.schema)
        obs.write(obs_data)
        exp_subset.set("obs", obs)

        ms = exp_subset.add_new_collection("ms")
        rna = ms.add_new_collection("RNA", Measurement)

        var_data = query.var().concat()
        var = rna.add_new_dataframe("var", schema=var_data.schema)
        var.write(var_data)

        x_type = x_data.schema.field_by_name("soma_data").type
        rna.add_new_collection("X")
        rna["X"].add_new_sparse_ndarray("raw", type=x_type, shape=(None, None))
        rna.X["raw"].write(x_data)

In [5]:
experiment = census["census_data"]["homo_sapiens"]
experiment

<Experiment 's3://cellxgene-census-public-us-west-2/cell-census/2023-12-15/soma/census_data/homo_sapiens' (open for 'r') (2 items)
    'ms': 's3://cellxgene-census-public-us-west-2/cell-census/2023-12-15/soma/census_data/homo_sapiens/ms' (unopened)
    'obs': 's3://cellxgene-census-public-us-west-2/cell-census/2023-12-15/soma/census_data/homo_sapiens/obs' (unopened)>

In [6]:
def download_datasets(n):
    query_datasets = f'dataset_id in {datasets[:n]}'
    query = experiment.axis_query(
        "RNA",
        obs_query=AxisQuery(value_filter=query_datasets),
        var_query=AxisQuery(coords=(slice(20000-1),)),
    )
    
    output_base_dir = f'/mnt/nvme/census-benchmark_:{n}'
    if exists(output_base_dir):
        err(f"Removing {output_base_dir}")
        rmtree(output_base_dir)
    
    subset_census(query, output_base_dir)

In [None]:
%%time
download_datasets(1)

Removing /mnt/nvme/census-benchmark_:1


In [None]:
BATCH_SIZE = 1024

In [None]:
experiment_datapipe = census_ml.ExperimentDataPipe(
    experiment,
    measurement_name="RNA",
    X_name="raw",
    obs_query=soma.AxisQuery(value_filter=query_datasets),
    var_query=soma.AxisQuery(coords=(slice(20000-1),)),
    batch_size=BATCH_SIZE,
    shuffle=True,
    soma_chunk_size=10000,
)

loader = census_ml.experiment_dataloader(experiment_datapipe)

In [None]:
n_epochs = 1
# n_epochs = 5

In [None]:
def benchmark(loader, n_samples = None):    
    loader_iter = loader.__iter__()
    # exclude first batch from benchmark as this includes the setup time
    batch = next(loader_iter)
    
    num_iter = n_samples // BATCH_SIZE if n_samples is not None else None
    
    start_time = time.time()
    
    batch_times = []
    batch_time = time.time()
    
    total = num_iter if num_iter is not None else len(loader_iter)
    for i, batch in tqdm(enumerate(loader_iter), total=total):
        X = batch["x"] if isinstance(batch, dict) else batch[0] 
        # for pytorch DataLoader
        # Merlin sends to cuda by default
        if hasattr(X, "is_cuda") and not X.is_cuda:
            X = X.cuda()
        
        if num_iter is not None and i == num_iter:
            break
        if i % 10 == 0:
            gc.collect()

        batch_elapsed = time.time() - batch_time
        print(f'Batch {i:04d} took {batch_elapsed:.2f}')
        batch_times.append(batch_elapsed)
        batch_time = time.time()
    
    execution_time = time.time() - start_time
    gc.collect()
    
    time_per_sample = (1e6 * execution_time) / (total * BATCH_SIZE)
    print(f'time per sample: {time_per_sample:.2f} μs')
    samples_per_sec = total * BATCH_SIZE / execution_time
    print(f'samples per sec: {samples_per_sec:.2f} samples/sec')
    
    return samples_per_sec, time_per_sample, batch_times

In [None]:
%%time
experiment_datapipe.shape

In [None]:
print("cellxgene_census")
for epoch in range(n_epochs):
    samples_per_sec, time_per_sample, batch_times = benchmark(loader, n_samples=experiment_datapipe.shape[0])
    results["cellxgene_census"][f"epoch_{epoch}"]["time_per_sample"] = time_per_sample
    results["cellxgene_census"][f"epoch_{epoch}"]["samples_per_sec"] = samples_per_sec
    results["cellxgene_census"][f"epoch_{epoch}"]["batch_times"] = batch_times

In [None]:
# census.close()