In [1]:
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch
from tqdm import tqdm
import scanpy as sc
import pandas as pd
from collections import defaultdict


## Load Data

In [12]:
train_data = torch.load('../data/train_data.pt')
val_data = torch.load('../data/val_data.pt')

In [3]:
# Load gene2go
gene2go = pd.read_csv("../../decode-agent/data/gene2go_human.tsv", sep="\t")
gene2go = gene2go[["GeneID", "GO_ID"]]

geneid_to_uniprot = torch.load('../data/geneid_to_uniprot.pt')

gene2go['GeneID'] = gene2go['GeneID'].astype(str)
gene2go['uniprot_id'] = gene2go['GeneID'].map(geneid_to_uniprot)
gene2go.dropna(axis=0, inplace=True)
gene2go.drop_duplicates(inplace=True)
gene2go

Unnamed: 0,GeneID,GO_ID,uniprot_id
0,1,GO:0002764,P04217
1,1,GO:0005576,P04217
5,1,GO:0005615,P04217
6,1,GO:0005886,P04217
7,1,GO:0031093,P04217
...,...,...,...
489840,128854680,GO:0005737,Q6B8I1
489842,128854680,GO:0008138,Q6B8I1
489843,128854680,GO:0016311,Q6B8I1
489844,128854680,GO:0016787,Q6B8I1


In [4]:
def build_ground_truth_map(gene2go_df):
    mapping = defaultdict(set)
    for _, row in gene2go_df.iterrows():
        mapping[row['GO_ID']].add(row['uniprot_id'])
    return mapping

In [5]:
class GeneGODataset(Dataset):
    def __init__(self, data):
        self.data = data  # list of (go_embed, esm_embed, go_id, uniprot_id)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        go_embed, esm_embed, go_id, uniprot_id = self.data[idx]
        return {
            "go_embed": go_embed,
            "esm_embed": esm_embed,
            "go_id": go_id,
            "uniprot_id": uniprot_id,
        }

In [7]:
class DualEncoder(nn.Module):
    def __init__(self, go_dim=768, esm_dim=1280, hidden_dim=512):
        super().__init__()
        self.go_proj = nn.Sequential(
            nn.Linear(go_dim, hidden_dim),
            nn.ReLU(),
            nn.LayerNorm(hidden_dim),
        )
        self.esm_proj = nn.Sequential(
            nn.Linear(esm_dim, hidden_dim),
            nn.ReLU(),
            nn.LayerNorm(hidden_dim),
        )

    def forward(self, go_embed, esm_embed):
        go_latent = self.go_proj(go_embed)
        esm_latent = self.esm_proj(esm_embed)
        return go_latent, esm_latent

In [8]:
def contrastive_loss(go_latents, esm_latents, temperature=0.07):
    # Normalize embeddings
    go_latents = F.normalize(go_latents, dim=1)
    esm_latents = F.normalize(esm_latents, dim=1)

    # Compute cosine similarity
    logits = torch.matmul(go_latents, esm_latents.T) / temperature

    # Labels: correct pair is at same index
    labels = torch.arange(go_latents.size(0), device=go_latents.device)
    loss_i = F.cross_entropy(logits, labels)
    loss_j = F.cross_entropy(logits.T, labels)
    return (loss_i + loss_j) / 2

In [14]:
def cosine_sim(a, b):
    return F.normalize(a, dim=-1) @ F.normalize(b, dim=-1).T

