# AnnData -> GeneFormerEmbeddings -> AnnData

In this file Joshua changes a file made by cooper to standardize the time point and replicate naming from files made by cooper

In [1]:
import torch

if torch.cuda.is_available():
    print("GPU is available")
else:
    print("GPU is not available")


GPU is not available


## main()

In [None]:
import sys
import os
import argparse
import pandas as pd
import numpy as np
import pickle
import scipy.sparse as sp
import scanpy as sc
import anndata as an
from datasets import Dataset, load_from_disk
import torch

sys.path.append('/home/jpic/geneformer_dev/scripts')
import geneformer_utils as gtu

def main(input_file=None, output_directory=None, verbose=True):

    input_path  = input_file
    base_name = os.path.splitext(os.path.basename(input_file))[0]
    output_path = os.path.join(output_directory, base_name + '.dataset')
    outpath = os.path.join(output_directory, base_name + '.h5ad')
    
    # Default values
    MODEL_PATH = "/nfs/turbo/umms-indikar/shared/projects/geneformer/geneformer-12L-30M/"
    DEFAULT_NAME_PATH = "/nfs/turbo/umms-indikar/shared/projects/geneformer/geneformer/gene_name_id_dict.pkl"
    DEFAULT_TOKEN_PATH = "/nfs/turbo/umms-indikar/shared/projects/geneformer/token_dictionary.pkl"
    DEFAULT_MEDIAN_PATH = "/nfs/turbo/umms-indikar/shared/projects/geneformer/geneformer/gene_median_dictionary.pkl"
    MODEL_INPUT_SIZE = 2048
    NUMBER_PROC = 16
    TARGET_SUM = 10000
    GENE_ID = 'ensembl_id'
    COUNTS_COLUMN = 'n_counts'
    LAYER = 'X'
    GENE_NAME_COLUMN = 'gene_name'

    # set values used for embedding
    token_path  = DEFAULT_TOKEN_PATH
    median_path = DEFAULT_MEDIAN_PATH
    n_proc      = NUMBER_PROC
    model_size  = MODEL_INPUT_SIZE
    target_sum  = TARGET_SUM
    gene_id     = GENE_ID
    aggregate_transcripts = False
    counts_column = COUNTS_COLUMN
    layer       = LAYER
    gene_names  = DEFAULT_NAME_PATH
    gene_name_column = GENE_NAME_COLUMN
    map_names   = False
    num_cells   = None # all cells, useful for testing 


    ###########################################
    #
    #   TOKENIZE COUNTS DATA FOR GENEFORMER
    #
    ###########################################
    print("Loading gene tokenization data...") if verbose else None
    gene_token_dict, gene_keys, genelist_dict = load_gene_tokenization(token_path)
    print(f"Loaded {len(gene_token_dict)} gene tokens") if verbose else None
    
    print("Loading gene median expression data...") if verbose else None
    gene_median_dict = load_gene_median_dict(median_path)
    print(f"Loaded {len(gene_median_dict)} gene median expression values") if verbose else None
    
    if map_names:
        print("Loading gene name mapping data...") if verbose else None
        gene_names = load_gene_names(gene_names)
        print(f"Loaded {len(gene_names)} gene name mappings") if verbose else None
    
    # Load and pre-process data
    print(f"Loading AnnData from {input_path}...") if verbose else None
    adata = sc.read_h5ad(input_path)
    print(f"Loaded AnnData with shape {adata.shape}") if verbose else None
    
    if map_names:
        print("Mapping gene names to Ensembl IDs...") if verbose else None
        adata = map_gene_names(adata, gene_id, gene_name_column, gene_names)
    
    if not layer == 'X':
        print(f"Using layer '{layer}' for expression data...") if verbose else None
        adata.X = adata.layers[layer]
        
    print("Checking for and/or calculating total counts per cell...") if verbose else None
    adata = check_counts_column(adata, counts_column)
    
    # Tokenize and rank genes
    print("Tokenizing and ranking genes...") if verbose else None
    tokenized_cells, cell_metadata = tokenize_anndata(
        adata, genelist_dict, gene_median_dict,
        target_sum=target_sum, gene_id=gene_id, counts_column=counts_column
    )
    print(f"Processed {len(tokenized_cells)} cells") if verbose else None
    
    # Create Hugging Face dataset
    print("Creating Hugging Face dataset...") if verbose else None
    dataset_dict = {
        "input_ids": tokenized_cells,
        **cell_metadata
    }
    output_dataset = Dataset.from_dict(dataset_dict)
    print(f"Dataset has {len(output_dataset)} examples") if verbose else None
    
    # Format cell features
    print("Formatting cell features...") if verbose else None
    dataset = output_dataset.map(format_cell_features, num_proc=n_proc)
    
    # Save dataset
    print(f"Saving processed dataset to {output_path}...") if verbose else None
    
    save_hf_dataset(dataset, output_path, overwrite=True)
    print("Processing completed successfully!") if verbose else None

    
    ###########################################
    #
    #   EMBED TOKENS WITH GENEFORMER TO ANNDATA
    #
    ###########################################
    dataset_path = output_path
    
    print(MODEL_PATH)
    
    print(f"Loading model from '{MODEL_PATH}'...") if verbose else None
    model = gtu.load_model(MODEL_PATH)
    print("Model loaded successfully!") if verbose else None
    
    print(f"Loading dataset from '{dataset_path}' (up to {num_cells} cells)...") if verbose else None
    try:
        df = gtu.load_data_as_dataframe(dataset_path, num_cells=num_cells)
        data = Dataset.from_pandas(df)
        df = df.drop(columns='input_ids')
    except FileNotFoundError:
        print(f"Error: Dataset file not found at '{dataset_path}'") if verbose else None
        sys.exit(1)
    except Exception as e:  # Catching other potential errors
        print(f"Error loading dataset: {e}") if verbose else None
        sys.exit(1)
    print("Dataset loaded successfully!") if verbose else None
    
    print("Extracting embeddings...") if verbose else None
    embs = gtu.extract_embedding_in_mem(model, data)
    adata = gtu.embedding_to_adata(embs)
    adata.obs = df.astype(str).reset_index().copy()
    print("Embeddings extracted successfully!") if verbose else None
    
    print(f"Writing results to '{outpath}'...") if verbose else None
    try:
        adata.write(outpath)
    except Exception as e:
        print(f"Error writing output file: {e}") if verbose else None
        sys.exit(1)
    print("Output file written successfully!") if verbose else None



