In [1]:
import sys
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pickle
import scipy.sparse as sp

# env_path = "/home/cstansbu/miniconda3/envs/scanpy/lib/python3.12/site-packages/"
# sys.path.append(env_path)

# for datasets
env_path = "/home/cstansbu/miniconda3/envs/geneformer/lib/python3.10/site-packages/"
sys.path.insert(0, env_path)

import scanpy as sc
import anndata as an
from datasets import Dataset

In [2]:
fpath = "/scratch/indikar_root/indikar1/cstansbu/HSC/scanpy/clustered.anndata.h5ad"

adata = sc.read_h5ad(fpath)
adata

AnnData object with n_obs × n_vars = 8562 × 14972
    obs: 'n_genes', 'doublet_score', 'predicted_doublet', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt'
    var: 'gene_name', 'Chromosome', 'Start', 'End', 'Strand', 'n_cells', 'mt', 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'ensembl_id'
    uns: 'clusters', 'fb_vs_hsc_up', 'go_annotations', 'hsc_v_fib_up', 'hvg', 'log1p', 'neighbors', 'panglaodb', 'pca', 'scenic_transcription_factors', 'scrublet', 'tabula_sapiens_deg', 'umap', 'v5_tags'
    obsm: 'X_pca', 'X_umap'
    varm: 'PCs'
    layers: 'filtered_counts', 'raw_counts'
    obsp: 'connectivities', 'distances'

In [3]:
def load_gene_map(gene_table_path, gene_type="protein_coding", keys='gene_name', values='gene_id'):
    """
    Loads and filters a gene mapping table from a CSV file.

    This function reads a CSV file containing gene information, filters for protein-coding genes
    (or other specified types), handles duplicates and missing values, and creates a dictionary 
    mapping the specified `keys` (e.g., gene names) to the corresponding `values` (e.g., gene IDs).

    Args:
        gene_table_path (str): Path to the CSV file containing the gene table.
        gene_type (str, optional): The type of gene to filter for (default: 'protein_coding').
        keys (str, optional): The column name to use as keys in the output dictionary (default: 'gene_name').
        values (str, optional): The column name to use as values in the output dictionary (default: 'gene_id').

    Returns:
        dict: A dictionary mapping gene keys to gene values for protein-coding genes.
    """
    usecols = ['gene_id', 'gene_name', 'gene_biotype']
    df = pd.read_csv(gene_table_path, usecols=usecols)

    # Filter and clean data in a single chain
    df = (
        df.drop_duplicates()
        .query("gene_biotype == @gene_type")
        .dropna(subset=['gene_name', 'gene_id'])
    )

    return dict(zip(df[keys], df[values]))
    
gene_table_path = "/scratch/indikar_root/indikar1/cstansbu/HSC/references/geneTable.csv"
gene_map = load_gene_map(gene_table_path)    
len(gene_map)

19393

In [4]:
list(gene_map.keys())[:10]

['ATAD3B',
 'PRDM16',
 'SKI',
 'PEX14',
 'PLCH2',
 'SPSB1',
 'HES3',
 'PLEKHM2',
 'CA6',
 'NMNAT1']

