In [1]:
# Install required packages if needed
# import subprocess; subprocess.run(['pip', 'install', 'torch-geometric', 'sentence-transformers', 'tqdm'], check=True)

In [2]:
import torch
from torch_geometric.data import HeteroData
from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv, Linear
from torch_geometric.utils import train_test_split_edges, negative_sampling
from tqdm import tqdm
import json
import numpy as np

GNN_config = {
    'epoch': 10,
    'batch': 128,
    'dropout': 0.2,
    'early_stopping': 2,
    'learning_rate': 1e-3
}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


Using device: cuda


In [3]:
import os
from pathlib import Path
import shutil

base_dir = Path('.')
data_dir = base_dir / 'data'
checkpoints_dir = base_dir / 'checkpoints'

# Fix nested directory structure if files were transferred from Turing
if (base_dir / 'data' / 'data').exists():
    print("Moving files from data/data/ to data/...")
    for item in (base_dir / 'data' / 'data').iterdir():
        if item.is_file():
            shutil.move(str(item), str(data_dir / item.name))
    print("[OK] Done")

if (base_dir / 'checkpoints' / 'checkpoints').exists():
    print("Moving files from checkpoints/checkpoints/ to checkpoints/...")
    for item in (base_dir / 'checkpoints' / 'checkpoints').iterdir():
        if item.is_file():
            shutil.move(str(item), str(checkpoints_dir / item.name))
    print("[OK] Done")

data_dir.mkdir(exist_ok=True)
checkpoints_dir.mkdir(exist_ok=True)
print(f"Data: {data_dir.absolute()}\nCheckpoints: {checkpoints_dir.absolute()}")

Data: /home/upandit/mag_citation_recommender/data
Checkpoints: /home/upandit/mag_citation_recommender/checkpoints


In [4]:
import urllib.request

urls = {
    'train.txt.zip.001': 'https://github.com/QianWangWPI/Released-Microsoft-dataset/raw/main/train.txt.zip.001',
    'train.txt.zip.002': 'https://github.com/QianWangWPI/Released-Microsoft-dataset/raw/main/train.txt.zip.002',
    'test.txt': 'https://raw.githubusercontent.com/QianWangWPI/Released-Microsoft-dataset/refs/heads/main/test.txt',
    'val.txt': 'https://raw.githubusercontent.com/QianWangWPI/Released-Microsoft-dataset/refs/heads/main/val.txt'
}

print("Downloading dataset files...")
for filename, url in urls.items():
    filepath = data_dir / filename
    if not filepath.exists():
        urllib.request.urlretrieve(url, filepath)
        print(f"[OK] {filename}")
    else:
        print(f"⊙ {filename} (exists)")

print("Download complete!")



Downloading dataset files...
⊙ train.txt.zip.001 (exists)
⊙ train.txt.zip.002 (exists)
⊙ test.txt (exists)
⊙ val.txt (exists)
Download complete!


In [5]:
import zipfile
import glob

train_txt_path = data_dir / "train.txt"
train_zip_path = data_dir / "train.txt.zip"

if not train_txt_path.exists():
    zip_parts = sorted(glob.glob(str(data_dir / "train.txt.zip.*")))
    if zip_parts:
        print("Combining zip parts...")
        with open(train_zip_path, 'wb') as outfile:
            for part in zip_parts:
                with open(part, 'rb') as infile:
                    outfile.write(infile.read())
        print("[OK] Combined")
    
    if train_zip_path.exists():
        print("Extracting...")
        with zipfile.ZipFile(train_zip_path, 'r') as zip_ref:
            zip_ref.extractall(data_dir)
        print("[OK] Extracted")
else:
    print("train.txt already exists")

train.txt already exists


In [6]:
import json

with open(data_dir / "train.txt", encoding="utf-8") as f:
    train_data = json.load(f)
with open(data_dir / "val.txt", encoding="utf-8") as f:
    val_data = json.load(f)
with open(data_dir / "test.txt", encoding="utf-8") as f:
    test_data = json.load(f)

print(f"Train: {len(train_data)}, Val: {len(val_data)}, Test: {len(test_data)}")


Train: 42000, Val: 9000, Test: 9000


In [7]:
train_data[0].keys()

dict_keys(['publication_ID', 'Citations', 'pubDate', 'language', 'title', 'journal', 'abstract', 'keywords', 'authors', 'venue', 'doi'])

In [8]:
import pandas as pd

def preprocess(data):
    df = pd.DataFrame(data)

    # Split citations into list
    df["Citations"] = df["Citations"].apply(lambda x: x.split(";") if isinstance(x, str) else [])

    df["title"] = df["title"].astype(str)
    df["abstract"] = df["abstract"].astype(str)
    df["keywords"] = df["keywords"].astype(str)
    # Extract doi for reporting
    df["doi"] = df["doi"].apply(lambda x: x.strip() if isinstance(x, str) else None)

    # Combine title + abstract + keywords as paper text
    df["text"] = df["title"] + ". " + df["abstract"] + ". " + df["keywords"] + '. doi:' + df['doi']

    # Extract authors as list of IDs
    df["author_ids"] = df["authors"].apply(lambda authors: [a.get('id', a.get('name')) for a in authors] if isinstance(authors, list) else [])


    # Extract keywords as lowercase list
    df["keyword_list"] = df["keywords"].apply(lambda x: [kw.strip().lower() for kw in x.split(';')] if isinstance(x, str) else [])

    # Extract venues as lowercase list (will be node)
    df["venue"] = df["venue"].apply(lambda x: x.split("name':")[-1].strip()[1:-2] if isinstance(x, str) else None)

    # Extract pub date for attributes
    df["pubDate"] = df["pubDate"].apply(lambda x: x.strip() if isinstance(x, str) else None)

    # Comment out the following if we want to consider the full date.
    # My reasoning for only using the year is that it compresses the graph down significantly if multiple papers share a year node. May cause more noise tho
    df["pubDate"] = df["pubDate"].apply(lambda date: int(date.split(' ')[0]) if date[0].isdigit() else None)

    return df[["publication_ID", "title", "text", "Citations", "author_ids", "keyword_list", "venue", "pubDate"]] #Added title for inference later

train_df = preprocess(train_data)
val_df = preprocess(val_data)
test_df = preprocess(test_data)

train_df.head(2)
val_df.head(5)

Unnamed: 0,publication_ID,title,text,Citations,author_ids,keyword_list,venue,pubDate
0,23641233,Local functional connectivity as a pre surgica...,Local functional connectivity as a pre surgica...,"[23847586, 24073391, 26204264]","[560cdf1545ce1e5960a19851, 560cdf1545ce1e5960a...",[0],"Frontiers in neurology', 'id': '5451a5cae0cf0b...",2013
1,17157189,Heat shock response and acute lung injury,Heat shock response and acute lung injury. All...,"[20465849, 23536968, 24524071, 22140545, 21543...","[53f4442cdabfaee43ec75166, 54096d37dabfae8faa6...","[animals, heat-shock proteins, genetics, metab...","Free Radical Biology and Medicine', 'id': '545...",2007
2,15872007,STa and cGMP stimulate CFTR translocation to t...,STa and cGMP stimulate CFTR translocation to t...,"[21347269, 22069681, 24275951, 21347269, 22069...","[53f467c8dabfaeecd6a126b7, 5608ce0745cedb3396d...","[animals, bacterial toxins, pharmacology, biot...",American journal of physiology. Cell physiolog...,2005
3,20360276,Stigma and depression treatment utilization am...,Stigma and depression treatment utilization am...,"[27473569, 26576680, 24938081, 28774339, 29536...","[53f42d05dabfaedf43511829, 53f4263edabfaeb2acf...","[adolescent, adult, aged, antidepressive agent...","Psychiatric services (Washington, D.C.)",2010
4,15963034,Increased incidence and severity of diabetic k...,Increased incidence and severity of diabetic k...,"[31086620, 24355514, 34188679, 35511179]","[53f43fb3dabfaee4dc7be511, 53f4d409dabfaeedd17...","[adolescent, child, child, preschool, colorado...","Pediatric Diabetes', 'id': '5451a5c4e0cf0b02b5...",2005


In [9]:
import itertools

all_df = pd.concat([train_df, val_df, test_df], ignore_index=True)

# All papers (in data or cited)
all_papers = set(all_df['publication_ID'].astype(str))
for cits in all_df['Citations']:
    all_papers.update(cits)
paper2idx = {pid: i for i, pid in enumerate(sorted(all_papers))}

# All authors
all_authors = set(itertools.chain.from_iterable(all_df['author_ids']))
# Filter out None values before sorting
all_authors = {aid for aid in all_authors if aid is not None}
author2idx = {aid: i for i, aid in enumerate(sorted(all_authors))}

# All keywords
all_keywords = set(itertools.chain.from_iterable(all_df['keyword_list']))
keyword2idx = {kw: i for i, kw in enumerate(sorted(all_keywords))}

# All venues
all_venues = set(all_df['venue'].dropna())
venue2idx = {v: i for i, v in enumerate(sorted(all_venues))}

# All pub dates
all_pubDates = set(all_df['pubDate'].dropna())
pubDate2idx = {pd: i for i, pd in enumerate(sorted(all_pubDates))}

print(f"#papers={len(paper2idx)} #authors={len(author2idx)} #keywords={len(keyword2idx)} #venues={len(venue2idx)} #pubDates={len(pubDate2idx)}")

#papers=424279 #authors=307266 #keywords=17207 #venues=5501 #pubDates=35


# Upgrading the graph where paper node features come from real text embeddings

In [10]:
# Install sentence-transformers if needed (uncomment if required)
# import subprocess; subprocess.run(['pip', 'install', '-q', 'sentence-transformers'], check=True)

from sentence_transformers import SentenceTransformer
import torch
import numpy as np
from tqdm import tqdm


In [11]:
model = SentenceTransformer("all-MiniLM-L6-v2")  # 384-dim embeddings

In [12]:
import math
import numpy as np

paper_texts = []
valid_pids = [pid for pid in paper2idx.keys() if str(pid).isdigit()]  # keep only numeric strings

for pid in sorted(valid_pids, key=lambda x: int(x)):
    pid_int = int(pid)
    text = all_df.loc[all_df["publication_ID"] == pid_int, "text"]

    if len(text) > 0 and isinstance(text.values[0], str):
        paper_texts.append(text.values[0])
    else:
        paper_texts.append(" ")  # placeholder if no text found


In [13]:
print(f"Total paper nodes: {len(paper2idx)}")
print(f"Total valid papers with text: {len(paper_texts)}")


Total paper nodes: 424279
Total valid papers with text: 424278


In [14]:
paper_embs = model.encode(
    paper_texts,
    batch_size=64,
    show_progress_bar=True,
    convert_to_numpy=True,
    normalize_embeddings=True
)


Batches:   0%|          | 0/6630 [00:00<?, ?it/s]

In [15]:
# Save paper embeddings to data directory
np.save(data_dir / "paper_embeddings.npy", paper_embs)
print(f"Paper embeddings saved to {data_dir / 'paper_embeddings.npy'}")

Paper embeddings saved to data/paper_embeddings.npy


In [16]:
# Load the embeddings
# import numpy as np
# paper_embs = np.load(data_dir / "paper_embeddings.npy")
# print("Loaded paper embeddings:", paper_embs.shape)


In [17]:
# Jacob Question: Why are only author and keyword learnable embeddings? what about paper?

from torch_geometric.data import HeteroData
import torch

data = HeteroData()

# Paper node features
embedding_dim = 384 #this is because the transformer has 384-dim embeddings

data['paper'].x = torch.tensor(paper_embs, dtype=torch.float) #the paper node is now embeddings from transformer
data['author'].x = torch.randn(len(author2idx), embedding_dim)
data['keyword'].x = torch.randn(len(keyword2idx), embedding_dim)
data['venue'].x = torch.randn(len(venue2idx), embedding_dim)
data['pubDate'].x = torch.randn(len(pubDate2idx), embedding_dim)

In [18]:
print(data['paper'].x.shape)
print(data['author'].x.shape)
print(data['keyword'].x.shape)


torch.Size([424278, 384])
torch.Size([307266, 384])
torch.Size([17207, 384])


In [19]:
# Saw time inefficiency in the above code, the below is a fix
#-----Shik Fixed here -------------
# Fixed edge overwriting (previously only kept last record)

edge_store = {
    ('paper', 'cites', 'paper'): [[], []],
    ('paper', 'written_by', 'author'): [[], []],
    ('author', 'authored', 'paper'): [[], []],
    ('paper', 'mentions', 'keyword'): [[], []],
    ('keyword', 'appears_in', 'paper'): [[], []],
    ('paper', 'published_in', 'venue'): [[], []],
    ('venue', 'published', 'paper'): [[], []],
    ('paper', 'publication_date', 'pubDate'): [[], []],
}

for _, row in tqdm(train_df.iterrows(), total=len(train_df), desc="Building edges"):
    pid = str(row['publication_ID'])
    if pid not in paper2idx:
        continue
    pidx = paper2idx[pid]

    # --- (a) Citations: paper → paper ---
    for cited in row['Citations']:
        cited = str(cited)
        if cited in paper2idx:
            edge_store[('paper', 'cites', 'paper')][0].append(pidx)
            edge_store[('paper', 'cites', 'paper')][1].append(paper2idx[cited])

    # --- (b) Authors: paper ↔ author ---
    for aid in row['author_ids']:
        if aid in author2idx:
            edge_store[('paper', 'written_by', 'author')][0].append(pidx)
            edge_store[('paper', 'written_by', 'author')][1].append(author2idx[aid])
            edge_store[('author', 'authored', 'paper')][0].append(author2idx[aid])
            edge_store[('author', 'authored', 'paper')][1].append(pidx)

    # --- (c) Keywords: paper ↔ keyword ---
    for kw in row['keyword_list']:
        if kw in keyword2idx:
            edge_store[('paper', 'mentions', 'keyword')][0].append(pidx)
            edge_store[('paper', 'mentions', 'keyword')][1].append(keyword2idx[kw])
            edge_store[('keyword', 'appears_in', 'paper')][0].append(keyword2idx[kw])
            edge_store[('keyword', 'appears_in', 'paper')][1].append(pidx)

    # --- (d) Venue: paper ↔ venue ---
    venue = row['venue']
    if venue in venue2idx:
        edge_store[('paper', 'published_in', 'venue')][0].append(pidx)
        edge_store[('paper', 'published_in', 'venue')][1].append(venue2idx[venue])
        edge_store[('venue', 'published', 'paper')][0].append(venue2idx[venue])
        edge_store[('venue', 'published', 'paper')][1].append(pidx)

    # --- (e) Publication Date: paper → pubDate ---
    pubDate = row['pubDate']
    if pubDate in pubDate2idx:
        edge_store[('paper', 'publication_date', 'pubDate')][0].append(pidx)
        edge_store[('paper', 'publication_date', 'pubDate')][1].append(pubDate2idx[pubDate])

for rel, (src, dst) in edge_store.items():
    if len(src) > 0:
        data[rel].edge_index = torch.tensor([src, dst], dtype=torch.long)
    else:
        print(f"No edges found for relation {rel}")


Building edges: 100%|██████████████████| 42000/42000 [00:03<00:00, 13580.83it/s]


In [20]:
def summarize_heterodata(data):
    print("=== Node Counts ===")
    for ntype in data.node_types:
        print(f"{ntype:<12} → {data[ntype].num_nodes:,} nodes | "
              f"Feature dim: {data[ntype].x.shape[1] if 'x' in data[ntype] else 'N/A'}")

    print("\n=== Edge Counts ===")
    for etype in data.edge_types:
        e = data[etype].edge_index
        print(f"{etype} → {e.shape[1]:,} edges (shape={tuple(e.shape)})")

summarize_heterodata(data)


=== Node Counts ===
paper        → 424,278 nodes | Feature dim: 384
author       → 307,266 nodes | Feature dim: 384
keyword      → 17,207 nodes | Feature dim: 384
venue        → 5,501 nodes | Feature dim: 384
pubDate      → 35 nodes | Feature dim: 384

=== Edge Counts ===
('paper', 'cites', 'paper') → 486,632 edges (shape=(2, 486632))
('paper', 'written_by', 'author') → 291,666 edges (shape=(2, 291666))
('author', 'authored', 'paper') → 291,666 edges (shape=(2, 291666))
('paper', 'mentions', 'keyword') → 948,766 edges (shape=(2, 948766))
('keyword', 'appears_in', 'paper') → 948,766 edges (shape=(2, 948766))
('paper', 'published_in', 'venue') → 42,000 edges (shape=(2, 42000))
('venue', 'published', 'paper') → 42,000 edges (shape=(2, 42000))
('paper', 'publication_date', 'pubDate') → 42,000 edges (shape=(2, 42000))


In [21]:
# Checking for self-loops
num_self_loops = (data['paper', 'cites', 'paper'].edge_index[0] ==
                  data['paper', 'cites', 'paper'].edge_index[1]).sum().item()
print(f"Self-loops in cites relation: {num_self_loops}")

# Checking for duplicate edges
ei = data['paper', 'cites', 'paper'].edge_index
num_unique = torch.unique(ei, dim=1).shape[1]
print(f"Duplicate edges: {ei.shape[1] - num_unique}")


Self-loops in cites relation: 0
Duplicate edges: 63994


In [22]:
import torch

# Remove duplicate edges from (paper, cites, paper)
ei = data['paper', 'cites', 'paper'].edge_index

# Sort columns so [src, dst] and [dst, src] duplicates align
ei_unique = torch.unique(ei, dim=1)
num_removed = ei.shape[1] - ei_unique.shape[1]

data['paper', 'cites', 'paper'].edge_index = ei_unique

print(f"Removed {num_removed} duplicate citation edges. New total: {ei_unique.shape[1]}")


Removed 63994 duplicate citation edges. New total: 422638


In [23]:
ei = data['paper', 'cites', 'paper'].edge_index
print(f"New edge count: {ei.shape[1]}")
num_unique = torch.unique(ei, dim=1).shape[1]
print(f"Remaining duplicates: {ei.shape[1] - num_unique}")


New edge count: 422638
Remaining duplicates: 0


In [24]:
from torch_geometric.nn import HeteroConv, GATConv, Linear
import torch.nn.functional as F
import torch.nn as nn