def evaluate(model, val_dataloader, gene2go_df, k=10):
    model.eval()
    all_go_latents = []
    all_go_ids = []
    all_protein_latents = []
    all_protein_ids = []

    # Build lookup from gene2go
    ground_truth = build_ground_truth_map(gene2go_df)

    with torch.no_grad():
        for batch in val_dataloader:
            go = batch["go_embed"]
            esm = batch["esm_embed"]
            go_ids = batch["go_id"]
            protein_ids = batch["uniprot_id"]

            go_latent, esm_latent = model(go, esm)

            all_go_latents.append(go_latent)
            all_go_ids.extend(go_ids)
            all_protein_latents.append(esm_latent)
            all_protein_ids.extend(protein_ids)

    go_latents = torch.cat(all_go_latents)  # (N_go, D)
    protein_latents = torch.cat(all_protein_latents)  # (N_protein, D)

    sim_matrix = cosine_sim(go_latents, protein_latents)  # (N_go, N_protein)

    recalls = []
    precisions = []
    mrrs = []

    for i, go_id in enumerate(all_go_ids):
        true_proteins = ground_truth.get(go_id, set())
        if not true_proteins:
            continue

        sims = sim_matrix[i]  # similarities to all proteins
        topk_idx = torch.topk(sims, k)[1]
        retrieved_ids = [all_protein_ids[j] for j in topk_idx]

        hits = [1 if pid in true_proteins else 0 for pid in retrieved_ids]
        num_hits = sum(hits)

        # Recall@k
        recalls.append(num_hits / len(true_proteins))

        # Precision@k
        precisions.append(num_hits / k)

        # MRR
        ranks = [j + 1 for j, h in enumerate(hits) if h]
        mrr = 1 / ranks[0] if ranks else 0
        mrrs.append(mrr)

    return {
        "Recall@k": sum(recalls) / len(recalls),
        "Precision@k": sum(precisions) / len(precisions),
        "MRR": sum(mrrs) / len(mrrs)
    }


In [21]:
### model = DualEncoder()
model = DualEncoder()
model.load_state_dict(torch.load('../models/20250516_20epoch.pt'))

train_dataset = GeneGODataset(train_data)
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_dataset = GeneGODataset(val_data)
val_dataloader = DataLoader(val_dataset, batch_size=128, shuffle=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

for epoch in range(10):
    model.train()
    total_loss = 0
    for batch in tqdm(train_dataloader):
        go = batch["go_embed"]
        esm = batch["esm_embed"]

        go_latent, esm_latent = model(go, esm)
        loss = contrastive_loss(go_latent, esm_latent)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}: Loss = {total_loss / len(train_dataloader):.4f}")

    metrics = evaluate(model, val_dataloader, gene2go, k=10)
    print(f"Evaluation: Recall@10 = {metrics['Recall@k']:.4f}, "
          f"Precision@10 = {metrics['Precision@k']:.4f}, "
          f"MRR = {metrics['MRR']:.4f}")

100%|██████████| 599/599 [08:19<00:00,  1.20it/s]


Epoch 1: Loss = 2.3164
Evaluation: Recall@10 = 0.1188, Precision@10 = 0.6804, MRR = 0.6914


100%|██████████| 599/599 [09:19<00:00,  1.07it/s]


Epoch 2: Loss = 2.3024
Evaluation: Recall@10 = 0.1037, Precision@10 = 0.6362, MRR = 0.6536


100%|██████████| 599/599 [09:40<00:00,  1.03it/s]


Epoch 3: Loss = 2.2890
Evaluation: Recall@10 = 0.1089, Precision@10 = 0.6755, MRR = 0.6852


100%|██████████| 599/599 [09:33<00:00,  1.04it/s]


Epoch 4: Loss = 2.2823
Evaluation: Recall@10 = 0.1123, Precision@10 = 0.6832, MRR = 0.6912


100%|██████████| 599/599 [09:33<00:00,  1.04it/s]


Epoch 5: Loss = 2.2717
Evaluation: Recall@10 = 0.1016, Precision@10 = 0.6763, MRR = 0.6721


100%|██████████| 599/599 [09:28<00:00,  1.05it/s]


Epoch 6: Loss = 2.2616
Evaluation: Recall@10 = 0.1177, Precision@10 = 0.6899, MRR = 0.6947


100%|██████████| 599/599 [09:03<00:00,  1.10it/s]


Epoch 7: Loss = 2.2596
Evaluation: Recall@10 = 0.1137, Precision@10 = 0.6213, MRR = 0.6279


100%|██████████| 599/599 [09:10<00:00,  1.09it/s]


Epoch 8: Loss = 2.2507
Evaluation: Recall@10 = 0.1061, Precision@10 = 0.6852, MRR = 0.6881


100%|██████████| 599/599 [09:38<00:00,  1.04it/s]


Epoch 9: Loss = 2.2399
Evaluation: Recall@10 = 0.1073, Precision@10 = 0.6899, MRR = 0.6903


100%|██████████| 599/599 [09:13<00:00,  1.08it/s]


Epoch 10: Loss = 2.2370
Evaluation: Recall@10 = 0.1159, Precision@10 = 0.6063, MRR = 0.6070


In [22]:
torch.save(model.state_dict(), '../models/20250516_30epoch.pt')