# Prepare a collection for Merlin benchmarks 

In [None]:
!lamin load laminlabs/arrayloader-benchmarks

In [None]:
import lamindb as ln
import anndata as ad
import dask
import dask.array as da
import dask.dataframe as dd
from dask.distributed import Client, LocalCluster
import numpy as np
import pandas as pd
import pyarrow as pa
from tqdm.notebook import tqdm

In [None]:
ln.track()

In [None]:
collection = ln.Collection.filter(uid="1gsdckxvOvIjQgeDVS1F").one()

In [None]:
collection

In [None]:
with collection.mapped(join="inner") as ds:
    var_inner = ds.var_joint

In [None]:
var_inner = var_inner[:20000].tolist()

In [None]:
artifacts_processed = []
for artifact in tqdm(collection.artifacts):
    print(artifact.key)
    with artifact.backed() as access:
        # todo: check why this error happens
        # TypeError: Indexing elements must be in increasing order
        # via selection of named varibales
        # access[:, var_inner].to_memory()
        # for the second artifact 0325478a-9b52-45b5-b40a-2e2ab0d72eb1.h5ad
        # ok, it seems it doens't work with non-increasing indices
        # todo: fix
        idx_sort, reverse = np.unique(access.var_names.get_indexer(var_inner), return_inverse=True)
        adata = access[:, idx_sort].to_memory()[:, reverse]
    assert all(adata.var_names == var_inner)
    print("adata loaded")
    artifact_processed = ln.Artifact(adata, description=artifact.description + " subset of 20k vars")
    artifact_processed.save()
    artifacts_processed.append(artifact_processed)

In [None]:
collection_inner = ln.Collection(
    artifacts_processed, 
    name=collection.name + " inner join and subset of 20k vars."
)

In [None]:
collection_inner.save()

## Prepare parquet

In [None]:
collection_inner = ln.Collection.filter(name__icontains=" inner join and subset of 20k vars.").one()

In [None]:
cluster = LocalCluster(n_workers=1, threads_per_worker=4)
client = Client(cluster)
client

In [None]:
@dask.delayed
def read_X(path, idx):
    return ad.read_h5ad(path, backed="r").X[idx, :].toarray().astype("float32", copy=False)

In [None]:
# number of files per parquet file
CHUNK_SIZE = 32768
# row group size of parquet files
ROW_GROUP_SIZE = 1024

In [None]:
array_chunks = []
chunk_sizes = []

for artifact in collection_inner.artifacts:
    with artifact.backed() as access:
        n_obs = access.shape[0]
    idx_splits = np.array_split(np.arange(n_obs), np.ceil(n_obs / CHUNK_SIZE))
    for idx in idx_splits:
        array_chunks.append(read_X(artifact.stage().as_posix(), idx))
        chunk_sizes.append(len(idx))

In [None]:
X = da.concatenate([
    da.from_delayed(chunk, (shape, len(var_inner)), dtype="float32") 
    for chunk, shape in zip(array_chunks, chunk_sizes)
]).rechunk((CHUNK_SIZE, -1))

In [None]:
X

In [None]:
@dask.delayed
def convert_to_dataframe(x, start, end):
    return pd.DataFrame(
        {'X': [arr.squeeze().astype("float32", copy=False) for arr in np.vsplit(x, x.shape[0])]},
        index=pd.RangeIndex(start, end)
    )

In [None]:
start_index = [0] + list(np.cumsum(X.chunks[0]))[:-1]
end_index = list(np.cumsum(X.chunks[0]))
# calculate divisons for dask dataframe
divisions = [0] + list(np.cumsum(X.chunks[0]))
divisions[-1] = divisions[-1] - 1
ddf = dd.from_delayed(
    [
        convert_to_dataframe(arr, start, end) for arr, start, end in 
        zip(X.to_delayed().flatten().tolist(), start_index, end_index)
    ],
    divisions=divisions
)

In [None]:
ddf

In [None]:
ddf.to_parquet(
    "./merlin_benchmark", 
    engine='pyarrow',
    schema=pa.schema([('X', pa.list_(pa.float32()))]),
    write_metadata_file=True,
    row_group_size=ROW_GROUP_SIZE
)

In [None]:
artifact_parquet = ln.Artifact("./merlin_benchmark", description=collection_inner.name + " counts in parquet")

In [None]:
artifact_parquet.save()