class HeteroGNN(nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()

        self.conv1 = HeteroConv({
            ('paper', 'cites', 'paper'): GATConv(-1, hidden_channels, residual=True, add_self_loops=True),
            ('paper', 'written_by', 'author'): GATConv(-1, hidden_channels, residual=True, add_self_loops=False),
            ('author', 'authored', 'paper'): GATConv(-1, hidden_channels, residual=True, add_self_loops=False),
            ('paper', 'mentions', 'keyword'): GATConv(-1, hidden_channels, residual=True, add_self_loops=False),
            ('keyword', 'appears_in', 'paper'): GATConv(-1, hidden_channels, residual=True, add_self_loops=False),
            ('paper', 'published_in', 'venue'): GATConv(-1, hidden_channels, residual=True, add_self_loops=False),
            ('venue', 'published', 'paper'): GATConv(-1, hidden_channels, residual=True, add_self_loops=False),
            ('paper','publication_date', 'pubDate'): GATConv(-1, hidden_channels, residual=True, add_self_loops=False)
        }, aggr='sum')

        self.lin = nn.Linear(hidden_channels, out_channels)

    def forward(self, x_dict, edge_index_dict):
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {key: F.relu(x) for key, x in x_dict.items()}
        x_dict = {key: self.lin(x) for key, x in x_dict.items()}
        return x_dict

In [25]:
model = HeteroGNN(hidden_channels=64, out_channels=64).to(device)
data = data.to(device) #Needed for all values to be on same device
print(model)


HeteroGNN(
  (conv1): HeteroConv(num_relations=8)
  (lin): Linear(in_features=64, out_features=64, bias=True)
)


In [26]:
from torch_geometric.utils import negative_sampling

def loss_fn(emb_src, emb_dst, pos_edges, neg_edges):
    pos_score = (emb_src[pos_edges[0]] * emb_dst[pos_edges[1]]).sum(dim=1)
    neg_score = (emb_src[neg_edges[0]] * emb_dst[neg_edges[1]]).sum(dim=1)
    labels = torch.cat([torch.ones_like(pos_score), torch.zeros_like(neg_score)])
    scores = torch.cat([pos_score, neg_score])
    return F.binary_cross_entropy_with_logits(scores, labels)


In [27]:
def clean_invalid_edges(data):
    for etype in data.edge_types:
        if 'edge_index' not in data[etype]:
            continue

        ei = data[etype].edge_index
        src_type, _, dst_type = etype
        src_nodes = data[src_type].num_nodes
        dst_nodes = data[dst_type].num_nodes

        mask = (ei[0] < src_nodes) & (ei[1] < dst_nodes)
        valid_count = mask.sum().item()

        if valid_count < ei.shape[1]:
            print(f"Removing {ei.shape[1] - valid_count} invalid edges from {etype}")
            data[etype].edge_index = ei[:, mask]

clean_invalid_edges(data)


Removing 48 invalid edges from ('paper', 'cites', 'paper')


In [28]:
def check_hetero_integrity(data):
    print("=== Sanity Check for HeteroData ===")
    for etype in data.edge_types:
        ei = data[etype].edge_index
        src_type, _, dst_type = etype

        if ei is None or ei.numel() == 0:
            print(f"{etype} has no edges")
            continue

        src_nodes, dst_nodes = data[src_type].num_nodes, data[dst_type].num_nodes
        max_src, max_dst = ei[0].max().item(), ei[1].max().item()
        min_src, min_dst = ei[0].min().item(), ei[1].min().item()

        if max_src >= src_nodes or max_dst >= dst_nodes:
            print(f"{etype} has invalid indices: "
                  f"src max {max_src}/{src_nodes}, dst max {max_dst}/{dst_nodes}")
        elif min_src < 0 or min_dst < 0:
            print(f"{etype} has negative indices!")
        else:
            print(f"{etype} OK ({ei.shape[1]} edges)")

check_hetero_integrity(data)


=== Sanity Check for HeteroData ===
('paper', 'cites', 'paper') OK (422590 edges)
('paper', 'written_by', 'author') OK (291666 edges)
('author', 'authored', 'paper') OK (291666 edges)
('paper', 'mentions', 'keyword') OK (948766 edges)
('keyword', 'appears_in', 'paper') OK (948766 edges)
('paper', 'published_in', 'venue') OK (42000 edges)
('venue', 'published', 'paper') OK (42000 edges)
('paper', 'publication_date', 'pubDate') OK (42000 edges)


In [29]:
# Split Citation Edges
from torch_geometric.transforms import RandomLinkSplit

# Randomly split only the citation edges for link prediction
transform = RandomLinkSplit(
    num_val=0.1,                   # 10% validation
    num_test=0.1,                  # 10% test
    is_undirected=False,           # citations are directional
    add_negative_train_samples=True,
    edge_types=[('paper', 'cites', 'paper')],  # focus only on citation edges
    rev_edge_types=[None]          # no reverse relation
)

train_data, val_data, test_data = transform(data)
train_data = train_data.to(device)
val_data = val_data.to(device)
test_data = test_data.to(device)



In [30]:
# Fixed to focus only on paper, cites, paper edges
optimizer = torch.optim.Adam(model.parameters(), lr=GNN_config["learning_rate"])

for epoch in range(GNN_config['epoch']):
    model.train()
    optimizer.zero_grad()

    out_dict = model(train_data.x_dict, train_data.edge_index_dict)

    # --- Focus only on (paper, cites, paper) edges ---
    etype = ('paper', 'cites', 'paper')
    pos_edges = train_data[etype].edge_label_index  # edges for training

    # Negative sampling for link prediction
    neg_edges = negative_sampling(
        edge_index=pos_edges,
        num_nodes=(train_data['paper'].num_nodes, train_data['paper'].num_nodes),
        num_neg_samples=pos_edges.size(1)
    )


    emb_src = out_dict['paper']
    emb_dst = out_dict['paper']

    total_loss = loss_fn(emb_src, emb_dst, pos_edges, neg_edges)

    total_loss.backward()
    optimizer.step()

    print(f"Epoch {epoch+1}/{GNN_config['epoch']} | Loss: {total_loss.item():.4f}")

Epoch 1/10 | Loss: 0.7390
Epoch 2/10 | Loss: 0.7063
Epoch 3/10 | Loss: 0.6940
Epoch 4/10 | Loss: 0.6880
Epoch 5/10 | Loss: 0.6846
Epoch 6/10 | Loss: 0.6809
Epoch 7/10 | Loss: 0.6784
Epoch 8/10 | Loss: 0.6756
Epoch 9/10 | Loss: 0.6730
Epoch 10/10 | Loss: 0.6706


In [31]:
torch.save(model.state_dict(), "hetero_gnn_checkpoint.pt")

In [32]:
import torch
import os
from pathlib import Path

# Use local checkpoints directory (defined in cell 2)
# checkpoint_dir is already defined as checkpoints_dir from earlier setup
checkpoint_path = checkpoints_dir / 'hetero_gnn_checkpoint.pt'

# Save model
torch.save(model.state_dict(), checkpoint_path)
print(f"Model saved at {checkpoint_path.absolute()}")

# Load:
# model.load_state_dict(torch.load(checkpoint_path))
# model.to(device)


Model saved at /home/upandit/mag_citation_recommender/checkpoints/hetero_gnn_checkpoint.pt


# Inference

In [33]:
# Set model to evaluation mode
import torch
from pathlib import Path

# Load model from local checkpoints directory
checkpoint_path = checkpoints_dir / 'hetero_gnn_checkpoint.pt'

model = HeteroGNN(hidden_channels=64, out_channels=64)
model.load_state_dict(torch.load(checkpoint_path))
model = model.to(device)
model.eval()



HeteroGNN(
  (conv1): HeteroConv(num_relations=8)
  (lin): Linear(in_features=64, out_features=64, bias=True)
)

In [34]:
model.eval()
with torch.no_grad():
    out_dict = model(data.x_dict, data.edge_index_dict)
    paper_emb = out_dict['paper'].cpu()

In [35]:
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

query_pid = "17396995"  # Example paper ID
query_idx = paper2idx[query_pid]

if isinstance(paper_embs, torch.Tensor):
    paper_embs = paper_embs.cpu().numpy()

# Extract query vector
query_vec = paper_embs[query_idx].reshape(1, -1)

# Compute cosine similarity across all papers
sims = cosine_similarity(query_vec, paper_embs)[0]
topk = np.argsort(sims)[::-1][1:6]  # Skip itself, take top 5

# Show info about the query paper itself
query_row = all_df.loc[all_df["publication_ID"] == int(query_pid)]
if len(query_row) > 0:
    print("Query Paper:")
    print(f"ID: {query_pid}")
    print(f"Title: {query_row['title'].values[0] if 'title' in query_row.columns else 'Unknown'}")
    print(f"Year: {query_row['pubDate'].values[0]}")
    print(f"Venue: {query_row['venue'].values[0]}")
    print(f"Abstract: {query_row['text'].values[0][:400]}...")
else:
    print("Query paper not found in all_df.")

print("\nTop 5 similar papers:")
print("-" * 50)
for i in topk:
    pid = list(paper2idx.keys())[i]
    row = all_df.loc[all_df["publication_ID"] == int(pid)]
    if len(row) > 0:
        title = row["title"].values[0] if "title" in row.columns else "Unknown"
        print(f"{pid} — {title}")
    else:
        print(f"{pid} — Not found")


Query Paper:
ID: 17396995
Title: Herpes simplex virus type 2 infection does not influence viral dynamics during early HIV 1 infection
Year: 2007
Venue: The Journal of infectious diseases
Abstract: Herpes simplex virus type 2 infection does not influence viral dynamics during early HIV 1 infection. We sought to compare baseline and longitudinal plasma HIV-1 loads between herpes simplex virus type 2 (HSV-2)-seropositive and -seronegative individuals who are enrolled in a primary HIV-1 infection cohort in San Diego, California.. Adult;California;epidemiology;Cohort Studies;HIV Infections;blood...

Top 5 similar papers:
--------------------------------------------------
17264332 — Clinicopathologic features of osteosarcoma in patients with Rothmund Thomson syndrome
15372107 — PDX 1 haploinsufficiency limits the compensatory islet hyperplasia that occurs in response to insulin resistance
19759291 — Medial prefrontal cortex secondary hyperalgesia and the default mode network
15983384 — High 

In [36]:
from sklearn.metrics import roc_auc_score, average_precision_score

@torch.no_grad()
def evaluate(model, data, device):
    model.eval()
    out = model(data.x_dict, data.edge_index_dict)
    src_emb = out['paper']
    dst_emb = out['paper']

    # Positive and negative edges from test split
    pos_edges = data['paper', 'cites', 'paper'].edge_label_index[:, data['paper', 'cites', 'paper'].edge_label == 1]
    neg_edges = data['paper', 'cites', 'paper'].edge_label_index[:, data['paper', 'cites', 'paper'].edge_label == 0]

    # Compute scores (dot product similarity)
    pos_scores = (src_emb[pos_edges[0]] * dst_emb[pos_edges[1]]).sum(dim=1).cpu().numpy()
    neg_scores = (src_emb[neg_edges[0]] * dst_emb[neg_edges[1]]).sum(dim=1).cpu().numpy()

    y_true = np.concatenate([np.ones_like(pos_scores), np.zeros_like(neg_scores)])
    y_scores = np.concatenate([pos_scores, neg_scores])

    auc = roc_auc_score(y_true, y_scores)
    ap = average_precision_score(y_true, y_scores)

    return auc, ap


In [37]:
val_auc, val_ap = evaluate(model, val_data, device)
print(f"Validation AUC: {val_auc:.4f}, AP: {val_ap:.4f}")

test_auc, test_ap = evaluate(model, test_data, device)
print(f"Test AUC: {test_auc:.4f}, AP: {test_ap:.4f}")


Validation AUC: 0.9093, AP: 0.8504
Test AUC: 0.9050, AP: 0.8441


# =========================================
# PART 2: GRIL ALGORITHMS IMPLEMENTATION
# =========================================

This section implements the GRIL algorithms for citation recommendation:
- Algorithm 1: Attention-based Graph Retriever (with Entity Updates and Gumbel-Softmax)
- Algorithm 3: SAG Pooling Layer
- Algorithm 4: Joint Training Framework
- Algorithm 5: Graph Supervision
- LLM Integration: Llama3-8B with LoRA
- Verbalization: Triple-to-text conversion


In [38]:
# Verify Part 1 outputs are available for Part 2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import HeteroData

# Check required variables from Part 1
required_vars = {
    'model': 'HeteroGNN model (loaded from checkpoint)',
    'data': 'HeteroData graph',
    'paper2idx': 'Paper ID to index mapping',
    'all_df': 'DataFrame with paper metadata',
    'device': 'torch.device (cuda/cpu)',
    'checkpoints_dir': 'Path to checkpoints directory'
}

missing = []
for var_name, description in required_vars.items():
    if var_name not in globals():
        missing.append(f"{var_name} ({description})")
    else:
        print(f"[OK] {var_name}: {type(globals()[var_name]).__name__}")

if missing:
    print(f"\n[WARNING]  Missing variables: {', '.join(missing)}")
    print("Please run Part 1 cells first (especially model loading cell).")
else:
    print("\nAll required variables from Part 1 are available!")
    print(f"   Model: {type(model).__name__}")
    print(f"   Data: {type(data).__name__} with {len(data.node_types)} node types")
    print(f"   Device: {device}")


 model: HeteroGNN
 data: HeteroData
 paper2idx: dict
 all_df: DataFrame
 device: device
 checkpoints_dir: PosixPath

 All required variables from Part 1 are available!
   Model: HeteroGNN
   Data: HeteroData with 5 node types
   Device: cuda


## Query Encoder (SentenceTransformer)


In [39]:
# =========================================
# QUERY ENCODER
# =========================================
from sentence_transformers import SentenceTransformer
from typing import List
import torch

class QueryEncoder:
    """
    Encodes natural language queries into dense vectors using SentenceTransformer.
    """
    def __init__(self, model_name: str = 'all-MiniLM-L6-v2', device=None):
        """
        Args:
            model_name: SentenceTransformer model name (default: 'all-MiniLM-L6-v2' - 384 dim)
            device: Device to run the model on
        """
        self.device = device if device is not None else torch.device('cpu')
        self.model = SentenceTransformer(model_name, device=self.device)
        self.embedding_dim = self.model.get_sentence_embedding_dimension()
        print(f"QueryEncoder initialized with {model_name}, embedding_dim={self.embedding_dim}")
    
    def encode(self, query_text: str, convert_to_tensor: bool = True) -> torch.Tensor:
        """Encode a single query text into a dense vector."""
        embedding = self.model.encode(query_text, convert_to_tensor=convert_to_tensor, device=self.device)
        return embedding
    
    def encode_batch(self, query_texts: List[str], batch_size: int = 32) -> torch.Tensor:
        """Encode a batch of query texts."""
        embeddings = self.model.encode(query_texts, batch_size=batch_size, 
                                      convert_to_tensor=True, device=self.device)
        return embeddings

# Initialize query encoder
query_encoder = QueryEncoder(device=device)


QueryEncoder initialized with all-MiniLM-L6-v2, embedding_dim=384


## Attention-based Relevance Scorer


In [40]:
# =========================================
# ATTENTION-BASED RELEVANCE SCORER
# =========================================
import torch.nn as nn
import torch.nn.functional as F

class RelevanceScorer(nn.Module):
    """
    Computes relevance scores between query, source node, and destination node embeddings.
    This is the core attention mechanism for Algorithm 1.
    """
    def __init__(self, query_dim: int, node_dim: int, hidden_dim: int = 128):
        super().__init__()
        
        # Project query to node dimension if needed
        if query_dim != node_dim:
            self.query_proj = nn.Linear(query_dim, node_dim)
        else:
            self.query_proj = nn.Identity()
        
        # MLP for computing relevance: f(query, src, dst) -> score
        self.mlp = nn.Sequential(
            nn.Linear(node_dim * 3, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1)
        )
    
    def forward(self, query_emb: torch.Tensor, src_embs: torch.Tensor, 
                dst_embs: torch.Tensor) -> torch.Tensor:
        """Compute relevance scores for edges."""
        query_proj = self.query_proj(query_emb)
        if query_proj.dim() == 1:
            query_proj = query_proj.unsqueeze(0)
        
        num_edges = src_embs.size(0)
        query_expanded = query_proj.expand(num_edges, -1)
        combined = torch.cat([query_expanded, src_embs, dst_embs], dim=1)
        scores = self.mlp(combined).squeeze(-1)
        scores = torch.sigmoid(scores)
        return scores

# Initialize relevance scorer
query_dim = query_encoder.embedding_dim  # 384
node_dim = 64  # From HeteroGNN output
relevance_scorer = RelevanceScorer(query_dim=query_dim, node_dim=node_dim).to(device)
print(f"RelevanceScorer initialized (query_dim={query_dim}, node_dim={node_dim})")


 RelevanceScorer initialized (query_dim=384, node_dim=64)


## Algorithm 1: Attention-based Growing and Pruning


In [41]:
# =========================================
# ALGORITHM 1: ATTENTION-BASED GROWING AND PRUNING
# =========================================
from typing import Dict, List, Tuple, Set, Optional
import torch
import numpy as np
from torch_geometric.data import HeteroData

def attention_based_graph_retriever(
    query_text: str,
    query_seed_entities: Dict[str, List[int]],
    gnn_model: nn.Module,
    full_data: HeteroData,
    query_encoder: QueryEncoder,
    relevance_scorer: RelevanceScorer,
    max_hops: int = 2,
    relevance_threshold: float = 0.1,
    max_nodes_per_hop: Optional[int] = None,
    device: torch.device = None
) -> HeteroData:
    """
    Implements Algorithm 1: Attention-based Growing and Pruning for dynamic subgraph retrieval.
    
    Args:
        query_text: Natural language query (e.g., paper title + abstract)
        query_seed_entities: Dict mapping node types to lists of node indices to start from
        gnn_model: Trained HeteroGNN model
        full_data: Full HeteroData graph
        query_encoder: QueryEncoder instance
        relevance_scorer: RelevanceScorer instance
        max_hops: Maximum number of hops to expand (default: 2)
        relevance_threshold: Minimum relevance score to retain an edge (default: 0.1)
        max_nodes_per_hop: Maximum number of nodes to retain per hop (None = no limit)
        device: Device to run computation on
    
    Returns:
        HeteroData: Retrieved subgraph containing only relevant nodes and edges
    """
    if device is None:
        device = next(gnn_model.parameters()).device
    
    # 1. Encode query
    query_emb = query_encoder.encode(query_text).to(device)
    
    # 2. Get GNN embeddings for all nodes
    gnn_model.eval()
    with torch.no_grad():
        out_dict = gnn_model(full_data.x_dict, full_data.edge_index_dict)
    
    # 3. Initialize current nodes (seed entities)
    current_nodes: Dict[str, Set[int]] = {k: set(v) for k, v in query_seed_entities.items()}
    
    # 4. Track all retained edges and nodes across hops
    retained_edges: Dict[Tuple[str, str, str], List[Tuple[int, int, float]]] = {}
    all_retained_nodes: Dict[str, Set[int]] = {k: set(v) for k, v in current_nodes.items()}
    
    # 5. Multi-hop expansion
    for hop in range(max_hops):
        newly_retained_nodes: Dict[str, Set[int]] = {}
        
        # Process each edge type
        for etype in full_data.edge_types:
            src_type, rel_type, dst_type = etype
            edge_index = full_data[etype].edge_index
            
            # 5.1. Grow: Filter edges starting from current nodes
            if src_type not in current_nodes or len(current_nodes[src_type]) == 0:
                continue
            
            src_nodes_tensor = torch.tensor(list(current_nodes[src_type]), device=device)
            src_mask = torch.isin(edge_index[0], src_nodes_tensor)
            potential_edge_indices = torch.where(src_mask)[0]
            
            if potential_edge_indices.numel() == 0:
                continue
            
            # Get source and destination indices for potential edges
            potential_src = edge_index[0, potential_edge_indices]
            potential_dst = edge_index[1, potential_edge_indices]
            
            # 5.2. Score: Calculate relevance for each potential edge
            # Clone tensors to detach from inference mode (allows use in autograd if needed)
            src_node_embs = out_dict[src_type][potential_src].clone()
            dst_node_embs = out_dict[dst_type][potential_dst].clone()
            # For inference, wrap in no_grad to avoid gradient computation
            with torch.no_grad():
                relevance_scores = relevance_scorer(query_emb, src_node_embs, dst_node_embs)
            
            # 5.3. Prune: Filter edges below threshold
            relevant_mask = relevance_scores > relevance_threshold
            
            if relevant_mask.sum() == 0:
                continue
            
            # Get retained edges
            retained_src = potential_src[relevant_mask].cpu().numpy()
            retained_dst = potential_dst[relevant_mask].cpu().numpy()
            retained_scores = relevance_scores[relevant_mask].cpu().numpy()
            
            # Optionally limit nodes per hop
            if max_nodes_per_hop is not None:
                top_k = min(max_nodes_per_hop, len(retained_scores))
                top_indices = np.argsort(retained_scores)[::-1][:top_k]
                retained_src = retained_src[top_indices]
                retained_dst = retained_dst[top_indices]
                retained_scores = retained_scores[top_indices]
            
            # Store retained edges
            edge_triples = list(zip(retained_src.tolist(), retained_dst.tolist(), retained_scores.tolist()))
            if etype not in retained_edges:
                retained_edges[etype] = []
            retained_edges[etype].extend(edge_triples)
            
            # 5.4. Update: Add destination nodes for next hop
            newly_retained_nodes.setdefault(dst_type, set()).update(retained_dst.tolist())
            all_retained_nodes.setdefault(dst_type, set()).update(retained_dst.tolist())
        
        # Merge newly retained nodes for next iteration
        for node_type, new_nodes in newly_retained_nodes.items():
            current_nodes.setdefault(node_type, set()).update(new_nodes)
    
    # 6. Construct subgraph from retained nodes and edges
    subgraph_data = _construct_subgraph(full_data, all_retained_nodes, retained_edges, device)
    return subgraph_data


def _construct_subgraph(
    full_data: HeteroData,
    retained_nodes: Dict[str, Set[int]],
    retained_edges: Dict[Tuple[str, str, str], List[Tuple[int, int, float]]],
    device: torch.device
) -> HeteroData:
    """Construct a HeteroData subgraph from retained nodes and edges."""
    subgraph_data = HeteroData().to(device)
    
    # Add node features for retained nodes
    for node_type in retained_nodes:
        if node_type not in full_data.node_types:
            continue
        
        node_indices = sorted(list(retained_nodes[node_type]))
        if len(node_indices) == 0:
            continue
        
        node_mapping = {orig_idx: new_idx for new_idx, orig_idx in enumerate(node_indices)}
        node_tensor = torch.tensor(node_indices, device=device)
        subgraph_data[node_type].x = full_data[node_type].x[node_tensor]
        subgraph_data[node_type]._node_mapping = node_mapping
    
    # Add edges
    for etype, edge_list in retained_edges.items():
        if len(edge_list) == 0:
            continue
        
        src_type, rel_type, dst_type = etype
        src_mapping = subgraph_data[src_type]._node_mapping
        dst_mapping = subgraph_data[dst_type]._node_mapping
        
        edge_src = []
        edge_dst = []
        
        for src, dst, score in edge_list:
            if src in src_mapping and dst in dst_mapping:
                edge_src.append(src_mapping[src])
                edge_dst.append(dst_mapping[dst])
        
        if len(edge_src) > 0:
            edge_index = torch.tensor([edge_src, edge_dst], dtype=torch.long, device=device)
            subgraph_data[etype].edge_index = edge_index
    
    return subgraph_data

print("Algorithm 1 functions defined")


 Algorithm 1 functions defined


## Entity Embedding Update (Eq. 4) and Gumbel-Softmax Sampling (Eq. 5)

This section implements:
- **Entity Embedding Update (Eq. 4)**: Message-passing mechanism that updates entity embeddings after each growing/pruning step
- **Gumbel-Softmax Sampling (Eq. 5)**: Differentiable subgraph sampling for end-to-end training


In [42]:
# =========================================
# ENTITY EMBEDDING UPDATE MODULE (Eq. 4)
# =========================================
class EntityEmbeddingUpdater(nn.Module):
    """
    Implements Equation 4 from GRIL paper:
    h'_ei = W₁h_ei + W₂ Σ_{j∈N(vi)} α_ji h_ej
    
    Updates entity embeddings through attention-weighted message passing.
    """
    def __init__(self, node_dim: int, hidden_dim: int = None):
        super().__init__()
        self.node_dim = node_dim
        self.hidden_dim = hidden_dim if hidden_dim is not None else node_dim
        
        # W₁: Linear transformation for self-embedding
        self.W1 = nn.Linear(node_dim, self.hidden_dim)
        
        # W₂: Linear transformation for aggregated neighbor embeddings
        self.W2 = nn.Linear(node_dim, self.hidden_dim)
        
        # Optional: Layer normalization
        self.layer_norm = nn.LayerNorm(self.hidden_dim)
    
    def forward(self, 
                node_embeddings: torch.Tensor,
                neighbor_embeddings: torch.Tensor,
                attention_weights: torch.Tensor,
                node_indices: torch.Tensor = None) -> torch.Tensor:
        """
        Update entity embeddings using attention-weighted message passing.
        
        Args:
            node_embeddings: Current node embeddings [num_nodes, node_dim]
            neighbor_embeddings: Neighbor embeddings [num_edges, node_dim]
            attention_weights: Attention scores α_ji [num_edges]
            node_indices: Optional mapping from edges to nodes [num_edges] (source node for each edge)
        
        Returns:
            Updated embeddings [num_nodes, hidden_dim]
        """
        # W₁h_ei: Self-embedding transformation
        self_emb = self.W1(node_embeddings)  # [num_nodes, hidden_dim]
        
        # W₂ Σ_{j∈N(vi)} α_ji h_ej: Aggregated neighbor embeddings
        if node_indices is not None:
            # Aggregate neighbors for each node using attention weights
            # neighbor_embeddings: [num_edges, node_dim]
            # attention_weights: [num_edges]
            weighted_neighbors = neighbor_embeddings * attention_weights.unsqueeze(-1)  # [num_edges, node_dim]
            
            # Aggregate by node (sum neighbors for each node)
            num_nodes = node_embeddings.size(0)
            aggregated = torch.zeros(num_nodes, self.node_dim, 
                                    device=node_embeddings.device, 
                                    dtype=node_embeddings.dtype)
            aggregated.index_add_(0, node_indices, weighted_neighbors)
            
            # Transform aggregated neighbors
            neighbor_emb = self.W2(aggregated)  # [num_nodes, hidden_dim]
        else:
            # Simple case: if node_indices not provided, assume direct aggregation
            # This is a fallback - in practice, node_indices should be provided
            weighted_neighbors = neighbor_embeddings * attention_weights.unsqueeze(-1)
            neighbor_emb = self.W2(weighted_neighbors.mean(dim=0, keepdim=True).expand_as(self_emb))
        
        # Combine: h'_ei = W₁h_ei + W₂ Σ_{j∈N(vi)} α_ji h_ej
        updated_emb = self_emb + neighbor_emb
        
        # Apply layer normalization
        updated_emb = self.layer_norm(updated_emb)
        
        return updated_emb

# Initialize entity embedding updater
entity_updater = EntityEmbeddingUpdater(node_dim=64, hidden_dim=64).to(device)
print(f"EntityEmbeddingUpdater initialized (node_dim=64, hidden_dim=64)")


 EntityEmbeddingUpdater initialized (node_dim=64, hidden_dim=64)


In [43]:
# =========================================
# GUMBEL-SOFTMAX SAMPLING (Eq. 5)
# =========================================
def gumbel_softmax_sampling(
    logits: torch.Tensor,
    temperature: float = 1.0,
    hard: bool = False,
    training: bool = True
) -> torch.Tensor:
    """
    Implements Equation 5 from GRIL paper:
    Mi = σ((log(ϵ/(1-ϵ)) + log(Pi/(1-Pi))) / τ)
    
    Differentiable subgraph sampling using Gumbel-Softmax reparameterization trick.
    
    Args:
        logits: Probability logits for each edge/triplet [num_edges] or [batch_size, num_edges]
        temperature: Temperature parameter τ (default: 1.0)
        hard: If True, returns hard (one-hot) samples, but gradients flow through soft samples
        training: If False, uses hard sampling (deterministic)
    
    Returns:
        Sampled mask M [num_edges] or [batch_size, num_edges]
    """
    if not training:
        # During inference, use hard thresholding
        probs = torch.sigmoid(logits)
        return (probs > 0.5).float()
    
    # Convert logits to probabilities
    # Pi = sigmoid(logits) for binary case
    probs = torch.sigmoid(logits)
    
    # Avoid numerical issues
    eps = 1e-10
    probs = torch.clamp(probs, eps, 1.0 - eps)
    
    # Sample Gumbel noise: ϵ ~ Uniform(0, 1)
    # Gumbel(0, 1) = -log(-log(U)) where U ~ Uniform(0, 1)
    uniform_noise = torch.rand_like(probs)
    uniform_noise = torch.clamp(uniform_noise, eps, 1.0 - eps)
    gumbel_noise = -torch.log(-torch.log(uniform_noise))
    
    # Compute log(Pi / (1 - Pi))
    log_odds = torch.log(probs / (1 - probs))
    
    # Apply Gumbel-Softmax: Mi = σ((log(ϵ/(1-ϵ)) + log(Pi/(1-Pi))) / τ)
    # Note: log(ϵ/(1-ϵ)) is the Gumbel noise
    y = (log_odds + gumbel_noise) / temperature
    soft_samples = torch.sigmoid(y)
    
    if hard:
        # Hard sampling: return one-hot, but gradients flow through soft
        hard_samples = (soft_samples > 0.5).float()
        # Straight-through estimator: use hard in forward, soft in backward
        return hard_samples + soft_samples - soft_samples.detach()
    else:
        return soft_samples


def compute_triplet_probabilities(
    relevance_scores: torch.Tensor,
    attention_weights: torch.Tensor = None
) -> torch.Tensor:
    """
    Compute probability scores P on triplets for Gumbel-Softmax sampling.
    
    Args:
        relevance_scores: Relevance scores from RelevanceScorer [num_edges]
        attention_weights: Optional attention weights α_ij [num_edges]
    
    Returns:
        Probability logits for triplets [num_edges]
    """
    if attention_weights is not None:
        # Combine relevance and attention: P = relevance * attention
        probs = relevance_scores * attention_weights
    else:
        probs = relevance_scores
    
    # Convert to logits for Gumbel-Softmax
    # Use logit transform: logit(p) = log(p / (1-p))
    eps = 1e-10
    probs = torch.clamp(probs, eps, 1.0 - eps)
    logits = torch.log(probs / (1 - probs))
    
    return logits

print("Gumbel-Softmax sampling functions defined")


 Gumbel-Softmax sampling functions defined


## Enhanced Algorithm 1: With Entity Updates and Gumbel-Softmax

This enhanced version integrates:
- Entity embedding updates after each hop (Eq. 4)
- Gumbel-Softmax sampling for differentiable subgraph selection (Eq. 5)
- Support for both training and inference modes


In [44]:
# =========================================
# ENHANCED ALGORITHM 1: WITH ENTITY UPDATES AND GUMBEL-SOFTMAX
# =========================================
def attention_based_graph_retriever_enhanced(
    query_text: str,
    query_seed_entities: Dict[str, List[int]],
    gnn_model: nn.Module,
    full_data: HeteroData,
    query_encoder: QueryEncoder,
    relevance_scorer: RelevanceScorer,
    entity_updater: EntityEmbeddingUpdater = None,
    max_hops: int = 2,
    relevance_threshold: float = 0.1,
    max_nodes_per_hop: Optional[int] = None,
    use_gumbel_softmax: bool = True,
    gumbel_temperature: float = 1.0,
    training: bool = False,
    device: torch.device = None
) -> Tuple[HeteroData, Dict]:
    """
    Enhanced Algorithm 1 with Entity Embedding Updates (Eq. 4) and Gumbel-Softmax Sampling (Eq. 5).
    
    Args:
        query_text: Natural language query (e.g., paper title + abstract)
        query_seed_entities: Dict mapping node types to lists of node indices to start from
        gnn_model: Trained HeteroGNN model
        full_data: Full HeteroData graph
        query_encoder: QueryEncoder instance
        relevance_scorer: RelevanceScorer instance
        entity_updater: EntityEmbeddingUpdater instance (optional, for Eq. 4)
        max_hops: Maximum number of hops to expand (default: 2)
        relevance_threshold: Minimum relevance score to retain an edge (default: 0.1)
        max_nodes_per_hop: Maximum number of nodes to retain per hop (None = no limit)
        use_gumbel_softmax: Whether to use Gumbel-Softmax sampling (default: True)
        gumbel_temperature: Temperature for Gumbel-Softmax (default: 1.0)
        training: Whether in training mode (affects Gumbel-Softmax behavior)
        device: Device to run computation on
    
    Returns:
        Tuple of:
        - HeteroData: Retrieved subgraph
        - Dict: Additional info (updated embeddings, triplet probabilities, etc.)
    """
    if device is None:
        device = next(gnn_model.parameters()).device
    
    # 1. Encode query
    # Ensure encoding is done outside inference mode
    if training:
        # During training, we want regular tensors
        query_emb = query_encoder.encode(query_text, convert_to_tensor=True).to(device)
        # Clone to ensure it's a regular tensor (not inference mode)
        query_emb = query_emb.clone()
    else:
        # During inference, use no_grad and detach
        with torch.no_grad():
            query_emb = query_encoder.encode(query_text, convert_to_tensor=True).to(device)
        query_emb = query_emb.clone().detach()
    
    # 2. Get initial GNN embeddings for all nodes
    gnn_model.eval()
    if training:
        # During training, we freeze GNN (no gradients through GNN)
        # Entity updater will create new tensors with gradients
        with torch.no_grad():
            out_dict = gnn_model(full_data.x_dict, full_data.edge_index_dict)
        # Clone to make them regular tensors, but don't require grad on initial embeddings
        # The entity updater will create new tensors with gradients
        updated_embeddings = {node_type: emb.clone().detach() 
                            for node_type, emb in out_dict.items()}
    else:
        # During inference, use no_grad
        with torch.no_grad():
            out_dict = gnn_model(full_data.x_dict, full_data.edge_index_dict)
        # Clone to ensure they're regular tensors (not inference mode)
        updated_embeddings = {node_type: emb.clone().detach() 
                            for node_type, emb in out_dict.items()}
    
    # 3. Initialize current nodes (seed entities)
    current_nodes: Dict[str, Set[int]] = {k: set(v) for k, v in query_seed_entities.items()}
    
    # 4. Track all retained edges and nodes across hops
    retained_edges: Dict[Tuple[str, str, str], List[Tuple[int, int, float]]] = {}
    all_retained_nodes: Dict[str, Set[int]] = {k: set(v) for k, v in current_nodes.items()}
    
    # Track triplet probabilities for Gumbel-Softmax
    all_triplet_logits: Dict[Tuple[str, str, str], torch.Tensor] = {}
    all_attention_weights: Dict[Tuple[str, str, str], torch.Tensor] = {}
    
    # 5. Multi-hop expansion with entity embedding updates
    for hop in range(max_hops):
        newly_retained_nodes: Dict[str, Set[int]] = {}
        hop_edges: Dict[Tuple[str, str, str], List[Tuple[int, int, float, torch.Tensor]]] = {}
        
        # Process each edge type
        for etype in full_data.edge_types:
            src_type, rel_type, dst_type = etype
            edge_index = full_data[etype].edge_index
            
            # 5.1. Grow: Filter edges starting from current nodes
            if src_type not in current_nodes or len(current_nodes[src_type]) == 0:
                continue
            
            src_nodes_tensor = torch.tensor(list(current_nodes[src_type]), device=device)
            src_mask = torch.isin(edge_index[0], src_nodes_tensor)
            potential_edge_indices = torch.where(src_mask)[0]
            
            if potential_edge_indices.numel() == 0:
                continue
            
            # Get source and destination indices for potential edges
            potential_src = edge_index[0, potential_edge_indices]
            potential_dst = edge_index[1, potential_edge_indices]
            
            # 5.2. Score: Calculate relevance using UPDATED embeddings
            # Use updated embeddings from previous hop (or initial if first hop)
            src_node_embs = updated_embeddings[src_type][potential_src]
            dst_node_embs = updated_embeddings[dst_type][potential_dst]
            
            # Compute relevance scores (with gradients if training)
            if training:
                relevance_scores = relevance_scorer(query_emb, src_node_embs, dst_node_embs)
            else:
                # During inference, use no_grad to avoid autograd tracking
                with torch.no_grad():
                    relevance_scores = relevance_scorer(query_emb, src_node_embs, dst_node_embs)
            
            # Compute attention weights (softmax over neighbors for each source node)
            # Group by source node and compute softmax
            if training:
                attention_weights = relevance_scores.clone()
                if len(potential_src) > 0:
                    # Normalize attention scores per source node
                    src_unique, src_inverse = torch.unique(potential_src, return_inverse=True)
                    for src_idx in range(len(src_unique)):
                        mask = (src_inverse == src_idx)
                        if mask.sum() > 1:
                            attention_weights[mask] = F.softmax(relevance_scores[mask] / 0.1, dim=0)
                        else:
                            attention_weights[mask] = 1.0
            else:
                # During inference, compute in no_grad context
                with torch.no_grad():
                    attention_weights = relevance_scores.clone()
                    if len(potential_src) > 0:
                        # Normalize attention scores per source node
                        src_unique, src_inverse = torch.unique(potential_src, return_inverse=True)
                        for src_idx in range(len(src_unique)):
                            mask = (src_inverse == src_idx)
                            if mask.sum() > 1:
                                attention_weights[mask] = F.softmax(relevance_scores[mask] / 0.1, dim=0)
                            else:
                                attention_weights[mask] = 1.0
            
            # 5.3. Prune: Filter edges below threshold
            relevant_mask = relevance_scores > relevance_threshold
            
            if relevant_mask.sum() == 0:
                continue
            
            # Apply Gumbel-Softmax sampling if enabled
            if use_gumbel_softmax and training:
                # Compute triplet probabilities
                triplet_logits = compute_triplet_probabilities(
                    relevance_scores[relevant_mask],
                    attention_weights[relevant_mask]
                )
                
                # Sample using Gumbel-Softmax
                sampled_mask = gumbel_softmax_sampling(
                    triplet_logits,
                    temperature=gumbel_temperature,
                    hard=False,
                    training=training
                )
                
                # Further filter by sampled mask
                sampled_indices = torch.where(sampled_mask > 0.5)[0]
                if len(sampled_indices) == 0:
                    continue
                
                # Get final retained edges
                final_mask = torch.zeros_like(relevant_mask)
                relevant_indices = torch.where(relevant_mask)[0]
                final_mask[relevant_indices[sampled_indices]] = True
            else:
                # Standard thresholding (inference mode)
                final_mask = relevant_mask
                if use_gumbel_softmax and not training:
                    # During inference, use hard thresholding
                    triplet_logits = compute_triplet_probabilities(
                        relevance_scores[relevant_mask],
                        attention_weights[relevant_mask]
                    )
                    sampled_mask = gumbel_softmax_sampling(
                        triplet_logits,
                        temperature=gumbel_temperature,
                        hard=True,
                        training=False
                    )
                else:
                    sampled_mask = None
            
            if final_mask.sum() == 0:
                continue
            
            # Get retained edges
            # Detach before converting to numpy if tensor requires grad
            if training:
                retained_src = potential_src[final_mask].detach().cpu().numpy()
                retained_dst = potential_dst[final_mask].detach().cpu().numpy()
                retained_scores = relevance_scores[final_mask].detach().cpu().numpy()
            else:
                retained_src = potential_src[final_mask].cpu().numpy()
                retained_dst = potential_dst[final_mask].cpu().numpy()
                retained_scores = relevance_scores[final_mask].cpu().numpy()
            retained_attention = attention_weights[final_mask]
            
            # Optionally limit nodes per hop
            if max_nodes_per_hop is not None:
                top_k = min(max_nodes_per_hop, len(retained_scores))
                # Use argsort and get top_k indices - ensure contiguous array
                sorted_indices = np.argsort(retained_scores)
                # Get top_k indices in descending order (largest first)
                # Use np.flip to reverse, but ensure it's a copy, not a view
                top_indices = np.flip(sorted_indices[-top_k:], axis=0).copy()  # Copy to avoid negative strides
                # Ensure arrays are contiguous to avoid negative stride issues
                retained_src = np.ascontiguousarray(retained_src[top_indices])
                retained_dst = np.ascontiguousarray(retained_dst[top_indices])
                retained_scores = np.ascontiguousarray(retained_scores[top_indices])
                # For tensor, use indexing and ensure contiguous
                retained_attention = retained_attention[top_indices]
                if isinstance(retained_attention, torch.Tensor):
                    retained_attention = retained_attention.contiguous()
            
            # Store retained edges with attention weights
            # Convert attention to numpy and ensure it's contiguous
            if isinstance(retained_attention, torch.Tensor):
                attention_numpy = retained_attention.detach().cpu().numpy()
                # Ensure contiguous array to avoid negative stride issues
                attention_numpy = np.ascontiguousarray(attention_numpy)
                attention_list = attention_numpy.tolist()
            else:
                attention_list = retained_attention.tolist() if hasattr(retained_attention, 'tolist') else list(retained_attention)
            
            edge_triples = list(zip(
                retained_src.tolist(), 
                retained_dst.tolist(), 
                retained_scores.tolist(),
                attention_list
            ))
            if etype not in retained_edges:
                retained_edges[etype] = []
            retained_edges[etype].extend([(s, d, sc) for s, d, sc, _ in edge_triples])
            
            # Store for entity embedding update
            hop_edges[etype] = edge_triples
            
            # Store triplet logits and attention weights
            if use_gumbel_softmax:
                if etype not in all_triplet_logits:
                    all_triplet_logits[etype] = []
                all_triplet_logits[etype].append(triplet_logits if sampled_mask is None else triplet_logits)
                all_attention_weights[etype] = retained_attention
            
            # 5.4. Update: Add destination nodes for next hop
            newly_retained_nodes.setdefault(dst_type, set()).update(retained_dst.tolist())
            all_retained_nodes.setdefault(dst_type, set()).update(retained_dst.tolist())
        
        # 5.5. ENTITY EMBEDDING UPDATE (Eq. 4) after each hop
        if entity_updater is not None and hop < max_hops - 1:  # Don't update after last hop
            for node_type in updated_embeddings.keys():
                if node_type not in all_retained_nodes or len(all_retained_nodes[node_type]) == 0:
                    continue
                
                # Collect edges involving this node type
                node_edges = []
                for etype, edge_list in hop_edges.items():
                    src_type, rel_type, dst_type = etype
                    
                    if src_type == node_type:
                        # Outgoing edges: node is source
                        for src, dst, score, attn in edge_list:
                            node_edges.append((src, dst, attn, 'out'))
                    elif dst_type == node_type:
                        # Incoming edges: node is destination
                        for src, dst, score, attn in edge_list:
                            node_edges.append((src, dst, attn, 'in'))
                
                if len(node_edges) == 0:
                    continue
                
                # Prepare for message passing
                node_indices_list = []
                neighbor_emb_list = []
                attention_list = []
                
                for src, dst, attn, direction in node_edges:
                    if direction == 'out':
                        # Node is source, neighbor is destination
                        node_idx = src
                        neighbor_idx = dst
                        neighbor_type = dst_type
                    else:
                        # Node is destination, neighbor is source
                        node_idx = dst
                        neighbor_idx = src
                        neighbor_type = src_type
                    
                    # Only process if node is in retained nodes
                    # The neighbor should exist in the full graph embeddings
                    if node_idx not in all_retained_nodes[node_type]:
                        continue
                    
                    # Verify neighbor type and index are valid
                    if neighbor_type not in updated_embeddings:
                        continue
                    
                    embedding_size = updated_embeddings[neighbor_type].size(0)
                    if neighbor_idx >= embedding_size:
                        # Skip if neighbor index is out of bounds for this node type
                        # This can happen if there's a mismatch between edge indices and node types
                        continue
                    
                    # Get neighbor embedding safely (neighbor might not be in retained set yet,
                    # but we can still access its embedding from the full graph)
                    try:
                        neighbor_emb = updated_embeddings[neighbor_type][neighbor_idx]
                        
                        node_indices_list.append(int(node_idx))  # Ensure Python int
                        neighbor_emb_list.append(neighbor_emb)
                        # Convert attn to Python float if it's a numpy scalar
                        if isinstance(attn, (np.number, np.ndarray)):
                            attention_list.append(float(attn))
                        elif isinstance(attn, torch.Tensor):
                            attention_list.append(float(attn.item()))
                        else:
                            attention_list.append(float(attn))
                    except (IndexError, RuntimeError) as e:
                        # Skip if there's any error accessing the embedding
                        continue
                
                if len(node_indices_list) > 0:
                    # Convert to tensors
                    # Ensure node_indices_list contains Python ints, not numpy scalars
                    node_indices_list_clean = [int(idx) for idx in node_indices_list]
                    node_indices_tensor = torch.tensor(node_indices_list_clean, device=device, dtype=torch.long)
                    neighbor_embs_tensor = torch.stack(neighbor_emb_list)
                    # Ensure attention_list contains Python floats, not numpy scalars
                    attention_list_clean = [float(attn) if isinstance(attn, (np.number, np.ndarray)) else attn for attn in attention_list]
                    attention_tensor = torch.tensor(attention_list_clean, device=device, dtype=neighbor_embs_tensor.dtype)
                    
                    # Get current node embeddings
                    current_node_indices = sorted(list(all_retained_nodes[node_type]))
                    current_embs = updated_embeddings[node_type][current_node_indices]
                    
                    # Create mapping from node indices to position in current_embs
                    idx_to_pos = {idx: pos for pos, idx in enumerate(current_node_indices)}
                    node_positions = torch.tensor([idx_to_pos[idx] for idx in node_indices_list], device=device)
                    
                    # Update embeddings using EntityEmbeddingUpdater
                    if training:
                        updated_embs = entity_updater(
                            current_embs,
                            neighbor_embs_tensor,
                            attention_tensor,
                            node_positions
                        )
                    else:
                        # During inference, use no_grad
                        with torch.no_grad():
                            updated_embs = entity_updater(
                                current_embs,
                                neighbor_embs_tensor,
                                attention_tensor,
                                node_positions
                            )
                    
                    # Update stored embeddings
                    # Since initial embeddings don't require grad, we can do in-place updates
                    # The entity updater creates new tensors with gradients, which we store
                    if not isinstance(current_node_indices, torch.Tensor):
                        indices_tensor = torch.tensor(current_node_indices, device=device, dtype=torch.long)
                    else:
                        indices_tensor = current_node_indices.to(device)
                    
                    # Update embeddings (updated_embs from entity_updater has gradients if training)
                    updated_embeddings[node_type][indices_tensor] = updated_embs
        
        # Merge newly retained nodes for next iteration
        for node_type, new_nodes in newly_retained_nodes.items():
            current_nodes.setdefault(node_type, set()).update(new_nodes)
    
    # 6. Construct subgraph from retained nodes and edges
    subgraph_data = _construct_subgraph(full_data, all_retained_nodes, retained_edges, device)
    
    # Store updated embeddings in subgraph metadata
    subgraph_data._updated_embeddings = updated_embeddings
    subgraph_data._triplet_logits = all_triplet_logits
    subgraph_data._attention_weights = all_attention_weights
    
    info_dict = {
        'updated_embeddings': updated_embeddings,
        'triplet_logits': all_triplet_logits,
        'attention_weights': all_attention_weights
    }
    
    return subgraph_data, info_dict

print("Enhanced Algorithm 1 with Entity Updates and Gumbel-Softmax defined")


 Enhanced Algorithm 1 with Entity Updates and Gumbel-Softmax defined


## Test: Enhanced Algorithm 1

Test the enhanced Algorithm 1 with entity embedding updates and Gumbel-Softmax sampling.


In [45]:
# =========================================
# TEST: Enhanced Algorithm 1
# =========================================
print("="*80)
print("TESTING ENHANCED ALGORITHM 1: With Entity Updates & Gumbel-Softmax")
print("="*80)

# Check dependencies
required = {
    'model': model,
    'data': data,
    'paper2idx': paper2idx,
    'all_df': all_df,
    'query_encoder': query_encoder,
    'relevance_scorer': relevance_scorer,
    'entity_updater': entity_updater,
    'device': device
}

missing = [k for k, v in required.items() if v is None or k not in globals()]
if missing:
    print(f"[ERROR] Missing dependencies: {missing}")
    print("Please run Part 1 and Part 2 initialization cells first.")
else:
    print("All dependencies available")
    
    # Test with a known paper ID
    query_paper_id = "17396995"  # Example paper from our dataset
    
    try:
        # Check if paper exists
        if query_paper_id not in paper2idx:
            print(f"[ERROR] Paper {query_paper_id} not found in paper2idx")
        else:
            query_paper_idx = paper2idx[query_paper_id]
            print(f"[OK] Found query paper: {query_paper_id} (index: {query_paper_idx})")
            
            # Get query text
            query_row = all_df[all_df['publication_ID'] == int(query_paper_id)]
            if len(query_row) == 0:
                print(f"[ERROR] Paper {query_paper_id} not found in all_df")
            else:
                query_text = query_row['text'].values[0]
                print(f"[OK] Query text length: {len(query_text)} characters")
                
                # Define seed entities
                query_seed_entities = {'paper': [query_paper_idx]}
                print(f"[OK] Seed entities: {query_seed_entities}")
                
                # Test 1: Enhanced Algorithm 1 (inference mode)
                print("\n" + "-"*80)
                print("Test 1: Enhanced Algorithm 1 (Inference Mode)")
                print("-"*80)
                
                retrieved_subgraph, info_dict = attention_based_graph_retriever_enhanced(
                    query_text=query_text,
                    query_seed_entities=query_seed_entities,
                    gnn_model=model,
                    full_data=data,
                    query_encoder=query_encoder,
                    relevance_scorer=relevance_scorer,
                    entity_updater=entity_updater,
                    max_hops=2,
                    relevance_threshold=0.1,
                    max_nodes_per_hop=100,
                    use_gumbel_softmax=True,
                    gumbel_temperature=1.0,
                    training=False,  # Inference mode
                    device=device
                )
                
                # Verify results
                print("\n" + "="*80)
                print("[OK] ENHANCED ALGORITHM 1 TEST PASSED (Inference Mode)!")
                print("="*80)
                print(f"Retrieved subgraph statistics:")
                print(f"  Node types: {len(retrieved_subgraph.node_types)}")
                for node_type in retrieved_subgraph.node_types:
                    num_nodes = retrieved_subgraph[node_type].num_nodes
                    print(f"    {node_type}: {num_nodes} nodes")
                
                print(f"\n  Edge types: {len(retrieved_subgraph.edge_types)}")
                for etype in retrieved_subgraph.edge_types:
                    num_edges = retrieved_subgraph[etype].edge_index.size(1) if hasattr(retrieved_subgraph[etype], 'edge_index') else 0
                    print(f"    {etype}: {num_edges} edges")
                
                # Check for updated embeddings
                if hasattr(retrieved_subgraph, '_updated_embeddings'):
                    print(f"\n  [OK] Entity embeddings updated: {len(retrieved_subgraph._updated_embeddings)} node types")
                    for node_type, emb in retrieved_subgraph._updated_embeddings.items():
                        print(f"    {node_type}: {emb.shape}")
                
                # Check for triplet logits (Gumbel-Softmax)
                if hasattr(retrieved_subgraph, '_triplet_logits'):
                    print(f"\n  [OK] Triplet logits computed: {len(retrieved_subgraph._triplet_logits)} edge types")
                
                # Compare with full graph
                print(f"\n  Comparison with full graph:")
                print(f"    Full graph papers: {data['paper'].num_nodes:,}")
                print(f"    Retrieved papers: {retrieved_subgraph['paper'].num_nodes:,}")
                print(f"    Reduction: {(1 - retrieved_subgraph['paper'].num_nodes / data['paper'].num_nodes) * 100:.2f}%")
                
                # Test 2: Training mode (with gradients)
                print("\n" + "-"*80)
                print("Test 2: Enhanced Algorithm 1 (Training Mode)")
                print("-"*80)
                
                # Enable training mode
                entity_updater.train()
                relevance_scorer.train()
                
                retrieved_subgraph_train, info_dict_train = attention_based_graph_retriever_enhanced(
                    query_text=query_text,
                    query_seed_entities=query_seed_entities,
                    gnn_model=model,
                    full_data=data,
                    query_encoder=query_encoder,
                    relevance_scorer=relevance_scorer,
                    entity_updater=entity_updater,
                    max_hops=2,
                    relevance_threshold=0.1,
                    max_nodes_per_hop=50,  # Smaller for training test
                    use_gumbel_softmax=True,
                    gumbel_temperature=1.0,
                    training=True,  # Training mode
                    device=device
                )
                
                print("[OK] Training mode test passed!")
                print(f"  Retrieved {retrieved_subgraph_train['paper'].num_nodes} papers in training mode")
                
                # Verify gradients can flow
                if hasattr(retrieved_subgraph_train, '_updated_embeddings'):
                    # Check if embeddings require grad
                    for node_type, emb in retrieved_subgraph_train._updated_embeddings.items():
                        if emb.requires_grad:
                            print(f"  [OK] Gradients enabled for {node_type} embeddings")
                            break
                
                print("\n[OK] Enhanced Algorithm 1 is working correctly with Entity Updates and Gumbel-Softmax!")
                
    except Exception as e:
        print(f"\n[ERROR] ERROR during Enhanced Algorithm 1 test:")
        print(f"   {type(e).__name__}: {str(e)}")
        import traceback
        print("\nFull traceback:")
        traceback.print_exc()


TESTING ENHANCED ALGORITHM 1: With Entity Updates & Gumbel-Softmax
 All dependencies available
 Found query paper: 17396995 (index: 18104)
 Query text length: 645 characters
 Seed entities: {'paper': [18104]}

--------------------------------------------------------------------------------
Test 1: Enhanced Algorithm 1 (Inference Mode)
--------------------------------------------------------------------------------

 ENHANCED ALGORITHM 1 TEST PASSED (Inference Mode)!
Retrieved subgraph statistics:
  Node types: 5
    paper: 201 nodes
    author: 16 nodes
    keyword: 28 nodes
    venue: 2 nodes
    pubDate: 1 nodes

  Edge types: 8
    ('paper', 'cites', 'paper'): 15 edges
    ('paper', 'written_by', 'author'): 21 edges
    ('paper', 'mentions', 'keyword'): 73 edges
    ('paper', 'published_in', 'venue'): 3 edges
    ('paper', 'publication_date', 'pubDate'): 3 edges
    ('author', 'authored', 'paper'): 6 edges
    ('keyword', 'appears_in', 'paper'): 100 edges
    ('venue', 'published', 

In [46]:
# =========================================
# COMPLEXITY ASSESSMENT MODULE (CAM)
# =========================================
class ComplexityAssessmentModule(nn.Module):
    """
    MLP classifier that predicts question complexity (number of hops).
    Output determines retrieval budget: number of triplets = 5 × predicted_hops
    """
    def __init__(self, query_dim: int, hidden_dim: int = 256, max_hops: int = 4):
        super().__init__()
        self.query_dim = query_dim
        self.hidden_dim = hidden_dim
        self.max_hops = max_hops
        
        self.mlp = nn.Sequential(
            nn.Linear(query_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, max_hops + 1)
        )
    
    def forward(self, query_emb: torch.Tensor) -> torch.Tensor:
        """Predict complexity (number of hops)."""
        return self.mlp(query_emb)
    
    def predict_hops(self, query_emb: torch.Tensor) -> int:
        """Predict number of hops for a query."""
        with torch.no_grad():
            logits = self.forward(query_emb)
            if logits.dim() == 1:
                predicted = torch.argmax(logits, dim=0).item()
            else:
                predicted = torch.argmax(logits, dim=1).item()
        return predicted
    
    def get_retrieval_budget(self, query_emb: torch.Tensor) -> int:
        """Get number of triplets to retrieve based on predicted complexity."""
        hops = self.predict_hops(query_emb)
        return 5 * hops

# Initialize CAM (will be trained separately or jointly)
# cam = ComplexityAssessmentModule(query_dim=query_encoder.embedding_dim).to(device)
print("ComplexityAssessmentModule class defined")


 ComplexityAssessmentModule class defined


## Joint Training Framework


In [47]:
# =========================================
# JOINT TRAINING ALGORITHM
# =========================================
class JointTrainingLoss(nn.Module):
    """
    Joint loss for LLM and retriever optimization.
    Implements Algorithm 4 from GRIL paper, Section 4.3.
    """
    def __init__(self, alpha: float = 1.0, beta: float = 1.0, gamma: float = 0.1, 
                 use_graph_supervision: bool = False):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.use_graph_supervision = use_graph_supervision
    
    def forward(self, llm_logits: torch.Tensor, ground_truth: torch.Tensor,
                triplet_probabilities: Optional[torch.Tensor] = None,
                shortest_path_entities: Optional[torch.Tensor] = None,
                retrieved_entities: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Dict[str, float]]:
        """Compute joint training loss."""
        # 1. LLM Accuracy Loss
        if llm_logits.dim() == 2 and llm_logits.size(1) > 1:
            if ground_truth.dtype == torch.long:
                accuracy_loss = F.cross_entropy(llm_logits, ground_truth)
            else:
                accuracy_loss = F.binary_cross_entropy_with_logits(llm_logits, ground_truth)
        else:
            accuracy_loss = F.mse_loss(llm_logits, ground_truth.float())
        
        # 2. Retriever Feedback Loss
        retriever_feedback_loss = torch.tensor(0.0, device=llm_logits.device)
        if triplet_probabilities is not None:
            probs = F.softmax(triplet_probabilities, dim=-1)
            entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=-1).mean()
            retriever_feedback_loss = -entropy
        
        # 3. Graph Supervision Loss
        graph_supervision_loss = torch.tensor(0.0, device=llm_logits.device)
        if self.use_graph_supervision and shortest_path_entities is not None:
            if retrieved_entities is not None:
                graph_supervision_loss = F.binary_cross_entropy_with_logits(
                    retrieved_entities.float(), shortest_path_entities.float())
        
        total_loss = (self.alpha * accuracy_loss + self.beta * retriever_feedback_loss + 
                     self.gamma * graph_supervision_loss)
        
        loss_dict = {
            'total_loss': total_loss.item(),
            'accuracy_loss': accuracy_loss.item(),
            'retriever_feedback_loss': retriever_feedback_loss.item(),
            'graph_supervision_loss': graph_supervision_loss.item()
        }
        return total_loss, loss_dict

print("JointTrainingLoss class defined")


 JointTrainingLoss class defined


## SAG (Self-Attention Graph) Pooling Layer (Algorithm 3)

Implements Algorithm 3 from GRIL paper:
- Computes self-attention scores A_s for entities in retrieved subgraph
- Aggregates entity embeddings with attention weights
- Projects to LLM embedding space via MLP
- Outputs soft graph token h_GT for LLM input


In [48]:
# =========================================
# SAG (SELF-ATTENTION GRAPH) POOLING LAYER (Algorithm 3)
# =========================================
class SAGPooling(nn.Module):
    """
    Implements Algorithm 3: Self-Attention Graph Pooling
    
    Generates dense graph-level embedding (soft graph token) from retrieved subgraph.
    Based on: Lee et al., "Self-attention graph pooling" (ICML 2019)
    
    Equation: h_GT = MLP(Σ_{ei∈Gs} A_si · h'_ei)
    where:
    - A_si: Self-attention score for entity e_i
    - h'_ei: Updated entity embedding from Algorithm 1
    - h_GT: Graph token embedding for LLM input
    """
    def __init__(self, 
                 node_dim: int,
                 graph_token_dim: int = 512,
                 hidden_dim: int = 256,
                 dropout: float = 0.1):
        """
        Args:
            node_dim: Dimension of node embeddings (from entity updater)
            graph_token_dim: Output dimension for graph token (LLM embedding space)
            hidden_dim: Hidden dimension for MLP
            dropout: Dropout rate
        """
        super().__init__()
        self.node_dim = node_dim
        self.graph_token_dim = graph_token_dim
        self.hidden_dim = hidden_dim
        
        # Self-attention mechanism: computes A_s ∈ R^{|Gs|×1}
        # Single-layer MLP that outputs attention scores
        self.attention = nn.Sequential(
            nn.Linear(node_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1)
        )
        
        # MLP to project aggregated embeddings to LLM embedding space
        self.graph_token_mlp = nn.Sequential(
            nn.Linear(node_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, graph_token_dim)
        )
    
    def forward(self, 
                node_embeddings: torch.Tensor,
                node_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Compute graph token from node embeddings.
        
        Args:
            node_embeddings: Entity embeddings [num_nodes, node_dim]
                           Can be from updated_embeddings or initial GNN embeddings
            node_mask: Optional mask to indicate which nodes to include [num_nodes]
                      (1 for include, 0 for exclude)
        
        Returns:
            graph_token: Graph-level embedding [graph_token_dim]
        """
        if node_mask is not None:
            # Apply mask: set masked nodes to zero
            node_embeddings = node_embeddings * node_mask.unsqueeze(-1)
            num_valid_nodes = node_mask.sum()
        else:
            num_valid_nodes = node_embeddings.size(0)
        
        # Step 1: Compute self-attention scores A_s ∈ R^{|Gs|×1}
        attention_scores = self.attention(node_embeddings)  # [num_nodes, 1]
        
        # Apply mask to attention scores if provided
        if node_mask is not None:
            attention_scores = attention_scores * node_mask.unsqueeze(-1)
            # Set masked nodes to very negative value for softmax
            attention_scores = attention_scores + (1 - node_mask.unsqueeze(-1)) * (-1e9)
        
        # Normalize attention scores (softmax)
        attention_weights = F.softmax(attention_scores, dim=0)  # [num_nodes, 1]
        
        # Step 2: Aggregate entity embeddings with attention weights
        # Σ_{ei∈Gs} A_si · h'_ei
        aggregated_embedding = (attention_weights * node_embeddings).sum(dim=0)  # [node_dim]
        
        # Step 3: Project to LLM embedding space via MLP
        # h_GT = MLP(aggregated_embedding)
        graph_token = self.graph_token_mlp(aggregated_embedding)  # [graph_token_dim]
        
        return graph_token, attention_weights.squeeze(-1)
    
    def compute_graph_token_from_subgraph(self,
                                         subgraph: HeteroData,
                                         updated_embeddings: Dict[str, torch.Tensor],
                                         node_type: str = 'paper') -> torch.Tensor:
        """
        Convenience method to compute graph token from a HeteroData subgraph.
        
        Args:
            subgraph: Retrieved subgraph from Algorithm 1
            updated_embeddings: Dictionary of updated embeddings from Algorithm 1
            node_type: Primary node type to use for graph token (default: 'paper')
        
        Returns:
            graph_token: Graph-level embedding [graph_token_dim]
            attention_weights: Attention weights for each node [num_nodes]
        """
        if node_type not in updated_embeddings:
            # Fallback: use subgraph node features
            if node_type in subgraph.node_types and hasattr(subgraph[node_type], 'x'):
                node_embeddings = subgraph[node_type].x
            else:
                raise ValueError(f"Node type {node_type} not found in updated_embeddings or subgraph")
        else:
            # Get embeddings for nodes in subgraph
            # Map subgraph node indices to full graph indices
            if hasattr(subgraph[node_type], '_node_mapping'):
                # Reverse mapping: subgraph index -> full graph index
                subgraph_to_full = {v: k for k, v in subgraph[node_type]._node_mapping.items()}
                num_subgraph_nodes = subgraph[node_type].num_nodes
                node_embeddings = []
                for i in range(num_subgraph_nodes):
                    if i in subgraph_to_full:
                        full_idx = subgraph_to_full[i]
                        node_embeddings.append(updated_embeddings[node_type][full_idx])
                    else:
                        # Use subgraph features as fallback
                        if hasattr(subgraph[node_type], 'x'):
                            node_embeddings.append(subgraph[node_type].x[i])
                        else:
                            raise ValueError(f"Cannot find embedding for subgraph node {i}")
                node_embeddings = torch.stack(node_embeddings)
            else:
                # No mapping available, use subgraph features directly
                if hasattr(subgraph[node_type], 'x'):
                    node_embeddings = subgraph[node_type].x
                else:
                    raise ValueError(f"Subgraph node type {node_type} has no features")
        
        return self.forward(node_embeddings)

# Initialize SAG Pooling layer
# node_dim should match the output dimension of EntityEmbeddingUpdater (64)
sag_pooling = SAGPooling(
    node_dim=64,  # From EntityEmbeddingUpdater output
    graph_token_dim=512,  # LLM embedding dimension (will match LLM later)
    hidden_dim=256,
    dropout=0.1
).to(device)

print(f"SAGPooling initialized (node_dim=64, graph_token_dim=512)")


 SAGPooling initialized (node_dim=64, graph_token_dim=512)


## Test: SAG Pooling with Algorithm 1 Output

Test SAG Pooling layer with a retrieved subgraph from Algorithm 1.


In [49]:
# =========================================
# TEST: SAG Pooling with Algorithm 1 Output
# =========================================
print("="*80)
print("TESTING SAG POOLING LAYER (Algorithm 3)")
print("="*80)

# Check dependencies
required = {
    'model': model,
    'data': data,
    'paper2idx': paper2idx,
    'all_df': all_df,
    'query_encoder': query_encoder,
    'relevance_scorer': relevance_scorer,
    'entity_updater': entity_updater,
    'sag_pooling': sag_pooling,
    'device': device
}

missing = [k for k, v in required.items() if v is None or k not in globals()]
if missing:
    print(f"[ERROR] Missing dependencies: {missing}")
    print("Please run Part 1, Part 2, and SAG Pooling initialization cells first.")
else:
    print("All dependencies available")
    
    # Test with a known paper ID
    query_paper_id = "17396995"  # Example paper from our dataset
    
    try:
        # Check if paper exists
        if query_paper_id not in paper2idx:
            print(f"[ERROR] Paper {query_paper_id} not found in paper2idx")
        else:
            query_paper_idx = paper2idx[query_paper_id]
            print(f"[OK] Found query paper: {query_paper_id} (index: {query_paper_idx})")
            
            # Get query text
            query_row = all_df[all_df['publication_ID'] == int(query_paper_id)]
            if len(query_row) == 0:
                print(f"[ERROR] Paper {query_paper_id} not found in all_df")
            else:
                query_text = query_row['text'].values[0]
                print(f"[OK] Query text length: {len(query_text)} characters")
                
                # Define seed entities
                query_seed_entities = {'paper': [query_paper_idx]}
                print(f"[OK] Seed entities: {query_seed_entities}")
                
                # Step 1: Run Enhanced Algorithm 1 to get subgraph
                print("\n" + "-"*80)
                print("Step 1: Running Enhanced Algorithm 1...")
                print("-"*80)
                
                retrieved_subgraph, info_dict = attention_based_graph_retriever_enhanced(
                    query_text=query_text,
                    query_seed_entities=query_seed_entities,
                    gnn_model=model,
                    full_data=data,
                    query_encoder=query_encoder,
                    relevance_scorer=relevance_scorer,
                    entity_updater=entity_updater,
                    max_hops=2,
                    relevance_threshold=0.1,
                    max_nodes_per_hop=100,
                    use_gumbel_softmax=True,
                    gumbel_temperature=1.0,
                    training=False,  # Inference mode
                    device=device
                )
                
                print(f"[OK] Retrieved subgraph with {retrieved_subgraph['paper'].num_nodes} papers")
                
                # Step 2: Compute graph token using SAG Pooling
                print("\n" + "-"*80)
                print("Step 2: Computing Graph Token with SAG Pooling...")
                print("-"*80)
                
                # Get updated embeddings from Algorithm 1
                updated_embeddings = info_dict.get('updated_embeddings', {})
                
                # Compute graph token
                graph_token, attention_weights = sag_pooling.compute_graph_token_from_subgraph(
                    subgraph=retrieved_subgraph,
                    updated_embeddings=updated_embeddings,
                    node_type='paper'
                )
                
                # Verify results
                print("\n" + "="*80)
                print("[OK] SAG POOLING TEST PASSED!")
                print("="*80)
                print(f"Graph token statistics:")
                print(f"  Graph token shape: {graph_token.shape}")
                print(f"  Graph token dimension: {graph_token.shape[0]} (should match LLM embedding dim)")
                print(f"  Attention weights shape: {attention_weights.shape}")
                print(f"  Number of nodes: {attention_weights.shape[0]}")
                print(f"  Attention weights sum: {attention_weights.sum().item():.4f} (should be ~1.0)")
                print(f"  Max attention weight: {attention_weights.max().item():.4f}")
                print(f"  Min attention weight: {attention_weights.min().item():.4f}")
                
                # Show top nodes by attention
                top_k = min(5, len(attention_weights))
                top_indices = torch.argsort(attention_weights, descending=True)[:top_k]
                print(f"\n  Top {top_k} nodes by attention:")
                for i, idx in enumerate(top_indices):
                    print(f"    {i+1}. Node {idx.item()}: attention = {attention_weights[idx].item():.4f}")
                
                print("\n[OK] SAG Pooling is working correctly!")
                
    except Exception as e:
        print(f"\n[ERROR] ERROR during SAG Pooling test:")
        print(f"   {type(e).__name__}: {str(e)}")
        import traceback
        print("\nFull traceback:")
        traceback.print_exc()


TESTING SAG POOLING LAYER (Algorithm 3)
 All dependencies available
 Found query paper: 17396995 (index: 18104)
 Query text length: 645 characters
 Seed entities: {'paper': [18104]}

--------------------------------------------------------------------------------
Step 1: Running Enhanced Algorithm 1...
--------------------------------------------------------------------------------
 Retrieved subgraph with 201 papers

--------------------------------------------------------------------------------
Step 2: Computing Graph Token with SAG Pooling...
--------------------------------------------------------------------------------

 SAG POOLING TEST PASSED!
Graph token statistics:
  Graph token shape: torch.Size([512])
  Graph token dimension: 512 (should match LLM embedding dim)
  Attention weights shape: torch.Size([201])
  Number of nodes: 201
  Attention weights sum: 1.0000 (should be ~1.0)
  Max attention weight: 0.0067
  Min attention weight: 0.0034

  Top 5 nodes by attention:
    1.

## LLM Integration (Llama3-8B + LoRA)

This section implements:
- **Verbalization**: Convert graph triples to natural language text
- **LLM Integration**: Load Llama3-8B with LoRA fine-tuning
- **Input Formatting**: Combine Graph Token + Verbalized Triples + Question
- **Joint Training Support**: Enable end-to-end training with retriever


In [None]:
# =========================================
# INSTALL REQUIRED PACKAGES FOR LLM
# =========================================
# !pip install -U transformers peft accelerate bitsandbytes huggingface_hub

try:
    from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
    from peft import LoraConfig, get_peft_model, TaskType
    from huggingface_hub import login
    print("Required libraries available")
except ImportError:
    print("[WARNING]  Please install: pip install -U transformers peft accelerate bitsandbytes huggingface_hub")


## Verbalization: Convert Triples to Text


In [None]:
# =========================================
# VERBALIZATION: Convert Graph Triples to Text
# =========================================
from typing import List, Tuple, Dict
import torch

def verbalize_triples(
    subgraph: HeteroData,
    updated_embeddings: Dict[str, torch.Tensor],
    all_df: 'pd.DataFrame',
    paper2idx: Dict[str, int],
    author2idx: Dict[str, int] = None,
    keyword2idx: Dict[str, int] = None,
    venue2idx: Dict[str, int] = None,
    max_triples: int = 50,
    relevance_scores: Dict[Tuple[str, str, str], torch.Tensor] = None
) -> List[str]:
    """
    Convert graph triples to natural language text format.
    
    Format: "<entity_source → relation → entity_target>"
    
    Args:
        subgraph: Retrieved subgraph from Algorithm 1
        updated_embeddings: Updated embeddings from Algorithm 1
        all_df: DataFrame with paper metadata
        paper2idx: Paper ID to index mapping
        author2idx: Author ID to index mapping (optional)
        keyword2idx: Keyword to index mapping (optional)
        venue2idx: Venue to index mapping (optional)
        max_triples: Maximum number of triples to include
        relevance_scores: Optional relevance scores for ranking triples
    
    Returns:
        List of verbalized triple strings
    """
    verbalized_triples = []
    triple_scores = []
    
    # Helper function to get entity name from index
    def get_entity_name(node_type: str, idx: int) -> str:
        if node_type == 'paper':
            # Find paper ID from index
            paper_id = None
            for pid, i in paper2idx.items():
                if i == idx:
                    paper_id = pid
                    break
            if paper_id:
                row = all_df[all_df['publication_ID'] == int(paper_id)]
                if len(row) > 0:
                    title = row['title'].values[0] if 'title' in row.columns else f"Paper {paper_id}"
                    return title[:100]  # Truncate long titles
            return f"Paper_{idx}"
        elif node_type == 'author' and author2idx:
            for aid, i in author2idx.items():
                if i == idx:
                    return str(aid)
            return f"Author_{idx}"
        elif node_type == 'keyword' and keyword2idx:
            for kw, i in keyword2idx.items():
                if i == idx:
                    return kw
            return f"Keyword_{idx}"
        elif node_type == 'venue' and venue2idx:
            for v, i in venue2idx.items():
                if i == idx:
                    return v
            return f"Venue_{idx}"
        elif node_type == 'pubDate':
            return f"Year_{idx}"
        else:
            return f"{node_type}_{idx}"
    
    # Process each edge type
    for etype in subgraph.edge_types:
        if etype not in subgraph.edge_types:
            continue
        
        src_type, rel_type, dst_type = etype
        edge_index = subgraph[etype].edge_index
        
        if edge_index.size(1) == 0:
            continue
        
        # Get relevance scores for this edge type if available
        edge_scores = None
        if relevance_scores and etype in relevance_scores:
            edge_scores = relevance_scores[etype]
        
        # Process each edge
        for i in range(edge_index.size(1)):
            src_idx = edge_index[0, i].item()
            dst_idx = edge_index[1, i].item()
            
            # Map subgraph indices to full graph indices if mapping exists
            if hasattr(subgraph[src_type], '_node_mapping'):
                src_mapping = {v: k for k, v in subgraph[src_type]._node_mapping.items()}
                if src_idx in src_mapping:
                    src_full_idx = src_mapping[src_idx]
                else:
                    continue
            else:
                src_full_idx = src_idx
            
            if hasattr(subgraph[dst_type], '_node_mapping'):
                dst_mapping = {v: k for k, v in subgraph[dst_type]._node_mapping.items()}
                if dst_idx in dst_mapping:
                    dst_full_idx = dst_mapping[dst_idx]
                else:
                    continue
            else:
                dst_full_idx = dst_idx
            
            # Get entity names
            src_name = get_entity_name(src_type, src_full_idx)
            dst_name = get_entity_name(dst_type, dst_full_idx)
            rel_name = rel_type.replace('_', ' ').title()
            
            # Create verbalized triple
            triple_text = f"<{src_name} → {rel_name} → {dst_name}>"
            
            # Get score for ranking
            score = 1.0
            if edge_scores is not None:
                if isinstance(edge_scores, torch.Tensor) and i < len(edge_scores):
                    score = edge_scores[i].item() if isinstance(edge_scores[i], torch.Tensor) else edge_scores[i]
            
            verbalized_triples.append(triple_text)
            triple_scores.append(score)
    
    # Sort by relevance score (if available) and take top max_triples
    if triple_scores:
        sorted_indices = sorted(range(len(triple_scores)), key=lambda i: triple_scores[i], reverse=True)
        verbalized_triples = [verbalized_triples[i] for i in sorted_indices[:max_triples]]
    else:
        verbalized_triples = verbalized_triples[:max_triples]
    
    return verbalized_triples

print("Verbalization function defined")


## LLM Integration: Llama3-8B with LoRA


In [None]:
# =========================================
# LLM INTEGRATION: Llama3-8B with LoRA
# =========================================
class GRIL_LLM(nn.Module):
    """
    LLM Reasoner with Graph Token Integration for GRIL.
    
    Architecture:
    - Base Model: Llama3-8B
    - Fine-tuning: LoRA (rank 8)
    - Input Format: [Graph Token] + Reasoning Paths + Question
    - Output: Answer logits P_{φ,ψ}(a|Gs, q)
    """
    def __init__(self,
                 model_name: str = "meta-llama/Meta-Llama-3-8B",
                 use_4bit: bool = True,
                 lora_rank: int = 8,
                 lora_alpha: int = 16,
                 lora_dropout: float = 0.1,
                 device: torch.device = None):
        """
        Args:
            model_name: HuggingFace model name for Llama3-8B
            use_4bit: Whether to use 4-bit quantization (saves memory)
            lora_rank: LoRA rank (default: 8 as per GRIL paper)
            lora_alpha: LoRA alpha parameter
            lora_dropout: LoRA dropout rate
            device: Device to run model on
        """
        super().__init__()
        self.model_name = model_name
        self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.use_4bit = use_4bit
        
        # Load tokenizer
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            print(f"[OK] Tokenizer loaded: {model_name}")
        except Exception as e:
            print(f"[WARNING]  Could not load tokenizer: {e}")
            print("   Make sure you've authenticated with HuggingFace and have access to the model")
            self.tokenizer = None
        
        # Load model with quantization if requested
        try:
            if use_4bit:
                quantization_config = BitsAndBytesConfig(
                    load_in_4bit=True,
                    bnb_4bit_compute_dtype=torch.float16,
                    bnb_4bit_use_double_quant=True,
                    bnb_4bit_quant_type="nf4"
                )
            else:
                quantization_config = None
            
            self.base_model = AutoModelForCausalLM.from_pretrained(
                model_name,
                quantization_config=quantization_config,
                device_map="auto" if use_4bit else None,
                trust_remote_code=True,
                torch_dtype=torch.float16 if use_4bit else torch.float32
            )
            
            if not use_4bit:
                self.base_model = self.base_model.to(self.device)
            
            print(f"[OK] Base model loaded: {model_name}")
            
            # Apply LoRA
            lora_config = LoraConfig(
                task_type=TaskType.CAUSAL_LM,
                r=lora_rank,
                lora_alpha=lora_alpha,
                lora_dropout=lora_dropout,
                target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
                bias="none"
            )
            
            self.model = get_peft_model(self.base_model, lora_config)
            print(f"[OK] LoRA applied (rank={lora_rank}, alpha={lora_alpha})")
            
            # Get embedding dimension
            self.embedding_dim = self.model.get_input_embeddings().embedding_dim
            print(f"[OK] Model embedding dimension: {self.embedding_dim}")
            
        except Exception as e:
            print(f"[WARNING]  Could not load model: {e}")
            print("   Make sure you've authenticated and have access to meta-llama/Meta-Llama-3-8B")
            self.model = None
            self.embedding_dim = 512
    
    def format_input(self,
                    graph_token: torch.Tensor,
                    verbalized_triples: List[str],
                    question: str,
                    max_length: int = 2048) -> Dict[str, torch.Tensor]:
        """
        Format input for LLM: [Graph Token] + Reasoning Paths + Question
        
        Args:
            graph_token: Graph token embedding from SAG [graph_token_dim]
            verbalized_triples: List of verbalized triple strings
            question: Question text
            max_length: Maximum sequence length
        
        Returns:
            Dictionary with input_ids and attention_mask
        """
        if self.tokenizer is None:
            raise ValueError("Tokenizer not loaded. Cannot format input.")
        
        # Create reasoning paths text
        reasoning_paths = "\\n".join(verbalized_triples)
        
        # Format prompt according to GRIL specification
        prompt = f"[Graph Token] Based on the following reasoning paths, please answer the given question.\\n\\nReasoning Paths: {reasoning_paths}\\n\\nQuestion: {question}\\n\\nAnswer:"
        
        # Tokenize
        inputs = self.tokenizer(
            prompt,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length
        )
        
        # Move to device
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        return inputs, prompt
    
    def forward(self,
                graph_token: torch.Tensor,
                verbalized_triples: List[str],
                question: str,
                labels: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
        """
        Forward pass through LLM with graph token integration.
        
        Args:
            graph_token: Graph token embedding from SAG [graph_token_dim]
            verbalized_triples: List of verbalized triple strings
            question: Question text
            labels: Optional labels for training [batch_size, seq_len]
        
        Returns:
            Dictionary with logits and loss (if labels provided)
        """
        if self.model is None:
            raise ValueError("Model not loaded. Cannot perform forward pass.")
        
        # Format input
        inputs, prompt = self.format_input(graph_token, verbalized_triples, question)
        
        # Get input embeddings
        input_embeddings = self.model.get_input_embeddings()(inputs['input_ids'])
        
        # Project graph token to model embedding space if needed
        if graph_token.shape[0] != self.embedding_dim:
            # Use a simple linear projection (could be made learnable)
            if not hasattr(self, 'graph_token_proj'):
                self.graph_token_proj = nn.Linear(graph_token.shape[0], self.embedding_dim).to(self.device)
            graph_token_proj = self.graph_token_proj(graph_token.unsqueeze(0))  # [1, embedding_dim]
        else:
            graph_token_proj = graph_token.unsqueeze(0)  # [1, embedding_dim]
        
        # Find where to insert graph token (at the beginning, before [Graph Token] token)
        # For simplicity, prepend to input embeddings
        # In practice, you'd find the [Graph Token] token position and replace it
        batch_size = input_embeddings.size(0)
        graph_token_expanded = graph_token_proj.expand(batch_size, -1)  # [batch_size, embedding_dim]
        
        # Prepend graph token to input embeddings
        # This is a simplified approach - in practice, you'd replace the [Graph Token] token embedding
        combined_embeddings = torch.cat([graph_token_expanded, input_embeddings], dim=1)  # [batch_size, seq_len+1, embedding_dim]
        
        # Adjust attention mask
        attention_mask = inputs['attention_mask']
        graph_token_mask = torch.ones(batch_size, 1, device=self.device, dtype=attention_mask.dtype)
        combined_attention_mask = torch.cat([graph_token_mask, attention_mask], dim=1)
        
        # Forward pass
        outputs = self.model(
            inputs_embeds=combined_embeddings,
            attention_mask=combined_attention_mask,
            labels=labels,
            return_dict=True
        )
        
        return {
            'logits': outputs.logits,
            'loss': outputs.loss if labels is not None else None,
            'prompt': prompt
        }
    
    def generate(self,
                graph_token: torch.Tensor,
                verbalized_triples: List[str],
                question: str,
                max_new_tokens: int = 100,
                temperature: float = 0.7) -> str:
        """
        Generate answer from graph token and question.
        
        Args:
            graph_token: Graph token embedding from SAG
            verbalized_triples: List of verbalized triple strings
            question: Question text
            max_new_tokens: Maximum number of tokens to generate
            temperature: Sampling temperature
        
        Returns:
            Generated answer text
        """
        if self.model is None:
            raise ValueError("Model not loaded. Cannot generate.")
        
        # Format input
        inputs, _ = self.format_input(graph_token, verbalized_triples, question)
        
        # Get input embeddings
        input_embeddings = self.model.get_input_embeddings()(inputs['input_ids'])
        
        # Project and prepend graph token (same as forward)
        if graph_token.shape[0] != self.embedding_dim:
            if not hasattr(self, 'graph_token_proj'):
                self.graph_token_proj = nn.Linear(graph_token.shape[0], self.embedding_dim).to(self.device)
            graph_token_proj = self.graph_token_proj(graph_token.unsqueeze(0))
        else:
            graph_token_proj = graph_token.unsqueeze(0)
        
        batch_size = input_embeddings.size(0)
        graph_token_expanded = graph_token_proj.expand(batch_size, -1)
        combined_embeddings = torch.cat([graph_token_expanded, input_embeddings], dim=1)
        
        attention_mask = inputs['attention_mask']
        graph_token_mask = torch.ones(batch_size, 1, device=self.device, dtype=attention_mask.dtype)
        combined_attention_mask = torch.cat([graph_token_mask, attention_mask], dim=1)
        
        # Generate
        with torch.no_grad():
            outputs = self.model.generate(
                inputs_embeds=combined_embeddings,
                attention_mask=combined_attention_mask,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                do_sample=True,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id
            )
        
        # Decode
        generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Extract answer (text after "Answer:")
        if "Answer:" in generated_text:
            answer = generated_text.split("Answer:")[-1].strip()
        else:
            answer = generated_text
        
        return answer

# =========================================
# HUGGINGFACE AUTHENTICATION
# =========================================
from huggingface_hub import login

login(token="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")
print("[OK] HuggingFace authenticated")

print("GRIL_LLM class defined")


In [None]:
# =========================================
# TEST: LLM Integration (Full Pipeline)
# =========================================
print("="*80)
print("TESTING LLM INTEGRATION (Full GRIL Pipeline)")
print("="*80)

# Check dependencies
required = {
    'model': model,
    'data': data,
    'paper2idx': paper2idx,
    'all_df': all_df,
    'query_encoder': query_encoder,
    'relevance_scorer': relevance_scorer,
    'entity_updater': entity_updater,
    'sag_pooling': sag_pooling,
    'author2idx': author2idx,
    'keyword2idx': keyword2idx,
    'venue2idx': venue2idx,
    'device': device
}

missing = [k for k, v in required.items() if v is None or k not in globals()]
if missing:
    print(f"[ERROR] Missing dependencies: {missing}")
else:
    print("All dependencies available")
    
    # Initialize LLM (this will attempt to load the model)
    print("\n" + "-"*80)
    print("Initializing LLM (Llama3-8B + LoRA)...")
    print("-"*80)
    print("[WARNING]  This requires HuggingFace authentication and model access")
    
    try:
        llm = GRIL_LLM(
            model_name="meta-llama/Meta-Llama-3-8B",
            use_4bit=True,  # Use 4-bit quantization to save memory
            lora_rank=8,
            device=device
        )
        
        if llm.model is None:
            print("[WARNING]  Model not loaded. Skipping LLM test.")
            print("   Please authenticate and request access to Llama3-8B")
        else:
            # Test with a query
            query_paper_id = "17396995"
            query_paper_idx = paper2idx[query_paper_id]
            query_row = all_df[all_df['publication_ID'] == int(query_paper_id)]
            query_text = query_row['text'].values[0]
            question = f"What papers are related to: {query_row['title'].values[0]}?"
            
            print(f"\n[OK] Query paper: {query_paper_id}")
            print(f"[OK] Question: {question}")
            
            # Step 1: Run Algorithm 1
            print("\n" + "-"*80)
            print("Step 1: Running Algorithm 1...")
            print("-"*80)
            
            retrieved_subgraph, info_dict = attention_based_graph_retriever_enhanced(
                query_text=query_text,
                query_seed_entities={'paper': [query_paper_idx]},
                gnn_model=model,
                full_data=data,
                query_encoder=query_encoder,
                relevance_scorer=relevance_scorer,
                entity_updater=entity_updater,
                max_hops=2,
                relevance_threshold=0.1,
                max_nodes_per_hop=50,  # Smaller for testing
                use_gumbel_softmax=True,
                training=False,
                device=device
            )
            
            print(f"[OK] Retrieved {retrieved_subgraph['paper'].num_nodes} papers")
            
            # Step 2: Compute Graph Token with SAG
            print("\n" + "-"*80)
            print("Step 2: Computing Graph Token with SAG...")
            print("-"*80)
            
            updated_embeddings = info_dict.get('updated_embeddings', {})
            graph_token, attention_weights = sag_pooling.compute_graph_token_from_subgraph(
                subgraph=retrieved_subgraph,
                updated_embeddings=updated_embeddings,
                node_type='paper'
            )
            
            print(f"[OK] Graph token computed: {graph_token.shape}")
            
            # Step 3: Verbalize Triples
            print("\n" + "-"*80)
            print("Step 3: Verbalizing Triples...")
            print("-"*80)
            
            verbalized_triples = verbalize_triples(
                subgraph=retrieved_subgraph,
                updated_embeddings=updated_embeddings,
                all_df=all_df,
                paper2idx=paper2idx,
                author2idx=author2idx,
                keyword2idx=keyword2idx,
                venue2idx=venue2idx,
                max_triples=20  # Limit for testing
            )
            
            print(f"[OK] Verbalized {len(verbalized_triples)} triples")
            print(f"   Example: {verbalized_triples[0] if verbalized_triples else 'None'}")
            
            # Step 4: LLM Forward Pass
            print("\n" + "-"*80)
            print("Step 4: LLM Forward Pass...")
            print("-"*80)
            
            outputs = llm.forward(
                graph_token=graph_token,
                verbalized_triples=verbalized_triples,
                question=question
            )
            
            print(f"[OK] LLM forward pass completed")
            print(f"   Logits shape: {outputs['logits'].shape}")
            
            # Step 5: Generate Answer (optional, can be slow)
            print("\n" + "-"*80)
            print("Step 5: Generating Answer (this may take a while)...")
            print("-"*80)
            
            answer = llm.generate(
                graph_token=graph_token,
                verbalized_triples=verbalized_triples,
                question=question,
                max_new_tokens=50,
                temperature=0.7
            )
            
            print(f"\n[OK] Generated Answer:")
            print(f"   {answer}")
            
            print("\n" + "="*80)
            print("[OK] FULL GRIL PIPELINE TEST PASSED!")
            print("="*80)
            print("Pipeline: Algorithm 1 → SAG → Verbalization → LLM")
            
    except Exception as e:
        print(f"\n[ERROR] ERROR during LLM integration test:")
        print(f"   {type(e).__name__}: {str(e)}")
        import traceback
        print("\nFull traceback:")
        traceback.print_exc()


## Joint Training Framework (Algorithm 4)

Complete implementation of end-to-end training with:
- LLM accuracy loss
- Retriever feedback loss (with stop-gradient)
- Graph supervision loss (optional)
- Separate optimizers for retriever and LLM


In [None]:
# =========================================
# ENHANCED JOINT TRAINING LOSS (Algorithm 4)
# =========================================
class JointTrainingLoss(nn.Module):
    """
    Enhanced Joint Training Loss with proper stop-gradient mechanism.
    
    Implements Algorithm 4 from GRIL paper:
    L_joint = L_accuracy + L_retriever + L_supervision
    
    Key: Stop-gradient on LLM/SAG when updating retriever
    """
    def __init__(self, alpha: float = 1.0, beta: float = 1.0, gamma: float = 0.1):
        super().__init__()
        self.alpha = alpha  # Weight for LLM accuracy loss
        self.beta = beta    # Weight for retriever feedback loss
        self.gamma = gamma  # Weight for graph supervision loss
    
    def forward(self,
                llm_logits: torch.Tensor,
                ground_truth: torch.Tensor,
                triplet_logits: Optional[torch.Tensor] = None,
                shortest_path_mask: Optional[torch.Tensor] = None,
                retrieved_mask: Optional[torch.Tensor] = None,
                use_stop_gradient: bool = True) -> Tuple[torch.Tensor, Dict[str, float]]:
        """
        Compute joint training loss with stop-gradient mechanism.
        
        Args:
            llm_logits: LLM output logits [batch_size, vocab_size] or [batch_size, seq_len, vocab_size]
            ground_truth: Ground truth labels [batch_size] or [batch_size, seq_len]
            triplet_logits: Triplet probability logits from retriever [num_triplets]
            shortest_path_mask: Binary mask for entities on shortest paths [num_entities]
            retrieved_mask: Binary mask for retrieved entities [num_entities]
            use_stop_gradient: Whether to use stop-gradient for retriever feedback
        
        Returns:
            total_loss: Combined loss tensor
            loss_dict: Dictionary with individual loss components
        """
        # 1. LLM Accuracy Loss: L_accuracy = -log P_{φ,ψ}(a|Gs, q)
        if llm_logits.dim() == 3:  # [batch, seq_len, vocab_size]
            # Shift logits and labels for language modeling
            shift_logits = llm_logits[..., :-1, :].contiguous()
            shift_labels = ground_truth[..., 1:].contiguous()
            accuracy_loss = F.cross_entropy(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1),
                ignore_index=-100
            )
        elif llm_logits.dim() == 2:
            if ground_truth.dtype == torch.long:
                accuracy_loss = F.cross_entropy(llm_logits, ground_truth)
            else:
                accuracy_loss = F.binary_cross_entropy_with_logits(llm_logits, ground_truth)
        else:
            accuracy_loss = F.mse_loss(llm_logits, ground_truth.float())
        
        # 2. Retriever Feedback Loss: L_retriever = -log(P_{φ,ψ}(a|Gs, q) · P_θ(Gs|q))
        # Uses LLM logits as implicit feedback with stop-gradient
        retriever_feedback_loss = torch.tensor(0.0, device=llm_logits.device)
        if triplet_logits is not None:
            # Get LLM probability (stop-gradient if requested)
            if use_stop_gradient:
                llm_probs = F.softmax(llm_logits.detach(), dim=-1)
            else:
                llm_probs = F.softmax(llm_logits, dim=-1)
            
            # Use max probability as feedback signal
            max_llm_prob = llm_probs.max(dim=-1)[0].mean()
            
            # Triplet probabilities from retriever
            triplet_probs = F.softmax(triplet_logits, dim=-1)
            
            # Combined probability: P(a|Gs, q) · P(Gs|q)
            # Use negative log-likelihood of combined probability
            combined_prob = max_llm_prob * triplet_probs.mean()
            retriever_feedback_loss = -torch.log(combined_prob + 1e-10)
        
        # 3. Graph Supervision Loss: L_supervision = BCE(Gs covers P(q,a))
        graph_supervision_loss = torch.tensor(0.0, device=llm_logits.device)
        if shortest_path_mask is not None and retrieved_mask is not None:
            # Binary cross-entropy: supervise retrieved entities to cover shortest path entities
            graph_supervision_loss = F.binary_cross_entropy_with_logits(
                retrieved_mask.float(),
                shortest_path_mask.float()
            )
        
        # Total loss
        total_loss = (self.alpha * accuracy_loss + 
                     self.beta * retriever_feedback_loss + 
                     self.gamma * graph_supervision_loss)
        
        loss_dict = {
            'total_loss': total_loss.item(),
            'accuracy_loss': accuracy_loss.item(),
            'retriever_feedback_loss': retriever_feedback_loss.item(),
            'graph_supervision_loss': graph_supervision_loss.item()
        }
        
        return total_loss, loss_dict

print("Enhanced JointTrainingLoss class defined")


## Graph Supervision: Shortest Path Extraction (Algorithm 5)


In [None]:
# =========================================
# GRAPH SUPERVISION: Shortest Path Extraction (Algorithm 5)
# =========================================
from collections import deque
from typing import Set

def extract_shortest_paths(
    query_entities: Set[int],
    answer_entities: Set[int],
    full_data: HeteroData,
    node_type: str = 'paper'
) -> Set[int]:
    """
    Extract entities on shortest paths between query and answer entities.
    
    Implements Algorithm 5: P(q, a) = entities on shortest paths
    
    Args:
        query_entities: Set of query entity indices
        answer_entities: Set of answer entity indices
        full_data: Full HeteroData graph
        node_type: Node type to consider (default: 'paper')
    
    Returns:
        Set of entity indices on shortest paths
    """
    if node_type not in full_data.node_types:
        return set()
    
    # Get edge index for the node type (paper-cites-paper)
    edge_index = None
    for etype in full_data.edge_types:
        if etype == (node_type, 'cites', node_type):
            edge_index = full_data[etype].edge_index
            break
    
    if edge_index is None:
        return set()
    
    # Build adjacency list
    num_nodes = full_data[node_type].num_nodes
    adj_list = [[] for _ in range(num_nodes)]
    for i in range(edge_index.size(1)):
        src = edge_index[0, i].item()
        dst = edge_index[1, i].item()
        adj_list[src].append(dst)
    
    # BFS to find shortest paths from query entities to answer entities
    path_entities = set()
    
    for query_entity in query_entities:
        if query_entity >= num_nodes:
            continue
        
        # BFS from query entity
        queue = deque([(query_entity, [query_entity])])
        visited = {query_entity}
        found_paths = []
        
        while queue:
            current, path = queue.popleft()
            
            # Check if we reached an answer entity
            if current in answer_entities:
                found_paths.append(path)
                # Don't continue from answer entities
                continue
            
            # Explore neighbors
            for neighbor in adj_list[current]:
                if neighbor not in visited:
                    visited.add(neighbor)
                    queue.append((neighbor, path + [neighbor]))
        
        # Collect all entities on shortest paths
        for path in found_paths:
            path_entities.update(path)
    
    return path_entities

def create_entity_mask(
    path_entities: Set[int],
    all_entities: Set[int],
    num_entities: int
) -> torch.Tensor:
    """
    Create binary mask for entities on shortest paths.
    
    Args:
        path_entities: Set of entity indices on shortest paths
        all_entities: Set of all entity indices to consider
        num_entities: Total number of entities
    
    Returns:
        Binary mask tensor [num_entities] (1 for path entities, 0 otherwise)
    """
    mask = torch.zeros(num_entities, dtype=torch.float)
    for entity_idx in path_entities:
        if entity_idx < num_entities:
            mask[entity_idx] = 1.0
    return mask

print("Graph supervision functions defined")


## Joint Training Loop

Complete training loop that:
- Runs forward pass through full pipeline
- Computes joint loss with stop-gradient
- Updates retriever and LLM separately
- Includes checkpointing and evaluation


In [None]:
# =========================================
# JOINT TRAINING LOOP
# =========================================
class GRILTrainer:
    """
    Complete joint training framework for GRIL.
    
    Manages:
    - Forward pass through Algorithm 1 → SAG → LLM
    - Joint loss computation with stop-gradient
    - Separate optimizers for retriever and LLM
    - Checkpointing and evaluation
    """
    def __init__(self,
                 gnn_model: nn.Module,
                 query_encoder: QueryEncoder,
                 relevance_scorer: RelevanceScorer,
                 entity_updater: EntityEmbeddingUpdater,
                 sag_pooling: SAGPooling,
                 llm: GRIL_LLM,
                 full_data: HeteroData,
                 joint_loss: JointTrainingLoss,
                 all_df: 'pd.DataFrame' = None,
                 paper2idx: Dict[str, int] = None,
                 author2idx: Dict[str, int] = None,
                 keyword2idx: Dict[str, int] = None,
                 venue2idx: Dict[str, int] = None,
                 retriever_lr: float = 1e-4,
                 llm_lr: float = 1e-5,
                 device: torch.device = None):
        """
        Args:
            gnn_model: Pre-trained HeteroGNN (frozen)
            query_encoder: Query encoder (frozen or trainable)
            relevance_scorer: Relevance scorer (trainable)
            entity_updater: Entity embedding updater (trainable)
            sag_pooling: SAG pooling layer (trainable)
            llm: LLM with LoRA (trainable)
            full_data: Full graph data
            joint_loss: Joint training loss function
            retriever_lr: Learning rate for retriever components
            llm_lr: Learning rate for LLM
            device: Device to run on
        """
        self.gnn_model = gnn_model
        self.query_encoder = query_encoder
        self.relevance_scorer = relevance_scorer
        self.entity_updater = entity_updater
        self.sag_pooling = sag_pooling
        self.llm = llm
        self.full_data = full_data
        self.joint_loss = joint_loss
        self.all_df = all_df
        self.paper2idx = paper2idx
        self.author2idx = author2idx
        self.keyword2idx = keyword2idx
        self.venue2idx = venue2idx
        self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Freeze GNN (pre-trained)
        for param in self.gnn_model.parameters():
            param.requires_grad = False
        
        # Collect trainable parameters
        retriever_params = []
        retriever_params.extend(list(self.relevance_scorer.parameters()))
        retriever_params.extend(list(self.entity_updater.parameters()))
        retriever_params.extend(list(self.sag_pooling.parameters()))
        
        llm_params = []
        if self.llm.model is not None:
            # Only LoRA parameters are trainable
            llm_params.extend([p for p in self.llm.model.parameters() if p.requires_grad])
            if hasattr(self.llm, 'graph_token_proj'):
                llm_params.extend(list(self.llm.graph_token_proj.parameters()))
        
        # Create optimizers
        self.retriever_optimizer = torch.optim.Adam(retriever_params, lr=retriever_lr)
        self.llm_optimizer = torch.optim.Adam(llm_params, lr=llm_lr) if llm_params else None
        
        print(f"GRILTrainer initialized")
        print(f"   Retriever parameters: {sum(p.numel() for p in retriever_params):,}")
        print(f"   LLM parameters: {sum(p.numel() for p in llm_params):,}")
    
    def forward_pass(self,
                    query_text: str,
                    query_seed_entities: Dict[str, List[int]],
                    question: str,
                    ground_truth: Optional[torch.Tensor] = None,
                    answer_entities: Optional[Set[int]] = None) -> Dict:
        """
        Complete forward pass through GRIL pipeline.
        
        Args:
            query_text: Query text for retrieval
            query_seed_entities: Seed entities for Algorithm 1
            question: Question text for LLM
            ground_truth: Ground truth labels for training
            answer_entities: Answer entity indices (for graph supervision)
        
        Returns:
            Dictionary with outputs and intermediate results
        """
        # Step 1: Algorithm 1 - Retrieve subgraph
        retrieved_subgraph, info_dict = attention_based_graph_retriever_enhanced(
            query_text=query_text,
            query_seed_entities=query_seed_entities,
            gnn_model=self.gnn_model,
            full_data=self.full_data,
            query_encoder=self.query_encoder,
            relevance_scorer=self.relevance_scorer,
            entity_updater=self.entity_updater,
            max_hops=2,
            relevance_threshold=0.1,
            max_nodes_per_hop=100,
            use_gumbel_softmax=True,
            gumbel_temperature=1.0,
            training=True,  # Enable gradients
            device=self.device
        )
        
        # Step 2: SAG Pooling - Compute graph token
        updated_embeddings = info_dict.get('updated_embeddings', {})
        graph_token, attention_weights = self.sag_pooling.compute_graph_token_from_subgraph(
            subgraph=retrieved_subgraph,
            updated_embeddings=updated_embeddings,
            node_type='paper'
        )
        
        # Step 3: Verbalize triples
        verbalized_triples = verbalize_triples(
            subgraph=retrieved_subgraph,
            updated_embeddings=updated_embeddings,
            all_df=self.all_df,
            paper2idx=self.paper2idx,
            author2idx=self.author2idx,
            keyword2idx=self.keyword2idx,
            venue2idx=self.venue2idx,
            max_triples=50
        )
        
        # Step 4: LLM forward pass
        if self.llm.model is None:
            # Mock LLM output for testing
            llm_outputs = {
                'logits': torch.randn(1, 100, 32000, device=self.device),  # Mock logits
                'loss': None
            }
        else:
            llm_outputs = self.llm.forward(
                graph_token=graph_token,
                verbalized_triples=verbalized_triples,
                question=question,
                labels=ground_truth
            )
        
        # Step 5: Extract triplet logits for retriever feedback
        triplet_logits = None
        if 'triplet_logits' in info_dict:
            # Concatenate all triplet logits
            all_logits = []
            for etype, logits_list in info_dict['triplet_logits'].items():
                if isinstance(logits_list, list) and len(logits_list) > 0:
                    if isinstance(logits_list[0], torch.Tensor):
                        all_logits.append(torch.cat(logits_list))
            if all_logits:
                triplet_logits = torch.cat(all_logits)
        
        # Step 6: Graph supervision (if answer entities provided)
        shortest_path_mask = None
        retrieved_mask = None
        if answer_entities is not None:
            query_entity_set = set(query_seed_entities.get('paper', []))
            path_entities = extract_shortest_paths(
                query_entities=query_entity_set,
                answer_entities=answer_entities,
                full_data=self.full_data,
                node_type='paper'
            )
            
            # Create masks
            num_papers = self.full_data['paper'].num_nodes
            shortest_path_mask = create_entity_mask(path_entities, path_entities, num_papers).to(self.device)
            
            # Get retrieved entities
            retrieved_papers = set()
            if 'paper' in retrieved_subgraph.node_types:
                if hasattr(retrieved_subgraph['paper'], '_node_mapping'):
                    retrieved_papers = set(retrieved_subgraph['paper']._node_mapping.keys())
            retrieved_mask = create_entity_mask(retrieved_papers, retrieved_papers, num_papers).to(self.device)
        
        return {
            'llm_logits': llm_outputs['logits'],
            'graph_token': graph_token,
            'verbalized_triples': verbalized_triples,
            'retrieved_subgraph': retrieved_subgraph,
            'triplet_logits': triplet_logits,
            'shortest_path_mask': shortest_path_mask,
            'retrieved_mask': retrieved_mask,
            'info_dict': info_dict
        }
    
    def train_step(self,
                   query_text: str,
                   query_seed_entities: Dict[str, List[int]],
                   question: str,
                   ground_truth: torch.Tensor,
                   answer_entities: Optional[Set[int]] = None) -> Dict[str, float]:
        """
        Single training step.
        
        Returns:
            Dictionary with loss values
        """
        # Forward pass
        outputs = self.forward_pass(
            query_text=query_text,
            query_seed_entities=query_seed_entities,
            question=question,
            ground_truth=ground_truth,
            answer_entities=answer_entities
        )
        
        # Compute joint loss
        total_loss, loss_dict = self.joint_loss(
            llm_logits=outputs['llm_logits'],
            ground_truth=ground_truth,
            triplet_logits=outputs['triplet_logits'],
            shortest_path_mask=outputs['shortest_path_mask'],
            retrieved_mask=outputs['retrieved_mask'],
            use_stop_gradient=True
        )
        
        # Backward pass with separate optimizers
        # 1. Update LLM (full gradient)
        if self.llm_optimizer is not None:
            self.llm_optimizer.zero_grad()
        
        # 2. Update Retriever (with stop-gradient on LLM)
        self.retriever_optimizer.zero_grad()
        
        # Compute retriever loss separately (with stop-gradient)
        if outputs['triplet_logits'] is not None:
            # Retriever feedback loss
            llm_probs = F.softmax(outputs['llm_logits'].detach(), dim=-1)
            max_llm_prob = llm_probs.max(dim=-1)[0].mean()
            triplet_probs = F.softmax(outputs['triplet_logits'], dim=-1)
            combined_prob = max_llm_prob * triplet_probs.mean()
            retriever_loss = -torch.log(combined_prob + 1e-10)
            
            # Graph supervision loss
            if outputs['shortest_path_mask'] is not None:
                graph_sup_loss = F.binary_cross_entropy_with_logits(
                    outputs['retrieved_mask'],
                    outputs['shortest_path_mask']
                )
                retriever_loss = retriever_loss + 0.1 * graph_sup_loss
            
            retriever_loss.backward(retain_graph=True)
            self.retriever_optimizer.step()
        
        # LLM loss (accuracy loss)
        accuracy_loss = loss_dict['accuracy_loss']
        if self.llm_optimizer is not None and outputs['llm_logits'].requires_grad:
            # Re-compute accuracy loss for LLM update
            if outputs['llm_logits'].dim() == 3:
                shift_logits = outputs['llm_logits'][..., :-1, :].contiguous()
                shift_labels = ground_truth[..., 1:].contiguous()
                llm_loss = F.cross_entropy(
                    shift_logits.view(-1, shift_logits.size(-1)),
                    shift_labels.view(-1),
                    ignore_index=-100
                )
            else:
                llm_loss = F.cross_entropy(outputs['llm_logits'], ground_truth)
            
            llm_loss.backward()
            self.llm_optimizer.step()
        
        return loss_dict

print("GRILTrainer class defined")


In [None]:
# =========================================
# CHECKPOINTING AND TRAINING UTILITIES
# =========================================
import os
from pathlib import Path

def save_checkpoint(trainer: GRILTrainer,
                   epoch: int,
                   loss_history: List[float],
                   checkpoint_dir: Path,
                   best_loss: float = float('inf')):
    """Save training checkpoint."""
    checkpoint_dir.mkdir(parents=True, exist_ok=True)
    
    checkpoint = {
        'epoch': epoch,
        'loss_history': loss_history,
        'best_loss': best_loss,
        'relevance_scorer_state': trainer.relevance_scorer.state_dict(),
        'entity_updater_state': trainer.entity_updater.state_dict(),
        'sag_pooling_state': trainer.sag_pooling.state_dict(),
        'retriever_optimizer_state': trainer.retriever_optimizer.state_dict(),
    }
    
    if trainer.llm_optimizer is not None:
        checkpoint['llm_optimizer_state'] = trainer.llm_optimizer.state_dict()
    
    if trainer.llm.model is not None:
        checkpoint['llm_state'] = trainer.llm.model.state_dict()
    
    # Save latest
    torch.save(checkpoint, checkpoint_dir / 'latest_checkpoint.pt')
    
    # Save best
    if loss_history and loss_history[-1] < best_loss:
        torch.save(checkpoint, checkpoint_dir / 'best_checkpoint.pt')
        return loss_history[-1]
    
    return best_loss

def load_checkpoint(trainer: GRILTrainer, checkpoint_path: Path):
    """Load training checkpoint."""
    checkpoint = torch.load(checkpoint_path, map_location=trainer.device)
    
    trainer.relevance_scorer.load_state_dict(checkpoint['relevance_scorer_state'])
    trainer.entity_updater.load_state_dict(checkpoint['entity_updater_state'])
    trainer.sag_pooling.load_state_dict(checkpoint['sag_pooling_state'])
    trainer.retriever_optimizer.load_state_dict(checkpoint['retriever_optimizer_state'])
    
    if 'llm_optimizer_state' in checkpoint and trainer.llm_optimizer is not None:
        trainer.llm_optimizer.load_state_dict(checkpoint['llm_optimizer_state'])
    
    if 'llm_state' in checkpoint and trainer.llm.model is not None:
        trainer.llm.model.load_state_dict(checkpoint['llm_state'])
    
    return checkpoint.get('epoch', 0), checkpoint.get('loss_history', []), checkpoint.get('best_loss', float('inf'))

print("Checkpointing utilities defined")


## Complete Training Loop

Full training loop with:
- Epoch iteration
- Loss tracking
- Checkpointing
- Evaluation


In [None]:
# =========================================
# COMPLETE TRAINING LOOP
# =========================================
def train_gril(trainer: GRILTrainer,
               train_data: List[Dict],
               val_data: List[Dict] = None,
               num_epochs: int = 10,
               checkpoint_dir: Path = None,
               eval_interval: int = 1,
               save_interval: int = 5):
    """
    Complete training loop for GRIL.
    
    Args:
        trainer: GRILTrainer instance
        train_data: List of training examples, each with:
                   - 'query_text': str
                   - 'query_seed_entities': Dict[str, List[int]]
                   - 'question': str
                   - 'ground_truth': torch.Tensor (tokenized answer)
                   - 'answer_entities': Optional[Set[int]] (for graph supervision)
        val_data: Validation data (same format)
        num_epochs: Number of training epochs
        checkpoint_dir: Directory to save checkpoints
        eval_interval: Evaluate every N epochs
        save_interval: Save checkpoint every N epochs
    """
    if checkpoint_dir is None:
        checkpoint_dir = Path('checkpoints') / 'gril_training'
    checkpoint_dir.mkdir(parents=True, exist_ok=True)
    
    loss_history = []
    best_loss = float('inf')
    start_epoch = 0
    
    # Try to load existing checkpoint
    latest_checkpoint = checkpoint_dir / 'latest_checkpoint.pt'
    if latest_checkpoint.exists():
        print(f"Loading checkpoint from {latest_checkpoint}")
        start_epoch, loss_history, best_loss = load_checkpoint(trainer, latest_checkpoint)
        print(f"Resuming from epoch {start_epoch}, best loss: {best_loss:.4f}")
    
    print(f"\n{'='*80}")
    print(f"Starting GRIL Training")
    print(f"{'='*80}")
    print(f"Training examples: {len(train_data)}")
    print(f"Validation examples: {len(val_data) if val_data else 0}")
    print(f"Epochs: {num_epochs}")
    print(f"Device: {trainer.device}")
    print(f"{'='*80}\n")
    
    for epoch in range(start_epoch, num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        print("-" * 80)
        
        # Training mode
        trainer.relevance_scorer.train()
        trainer.entity_updater.train()
        trainer.sag_pooling.train()
        if trainer.llm.model is not None:
            trainer.llm.model.train()
        
        epoch_losses = []
        epoch_loss_components = {'accuracy_loss': [], 'retriever_feedback_loss': [], 'graph_supervision_loss': []}
        
        # Training loop
        for i, example in enumerate(train_data):
            try:
                # Get data
                query_text = example['query_text']
                query_seed_entities = example['query_seed_entities']
                question = example['question']
                ground_truth = example['ground_truth'].to(trainer.device)
                answer_entities = example.get('answer_entities', None)
                
                # Training step
                loss_dict = trainer.train_step(
                    query_text=query_text,
                    query_seed_entities=query_seed_entities,
                    question=question,
                    ground_truth=ground_truth,
                    answer_entities=answer_entities
                )
                
                epoch_losses.append(loss_dict['total_loss'])
                epoch_loss_components['accuracy_loss'].append(loss_dict.get('accuracy_loss', 0))
                epoch_loss_components['retriever_feedback_loss'].append(loss_dict.get('retriever_feedback_loss', 0))
                epoch_loss_components['graph_supervision_loss'].append(loss_dict.get('graph_supervision_loss', 0))
                
                # Print progress
                if (i + 1) % 10 == 0:
                    avg_loss = np.mean(epoch_losses[-10:])
                    print(f"  Step {i+1}/{len(train_data)}: Loss = {avg_loss:.4f}")
            
            except Exception as e:
                print(f"  Error in step {i+1}: {e}")
                continue
        
        # Epoch statistics
        avg_epoch_loss = np.mean(epoch_losses) if epoch_losses else 0.0
        loss_history.append(avg_epoch_loss)
        
        print(f"\nEpoch {epoch+1} Summary:")
        print(f"  Average Loss: {avg_epoch_loss:.4f}")
        if epoch_losses:
            print(f"  Accuracy Loss: {np.mean(epoch_loss_components['accuracy_loss']):.4f}")
            print(f"  Retriever Feedback Loss: {np.mean(epoch_loss_components['retriever_feedback_loss']):.4f}")
            print(f"  Graph Supervision Loss: {np.mean(epoch_loss_components['graph_supervision_loss']):.4f}")
        
        # Evaluation
        if val_data and (epoch + 1) % eval_interval == 0:
            print(f"\nEvaluating on validation set...")
            val_loss = evaluate_gril(trainer, val_data)
            print(f"  Validation Loss: {val_loss:.4f}")
        
        # Save checkpoint
        if (epoch + 1) % save_interval == 0:
            best_loss = save_checkpoint(trainer, epoch + 1, loss_history, checkpoint_dir, best_loss)
            print(f"  Checkpoint saved (best loss: {best_loss:.4f})")
        
        print()
    
    # Final checkpoint
    best_loss = save_checkpoint(trainer, num_epochs, loss_history, checkpoint_dir, best_loss)
    print(f"{'='*80}")
    print(f"Training Complete!")
    print(f"Final best loss: {best_loss:.4f}")
    print(f"{'='*80}")

def evaluate_gril(trainer: GRILTrainer, val_data: List[Dict]) -> float:
    """Evaluate GRIL on validation data."""
    trainer.relevance_scorer.eval()
    trainer.entity_updater.eval()
    trainer.sag_pooling.eval()
    if trainer.llm.model is not None:
        trainer.llm.model.eval()
    
    val_losses = []
    
    with torch.no_grad():
        for example in val_data:
            try:
                query_text = example['query_text']
                query_seed_entities = example['query_seed_entities']
                question = example['question']
                ground_truth = example['ground_truth'].to(trainer.device)
                answer_entities = example.get('answer_entities', None)
                
                # Forward pass
                outputs = trainer.forward_pass(
                    query_text=query_text,
                    query_seed_entities=query_seed_entities,
                    question=question,
                    ground_truth=ground_truth,
                    answer_entities=answer_entities
                )
                
                # Compute loss
                _, loss_dict = trainer.joint_loss(
                    llm_logits=outputs['llm_logits'],
                    ground_truth=ground_truth,
                    triplet_logits=outputs['triplet_logits'],
                    shortest_path_mask=outputs['shortest_path_mask'],
                    retrieved_mask=outputs['retrieved_mask'],
                    use_stop_gradient=False  # No gradients needed for eval
                )
                
                val_losses.append(loss_dict['total_loss'])
            
            except Exception as e:
                continue
    
    return np.mean(val_losses) if val_losses else float('inf')

print("Training loop functions defined")


## Summary: Complete GRIL Implementation

**All components are now implemented:**

- **Algorithm 1**: Enhanced with Entity Updates (Eq. 4) and Gumbel-Softmax (Eq. 5)
- **Algorithm 3**: SAG Pooling Layer for graph token generation
- **Verbalization**: Triple-to-text conversion
- **LLM Integration**: Llama3-8B with LoRA fine-tuning
- **Algorithm 4**: Joint Training Framework with stop-gradient
- **Algorithm 5**: Graph Supervision (shortest path extraction)
- **Training Loop**: Complete with checkpointing and evaluation

**Ready for end-to-end training!**


## Training Setup Guide

This section shows how to prepare data and start training the GRIL model.


In [None]:
# =========================================
# PREPARE TRAINING DATA FOR GRIL
# =========================================
from transformers import AutoTokenizer
from typing import List, Dict, Set, Optional

def prepare_gril_training_data(
    raw_train_data: List[Dict],
    all_df: 'pd.DataFrame',
    paper2idx: Dict[str, int],
    tokenizer: AutoTokenizer = None,
    max_answer_length: int = 100,
    include_answer_entities: bool = False
) -> List[Dict]:
    """
    Convert raw training data to GRIL training format.
    
    Args:
        raw_train_data: List of dicts with keys like 'publication_ID', 'title', 'abstract', 'Citations'
        all_df: DataFrame with paper metadata
        paper2idx: Mapping from paper ID to index
        tokenizer: Tokenizer for encoding answers (if None, will create simple token IDs)
        max_answer_length: Maximum length for answer tokens
        include_answer_entities: Whether to extract answer entities for graph supervision
    
    Returns:
        List of training examples in GRIL format:
        {
            'query_text': str,
            'query_seed_entities': Dict[str, List[int]],
            'question': str,
            'ground_truth': torch.Tensor,
            'answer_entities': Optional[Set[int]]
        }
    """
    gril_data = []
    
    # Initialize tokenizer if not provided
    if tokenizer is None:
        try:
            tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
        except:
            print("Warning: Could not load tokenizer. Using simple token IDs.")
            tokenizer = None
    
    for item in raw_train_data:
        try:
            # Get query paper info
            paper_id = str(item.get('publication_ID', ''))
            if paper_id not in paper2idx:
                continue
            
            query_paper_idx = paper2idx[paper_id]
            
            # Get query text (title + abstract)
            query_row = all_df[all_df['publication_ID'] == int(paper_id)]
            if len(query_row) == 0:
                continue
            
            query_text = query_row['text'].values[0] if 'text' in query_row.columns else ''
            if not query_text or len(query_text) < 10:
                continue
            
            # Create question
            title = item.get('title', 'this paper')
            question = f"What papers are related to: {title}?"
            
            # Get answer (citations)
            citations = item.get('Citations', [])
            if not citations:
                # If no citations, use empty answer
                answer_text = "No related papers found."
            else:
                # Create answer from citations
                citation_titles = []
                for cit_id in citations[:5]:  # Limit to 5 citations
                    cit_row = all_df[all_df['publication_ID'] == int(cit_id)]
                    if len(cit_row) > 0 and 'title' in cit_row.columns:
                        citation_titles.append(cit_row['title'].values[0])
                
                if citation_titles:
                    answer_text = "Related papers: " + "; ".join(citation_titles)
                else:
                    answer_text = "No related papers found."
            
            # Tokenize answer
            if tokenizer is not None:
                answer_tokens = tokenizer(
                    answer_text,
                    return_tensors='pt',
                    padding='max_length',
                    truncation=True,
                    max_length=max_answer_length
                )['input_ids'].squeeze(0)
            else:
                # Simple token IDs (just use character codes)
                answer_tokens = torch.tensor([ord(c) % 1000 for c in answer_text[:max_answer_length]], dtype=torch.long)
                if len(answer_tokens) < max_answer_length:
                    answer_tokens = torch.cat([answer_tokens, torch.zeros(max_answer_length - len(answer_tokens), dtype=torch.long)])
            
            # Extract answer entities for graph supervision (optional)
            answer_entities = None
            if include_answer_entities and citations:
                answer_entities = set()
                for cit_id in citations:
                    cit_str = str(cit_id)
                    if cit_str in paper2idx:
                        answer_entities.add(paper2idx[cit_str])
            
            # Create GRIL training example
            gril_example = {
                'query_text': query_text,
                'query_seed_entities': {'paper': [query_paper_idx]},
                'question': question,
                'ground_truth': answer_tokens,
                'answer_entities': answer_entities
            }
            
            gril_data.append(gril_example)
        
        except Exception as e:
            print(f"Error processing item: {e}")
            continue
    
    print(f"Prepared {len(gril_data)} training examples from {len(raw_train_data)} raw examples")
    return gril_data

print("Training data preparation function defined")


## Initialize GRIL Trainer

Set up all components and create the trainer.


In [None]:
# =========================================
# INITIALIZE GRIL TRAINER
# =========================================

# Check that all required components are available
required_components = {
    'model': model,
    'data': data,
    'query_encoder': query_encoder,
    'relevance_scorer': relevance_scorer,
    'entity_updater': entity_updater,
    'sag_pooling': sag_pooling,
    'all_df': all_df,
    'paper2idx': paper2idx,
    'author2idx': author2idx,
    'keyword2idx': keyword2idx,
    'venue2idx': venue2idx,
    'device': device
}

missing = [k for k, v in required_components.items() if v is None or k not in globals()]
if missing:
    print(f"[ERROR] Missing components: {missing}")
    print("Please run all previous cells to initialize components.")
else:
    print("All required components available")
    
    # Initialize LLM (optional - can train without it using mock outputs)
    print("\nInitializing LLM...")
    try:
        llm = GRIL_LLM(
            model_name="meta-llama/Meta-Llama-3-8B",
            use_4bit=True,
            lora_rank=8,
            device=device
        )
        if llm.model is None:
            print("LLM not loaded. Training will use mock LLM outputs.")
        else:
            print("LLM loaded successfully")
    except Exception as e:
        print(f"Could not load LLM: {e}")
        print("Training will use mock LLM outputs.")
        llm = GRIL_LLM(model_name=None, device=device)  # Mock LLM
    
    # Initialize joint loss
    joint_loss = JointTrainingLoss(alpha=1.0, beta=1.0, gamma=0.1)
    
    # Create trainer
    trainer = GRILTrainer(
        gnn_model=model,
        query_encoder=query_encoder,
        relevance_scorer=relevance_scorer,
        entity_updater=entity_updater,
        sag_pooling=sag_pooling,
        llm=llm,
        full_data=data,
        joint_loss=joint_loss,
        all_df=all_df,
        paper2idx=paper2idx,
        author2idx=author2idx,
        keyword2idx=keyword2idx,
        venue2idx=venue2idx,
        retriever_lr=1e-4,
        llm_lr=1e-5,
        device=device
    )
    
    print("\nGRIL Trainer initialized and ready for training!")


## Prepare Training Data

Convert raw training data to GRIL format.


In [None]:
# =========================================
# PREPARE TRAINING AND VALIDATION DATA
# =========================================

# Use a subset for faster training (adjust as needed)
TRAIN_SUBSET_SIZE = 100  # Use first 100 examples for quick testing
VAL_SUBSET_SIZE = 20     # Use first 20 examples for validation

print("Preparing training data...")
gril_train_data = prepare_gril_training_data(
    raw_train_data=train_data[:TRAIN_SUBSET_SIZE],
    all_df=all_df,
    paper2idx=paper2idx,
    tokenizer=None,  # Will use simple token IDs
    max_answer_length=50,
    include_answer_entities=False  # Set to True if you have citation data
)

print("\nPreparing validation data...")
gril_val_data = prepare_gril_training_data(
    raw_train_data=val_data[:VAL_SUBSET_SIZE],
    all_df=all_df,
    paper2idx=paper2idx,
    tokenizer=None,
    max_answer_length=50,
    include_answer_entities=False
)

print(f"\nTraining data: {len(gril_train_data)} examples")
print(f"Validation data: {len(gril_val_data)} examples")
print("\nSample training example:")
if gril_train_data:
    sample = gril_train_data[0]
    print(f"  Query text length: {len(sample['query_text'])} chars")
    print(f"  Seed entities: {sample['query_seed_entities']}")
    print(f"  Question: {sample['question']}")
    print(f"  Ground truth shape: {sample['ground_truth'].shape}")


## Start Training

Run the training loop. Adjust parameters as needed.


In [None]:
# =========================================
# START GRIL TRAINING
# =========================================

# Training parameters
NUM_EPOCHS = 5           # Number of training epochs
EVAL_INTERVAL = 1        # Evaluate every N epochs
SAVE_INTERVAL = 2        # Save checkpoint every N epochs
CHECKPOINT_DIR = Path('checkpoints') / 'gril_training'

print("Starting GRIL training...")
print(f"Training examples: {len(gril_train_data)}")
print(f"Validation examples: {len(gril_val_data)}")
print(f"Epochs: {NUM_EPOCHS}")
print(f"Device: {device}")
print("\nNote: This may take a while. Each example runs the full pipeline:")
print("  Algorithm 1 → SAG Pooling → Verbalization → LLM")

# Start training
train_gril(
    trainer=trainer,
    train_data=gril_train_data,
    val_data=gril_val_data,
    num_epochs=NUM_EPOCHS,
    checkpoint_dir=CHECKPOINT_DIR,
    eval_interval=EVAL_INTERVAL,
    save_interval=SAVE_INTERVAL
)

print("\nTraining complete! Checkpoints saved in:", CHECKPOINT_DIR)


## Training Instructions Summary

**To train the GRIL model, follow these steps:**

1. **Run all previous cells** to initialize:
   - GNN model (pre-trained)
   - Graph data and mappings
   - All GRIL components (Algorithm 1, SAG, LLM, etc.)

2. **Run the "Initialize GRIL Trainer" cell** above to create the trainer

3. **Run the "Prepare Training Data" cell** to convert your data to GRIL format

4. **Run the "Start Training" cell** to begin training

**Training Parameters (adjust as needed):**
- `NUM_EPOCHS`: Number of training epochs (default: 5)
- `TRAIN_SUBSET_SIZE`: Number of training examples (default: 100 for testing)
- `VAL_SUBSET_SIZE`: Number of validation examples (default: 20)
- `retriever_lr`: Learning rate for retriever (default: 1e-4)
- `llm_lr`: Learning rate for LLM (default: 1e-5)

**Checkpoints:**
- Saved in `checkpoints/gril_training/`
- `latest_checkpoint.pt`: Most recent checkpoint
- `best_checkpoint.pt`: Best model (lowest loss)

**Note:** Training can be slow because each example runs the full pipeline. Start with a small subset to test, then scale up.
