In [23]:
import pandas as pd
import anndata as ad
import numpy as np
import h5py
import os
import pickle
import scipy.sparse as sp
from geneformer import TranscriptomeTokenizer

In [12]:
input_path = "/nfs/turbo/umms-indikar/shared/projects/geneformer/data/rajapakse_lab_data.h5ad"
input_dir = os.path.dirname(input_path)
output_path = "/scratch/indikar_root/indikar1/cstansbu/geneformer/"
prefix = "test"

def get_attributes(h5ad_path):
    """
    Extracts attribute names from the `.obs` field of an h5ad AnnData file,
    returning them as a dictionary with keys and values being the attribute names.

    Args:
        h5ad_path (str): The path to the h5ad file.

    Returns:
        dict: A dictionary where keys and values are the unique attribute names 
              found in the `.obs` field of the h5ad file.
    """
    with h5py.File(h5ad_path, mode="r") as store:
        attribute_names = list(store["obs"].keys())

    attribute_name_dict = {name: name for name in attribute_names}  # Create dictionary
    return attribute_name_dict
    

custom_attr_name_dict = get_attributes(input_path)
custom_attr_name_dict

{'control': 'control',
 'data_id': 'data_id',
 'dataset': 'dataset',
 'hour': 'hour',
 'n_counts': 'n_counts',
 'sample_id': 'sample_id',
 'timepoint': 'timepoint'}

In [15]:
def load_gene_median_dict(gene_median_file):
    """
    Loads a gene median dictionary from a pickle file.

    Args:
        gene_median_file (str): Path to the pickle file containing the gene median dictionary.

    Returns:
        dict: A dictionary mapping gene IDs to their median expression values.
    """

    with open(gene_median_file, "rb") as f:
        gene_median_dict = pickle.load(f)

    return gene_median_dict


def load_gene_tokenization(token_dictionary_file):
    """
    Loads gene tokenization data from a pickle file.

    Args:
        token_dictionary_file (str): Path to the pickle file containing the gene-token dictionary.

    Returns:
        dict: Gene-token dictionary (Ensembl ID: token).
        list: List of all gene keys (Ensembl IDs).
        dict: Dictionary mapping gene keys to True (used for selecting genes later).
    """

    with open(token_dictionary_file, "rb") as f:
        gene_token_dict = pickle.load(f)

    gene_keys = list(gene_token_dict.keys())

    # Optimization: Pre-allocate the list for slight performance improvement
    genelist_dict = dict.fromkeys(gene_keys, True)

    return gene_token_dict, gene_keys, genelist_dict


def rank_genes(gene_vector, gene_tokens):
    """Ranks genes based on expression values in descending order.

    Args:
        gene_vector (numpy.ndarray): Array of gene expression values.
        gene_tokens (numpy.ndarray): Array of corresponding gene tokens.

    Returns:
        numpy.ndarray: Array of gene tokens sorted by descending expression value.
    """
    return gene_tokens[np.argsort(-gene_vector)]


def normalize_counts(adata_chunk,  counts_column='n_counts', target_sum=10000):
    """Normalizes gene expression counts within a chunk of AnnData.

    Args:
        adata_chunk (AnnData): A chunk of the AnnData object containing gene expression data.
        counts_column (str): Name of the column in `adata_chunk.obs` containing the total counts per cell.
        target_sum (float): The desired total count per cell after normalization.
        norm_factor_vector (numpy.ndarray): An array of normalization factors for each gene.

    Returns:
        scipy.sparse.csr_matrix: A sparse matrix containing the normalized gene expression counts.

    This function performs the following steps:
        1. Extracts the total counts per cell from the specified column (`counts_column`).
        2. Normalizes the gene expression matrix (`adata_chunk.X`) by dividing by the total counts 
           and multiplying by the `target_sum`.
        3. Further adjusts the normalized values by dividing by the gene-specific normalization 
           factors (`norm_factor_vector`).
        4. Returns the normalized expression matrix as a sparse CSR matrix for efficient storage 
           and computation.
    """
    
    n_counts = adata_chunk.obs[counts_column].values[:, None]  # Cell counts as column vector
    X_norm = adata_chunk.X / n_counts * target_sum / norm_factor_vector
    return sp.csr_matrix(X_norm)  # Efficient sparse representation


