# Prepare a collection for Merlin benchmarks 

In [1]:
import anndata as ad
import dask
import dask.array as da
import dask.dataframe as dd
from dask.distributed import Client, LocalCluster
from functools import reduce
from glob import glob
import json
import numpy as np
from os import makedirs, stat
from os.path import basename, exists, join
import pandas as pd
import pyarrow as pa
import shlex
from subprocess import check_call, check_output
from sys import stderr
from tqdm.notebook import tqdm

def err(msg):
    stderr.write(msg)
    stderr.write('\n')

In [2]:
with open('h5ads.txt', 'r') as f:
    h5ad_names = [ l.rstrip('\n') for l in f.readlines() ]
print('\n'.join(h5ad_names))

04a23820-ffa8-4be5-9f65-64db15631d1e.h5ad
0325478a-9b52-45b5-b40a-2e2ab0d72eb1.h5ad
0d7f4c06-a6bd-47d2-a42b-3a7196704f77.h5ad
090da8ea-46e8-40df-bffc-1f78e1538d27.h5ad
07b1d7c8-5c2e-42f7-9246-26f746cd6013.h5ad
0ee5ae70-c3f5-473f-bd1c-287f4690ffc5.h5ad
1cf24082-59de-4029-ac81-6e398768af3a.h5ad
1b767f95-d0a0-4a3d-b394-cc665d86c3dc.h5ad
19b21f40-db42-4a71-a0d6-913e83b17784.h5ad
18500fcd-9960-49cb-8a8e-7d868dc14efe.h5ad
182f6a56-7360-4924-a74e-1772e07b3031.h5ad
22658f4f-9268-41ad-8828-cc53f4baa9fa.h5ad
2190bd4d-3be0-4bf7-8ca8-8d6f71228936.h5ad
2185eb07-22e2-4209-b3c8-7111afc6aa90.h5ad
2a8ca8f3-5599-4cda-b973-3a2dfc3c1fe6.h5ad
35090826-f636-40c1-a3ef-4466beeab9f8.h5ad
31f04740-c712-4c4b-a3f8-55c0506b3034.h5ad
3a7f3ab4-a280-4b3b-b2c0-6dd05614a78c.h5ad
367b55f4-d543-49aa-90e8-4765fcb8c687.h5ad
470565f2-5afc-456a-b617-18e4496c04fd.h5ad
46ff9dc2-3d87-4b36-91aa-ffa8aa13c52e.h5ad
43b7e156-65b3-4a7b-8c7a-08528e4b21d0.h5ad
421e5f54-5de7-425f-b399-34ead0651ce1.h5ad
40e79234-65e8-45e9-b555-5c663154a1

In [3]:
bkt = 'cellxgene-census-public-us-west-2'
bkt_dir = 'cell-census/2023-12-15/h5ads'
prefix = f'{bkt}/{bkt_dir}'

def sync(start, end, dryrun=False):
    cmd = [
        'aws', 's3', 'sync',
        *(['--dryrun'] if dryrun else []),
        '--exclude', '*',
        *[
            arg
            for name in h5ad_names[start:end]
            for arg in [ '--include', name ]
        ],
        f's3://{prefix}',
        f's3/{prefix}',
    ]
    err(f'Running: {shlex.join(cmd)}')
    check_call(cmd)

In [4]:
%%time
sync(0, 10)

Running: aws s3 sync --exclude '*' --include 04a23820-ffa8-4be5-9f65-64db15631d1e.h5ad --include 0325478a-9b52-45b5-b40a-2e2ab0d72eb1.h5ad --include 0d7f4c06-a6bd-47d2-a42b-3a7196704f77.h5ad --include 090da8ea-46e8-40df-bffc-1f78e1538d27.h5ad --include 07b1d7c8-5c2e-42f7-9246-26f746cd6013.h5ad --include 0ee5ae70-c3f5-473f-bd1c-287f4690ffc5.h5ad --include 1cf24082-59de-4029-ac81-6e398768af3a.h5ad --include 1b767f95-d0a0-4a3d-b394-cc665d86c3dc.h5ad --include 19b21f40-db42-4a71-a0d6-913e83b17784.h5ad --include 18500fcd-9960-49cb-8a8e-7d868dc14efe.h5ad s3://cellxgene-census-public-us-west-2/cell-census/2023-12-15/h5ads s3/cellxgene-census-public-us-west-2/cell-census/2023-12-15/h5ads


