### I. Dataset

In [30]:
from pathlib import Path
import json, ast
import pandas as pd

ROOT = Path(".").resolve()          # your notebook folder
PATH_TRAIN = ROOT / "dataset" / "train.txt"
PATH_VAL   = ROOT / "dataset" / "val.txt"
PATH_TEST  = ROOT / "dataset" / "test.txt"

PATH_TRAIN, PATH_VAL, PATH_TEST

def load_split_whole_file(path: Path) -> pd.DataFrame:
    # read the whole file into one string
    with path.open("r", encoding="utf-8") as f:
        content = f.read()
    
    # try JSON first
    try:
        obj = json.loads(content)
    except json.JSONDecodeError:
        # if it's not valid JSON, treat it as a Python literal
        obj = ast.literal_eval(content)

    # we expect a *list of dicts*
    if not isinstance(obj, list):
        raise TypeError(f"Expected a list, got {type(obj)} from {path}")
    
    # one row per paper
    df = pd.json_normalize(obj)   # or pd.DataFrame(obj)
    return df

df_train = load_split_whole_file(PATH_TRAIN)
df_val   = load_split_whole_file(PATH_VAL)
df_test  = load_split_whole_file(PATH_TEST)

df_train.shape, df_val.shape, df_test.shape



((42000, 11), (9000, 11), (9000, 11))

In [33]:
df_train["split"] = "train"
df_val["split"]   = "val"
df_test["split"]  = "test"

df_raw = pd.concat([df_train, df_val, df_test], ignore_index=True)
df_raw.shape, df_raw["split"].value_counts()


((60000, 12),
 split
 train    42000
 val       9000
 test      9000
 Name: count, dtype: int64)

In [35]:
def preprocess_papers(df):
    df = df.copy()

    # keep only non-empty title & abstract
    df = df[df["title"].notna() & df["abstract"].notna()]
    df = df[df["abstract"].str.len() > 20]

    # parse "179;238;..." → [179, 238, ...]
    def parse_citations(cstr):
        if not isinstance(cstr, str):
            return []
        out = []
        for c in cstr.split(";"):
            c = c.strip()
            if c.isdigit():
                out.append(int(c))
        return out

    df["citation_list"]  = df["Citations"].apply(parse_citations)
    df["citation_count"] = df["citation_list"].str.len()

    # optional: filter
    df = df[df["citation_count"] > 0]
    df["publication_ID"] = df["publication_ID"].astype(int)

    return df

df = preprocess_papers(df_raw)
df.head()


