In [1]:
# ░░░  Cell: SpaFormer embedding generation  ░░░
import os, sys, numpy as np, torch, dgl
sys.path.append('.')                           # make "spaformer/" visible

from spaformer.edcoder import PreModel

# -- 1. load prepared arrays -------------------------------------------------
X      = np.load("Data/spaformer_prepared/X.npy").astype('float32')   # (N,G)
edges  = np.load("Data/spaformer_prepared/edges.npy").astype('int64') # (2,E)

print(f"Cells {X.shape[0]:>5}   genes {X.shape[1]:>4}   graph edges {edges.shape[1]}")

# -- 2. build a DGL graph -----------------------------------------------------
g = dgl.graph((edges[0], edges[1]), num_nodes=X.shape[0])
g = dgl.add_self_loop(g)                # encoders usually expect self-loops
device = torch.device('cpu')
g = g.to(device)

# store features inside the graph for encoders that look them up there
g.ndata['feat'] = torch.from_numpy(X).to(device)

# -- 3. instantiate **encoder-only** SpaFormer -------------------------------
model = PreModel(
        in_dim         = X.shape[1],
        num_hidden     = 256,
        num_layers     = 4,
        nhead          = 8,
        nhead_out      = 8,
        activation     = 'gelu',
        feat_drop      = 0.1,
        attn_drop      = 0.1,
        negative_slope = 0.2,
        norm           = 'layernorm',

        encoder_type   = 'gin',     # lightweight, no extra kwargs needed
        decoder_type   = 'linear',  # stub – we ignore the decoder
        loss_fn        = 'mse',
        latent_dim     = 256,

        mask_node_rate     = 0.0,
        mask_feature_rate  = 0.0,
        objective      = 'ae',
).to(device).eval()

# -- 4. (optional) load checkpoint & discard decoder -------------------------
ckpt = "checkpoints/spaformer_encoder_pretrain.pth"
if os.path.exists(ckpt):
    state = torch.load(ckpt, map_location=device)
    state = {k.replace('module.', ''): v            # strip DDP prefixes
             for k, v in state.items()
             if not k.startswith('decoder')}        # drop decoder weights
    model.load_state_dict(state, strict=False)
    print("✓ encoder checkpoint loaded")

# -- 5. forward pass (embedding extraction) ----------------------------------
with torch.no_grad():
    z = model.embed(                               # << use .embed(...)
        g,
        torch.from_numpy(X).to(device)             # node features
    )                                              # → (N, latent_dim)

print("Latent embeddings:", z.shape)

# -- 6. save -----------------------------------------------------------------
out_path = "Data/spaformer_embeddings.npy"
np.save(out_path, z.cpu().numpy())
print("✓ saved to", out_path)

  @custom_fwd(cast_inputs=th.float16)
  def backward(ctx, dZ):
  @custom_fwd(cast_inputs=th.float16)
  def backward(ctx, *dZ):
  @custom_fwd(cast_inputs=th.float16)
  def backward(ctx, dZ):
  @custom_fwd(cast_inputs=th.float16)
  def backward(ctx, *dZ):
  @custom_fwd(cast_inputs=th.float16)
  def backward(ctx, grad_out):
  @custom_fwd(cast_inputs=th.float16)
  def backward(ctx, *grad_out):
  @custom_fwd(cast_inputs=th.float16)
  def backward(ctx, dy):
  @custom_fwd(cast_inputs=th.float16)
  def backward(ctx, dy):
  @custom_fwd(cast_inputs=th.float16)
  @custom_fwd(cast_inputs=th.float16)


Cells  4992   genes 3004   graph edges 29952


  assert input.numel() == input.storage().size(), "Cannot convert view " \


Latent embeddings: torch.Size([4992, 256])
✓ saved to Data/spaformer_embeddings.npy
