In [4]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
from torch_geometric.nn import GATv2Conv
from performer_pytorch import Performer


# ============================================================
# 1️⃣ ROTARY EXPRESSION EMBEDDING
# ============================================================
class PositionalExprEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.mask_token_id = -10  # same as BulkFormer
        self.inv_freq = nn.Parameter(
            1.0 / (100 ** (torch.arange(0, dim, 2).float() / dim)),
            requires_grad=False
        )

    def forward(self, x):
        # x : [batch, genes]
        mask_idx = (x == self.mask_token_id).nonzero()

        rot = torch.einsum("bi,j->bij", x, self.inv_freq)
        rot = torch.cat((rot.sin(), rot.cos()), dim=-1)

        # zero out masked tokens
        if len(mask_idx) > 0:
            rot[mask_idx[:, 0], mask_idx[:, 1]] = 0

        return rot

# ============================================================
# 2️⃣ GBFormer (GAT + Performer blocks)
# ============================================================
class GBFormer(nn.Module):
    def __init__(self, dim, gene_length,
                 bin_head=2, full_head=2, bins=6, p_repeat=1):
        super().__init__()

        self.dim = dim
        self.bins = bins
        self.gene_length = gene_length

        # GAT module
        self.g = GATv2Conv(dim, dim, add_self_loops=False)

        # bin selector
        self.which_b = nn.Linear(dim, 1)

        # small Performer heads
        self.b = nn.ModuleList([
            Performer(
                dim=dim,
                heads=bin_head,
                dim_head=32,
                depth=1,
                attn_dropout=0.1,
                reversible=False
            )
            for _ in range(bins)
        ])

        self.f = nn.ModuleList([
            Performer(
                dim=dim,
                heads=full_head,
                dim_head=32,
                depth=1,
                attn_dropout=0.1,
                reversible=False
            )
            for _ in range(p_repeat)
        ])

        self.ln = nn.LayerNorm(dim)

    def forward(self, x, edge_index):
        # x: [B, G, E]
        B, G, E = x.shape

        x = self.ln(x)

        # --- SAFE GRAPH UPDATE (no in-place) ---
        x_graph = []
        for b in range(B):
            gx = self.g(x[b], edge_index)
            x_graph.append(x[b] + gx)
        x = torch.stack(x_graph, dim=0)

        # --- choose bins ---
        scores = self.which_b(x).squeeze(-1)      # [B, G]
        order = torch.argsort(scores, dim=1, descending=True)  # [B, G]
        order_exp = order.unsqueeze(-1).expand(-1, -1, E)

        # reorder (no in-place)
        x_sorted = torch.gather(x, 1, order_exp)

        # split
        n = (G - 1) // self.bins + 1
        chunks = torch.split(x_sorted, n, dim=1)

        # run performer per chunk
        outs = []
        for chunk, layer in zip(chunks, self.b):
            outs.append(layer(chunk))
        xs = torch.cat(outs, dim=1)

        # --- UNSORT (no inplace scatter_) ---
        out = torch.zeros_like(xs)
        out = out.scatter(1, order_exp, xs)   # **THIS IS NOT INPLACE**

        # --- global Performer ---
        for layer in self.f:
            out = layer(out)

        return out




# ============================================================
# 3️⃣ BULKFORMER MODEL (Modified)
# ============================================================
class BulkFormer(nn.Module):
    def __init__(self, dim, graph, gene_emb, gene_length,
                 bin_head=2, full_head=2, bins=10, gb_repeat=2, p_repeat=1):
        super().__init__()

        self.dim = dim
        self.graph = graph
        self.gene_length = gene_length

        # 320-dim ESM2 embedding
        self.gene_emb = nn.Parameter(gene_emb)

        self.gene_emb_proj = nn.Sequential(
            nn.Linear(gene_emb.shape[1], 4 * dim),
            nn.ReLU(),
            nn.Linear(4 * dim, dim)
        )

        self.expr_emb = PositionalExprEmbedding(dim)

        # AE latent → sample context vector
        self.ae_enc = nn.Sequential(
            nn.Linear(gene_length, 4 * dim),
            nn.ReLU(),
            nn.Linear(4 * dim, dim)
        )

        self.x_proj = nn.Sequential(
            nn.Linear(dim, 4 * dim),
            nn.ReLU(),
            nn.Linear(4 * dim, dim)
        )

        self.blocks = nn.ModuleList([
            GBFormer(dim, gene_length,
                     bin_head=bin_head,
                     full_head=full_head,
                     bins=bins,
                     p_repeat=p_repeat)
            for _ in range(gb_repeat)
        ])

        self.ln = nn.LayerNorm(dim)

        # Final head predicts expression
        self.head = nn.Sequential(
            nn.Linear(dim, 4 * dim),
            nn.ReLU(),
            nn.Linear(4 * dim, 1)
        )

    def forward(self, x, ae_latent=None):
        # x: [B, G]
        B, G = x.shape

        gene_tok = self.gene_emb_proj(self.gene_emb).unsqueeze(0)  # [1, G, dim]
        expr_tok = self.expr_emb(x)                                # [B, G, dim]
        ae_tok = ae_latent.unsqueeze(1)                            # [B, 1, dim]

        x = expr_tok + gene_tok + ae_tok
        x = self.x_proj(x)

        for block in self.blocks:
            x = block(x, self.graph)

        x = self.ln(x)

        out = self.head(x).squeeze(-1)  # predict masked expression
        return out



  from .autonotebook import tqdm as notebook_tqdm


