
# NBFNet on STRING PPI (True NBFNet) + PAM50 Probe

This notebook:
1) Downloads **STRING PPI** (human; configurable threshold)
2) Builds triples and train/valid/test splits
3) **Clones & imports the real NBFNet** from `KiddoZhu/NBFNet-PyG` (no dot-product fallback)
4) Trains & evaluates a 6-layer NBFNet on the PPI KG
5) **Probes PAM50 genes**: for each of the 50, ranks likely partners



## 1) Imports & global config

In [1]:
# 항상 맨 첫 셀에서 실행!
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"  # ★ 변경점: 반드시 0,1만 노출

import random
import numpy as np
import torch
from pathlib import Path


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# ==== Global settings ====
DATA_ROOT = Path("./data_string"); DATA_ROOT.mkdir(parents=True, exist_ok=True)
RUN_DIR = Path("./nbfnet_runs"); RUN_DIR.mkdir(parents=True, exist_ok=True)
EXPT_NAME = "string_ppi_interact"

# STRING
SPECIES_ID = 9606  # Human
CONFIDENCE_THRESHOLD = 800  # 0..1000
MAX_EDGES = None           # set to int to cap edges for quick dev
RANDOM_SEED = 42
TRAIN_RATIO, VALID_RATIO, TEST_RATIO = 0.85 , 0.05, 0.10

# NBFNet hyperparams (6 layers)
INPUT_DIM = 64
HIDDEN_DIMS = [64,64,64,64,64,64]   # 6 layers
# (참고) 아래에서 정/역 관계 2개를 쓸 것이므로 여기의 NUM_RELATIONS는 사용하지 않음
NUM_EPOCHS = 10
BATCH_SIZE = 4096
LR = 2e-3
NUM_NEG = 16
TOPK = 20
MESSAGE_FUNCT = "distmult"
AGGREGATE_FUNC = "pna"
SHORTCUT = True
LAYER_NORM = True

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

<torch._C.Generator at 0x7f5752f928f0>

## 2) Download STRING (aliases + links)

In [3]:

import requests
import shutil

def download_with_fallback(urls, dest_path):
    dest_path = Path(dest_path)
    dest_path.parent.mkdir(parents=True, exist_ok=True)
    for u in urls:
        try:
            print(f"Trying: {u}")
            with requests.get(u, stream=True, timeout=60) as r:
                r.raise_for_status()
                tmp = dest_path.with_suffix(dest_path.suffix + ".part")
                with open(tmp, "wb") as f:
                    shutil.copyfileobj(r.raw, f)
                tmp.rename(dest_path)
            print(f"Downloaded -> {dest_path}")
            return dest_path
        except Exception as e:
            print("  Failed:", e)
    raise RuntimeError("All URL candidates failed.")

def string_aliases_urls(species):
    return [
        f"https://stringdb-static.org/download/protein.aliases.v12.0/{species}.protein.aliases.v12.0.txt.gz",
        f"https://stringdb-static.org/download/protein.aliases.v11.5/{species}.protein.aliases.v11.5.txt.gz",
    ]

def string_links_urls(species):
    return [
        f"https://stringdb-static.org/download/protein.links.full.v12.0/{species}.protein.links.full.v12.0.txt.gz",
        f"https://stringdb-static.org/download/protein.links.detailed.v11.5/{species}.protein.links.detailed.v11.5.txt.gz",
        f"https://stringdb-static.org/download/protein.links.v11.5/{species}.protein.links.v11.5.txt.gz",
    ]

ALIASES_GZ = DATA_ROOT / f"{SPECIES_ID}.protein.aliases.txt.gz"
LINKS_GZ   = DATA_ROOT / f"{SPECIES_ID}.protein.links.txt.gz"

if not ALIASES_GZ.exists():
    download_with_fallback(string_aliases_urls(SPECIES_ID), ALIASES_GZ)

if not LINKS_GZ.exists():
    download_with_fallback(string_links_urls(SPECIES_ID), LINKS_GZ)

print("Ready:", ALIASES_GZ, LINKS_GZ)


Ready: data_string/9606.protein.aliases.txt.gz data_string/9606.protein.links.txt.gz


