# Prepare a collection for Merlin benchmarks 

In [1]:
import anndata as ad
import dask
import dask.array as da
import dask.dataframe as dd
from dask.distributed import Client, LocalCluster
from functools import reduce
from glob import glob
import json
import numpy as np
from os import makedirs, stat
from os.path import basename, exists, join
import pandas as pd
import pyarrow as pa
import shlex
from subprocess import check_call, check_output
from sys import stderr
from tqdm.notebook import tqdm

def err(msg):
    stderr.write(msg)
    stderr.write('\n')

In [2]:
with open('h5ads.txt', 'r') as f:
    h5ad_names = [ l.rstrip('\n') for l in f.readlines() ]
print('\n'.join(h5ad_names))

04a23820-ffa8-4be5-9f65-64db15631d1e.h5ad
0325478a-9b52-45b5-b40a-2e2ab0d72eb1.h5ad
0d7f4c06-a6bd-47d2-a42b-3a7196704f77.h5ad
090da8ea-46e8-40df-bffc-1f78e1538d27.h5ad
07b1d7c8-5c2e-42f7-9246-26f746cd6013.h5ad
0ee5ae70-c3f5-473f-bd1c-287f4690ffc5.h5ad
1cf24082-59de-4029-ac81-6e398768af3a.h5ad
1b767f95-d0a0-4a3d-b394-cc665d86c3dc.h5ad
19b21f40-db42-4a71-a0d6-913e83b17784.h5ad
18500fcd-9960-49cb-8a8e-7d868dc14efe.h5ad
182f6a56-7360-4924-a74e-1772e07b3031.h5ad
22658f4f-9268-41ad-8828-cc53f4baa9fa.h5ad
2190bd4d-3be0-4bf7-8ca8-8d6f71228936.h5ad
2185eb07-22e2-4209-b3c8-7111afc6aa90.h5ad
2a8ca8f3-5599-4cda-b973-3a2dfc3c1fe6.h5ad
35090826-f636-40c1-a3ef-4466beeab9f8.h5ad
31f04740-c712-4c4b-a3f8-55c0506b3034.h5ad
3a7f3ab4-a280-4b3b-b2c0-6dd05614a78c.h5ad
367b55f4-d543-49aa-90e8-4765fcb8c687.h5ad
470565f2-5afc-456a-b617-18e4496c04fd.h5ad
46ff9dc2-3d87-4b36-91aa-ffa8aa13c52e.h5ad
43b7e156-65b3-4a7b-8c7a-08528e4b21d0.h5ad
421e5f54-5de7-425f-b399-34ead0651ce1.h5ad
40e79234-65e8-45e9-b555-5c663154a1

In [3]:
bkt = 'cellxgene-census-public-us-west-2'
bkt_dir = 'cell-census/2023-12-15/h5ads'
prefix = f'{bkt}/{bkt_dir}'

def sync(start, end, dryrun=False):
    cmd = [
        'aws', 's3', 'sync',
        *(['--dryrun'] if dryrun else []),
        '--exclude', '*',
        *[
            arg
            for name in h5ad_names[start:end]
            for arg in [ '--include', name ]
        ],
        f's3://{prefix}',
        f's3/{prefix}',
    ]
    err(f'Running: {shlex.join(cmd)}')
    check_call(cmd)

In [4]:
%%time
sync(0, 10)

Running: aws s3 sync --exclude '*' --include 04a23820-ffa8-4be5-9f65-64db15631d1e.h5ad --include 0325478a-9b52-45b5-b40a-2e2ab0d72eb1.h5ad --include 0d7f4c06-a6bd-47d2-a42b-3a7196704f77.h5ad --include 090da8ea-46e8-40df-bffc-1f78e1538d27.h5ad --include 07b1d7c8-5c2e-42f7-9246-26f746cd6013.h5ad --include 0ee5ae70-c3f5-473f-bd1c-287f4690ffc5.h5ad --include 1cf24082-59de-4029-ac81-6e398768af3a.h5ad --include 1b767f95-d0a0-4a3d-b394-cc665d86c3dc.h5ad --include 19b21f40-db42-4a71-a0d6-913e83b17784.h5ad --include 18500fcd-9960-49cb-8a8e-7d868dc14efe.h5ad s3://cellxgene-census-public-us-west-2/cell-census/2023-12-15/h5ads s3/cellxgene-census-public-us-west-2/cell-census/2023-12-15/h5ads


