# RNA and ATAC Data Preprocessing Notebook
This notebook provides a complete pipeline for preprocessing single-cell RNA-seq and ATAC-seq data, including normalization, filtering, and tokenization for downstream analysis.



## Import Required Libraries

In [1]:
import os
import scanpy as sc
import anndata
import pandas as pd
import numpy as np
from typing import List, Dict
from tqdm import tqdm
from operator import itemgetter
from datasets import Dataset, Features, Sequence, Value
import pickle
import multiprocessing
from collections import Counter



In [2]:
species_list = ['hg38', 'mm10']
species2token = {'hg38': 0, 'mm10': 1} 
dict_data_dir = './preprocess/dict_data'

In [3]:
sample_raw_names_all=['pbmc10k']
species_=['hg38']
rna_raw_names_all = [
    "sample_data/hPBMC_10k_scGLUE_10xDemo/gex.h5ad"
]

atac_raw_names_all = [
    "sample_data/hPBMC_10k_scGLUE_10xDemo/atac.h5ad"
]

# Define output directory
output_dir = "./processed_data/"
os.makedirs(output_dir, exist_ok=True)

### RNA Data Processing and Ranking

In [4]:
def gene_filter_based_ENSid(ENSG2token: Dict[str, str], ENSG_list: List[str]) -> List[str]:
    '''
    Filter genes based on ENSEMBL ID to token mapping.
    Genes not in the mapping are marked as 'delete'.
    '''
    res = []
    for ENSG in ENSG_list:
        res.append(ENSG if ENSG in ENSG2token else "delete")
    return res

def Normalization_with_median(adata: anndata.AnnData, ENSid2median: Dict[str, int]) -> anndata.AnnData:
    '''
    Normalize expression data using gene-specific median values.
    '''
    if isinstance(adata.X, np.ndarray):
        X = adata.X
    else:
        X = adata.X.toarray()

    # Get median values for each gene
    gene_nonzero_median = []
    for gene_ENSid in adata.var.gene_ids.to_list():
        gene_nonzero_median.append(ENSid2median.get(gene_ENSid, np.nan))
    gene_nonzero_median = np.array(gene_nonzero_median)

    # Normalize by median values
    adata.X = np.nan_to_num(X / np.tile(gene_nonzero_median, (X.shape[0], 1)))
    return adata

def rank_rna_value(adata: anndata.AnnData, ENSG2token: Dict[str, int], species='hg38') -> Dict:
    '''
    Process RNA data: remove zeros, sort in descending order, and convert ENSEMBL IDs to tokens.
    '''
    ENSid_list = adata.var.index.to_list()
    cell_names = adata.obs.index.to_list()
    cell2data = {}
    
    
    for cell_idx, cell_data in enumerate(tqdm(adata.X)):
        # Get non-zero genes and sort in descending order
        nonzero_mask = np.nonzero(cell_data)[0]
        sorted_descend_indices = np.argsort(-cell_data[nonzero_mask])
        value = cell_data[nonzero_mask][sorted_descend_indices]
        ENSid_list_ = np.array(ENSid_list)[nonzero_mask][sorted_descend_indices]
        
        # Convert ENSEMBL IDs to tokens
        id_list = list(itemgetter(*ENSid_list_)(ENSG2token))
        assert len(id_list) == len(value)
        
        cell_name = cell_names[cell_idx]
        cell2data[cell_name] = {
            'input_ids': np.array(id_list).astype(np.int32),
            'values': np.array(value).astype(np.float32),
            'length': len(id_list),
            'species': species2token[species],
        }
        
        # Add cell type and batch information if available
        if 'cell_type' in adata.obs.keys():
            cell2data[cell_name]['cell_types'] = adata.obs.cell_type[cell_name]
        if 'batch' in adata.obs.keys():
            cell2data[cell_name]['batchs'] = adata.obs.batch[cell_name]
            
    return cell2data

### ATAC Data Processing Functions

In [5]:
def peak_filter_based_name(peak2token: Dict[str, str], peak_name_list: List[str], species=None) -> List[str]:
    '''
    Filter peaks based on name to token mapping.
    Peaks not in the mapping are marked as 'delete'.
    '''
    
    res = []
    for peak_name in peak_name_list:
        res.append(f"{species}_{peak_name}" if f"{species}_{peak_name}" in peak2token else "delete")
    return res

