# Prepare a collection for Merlin benchmarks 

In [7]:
import anndata as ad
import dask
import dask.array as da
import dask.dataframe as dd
from dask.distributed import Client, LocalCluster
import numpy as np
from os.path import basename
import pandas as pd
import pyarrow as pa
import shlex
from subprocess import check_call, check_output
from sys import stderr
from tqdm.notebook import tqdm


def err(msg):
    stderr.write(msg)
    stderr.write('\n')

In [2]:
with open('h5ads.txt', 'r') as f:
    h5ad_paths = [ l.rstrip('\n') for l in f.readlines() ]
h5ad_paths

['cell-census/2023-12-15/h5ads/04a23820-ffa8-4be5-9f65-64db15631d1e.h5ad',
 'cell-census/2023-12-15/h5ads/0325478a-9b52-45b5-b40a-2e2ab0d72eb1.h5ad',
 'cell-census/2023-12-15/h5ads/0d7f4c06-a6bd-47d2-a42b-3a7196704f77.h5ad',
 'cell-census/2023-12-15/h5ads/090da8ea-46e8-40df-bffc-1f78e1538d27.h5ad',
 'cell-census/2023-12-15/h5ads/07b1d7c8-5c2e-42f7-9246-26f746cd6013.h5ad',
 'cell-census/2023-12-15/h5ads/0ee5ae70-c3f5-473f-bd1c-287f4690ffc5.h5ad',
 'cell-census/2023-12-15/h5ads/1cf24082-59de-4029-ac81-6e398768af3a.h5ad',
 'cell-census/2023-12-15/h5ads/1b767f95-d0a0-4a3d-b394-cc665d86c3dc.h5ad',
 'cell-census/2023-12-15/h5ads/19b21f40-db42-4a71-a0d6-913e83b17784.h5ad',
 'cell-census/2023-12-15/h5ads/18500fcd-9960-49cb-8a8e-7d868dc14efe.h5ad',
 'cell-census/2023-12-15/h5ads/182f6a56-7360-4924-a74e-1772e07b3031.h5ad',
 'cell-census/2023-12-15/h5ads/22658f4f-9268-41ad-8828-cc53f4baa9fa.h5ad',
 'cell-census/2023-12-15/h5ads/2190bd4d-3be0-4bf7-8ca8-8d6f71228936.h5ad',
 'cell-census/2023-12-15/

In [8]:
bkt = 'cellxgene-census-public-us-west-2'
bkt_dir = 'cell-census/2023-12-15/h5ads'
prefix = f'{bkt}/{bkt_dir}'

def sync(start, end, dryrun=False):
    basenames = [ basename(h5ad_path) for h5ad_path in h5ad_paths[start:end] ]
    cmd = [
        'aws', 's3', 'sync',
        *(['--dryrun'] if dryrun else []),
        '--exclude', '*',
        *[
            arg
            for name in basenames
            for arg in [ '--include', name ]
        ],
        f's3://{prefix}',
        f's3/{prefix}',
    ]
    err(f'Running: {shlex.join(cmd)}')
    check_call(cmd)

In [9]:
sync(0, 10, dryrun=True)

Running: aws s3 sync --dryrun --exclude '*' --include 04a23820-ffa8-4be5-9f65-64db15631d1e.h5ad --include 0325478a-9b52-45b5-b40a-2e2ab0d72eb1.h5ad --include 0d7f4c06-a6bd-47d2-a42b-3a7196704f77.h5ad --include 090da8ea-46e8-40df-bffc-1f78e1538d27.h5ad --include 07b1d7c8-5c2e-42f7-9246-26f746cd6013.h5ad --include 0ee5ae70-c3f5-473f-bd1c-287f4690ffc5.h5ad --include 1cf24082-59de-4029-ac81-6e398768af3a.h5ad --include 1b767f95-d0a0-4a3d-b394-cc665d86c3dc.h5ad --include 19b21f40-db42-4a71-a0d6-913e83b17784.h5ad --include 18500fcd-9960-49cb-8a8e-7d868dc14efe.h5ad s3://cellxgene-census-public-us-west-2/cell-census/2023-12-15/h5ads s3/cellxgene-census-public-us-west-2/cell-census/2023-12-15/h5ads


(dryrun) download: s3://cellxgene-census-public-us-west-2/cell-census/2023-12-15/h5ads/0325478a-9b52-45b5-b40a-2e2ab0d72eb1.h5ad to s3/cellxgene-census-public-us-west-2/cell-census/2023-12-15/h5ads/0325478a-9b52-45b5-b40a-2e2ab0d72eb1.h5ad
(dryrun) download: s3://cellxgene-census-public-us-west-2/cell-census/2023-12-15/h5ads/04a23820-ffa8-4be5-9f65-64db15631d1e.h5ad to s3/cellxgene-census-public-us-west-2/cell-census/2023-12-15/h5ads/04a23820-ffa8-4be5-9f65-64db15631d1e.h5ad
(dryrun) download: s3://cellxgene-census-public-us-west-2/cell-census/2023-12-15/h5ads/07b1d7c8-5c2e-42f7-9246-26f746cd6013.h5ad to s3/cellxgene-census-public-us-west-2/cell-census/2023-12-15/h5ads/07b1d7c8-5c2e-42f7-9246-26f746cd6013.h5ad
(dryrun) download: s3://cellxgene-census-public-us-west-2/cell-census/2023-12-15/h5ads/090da8ea-46e8-40df-bffc-1f78e1538d27.h5ad to s3/cellxgene-census-public-us-west-2/cell-census/2023-12-15/h5ads/090da8ea-46e8-40df-bffc-1f78e1538d27.h5ad
(dryrun) download: s3://cellxgene-census

## Prepare parquet

In [None]:
cluster = LocalCluster(n_workers=1, threads_per_worker=4)
client = Client(cluster)
client

In [None]:
@dask.delayed
def read_X(path, idx):
    return ad.read_h5ad(path, backed="r").X[idx, :].toarray().astype("float32", copy=False)

In [None]:
# number of files per parquet file
CHUNK_SIZE = 32768
# row group size of parquet files
ROW_GROUP_SIZE = 1024

In [None]:
array_chunks = []
chunk_sizes = []

for artifact in collection_inner.artifacts:
    with artifact.backed() as access:
        n_obs = access.shape[0]
    idx_splits = np.array_split(np.arange(n_obs), np.ceil(n_obs / CHUNK_SIZE))
    for idx in idx_splits:
        array_chunks.append(read_X(artifact.stage().as_posix(), idx))
        chunk_sizes.append(len(idx))

In [None]:
X = da.concatenate([
    da.from_delayed(chunk, (shape, len(var_inner)), dtype="float32") 
    for chunk, shape in zip(array_chunks, chunk_sizes)
]).rechunk((CHUNK_SIZE, -1))

In [None]:
X

In [None]:
@dask.delayed
def convert_to_dataframe(x, start, end):
    return pd.DataFrame(
        {'X': [arr.squeeze().astype("float32", copy=False) for arr in np.vsplit(x, x.shape[0])]},
        index=pd.RangeIndex(start, end)
    )

In [None]:
start_index = [0] + list(np.cumsum(X.chunks[0]))[:-1]
end_index = list(np.cumsum(X.chunks[0]))
# calculate divisons for dask dataframe
divisions = [0] + list(np.cumsum(X.chunks[0]))
divisions[-1] = divisions[-1] - 1
ddf = dd.from_delayed(
    [
        convert_to_dataframe(arr, start, end) for arr, start, end in 
        zip(X.to_delayed().flatten().tolist(), start_index, end_index)
    ],
    divisions=divisions
)

In [None]:
ddf

In [None]:
ddf.to_parquet(
    "./merlin_benchmark", 
    engine='pyarrow',
    schema=pa.schema([('X', pa.list_(pa.float32()))]),
    write_metadata_file=True,
    row_group_size=ROW_GROUP_SIZE
)

In [None]:
artifact_parquet = ln.Artifact("./merlin_benchmark", description=collection_inner.name + " counts in parquet")

In [None]:
artifact_parquet.save()