# BulkFormer Feature Extraction Tutorial

This notebook demonstrates how to use **BulkFormer**, a graph-based attention Transformer model for extracting rich feature representations from bulk RNA-seq expression data.

## Overview

BulkFormer integrates:
1. **Gene co-expression networks** (GTEx-based gene-gene relationships)
2. **Protein sequence embeddings** (ESM2 embeddings from protein sequences)
3. **Expression values** (log-transformed TPM values)

The model uses a Graph-Based Transformer architecture to learn contextualized gene embeddings that can be aggregated into sample-level representations for downstream tasks like drug response prediction, tissue classification, and disease phenotype analysis.

**Reference:** [BulkFormer manuscript (bioRxiv 2025)](https://www.biorxiv.org/content/10.1101/2025.06.11.659222v1.full)

## Package Structure

This notebook uses the refactored `bulkformer` package:
- **Model classes:** `bulkformer.models.model`
- **Configuration:** `bulkformer.config`
- **Utility functions:** `bulkformer.utils`

For CLI usage: `uv run bulkformer extract --help`

## 1. Setup and Configuration

First, configure the GPU device to use. This is optional - the code will automatically fall back to CPU if CUDA is not available.


In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  

## 2. Import Required Libraries

Import all necessary packages:
- **PyTorch & PyG:** For deep learning and graph operations
- **BulkFormer components:** Model, configuration, and utility functions
- **Data processing:** pandas, numpy for data manipulation


In [3]:
import math
import pandas as pd
import numpy as np
from tqdm import tqdm
from scipy.stats import pearsonr, spearmanr
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, random_split
from torch_geometric.typing import SparseTensor
from bulkformer.models.model import BulkFormer
from bulkformer.config import model_params
from bulkformer.utils import normalize_data, align_genes, extract_features

## 3. Define Data Paths

The notebook automatically detects the project root directory, handling execution from either:
- The project root: `/path/to/BulkFormer/`
- The notebooks folder: `/path/to/BulkFormer/notebooks/`

All paths are resolved relative to the project root using `pathlib.Path`, ensuring the notebook works correctly regardless of the current working directory.

Set paths to required data files:
- **G_gtex.pt:** Gene co-expression graph structure (GTEx-based)
- **G_gtex_weight.pt:** Edge weights for the gene graph
- **esm2_feature_concat.pt:** Pre-computed ESM2 protein embeddings for all genes


In [4]:
from pathlib import Path

# Detect project root (handle running from root or notebooks folder)
current_dir: Path = Path.cwd()
if current_dir.name == 'notebooks':
    project_root: Path = current_dir.parent
else:
    project_root: Path = current_dir

print(f"üìÅ Project root: {project_root}")

# Configuration
device: str = 'cuda'  # Use 'cpu' if CUDA is not available
graph_path: Path = project_root / 'data' / 'G_gtex.pt'
weights_path: Path = project_root / 'data' / 'G_gtex_weight.pt'
gene_emb_path: Path = project_root / 'data' / 'esm2_feature_concat.pt'

üìÅ Project root: /home/antonkulaga/sources/BulkFormer


### 3.1 Download Required Files

‚ö†Ô∏è **Note:** This notebook requires data and model files that need to be downloaded first.

If you haven't downloaded them yet, run: `uv run bulkformer download all`


In [14]:
# This notebook uses data and model files from the data/ and model/ directories
# If you get FileNotFoundError, download them with:
#   cd {project_root}
#   uv run bulkformer download all
# 
# Or visit: https://doi.org/10.5281/zenodo.15559368


## 4. Initialize BulkFormer Model

Load the model components and initialize BulkFormer:

1. **Gene co-expression graph:** A sparse adjacency matrix representing gene-gene relationships learned from GTEx data
2. **Graph weights:** Edge weights quantifying the strength of co-expression relationships
3. **Gene embeddings:** ESM2-derived protein sequence embeddings (2560-dimensional per gene)

The model uses these components to contextualize each gene's expression through graph-based attention mechanisms.


In [5]:
# Initialize the BulkFormer model with preloaded graph structure and gene embeddings
graph: torch.Tensor = torch.load(str(graph_path), map_location='cpu', weights_only=False)
weights: torch.Tensor = torch.load(str(weights_path), map_location='cpu', weights_only=False)

# Create sparse tensor representation for efficient graph operations
graph: SparseTensor = SparseTensor(row=graph[1], col=graph[0], value=weights).t().to(device)

# Load ESM2 protein embeddings (20,010 genes √ó 2560 dimensions)
gene_emb: torch.Tensor = torch.load(str(gene_emb_path), map_location='cpu', weights_only=False)

# Configure model with graph structure and embeddings
model_params['graph'] = graph
model_params['gene_emb'] = gene_emb

# Initialize model and move to device
model: BulkFormer = BulkFormer(**model_params).to(device)

## 5. Load Pre-trained Checkpoint

Load the pre-trained BulkFormer weights (trained on GTEx + TCGA data).

The model was trained with a reconstruction objective, learning to predict gene expression values while building contextualized gene representations through the graph-based attention layers.


In [6]:
# Load the pre-trained BulkFormer model checkpoint
model_checkpoint_path: Path = project_root / 'model' / 'Bulkformer_ckpt_epoch_29.pt'
ckpt_model: dict = torch.load(str(model_checkpoint_path), weights_only=False)

# Remove 'module.' prefix from keys (artifacts from DataParallel training)
new_state_dict: OrderedDict = OrderedDict()
for key, value in ckpt_model.items():
    new_key: str = key[7:] if key.startswith("module.") else key
    new_state_dict[new_key] = value

# Load weights into model
model.load_state_dict(new_state_dict)

<All keys matched successfully>

## 6. Utility Functions

The following functions are imported from `bulkformer.utils`:

### `normalize_data(X_df, gene_length_dict)`
Converts raw RNA-seq counts to **log-transformed TPM** (Transcripts Per Million):
- Normalizes for gene length (longer genes get more reads)
- Normalizes for sequencing depth (total reads per sample)
- Applies log1p transformation for variance stabilization

### `align_genes(X_df, gene_list)`
Aligns your expression data to BulkFormer's gene space (20,010 genes):
- Adds missing genes with placeholder value (-10)
- Reorders columns to match model's expected gene order
- Returns a binary mask indicating which genes were imputed

### `extract_features(model, expr_array, ...)`
Extracts features from expression data:
- **transcriptome_level:** Sample-level embeddings (aggregated across genes)
- **gene_level:** Gene-level embeddings (per gene, per sample)
- **expression_imputation:** Model-predicted expression values


## 7. Load Expression Data

Load your bulk RNA-seq expression data. The demo file contains log-transformed TPM values for 968 samples.

**Expected format:**
- Rows: Samples
- Columns: Genes (Ensembl IDs, e.g., ENSG00000000003)
- Values: Log-transformed TPM (log(TPM + 1))

If you have **raw counts** instead, see cells 10-11 for normalization.


In [7]:
# Load demo normalized data (log-transformed TPM)
demo_data_path: Path = project_root / 'data' / 'demo.csv'
log_tpm_df: pd.DataFrame = pd.read_csv(demo_data_path, index_col=0)
print(f"Loaded expression data: {log_tpm_df.shape[0]} samples √ó {log_tpm_df.shape[1]} genes")

Loaded expression data: 967 samples √ó 20009 genes


### 7.1 Optional: Normalization from Raw Counts

If you have raw count data instead of normalized TPM, use the `normalize_data()` function:

```python
gene_length_path = project_root / 'data' / 'gene_length_df.csv'
gene_length_df = pd.read_csv(gene_length_path)
gene_length_dict = gene_length_df.set_index('ensg_id')['length'].to_dict()
log_tpm_df = normalize_data(X_df=count_df, gene_length_dict=gene_length_dict)
```

This converts counts ‚Üí TPM ‚Üí log(TPM + 1) in one step.


In [None]:
# Optional: Load raw count data if you have it
# count_df = pd.read_csv('your_count_data.csv', index_col=0)

In [None]:
# Optional: Convert raw counts to normalized expression values (log-transformed TPM)
# If you have raw count data, uncomment and run this cell
# gene_length_path = project_root / 'data' / 'gene_length_df.csv'
# gene_length_df = pd.read_csv(gene_length_path)
# gene_length_dict = gene_length_df.set_index('ensg_id')['length'].to_dict()
# log_tpm_df = normalize_data(X_df=count_df, gene_length_dict=gene_length_dict)

## 8. Gene Alignment

BulkFormer expects exactly **20,010 genes** in a specific order. The `align_genes()` function:
1. Identifies missing genes in your data
2. Adds them with placeholder value (-10)
3. Reorders columns to match the model's expected gene order
4. Creates a binary mask to track which genes were imputed


In [8]:
# Load BulkFormer's gene list (20,010 genes)
gene_info_path: Path = project_root / 'data' / 'bulkformer_gene_info.csv'
bulkformer_gene_info: pd.DataFrame = pd.read_csv(gene_info_path)
bulkformer_gene_list: list[str] = bulkformer_gene_info['ensg_id'].to_list()
print(f"BulkFormer expects {len(bulkformer_gene_list)} genes")

BulkFormer expects 20010 genes


In [9]:
# Align expression data to BulkFormer's gene space
input_df: pd.DataFrame
to_fill_columns: list[str]
var: pd.DataFrame
input_df, to_fill_columns, var = align_genes(X_df=log_tpm_df, gene_list=bulkformer_gene_list)

print(f"Aligned data shape: {input_df.shape}")
print(f"Number of genes imputed: {len(to_fill_columns)}")

Aligned data shape: (967, 20010)
Number of genes imputed: 1


In [10]:
var.reset_index(inplace=True)
valid_gene_idx: list[int] = list(var[var['mask'] == 0].index)
print(f"Number of valid (non-imputed) genes: {len(valid_gene_idx)}")

Number of valid (non-imputed) genes: 20009


### 8.1 Identify Valid Genes

Extract indices of genes that were **not** imputed (mask == 0). These are genes that were present in your original data.


In [11]:
high_var_gene_path: Path = project_root / 'data' / 'high_var_gene_list.pt'
high_var_gene_idx: torch.Tensor = torch.load(str(high_var_gene_path), weights_only=False)
print(f"Number of high-variability genes: {len(high_var_gene_idx)}")

Number of high-variability genes: 2000


### 8.2 Load High-Variability Genes

Load indices of highly variable genes used for transcriptome-level aggregation. These genes were selected based on variance across GTEx training data and are most informative for sample-level representations.


In [12]:
# Extract transcriptome-level embedding (sample-level representation)
# Using first 16 samples as example
res1: torch.Tensor = extract_features(
    model=model,
    expr_array=input_df.values[:16],  # Shape: [n_samples, n_genes]
    high_var_gene_idx=high_var_gene_idx,  # Highly variable genes for aggregation
    feature_type='transcriptome_level',  # Aggregate to sample-level
    aggregate_type='max',  # Use max pooling
    device=device,
    batch_size=4,
    return_expr_value=False,  # Return embeddings, not expression predictions
    esm2_emb=model_params['gene_emb'],
    valid_gene_idx=valid_gene_idx
)

Extracting features:   0%|          | 0/4 [00:00<?, ?it/s]

Extracting features: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:05<00:00,  1.46s/it]