## 3) Parse aliases → symbol ↔ protein id

In [4]:
import re, gzip
import pandas as pd

def read_aliases_gz(path, prefer_sources=("Ensembl_HGNC", "HGNC", "Ensembl", "UniProtKB")):
    rows = []
    with gzip.open(path, "rt", encoding="utf-8", errors="ignore") as f:
        for line in f:
            if line.startswith("#"): continue
            parts = line.rstrip("\n").split("\t")
            if len(parts) < 3: 
                continue
            prot, alias, source = parts[0], parts[1], parts[2]
            rows.append((prot, alias, source))
    df = pd.DataFrame(rows, columns=["protein_id", "alias", "source"])
    df["is_symbol_like"] = df["alias"].str.match(r"^[A-Z0-9._-]+$")
    pref = df["source"].apply(lambda s: (prefer_sources.index(s) if s in prefer_sources else len(prefer_sources)))
    df["pref_rank"] = pref
    sym_df = df[df["is_symbol_like"]].copy()
    best_alias = sym_df.sort_values(["protein_id","pref_rank"]).groupby("protein_id").first().reset_index()
    prot_to_sym = dict(zip(best_alias["protein_id"], best_alias["alias"]))
    sym_to_prot = {}
    for _, r in best_alias.iterrows():
        sym = r["alias"]; pid = r["protein_id"]
        if sym not in sym_to_prot:
            sym_to_prot[sym] = pid
    return prot_to_sym, sym_to_prot

prot_to_sym, sym_to_prot = read_aliases_gz(ALIASES_GZ)
len(prot_to_sym), len(sym_to_prot)


(19699, 19691)

In [5]:
import gzip, random
import pandas as pd

def preview_mappings(prot_to_sym, sym_to_prot, n=10, random_sample=False):
    print(f"[counts] prot_to_sym={len(prot_to_sym):,}, sym_to_prot={len(sym_to_prot):,}")

    if not prot_to_sym or not sym_to_prot:
        print("\n[debug] mappings are empty. Showing first few raw lines from the gz file…")
        with gzip.open(ALIASES_GZ, "rt", encoding="utf-8", errors="ignore") as f:
            for i, line in zip(range(10), f):
                print(line.rstrip("\n"))
        return

    if random_sample:
        pts_items = random.sample(list(prot_to_sym.items()), min(n, len(prot_to_sym)))
        stp_items = random.sample(list(sym_to_prot.items()), min(n, len(sym_to_prot)))
    else:
        pts_items = sorted(prot_to_sym.items(), key=lambda kv: kv[0])[:n]
        stp_items = sorted(sym_to_prot.items(), key=lambda kv: kv[0])[:n]

    print("\n[preview] prot_to_sym:")
    for k, v in pts_items:
        print(f"  {k}  ->  {v}")

    print("\n[preview] sym_to_prot:")
    for k, v in stp_items:
        print(f"  {k}  ->  {v}")

    display(pd.DataFrame(pts_items, columns=["protein_id", "symbol"]).head(n))
    display(pd.DataFrame(stp_items, columns=["symbol", "protein_id"]).head(n))

prot_to_sym, sym_to_prot = read_aliases_gz(ALIASES_GZ)
preview_mappings(prot_to_sym, sym_to_prot, n=10, random_sample=False)


[counts] prot_to_sym=19,699, sym_to_prot=19,691

[preview] prot_to_sym:
  9606.ENSP00000000233  ->  ARF5
  9606.ENSP00000000412  ->  M6PR
  9606.ENSP00000001008  ->  FKBP4
  9606.ENSP00000001146  ->  CYP26B1
  9606.ENSP00000002125  ->  NDUFAF7
  9606.ENSP00000002165  ->  FUCA2
  9606.ENSP00000002596  ->  HS3ST1
  9606.ENSP00000002829  ->  SEMA3F
  9606.ENSP00000003084  ->  CFTR
  9606.ENSP00000003100  ->  CYP51A1