CPU times: user 2.39 ms, sys: 236 µs, total: 2.63 ms
Wall time: 2.46 s


In [5]:
h5ad_paths = glob(f's3/{prefix}/*.h5ad')
h5ad_paths

['s3/cellxgene-census-public-us-west-2/cell-census/2023-12-15/h5ads/090da8ea-46e8-40df-bffc-1f78e1538d27.h5ad',
 's3/cellxgene-census-public-us-west-2/cell-census/2023-12-15/h5ads/0ee5ae70-c3f5-473f-bd1c-287f4690ffc5.h5ad',
 's3/cellxgene-census-public-us-west-2/cell-census/2023-12-15/h5ads/18500fcd-9960-49cb-8a8e-7d868dc14efe.h5ad',
 's3/cellxgene-census-public-us-west-2/cell-census/2023-12-15/h5ads/0d7f4c06-a6bd-47d2-a42b-3a7196704f77.h5ad',
 's3/cellxgene-census-public-us-west-2/cell-census/2023-12-15/h5ads/04a23820-ffa8-4be5-9f65-64db15631d1e.h5ad',
 's3/cellxgene-census-public-us-west-2/cell-census/2023-12-15/h5ads/07b1d7c8-5c2e-42f7-9246-26f746cd6013.h5ad',
 's3/cellxgene-census-public-us-west-2/cell-census/2023-12-15/h5ads/1b767f95-d0a0-4a3d-b394-cc665d86c3dc.h5ad',
 's3/cellxgene-census-public-us-west-2/cell-census/2023-12-15/h5ads/19b21f40-db42-4a71-a0d6-913e83b17784.h5ad',
 's3/cellxgene-census-public-us-west-2/cell-census/2023-12-15/h5ads/1cf24082-59de-4029-ac81-6e398768af3a

Adapted [from `lamindb.MappedCollection._make_join_vars`](https://github.com/laminlabs/lamindb/blob/0.69.9/lamindb/core/_mapped_collection.py#L198-L221):

In [6]:
%%time
var_list = []
for path in h5ad_paths:
    h5ad = ad.read_h5ad(path, 'r')
    var_list.append(h5ad.var.index)

print('lens: %s' % ', '.join(map(str, [ len(v) for v in var_list ])))

var_inner = reduce(pd.Index.intersection, var_list)
print(f'{len(var_inner)} joint vars')

var_inner = var_inner[:20000].tolist()
var_inner[:10]

lens: 59357, 59357, 59357, 59357, 59357, 59357, 59357, 59357, 59357, 59357
59357 joint vars
CPU times: user 3.77 s, sys: 310 ms, total: 4.08 s
Wall time: 4.84 s


['ENSG00000287383',
 'ENSG00000100097',
 'ENSG00000271850',
 'ENSG00000156049',
 'ENSG00000231373',
 'ENSG00000232560',
 'ENSG00000233359',
 'ENSG00000172197',
 'ENSG00000286271',
 'ENSG00000150656']

In [7]:
import dask
dask.config.set({'temporary_directory': 'scratch'})
cluster = LocalCluster(n_workers=1, threads_per_worker=8)
client = Client(cluster)
client

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 1
Total threads: 8,Total memory: 62.10 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:37379,Workers: 1
Dashboard: http://127.0.0.1:8787/status,Total threads: 8
Started: Just now,Total memory: 62.10 GiB

0,1
Comm: tcp://127.0.0.1:37645,Total threads: 8
Dashboard: http://127.0.0.1:34761/status,Memory: 62.10 GiB
Nanny: tcp://127.0.0.1:42891,
Local directory: /home/ec2-user/c/tiledb/arrayloader-benchmarks/scratch/dask-scratch-space/worker-yetne5mo,Local directory: /home/ec2-user/c/tiledb/arrayloader-benchmarks/scratch/dask-scratch-space/worker-yetne5mo


In [8]:
out_dir = 'var20k'
makedirs(out_dir, exist_ok=True)

@dask.delayed
def slice_20k_vars(h5ad_path):    
    out_path = join(out_dir, basename(h5ad_path))
    if exists(out_path):
        #stderr(f"Found {out_path}, skipping")
        return
    h5ad = ad.read_h5ad(h5ad_path, 'r')
    # todo: check why this error happens
    # TypeError: Indexing elements must be in increasing order
    # via selection of named varibales
    # access[:, var_inner].to_memory()
    # for the second artifact 0325478a-9b52-45b5-b40a-2e2ab0d72eb1.h5ad
    # ok, it seems it doens't work with non-increasing indices
    # todo: fix
    adata = h5ad[:, var_inner]#.to_memory()
    # idx_sort, reverse = np.unique(h5ad.var_names.get_indexer(var_inner), return_inverse=True)
    # adata = h5ad[:, idx_sort].to_memory()[:, reverse]
    #err(f"Writing {out_path}")
    adata.write_h5ad(out_path)

In [16]:
%%time
dask.compute([ slice_20k_vars(h5ad_path) for h5ad_path in h5ad_paths ])

CPU times: user 2.51 s, sys: 197 ms, total: 2.7 s
Wall time: 3min 13s


([None, None, None, None, None, None, None, None, None, None],)

In [19]:
stat(h5ad_path).st_size

532566682

In [23]:
h5ad_path = 'var20k/0d7f4c06-a6bd-47d2-a42b-3a7196704f77.h5ad'
a1 = ad.read_h5ad(h5ad_path, 'r')
a1

AnnData object with n_obs × n_vars = 9932 × 20000 backed at 'var20k/0d7f4c06-a6bd-47d2-a42b-3a7196704f77.h5ad'
    obs: 'roi', 'organism_ontology_term_id', 'disease_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'assay_ontology_term_id', 'sex_ontology_term_id', 'development_stage_ontology_term_id', 'donor_id', 'suspension_type', 'dissection', 'fraction_mitochondrial', 'fraction_unspliced', 'cell_cycle_score', 'total_genes', 'total_UMIs', 'sample_id', 'supercluster_term', 'cluster_id', 'subcluster_id', 'cell_type_ontology_term_id', 'tissue_ontology_term_id', 'is_primary_data', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage'
    var: 'Biotype', 'Chromosome', 'End', 'Gene', 'Start', 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype'
    uns: 'batch_condition', 'schema_version', 'title'
    obsm: 'X_UMAP', 'X_tSNE'

In [18]:
a0 = ad.read_h5ad('s3/cellxgene-census-public-us-west-2/cell-census/2023-12-15/h5ads/0d7f4c06-a6bd-47d2-a42b-3a7196704f77.h5ad', 'r')
a0

AnnData object with n_obs × n_vars = 9932 × 59357 backed at 's3/cellxgene-census-public-us-west-2/cell-census/2023-12-15/h5ads/0d7f4c06-a6bd-47d2-a42b-3a7196704f77.h5ad'
    obs: 'roi', 'organism_ontology_term_id', 'disease_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'assay_ontology_term_id', 'sex_ontology_term_id', 'development_stage_ontology_term_id', 'donor_id', 'suspension_type', 'dissection', 'fraction_mitochondrial', 'fraction_unspliced', 'cell_cycle_score', 'total_genes', 'total_UMIs', 'sample_id', 'supercluster_term', 'cluster_id', 'subcluster_id', 'cell_type_ontology_term_id', 'tissue_ontology_term_id', 'is_primary_data', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage'
    var: 'Biotype', 'Chromosome', 'End', 'Gene', 'Start', 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype'
    uns: 'batch_condition', 'schema_version', 'title'
    obsm: 'X_UMAP', 'X_tSNE'

In [11]:
# InstanceNotSetupError: To use lamindb, you need to connect to an instance.
# from lamindb import MappedCollection
# mc = MappedCollection(h5ad_paths, join="inner")
# mc

InstanceNotSetupError: To use lamindb, you need to connect to an instance.

Connect to an instance: `ln.connect()`. Init an instance: `ln.setup.init()`.

If you used the CLI to set up lamindb in a notebook, restart the Python session.


## Prepare parquet

In [20]:
cluster.close()

In [22]:
@dask.delayed
def read_X(path, idx):
    return ad.read_h5ad(path, backed="r").X[idx, :].toarray().astype("float32", copy=False)

In [23]:
# number of files per parquet file
CHUNK_SIZE = 32768
# row group size of parquet files
ROW_GROUP_SIZE = 1024

In [24]:
%%time
array_chunks = []
chunk_sizes = []

for path in h5ad_paths:
    name = basename(path)
    h5ad_path = join(out_dir, name)
    access = ad.read_h5ad(h5ad_path, 'r')
    n_obs = access.shape[0]
    idx_splits = np.array_split(np.arange(n_obs), np.ceil(n_obs / CHUNK_SIZE))
    for idx in idx_splits:
        array_chunks.append(read_X(h5ad_path, idx))
        chunk_sizes.append(len(idx))

CPU times: user 2.45 s, sys: 540 ms, total: 2.99 s
Wall time: 11.1 s


In [25]:
X = da.concatenate([
    da.from_delayed(chunk, (shape, len(var_inner)), dtype="float32") 
    for chunk, shape in zip(array_chunks, chunk_sizes)
]).rechunk((CHUNK_SIZE, -1))

In [26]:
X

Unnamed: 0,Array,Chunk
Bytes,61.39 GiB,2.44 GiB
Shape,"(823976, 20000)","(32768, 20000)"
Dask graph,26 chunks in 62 graph layers,26 chunks in 62 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 61.39 GiB 2.44 GiB Shape (823976, 20000) (32768, 20000) Dask graph 26 chunks in 62 graph layers Data type float32 numpy.ndarray",20000  823976,

Unnamed: 0,Array,Chunk
Bytes,61.39 GiB,2.44 GiB
Shape,"(823976, 20000)","(32768, 20000)"
Dask graph,26 chunks in 62 graph layers,26 chunks in 62 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [27]:
@dask.delayed
def convert_to_dataframe(x, start, end):
    return pd.DataFrame(
        {'X': [arr.squeeze().astype("float32", copy=False) for arr in np.vsplit(x, x.shape[0])]},
        index=pd.RangeIndex(start, end)
    )

In [28]:
start_index = [0] + list(np.cumsum(X.chunks[0]))[:-1]
end_index = list(np.cumsum(X.chunks[0]))
# calculate divisons for dask dataframe
divisions = [0] + list(np.cumsum(X.chunks[0]))
divisions[-1] = divisions[-1] - 1
ddf = dd.from_delayed(
    [
        convert_to_dataframe(arr, start, end) for arr, start, end in 
        zip(X.to_delayed().flatten().tolist(), start_index, end_index)
    ],
    divisions=divisions
)

In [29]:
ddf

Unnamed: 0_level_0,X
npartitions=26,Unnamed: 1_level_1
0,object
32768,...
...,...
819200,...
823975,...


In [30]:
%%time
ddf.to_parquet(
    "./merlin_benchmark", 
    engine='pyarrow',
    schema=pa.schema([('X', pa.list_(pa.float32()))]),
    write_metadata_file=True,
    row_group_size=ROW_GROUP_SIZE
)

2024-04-10 16:44:22,815 - distributed.spill - ERROR - Spill to disk failed; keeping data in memory
Traceback (most recent call last):
  File "/home/ec2-user/miniconda3/envs/arrayloader-benchmarks/lib/python3.11/site-packages/distributed/spill.py", line 124, in _handle_errors
    yield
  File "/home/ec2-user/miniconda3/envs/arrayloader-benchmarks/lib/python3.11/site-packages/distributed/spill.py", line 199, in evict
    _, _, weight = self.fast.evict()
                   ^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/miniconda3/envs/arrayloader-benchmarks/lib/python3.11/site-packages/zict/common.py", line 127, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/miniconda3/envs/arrayloader-benchmarks/lib/python3.11/site-packages/zict/lru.py", line 227, in evict
    cb(key, value)
  File "/home/ec2-user/miniconda3/envs/arrayloader-benchmarks/lib/python3.11/site-packages/zict/buffer.py", line 139, in fast_to_slow
    self.slow[key] = value
    ~~~~

OSError: [Errno 28] No space left on device