def rank_atac_peaks(adata: anndata.AnnData, peak2token: Dict[str, int], species) -> Dict:
    '''
    Process ATAC data: remove zeros and convert peak names to tokens.
    '''
    peak_name_list = [i for i in adata.var.index.to_list()]
    peak_token_list = list(itemgetter(*peak_name_list)(peak2token))
    cell_names = adata.obs.index.to_list()
    cell2data = {}
    
    if isinstance(adata.X, np.ndarray):
        X = adata.X
    else:
        X = adata.X.toarray()
    
    for cell_idx, cell_data in enumerate(tqdm(X)):
        nonzero_mask = np.nonzero(cell_data)[0]
        peak_token_list_cell = np.array(peak_token_list)[nonzero_mask]
        cell_name = cell_names[cell_idx]
        cell2data[cell_name] = peak_token_list_cell

    return cell2data

### Main Processing Function

In [6]:
def save_data(path: str, dataset: Dataset, rna_length: List[int], atac_length: List[int]) -> None:
    '''save dataset to path, ascending sort length and save as pickle'''
    dataset.save_to_disk(path)
    sorted_list = sorted(rna_length)
    with open(path + '/sorted_rna_length.pickle', 'wb') as f:
        pickle.dump(sorted_list, f)
    sorted_list = sorted(atac_length)
    with open(path + '/sorted_atac_length.pickle', 'wb') as f:
        pickle.dump(sorted_list, f)

In [None]:
def process(each_cell_name):
    """
    Process multi-omics data for a single cell.
    
    Args:
        each_cell_name: Identifier for the cell to process
        
    Returns:
        Tuple containing processed RNA and ATAC data for the cell
    """
    # Extract RNA and ATAC data for the current cell
    rna_data = rna_cell2data[each_cell_name]
    atac_data = atac_cell2data[each_cell_name]
    peak_num = len(atac_data)
    
    # Filter cells with insufficient ATAC data
    if len(atac_data) < 1000:
        return []

    # Extract RNA features
    rna_gene_ids = rna_data['input_ids']
    rna_gene_values = rna_data['values']
    rna_length = rna_data['length']
    cell_type = rna_data['cell_types'] if 'cell_types' in rna_data else None

    # Filter cells with insufficient RNA data
    if rna_length < 50:
        return []

    # Map ATAC peaks to genes
    gene_pos = []
    for peak_id in atac_data:
        if peak_id in peak2gene:
            gene_pos.extend(peak2gene[peak_id])
    gene_pos = np.array(gene_pos)

    # Calculate ATAC gene accessibility scores
    geneId2count = Counter(gene_pos[:, 0])
    geneId_counts = np.array([[gene_id, count] for gene_id, count in geneId2count.items()])
    geneId_total_peaks = np.array([ENSG2peakNum[token2ENSG[gene_token]] for gene_token in geneId_counts[:, 0]])
    geneId_score = geneId_counts[:, 1] / geneId_total_peaks
    
    # Filter genes by accessibility score and select top ones
    geneId_counts = geneId_counts[geneId_score > 0.01]
    geneId_score = geneId_score[geneId_score > 0.01]
    arg_sort_score = np.argsort(-geneId_score)
    geneId_score = geneId_score[arg_sort_score]
    geneId_counts = geneId_counts[arg_sort_score]
    
    if geneId_counts.shape[0] == 0:
        return []

    # Find intersection between RNA and ATAC genes
    atac_genes_list = geneId_counts[:, 0].tolist()
    ran_gene_list = rna_gene_ids.tolist()
    gene_intersection = list(set(atac_genes_list) & set(ran_gene_list))
    geneId_has_atac_ratio = len(gene_intersection) / len(ran_gene_list)

    # Filter gene positions based on ATAC genes
    id_indexs = np.isin(gene_pos[:, 0], atac_genes_list)
    gene_pos = gene_pos[id_indexs]

    # Organize peak-gene relationships
    atac_length = geneId_counts.shape[0]
    gene_peak = {}
    unique_genes = geneId_counts[:, 0]
    poses_list = [gene_pos[gene_pos[:, 0] == gene_id][:, 1].tolist() for gene_id in unique_genes]
    lengths = [len(poses) for poses in poses_list]
    count = np.cumsum([0] + lengths).tolist()
    index = [pos for poses in poses_list for pos in poses]
    gene_peak['count'] = count
    gene_peak['index'] = index

    # Extract batch information if available
    batch = rna_data['batchs'] if 'batchs' in rna_data else None

    # Return all processed data
    return (np.array(rna_gene_ids).tolist(),
            np.array(rna_gene_values).tolist(),
            rna_length,
            geneId_counts[:, 0],
            geneId_score,
            gene_peak,
            atac_length,
            geneId_has_atac_ratio,
            peak_num,
            each_cell_name,
            atac_data,
            cell_type,
            batch)