[preview] sym_to_prot:
  A1BG  ->  9606.ENSP00000263100
  A1CF  ->  9606.ENSP00000378868
  A2M  ->  9606.ENSP00000323929
  A2ML1  ->  9606.ENSP00000299698
  A3GALT2  ->  9606.ENSP00000475261
  A4GALT  ->  9606.ENSP00000384794
  A4GNT  ->  9606.ENSP00000236709
  AAAS  ->  9606.ENSP00000209873
  AACS  ->  9606.ENSP00000324842
  AADAC  ->  9606.ENSP00000232892


Unnamed: 0,protein_id,symbol
0,9606.ENSP00000000233,ARF5
1,9606.ENSP00000000412,M6PR
2,9606.ENSP00000001008,FKBP4
3,9606.ENSP00000001146,CYP26B1
4,9606.ENSP00000002125,NDUFAF7
5,9606.ENSP00000002165,FUCA2
6,9606.ENSP00000002596,HS3ST1
7,9606.ENSP00000002829,SEMA3F
8,9606.ENSP00000003084,CFTR
9,9606.ENSP00000003100,CYP51A1


Unnamed: 0,symbol,protein_id
0,A1BG,9606.ENSP00000263100
1,A1CF,9606.ENSP00000378868
2,A2M,9606.ENSP00000323929
3,A2ML1,9606.ENSP00000299698
4,A3GALT2,9606.ENSP00000475261
5,A4GALT,9606.ENSP00000384794
6,A4GNT,9606.ENSP00000236709
7,AAAS,9606.ENSP00000209873
8,AACS,9606.ENSP00000324842
9,AADAC,9606.ENSP00000232892


In [6]:
def read_links_gz(path, score_threshold=CONFIDENCE_THRESHOLD, max_edges=MAX_EDGES):
    rows = []
    with gzip.open(path, "rt", encoding="utf-8", errors="ignore") as f:
        header = f.readline().strip().split()
        if len(header) < 3:
            f.seek(0); header = None
        count = 0
        for line in f:
            if line.startswith("#"): continue
            parts = line.strip().split()
            if header is None:
                if len(parts) < 3: continue
                p1, p2, sc = parts[0], parts[1], parts[2]
            else:
                def col(name, default_idx=-1):
                    try:
                        return parts[header.index(name)]
                    except Exception:
                        return parts[default_idx] if default_idx >= 0 else None
                p1 = col("protein1", 0)
                p2 = col("protein2", 1)
                sc = None
                for c in ["combined_score","experimental","experimental_score","score"]:
                    if c in header:
                        try:
                            sc = parts[header.index(c)]; break
                        except: pass
                if sc is None: sc = parts[-1]
            try: sc = int(float(sc))
            except: continue
            if sc >= score_threshold:
                rows.append((p1, p2, sc))
                count += 1
                if max_edges is not None and count >= max_edges:
                    break
    df = pd.DataFrame(rows, columns=["protein1","protein2","score"])
    df = df[df["protein1"] != df["protein2"]].copy()
    df["pair"] = df.apply(lambda r: tuple(sorted((r["protein1"], r["protein2"]))), axis=1)
    df = df.drop_duplicates("pair").drop(columns=["pair"])
    return df

links_df = read_links_gz(LINKS_GZ)
proteins = pd.Index(pd.unique(links_df[["protein1","protein2"]].values.ravel()))
pid_to_idx = {pid:i for i,pid in enumerate(proteins)}
idx_to_pid = {i:pid for pid,i in pid_to_idx.items()}

edges = np.array([(pid_to_idx[p1], pid_to_idx[p2]) for p1,p2 in links_df[["protein1","protein2"]].values], dtype=np.int64)
print(f"#nodes={len(proteins)}, #edges={len(edges)}")

(pd.Series(proteins, name="protein_id").to_frame()
 .assign(gene_symbol=lambda df: df["protein_id"].map(lambda x: prot_to_sym.get(x, None)))
 .to_csv(DATA_ROOT/"protein_index_map.csv", index_label="node_idx"))


#nodes=14481, #edges=162305


In [7]:
REL = 0  # single relation id

def make_triples(edges, both=True):
    triples = []
    for u,v in edges:
        triples.append((u, REL, v))
        if both:
            triples.append((v, REL, u))
    return np.array(triples, dtype=np.int64)