In [7]:
# Load expression validation set
val_expr = pd.read_parquet(
    "./data/archs4/processed_short_proteins/val_expr_logtpm_short.parquet"
)

X_val = torch.tensor(val_expr.T.values.astype("float32"), device=device)
N_val, G = X_val.shape

print("Validation expression:", X_val.shape)

# Load ESM2 gene identity embeddings
esm2_raw = torch.load("./data/embeddings/esm2_t6_8M_UR50D_gene_embeddings.pt")
esm2 = esm2_raw["embeddings"].float().to(device)
print("ESM2:", esm2.shape)

# # Load AE sample latent embeddings (if needed)
# ae_val = torch.load("./data/embeddings/ae_gene_latents_320_val_set.pt")
# ae_val = torch.tensor(ae_val, dtype=torch.float32).to(device)
# print("AE latent:", ae_val.shape)


Validation expression: torch.Size([9557, 19357])
ESM2: torch.Size([19357, 320])


In [8]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from scipy.stats import pearsonr
import matplotlib.pyplot as plt

# -------------------------------------------------------------
# Load model (assumes BulkFormer class already imported)
# -------------------------------------------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"

model = BulkFormer(
    dim=320,
    graph=torch.load("./graph/edge_index_top20.pt").long().to(device),
    gene_emb=esm2,          # <-- MUST MATCH the same tensor shape
    gene_length=G,
    gb_repeat=1,
    bins=1,
    bin_head=2,
    full_head=2,
    p_repeat=1
).to(device)

model.load_state_dict(torch.load("bulkformer_gbformer.pt", map_location=device))
model.eval()


