# RNA and ATAC Data Preprocessing Notebook
This notebook provides a complete pipeline for preprocessing single-cell RNA-seq and ATAC-seq data, including normalization, filtering, and tokenization for downstream analysis.



## Import Required Libraries

In [1]:
import scanpy as sc
import anndata
import pandas as pd
import numpy as np
from typing import List, Dict, Tuple, Optional
from tqdm import tqdm
from operator import itemgetter
from datasets import Dataset, Features, Sequence, Value
import pickle
import multiprocessing
from collections import Counter
import os
from multiprocessing import Pool
from functools import partial
from concurrent.futures import ThreadPoolExecutor, as_completed

In [2]:

rna_raw_names_all = [
    "/public/share/t_lgl/scFM/evaluation/data_set/raw_data/09_hPBMC_10k_scGLUE_10xDemo/gex.h5ad"
]

atac_raw_names_all = [
    "/public/share/t_lgl/scFM/evaluation/data_set/raw_data/09_hPBMC_10k_scGLUE_10xDemo/atac.h5ad"
]

# Define output directory
output_dir = "./processed_data/"
os.makedirs(output_dir, exist_ok=True)

### Gene Filtering Based on ENSEMBL IDs

In [3]:
def gene_filter_based_ENSid(ENSG2token: Dict[str, str], ENSG_list: List[str]) -> List[str]:
    '''
    Filter genes based on ENSEMBL ID to token mapping.
    Genes not in the mapping are marked as 'delete'.
    '''
    res = []
    for ENSG in ENSG_list:
        res.append(ENSG if ENSG in ENSG2token else "delete")
    return res

### Normalization with Gene Median Values

In [4]:
def Normalization_with_median(adata: anndata.AnnData, ENSid2median: Dict[str, int]) -> anndata.AnnData:
    '''
    Normalize expression data using gene-specific median values.
    '''
    if isinstance(adata.X, np.ndarray):
        X = adata.X
    else:
        X = adata.X.toarray()

    # Get median values for each gene
    gene_nonzero_median = []
    for gene_ENSid in adata.var.gene_ids.to_list():
        gene_nonzero_median.append(ENSid2median.get(gene_ENSid, np.nan))
    gene_nonzero_median = np.array(gene_nonzero_median)

    # Normalize by median values
    adata.X = np.nan_to_num(X / np.tile(gene_nonzero_median, (X.shape[0], 1)))
    return adata

### RNA Data Processing and Ranking

In [5]:

def rank_rna_value(adata: anndata.AnnData, ENSG2token: Dict[str, int], species='hg38') -> Dict:
    '''
    Process RNA data: remove zeros, sort in descending order, and convert ENSEMBL IDs to tokens.
    '''
    ENSid_list = adata.var.index.to_list()
    cell_names = adata.obs.index.to_list()
    cell2data = {}
    
    species2token = {'hg38': 0, 'mm10': 1}  # Example mapping
    
    for cell_idx, cell_data in enumerate(tqdm(adata.X)):
        # Get non-zero genes and sort in descending order
        nonzero_mask = np.nonzero(cell_data)[0]
        sorted_descend_indices = np.argsort(-cell_data[nonzero_mask])
        value = cell_data[nonzero_mask][sorted_descend_indices]
        ENSid_list_ = np.array(ENSid_list)[nonzero_mask][sorted_descend_indices]
        
        # Convert ENSEMBL IDs to tokens
        id_list = list(itemgetter(*ENSid_list_)(ENSG2token))
        assert len(id_list) == len(value)
        
        cell_name = cell_names[cell_idx]
        cell2data[cell_name] = {
            'input_ids': np.array(id_list).astype(np.int32),
            'values': np.array(value).astype(np.float32),
            'length': len(id_list),
            'species': species2token[species],
        }
        
        # Add cell type and batch information if available
        if 'cell_type' in adata.obs.keys():
            cell2data[cell_name]['cell_types'] = adata.obs.cell_type[cell_name]
        if 'batch' in adata.obs.keys():
            cell2data[cell_name]['batchs'] = adata.obs.batch[cell_name]
            
    return cell2data