triples = make_triples(edges, both=True)
num = len(triples)
perm = np.random.permutation(num)
n_train = int(num * TRAIN_RATIO)
n_valid = int(num * VALID_RATIO)
train_triples = triples[perm[:n_train]]
valid_triples = triples[perm[n_train:n_train+n_valid]]
test_triples  = triples[perm[n_train+n_valid:]]
len(train_triples), len(valid_triples), len(test_triples)


(275918, 16230, 32462)

In [8]:
import torch
from torch_geometric.data import Data

# direct relation id = 0, inverse relation id = 1
NUM_RELATIONS = 2

edge_index_fwd = torch.as_tensor(edges.T, dtype=torch.long)   # [2, E]
edge_index_rev = edge_index_fwd.flip(0)                       # [2, E] v->u

edge_index = torch.cat([edge_index_fwd, edge_index_rev], dim=1)  # [2, 2E]
edge_type  = torch.cat([
    torch.zeros(edge_index_fwd.size(1), dtype=torch.long),       # 0
    torch.ones(edge_index_rev.size(1), dtype=torch.long)         # 1
], dim=0)

data = Data(edge_index=edge_index, edge_type=edge_type)
data.num_nodes = len(proteins)
data.num_edges = edge_index.size(1)
data.num_relations = NUM_RELATIONS

print(data)


Data(edge_index=[2, 324610], edge_type=[324610], num_nodes=14481, num_edges=324610, num_relations=2)


In [9]:
# === PyG MessagePassing init shim (RUN ONCE, before creating the model) ===
from torch_geometric.nn.conv import MessagePassing

# 기존 __init__을 감싸서, 인스턴스 생성 시 구버전 필드명들도 채워준다.
if not hasattr(MessagePassing, "_nbfnet_init_shim_applied"):
    _orig_init = MessagePassing.__init__

    def _init_shim(self, *args, **kwargs):
        _orig_init(self, *args, **kwargs)
        # 최신 필드명을 구버전 이름으로 복사
        if hasattr(self, "_fused_user_args") and not hasattr(self, "__fused_user_args__"):
            self.__fused_user_args__ = self._fused_user_args
        if hasattr(self, "_user_args") and not hasattr(self, "__user_args__"):
            self.__user_args__ = self._user_args
        if hasattr(self, "_special_args") and not hasattr(self, "__special_args__"):
            self.__special_args__ = self._special_args

    MessagePassing.__init__ = _init_shim
    MessagePassing._nbfnet_init_shim_applied = True

# map new (single-underscore) names -> old (double-underscore) names expected by nbfnet-pyg
if not hasattr(MessagePassing, "__check_input__") and hasattr(MessagePassing, "_check_input"):
    MessagePassing.__check_input__ = MessagePassing._check_input
if not hasattr(MessagePassing, "__collect__") and hasattr(MessagePassing, "_collect"):
    MessagePassing.__collect__ = MessagePassing._collect
if not hasattr(MessagePassing, "__fused_user_args__") and hasattr(MessagePassing, "_fused_user_args"):
    MessagePassing.__fused_user_args__ = MessagePassing._fused_user_args
# 추가: size 처리
if not hasattr(MessagePassing, "__set_size__") and hasattr(MessagePassing, "_set_size"):
    MessagePassing.__set_size__ = MessagePassing._set_size
# 일부 PyG 버전에서는 lift도 단일 밑줄로 제공
if not hasattr(MessagePassing, "__lift__") and hasattr(MessagePassing, "_lift"):
    MessagePassing.__lift__ = MessagePassing._lift
# === Torch sparse_csr/csc compatibility shim (RUN ONCE before training) ===
import torch

# PyTorch<1.10 에는 sparse_csr / sparse_csc 속성이 없음 → 더미 속성으로 막아줌
if not hasattr(torch, "sparse_csr"):
    torch.sparse_csr = object()
if not hasattr(torch, "sparse_csc"):
    torch.sparse_csc = object()


In [16]:
# ==== Fix nvcc toolchain for rspmm build (RUN ONCE, before importing NBFNet) ====
import os, sys, shutil, subprocess, importlib
from pathlib import Path
import torch.utils.cpp_extension as cpp_ext

# 0) 기존 중복 패치 제거
cpp_ext = importlib.reload(cpp_ext)

