### BulkFormer feature extraction

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  

In [40]:
import math
import pandas as pd
import numpy as np
from tqdm import tqdm
from scipy.stats import pearsonr, spearmanr
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset,DataLoader,random_split
from torch_geometric.typing import SparseTensor

In [3]:
from utils.BulkFormer import BulkFormer

In [4]:
from model.config import model_params

In [7]:
device = 'cuda'

In [None]:
graph_path = 'data/G_gtex.pt'
weights_path = 'data/G_gtex_weight.pt'
gene_emb_path = 'data/esm2_feature_concat.pt'

In [9]:
graph = torch.load(graph_path, map_location='cpu', weights_only=False)
weights = torch.load(weights_path, map_location='cpu', weights_only=False)
graph = SparseTensor(row=graph[1], col=graph[0], value=weights).t().to(device)
gene_emb = torch.load(gene_emb_path, map_location='cpu', weights_only=False)
model_params['graph'] = graph
model_params['gene_emb'] = gene_emb

In [10]:
model = BulkFormer(**model_params).to(device)

In [None]:
ckpt_model = torch.load('model/Bulkformer_ckpt_epoch_29.pt',weights_only=False)

In [None]:
new_state_dict = OrderedDict()
for key, value in ckpt_model.items():
    new_key = key[7:] if key.startswith("module.") else key
    new_state_dict[new_key] = value

In [14]:
model.load_state_dict(new_state_dict)

<All keys matched successfully>

In [15]:
def extract_feature(expr_array, 
                    high_var_gene_idx,
                    feature_type,
                    aggregate_type,
                    device,
                    batch_size,
                    return_expr_value = False,
                    esm2_emb = None,
                    valid_gene_idx = None):
    # 构建dataloader
    expr_tensor = torch.tensor(expr_array,dtype=torch.float32,device=device)
    mydataset = TensorDataset(expr_tensor)
    myloader = DataLoader(mydataset, batch_size=batch_size, shuffle=False) 
    model.eval()

    all_emb_list = []
    all_expr_value_list = []

    # 转录组水平特征提取
    with torch.no_grad():
        if feature_type == 'transcriptome_level':
            for (X,) in tqdm(myloader, total=len(myloader)):
                X = X.to(device)
                output, emb = model(X, [2])
                all_expr_value_list.append(output.detach().cpu().numpy())
                emb = emb[2].detach().cpu().numpy()
                emb_valid = emb[:,high_var_gene_idx,:]
                # 采用不同的策略对特征进行聚合
                if aggregate_type == 'max':
                    final_emb =np.max(emb_valid, axis=1)
                elif aggregate_type == 'mean':
                    final_emb =np.mean(emb_valid, axis=1)
                elif aggregate_type == 'median':
                    final_emb =np.median(emb_valid, axis=1)
                elif aggregate_type == 'all':
                    max_emb =np.max(emb_valid, axis=1)
                    mean_emb =np.mean(emb_valid, axis=1)
                    median_emb =np.median(emb_valid, axis=1)
                    final_emb = max_emb+mean_emb+median_emb

                all_emb_list.append(final_emb)
            result_emb = np.vstack(all_emb_list)
            result_emb = torch.tensor(result_emb,device='cpu',dtype=torch.float32)

        elif feature_type == 'gene_level':
            for (X,) in tqdm(myloader, total=len(myloader)):
                X = X.to(device)
                output, emb = model(X, [2])
                emb = emb[2].detach().cpu().numpy()
                emb_valid = emb[:,valid_gene_idx,:]
                all_emb_list.append(emb_valid)
                all_expr_value_list.append(output.detach().cpu().numpy())
            all_emb = np.vstack(all_emb_list)
            all_emb_tensor = torch.tensor(all_emb,device='cpu',dtype=torch.float32)
            esm2_emb_selected = esm2_emb[valid_gene_idx]
            esm2_emb_expanded = esm2_emb_selected.unsqueeze(0).expand(all_emb_tensor.shape[0], -1, -1)  # [B, N, D]
            esm2_emb_expanded = esm2_emb_expanded.to('cpu')
            # 在最后一个维度拼接 -> [B, N, 640+1280]
            result_emb = torch.cat([all_emb_tensor, esm2_emb_expanded], dim=-1)
    
    if return_expr_value:
        return np.vstack(all_expr_value_list)
    
    else:
        return result_emb

In [None]:
def main_gene_selection(X_df, gene_list):

    to_fill_columns = list(set(gene_list) - set(X_df.columns))

    # 使用-10(训练中的mask token)来填充缺失的基因
    padding_df = pd.DataFrame(np.full((X_df.shape[0], len(to_fill_columns)), -10), 
                            columns=to_fill_columns, 
                            index=X_df.index)

    X_df = pd.DataFrame(np.concatenate([df.values for df in [X_df, padding_df]], axis=1), 
                        index=X_df.index, 
                        columns=list(X_df.columns) + list(padding_df.columns))
    X_df = X_df[gene_list]
    
    var = pd.DataFrame(index=X_df.columns)
    var['mask'] = [1 if i in to_fill_columns else 0 for i in list(var.index)]
    return X_df, to_fill_columns,var

In [None]:
# load demo data
demo_df = pd.read_csv('data/demo.csv')

In [None]:
bulkformer_gene_info = pd.read_csv('data/bulkformer_gene_info.csv')

In [20]:
bulkformer_gene_list = bulkformer_gene_info['ensg_id'].to_list()

In [21]:
input_df , to_fill_columns, var= main_gene_selection(X_df=demo_df,gene_list=bulkformer_gene_list)

In [None]:
var.reset_index(inplace=True)
valid_gene_idx = list(var[var['mask'] == 0].index)

In [None]:
high_var_gene_idx = torch.load('data/high_var_gene_list.pt',weights_only=False)

In [32]:
# Extract transcritome-level embedding
result = extract_feature(
    expr_array= input_df.values[:16],
    high_var_gene_idx=high_var_gene_idx,
    feature_type='transcriptome_level',
    aggregate_type='max',
    device=device,
    batch_size=4,
    return_expr_value=False,
    esm2_emb=model_params['gene_emb'],
    valid_gene_idx=valid_gene_idx
)

100%|██████████| 4/4 [00:04<00:00,  1.10s/it]


In [33]:
result.shape

torch.Size([16, 640])

In [34]:
# Extract gene-level embedding
result = extract_feature(
    expr_array= input_df.values[:16],
    high_var_gene_idx=high_var_gene_idx,
    feature_type='gene_level',
    aggregate_type='all',
    device=device,
    batch_size=4,
    return_expr_value=False,
    esm2_emb=model_params['gene_emb'],
    valid_gene_idx=valid_gene_idx
)

100%|██████████| 4/4 [00:04<00:00,  1.15s/it]


In [35]:
result.shape

torch.Size([16, 20010, 1920])

In [36]:
# Extract expression values
result = extract_feature(
    expr_array= input_df.values[:16],
    high_var_gene_idx=high_var_gene_idx,
    feature_type='transcriptome_level',
    aggregate_type='all',
    device=device,
    batch_size=4,
    return_expr_value=True,
    esm2_emb=model_params['gene_emb'],
    valid_gene_idx=valid_gene_idx
)

100%|██████████| 4/4 [00:04<00:00,  1.21s/it]


In [37]:
result.shape

(16, 20010)