### BulkFormer feature extraction

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  

In [64]:
import math
import pandas as pd
import numpy as np
from tqdm import tqdm
from scipy.stats import pearsonr, spearmanr
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset,DataLoader,random_split
from torch_geometric.typing import SparseTensor
from utils.BulkFormer import BulkFormer
from model.config import model_params

In [4]:
# Configuration
device = 'cuda'
graph_path = 'data/G_gtex.pt'
weights_path = 'data/G_gtex_weight.pt'
gene_emb_path = 'data/esm2_feature_concat.pt'

In [5]:
# Initialize the BulkFormer model with preloaded graph structure and gene embeddings.
graph = torch.load(graph_path, map_location='cpu', weights_only=False)
weights = torch.load(weights_path, map_location='cpu', weights_only=False)
graph = SparseTensor(row=graph[1], col=graph[0], value=weights).t().to(device)
gene_emb = torch.load(gene_emb_path, map_location='cpu', weights_only=False)
model_params['graph'] = graph
model_params['gene_emb'] = gene_emb
model = BulkFormer(**model_params).to(device)

In [6]:
# Load the pretrained BulkFormer model checkpoint for inference or fine-tuning.
ckpt_model = torch.load('model/Bulkformer_ckpt_epoch_29.pt',weights_only=False)

new_state_dict = OrderedDict()
for key, value in ckpt_model.items():
    new_key = key[7:] if key.startswith("module.") else key
    new_state_dict[new_key] = value

model.load_state_dict(new_state_dict)

<All keys matched successfully>

In [7]:
def normalize_data(X_df, gene_length_dict):
    """
    Normalize RNA-seq count data to log-transformed TPM values.

    Parameters
    ----------
    X_df : pandas.DataFrame
        A gene expression matrix where rows represent samples and columns represent genes.
        Each entry contains the raw read count of a gene in a given sample.

    gene_length_dict : dict
        A dictionary mapping gene identifiers (Ensembl gene IDs) to gene lengths (in base pairs).

    Returns
    -------
    log_tpm_df : pandas.DataFrame
        A DataFrame of the same shape as `X_df`, containing log-transformed TPM values
        (i.e., log(TPM + 1)) for each gene in each sample.

    Description
    -----------
    This function converts raw RNA-seq count data to transcripts per million (TPM) values by
    normalizing for gene length and sample-specific total expression. Gene lengths are provided
    via `gene_length_dict`, and genes not present in the dictionary are assigned a default
    length of 1,000 bp (equivalent to no correction). The resulting TPM values are subsequently
    log-transformed using the natural logarithm (log1p). This normalization procedure accounts
    for both gene length and sequencing depth, facilitating cross-sample and cross-gene comparisons.
    """
    gene_names = X_df.columns
    gene_lengths_kb = np.array([gene_length_dict.get(gene, 1000) / 1000  for gene in gene_names])
    counts_matirx = X_df.values
    rate = counts_matirx / gene_lengths_kb
    sum_per_sample = rate.sum(axis=1)
    sum_per_sample[sum_per_sample == 0] = 1e-6  
    sum_per_sample = sum_per_sample.reshape(-1, 1)
    tpm = rate / sum_per_sample * 1e6
    log_tpm = np.log1p(tpm)
    log_tpm_df = pd.DataFrame(log_tpm,index=X_df.index, columns=X_df.columns)
    return log_tpm_df