# 1) conda의 gcc-10 / g++-10 우선 선택 (없으면 일반 conda gcc, 마지막엔 시스템 g++까지)
PREFIX   = Path(sys.prefix)
cand_gcc = [
    PREFIX/"bin/x86_64-conda-linux-gnu-gcc-10",
    PREFIX/"bin/x86_64-conda-linux-gnu-gcc",
    Path("/usr/bin/gcc-10"),
    Path("/usr/bin/gcc"),
]
cand_gxx = [
    PREFIX/"bin/x86_64-conda-linux-gnu-g++-10",
    PREFIX/"bin/x86_64-conda-linux-gnu-g++",
    Path("/usr/bin/g++-10"),
    Path("/usr/bin/g++"),
]
CONDAGCC = next((p for p in cand_gcc if p.exists()), cand_gcc[-1])
CONDAGXX = next((p for p in cand_gxx if p.exists()), cand_gxx[-1])
print("use gcc :", CONDAGCC)
print("use g++ :", CONDAGXX)
assert CONDAGCC.exists() and CONDAGXX.exists(), "conda/system gcc/g++ not found"

# 2) libstdc++ include 경로 (선택된 컴파일러 버전에 맞춰 계산)
def _detect_includes(gcc_path: Path):
    if gcc_path.name.endswith("gcc-10"):
        base = Path("/usr/include/c++/10")
        tgt  = Path("/usr/include/x86_64-linux-gnu/c++/10")
        return base, tgt
    ver = subprocess.check_output([str(gcc_path), "-dumpversion"]).decode().strip()
    base = PREFIX / "x86_64-conda-linux-gnu" / "include" / "c++" / ver
    tgt  = base / "x86_64-conda-linux-gnu"
    return base, tgt

INC_BASE, INC_TGT = _detect_includes(CONDAGCC)
print("libstdc++ inc base:", INC_BASE, INC_BASE.exists())
print("libstdc++ inc tgt :", INC_TGT , INC_TGT.exists())

# (추가) conda nvcc가 우선 잡히도록 PATH 앞에 붙임
os.environ["PATH"] = f"{str(Path(sys.prefix)/'bin')}:{os.environ['PATH']}"

# (추가) conda CUDA 헤더 경로를 -isystem으로 같이 주기 위해 기억
CUDA_INC_CONDA = Path(sys.prefix) / "targets" / "x86_64-linux" / "include"

# 3) nvcc / ninja 확인
nvcc  = shutil.which("nvcc")
ninja = shutil.which("ninja") or shutil.which("ninja-build")
print("nvcc:", nvcc)
print("ninja:", ninja)
assert nvcc  is not None, "nvcc not found (e.g., `conda install -y cuda-nvcc`)"
assert ninja is not None, "ninja not found (e.g., `conda install -y ninja` or `pip install ninja`)"

# 4) 환경변수 고정
os.environ["CC"]          = str(CONDAGCC)
os.environ["CXX"]         = str(CONDAGXX)
os.environ["CUDAHOSTCXX"] = str(CONDAGXX)
os.environ.setdefault("TORCH_CUDA_ARCH_LIST", "8.6")  # RTX A6000
TORCH_EXT_DIR = (Path("./nbfnet_runs/torch_ext").resolve())
os.environ["TORCH_EXTENSIONS_DIR"] = str(TORCH_EXT_DIR)
print("TORCH_CUDA_ARCH_LIST:", os.environ["TORCH_CUDA_ARCH_LIST"])
print("TORCH_EXTENSIONS_DIR:", os.environ["TORCH_EXTENSIONS_DIR"])

# 5) 실패 캐시 정리
shutil.rmtree(os.path.expanduser("~/.cache/torch_extensions"), ignore_errors=True)
shutil.rmtree(TORCH_EXT_DIR, ignore_errors=True)
TORCH_EXT_DIR.mkdir(parents=True, exist_ok=True)