def tokenize_anndata(adata, genelist_dict, gene_median_dict, 
                     chunk_size=100000, target_sum=10000):
    """
    Tokenizes and ranks genes within an AnnData object, optimizing for memory efficiency.

    This function processes gene expression data in chunks, applies normalization, and ranks genes
    for each cell based on their expression levels. The resulting tokenized and ranked gene
    representations, along with cell metadata, are returned.

    Args:
        adata (AnnData): The AnnData object containing gene expression data.
        genelist_dict (dict): Dictionary mapping gene IDs to boolean values indicating relevance.
        gene_median_dict (dict): Dictionary mapping gene IDs to their median expression values.
        chunk_size (int, optional): Number of cells to process in each chunk (default: 1000).
        target_sum (int, optional): Target sum for count normalization (default: 10000).

    Returns:
        tuple: 
            - list: List of tokenized and ranked gene lists for each cell.
            - dict: Dictionary containing cell metadata (keys are metadata column names).
    """
    # Filter relevant miRNAs
    coding_miRNA_mask = np.array([genelist_dict.get(i, False) for i in adata.var['ensembl_id']])
    coding_miRNA_loc = np.where(coding_miRNA_mask)[0]

    # Extract miRNA information
    coding_miRNA_ids = adata.var['ensembl_id'][coding_miRNA_loc]
    norm_factor_vector = np.array([gene_median_dict[i] for i in coding_miRNA_ids])
    coding_miRNA_tokens = np.array([gene_token_dict[i] for i in coding_miRNA_ids])

    tokenized_cells = []
    file_cell_metadata = {k: [] for k in adata.obs.columns}  # Initialize metadata dict

    # Process in chunks for memory efficiency
    for chunk_start in range(0, adata.shape[0], chunk_size):
        chunk_end = chunk_start + chunk_size
        adata_chunk = adata[chunk_start:chunk_end, coding_miRNA_loc]
        
        # Normalize counts (could be replaced with the untested function above)
        n_counts = adata_chunk.obs['n_counts'].values[:, None]
        X_norm = adata_chunk.X / n_counts * target_sum / norm_factor_vector
        X_norm = sp.csr_matrix(X_norm)  

        # Tokenize and rank genes for each cell in chunk
        for i in range(X_norm.shape[0]):
            ranks = rank_genes(X_norm[i].data, coding_miRNA_tokens[X_norm[i].indices])
            ranks = list(ranks[~np.isnan(ranks)].astype(int))

            tokenized_cells.append(ranks)

        # Update metadata
        for k in adata.obs.columns:
            file_cell_metadata[k].extend(adata_chunk.obs[k].tolist())

    return tokenized_cells, file_cell_metadata

In [20]:
DEFAULT_TOKEN_PATH = "/nfs/turbo/umms-indikar/shared/projects/geneformer/token_dictionary.pkl"
DEFAULT_MEDIAN_PATH = "/nfs/turbo/umms-indikar/shared/projects/geneformer/geneformer/gene_median_dictionary.pkl"

gene_token_dict, gene_keys, genelist_dict = load_gene_tokenization(DEFAULT_TOKEN_PATH)
gene_median_dict = load_gene_median_dict(DEFAULT_MEDIAN_PATH)

In [24]:
tokenized_cells, cell_metadata = tokenize_anndata(adata, 
                                                  genelist_dict, 
                                                  gene_median_dict)

  coding_miRNA_ids = adata.var['ensembl_id'][coding_miRNA_loc]


In [27]:
adata

AnnData object with n_obs × n_vars = 66 × 19393 backed at '/nfs/turbo/umms-indikar/shared/projects/geneformer/data/rajapakse_lab_data.h5ad'
    obs: 'dataset', 'sample_id', 'timepoint', 'hour', 'n_counts', 'control'
    var: 'gene_id', 'token_id', 'Chromosome', 'Source', 'Feature', 'Start', 'End', 'Score', 'Strand', 'Frame', 'gene_version', 'gene_source', 'gene_biotype', 'transcript_id', 'transcript_version', 'transcript_name', 'transcript_source', 'transcript_biotype', 'tag', 'ccds_id', 'exon_number', 'exon_id', 'exon_version', 'protein_id', 'protein_version', 'transcript_support_level', 'ensembl_id'