## 9. Feature Extraction

BulkFormer can extract three types of features:

### 9.1 Transcriptome-Level Embeddings

Sample-level representations created by aggregating gene embeddings from layer 2 of BulkFormer.

**Aggregation methods:**
- `'max'`: Maximum pooling across genes (captures peak signals)
- `'mean'`: Average pooling (balanced representation)
- `'median'`: Median pooling (robust to outliers)
- `'all'`: Sum of max, mean, and median (most comprehensive)

**Use cases:** Drug response prediction, tissue classification, disease phenotyping


In [13]:
print(f"Transcriptome-level embeddings shape: {res1.shape}")
print(f"  - {res1.shape[0]} samples")
print(f"  - {res1.shape[1]} embedding dimensions")
res1.shape

Transcriptome-level embeddings shape: torch.Size([16, 640])
  - 16 samples
  - 640 embedding dimensions


torch.Size([16, 640])

**Result shape:** `[n_samples, embedding_dim]`
- Each sample is represented by a single vector
- Default embedding dimension is 640 (from model config)


In [14]:
res1

tensor([[2.4622, 0.3204, 0.8734,  ..., 2.5265, 2.0351, 2.1097],
        [0.8776, 0.6784, 0.5454,  ..., 1.5435, 1.7944, 1.3206],
        [2.0156, 1.1207, 1.0379,  ..., 2.4110, 2.2231, 2.2176],
        ...,
        [1.0094, 0.5021, 0.7578,  ..., 1.5157, 1.7562, 1.3782],
        [1.1167, 0.8199, 0.5786,  ..., 1.7397, 1.8430, 1.0931],
        [2.3170, 1.0496, 1.0248,  ..., 2.4119, 2.0700, 2.2761]])