### ATAC Data Processing Functions

In [6]:


def peak_filter_based_name(peak2token: Dict[str, str], peak_name_list: List[str], species=None) -> List[str]:
    '''
    Filter peaks based on name to token mapping.
    Peaks not in the mapping are marked as 'delete'.
    '''
    res = []
    for peak_name in peak_name_list:
        res.append(f"{species}_{peak_name}" if f"{species}_{peak_name}" in peak2token else "delete")
    return res

def rank_atac_peaks(adata: anndata.AnnData, peak2token: Dict[str, int], species) -> Dict:
    '''
    Process ATAC data: remove zeros and convert peak names to tokens.
    '''
    peak_name_list = [i for i in adata.var.index.to_list()]
    peak_token_list = list(itemgetter(*peak_name_list)(peak2token))
    cell_names = adata.obs.index.to_list()
    cell2data = {}
    
    if isinstance(adata.X, np.ndarray):
        X = adata.X
    else:
        X = adata.X.toarray()
    
    for cell_idx, cell_data in enumerate(tqdm(X)):
        nonzero_mask = np.nonzero(cell_data)[0]
        peak_token_list_cell = np.array(peak_token_list)[nonzero_mask]
        cell_name = cell_names[cell_idx]
        cell2data[cell_name] = peak_token_list_cell

    return cell2data

### Main Processing Function

In [7]:

def save_data(path: str, dataset: Dataset, rna_length: List[int]) -> None:
    '''
    Save processed dataset to disk with sorted RNA lengths.
    '''
    dataset.save_to_disk(path)
    sorted_list = sorted(rna_length)
    with open(os.path.join(path, 'sorted_rna_length.pickle'), 'wb') as f:
        pickle.dump(sorted_list, f)

def process(each_cell_name):
    '''
    Process individual cells by extracting RNA and ATAC information.
    '''
    rna_data = rna_cell2data[each_cell_name]
    
    # Extract RNA data
    rna_gene_ids = rna_data['input_ids']
    rna_gene_values = rna_data['values']
    rna_length = rna_data['length']
    cell_type = rna_data['cell_types'] if 'cell_types' in rna_data else None
    batch = rna_data['batchs'] if 'batchs' in rna_data else None

    return (np.array(rna_gene_ids).tolist(),
            np.array(rna_gene_values).tolist(),
            rna_length, 
            each_cell_name,
            cell_type,
            batch)


# Define species and corresponding token
species_ = ["hg38"]
species2token = {j: i for i, j in enumerate(species_)}

# Define paths to dictionary files
dict_data_dir = '../prior_data'

# Load ENSG to token mapping
print("1. Loading ENSG2token dictionary")
ENSG2token_path = f'{dict_data_dir}/hm_ENSG2token_dict.pickle'
ENSG2token = pd.read_pickle(ENSG2token_path)
token2ENSG = {j: i for i, j in ENSG2token.items()}

# Load gene median values
print("2. Loading median value dictionary")
ENSG2median_path = f'{dict_data_dir}/RNA_nonzero_median_10W.hg38.pickle'
ENSG2median = pd.read_pickle(ENSG2median_path)