In [6]:
arg_in  = '/nfs/turbo/umms-indikar/shared/projects/geneformer/data/rajapakse_lab_data_jpic.h5ad'
arg_out = '/nfs/turbo/umms-indikar/shared/projects/geneformer/data/test'
main(input_file=arg_in, output_directory=arg_out, verbose=True)

Loading gene tokenization data...
Loaded 25426 gene tokens
Loading gene median expression data...
Loaded 25424 gene median expression values
Loading AnnData from /nfs/turbo/umms-indikar/shared/projects/geneformer/data/rajapakse_lab_data_jpic.h5ad...
Loaded AnnData with shape (66, 19393)
Checking for and/or calculating total counts per cell...
Tokenizing and ranking genes...


NameError: name 'gene_token_dict' is not defined

In [7]:
help(tokenize_anndata)

Help on function tokenize_anndata in module __main__:

tokenize_anndata(adata, genelist_dict, gene_median_dict, chunk_size=100000, target_sum=10000, counts_column='n_counts', gene_id='ensembl_id')
    Tokenizes and ranks genes within an AnnData object, optimizing for memory efficiency.
    
    This function processes gene expression data in chunks, applies normalization, and ranks genes
    for each cell based on their expression levels. The resulting tokenized and ranked gene
    representations, along with cell metadata, are returned.
    
    Args:
        adata (AnnData): The AnnData object containing gene expression data.
        genelist_dict (dict): Dictionary mapping gene IDs to boolean values indicating relevance.
        gene_median_dict (dict): Dictionary mapping gene IDs to their median expression values.
        chunk_size (int, optional): Number of cells to process in each chunk (default: 1000).
        target_sum (int, optional): Target sum for count normalization (defa

In [None]:
arg_in  = '/nfs/turbo/umms-indikar/shared/projects/geneformer/data/rajapakse_lab_data_jpic.h5ad'
arg_out = '/nfs/turbo/umms-indikar/shared/projects/geneformer/data/rajapakse_lab_data_jpic.dataset'
arg_out2 = '/nfs/turbo/umms-indikar/shared/projects/geneformer/data/rajapakse_lab_data_jpic_GF_embeddings'
main()

## to_geneformer.py

In [6]:
import sys
import os
import argparse
import pandas as pd
import numpy as np
import pickle
import scipy.sparse as sp
import scanpy as sc
import anndata as an
from datasets import Dataset

In [2]:
# input arguments
arg_in  = '/nfs/turbo/umms-indikar/shared/projects/geneformer/data/rajapakse_lab_data_jpic.h5ad'
arg_out = '/nfs/turbo/umms-indikar/shared/projects/geneformer/data/rajapakse_lab_data_jpic.dataset'
arg_out2 = '/nfs/turbo/umms-indikar/shared/projects/geneformer/data/rajapakse_lab_data_jpic_GF_embeddings'
arg_verbos = True

In [3]:
# Default values
DEFAULT_NAME_PATH = "/nfs/turbo/umms-indikar/shared/projects/geneformer/geneformer/gene_name_id_dict.pkl"
DEFAULT_TOKEN_PATH = "/nfs/turbo/umms-indikar/shared/projects/geneformer/token_dictionary.pkl"
DEFAULT_MEDIAN_PATH = "/nfs/turbo/umms-indikar/shared/projects/geneformer/geneformer/gene_median_dictionary.pkl"
MODEL_INPUT_SIZE = 2048
NUMBER_PROC = 16
TARGET_SUM = 10000
GENE_ID = 'ensembl_id'
COUNTS_COLUMN = 'n_counts'
LAYER = 'X'
GENE_NAME_COLUMN = 'gene_name'

input_path = arg_in
output_path = arg_out
token_path  = DEFAULT_TOKEN_PATH
median_path = DEFAULT_MEDIAN_PATH
n_proc      = NUMBER_PROC
model_size  = MODEL_INPUT_SIZE
target_sum  = TARGET_SUM
gene_id     = GENE_ID
aggregate_transcripts = False  # Default to False since it's an optional flag
counts_column = COUNTS_COLUMN
layer       = LAYER
gene_names  = DEFAULT_NAME_PATH
gene_name_column = GENE_NAME_COLUMN
map_names   = False  # Default to False since it's an optional flag
verbose     = arg_verbos

In [7]:
print("Loading gene tokenization data...")
gene_token_dict, gene_keys, genelist_dict = load_gene_tokenization(token_path)
print(f"Loaded {len(gene_token_dict)} gene tokens")

print("Loading gene median expression data...")
gene_median_dict = load_gene_median_dict(median_path)
print(f"Loaded {len(gene_median_dict)} gene median expression values")

if map_names:
    print("Loading gene name mapping data...")
    gene_names = load_gene_names(gene_names)
    print(f"Loaded {len(gene_names)} gene name mappings")

# Load and pre-process data
print(f"Loading AnnData from {input_path}...")
adata = sc.read_h5ad(input_path)
print(f"Loaded AnnData with shape {adata.shape}")

if map_names:
    print("Mapping gene names to Ensembl IDs...")
    adata = map_gene_names(adata, gene_id, gene_name_column, gene_names)

if not layer == 'X':
    print(f"Using layer '{layer}' for expression data...")
    adata.X = adata.layers[layer]
    
print("Checking for and/or calculating total counts per cell...")
adata = check_counts_column(adata, counts_column)

# Tokenize and rank genes
print("Tokenizing and ranking genes...")
tokenized_cells, cell_metadata = tokenize_anndata(
    adata, genelist_dict, gene_median_dict,
    target_sum=target_sum, gene_id=gene_id, counts_column=counts_column
)
print(f"Processed {len(tokenized_cells)} cells")

# Create Hugging Face dataset
print("Creating Hugging Face dataset...")
dataset_dict = {
    "input_ids": tokenized_cells,
    **cell_metadata
}
output_dataset = Dataset.from_dict(dataset_dict)
print(f"Dataset has {len(output_dataset)} examples")

# Format cell features
print("Formatting cell features...")
dataset = output_dataset.map(format_cell_features, num_proc=n_proc)

# Save dataset
print(f"Saving processed dataset to {output_path}...")
save_hf_dataset(dataset, output_path, overwrite=True)
print("Processing completed successfully!")

Loading gene tokenization data...
Loaded 25426 gene tokens
Loading gene median expression data...
Loaded 25424 gene median expression values
Loading AnnData from /nfs/turbo/umms-indikar/shared/projects/geneformer/data/rajapakse_lab_data_jpic.h5ad...
Loaded AnnData with shape (66, 19393)
Checking for and/or calculating total counts per cell...
Tokenizing and ranking genes...
Processed 66 cells
Creating Hugging Face dataset...
Dataset has 66 examples
Formatting cell features...


Map (num_proc=16):   0%|          | 0/66 [00:00<?, ? examples/s]

Saving processed dataset to /nfs/turbo/umms-indikar/shared/projects/geneformer/data/rajapakse_lab_data_jpic.dataset...


Saving the dataset (0/1 shards):   0%|          | 0/66 [00:00<?, ? examples/s]

Processing completed successfully!


## Embedding Extractor

In [7]:
import sys
import os
import pandas as pd
import numpy as np
import torch
import anndata as an
import scanpy as sc
from datasets import Dataset, load_from_disk
sys.path.append('/home/jpic/geneformer_dev/scripts')
import geneformer_utils as gtu
torch.cuda.empty_cache()

In [8]:
sys.path.append('/home/jpic/geneformer_dev/scripts')
import geneformer_utils as gtu
torch.cuda.empty_cache()

In [9]:
dataset_path = arg_out
model_path = "/nfs/turbo/umms-indikar/shared/projects/geneformer/geneformer-12L-30M/"
outpath = arg_out2
num_cells = None # all cells, useful for testing 

print(model_path)

print(f"Loading model from '{model_path}'...")
model = gtu.load_model(model_path)
print("Model loaded successfully!")

print(f"Loading dataset from '{dataset_path}' (up to {num_cells} cells)...")
try:
    df = gtu.load_data_as_dataframe(dataset_path, num_cells=num_cells)
    data = Dataset.from_pandas(df)
    df = df.drop(columns='input_ids')
except FileNotFoundError:
    print(f"Error: Dataset file not found at '{dataset_path}'")
    sys.exit(1)
except Exception as e:  # Catching other potential errors
    print(f"Error loading dataset: {e}")
    sys.exit(1)
print("Dataset loaded successfully!")

print("Extracting embeddings...")
embs = gtu.extract_embedding_in_mem(model, data)
adata = gtu.embedding_to_adata(embs)
adata.obs = df.astype(str).reset_index().copy()
print("Embeddings extracted successfully!")

print(f"Writing results to '{outpath}'...")
try:
    adata.write(outpath)
except Exception as e:
    print(f"Error writing output file: {e}")
    sys.exit(1)
print("Output file written successfully!")


/nfs/turbo/umms-indikar/shared/projects/geneformer/geneformer-12L-30M/
Loading model from '/nfs/turbo/umms-indikar/shared/projects/geneformer/geneformer-12L-30M/'...


  return self.fget.__get__(instance, owner)()


Model loaded successfully!
Loading dataset from '/nfs/turbo/umms-indikar/shared/projects/geneformer/data/rajapakse_lab_data_jpic.dataset' (up to None cells)...
Dataset loaded successfully!
Extracting embeddings...


  0%|          | 0/7 [00:00<?, ?it/s]

Embeddings extracted successfully!
Writing results to '/nfs/turbo/umms-indikar/shared/projects/geneformer/data/rajapakse_lab_data_jpic_GF_embeddings'...
Output file written successfully!


## Helper Codes

In [5]:
# from to_geneformer.py
def check_counts_column(adata, counts_column):
    """Checks for and calculates a total counts column in AnnData.

    This function examines the AnnData object's observation (`obs`) columns for the specified 
    `counts_column`. If it doesn't exist, the function calculates the sum of each row (cell) 
    across all features in the data matrix (`X`) and stores it as a new column in `obs`.

    Args:
        adata: An AnnData object containing the data to be analyzed.
        counts_column: A string representing the desired name for the total counts column.

    Returns:
        adata: The modified AnnData object, now with the `counts_column` present (either 
               pre-existing or newly calculated).
    """
    obs_columns = adata.obs.columns
    
    if counts_column in obs_columns:
        return adata
    else:
        adata.obs[counts_column] = adata.X.sum(axis=1)
        return adata
    
    
def map_gene_names(adata, gene_id, gene_name_column, gene_names):
    """A function mapping gene names to gene ids """
    var_columns = adata.var.columns
    
    if gene_id in var_columns:
        return adata
    else:
        adata.var[gene_id] = adata.var[gene_name_column].map(gene_names)
        return adata
    
    
def load_gene_names(gene_names_file):
    """
    Loads a gene median dictionary from a pickle file.

    Args:
        gene_names_file (str): Path to the pickle file containing the gene names dictionary.

    Returns:
        dict: A dictionary mapping gene names to IDs
    """

    with open(gene_names_file, "rb") as f:
        gene_names_dict = pickle.load(f)

    return gene_names_dict


def load_gene_median_dict(gene_median_file):
    """
    Loads a gene median dictionary from a pickle file.

    Args:
        gene_median_file (str): Path to the pickle file containing the gene median dictionary.

    Returns:
        dict: A dictionary mapping gene IDs to their median expression values.
    """

    with open(gene_median_file, "rb") as f:
        gene_median_dict = pickle.load(f)

    return gene_median_dict


def load_gene_tokenization(token_dictionary_file):
    """
    Loads gene tokenization data from a pickle file.

    Args:
        token_dictionary_file (str): Path to the pickle file containing the gene-token dictionary.

    Returns:
        dict: Gene-token dictionary (Ensembl ID: token).
        list: List of all gene keys (Ensembl IDs).
        dict: Dictionary mapping gene keys to True (used for selecting genes later).
    """

    with open(token_dictionary_file, "rb") as f:
        gene_token_dict = pickle.load(f)

    gene_keys = list(gene_token_dict.keys())

    # Optimization: Pre-allocate the list for slight performance improvement
    genelist_dict = dict.fromkeys(gene_keys, True)

    return gene_token_dict, gene_keys, genelist_dict


def rank_genes(gene_vector, gene_tokens):
    """Ranks genes based on expression values in descending order.

    Args:
        gene_vector (numpy.ndarray): Array of gene expression values.
        gene_tokens (numpy.ndarray): Array of corresponding gene tokens.

    Returns:
        numpy.ndarray: Array of gene tokens sorted by descending expression value.
    """
    return gene_tokens[np.argsort(-gene_vector)]


def normalize_counts(adata_chunk,  counts_column='n_counts', target_sum=10000):
    """Normalizes gene expression counts within a chunk of AnnData.

    Args:
        adata_chunk (AnnData): A chunk of the AnnData object containing gene expression data.
        counts_column (str): Name of the column in `adata_chunk.obs` containing the total counts per cell.
        target_sum (float): The desired total count per cell after normalization.
        norm_factor_vector (numpy.ndarray): An array of normalization factors for each gene.

    Returns:
        scipy.sparse.csr_matrix: A sparse matrix containing the normalized gene expression counts.

    This function performs the following steps:
        1. Extracts the total counts per cell from the specified column (`counts_column`).
        2. Normalizes the gene expression matrix (`adata_chunk.X`) by dividing by the total counts 
           and multiplying by the `target_sum`.
        3. Further adjusts the normalized values by dividing by the gene-specific normalization 
           factors (`norm_factor_vector`).
        4. Returns the normalized expression matrix as a sparse CSR matrix for efficient storage 
           and computation.
    """
    
    n_counts = adata_chunk.obs[counts_column].values[:, None]  # Cell counts as column vector
    X_norm = adata_chunk.X / n_counts * target_sum / norm_factor_vector
    return sp.csr_matrix(X_norm)  # Efficient sparse representation


def tokenize_anndata(adata, genelist_dict, gene_median_dict, 
                     chunk_size=100000, target_sum=10000, 
                     counts_column='n_counts', gene_id="ensembl_id"):
    """
    Tokenizes and ranks genes within an AnnData object, optimizing for memory efficiency.

    This function processes gene expression data in chunks, applies normalization, and ranks genes
    for each cell based on their expression levels. The resulting tokenized and ranked gene
    representations, along with cell metadata, are returned.

    Args:
        adata (AnnData): The AnnData object containing gene expression data.
        genelist_dict (dict): Dictionary mapping gene IDs to boolean values indicating relevance.
        gene_median_dict (dict): Dictionary mapping gene IDs to their median expression values.
        chunk_size (int, optional): Number of cells to process in each chunk (default: 1000).
        target_sum (int, optional): Target sum for count normalization (default: 10000).
        counts_column (str, optional): The column in `adata.obs` containing cell counts (default: 'n_counts').
        gene_id (str, optional): The column in `adata.var` containing gene IDs (default: 'ensembl_id').

    Returns:
        tuple: 
            - list: List of tokenized and ranked gene lists for each cell.
            - dict: Dictionary containing cell metadata (keys are metadata column names).
    """
    # Filter relevant miRNAs
    coding_miRNA_mask = np.array([genelist_dict.get(i, False) for i in adata.var[gene_id]])
    coding_miRNA_loc = np.where(coding_miRNA_mask)[0]

    # Extract miRNA information
    coding_miRNA_ids = adata.var[gene_id].iloc[coding_miRNA_loc]
    norm_factor_vector = np.array([gene_median_dict[i] for i in coding_miRNA_ids])
    coding_miRNA_tokens = np.array([gene_token_dict[i] for i in coding_miRNA_ids])

    tokenized_cells = []
    file_cell_metadata = {k: [] for k in adata.obs.columns}  # Initialize metadata dict

    # Process in chunks for memory efficiency
    for chunk_start in range(0, adata.shape[0], chunk_size):
        chunk_end = chunk_start + chunk_size
        adata_chunk = adata[chunk_start:chunk_end, coding_miRNA_loc]
        
        # Normalize counts (could be replaced with the untested function above)
        n_counts = adata_chunk.obs[counts_column].values[:, None]
        X_norm = adata_chunk.X / n_counts * target_sum / norm_factor_vector
        X_norm = sp.csr_matrix(X_norm)  

        # Tokenize and rank genes for each cell in chunk
        for i in range(X_norm.shape[0]):
            ranks = rank_genes(X_norm[i].data, coding_miRNA_tokens[X_norm[i].indices])
            ranks = list(ranks[~np.isnan(ranks)].astype(int))

            tokenized_cells.append(ranks)

        # Update metadata
        for k in adata.obs.columns:
            file_cell_metadata[k].extend(adata_chunk.obs[k].astype(str).tolist())

    return tokenized_cells, file_cell_metadata


def format_cell_features(example):
    """
    Truncates gene tokens (`input_ids`) to `model_size` and adds a `length` feature.

    Args:
        example (dict): Cell data with `input_ids` (list of gene tokens).

    Returns:
        dict: Modified cell data with truncated `input_ids` and added `length`.
    """
    example["input_ids"] = example["input_ids"][0:model_size] 
    example["length"] = len(example["input_ids"]) 
    return example


def save_hf_dataset(dataset: Dataset, output_path: str, overwrite=True):
    """
    Saves a Hugging Face Dataset to disk at a specified file path.

    This function serializes a Hugging Face `Dataset` object and saves it to disk in the Arrow format.

    Args:
        dataset (Dataset): The Hugging Face `Dataset` object to be saved.
        output_path (str): The full file path (including the filename) where the dataset will be saved. 
        overwrite (bool, optional): If `True`, an existing dataset at `output_path` will be overwritten. 
                                   If `False` and the file exists, a `FileExistsError` is raised (default: True).

    Raises:
        TypeError: If `dataset` is not a Hugging Face `Dataset` instance.
        FileExistsError: If `output_path` points to an existing file and `overwrite` is False.
    """

    if not isinstance(dataset, Dataset):
        raise TypeError("The provided dataset is not a Hugging Face Dataset.")

    if os.path.exists(output_path) and not overwrite:
        raise FileExistsError(
            f"Dataset '{output_path}' already exists. Set `overwrite=True` to overwrite."
        )
    dataset.save_to_disk(output_path)

## Check output

In [10]:
adata = sc.read_h5ad(arg_out2)
adata



AnnData object with n_obs × n_vars = 66 × 512
    obs: 'index', 'dataset', 'sample_id', 'timepoint', 'hour', 'n_counts', 'control', 'order', 'replicate', 'batch', 'length'

In [13]:
adata.var

D0
D1
D2
D3
D4
...
D507
D508
D509
D510
D511


# JPIC Data Engineering

In this file Joshua changes a file made by cooper to standardize the time point and replicate naming from files made by cooper

In [1]:
import pandas as pd
import scanpy as sc
import numpy as np
import anndata as an
import re

In [2]:
coopers_data_path = "/nfs/turbo/umms-indikar/shared/projects/geneformer/data/rajapakse_lab_data.h5ad"
ad = an.read(coopers_data_path)



In [3]:
ad.var

Unnamed: 0_level_0,gene_id,token_id,Chromosome,Source,Feature,Start,End,Score,Strand,Frame,...,transcript_biotype,tag,ccds_id,exon_number,exon_id,exon_version,protein_id,protein_version,transcript_support_level,ensembl_id
gene_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
A1BG,ENSG00000121410,5150.0,19,ensembl_havana,gene,58345177.0,58353492.0,.,-,.,...,,,,,,,,,,ENSG00000121410
A1CF,ENSG00000148584,9064.0,10,ensembl_havana,gene,50799408.0,50885675.0,.,-,.,...,,,,,,,,,,ENSG00000148584
A2M,ENSG00000175899,13826.0,12,ensembl_havana,gene,9067663.0,9116229.0,.,-,.,...,,,,,,,,,,ENSG00000175899
A2ML1,ENSG00000166535,11812.0,12,ensembl_havana,gene,8822620.0,8887001.0,.,+,.,...,,,,,,,,,,ENSG00000166535
A3GALT2,ENSG00000184389,15327.0,1,ensembl_havana,gene,33306765.0,33321098.0,.,-,.,...,,,,,,,,,,ENSG00000184389
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
ZYG11A,ENSG00000203995,17515.0,1,ensembl_havana,gene,52842510.0,52894998.0,.,+,.,...,,,,,,,,,,ENSG00000203995
ZYG11B,ENSG00000162378,10655.0,1,ensembl_havana,gene,52726452.0,52827336.0,.,+,.,...,,,,,,,,,,ENSG00000162378
ZYX,ENSG00000159840,10336.0,7,ensembl_havana,gene,143381294.0,143391111.0,.,+,.,...,,,,,,,,,,ENSG00000159840
ZZEF1,ENSG00000074755,1277.0,17,ensembl_havana,gene,4004444.0,4143030.0,.,-,.,...,,,,,,,,,,ENSG00000074755


In [4]:
def extract_timepoint_replicate_2015(data_id):
    match = re.match(r"S(\d+)([ab])", data_id)
    if match:
        time = int(match.group(1))
        replicate = 1 if match.group(2) == 'a' else 2
        return time, replicate
    else:
        return None, None

def extract_timepoint_replicate_2018(data_id):
    match = re.match(r"(\d+)_T(\d+)R(\d+)", data_id)
    if match:
        return int(match.group(2)), int(match.group(3))
    else:
        return None, None


In [5]:
adDs5 = ad[ad.obs['dataset'] == 'chen_2015']
adDs8 = ad[ad.obs['dataset'] == 'liu_2018']

In [6]:
timepoint_replicate = adDs8.obs.index.to_series().apply(extract_timepoint_replicate_2018)
timepoint_replicate_df = timepoint_replicate.apply(pd.Series)
timepoint_replicate_df.columns = ['order', 'replicate']

# Add the new columns to the AnnData object
adDs8.obs = adDs8.obs.join(timepoint_replicate_df)
adDs8

AnnData object with n_obs × n_vars = 48 × 19393
    obs: 'dataset', 'sample_id', 'timepoint', 'hour', 'n_counts', 'control', 'order', 'replicate'
    var: 'gene_id', 'token_id', 'Chromosome', 'Source', 'Feature', 'Start', 'End', 'Score', 'Strand', 'Frame', 'gene_version', 'gene_source', 'gene_biotype', 'transcript_id', 'transcript_version', 'transcript_name', 'transcript_source', 'transcript_biotype', 'tag', 'ccds_id', 'exon_number', 'exon_id', 'exon_version', 'protein_id', 'protein_version', 'transcript_support_level', 'ensembl_id'

In [7]:
timepoint_replicate = adDs5.obs.index.to_series().apply(extract_timepoint_replicate_2015)
timepoint_replicate_df = timepoint_replicate.apply(pd.Series)
timepoint_replicate_df.columns = ['order', 'replicate']

# Add the new columns to the AnnData object
adDs5.obs = adDs5.obs.join(timepoint_replicate_df)
adDs5

AnnData object with n_obs × n_vars = 18 × 19393
    obs: 'dataset', 'sample_id', 'timepoint', 'hour', 'n_counts', 'control', 'order', 'replicate'
    var: 'gene_id', 'token_id', 'Chromosome', 'Source', 'Feature', 'Start', 'End', 'Score', 'Strand', 'Frame', 'gene_version', 'gene_source', 'gene_biotype', 'transcript_id', 'transcript_version', 'transcript_name', 'transcript_source', 'transcript_biotype', 'tag', 'ccds_id', 'exon_number', 'exon_id', 'exon_version', 'protein_id', 'protein_version', 'transcript_support_level', 'ensembl_id'

In [8]:
adDs5.obs

Unnamed: 0_level_0,dataset,sample_id,timepoint,hour,n_counts,control,order,replicate
data_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
S1a,chen_2015,S1a,0.0,0.0,7901832,True,1,1
S1b,chen_2015,S1b,0.0,0.0,8113329,True,1,2
S2a,chen_2015,S2a,0.0,0.0,9831046,False,2,1
S2b,chen_2015,S2b,0.0,0.0,10123271,False,2,2
S3a,chen_2015,S3a,1.0,8.0,10490839,False,3,1
S3b,chen_2015,S3b,1.0,8.0,10713844,False,3,2
S4a,chen_2015,S4a,2.0,16.0,9183324,False,4,1
S4b,chen_2015,S4b,2.0,16.0,9401913,False,4,2
S5a,chen_2015,S5a,3.0,24.0,9655719,False,5,1
S5b,chen_2015,S5b,3.0,24.0,9863515,False,5,2


In [9]:
adDs8.obs

Unnamed: 0_level_0,dataset,sample_id,timepoint,hour,n_counts,control,order,replicate
data_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
63246_T0R1,liu_2018,63246,1.0,-48.0,11940999,True,0,1
63252_T1R1,liu_2018,63252,2.0,0.0,18063509,False,1,1
63249_T2R1,liu_2018,63249,3.0,8.0,11031474,False,2,1
63261_T3R1,liu_2018,63261,1.0,16.0,16761043,False,3,1
63258_T4R1,liu_2018,63258,2.0,24.0,8244802,False,4,1
63255_T5R1,liu_2018,63255,3.0,32.0,10615057,False,5,1
63270_T6R1,liu_2018,63270,1.0,40.0,16486670,False,6,1
63267_T7R1,liu_2018,63267,2.0,48.0,10127547,False,7,1
63264_T8R1,liu_2018,63264,3.0,56.0,11231585,False,8,1
63279_T9R1,liu_2018,63279,1.0,64.0,10781978,False,9,1


In [10]:
adDs_combined = adDs5.concatenate(adDs8, join='outer', index_unique=None)

  adDs_combined = adDs5.concatenate(adDs8, join='outer', index_unique=None)


In [12]:
adDs_combined.obs

Unnamed: 0_level_0,dataset,sample_id,timepoint,hour,n_counts,control,order,replicate,batch
data_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
S1a,chen_2015,S1a,0.0,0.0,7901832,True,1,1,0
S1b,chen_2015,S1b,0.0,0.0,8113329,True,1,2,0
S2a,chen_2015,S2a,0.0,0.0,9831046,False,2,1,0
S2b,chen_2015,S2b,0.0,0.0,10123271,False,2,2,0
S3a,chen_2015,S3a,1.0,8.0,10490839,False,3,1,0
...,...,...,...,...,...,...,...,...,...
63275_T11R3,liu_2018,63275,3.0,80.0,13515971,False,11,3,1
63290_T12R3,liu_2018,63290,1.0,88.0,9522866,False,12,3,1
63287_T13R3,liu_2018,63287,2.0,96.0,12370157,False,13,3,1
63284_T14R3,liu_2018,63284,3.0,104.0,10970735,False,14,3,1


In [14]:
adDs_combined

AnnData object with n_obs × n_vars = 66 × 19393
    obs: 'dataset', 'sample_id', 'timepoint', 'hour', 'n_counts', 'control', 'order', 'replicate', 'batch'
    var: 'gene_id', 'token_id', 'Chromosome', 'Source', 'Feature', 'Start', 'End', 'Score', 'Strand', 'Frame', 'gene_version', 'gene_source', 'gene_biotype', 'transcript_id', 'transcript_version', 'transcript_name', 'transcript_source', 'transcript_biotype', 'tag', 'ccds_id', 'exon_number', 'exon_id', 'exon_version', 'protein_id', 'protein_version', 'transcript_support_level', 'ensembl_id'

In [13]:
# check that the size of the remerged ann data object is compatible with the size of Cooper's
ad

AnnData object with n_obs × n_vars = 66 × 19393
    obs: 'dataset', 'sample_id', 'timepoint', 'hour', 'n_counts', 'control'
    var: 'gene_id', 'token_id', 'Chromosome', 'Source', 'Feature', 'Start', 'End', 'Score', 'Strand', 'Frame', 'gene_version', 'gene_source', 'gene_biotype', 'transcript_id', 'transcript_version', 'transcript_name', 'transcript_source', 'transcript_biotype', 'tag', 'ccds_id', 'exon_number', 'exon_id', 'exon_version', 'protein_id', 'protein_version', 'transcript_support_level', 'ensembl_id'

In [15]:
out_path = "/nfs/turbo/umms-indikar/shared/projects/geneformer/data/rajapakse_lab_data_jpic.h5ad"
adDs_combined.write(out_path)

# Tokenize

In [1]:
import pandas as pd
import anndata as ad
import numpy as np
import h5py
import os
import pickle
import scipy.sparse as sp
from geneformer import TranscriptomeTokenizer

In [2]:
input_path = "/nfs/turbo/umms-indikar/shared/projects/geneformer/data/rajapakse_lab_data_jpic.h5ad"

In [3]:
input_dir = os.path.dirname(input_path)
output_path = "/scratch/indikar_root/indikar1/cstansbu/geneformer/"
prefix = "test"

def get_attributes(h5ad_path):
    """
    Extracts attribute names from the `.obs` field of an h5ad AnnData file,
    returning them as a dictionary with keys and values being the attribute names.

    Args:
        h5ad_path (str): The path to the h5ad file.

    Returns:
        dict: A dictionary where keys and values are the unique attribute names 
              found in the `.obs` field of the h5ad file.
    """
    with h5py.File(h5ad_path, mode="r") as store:
        attribute_names = list(store["obs"].keys())

    attribute_name_dict = {name: name for name in attribute_names}  # Create dictionary
    return attribute_name_dict
    

custom_attr_name_dict = get_attributes(input_path)
custom_attr_name_dict

{'batch': 'batch',
 'control': 'control',
 'data_id': 'data_id',
 'dataset': 'dataset',
 'hour': 'hour',
 'n_counts': 'n_counts',
 'order': 'order',
 'replicate': 'replicate',
 'sample_id': 'sample_id',
 'timepoint': 'timepoint'}

In [4]:
def load_gene_median_dict(gene_median_file):
    """
    Loads a gene median dictionary from a pickle file.

    Args:
        gene_median_file (str): Path to the pickle file containing the gene median dictionary.

    Returns:
        dict: A dictionary mapping gene IDs to their median expression values.
    """

    with open(gene_median_file, "rb") as f:
        gene_median_dict = pickle.load(f)

    return gene_median_dict


def load_gene_tokenization(token_dictionary_file):
    """
    Loads gene tokenization data from a pickle file.

    Args:
        token_dictionary_file (str): Path to the pickle file containing the gene-token dictionary.

    Returns:
        dict: Gene-token dictionary (Ensembl ID: token).
        list: List of all gene keys (Ensembl IDs).
        dict: Dictionary mapping gene keys to True (used for selecting genes later).
    """

    with open(token_dictionary_file, "rb") as f:
        gene_token_dict = pickle.load(f)

    gene_keys = list(gene_token_dict.keys())

    # Optimization: Pre-allocate the list for slight performance improvement
    genelist_dict = dict.fromkeys(gene_keys, True)

    return gene_token_dict, gene_keys, genelist_dict


def rank_genes(gene_vector, gene_tokens):
    """Ranks genes based on expression values in descending order.

    Args:
        gene_vector (numpy.ndarray): Array of gene expression values.
        gene_tokens (numpy.ndarray): Array of corresponding gene tokens.

    Returns:
        numpy.ndarray: Array of gene tokens sorted by descending expression value.
    """
    return gene_tokens[np.argsort(-gene_vector)]


def normalize_counts(adata_chunk,  counts_column='n_counts', target_sum=10000):
    """Normalizes gene expression counts within a chunk of AnnData.

    Args:
        adata_chunk (AnnData): A chunk of the AnnData object containing gene expression data.
        counts_column (str): Name of the column in `adata_chunk.obs` containing the total counts per cell.
        target_sum (float): The desired total count per cell after normalization.
        norm_factor_vector (numpy.ndarray): An array of normalization factors for each gene.

    Returns:
        scipy.sparse.csr_matrix: A sparse matrix containing the normalized gene expression counts.

    This function performs the following steps:
        1. Extracts the total counts per cell from the specified column (`counts_column`).
        2. Normalizes the gene expression matrix (`adata_chunk.X`) by dividing by the total counts 
           and multiplying by the `target_sum`.
        3. Further adjusts the normalized values by dividing by the gene-specific normalization 
           factors (`norm_factor_vector`).
        4. Returns the normalized expression matrix as a sparse CSR matrix for efficient storage 
           and computation.
    """
    
    n_counts = adata_chunk.obs[counts_column].values[:, None]  # Cell counts as column vector
    X_norm = adata_chunk.X / n_counts * target_sum / norm_factor_vector
    return sp.csr_matrix(X_norm)  # Efficient sparse representation


def tokenize_anndata(adata, genelist_dict, gene_median_dict, 
                     chunk_size=100000, target_sum=10000):
    """
    Tokenizes and ranks genes within an AnnData object, optimizing for memory efficiency.

    This function processes gene expression data in chunks, applies normalization, and ranks genes
    for each cell based on their expression levels. The resulting tokenized and ranked gene
    representations, along with cell metadata, are returned.

    Args:
        adata (AnnData): The AnnData object containing gene expression data.
        genelist_dict (dict): Dictionary mapping gene IDs to boolean values indicating relevance.
        gene_median_dict (dict): Dictionary mapping gene IDs to their median expression values.
        chunk_size (int, optional): Number of cells to process in each chunk (default: 1000).
        target_sum (int, optional): Target sum for count normalization (default: 10000).

    Returns:
        tuple: 
            - list: List of tokenized and ranked gene lists for each cell.
            - dict: Dictionary containing cell metadata (keys are metadata column names).
    """
    # Filter relevant miRNAs
    coding_miRNA_mask = np.array([genelist_dict.get(i, False) for i in adata.var['ensembl_id']])
    coding_miRNA_loc = np.where(coding_miRNA_mask)[0]

    # Extract miRNA information
    coding_miRNA_ids = adata.var['ensembl_id'][coding_miRNA_loc]
    norm_factor_vector = np.array([gene_median_dict[i] for i in coding_miRNA_ids])
    coding_miRNA_tokens = np.array([gene_token_dict[i] for i in coding_miRNA_ids])

    tokenized_cells = []
    file_cell_metadata = {k: [] for k in adata.obs.columns}  # Initialize metadata dict

    # Process in chunks for memory efficiency
    for chunk_start in range(0, adata.shape[0], chunk_size):
        chunk_end = chunk_start + chunk_size
        adata_chunk = adata[chunk_start:chunk_end, coding_miRNA_loc]
        
        # Normalize counts (could be replaced with the untested function above)
        n_counts = adata_chunk.obs['n_counts'].values[:, None]
        X_norm = adata_chunk.X / n_counts * target_sum / norm_factor_vector
        X_norm = sp.csr_matrix(X_norm)  

        # Tokenize and rank genes for each cell in chunk
        for i in range(X_norm.shape[0]):
            ranks = rank_genes(X_norm[i].data, coding_miRNA_tokens[X_norm[i].indices])
            ranks = list(ranks[~np.isnan(ranks)].astype(int))

            tokenized_cells.append(ranks)

        # Update metadata
        for k in adata.obs.columns:
            file_cell_metadata[k].extend(adata_chunk.obs[k].tolist())

    return tokenized_cells, file_cell_metadata


In [5]:
DEFAULT_TOKEN_PATH = "/nfs/turbo/umms-indikar/shared/projects/geneformer/token_dictionary.pkl"
DEFAULT_MEDIAN_PATH = "/nfs/turbo/umms-indikar/shared/projects/geneformer/geneformer/gene_median_dictionary.pkl"

gene_token_dict, gene_keys, genelist_dict = load_gene_tokenization(DEFAULT_TOKEN_PATH)
gene_median_dict = load_gene_median_dict(DEFAULT_MEDIAN_PATH)

In [6]:
print(input_path)
adata = ad.read(input_path, backed="r")
adata

/nfs/turbo/umms-indikar/shared/projects/geneformer/data/rajapakse_lab_data_jpic.h5ad




AnnData object with n_obs × n_vars = 66 × 19393 backed at '/nfs/turbo/umms-indikar/shared/projects/geneformer/data/rajapakse_lab_data_jpic.h5ad'
    obs: 'dataset', 'sample_id', 'timepoint', 'hour', 'n_counts', 'control', 'order', 'replicate', 'batch'
    var: 'gene_id', 'token_id', 'Chromosome', 'Source', 'Feature', 'Start', 'End', 'Score', 'Strand', 'Frame', 'gene_version', 'gene_source', 'gene_biotype', 'transcript_id', 'transcript_version', 'transcript_name', 'transcript_source', 'transcript_biotype', 'tag', 'ccds_id', 'exon_number', 'exon_id', 'exon_version', 'protein_id', 'protein_version', 'transcript_support_level', 'ensembl_id'

In [7]:
tokenized_cells, cell_metadata = tokenize_anndata(adata, 
                                                  genelist_dict, 
                                                  gene_median_dict)

In [13]:
np.array(tokenized_cells).shape

ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (66,) + inhomogeneous part.

In [12]:
pd.DataFrame(cell_metadata)

Unnamed: 0,dataset,sample_id,timepoint,hour,n_counts,control,order,replicate,batch
0,chen_2015,S1a,0.0,0.0,7901832,True,1,1,0
1,chen_2015,S1b,0.0,0.0,8113329,True,1,2,0
2,chen_2015,S2a,0.0,0.0,9831046,False,2,1,0
3,chen_2015,S2b,0.0,0.0,10123271,False,2,2,0
4,chen_2015,S3a,1.0,8.0,10490839,False,3,1,0
...,...,...,...,...,...,...,...,...,...
61,liu_2018,63275,3.0,80.0,13515971,False,11,3,1
62,liu_2018,63290,1.0,88.0,9522866,False,12,3,1
63,liu_2018,63287,2.0,96.0,12370157,False,13,3,1
64,liu_2018,63284,3.0,104.0,10970735,False,14,3,1