In [52]:
def skeletonize(adata, gene_map, gene_identifier, gene_column_type, gene_index, counts_layer):
    """
    Creates a simplified AnnData object by filtering genes and mapping identifiers.

    This function takes an AnnData object and a gene mapping dictionary. It filters the genes in 
    the AnnData based on the provided mapping, optionally converting between gene names and Ensembl IDs.
    The resulting AnnData object contains only the relevant genes and their associated count data,
    along with the original observation metadata.

    Args:
        adata (anndata.AnnData): The input AnnData object containing gene expression data.
        gene_map (dict): A dictionary mapping gene identifiers (keys) to desired values.
                            Keys can be either gene names or Ensembl IDs, depending on `gene_column_type`.
        gene_identifier (str): The column name in `adata.var` containing the gene identifiers 
                            (either 'gene_name' or 'ensemble_id').
        gene_column_type (str): Indicates the type of gene identifier in the `gene_identifier` column:
                            'gene_name' or 'ensemble_id'.
        gene_index (str): The column name to use as the gene index in the output AnnData's `.var` attribute.
                            Typically 'gene_name' or 'ensemble_id'.
        counts_layer (str): The layer name in `adata` to use for count data (default is 'counts').

    Returns:
        anndata.AnnData: A new AnnData object with:
            - `.X`: Count data filtered to the relevant genes.
            - `.obs`: Copied from the original `adata`.
            - `.var`: Contains the `gene_index` as the index, and the mapped gene identifiers.
            - `.obs['n_counts']`: Added (or preserved) to indicate the total counts per cell.

    Raises:
        ValueError: If an invalid `gene_column_type` is provided.
    """
    
    if gene_column_type == 'gene_name':
        var = pd.DataFrame(adata.var[gene_identifier].copy())
        var.columns = ['gene_name']
        var['ensemble_id'] = var['gene_name'].map(gene_map)
        var = var[var['ensemble_id'].notna()]
        gene_idx = var.index # use the existing index
    elif gene_column_type == 'ensemble_id':
        var = pd.DataFrame(adata.var[gene_identifier].copy())
        var.columns = ['ensemble_id']
        var['gene_name'] = var['ensemble_id'].map(gene_map)
        var = var[var['gene_name'].notna()]
        gene_idx = var.index # use the existing index
    else:
        raise ValueError('gene_column_type must be one of: `ensemble_id` or `gene_name`')
        
    X = adata[:, gene_idx].layers[counts_layer]
    ndata = an.AnnData(X)
    ndata.obs = adata.obs.copy()
    ndata.obs_names = adata.obs_names
    
    if not 'n_counts' in ndata.obs.columns:
        ndata.obs['n_counts'] = X.sum(axis=1)
    
    ndata.var_names = var[gene_index].values
    ndata.var = var.set_index(gene_index)
    
    return ndata
    
    
test = skeletonize(adata, gene_map,
                   gene_identifier='gene_name', 
                   gene_column_type='gene_name', 
                   gene_index='gene_name',
                   counts_layer="raw_counts")
test

AnnData object with n_obs × n_vars = 8562 × 14972
    obs: 'n_genes', 'doublet_score', 'predicted_doublet', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'n_counts'
    var: 'ensemble_id'

In [50]:
test.to_df()

gene_name,ATAD3B,SKI,PEX14,PLCH2,SPSB1,HES3,PLEKHM2,CA6,NMNAT1,CCDC27,...,CDY1,TSPY4,TSPY9,KDM5D,BPY2C,CDY2B,SRY,VCY,DAZ1,RBMY1E
cell_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
AAACCCAAGGTTACCT,0,1,0,0,0,0,0,0,0,1,...,0,0,0,0,0,0,0,4,0,0
AAACCCAAGTTGAAGT,0,0,0,0,0,0,0,0,0,0,...,0,0,1,0,0,0,0,3,0,0
AAACCCAAGTTGTCGT,0,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,1,1,0,0
AAACCCACAGAAGCGT,0,0,0,0,0,0,0,0,0,0,...,0,0,1,0,0,0,0,0,0,0
AAACCCACAGGAGGTT,0,0,0,0,0,0,0,0,0,0,...,0,0,1,1,0,0,0,3,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGTTGCAAGAGGTC,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0
TTTGTTGCATGTGGTT,1,0,0,0,0,0,1,0,0,0,...,0,0,0,0,0,0,0,3,0,0
TTTGTTGGTATACCCA,0,1,0,0,0,0,0,0,0,1,...,0,0,1,0,0,0,0,2,0,0
TTTGTTGTCACGTAGT,0,0,0,0,0,0,2,0,0,1,...,0,0,0,0,0,0,0,1,0,0


In [6]:
break

SyntaxError: 'break' outside loop (668683560.py, line 1)

In [None]:
adata.obs.head()

In [None]:
adata.obs['n_counts'] = adata.obs['total_counts']