In [15]:
# Extract gene-level embeddings (per-gene, per-sample)
res2: torch.Tensor = extract_features(
    model=model,
    expr_array=input_df.values[:16],  # Shape: [n_samples, n_genes]
    high_var_gene_idx=high_var_gene_idx,
    feature_type='gene_level',  # Keep gene-level resolution
    aggregate_type='all',  # Not used for gene_level, but required parameter
    device=device,
    batch_size=4,
    return_expr_value=False,
    esm2_emb=model_params['gene_emb'],  # Concatenate with ESM2 embeddings
    valid_gene_idx=valid_gene_idx  # Only valid (non-imputed) genes
)

Extracting features: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:05<00:00,  1.49s/it]


### 9.2 Gene-Level Embeddings

Per-gene, per-sample representations that concatenate:
1. **BulkFormer embeddings** (context-aware, from layer 2)
2. **ESM2 protein embeddings** (sequence-based, static)

**Result:** Fused embeddings combining expression context and protein sequence information.

**Use cases:** Gene-level predictions (e.g., essentiality, pathway activity), multi-modal learning


In [16]:
print(f"Gene-level embeddings shape: {res2.shape}")
print(f"  - {res2.shape[0]} samples")
print(f"  - {res2.shape[1]} valid genes")
print(f"  - {res2.shape[2]} embedding dimensions (BulkFormer 640 + ESM2 2560)")
res2.shape