for sample_raw_name,species, sample_file_name_rna,sample_file_name_atac in zip(sample_raw_names_all, species_, rna_raw_names_all,atac_raw_names_all):

    save_path = './processed_data/' + (sample_raw_name) + '/RNA_ATAC_data_v1'
    print('Processing sample:', sample_raw_name)
    
    # 1. Load ENSG to token dictionary
    print("1. Loading ENSG2token dictionary")
    ENSG2token_path = f'{dict_data_dir}/hm_ENSG2token_dict.pickle'
    ENSG2token = pd.read_pickle(ENSG2token_path)
    token2ENSG = {j: i for i, j in ENSG2token.items()}
    
    # 2. Load median expression values
    print("2. Loading median expression dictionary")
    ENSG2median_path = f'{dict_data_dir}/RNA_nonzero_median_10W.{species}.pickle'
    ENSG2median = pd.read_pickle(ENSG2median_path)
    
    # 3. Load RNA and ATAC data
    print("3. Loading data")
    adata_rna_ = sc.read(sample_file_name_rna)
        
    adata_rna = anndata.AnnData(X=adata_rna_.X, var=adata_rna_.var, obs=adata_rna_.obs)
    adata_atac = sc.read(sample_file_name_atac)
    
    # 4. Filter genes based on ENSG IDs
    print("4. Filtering genes")
    gene_ENSid_list = gene_filter_based_ENSid(ENSG2token, adata_rna.var.index.to_list())
    adata_rna.var['gene_names'] = adata_rna.var.gene_symbols
    adata_rna.var.index = gene_ENSid_list
    adata_rna = adata_rna[:, adata_rna.var.index != "delete"]
    
    # 5. Normalize and process RNA data
    print("5. Normalizing and processing RNA data")
    sc.pp.normalize_total(adata_rna, target_sum=1e4, inplace=True)
    sc.pp.log1p(adata_rna)
    adata_rna = Normalization_with_median(adata_rna, ENSG2median)
    rna_cell2data = rank_rna_value(adata_rna, ENSG2token, species=species)
    
    # 6. Load peak dictionaries
    print("6. Loading peak dictionaries")
    peak2token_path = f'{dict_data_dir}/peak2token_dict.pickle'
    peak2token = pd.read_pickle(peak2token_path)
    peak2gene_path = f'{dict_data_dir}/peakId2geneID_dict.pickle'
    peak2gene = pd.read_pickle(peak2gene_path)
    
    # 7-8. Filter and process ATAC peaks
    print("7. Filtering and processing ATAC peaks")
    peak_name_list = peak_filter_based_name(peak2token, adata_atac.var.index.to_list(), species=species)
    adata_atac.var.index = peak_name_list
    adata_atac = adata_atac[:, adata_atac.var.index != "delete"]
    
    print("8. Ranking ATAC peaks")
    atac_cell2data = rank_atac_peaks(adata_atac, peak2token, species=species)
    
    # 9. Load ENSG to peak number dictionary
    print("9. Loading ENSG2peakNum dictionary")
    ENSG2peakNum_path = f'{dict_data_dir}/ENSG2peakNum_dict.pickle'
    ENSG2peakNum = pd.read_pickle(ENSG2peakNum_path)
    
    # 10. Process cells with both RNA and ATAC data
    print("10. Processing multi-omics cells")
    rna_cell_names = list(rna_cell2data.keys())
    atac_cell_names = list(atac_cell2data.keys())
    cell_inter_names = sorted(list(set(rna_cell_names) & set(atac_cell_names)))
    
    print(f"Sample {sample_raw_name} has {len(cell_inter_names)} multi-omics cells")
    
    # Initialize lists to store processed data
    rna_gene_ids = []
    rna_gene_values = []
    atac_gene_peaks = []
    rna_lengths = []
    cell_names = []
    gene_id_has_atac_ratios = []
    peak_nums = []
    atac_gene_ids = []
    atac_gene_scores = []
    atac_lengths = []
    atac_cell_peaks = []
    cell_types = []
    batchs = []
    
    # Process cells in parallel
    pool = multiprocessing.Pool(processes=2)
    results = pool.imap(process, cell_inter_names)
    
    for res in results:
        if len(res) > 1:
            rna_gene_ids.append(res[0])
            rna_gene_values.append(res[1])
            rna_lengths.append(res[2])
            atac_gene_ids.append(res[3])
            atac_gene_scores.append(res[4])
            atac_gene_peaks.append(res[5])
            atac_lengths.append(res[6])
            gene_id_has_atac_ratios.append(res[7])
            peak_nums.append(res[8])
            cell_names.append(res[9])
            atac_cell_peaks.append(res[10])
            cell_types.append(res[11])
            batchs.append(res[12])

        if len(atac_lengths) > 0 and len(atac_lengths) % 100 == 0:
            print(f"Processed cell number: {len(atac_lengths)}")

    print(f"Sample {sample_raw_name}: {len(atac_lengths)} cells processed successfully")
    
    # 11. Create dataset dictionary
    data_dict = {
        'rna_gene_ids': rna_gene_ids,
        'rna_gene_values': rna_gene_values,
        'rna_lengths': rna_lengths,
        'atac_gene_ids': atac_gene_ids,
        'atac_gene_scores': atac_gene_scores,
        'atac_gene_peaks': atac_gene_peaks,
        'atac_cell_peaks': atac_cell_peaks,
        'atac_lengths': atac_lengths,
        'genes_has_atac': gene_id_has_atac_ratios,
        'peak_num': peak_nums,
        'species': [species2token[species]] * len(atac_lengths),
        'cell_name': cell_names,
        'cell_types': cell_types,
    }

    # Add batch information if available
    if None not in batchs:
        data_dict['batchs'] = batchs
        structure = Features({
            'rna_gene_ids': Sequence(feature=Value(dtype='int32')),
            'rna_gene_values': Sequence(feature=Value(dtype='float32')),
            'rna_lengths': Value(dtype='int16'),
            'atac_gene_ids': Sequence(feature=Value(dtype='int32')),
            'atac_gene_scores': Sequence(feature=Value(dtype='float16')),
            'atac_gene_peaks': {
                'count': Sequence(feature=Value(dtype='int32')),
                'index': Sequence(feature=Value(dtype='int16'))
            },
            'atac_cell_peaks': Sequence(feature=Value(dtype='int32')),
            'atac_lengths': Value(dtype='int16'),
            'genes_has_atac': Value(dtype='float16'),
            'peak_num': Value(dtype='int32'),
            'species': Value(dtype='int8'),
            'cell_name': Value(dtype='string'),
            'cell_types': Value(dtype='string'),
            'batchs': Value(dtype='string'),
        })
    else:
        structure = Features({
            'rna_gene_ids': Sequence(feature=Value(dtype='int32')),
            'rna_gene_values': Sequence(feature=Value(dtype='float32')),
            'rna_lengths': Value(dtype='int16'),
            'atac_gene_ids': Sequence(feature=Value(dtype='int32')),
            'atac_gene_scores': Sequence(feature=Value(dtype='float16')),
            'atac_gene_peaks': {
                'count': Sequence(feature=Value(dtype='int32')),
                'index': Sequence(feature=Value(dtype='int16'))
            },
            'atac_cell_peaks': Sequence(feature=Value(dtype='int32')),
            'atac_lengths': Value(dtype='int16'),
            'genes_has_atac': Value(dtype='float16'),
            'peak_num': Value(dtype='int32'),
            'species': Value(dtype='int8'),
            'cell_name': Value(dtype='string'),
            'cell_types': Value(dtype='string'),
        })

    # 12. Create and save dataset
    dataset = Dataset.from_dict(data_dict, features=structure)
    os.makedirs(save_path, exist_ok=True)
    save_data(save_path, dataset, rna_lengths, atac_lengths)

    # Clean up
    pool.close()
    pool.join()
    del data_dict
    print(f"Completed processing for sample: {sample_raw_name}")