In [None]:
def load_gene_median_dict(gene_median_file):
    """
    Loads a gene median dictionary from a pickle file.

    Args:
        gene_median_file (str): Path to the pickle file containing the gene median dictionary.

    Returns:
        dict: A dictionary mapping gene IDs to their median expression values.
    """

    with open(gene_median_file, "rb") as f:
        gene_median_dict = pickle.load(f)

    return gene_median_dict


def load_gene_tokenization(token_dictionary_file):
    """
    Loads gene tokenization data from a pickle file.

    Args:
        token_dictionary_file (str): Path to the pickle file containing the gene-token dictionary.

    Returns:
        dict: Gene-token dictionary (Ensembl ID: token).
        list: List of all gene keys (Ensembl IDs).
        dict: Dictionary mapping gene keys to True (used for selecting genes later).
    """

    with open(token_dictionary_file, "rb") as f:
        gene_token_dict = pickle.load(f)

    gene_keys = list(gene_token_dict.keys())

    # Optimization: Pre-allocate the list for slight performance improvement
    genelist_dict = dict.fromkeys(gene_keys, True)

    return gene_token_dict, gene_keys, genelist_dict


# GENE TOKEN DICT
fpath = "/home/cstansbu/git_repositories/Geneformer/geneformer/token_dictionary.pkl"
gene_token_dict, gene_keys, genelist_dict = load_gene_tokenization(fpath)

# GENE MEDIAN DICT
fpath = "/home/cstansbu/git_repositories/Geneformer/geneformer/gene_median_dictionary.pkl"
gene_median_dict = load_gene_median_dict(fpath)

print('done')

In [None]:
def rank_genes(gene_vector, gene_tokens):
    """Ranks genes based on expression values in descending order.

    Args:
        gene_vector (numpy.ndarray): Array of gene expression values.
        gene_tokens (numpy.ndarray): Array of corresponding gene tokens.

    Returns:
        numpy.ndarray: Array of gene tokens sorted by descending expression value.
    """
    return gene_tokens[np.argsort(-gene_vector)]


def tokenize_anndata(adata, genelist_dict, gene_median_dict, 
                     chunk_size=1000, target_sum=10000, 
                     counts_column='n_counts', gene_id="ensembl_id"):
    """
    Tokenizes and ranks genes within an AnnData object, optimizing for memory efficiency.

    This function processes gene expression data in chunks, applies normalization, and ranks genes
    for each cell based on their expression levels. The resulting tokenized and ranked gene
    representations, along with cell metadata, are returned.

    Args:
        adata (AnnData): The AnnData object containing gene expression data.
        genelist_dict (dict): Dictionary mapping gene IDs to boolean values indicating relevance.
        gene_median_dict (dict): Dictionary mapping gene IDs to their median expression values.
        chunk_size (int, optional): Number of cells to process in each chunk (default: 1000).
        target_sum (int, optional): Target sum for count normalization (default: 10000).
        counts_column (str, optional): The column in `adata.obs` containing cell counts (default: 'n_counts').
        gene_id (str, optional): The column in `adata.var` containing gene IDs (default: 'ensembl_id').

    Returns:
        tuple: 
            - list: List of tokenized and ranked gene lists for each cell.
            - dict: Dictionary containing cell metadata (keys are metadata column names).
    """
    # Filter relevant miRNAs
    coding_miRNA_mask = np.array([genelist_dict.get(i, False) for i in adata.var[gene_id]])
    coding_miRNA_loc = np.where(coding_miRNA_mask)[0]

    # Extract miRNA information
    coding_miRNA_ids = adata.var[gene_id][coding_miRNA_loc]
    norm_factor_vector = np.array([gene_median_dict[i] for i in coding_miRNA_ids])
    coding_miRNA_tokens = np.array([gene_token_dict[i] for i in coding_miRNA_ids])

    tokenized_cells = []
    file_cell_metadata = {k: [] for k in adata.obs.columns}  # Initialize metadata dict

    # Process in chunks for memory efficiency
    for chunk_start in range(0, adata.shape[0], chunk_size):
        chunk_end = chunk_start + chunk_size
        adata_chunk = adata[chunk_start:chunk_end, coding_miRNA_loc]

        # Normalize counts
        n_counts = adata_chunk.obs[counts_column].values[:, None]
        X_norm = adata_chunk.X / n_counts * target_sum / norm_factor_vector
        X_norm = sp.csr_matrix(X_norm)  

        # Tokenize and rank genes for each cell in chunk
        for i in range(X_norm.shape[0]):
            ranks = rank_genes(X_norm[i].data, coding_miRNA_tokens[X_norm[i].indices])
            ranks = list(ranks[~np.isnan(ranks)].astype(int))

            tokenized_cells.append(ranks)

        # Update metadata
        for k in adata.obs.columns:
            file_cell_metadata[k].extend(adata_chunk.obs[k].tolist())

    return tokenized_cells, file_cell_metadata