BulkFormer(
  (gene_emb_proj): Sequential(
    (0): Linear(in_features=320, out_features=1280, bias=True)
    (1): ReLU()
    (2): Linear(in_features=1280, out_features=320, bias=True)
  )
  (expr_emb): PositionalExprEmbedding()
  (ae_enc): Sequential(
    (0): Linear(in_features=19357, out_features=1280, bias=True)
    (1): ReLU()
    (2): Linear(in_features=1280, out_features=320, bias=True)
  )
  (x_proj): Sequential(
    (0): Linear(in_features=320, out_features=1280, bias=True)
    (1): ReLU()
    (2): Linear(in_features=1280, out_features=320, bias=True)
  )
  (blocks): ModuleList(
    (0): GBFormer(
      (g): GATv2Conv(320, 320, heads=1)
      (which_b): Linear(in_features=320, out_features=1, bias=True)
      (b): ModuleList(
        (0): Performer(
          (net): SequentialSequence(
            (layers): ModuleList(
              (0): ModuleList(
                (0): PreLayerNorm(
                  (norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
              

In [9]:
expr = pd.read_parquet(
    "./data/archs4/processed_short_proteins/test_expr_logtpm_short.parquet"
).T.values.astype("float32")

X = torch.tensor(expr).to(device)    # shape: [samples, genes]
N, G = X.shape
print(X.shape)


torch.Size([9446, 19357])


In [10]:
def apply_mask(x, mask_ratio=0.15):
    mask = (torch.rand_like(x) < mask_ratio)
    x_masked = x.clone()
    x_masked[mask] = -10       # BulkFormer mask token
    return x_masked, mask


In [11]:
def impute_bulkformer(model, X, mask_ratio=0.15):
    X_masked, mask = apply_mask(X, mask_ratio)
    with torch.no_grad():
        # AE latents = zeros if you are not using AE for inference
        ae_latent = torch.zeros((X.shape[0], 320), device=device)
        pred = model(X_masked, ae_latent)

    # extract masked positions only
    true_vals = X[mask].detach().cpu().numpy()
    pred_vals = pred[mask].detach().cpu().numpy()

    return pearsonr(true_vals, pred_vals)[0]


In [12]:
gene_means = X.mean(dim=0, keepdim=True)

def impute_mean(X, mask_ratio=0.15):
    _, mask = apply_mask(X, mask_ratio)
    true_vals = X[mask].cpu().numpy()
    pred_vals = gene_means.repeat(X.shape[0], 1)[mask].cpu().numpy()
    return pearsonr(true_vals, pred_vals)[0]


In [13]:
gene_medians = X.median(dim=0, keepdim=True).values

def impute_median(X, mask_ratio=0.15):
    _, mask = apply_mask(X, mask_ratio)
    true_vals = X[mask].cpu().numpy()
    pred_vals = gene_medians.repeat(X.shape[0], 1)[mask].cpu().numpy()
    return pearsonr(true_vals, pred_vals)[0]


In [None]:
import time

def impute_bulkformer_batched_gpu(model, X, mask_ratio=0.15, batch_size=64):
    """
    Fast GPU-optimized batched imputation.
    - Larger batch_size (64) to maximize GPU utilization
    - Vectorized mask operations on GPU
    - Minimal CPU transfers
    """
    model.eval()
    device = next(model.parameters()).device
    
    batch_start = time.time()
    N, G = X.shape
    
    # Create mask and masked input on GPU
    print(f"  [GPU] Creating masks... (shape: {X.shape})")
    mask = (torch.rand_like(X) < mask_ratio)
    X_masked = X.clone()
    X_masked[mask] = -10
    
    # Pre-allocate output on GPU
    ae_latent = torch.zeros((N, 320), device=device)
    preds = torch.zeros_like(X, device=device)
    
    print(f"  [GPU] Processing {N} samples in batches of {batch_size}...")
    num_batches = (N + batch_size - 1) // batch_size
    
    with torch.no_grad():
        for batch_idx, i in enumerate(range(0, N, batch_size)):
            batch_time = time.time()
            end_idx = min(i + batch_size, N)
            batch_size_actual = end_idx - i
            
            xb = X_masked[i:end_idx]
            aeb = ae_latent[i:end_idx]
            
            # Forward pass on GPU
            pred_b = model(xb, aeb)
            preds[i:end_idx] = pred_b
            
            batch_elapsed = time.time() - batch_time
            throughput = batch_size_actual / batch_elapsed
            print(f"    Batch {batch_idx+1}/{num_batches} | "
                  f"Samples {i+1}-{end_idx} | Time: {batch_elapsed:.2f}s | "
                  f"Throughput: {throughput:.0f} samples/s")
    
    # Extract masked positions only (keep on GPU until final computation)
    true_vals = X[mask]
    pred_vals = preds[mask]
    
    # Compute correlation on GPU for speed
    pcc = np.corrcoef(true_vals.detach().cpu().numpy(), 
                      pred_vals.detach().cpu().numpy())[0, 1]
    
    total_time = time.time() - batch_start
    print(f"  ✓ Total inference time: {total_time:.2f}s ({N/total_time:.0f} samples/s)")
    
    return {
        "true": true_vals.detach().cpu().numpy(),
        "pred": pred_vals.detach().cpu().numpy(),
        "pcc": pcc,
        "time": total_time
    }


# Run optimized GPU inference
print("\n[INFERENCE] Running GPU-optimized BulkFormer imputation...")
results = {}
results["BulkFormer"] = impute_bulkformer_batched_gpu(model, X, mask_ratio=0.15, batch_size=64)
results["Mean"]       = impute_mean(X, 0.15)
results["Median"]     = impute_median(X, 0.15)

print("\n[RESULTS] Imputation Performance (15% masked):")
for method, res in results.items():
    if isinstance(res, dict):
        print(f"  {method:12} | PCC: {res['pcc']:.4f} | Time: {res.get('time', 'N/A')}")
    else:
        print(f"  {method:12} | PCC: {res:.4f}")

results



[INFERENCE] Running GPU-optimized BulkFormer imputation...
  [GPU] Creating masks... (shape: torch.Size([9446, 19357]))
  [GPU] Processing 9446 samples in batches of 64...


OutOfMemoryError: CUDA out of memory. Tried to allocate 5.91 GiB. GPU 0 has a total capacity of 24.00 GiB of which 0 bytes is free. Process 66668 has 17179869184.00 GiB memory in use. Of the allocated memory 18.40 GiB is allocated by PyTorch, and 4.62 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
plt.figure(figsize=(6,4))
plt.bar(results.keys(), results.values(), color=["steelblue","gray","darkgray"])
plt.ylabel("Pearson r")
plt.title("Imputation performance (15% missing)")
plt.ylim(0,1)
plt.show()


In [None]:
print("\n[BENCHMARK] Testing masking ratios with GPU acceleration...\n")

mask_rates = [0.05, 0.15, 0.25, 0.35]
pcc_curve = []
times = []

for r in mask_rates:
    print(f"Testing mask ratio: {int(r*100)}%")
    start = time.time()
    result = impute_bulkformer_batched_gpu(model, X, mask_ratio=r, batch_size=64)
    elapsed = time.time() - start
    
    pcc_curve.append(result["pcc"])
    times.append(elapsed)
    print(f"  → PCC: {result['pcc']:.4f} | Total time: {elapsed:.2f}s\n")

# Summary
print("\n[BENCHMARK SUMMARY]")
print("Mask % | PCC    | Time (s) | Throughput")
print("-------|--------|----------|----------")
for r, p, t in zip(mask_rates, pcc_curve, times):
    throughput = len(X) / t
    print(f"{int(r*100):5}% | {p:.4f} | {t:8.2f} | {throughput:8.0f} samples/s")


In [None]:
plt.figure(figsize=(6,4))
plt.plot([r*100 for r in mask_rates], pcc_curve, marker="o")
plt.xlabel("Masking ratio (%)")
plt.ylabel("Pearson r")
plt.title("BulkFormer imputation vs missingness")
plt.ylim(0,1)
plt.grid()
plt.show()