In [None]:
def tokenize_anndata(adata_file_path, target_sum=10_000, custom_attr_name_dict=None):
    adata = ad.read(adata_file_path, backed="r")

    if custom_attr_name_dict is not None:
        file_cell_metadata = {
            attr_key: [] for attr_key in custom_attr_name_dict.keys()
        }

    coding_miRNA_loc = np.where(
        [self.genelist_dict.get(i, False) for i in adata.var["ensembl_id"]]
    )[0]
    norm_factor_vector = np.array(
        [
            gene_median_dict[i]
            for i in adata.var["ensembl_id"][coding_miRNA_loc]
        ]
    )
    coding_miRNA_ids = adata.var["ensembl_id"][coding_miRNA_loc]
    coding_miRNA_tokens = np.array(
        [self.gene_token_dict[i] for i in coding_miRNA_ids]
    )

    try:
        _ = adata.obs["filter_pass"]
    except KeyError:
        var_exists = False
    else:
        var_exists = True

    if var_exists:
        filter_pass_loc = np.where([i == 1 for i in adata.obs["filter_pass"]])[0]
    elif not var_exists:
        print(
            f"{adata_file_path} has no column attribute 'filter_pass'; tokenizing all cells."
        )
        filter_pass_loc = np.array([i for i in range(adata.shape[0])])

    tokenized_cells = []

    for i in range(0, len(filter_pass_loc), self.chunk_size):
        idx = filter_pass_loc[i : i + self.chunk_size]

        n_counts = adata[idx].obs["n_counts"].values[:, None]
        X_view = adata[idx, coding_miRNA_loc].X
        X_norm = X_view / n_counts * target_sum / norm_factor_vector
        X_norm = sp.csr_matrix(X_norm)

        tokenized_cells += [
            rank_genes(X_norm[i].data, coding_miRNA_tokens[X_norm[i].indices])
            for i in range(X_norm.shape[0])
        ]

        # add custom attributes for subview to dict
        if self.custom_attr_name_dict is not None:
            for k in file_cell_metadata.keys():
                file_cell_metadata[k] += adata[idx].obs[k].tolist()
        else:
            file_cell_metadata = None

    return tokenized_cells, file_cell_metadata


tokenize_anndata(input_path)

In [3]:
# tk.tokenize_data(input_dir, 
#                  output_path, 
#                  prefix, 
#                  file_format="h5ad")

In [4]:
adata = ad.read(input_path, backed="r")
adata



AnnData object with n_obs × n_vars = 66 × 19393 backed at '/nfs/turbo/umms-indikar/shared/projects/geneformer/data/rajapakse_lab_data.h5ad'
    obs: 'dataset', 'sample_id', 'timepoint', 'hour', 'n_counts', 'control'
    var: 'gene_id', 'token_id', 'Chromosome', 'Source', 'Feature', 'Start', 'End', 'Score', 'Strand', 'Frame', 'gene_version', 'gene_source', 'gene_biotype', 'transcript_id', 'transcript_version', 'transcript_name', 'transcript_source', 'transcript_biotype', 'tag', 'ccds_id', 'exon_number', 'exon_id', 'exon_version', 'protein_id', 'protein_version', 'transcript_support_level', 'ensembl_id'

In [5]:
adata.var.head()

Unnamed: 0_level_0,gene_id,token_id,Chromosome,Source,Feature,Start,End,Score,Strand,Frame,...,transcript_biotype,tag,ccds_id,exon_number,exon_id,exon_version,protein_id,protein_version,transcript_support_level,ensembl_id
gene_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
A1BG,ENSG00000121410,5150.0,19,ensembl_havana,gene,58345177.0,58353492.0,.,-,.,...,,,,,,,,,,ENSG00000121410
A1CF,ENSG00000148584,9064.0,10,ensembl_havana,gene,50799408.0,50885675.0,.,-,.,...,,,,,,,,,,ENSG00000148584
A2M,ENSG00000175899,13826.0,12,ensembl_havana,gene,9067663.0,9116229.0,.,-,.,...,,,,,,,,,,ENSG00000175899
A2ML1,ENSG00000166535,11812.0,12,ensembl_havana,gene,8822620.0,8887001.0,.,+,.,...,,,,,,,,,,ENSG00000166535
A3GALT2,ENSG00000184389,15327.0,1,ensembl_havana,gene,33306765.0,33321098.0,.,-,.,...,,,,,,,,,,ENSG00000184389


In [10]:
idx = ['63246_T0R1', '63249_T2R1']
var_idx = ['A1BG', 'A2M']

X_view = adata[idx, var_idx]
X_view

View of AnnData object with n_obs × n_vars = 2 × 2 backed at '/nfs/turbo/umms-indikar/shared/projects/geneformer/data/rajapakse_lab_data.h5ad'
    obs: 'dataset', 'sample_id', 'timepoint', 'hour', 'n_counts', 'control'
    var: 'gene_id', 'token_id', 'Chromosome', 'Source', 'Feature', 'Start', 'End', 'Score', 'Strand', 'Frame', 'gene_version', 'gene_source', 'gene_biotype', 'transcript_id', 'transcript_version', 'transcript_name', 'transcript_source', 'transcript_biotype', 'tag', 'ccds_id', 'exon_number', 'exon_id', 'exon_version', 'protein_id', 'protein_version', 'transcript_support_level', 'ensembl_id'