tokenized_cells, cell_metadata = tokenize_anndata(adata, genelist_dict, gene_median_dict)
print(len(tokenized_cells))
print(tokenized_cells[0])

In [None]:
# break

In [None]:
def create_dataset(tokenized_cells, cell_metadata, gene_token_dict, model_input_size=2048, nproc=16):
    """
    Creates a Hugging Face Dataset from tokenized cells and associated metadata.

    Args:
        tokenized_cells (list): List of tokenized cell representations (lists of tokens).
        cell_metadata (dict, optional): Dictionary containing additional cell metadata.
        model_input_size (int): The maximum input size for the model.
        gene_token_dict (dict): Dictionary mapping genes to their tokens.
        nproc (int, optional): Number of processes to use for mapping. Defaults to 16.

    Returns:
        datasets.Dataset: The processed Hugging Face dataset.
    """
    
    # Merge cell metadata into the dataset dictionary if provided
    dataset_dict = {
        "input_ids": tokenized_cells,
        **cell_metadata 
    }
    
    output_dataset = Dataset.from_dict(dataset_dict)

    def format_cell_features(example):
        example["input_ids"] = example["input_ids"][0 : model_input_size] # truncate
        example["length"] = len(example["input_ids"])  # Add length for convenience
        return example

    return output_dataset.map(format_cell_features, num_proc=nproc)  # Return mapped dataset

dataset = create_dataset(tokenized_cells, cell_metadata, gene_token_dict)
type(dataset)
dataset[0]

In [None]:
break

In [None]:
def save_hf_dataset(dataset: Dataset, output_directory: str, output_prefix: str, overwrite=False):
    """
    Saves a Hugging Face Dataset to disk in a specified directory.

    Args:
        dataset (Dataset): The Hugging Face Dataset to be saved.
        output_directory (str): The directory where the dataset will be saved.
        output_prefix (str): The prefix for the dataset filename.
        overwrite (bool, optional): Whether to overwrite an existing dataset. Defaults to False.

    Raises:
        TypeError: If the dataset is not a Hugging Face Dataset instance.
        FileExistsError: If a dataset with the same name exists and `overwrite` is False.
    """

    if not isinstance(dataset, Dataset):
        raise TypeError("The provided dataset is not a Hugging Face Dataset.")

    output_path = os.path.join(output_directory, f"{output_prefix}.dataset")

    if os.path.exists(output_path) and not overwrite:
        raise FileExistsError(
            f"Dataset '{output_path}' already exists. Set `overwrite=True` to overwrite."
        )

    os.makedirs(output_directory, exist_ok=True)
    dataset.save_to_disk(output_path)
    
output_directory = "/scratch/indikar_root/indikar1/cstansbu/geneformer/"
output_prefix = "test"
save_hf_dataset(dataset, output_directory, output_prefix, overwrite=True)

In [None]:
break

In [None]:
output_path = (Path(output_directory) / output_prefix).with_suffix(".dataset")
tokenized_dataset.save_to_disk(output_path)

In [None]:
tokenized_cells[0]

In [None]:
break

In [None]:
# adata.obs['n_counts']

In [None]:
adata.obs['predicted_doublet'].value_counts()