Gene-level embeddings shape: torch.Size([16, 20009, 1920])
  - 16 samples
  - 20009 valid genes
  - 1920 embedding dimensions (BulkFormer 640 + ESM2 2560)


torch.Size([16, 20009, 1920])

**Result shape:** `[n_samples, n_valid_genes, combined_dim]`
- `combined_dim = bulkformer_dim (640) + esm2_dim (2560) = 3200`
- Only non-imputed genes are included


In [17]:
res2

tensor([[[ 1.0551, -1.8429,  0.1478,  ..., -0.0929, -0.0102,  0.0749],
         [ 0.1574, -1.2592, -0.0561,  ..., -0.1507, -0.0174,  0.1455],
         [ 1.0401,  0.0246,  0.6002,  ..., -0.0918,  0.0819,  0.0879],
         ...,
         [ 0.6453, -1.6589,  0.2273,  ..., -0.0316,  0.0078,  0.0943],
         [ 1.1758, -1.7251,  0.3288,  ..., -0.0891, -0.0469,  0.1897],
         [ 0.9684, -1.2767,  0.5495,  ..., -0.0528, -0.0946,  0.0670]],

        [[ 0.2741, -1.6812, -0.1459,  ..., -0.0929, -0.0102,  0.0749],
         [-1.3639, -1.1414, -0.2839,  ..., -0.1507, -0.0174,  0.1455],
         [-0.4136, -2.3218, -0.5184,  ..., -0.0918,  0.0819,  0.0879],
         ...,
         [-0.3523, -1.3333, -0.3306,  ..., -0.0316,  0.0078,  0.0943],
         [ 0.5097, -1.6272,  0.1369,  ..., -0.0891, -0.0469,  0.1897],
         [ 0.5339, -1.0964,  0.0391,  ..., -0.0528, -0.0946,  0.0670]],

        [[ 0.9490, -1.4407,  0.2156,  ..., -0.0929, -0.0102,  0.0749],
         [-0.4132, -1.9796, -0.4000,  ..., -0

In [18]:
# Extract model-predicted expression values (imputation)
res3: np.ndarray = extract_features(
    model=model,
    expr_array=input_df.values[:16],  # Shape: [n_samples, n_genes]
    high_var_gene_idx=high_var_gene_idx,
    feature_type='transcriptome_level',
    aggregate_type='all',
    device=device,
    batch_size=4,
    return_expr_value=True,  # Return expression predictions instead of embeddings
    esm2_emb=model_params['gene_emb'],
    valid_gene_idx=valid_gene_idx
)

Extracting features: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:06<00:00,  1.54s/it]


### 9.3 Expression Imputation

Use BulkFormer's reconstruction head to predict/impute gene expression values.

**Use cases:**
- Impute missing genes in your data
- Denoise expression measurements
- Reconstruct full expression profiles from partial data

**Note:** Set `return_expr_value=True` to get expression predictions instead of embeddings.


In [19]:
print(f"Predicted expression values shape: {res3.shape}")
print(f"  - {res3.shape[0]} samples")
print(f"  - {res3.shape[1]} genes (all 20,010 genes)")
res3.shape

Predicted expression values shape: (16, 20010)
  - 16 samples
  - 20010 genes (all 20,010 genes)


(16, 20010)

**Result shape:** `[n_samples, n_genes]`
- Model's prediction of log-transformed TPM values
- Can be compared with input to assess reconstruction quality


In [20]:
res3

array([[1.2735547 , 0.17127015, 0.32037833, ..., 0.20279405, 0.19119827,
        0.22872534],
       [3.0849693 , 0.2441585 , 3.2666268 , ..., 3.9643104 , 0.2724411 ,
        0.34996656],
       [1.3099808 , 0.30123997, 0.7367003 , ..., 0.3702867 , 0.30796665,
        0.3844244 ],
       ...,
       [0.9285814 , 0.20617509, 1.6781065 , ..., 0.24234761, 0.23101696,
        0.29190505],
       [4.0148835 , 1.7378029 , 3.1773732 , ..., 3.713794  , 0.44067258,
        0.51471543],
       [3.4331362 , 0.31305268, 4.104111  , ..., 0.33597103, 0.3229271 ,
        0.37848967]], shape=(16, 20010), dtype=float32)

## 10. Summary and Next Steps

You've successfully extracted three types of features from bulk RNA-seq data:

1. **Transcriptome-level embeddings** (`res1`): Compact sample representations for downstream ML tasks
2. **Gene-level embeddings** (`res2`): Rich per-gene features combining expression context and protein sequence
3. **Expression predictions** (`res3`): Model-imputed gene expression values

### Downstream Applications

These features can be used for:
- **Drug response prediction:** Train ML models using transcriptome embeddings
- **Disease classification:** Use embeddings to classify cancer subtypes, disease stages, etc.
- **Cross-dataset integration:** BulkFormer embeddings can bridge different RNA-seq datasets
- **Gene function prediction:** Use gene-level embeddings for functional annotation
- **Data imputation:** Use predicted expressions to fill missing values

### Saving Results

```python
# Save embeddings (paths are relative to project root)
output_dir = project_root / 'results'
output_dir.mkdir(exist_ok=True)

torch.save(res1, output_dir / 'transcriptome_embeddings.pt')
torch.save(res2, output_dir / 'gene_level_embeddings.pt')

# Save predictions as CSV
pred_df = pd.DataFrame(res3, index=input_df.index[:16], columns=input_df.columns)
pred_df.to_csv(output_dir / 'expression_predictions.csv')
```

### References

- **BulkFormer paper:** [bioRxiv 2025.06.11.659222](https://www.biorxiv.org/content/10.1101/2025.06.11.659222v1.full)
- **GTEx Consortium:** [https://gtexportal.org](https://gtexportal.org)
- **ESM2 protein embeddings:** Lin et al., Science 2023