In [8]:
def main_gene_selection(X_df, gene_list):
    """
    Aligns a gene expression matrix to a predefined gene list by adding placeholder values
    for missing genes and generating a binary mask indicating imputed entries.

    Parameters
    ----------
    X_df : pandas.DataFrame
        A gene expression matrix with rows representing samples and columns representing genes.
        The entries are typically log-transformed or normalized expression values.

    gene_list : list of str
        A predefined list of gene identifiers (Ensembl Gene IDs) to be retained
        in the final matrix. This list defines the desired gene space for downstream analyses.

    Returns
    -------
    X_df : pandas.DataFrame
        A gene expression matrix aligned to `gene_list`, with missing genes filled with a constant
        placeholder value (−10) and columns ordered accordingly.

    to_fill_columns : list of str
        A list of genes from `gene_list` that were not present in the original `X_df`
        and were therefore added with placeholder values.

    var : pandas.DataFrame
        A DataFrame with one row per gene, containing a binary column `'mask'` indicating
        whether a gene was imputed (1) or originally present (0). This can be used for masking
        in training or evaluation of models that distinguish observed and imputed entries.

    Notes
    -----
    This function ensures that all samples share a consistent gene space, which is essential
    for tasks such as model training, cross-dataset integration, or visualization. Placeholder
    values (−10) are used to maintain matrix shape while avoiding unintended bias in downstream
    statistical analyses or machine learning models.
    """
    to_fill_columns = list(set(gene_list) - set(X_df.columns))

    padding_df = pd.DataFrame(np.full((X_df.shape[0], len(to_fill_columns)), -10), 
                            columns=to_fill_columns, 
                            index=X_df.index)

    X_df = pd.DataFrame(np.concatenate([df.values for df in [X_df, padding_df]], axis=1), 
                        index=X_df.index, 
                        columns=list(X_df.columns) + list(padding_df.columns))
    X_df = X_df[gene_list]
    
    var = pd.DataFrame(index=X_df.columns)
    var['mask'] = [1 if i in to_fill_columns else 0 for i in list(var.index)]
    return X_df, to_fill_columns,var

In [9]:
def extract_feature(expr_array, 
                    high_var_gene_idx,
                    feature_type,
                    aggregate_type,
                    device,
                    batch_size,
                    return_expr_value = False,
                    esm2_emb = None,
                    valid_gene_idx = None):
    """
    Extracts transcriptome-level or gene-level feature representations from input expression profiles
    using a pre-trained deep learning model.

    Parameters
    ----------
    expr_array : np.ndarray
        A NumPy array of shape [N_samples, N_genes] representing gene expression profiles
        (e.g., log-transformed TPM values).

    high_var_gene_idx : list or np.ndarray
        Indices of highly variable genes used for transcriptome-level embedding aggregation.

    feature_type : str
        Specifies the type of feature to extract. Options:
            - 'transcriptome_level': aggregate gene embeddings to a single sample-level vector.
            - 'gene_level': retain per-gene embeddings for downstream fusion with external embeddings (e.g., ESM2).

    aggregate_type : str
        Aggregation method used when `feature_type='transcriptome_level'`. Options include:
            - 'max': use maximum value across selected genes.
            - 'mean': use average value.
            - 'median': use median value.
            - 'all': combine all three strategies by summation.

    device : torch.device
        Computation device (e.g., 'cuda' or 'cpu') for model inference.

    batch_size : int
        Number of samples per batch during feature extraction.

    return_expr_value : bool, optional
        If True, return predicted gene expression values instead of extracted embeddings. Default is False.

    esm2_emb : torch.Tensor, optional
        Precomputed ESM2 embeddings for all genes, used in gene-level feature concatenation.
        Required if `feature_type='gene_level'`.

    valid_gene_idx : list or np.ndarray, optional
        Indices of valid genes to be retained in gene-level embedding extraction.

    Returns
    -------
    result_emb : torch.Tensor
        The extracted feature representations:
            - [N_samples, D] for transcriptome-level features.
            - [N_samples, N_genes, D_concat] for gene-level features with ESM2 concatenation.

    or (if `return_expr_value=True`)
    expr_predictions : np.ndarray
        Model-predicted expression profiles for all samples.

    Notes
    -----
    This function supports two types of transcriptomic representations:
    (1) transcriptome-level features derived by aggregating gene-level embeddings from a deep model, and
    (2) gene-level embeddings optionally fused with external protein-based features such as ESM2.
    This allows flexible integration of expression and sequence-based representations for downstream tasks
    such as drug response prediction, disease classification, or feature alignment in multi-modal settings.
    """

    expr_tensor = torch.tensor(expr_array,dtype=torch.float32,device=device)
    mydataset = TensorDataset(expr_tensor)
    myloader = DataLoader(mydataset, batch_size=batch_size, shuffle=False) 
    model.eval()

    all_emb_list = []
    all_expr_value_list = []

    with torch.no_grad():
        if feature_type == 'transcriptome_level':
            for (X,) in tqdm(myloader, total=len(myloader)):
                X = X.to(device)
                output, emb = model(X, [2])
                all_expr_value_list.append(output.detach().cpu().numpy())
                emb = emb[2].detach().cpu().numpy()
                emb_valid = emb[:,high_var_gene_idx,:]
 
                if aggregate_type == 'max':
                    final_emb =np.max(emb_valid, axis=1)
                elif aggregate_type == 'mean':
                    final_emb =np.mean(emb_valid, axis=1)
                elif aggregate_type == 'median':
                    final_emb =np.median(emb_valid, axis=1)
                elif aggregate_type == 'all':
                    max_emb =np.max(emb_valid, axis=1)
                    mean_emb =np.mean(emb_valid, axis=1)
                    median_emb =np.median(emb_valid, axis=1)
                    final_emb = max_emb+mean_emb+median_emb

                all_emb_list.append(final_emb)
            result_emb = np.vstack(all_emb_list)
            result_emb = torch.tensor(result_emb,device='cpu',dtype=torch.float32)

        elif feature_type == 'gene_level':
            for (X,) in tqdm(myloader, total=len(myloader)):
                X = X.to(device)
                output, emb = model(X, [2])
                emb = emb[2].detach().cpu().numpy()
                emb_valid = emb[:,valid_gene_idx,:]
                all_emb_list.append(emb_valid)
                all_expr_value_list.append(output.detach().cpu().numpy())
            all_emb = np.vstack(all_emb_list)
            all_emb_tensor = torch.tensor(all_emb,device='cpu',dtype=torch.float32)
            esm2_emb_selected = esm2_emb[valid_gene_idx]
            esm2_emb_expanded = esm2_emb_selected.unsqueeze(0).expand(all_emb_tensor.shape[0], -1, -1) 
            esm2_emb_expanded = esm2_emb_expanded.to('cpu')

            result_emb = torch.cat([all_emb_tensor, esm2_emb_expanded], dim=-1)
    
    if return_expr_value:
        return np.vstack(all_expr_value_list)
    
    else:
        return result_emb