CPU times: user 1.95 ms, sys: 236 µs, total: 2.18 ms
Wall time: 2.27 s


In [5]:
h5ad_paths = glob(f's3/{prefix}/*.h5ad')
h5ad_paths

['s3/cellxgene-census-public-us-west-2/cell-census/2023-12-15/h5ads/090da8ea-46e8-40df-bffc-1f78e1538d27.h5ad',
 's3/cellxgene-census-public-us-west-2/cell-census/2023-12-15/h5ads/0ee5ae70-c3f5-473f-bd1c-287f4690ffc5.h5ad',
 's3/cellxgene-census-public-us-west-2/cell-census/2023-12-15/h5ads/18500fcd-9960-49cb-8a8e-7d868dc14efe.h5ad',
 's3/cellxgene-census-public-us-west-2/cell-census/2023-12-15/h5ads/0d7f4c06-a6bd-47d2-a42b-3a7196704f77.h5ad',
 's3/cellxgene-census-public-us-west-2/cell-census/2023-12-15/h5ads/04a23820-ffa8-4be5-9f65-64db15631d1e.h5ad',
 's3/cellxgene-census-public-us-west-2/cell-census/2023-12-15/h5ads/07b1d7c8-5c2e-42f7-9246-26f746cd6013.h5ad',
 's3/cellxgene-census-public-us-west-2/cell-census/2023-12-15/h5ads/1b767f95-d0a0-4a3d-b394-cc665d86c3dc.h5ad',
 's3/cellxgene-census-public-us-west-2/cell-census/2023-12-15/h5ads/19b21f40-db42-4a71-a0d6-913e83b17784.h5ad',
 's3/cellxgene-census-public-us-west-2/cell-census/2023-12-15/h5ads/1cf24082-59de-4029-ac81-6e398768af3a

Adapted [from `lamindb.MappedCollection._make_join_vars`](https://github.com/laminlabs/lamindb/blob/0.69.9/lamindb/core/_mapped_collection.py#L198-L221):

In [6]:
%%time
var_list = []
for path in h5ad_paths:
    h5ad = ad.read_h5ad(path, 'r')
    var_list.append(h5ad.var.index)

print('lens: %s' % ', '.join(map(str, [ len(v) for v in var_list ])))

var_inner = reduce(pd.Index.intersection, var_list)
print(f'{len(var_inner)} joint vars')

var_inner = var_inner[:20000].tolist()
var_inner[:10]

lens: 59357, 59357, 59357, 59357, 59357, 59357, 59357, 59357, 59357, 59357
59357 joint vars
CPU times: user 3.49 s, sys: 377 ms, total: 3.86 s
Wall time: 4.66 s


['ENSG00000287383',
 'ENSG00000100097',
 'ENSG00000271850',
 'ENSG00000156049',
 'ENSG00000231373',
 'ENSG00000232560',
 'ENSG00000233359',
 'ENSG00000172197',
 'ENSG00000286271',
 'ENSG00000150656']

In [7]:
dask.config.set({'temporary_directory': 'scratch'})
cluster = LocalCluster()
client = Client(cluster)
client

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 4
Total threads: 8,Total memory: 62.10 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:42251,Workers: 4
Dashboard: http://127.0.0.1:8787/status,Total threads: 8
Started: Just now,Total memory: 62.10 GiB

0,1
Comm: tcp://127.0.0.1:36715,Total threads: 2
Dashboard: http://127.0.0.1:35907/status,Memory: 15.52 GiB
Nanny: tcp://127.0.0.1:39515,
Local directory: /home/ec2-user/c/tiledb/arrayloader-benchmarks/scratch/dask-scratch-space/worker-rkfkhf3w,Local directory: /home/ec2-user/c/tiledb/arrayloader-benchmarks/scratch/dask-scratch-space/worker-rkfkhf3w

0,1
Comm: tcp://127.0.0.1:40631,Total threads: 2
Dashboard: http://127.0.0.1:44035/status,Memory: 15.52 GiB
Nanny: tcp://127.0.0.1:46121,
Local directory: /home/ec2-user/c/tiledb/arrayloader-benchmarks/scratch/dask-scratch-space/worker-2t4xtl5x,Local directory: /home/ec2-user/c/tiledb/arrayloader-benchmarks/scratch/dask-scratch-space/worker-2t4xtl5x

0,1
Comm: tcp://127.0.0.1:43137,Total threads: 2
Dashboard: http://127.0.0.1:38873/status,Memory: 15.52 GiB
Nanny: tcp://127.0.0.1:37459,
Local directory: /home/ec2-user/c/tiledb/arrayloader-benchmarks/scratch/dask-scratch-space/worker-mte63x66,Local directory: /home/ec2-user/c/tiledb/arrayloader-benchmarks/scratch/dask-scratch-space/worker-mte63x66

0,1
Comm: tcp://127.0.0.1:42875,Total threads: 2
Dashboard: http://127.0.0.1:40037/status,Memory: 15.52 GiB
Nanny: tcp://127.0.0.1:42607,
Local directory: /home/ec2-user/c/tiledb/arrayloader-benchmarks/scratch/dask-scratch-space/worker-880fo0pn,Local directory: /home/ec2-user/c/tiledb/arrayloader-benchmarks/scratch/dask-scratch-space/worker-880fo0pn


2024-04-11 13:49:04,519 - distributed.scheduler - ERROR - Task ('fromdelayed-reset_index-operation-toparquetdata-a29d2a4e1199659c5d173574e75ce747', 9) has 61.37 GiB worth of input dependencies, but worker tcp://127.0.0.1:42875 has memory_limit set to 15.52 GiB.


In [8]:
out_dir = 'var20k'
makedirs(out_dir, exist_ok=True)

@dask.delayed
def slice_20k_vars(h5ad_path):    
    out_path = join(out_dir, basename(h5ad_path))
    if exists(out_path):
        #stderr(f"Found {out_path}, skipping")
        return
    h5ad = ad.read_h5ad(h5ad_path, 'r')
    # todo: check why this error happens
    # TypeError: Indexing elements must be in increasing order
    # via selection of named varibales
    # access[:, var_inner].to_memory()
    # for the second artifact 0325478a-9b52-45b5-b40a-2e2ab0d72eb1.h5ad
    # ok, it seems it doens't work with non-increasing indices
    # todo: fix
    adata = h5ad[:, var_inner]#.to_memory()
    # idx_sort, reverse = np.unique(h5ad.var_names.get_indexer(var_inner), return_inverse=True)
    # adata = h5ad[:, idx_sort].to_memory()[:, reverse]
    #err(f"Writing {out_path}")
    adata.write_h5ad(out_path)

In [9]:
%%time
dask.compute([ slice_20k_vars(h5ad_path) for h5ad_path in h5ad_paths ])

CPU times: user 66.8 ms, sys: 7.56 ms, total: 74.4 ms
Wall time: 260 ms


([None, None, None, None, None, None, None, None, None, None],)

## Prepare parquet

In [11]:
#cluster.close()

In [10]:
@dask.delayed
def read_X(path, idx):
    return ad.read_h5ad(path, backed="r").X[idx, :].toarray().astype("float32", copy=False)

In [11]:
# number of files per parquet file
CHUNK_SIZE = 32768
# row group size of parquet files
ROW_GROUP_SIZE = 1024

In [12]:
%%time
array_chunks = []
chunk_sizes = []

for path in h5ad_paths:
    name = basename(path)
    h5ad_path = join(out_dir, name)
    access = ad.read_h5ad(h5ad_path, 'r')
    n_obs = access.shape[0]
    idx_splits = np.array_split(np.arange(n_obs), np.ceil(n_obs / CHUNK_SIZE))
    for idx in idx_splits:
        array_chunks.append(read_X(h5ad_path, idx))
        chunk_sizes.append(len(idx))

CPU times: user 1.93 s, sys: 263 ms, total: 2.19 s
Wall time: 3.1 s


In [13]:
X = da.concatenate([
    da.from_delayed(chunk, (shape, len(var_inner)), dtype="float32") 
    for chunk, shape in zip(array_chunks, chunk_sizes)
]).rechunk((CHUNK_SIZE, -1))

In [14]:
X

Unnamed: 0,Array,Chunk
Bytes,61.39 GiB,2.44 GiB
Shape,"(823976, 20000)","(32768, 20000)"
Dask graph,26 chunks in 62 graph layers,26 chunks in 62 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 61.39 GiB 2.44 GiB Shape (823976, 20000) (32768, 20000) Dask graph 26 chunks in 62 graph layers Data type float32 numpy.ndarray",20000  823976,

Unnamed: 0,Array,Chunk
Bytes,61.39 GiB,2.44 GiB
Shape,"(823976, 20000)","(32768, 20000)"
Dask graph,26 chunks in 62 graph layers,26 chunks in 62 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [15]:
@dask.delayed
def convert_to_dataframe(x, start, end):
    return pd.DataFrame(
        {'X': [arr.squeeze().astype("float32", copy=False) for arr in np.vsplit(x, x.shape[0])]},
        index=pd.RangeIndex(start, end)
    )

In [16]:
start_index = [0] + list(np.cumsum(X.chunks[0]))[:-1]
end_index = list(np.cumsum(X.chunks[0]))
# calculate divisons for dask dataframe
divisions = [0] + list(np.cumsum(X.chunks[0]))
divisions[-1] = divisions[-1] - 1
ddf = dd.from_delayed(
    [
        convert_to_dataframe(arr, start, end) for arr, start, end in 
        zip(X.to_delayed().flatten().tolist(), start_index, end_index)
    ],
    divisions=divisions
)

In [17]:
ddf

Unnamed: 0_level_0,X
npartitions=26,Unnamed: 1_level_1
0,object
32768,...
...,...
819200,...
823975,...


In [18]:
%%time
ddf.to_parquet(
    "./merlin_benchmark", 
    engine='pyarrow',
    schema=pa.schema([('X', pa.list_(pa.float32()))]),
    write_metadata_file=True,
    row_group_size=ROW_GROUP_SIZE
)

2024-04-11 13:33:40,532 - distributed.worker - ERROR - Worker stream died during communication: tcp://127.0.0.1:42875
Traceback (most recent call last):
  File "/home/ec2-user/miniconda3/envs/arrayloader-benchmarks/lib/python3.11/site-packages/tornado/iostream.py", line 861, in _read_to_buffer
    bytes_read = self.read_from_fd(buf)
                 ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/miniconda3/envs/arrayloader-benchmarks/lib/python3.11/site-packages/tornado/iostream.py", line 1116, in read_from_fd
    return self.socket.recv_into(buf, len(buf))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ConnectionResetError: [Errno 104] Connection reset by peer

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/ec2-user/miniconda3/envs/arrayloader-benchmarks/lib/python3.11/site-packages/distributed/worker.py", line 2059, in gather_dep
    response = await get_data_from_worker(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^

MemoryError: Task ('fromdelayed-reset_index-operation-toparquetdata-a29d2a4e1199659c5d173574e75ce747', 9) has 61.37 GiB worth of input dependencies, but worker tcp://127.0.0.1:42875 has memory_limit set to 15.52 GiB.