# Process each sample
for species, sample_file_name_rna in zip(species_, rna_raw_names_all):
    sample_raw_name = os.path.basename(sample_file_name_rna).split(".")[0]
    save_path = os.path.join(output_dir, sample_raw_name, 'RNA_ATAC_data_v1')
    
    print('*' * 50)
    print(f'Processing: {sample_raw_name}')
    print('*' * 50)
    
    # Load RNA data
    print("3. Loading data")
    print(sample_file_name_rna)
    adata_rna_ = sc.read(sample_file_name_rna)
    adata_rna = anndata.AnnData(X=adata_rna_.X, var=adata_rna_.var, obs=adata_rna_.obs)
    
    # Filter genes based on ENSEMBL IDs
    print("4. Filtering genes")
    adata_rna.var['gene_ids'] = adata_rna.var['gene_ids'].str.replace(r'\.\d+$', '', regex=True)
    gene_ENSid_list = gene_filter_based_ENSid(ENSG2token, adata_rna.var.gene_ids.to_list())
    adata_rna.var['gene_names'] = adata_rna.var.index.tolist()
    adata_rna.var.index = gene_ENSid_list
    adata_rna = adata_rna[:, adata_rna.var.index != "delete"]
    
    # Normalize and process RNA data
    print("5. Normalizing, filtering, and ranking RNA values")
    sc.pp.normalize_total(adata_rna, target_sum=1e4, inplace=True)
    print(adata_rna.shape)
    sc.pp.log1p(adata_rna)
    adata_rna = Normalization_with_median(adata_rna, ENSG2median)
    rna_cell2data = rank_rna_value(adata_rna, ENSG2token, species=species)
    
    # Process cells
    print("6. Processing individual cells")
    rna_cell_names = list(rna_cell2data.keys())
    cell_inter_names = sorted(rna_cell_names)
    print(f"{sample_raw_name}, {len(cell_inter_names)} cells")
    
    # Prepare data arrays
    rna_gene_ids = []
    rna_gene_values = []
    rna_lengths = []
    cell_names = []
    cell_types = []
    batchs = []
    
    # Process cells in parallel
    with multiprocessing.Pool(processes=multiprocessing.cpu_count()) as pool:
        results = pool.imap(process, cell_inter_names)
        for res in results:
            if len(res) > 1:
                rna_gene_ids.append(res[0])
                rna_gene_values.append(res[1])
                rna_lengths.append(res[2])
                cell_names.append(res[3])
                cell_types.append(res[4])
                batchs.append(res[5])
    
    # Create dataset
    print("7. Creating dataset")
    data_dict = {
        'rna_gene_ids': rna_gene_ids,
        'rna_gene_values': rna_gene_values,
        'rna_lengths': rna_lengths,
        'species': [species2token[species]] * len(rna_lengths),
        'cell_name': cell_names,
        'cell_types': cell_types,
    }
    
    # Add batch information if available
    if None not in batchs:
        data_dict['batchs'] = batchs
    
    # Define dataset structure
    structure = Features({
        'rna_gene_ids': Sequence(feature=Value(dtype='int32')),
        'rna_gene_values': Sequence(feature=Value(dtype='float32')),
        'rna_lengths': Value(dtype='int16'),
        'species': Value(dtype='int8'),
        'cell_name': Value(dtype='string'),
        'cell_types': Value(dtype='string'),
    })
    
    # Add batch field if available
    if None not in batchs:
        structure = Features({
            'rna_gene_ids': Sequence(feature=Value(dtype='int32')),
            'rna_gene_values': Sequence(feature=Value(dtype='float32')),
            'rna_lengths': Value(dtype='int16'),
            'species': Value(dtype='int8'),
            'cell_name': Value(dtype='string'),
            'cell_types': Value(dtype='string'),
            'batchs': Value(dtype='string'),
        })
    
    # Create and save dataset
    dataset = Dataset.from_dict(data_dict, features=structure)
    os.makedirs(save_path, exist_ok=True)
    save_data(save_path, dataset, rna_lengths)
    
    print(f"Processing complete for {sample_raw_name}. Data saved to {save_path}")

1. Loading ENSG2token dictionary
2. Loading median value dictionary
**************************************************
Processing: gex
**************************************************
3. Loading data
/public/share/t_lgl/scFM/evaluation/data_set/raw_data/09_hPBMC_10k_scGLUE_10xDemo/gex.h5ad
4. Filtering genes
5. Normalizing, filtering, and ranking RNA values


  view_to_actual(adata)


(9631, 19365)


100%|██████████| 9631/9631 [00:40<00:00, 238.68it/s]

6. Processing individual cells
gex, 9631 cells





7. Creating dataset


Saving the dataset (0/1 shards):   0%|          | 0/9631 [00:00<?, ? examples/s]

Processing complete for gex. Data saved to ./processed_data/gex/RNA_ATAC_data_v1