In [None]:
# Load demo normalized data (log-transformed TPM)
log_tpm_df = pd.read_csv('data/demo_normalized_data.csv')

In [34]:
# Load demo count data (raw count)
count_df = pd.read_csv('data/demo_count_data.csv')

In [35]:
# Convert raw counts to normalized expression values (log-transformed TPM)
gene_length_df = pd.read_csv('data/gene_length_df.csv')
gene_length_dict = gene_length_df.set_index('ensg_id')['length'].to_dict()
log_tpm_df = normalize_data(X_df=count_df, gene_length_dict=gene_length_dict)

In [36]:
bulkformer_gene_info = pd.read_csv('data/bulkformer_gene_info.csv')
bulkformer_gene_list = bulkformer_gene_info['ensg_id'].to_list()

In [37]:
# Align expression data to a predefined gene list with placeholder imputation for missing genes.
input_df , to_fill_columns, var= main_gene_selection(X_df=log_tpm_df,gene_list=bulkformer_gene_list)

In [38]:
var.reset_index(inplace=True)
valid_gene_idx = list(var[var['mask'] == 0].index)

In [39]:
high_var_gene_idx = torch.load('data/high_var_gene_list.pt',weights_only=False)

In [41]:
# Extract transcritome-level embedding
res1 = extract_feature(
    expr_array= input_df.values[:16],
    high_var_gene_idx=high_var_gene_idx,
    feature_type='transcriptome_level',
    aggregate_type='max',
    device=device,
    batch_size=4,
    return_expr_value=False,
    esm2_emb=model_params['gene_emb'],
    valid_gene_idx=valid_gene_idx
)

100%|██████████| 4/4 [00:04<00:00,  1.12s/it]


In [42]:
res1.shape

torch.Size([16, 640])

In [43]:
res1