# 6) cpp_extension.load 단일 패치:
#    - 기존 -ccbin(공백형/=/형 모두) 제거
#    - 선택된 g++로 -ccbin 강제
#    - libstdc++ 헤더를 -isystem으로 우선 사용
_orig_load = cpp_ext.load
def _patched_load(name, sources, extra_cflags=None, extra_cuda_cflags=None, **kw):
    ecf  = list(extra_cflags or [])
    eccf = list(extra_cuda_cflags or [])
    if CUDA_INC_CONDA.exists():
        ecf.append(f"-isystem{CUDA_INC_CONDA}")

    cleaned = []
    i = 0
    while i < len(eccf):
        f = eccf[i]
        if f == "-ccbin":
            i += 2
            continue
        if str(f).startswith("-ccbin="):
            i += 1
            continue
        cleaned.append(f)
        i += 1
    eccf = cleaned
    eccf.append(f"-ccbin={CONDAGXX}")

    if INC_BASE.exists(): ecf.append(f"-isystem{INC_BASE}")
    if INC_TGT.exists():  ecf.append(f"-isystem{INC_TGT}")

    kw.pop("verbose", None)  # 중복 verbose 방지

    print(f"[patch] -ccbin={CONDAGXX}")
    if INC_BASE.exists(): print(f"[patch] -isystem {INC_BASE}")
    if INC_TGT.exists():  print(f"[patch] -isystem {INC_TGT}")

    return _orig_load(name, sources, extra_cflags=ecf, extra_cuda_cflags=eccf, verbose=True, **kw)

cpp_ext.load = _patched_load
print("OK: cpp_extension.load single-patched (toolchain pinned)")


use gcc : /data2/project/bin_jip/miniconda3/envs/nbf37/bin/x86_64-conda-linux-gnu-gcc
use g++ : /data2/project/bin_jip/miniconda3/envs/nbf37/bin/x86_64-conda-linux-gnu-g++
libstdc++ inc base: /data2/project/bin_jip/miniconda3/envs/nbf37/x86_64-conda-linux-gnu/include/c++/10.4.0 True
libstdc++ inc tgt : /data2/project/bin_jip/miniconda3/envs/nbf37/x86_64-conda-linux-gnu/include/c++/10.4.0/x86_64-conda-linux-gnu True
nvcc: /data2/project/bin_jip/miniconda3/envs/nbf37/bin/nvcc
ninja: /data2/project/bin_jip/miniconda3/envs/nbf37/bin/ninja
TORCH_CUDA_ARCH_LIST: 8.6
TORCH_EXTENSIONS_DIR: /data2/project/bin_jip/nbfnet_pyg/nbfnet_runs/torch_ext
OK: cpp_extension.load single-patched (toolchain pinned)


In [17]:
from nbfnet.models import NBFNet

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print("device =", device)

model = NBFNet(
    input_dim=INPUT_DIM,
    hidden_dims=HIDDEN_DIMS,
    num_relation=NUM_RELATIONS,      # 2 (정/역)
    message_func=MESSAGE_FUNCT,      # "distmult"
    aggregate_func=AGGREGATE_FUNC,   # "pna"
    short_cut=SHORTCUT,
    layer_norm=LAYER_NORM,
    activation="relu",
    concat_hidden=False,
    num_mlp_layer=2,
    dependent=True,
    remove_one_hop=True,
    num_beam=10,
    path_topk=10,
).to(device)

# ★ 변경점: graph도 한 번만 옮겨놓고 이후에는 data 그대로 사용
data = data.to(device)

sum(p.numel() for p in model.parameters())/1e6


device = cuda:1


0.387329

In [18]:
import math, time
import torch.nn.functional as F

def make_tail_neg_batch(pos_triples, num_nodes, num_neg, rel_id=0):
    B = pos_triples.size(0)
    h = pos_triples[:, 0]
    t_pos = pos_triples[:, 2]
    r = torch.full_like(h, fill_value=rel_id)

    neg_tails = torch.randint(0, num_nodes, (B, num_neg), device=pos_triples.device)
    mask = (neg_tails == t_pos.unsqueeze(1))
    while mask.any():
        neg_tails[mask] = torch.randint(0, num_nodes, (mask.sum(),), device=pos_triples.device)
        mask = (neg_tails == t_pos.unsqueeze(1))

    h_mat = h.unsqueeze(1).expand(-1, 1+num_neg)
    t_mat = torch.cat([t_pos.unsqueeze(1), neg_tails], dim=1)
    r_mat = torch.full_like(t_mat, fill_value=rel_id)

    batch = torch.stack([h_mat, t_mat, r_mat], dim=-1).long()
    return batch

