In [None]:
# Load the edge index tensor that was extracted from KNN
# edge_index has shape [2, num_edges] where:
#   - edge_index[0] = source node indices
#   - edge_index[1] = target node indices
import torch

edge_index = torch.load("./graph/edge_index_top20.pt")
print(edge_index.shape)


torch.Size([2, 387140])


In [None]:
# Initialize node features for the graph
# num_genes = total number of unique genes in the KNN edge_index
# We create random embeddings of size [num_genes, d_model] to simulate gene embeddings
num_genes = edge_index.max().item() + 1
d_model = 128   # small for testing

# x shape: [num_genes, d_model] - each gene gets a random feature vector
x = torch.randn(num_genes, d_model)


In [None]:
# Define a Graph Attention Network (GAT) with 2 layers
# GATv2Conv uses multi-head attention to aggregate information from neighboring nodes
# heads=4: use 4 attention heads for richer representational capacity
# concat=False: average the heads instead of concatenating (keeps dimensions constant)
import torch.nn as nn
from torch_geometric.nn import GATv2Conv

class GeneGAT(nn.Module):
    def __init__(self, dim):
        super().__init__()
        # First GAT layer: applies attention over KNN edges, outputs [num_genes, dim]
        self.gat1 = GATv2Conv(dim, dim, heads=4, concat=False)
        # Second GAT layer: further refines embeddings with attention, outputs [num_genes, dim]
        self.gat2 = GATv2Conv(dim, dim, heads=4, concat=False)

    def forward(self, x, edge_index):
        # Layer 1: attend to neighbors, apply ReLU activation
        x = self.gat1(x, edge_index).relu()
        # Layer 2: attend to neighbors again on refined features
        x = self.gat2(x, edge_index)
        return x


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Instantiate the GAT model and run a forward pass
# This tests whether the KNN-extracted edges work correctly with the GATv2 layers
model = GeneGAT(d_model)
out = model(x, edge_index)

# Verify that input and output have the same shape
# (GAT preserves node count, only refines feature representations)
print(x.shape, "→", out.shape)


torch.Size([19357, 128]) → torch.Size([19357, 128])


In [5]:
esm2_raw = torch.load("./data/embeddings/esm2_t6_8M_UR50D_gene_embeddings.pt")
print(type(esm2_raw))
print(esm2_raw.keys())

<class 'dict'>
dict_keys(['embeddings', 'genes', 'model'])


In [6]:
print("\n[DATA] Loading ESM2 protein embeddings...")
raw = torch.load("./data/embeddings/esm2_t6_8M_UR50D_gene_embeddings.pt")

esm2 = raw["embeddings"].float()     # main embedding matrix
esm2_genes = raw["genes"]            # list of gene names for ordering

print("  ✓ ESM2 embeddings:", esm2.shape)
print("  ✓ ESM2 gene count:", len(esm2_genes))



[DATA] Loading ESM2 protein embeddings...
  ✓ ESM2 embeddings: torch.Size([19357, 320])
  ✓ ESM2 gene count: 19357
  ✓ ESM2 embeddings: torch.Size([19357, 320])
  ✓ ESM2 gene count: 19357


In [8]:
import pandas as pd

archs4_genes = pd.read_csv(
    "./data/archs4/processed_short_proteins/train_gene_order_short.csv"
)["gene_symbol"].tolist()

print(esm2_genes[:10])
print(archs4_genes[:10])
print(len(set(esm2_genes) & set(archs4_genes)))



['MT-ND1', 'MT-ND2', 'MT-CO1', 'MT-CO2', 'MT-ATP8', 'MT-ATP6', 'MT-CO3', 'MT-ND3', 'MT-ND4L', 'MT-ND4']
['MT-ND1', 'MT-ND2', 'MT-CO1', 'MT-CO2', 'MT-ATP8', 'MT-ATP6', 'MT-CO3', 'MT-ND3', 'MT-ND4L', 'MT-ND4']
19357


In [9]:
import torch

# Load file
path = "./data/embeddings/esm2_t6_8M_UR50D_gene_embeddings.pt"
data = torch.load(path)

# Extract embeddings
esm2 = data["embeddings"]        # tensor [19357, 320]

print("ESM2 shape:", esm2.shape)

# Pick a gene index (for example: gene 0)
idx = 0
vec = esm2[idx]

print(f"Embedding for gene index {idx}:")
print(vec)
print("Vector shape:", vec.shape)


ESM2 shape: torch.Size([19357, 320])
Embedding for gene index 0:
tensor([ 9.2978e-02,  6.5294e-02,  1.0580e-01,  1.7076e-01,  2.4781e-01,
        -1.1651e-01,  1.5424e-01,  1.3710e-02,  3.6349e-02, -2.3073e-01,
         1.5996e-01, -6.3580e-02, -1.9081e-02,  1.7054e-01,  2.7768e-01,
        -2.0805e-01,  5.3693e-02,  9.7801e-02, -7.7328e-02,  1.6150e-01,
         6.6416e-02,  4.3377e-02,  1.2056e-01,  5.8363e-02,  1.8182e-01,
         1.6818e-01, -5.2231e-02, -4.0707e-02,  2.4665e-01,  2.2673e-01,
        -2.4801e-02, -2.1880e-01,  7.9713e-02,  1.3767e-01,  2.4228e-01,
        -1.8152e-01,  1.5154e-01,  2.8462e-02,  1.5701e-01,  1.9493e-01,
        -6.5894e-02, -1.6705e-02,  1.0056e-01,  2.9792e-01,  7.9673e-03,
        -5.8195e-03, -7.1435e-01,  1.1258e-01,  1.3825e-01, -1.4130e-01,
        -1.1186e-01,  2.9676e-03, -4.3466e-02,  1.5705e-01, -2.6676e-02,
        -3.0482e-01,  7.9775e-02, -1.0222e-01, -1.2259e-01,  3.7344e-02,
        -5.3718e-02, -2.8479e-02, -5.1508e+00, -3.5766e-02,

In [11]:
X_df = pd.read_parquet("./data/archs4/processed_short_proteins/test_expr_logtpm_short.parquet")
print("Shape:", X_df.shape)
print("Index length (rows):", len(X_df.index))
print("Column length (genes):", len(X_df.columns))


Shape: (19357, 9446)
Index length (rows): 19357
Column length (genes): 9446