tensor([[0.7260, 0.6496, 0.5655,  ..., 1.4178, 1.4416, 0.9152],
        [0.8606, 0.5044, 0.5764,  ..., 1.6397, 1.8156, 1.0455],
        [0.7369, 0.6450, 0.5741,  ..., 1.6920, 1.5773, 1.0917],
        ...,
        [0.8528, 0.6433, 0.4971,  ..., 1.4733, 1.5902, 1.0816],
        [0.6435, 0.7442, 0.6239,  ..., 1.2714, 1.8716, 1.0143],
        [0.8938, 0.5444, 0.4105,  ..., 1.2515, 1.6817, 0.9039]])

In [44]:
# Extract gene-level embedding
res2 = extract_feature(
    expr_array= input_df.values[:16],
    high_var_gene_idx=high_var_gene_idx,
    feature_type='gene_level',
    aggregate_type='all',
    device=device,
    batch_size=4,
    return_expr_value=False,
    esm2_emb=model_params['gene_emb'],
    valid_gene_idx=valid_gene_idx
)

100%|██████████| 4/4 [00:04<00:00,  1.17s/it]


In [45]:
res2.shape

torch.Size([16, 20010, 1920])

In [46]:
res2

tensor([[[-3.5360e-01, -1.6242e+00, -3.5189e-01,  ..., -9.7180e-02,
          -1.1555e-01, -6.9436e-02],
         [ 1.4793e-01, -1.6561e+00, -1.6825e-01,  ..., -9.2917e-02,
          -1.0225e-02,  7.4865e-02],
         [-1.3270e+00, -1.5774e+00, -2.7984e-01,  ..., -1.5073e-01,
          -1.7446e-02,  1.4547e-01],
         ...,
         [-4.7221e-01, -1.6490e+00, -3.5948e-01,  ..., -3.1607e-02,
           7.7570e-03,  9.4292e-02],
         [ 4.1132e-01, -1.5150e+00,  8.7162e-02,  ..., -8.9079e-02,
          -4.6900e-02,  1.8972e-01],
         [ 4.2977e-01, -1.2008e+00, -7.7758e-02,  ..., -5.2847e-02,
          -9.4606e-02,  6.6996e-02]],

        [[-8.0668e-02, -1.0516e+00, -1.7151e-02,  ..., -9.7180e-02,
          -1.1555e-01, -6.9436e-02],
         [ 2.9337e-01, -1.7559e+00, -1.5734e-01,  ..., -9.2917e-02,
          -1.0225e-02,  7.4865e-02],
         [-1.4982e+00, -1.0418e+00, -2.4782e-01,  ..., -1.5073e-01,
          -1.7446e-02,  1.4547e-01],
         ...,
         [-3.7557e-01, -1

In [47]:
# Extract expression values
res3 = extract_feature(
    expr_array= input_df.values[:16],
    high_var_gene_idx=high_var_gene_idx,
    feature_type='transcriptome_level',
    aggregate_type='all',
    device=device,
    batch_size=4,
    return_expr_value=True,
    esm2_emb=model_params['gene_emb'],
    valid_gene_idx=valid_gene_idx
)

100%|██████████| 4/4 [00:04<00:00,  1.22s/it]


In [48]:
res3.shape

(16, 20010)

In [49]:
res3

array([[3.7627857 , 0.33182847, 2.8268478 , ..., 3.5956197 , 0.40449452,
        0.45241898],
       [4.6875224 , 0.23467608, 3.7401597 , ..., 4.404271  , 0.25158244,
        0.33082467],
       [2.620734  , 0.33064708, 3.7773688 , ..., 2.4631546 , 0.3549513 ,
        0.44277862],
       ...,
       [0.9660633 , 0.29477847, 2.812781  , ..., 2.8931031 , 0.3250854 ,
        0.40352863],
       [0.39897698, 0.168719  , 3.5321    , ..., 4.6415453 , 0.23221342,
        0.27521023],
       [3.5847843 , 0.8626206 , 3.2892735 , ..., 4.1765776 , 0.46255967,
        0.5074125 ]], dtype=float32)