def bce_loss_from_scores(scores):
    labels = torch.zeros_like(scores)
    labels[:, 0] = 1.0
    return F.binary_cross_entropy_with_logits(scores, labels)

@torch.no_grad()
def eval_auc(model, data, triples, num_neg=64, rel_id=0, batch_size=4096):
    model.eval()
    total, correct = 0, 0
    for i in range(0, len(triples), batch_size):
        pos = torch.as_tensor(triples[i:i+batch_size], device=device)
        batch = make_tail_neg_batch(pos, data.num_nodes, num_neg, rel_id=rel_id)
        # ★ 변경점: data는 이미 device에 있음 → 매번 .to(device) 하지 않음
        scores = model(data, batch).squeeze(-1)  # [B, 1+K]
        pos_s = scores[:, 0:1]
        neg_s = scores[:, 1:]
        correct += (pos_s > neg_s).float().mean(dim=1).sum().item()
        total += pos.size(0)
    return correct / total


In [19]:
train_t = torch.as_tensor(train_triples, device=device)
valid_t = torch.as_tensor(valid_triples, device=device)
test_t  = torch.as_tensor(test_triples , device=device)

print("CUDA available:", torch.cuda.is_available())
print("GPU count visible to torch:", torch.cuda.device_count())
for i in range(torch.cuda.device_count()):
    print(f" [{i}] name={torch.cuda.get_device_name(i)}")


CUDA available: True
GPU count visible to torch: 2
 [0] name=NVIDIA RTX A6000
 [1] name=NVIDIA RTX A6000


In [20]:
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
best_val = -1
best_state = None

for epoch in range(1, NUM_EPOCHS+1):
    model.train()
    t0 = time.time()
    # ★ 변경점: randperm은 CPU에서 생성 (device 인자 제거)
    idx = torch.randperm(train_t.size(0))

    train_loss = 0.0
    for i in range(0, train_t.size(0), BATCH_SIZE):
        pos = train_t[idx[i:i+BATCH_SIZE]]
        batch = make_tail_neg_batch(pos, data.num_nodes, NUM_NEG, rel_id=0)
        scores = model(data, batch)  # [B, 1+K]
        loss = bce_loss_from_scores(scores)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * pos.size(0)
    train_loss /= train_t.size(0)

    val_metric = eval_auc(model, data, valid_triples, num_neg=64, rel_id=0, batch_size=8192)
    dt = time.time() - t0
    print(f"[{epoch:02d}] loss={train_loss:.4f}  val@AUC≈{val_metric:.4f}  ({dt:.1f}s)")
    if val_metric > best_val:
        best_val = val_metric
        best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}

# best 로드 후 테스트 점검
if best_state is not None:
    model.load_state_dict(best_state)
test_metric = eval_auc(model, data, test_triples, num_neg=64, rel_id=0, batch_size=8192)
print(f"Test AUC≈{test_metric:.4f}")


Load rspmm extension. This may take a while...
[patch] -ccbin=/data2/project/bin_jip/miniconda3/envs/nbf37/bin/x86_64-conda-linux-gnu-g++
[patch] -isystem /data2/project/bin_jip/miniconda3/envs/nbf37/x86_64-conda-linux-gnu/include/c++/10.4.0
[patch] -isystem /data2/project/bin_jip/miniconda3/envs/nbf37/x86_64-conda-linux-gnu/include/c++/10.4.0/x86_64-conda-linux-gnu
Using /data2/project/bin_jip/nbfnet_pyg/nbfnet_runs/torch_ext as PyTorch extensions root...
Creating extension directory /data2/project/bin_jip/nbfnet_pyg/nbfnet_runs/torch_ext/rspmm...
Detected CUDA files, patching ldflags
Emitting ninja build file /data2/project/bin_jip/nbfnet_pyg/nbfnet_runs/torch_ext/rspmm/build.ninja...
Building extension module rspmm...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
[1/3] /usr/lib/nvidia-cuda-toolkit/bin/nvcc --generate-dependencies-with-compile --dependency-output rspmm.cuda.o.d -ccbin /data2/project/bin_jip/miniconda