In [None]:
sc.pl.umap(
    adata,
    color=["doublet_score", "n_genes_by_counts", "total_counts", "pct_counts_mt"],
    size=25,
)

In [None]:
sc.pl.umap(
    adata,
    color=["r12", "r16", "r97"],
    size=25,
)

In [None]:
adata.var[['gene_name', 'means', 'dispersions']].sort_values(by='means', ascending=False).head(30)

In [None]:
# ?sc.pl.umap

In [None]:
sc.pl.umap(
    adata,
    gene_symbols='gene_name',
    color=["KLF2", "FSTL1", "GATA2", "FOXL2NB"],
    size=25,
)

In [None]:
sc.pl.umap(
    adata,
    gene_symbols='gene_name',
    color=["MPL", "CD34", "MGST1", "PTPRC"],
    size=25,
)

In [None]:
sc.pl.umap(
    adata,
    gene_symbols='gene_name',
    color=["GATA2", "FOS", "STAT5A", "REL", "GFI1B"],
    size=25,
)

In [None]:
break

In [None]:
adata.uns.keys()

In [None]:
n = 25
sc.pl.umap(
    adata,
    gene_symbols='gene_name',
    color=adata.uns['fb_vs_hsc_up']['gene_name'].head(n).unique(),
    size=25,
)

adata.uns['fb_vs_hsc_up'].head(n)

In [None]:
n = 25
sc.pl.umap(
    adata,
    gene_symbols='gene_name',
    color=adata.uns['hsc_v_fib_up']['gene_name'].head(n).unique(),
    size=25,
)

adata.uns['hsc_v_fib_up'].head(n)

In [None]:
sc.pl.umap(
    adata,
    gene_symbols='gene_name',
    color=adata.uns['panglaodb']['gene_name'].head(n).unique(),
    size=25,
)

adata.uns['panglaodb'].head(n)

In [None]:
sc.pl.umap(
    adata,
    gene_symbols='gene_name',
    color=adata.uns['tabula_sapiens_deg']['gene_name'].head(10).unique(),
    size=25,
)

adata.uns['tabula_sapiens_deg'].head(10)

In [None]:
sc.pl.umap(
    adata,
    gene_symbols='gene_name',
    color=adata.uns['tabula_sapiens_deg']['gene_name'].tail(10).unique(),
    size=25,
)

adata.uns['tabula_sapiens_deg'].tail(10)

In [None]:
sc.pl.umap(
    adata,
    gene_symbols='gene_name',
    color=adata.uns['go_annotations']['gene_name'].unique(),
    size=25,
)


# Playground

In [None]:
adata

In [None]:
# folder path
dpath = "../config/gene_annotations/"

gene_map = dict(zip(adata.var['gene_name'].values, adata.var['ensembl_id'].values))


for f in os.listdir(dpath):
    if not f.endswith(".csv"):
        continue
        
    # load the annotation
    fpath = f"{dpath}{f}"    
    df = pd.read_csv(fpath)
    
    # make a key name
    key_name = f.replace(".csv", "")
        
    # map ensemble_ids, drop unseen genes
    df['gene_name'] = df['gene_name'].astype(str).str.upper()
    df = df[df['gene_name'].isin(list(gene_map.keys()))]
    df['ensembl_id'] = df['gene_name'].map(gene_map)
    
    adata.uns[key_name] = df
    
adata

In [None]:
adata.uns['tabula_sapiens_deg'].query("cell_type == 'hematopoietic stem cell' and pct_nz_group > 0.9 and pvals_adj > 0").sort_values(by='logfoldchanges').head(20)

In [None]:
sc.pl.umap(
    adata,
    gene_symbols='gene_name',
    color=["FSTL1", "JUNB", "RHOA", "MGST1", "NOP53", "ADGRG1"],
    size=25,
)

In [None]:
adata.uns['go_annotations']

In [None]:
sc.pl.umap(
    adata,
    gene_symbols='gene_name',
    color=adata.uns['go_annotations']['gene_name'].to_list(),
    size=25,
)