Processing sample: pbmc10k
1. Loading ENSG2token dictionary
2. Loading median expression dictionary
3. Loading data
4. Filtering genes
5. Normalizing and processing RNA data


  view_to_actual(adata)
100%|██████████| 9631/9631 [00:33<00:00, 291.06it/s]


6. Loading peak dictionaries
7. Filtering and processing ATAC peaks
8. Ranking ATAC peaks


100%|██████████| 9631/9631 [11:57<00:00, 13.43it/s]


9. Loading ENSG2peakNum dictionary
10. Processing multi-omics cells
Sample pbmc10k has 9631 multi-omics cells
Processed cell number: 100
Processed cell number: 200
Processed cell number: 300
Processed cell number: 400
Processed cell number: 500
Processed cell number: 600
Processed cell number: 700
Processed cell number: 800
Processed cell number: 900
Processed cell number: 1000
Processed cell number: 1100
Processed cell number: 1200
Processed cell number: 1300
Processed cell number: 1400
Processed cell number: 1500
Processed cell number: 1600
Processed cell number: 1700
Processed cell number: 1800
Processed cell number: 1900
Processed cell number: 2000
Processed cell number: 2100
Processed cell number: 2200
Processed cell number: 2300
Processed cell number: 2400
Processed cell number: 2500
Processed cell number: 2600
Processed cell number: 2700
Processed cell number: 2800
Processed cell number: 2900
Processed cell number: 3000
Processed cell number: 3100
Processed cell number: 3200
Pro

Saving the dataset (0/10 shards):   0%|          | 0/9631 [00:00<?, ? examples/s]

Completed processing for sample: pbmc10k