RuntimeError: Error building extension 'rspmm'

In [12]:
PAM50 = [
    "BCL2", "BIRC5", "CCNB1", "CDC20", "CDH3", "CENPF", "CXXC5", "EGFR", "ERBB2", "ESR1",
    "EXO1", "FGFR4", "FOXA1", "GRB7", "KIF2C", "KRT14", "KRT17", "KRT5", "MELK", "MIA",
    "MK167", "MMP11", "MYBL2", "NAT1", "ORC6", "PGR", "PHGDH", "PTTG1", "RRM2", "SLC39A6",
    "TMEM45B", "TYMS", "UBE2C", "UBE2T", "BAG1", "BLVRA", "CXXC5", "FOXA1", "MAPT", "MMP11",
    "MMP7", "MMP9", "MYC", "NDC80", "NUF2", "SFRP1", "UBE2C", "CCNE1", "KRT8", "KRT18"
]

def symbol_to_nodeidx(sym):
    pid = sym_to_prot.get(sym)
    if pid is None:
        return None
    return pid_to_idx.get(pid)

pam50_nodes = {sym: symbol_to_nodeidx(sym) for sym in PAM50}
present = {sym: idx for sym, idx in pam50_nodes.items() if idx is not None}
missing = [sym for sym, idx in pam50_nodes.items() if idx is None]
print(f"PAM50 present: {len(present)} / 50  | missing: {len(missing)}")

@torch.no_grad()
def rank_neighbors_for_head(model, data, h_idx, topk=TOPK, rel_id=0, chunk=8192):
    model.eval()
    all_scores, all_nodes = [], []
    N = data.num_nodes
    h = torch.full((1,), h_idx, dtype=torch.long, device=device)
    r = torch.full((1,), rel_id, dtype=torch.long, device=device)
    for start in range(0, N, chunk):
        cand = torch.arange(start, min(start+chunk, N), device=device)
        h_mat = h.unsqueeze(1).expand(1, cand.numel())
        t_mat = cand.unsqueeze(0)
        r_mat = r.unsqueeze(1).expand(1, cand.numel())

        batch = torch.stack([h_mat, t_mat, r_mat], dim=-1).long()

        scores = model(data, batch).squeeze(0)  # [L]
        all_scores.append(scores.detach().cpu())
        all_nodes.append(cand.detach().cpu())
    scores = torch.cat(all_scores)
    nodes  = torch.cat(all_nodes)

    u = h_idx
    known_v = set(edges[edges[:,0] == u][:,1].tolist())
    known_v.add(u)
    mask = torch.tensor([n not in known_v for n in nodes.tolist()])
    scores = scores[mask]; nodes = nodes[mask]

    topk_idx = torch.topk(scores, k=min(topk, scores.numel())).indices
    top_nodes = nodes[topk_idx].tolist()
    top_scores = scores[topk_idx].tolist()
    return list(zip(top_nodes, top_scores))

def idx_to_symbol(node_idx):
    pid = idx_to_pid[node_idx]
    return prot_to_sym.get(pid, pid)

pam50_suggestions = {}
for sym, idx in present.items():
    pairs = rank_neighbors_for_head(model, data, idx, topk=TOPK, rel_id=0)
    pam50_suggestions[sym] = [(idx_to_symbol(n), float(s)) for n, s in pairs]

import pandas as pd
rows = []
for sym, lst in pam50_suggestions.items():
    for rank, (gsym, score) in enumerate(lst, start=1):
        rows.append({"seed_gene": sym, "rank": rank, "candidate_gene": gsym, "score": score})
pam50_df = pd.DataFrame(rows).sort_values(["seed_gene","rank"])
display(pam50_df.head(20))
pam50_df.to_csv(RUN_DIR / f"{EXPT_NAME}_pam50_top{TOPK}.csv", index=False)


PAM50 present: 43 / 50  | missing: 3


AttributeError: 'GeneralizedRelationalConv' object has no attribute '__check_input__'

In [15]:
import gc
gc.collect()
torch.cuda.empty_cache()