Unnamed: 0,publication_ID,Citations,pubDate,language,title,journal,abstract,keywords,authors,venue,doi,split,citation_list,citation_count
0,17396995,17957262;21818356;24164861;21818356;24164861;2...,2007 May 1,eng,Herpes simplex virus type 2 infection does not...,The Journal of infectious diseases,We sought to compare baseline and longitudinal...,Adult;California;epidemiology;Cohort Studies;H...,"[{'name': 'Edward R Cachay', 'org': 'Universit...",{'name': 'The Journal of infectious diseases'},10.1086/513568,train,"[17957262, 21818356, 24164861, 21818356, 24164...",7
1,16779733,19197361;19399183;20041174;20300572;17311474;2...,2006 Jul 15,eng,Efficacy of the anti Candida rAls3p N or rAls1...,The Journal of infectious diseases.,We have shown that vaccination with the recomb...,Animals;Candida;immunology;isolation & purific...,"[{'name': 'Brad J Spellberg', 'org': 'Departme...",{'name': 'The Journal of infectious diseases'},10.1086/504691,train,"[19197361, 19399183, 20041174, 20300572, 17311...",21
2,12412787,28740334,2002 Nov,eng,Role of the interleukin 6 interleukin 6 solubl...,Journal of bone and mineral research : the off...,We have observed a strong correlation between ...,Adult;Animals;Bone Resorption;metabolism;Bone ...,"[{'name': 'Karl Insogna', 'org': 'Department o...",{'name': 'Journal of bone and mineral research...,0,train,[28740334],1
3,18070707,22567368;22348393;22495885;23874387;23100393;2...,2007 Dec,eng,Genetic events in the pathogenesis of multiple...,Best practice & research. Clinical haematology,The genetics of myeloma has been increasingly ...,Gene Expression Profiling;Humans;Immunoglobuli...,"[{'name': 'W.J. Chng', 'org': 'Mayo Clinic Ari...",{'name': 'Best practice & research. Clinical h...,10.1016/j.beha.2007.08.004,train,"[22567368, 22348393, 22495885, 23874387, 23100...",18
4,16365419,20498830;26334995,2006 Jan 01,eng,PU 1 regulates cathepsin S expression in profe...,"Journal of immunology (Baltimore, Md. : 1950)",Cathepsin S (CTSS) is a cysteine protease that...,Animals;Antigen-Presenting Cells;immunology;me...,"[{'name': 'Ying Wang', 'id': '53f5626bdabfae5d...","{'name': 'Journal of immunology (Baltimore, Md...",10.4049/jimmunol.176.1.275,train,"[20498830, 26334995]",2


In [37]:
paper_ids = sorted(df["publication_ID"].unique())
paper_id2idx = {pid: i for i, pid in enumerate(paper_ids)}
idx2paper_id = {i: pid for pid, i in paper_id2idx.items()}
num_nodes = len(paper_ids)
num_nodes


56416

In [38]:
paper_set = set(paper_ids)

def build_edges_for_split(df_split):
    src_list, dst_list = [], []
    for _, row in df_split.iterrows():
        src_pid = int(row["publication_ID"])
        if src_pid not in paper_id2idx:
            continue
        src_idx = paper_id2idx[src_pid]
        for cited_pid in row["citation_list"]:
            if cited_pid in paper_set:
                dst_idx = paper_id2idx[cited_pid]
                src_list.append(src_idx)
                dst_list.append(dst_idx)
    return src_list, dst_list

df_tr = df[df["split"] == "train"]
df_va = df[df["split"] == "val"]
df_te = df[df["split"] == "test"]

train_src, train_dst = build_edges_for_split(df_tr)
val_src,   val_dst   = build_edges_for_split(df_va)
test_src,  test_dst  = build_edges_for_split(df_te)

len(train_src), len(val_src), len(test_src)


(6789, 1481, 1446)

In [39]:
from sentence_transformers import SentenceTransformer
import numpy as np

model = SentenceTransformer("all-MiniLM-L6-v2") 

texts = (df["title"].fillna("") + " " + df["abstract"].fillna("")).tolist()
emb = model.encode(texts, batch_size=64, show_progress_bar=True)  # (N_df, d)

emb_dim = emb.shape[1]
x = np.zeros((num_nodes, emb_dim), dtype="float32")

# put embeddings into the correct row for each node index
for i, (_, row) in enumerate(df.iterrows()):
    pid = int(row["publication_ID"])
    idx = paper_id2idx[pid]
    x[idx] = emb[i]


  from .autonotebook import tqdm as notebook_tqdm
Batches: 100%|██████████| 910/910 [04:14<00:00,  3.57it/s]


In [40]:
import torch
from torch_geometric.data import Data

x_tensor = torch.tensor(x, dtype=torch.float32)

edge_index_train = torch.tensor([train_src, train_dst], dtype=torch.long)
edge_index_val   = torch.tensor([val_src,   val_dst],   dtype=torch.long)
edge_index_test  = torch.tensor([test_src,  test_dst],  dtype=torch.long)

data = Data(x=x_tensor, edge_index=edge_index_train)
data.train_pos_edge_index = edge_index_train
data.val_pos_edge_index   = edge_index_val
data.test_pos_edge_index  = edge_index_test

data


Data(x=[56416, 384], edge_index=[2, 6789], train_pos_edge_index=[2, 6789], val_pos_edge_index=[2, 1481], test_pos_edge_index=[2, 1446])

In [41]:
from torch_geometric.nn import SAGEConv
import torch.nn.functional as F
from torch.nn import Linear, Module

class GraphSAGEEncoder(Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.conv1 = SAGEConv(in_dim, hidden_dim)
        self.conv2 = SAGEConv(hidden_dim, out_dim)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x

class LinkPredictor(Module):
    def __init__(self, in_dim, hidden_dim=64):
        super().__init__()
        self.lin1 = Linear(in_dim * 2, hidden_dim)
        self.lin2 = Linear(hidden_dim, 1)

    def forward(self, z_src, z_dst):
        h = torch.cat([z_src, z_dst], dim=-1)
        h = F.relu(self.lin1(h))
        h = self.lin2(h)
        return torch.sigmoid(h).view(-1)


In [42]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

in_dim = data.x.size(1)
hidden_dim = 256
out_dim = 256   # embedding size for papers

encoder = GraphSAGEEncoder(in_dim, hidden_dim, out_dim).to(device)
predictor = LinkPredictor(out_dim).to(device)

data = data.to(device)
optimizer = torch.optim.Adam(
    list(encoder.parameters()) + list(predictor.parameters()),
    lr=1e-3,
    weight_decay=1e-5,
)


In [43]:
import torch

def negative_sampling(num_nodes, pos_edge_index, num_neg_samples=None):
    """
    Very simple uniform negative sampling.
    Returns [2, num_neg_samples].
    """
    if num_neg_samples is None:
        num_neg_samples = pos_edge_index.size(1)

    # set of existing positive edges to avoid sampling them as negatives
    pos = set(
        (int(s), int(d))
        for s, d in zip(pos_edge_index[0].tolist(), pos_edge_index[1].tolist())
    )

    neg_src = []
    neg_dst = []

    # brute-force is OK for moderate graphs (your subset)
    while len(neg_src) < num_neg_samples:
        src = torch.randint(0, num_nodes, (1,)).item()
        dst = torch.randint(0, num_nodes, (1,)).item()
        if (src, dst) in pos:
            continue
        neg_src.append(src)
        neg_dst.append(dst)

    neg_edge_index = torch.tensor([neg_src, neg_dst], dtype=torch.long, device=pos_edge_index.device)
    return neg_edge_index


In [None]:
def train_step(data):
    encoder.train()
    predictor.train()

    optimizer.zero_grad()

    # 1. Node embeddings from GraphSAGE
    z = encoder(data.x, data.edge_index)  # [N, out_dim]

    pos_edge_index = data.train_pos_edge_index  # [2, E_train]
    num_nodes = data.num_nodes

    # 2. Positive samples
    src_pos = pos_edge_index[0]
    dst_pos = pos_edge_index[1]
    z_src_pos = z[src_pos]
    z_dst_pos = z[dst_pos]
    pred_pos = predictor(z_src_pos, z_dst_pos)

    # 3. Negative samples (same number as positives)
    neg_edge_index = negative_sampling(num_nodes, pos_edge_index,
                                       num_neg_samples=pos_edge_index.size(1))
    src_neg = neg_edge_index[0]
    dst_neg = neg_edge_index[1]
    z_src_neg = z[src_neg]
    z_dst_neg = z[dst_neg]
    pred_neg = predictor(z_src_neg, z_dst_neg)

    # 4. Loss: encourage high scores for pos, low for neg
    loss_pos = -torch.log(pred_pos + 1e-15).mean()
    loss_neg = -torch.log(1 - pred_neg + 1e-15).mean()
    loss = loss_pos + loss_neg

    # 5. Backprop once
    loss.backward()
    optimizer.step()

    return float(loss.item())


In [45]:
def recall_at_k(z, pos_edge_index, k=10, num_neg_candidates=99):
    num_nodes = z.size(0)
    pos_src = pos_edge_index[0]
    pos_dst = pos_edge_index[1]
    num_pos = pos_edge_index.size(1)

    hits = 0

    for i in range(num_pos):
        src = int(pos_src[i])
        true_dst = int(pos_dst[i])

        # sample negatives
        neg_dst = []
        while len(neg_dst) < num_neg_candidates:
            cand = torch.randint(0, num_nodes, (1,)).item()
            if cand == true_dst:
                continue
            neg_dst.append(cand)

        candidates = torch.tensor([true_dst] + neg_dst, device=z.device)
        src_batch = torch.full_like(candidates, src)

        scores = predictor(z[src_batch], z[candidates])  # [num_candidates]

        # higher score = more likely citation
        topk = torch.topk(scores, k=min(k, len(scores)), largest=True).indices
        # index 0 in `candidates` corresponds to true_dst
        if 0 in topk:
            hits += 1

    return hits / num_pos


In [46]:
@torch.no_grad()
def evaluate(data, k=10):
    encoder.eval()
    predictor.eval()

    z = encoder(data.x, data.edge_index)

    val_recall = recall_at_k(z, data.val_pos_edge_index, k=k)
    test_recall = recall_at_k(z, data.test_pos_edge_index, k=k)

    return val_recall, test_recall


In [47]:
num_epochs = 20
best_val = 0.0
best_state = None

for epoch in range(1, num_epochs + 1):
    loss = train_step(data, batch_size=4096)
    val_recall, test_recall = evaluate(data, k=10)

    if val_recall > best_val:
        best_val = val_recall
        best_state = {
            "encoder": encoder.state_dict(),
            "predictor": predictor.state_dict(),
        }

    print(
        f"Epoch {epoch:02d} | "
        f"loss = {loss:.4f} | "
        f"val@10 = {val_recall:.4f} | "
        f"test@10 = {test_recall:.4f}"
    )

# Load best model (according to validation Recall@K)
if best_state is not None:
    encoder.load_state_dict(best_state["encoder"])
    predictor.load_state_dict(best_state["predictor"])


RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.