In [None]:
import logging
import numpy as np
import scanpy as sc
import pandas as pd
import scipy
import json
import pickle

clustered_embeddings = sc.read_h5ad(snakemake.input.umap_embedding)
clustered_embeddings.obs.set_index("orig_ids", inplace=True)  # needed to allow transfer

In [None]:
adata = sc.read_h5ad(snakemake.input.read_count_table)
sc.pp.normalize_total(adata, target_sum=1e4)

In [None]:
# EnrichR terms
with open(snakemake.input.enrichr_terms, "r") as f:
    terms = json.load(f)


In [None]:
processed_data = np.load(snakemake.input.processed_data, allow_pickle=True)

In [None]:
# assert that the order of orig_ids matches the one in adata.var.index
assert (processed_data["orig_ids"] == adata.obs.index).all()

In [None]:
with open(snakemake.input.gene_log1p_normalizers, "rb") as fp:
    gene_log1p_normalizers = pickle.load(fp)
adata.var["log1p_normalizer"] = gene_log1p_normalizers

In [None]:
# Add the cluster labels

llava_cluster_labels = pd.read_csv(snakemake.input.cellwhisperer_llava_labels)
llava_cluster_labels["cluster_values"] = llava_cluster_labels["cluster_values"].astype(str)
llava_cluster_labels.set_index(["cluster_field", "cluster_values"], inplace=True)
llava_cluster_labels.head()

In [None]:
adata.obsm["X_cellwhisperer_umap"] = clustered_embeddings.obsm["X_umap"]
adata.obsm["transcriptome_embeds"] = processed_data["transcriptome_embeds"]

# Define "Corpora" metadata fields (some of them get shown in the UI. They are also a prerequisite for setting `default_embedding``)
adata.uns["dataset_name"] = snakemake.wildcards.dataset
adata.uns["model_name"] = snakemake.wildcards.model
adata.uns["terms"] = terms
adata.uns["version"] = {"corpora_schema_version": "1.1.0", "corpora_encoding_version": "0.1.0"}
adata.uns["title"] = f"{snakemake.wildcards.dataset} ({snakemake.wildcards.model})"
adata.uns["layer_descriptions"] = "X: log1p after 10K total-normalization"
adata.uns["organism"] = "human"
adata.uns["organism_ontology_term_id"] = "NCBITaxon:9606"
adata.uns["default_embedding"] = "X_cellwhisperer_umap"


In [None]:
# Transfer the cluster labels
adata.obs["leiden"] = clustered_embeddings.obs["leiden"]

cluster_map = {
    cluster_column: ("cluster_label" if cluster_column == "leiden" else f"{cluster_column}_label")
    for cluster_column in llava_cluster_labels.index.get_level_values("cluster_field").unique()
}

for cluster_column, label_column in cluster_map.items():
    adata.obs[label_column] = llava_cluster_labels.loc[cluster_column, "curated_labels"].reindex(adata.obs[cluster_column].values).values


In [None]:
adata.obs.iloc[0]


In [None]:
if "series_submission_date" in adata.obs:
    adata.obs["series_submission_date_cont"] = pd.to_datetime(adata.obs["series_submission_date"])
    adata.obs.loc[adata.obs["series_submission_date_cont"].isna(), "series_submission_date_cont"] = adata.obs["series_submission_date_cont"].mean()  # 57 are broken
    adata.obs["series_submission_date_cont"] = adata.obs["series_submission_date_cont"].apply(lambda x: x.timestamp()) / (365.25 * 24*60*60) + 1970

In [None]:
# Remove columns with an extensive number of categories
drop_cols = [c for c in adata.obs.columns if str(adata.obs[c].dtype) == 'category' and len(adata.obs[c].dtype.categories) > snakemake.params.max_categories_filter and c != "series_submission_date"]
adata.obs.drop(columns=drop_cols, inplace=True)

In [None]:
# Convert int64 to int32 and float64 to float32 for `cellxgene`
adata.obs = adata.obs.astype({col: np.int32 if adata.obs[col].dtype == np.int64 else np.float32 for col in adata.obs.columns if adata.obs[col].dtype in [np.int64, np.float64]})
adata.var = adata.var.astype({col: np.int32 if adata.var[col].dtype == np.int64 else np.float32 for col in adata.var.columns if adata.var[col].dtype in [np.int64, np.float64]})

In [None]:
if 'normalized' in adata.layers:
    adata.X = adata.layers['normalized']
    logging.warning("Taking the provided `normalized` layer instead of computing log1p")
else:
    sc.pp.log1p(adata)

In [None]:
for key in list(adata.layers.keys()):
    del adata.layers[key]

for key in list(adata.obsp.keys()):
    del adata.obsp[key]

In [None]:
# shrink huge datasets
adata.X = adata.X.astype(np.float32)  # Convert float64 to float32
if len(adata.X.data) > 1e9:
    logging.warning(f"Reducing number of elements sparse matrix")
    threshold = 0.1
    if len(adata.X.data) > 1e10:
        threshold = 0.5
    adata.X.data[adata.X.data < threshold] = 0
    adata.X.eliminate_zeros()

In [None]:
# TODO optionally reduce `var` dimension (e.g. filter by gene names)

duplicated_gene_names = adata.var.index[adata.var["gene_name"].duplicated()]
if len(duplicated_gene_names) > 0:
    adata.var["gene_name"] = adata.var["gene_name"].astype(str)
    adata.var.loc[duplicated_gene_names, "gene_name"] = adata.var.loc[duplicated_gene_names].apply(lambda row: f"{row['gene_name']}_{row.name}", axis=1)

In [None]:
assert not adata.var["gene_name"].duplicated().any()

In [None]:
# Convert CSR to CSC matrix (cellxgene is optimized for CSC)
if isinstance(adata.X, scipy.sparse.csr_matrix):
    adata.X = scipy.sparse.csc_matrix(adata.X)

In [None]:
top_genes = pd.read_parquet(snakemake.input.top_genes)

# ensure they are categorical

# ensure the index matches
assert (top_genes.index == adata.obs.index).all()

# add via obsm
adata.obsm["top_genes"] = top_genes

In [None]:
adata.write_h5ad(snakemake.output.adata)