In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import SAGEConv
from torch_geometric.data import DataLoader
import os

In [18]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# === GNN Ranker Model ===
class GNNRanker(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(GNNRanker, self).__init__()
        
        #Mean Aggregation
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)

        # Max Aggrgation
        # self.conv1 = SAGEConv(input_dim, hidden_dim, aggr='max')
        # self.conv2 = SAGEConv(hidden_dim, hidden_dim, aggr='max')

        # Min Aggrgation
        # self.conv1 = SAGEConv(input_dim, hidden_dim, aggr='min')
        # self.conv2 = SAGEConv(hidden_dim, hidden_dim, aggr='min')

        # Sum Aggrgation
        # self.conv1 = SAGEConv(input_dim, hidden_dim, aggr='sum')
        # self.conv2 = SAGEConv(hidden_dim, hidden_dim, aggr='sum')

        self.scorer = nn.Linear(hidden_dim, 1)
    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return self.scorer(x).squeeze(-1)  # shape: [num_nodes]


In [19]:
# === Loss Function ===
def ranking_loss(scores, positive_ids):
    loss = 0.0
    pos_scores = scores[positive_ids]
    all_indices = torch.arange(len(scores), device=scores.device)
    neg_indices = list(set(all_indices.tolist()) - set(positive_ids.tolist()))
    
    if len(neg_indices) == 0 or len(pos_scores) == 0:
        return torch.tensor(0.0, device=scores.device, requires_grad=True)

    neg_scores = scores[neg_indices]
    for pos in pos_scores:
        margin = 1.0
        loss += torch.sum(F.relu(margin - pos + neg_scores))
    return loss / len(pos_scores)

In [20]:
# === MRR & F1 Evaluation ===
def compute_mrr(scores, positive_ids):
    sorted_indices = torch.argsort(scores, descending=True)
    ranks = [(sorted_indices == pid).nonzero(as_tuple=True)[0].item() + 1 for pid in positive_ids if pid < len(scores)]
    reciprocals = [1.0 / rank for rank in ranks]
    return sum(reciprocals) / len(ranks) if ranks else 0.0

def compute_f1_at_k(scores, positive_ids, k=5):
    top_k = torch.argsort(scores, descending=True)[:k]
    retrieved_set = set(top_k.tolist())
    gold_set = set(positive_ids.tolist())
    tp = len(retrieved_set & gold_set)
    precision = tp / k
    recall = tp / len(gold_set) if gold_set else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0
    return f1

In [21]:
from tqdm import tqdm

In [22]:
# === Load Data ===
train_graphs = torch.load('data/train.pt')
train_list = [v['graph'] for v in train_graphs.values()]
train_loader = DataLoader(train_list, batch_size=1, shuffle=True)

# === Initialize Model and Optimizer ===
model = GNNRanker(input_dim=768, hidden_dim=256).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)



# Mean Aggregation

In [None]:
# === Training Loop ===
for epoch in range(10):
    model.train()
    total_loss = 0
    total_mrr = 0
    total_f1 = 0
    batches = 0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        batch = batch.to(DEVICE)

        if batch.edges_index is None:
            continue
        if batch.edges_index.dtype != torch.long:
            batch.edges_index = batch.edges_index.long()

        scores = model(batch.x, batch.edges_index)
        positive_ids = batch.positive_ids.to(DEVICE)

        loss = ranking_loss(scores, positive_ids)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_mrr += compute_mrr(scores, positive_ids)
        total_f1 += compute_f1_at_k(scores, positive_ids)
        batches += 1

    avg_loss = total_loss / batches if batches else 0
    avg_mrr = total_mrr / batches if batches else 0
    avg_f1 = total_f1 / batches if batches else 0

    print(f"Epoch {epoch+1} | Loss: {avg_loss:.4f} | MRR: {avg_mrr:.4f} | F1@5: {avg_f1:.4f}")

Epoch 1: 100%|███████████████████████████████████████████████████████████████████| 167454/167454 [29:10<00:00, 95.67it/s]


Epoch 1 | Loss: 0.3840 | MRR: 0.6766 | F1@5: 0.6313


Epoch 2: 100%|██████████████████████████████████████████████████████████████████| 167454/167454 [26:38<00:00, 104.79it/s]


Epoch 2 | Loss: 0.2549 | MRR: 0.6856 | F1@5: 0.6333


Epoch 3: 100%|███████████████████████████████████████████████████████████████████| 167454/167454 [29:11<00:00, 95.59it/s]


Epoch 3 | Loss: 0.2282 | MRR: 0.6874 | F1@5: 0.6339


Epoch 4: 100%|██████████████████████████████████████████████████████████████████| 167454/167454 [26:10<00:00, 106.60it/s]


Epoch 4 | Loss: 0.2139 | MRR: 0.6885 | F1@5: 0.6341


Epoch 5: 100%|███████████████████████████████████████████████████████████████████| 167454/167454 [28:09<00:00, 99.11it/s]


Epoch 5 | Loss: 0.2024 | MRR: 0.6892 | F1@5: 0.6343


Epoch 8: 100%|███████████████████████████████████████████████████████████████████| 167454/167454 [47:39<00:00, 58.56it/s]


Epoch 8 | Loss: 0.1832 | MRR: 0.6907 | F1@5: 0.6347


Epoch 9: 100%|█████████████████████████████████████████████████████████████████| 167454/167454 [1:48:16<00:00, 25.78it/s]


Epoch 9 | Loss: 0.1801 | MRR: 0.6911 | F1@5: 0.6348


Epoch 10:   5%|███▎                                                                | 8196/167454 [01:26<27:40, 95.89it/s]

In [25]:
dev_graphs = torch.load('data/dev.pt')
dev_list = [v['graph'] for v in dev_graphs.values()]
dev_loader = DataLoader(dev_list, batch_size=1, shuffle=False)


In [15]:
model.eval()
total_mrr = 0.0
total_f1 = 0.0
batches = 0

with torch.no_grad():
    for batch in tqdm(dev_loader):
        batch = batch.to(DEVICE)

        if batch.edges_index.dtype != torch.long:
            batch.edges_index = batch.edges_index.long()

        scores = model(batch.x, batch.edges_index)
        positive_ids = batch.positive_ids.to(DEVICE)

        mrr = compute_mrr(scores, positive_ids)
        f1 = compute_f1_at_k(scores, positive_ids, k=5)

        total_mrr += mrr
        total_f1 += f1
        batches += 1

print(f"[Dev Eval] MRR: {total_mrr / batches:.4f} | F1@5: {total_f1 / batches:.4f}")


100%|█████████████████████████████████████████████████████████████████████████████| 12576/12576 [01:04<00:00, 196.28it/s]

[Dev Eval] MRR: 0.6462 | F1@5: 0.6258





#Generation

In [45]:
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from rouge_score import rouge_scorer
from nltk import word_tokenize
import numpy as np
from tqdm import tqdm

# Load LLM
tokenizer = T5Tokenizer.from_pretrained('t5-small')
t5_model = T5ForConditionalGeneration.from_pretrained('t5-small').to(DEVICE)

scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)

def compute_f1(prediction, ground_truth):
    pred_tokens = word_tokenize(prediction.lower())
    gold_tokens = word_tokenize(ground_truth.lower())
    common = set(pred_tokens) & set(gold_tokens)
    if len(common) == 0:
        return 0.0
    precision = len(common) / len(pred_tokens)
    recall = len(common) / len(gold_tokens)
    f1 = 2 * precision * recall / (precision + recall)
    return f1

# === Evaluation Loop ===
total_f1, total_rouge = 0.0, 0.0
count = 0
model.eval()

for i, (qid, sample) in enumerate(tqdm(dev_graphs.items())):
    graph = sample['graph'].to(DEVICE)
    query = sample['query']
    passage_titles = sample['passage_titles']
    positive_ids = sample['positive_ids']

    # Gold answer (concatenated titles of gold passages)
    gold_answer = " ".join([passage_titles[i] for i in positive_ids])

    # GNN Scoring + Ranking
    with torch.no_grad():
        scores = model(graph.x, graph.edges_index)
    topk = torch.argsort(scores, descending=True)[:5]
    top_passages = [passage_titles[i] for i in topk.cpu()]
    context = " ".join(top_passages)

    # LLM Input
    prompt = f"question: {query} context: {context}"
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(DEVICE)

    # Generate Answer
    output = t5_model.generate(input_ids, max_length=50)
    prediction = tokenizer.decode(output[0], skip_special_tokens=True)

    # Score
    f1 = compute_f1(prediction, gold_answer)
    rouge = scorer.score(prediction, gold_answer)['rougeL'].fmeasure

    total_f1 += f1
    total_rouge += rouge
    count += 1

# Final Results
print(f"\nAverage F1: {total_f1 / count:.4f}")
print(f"Average ROUGE-L: {total_rouge / count:.4f}")


100%|██████████████████████████████████████████████████████████████████████████████| 12576/12576 [21:31<00:00,  9.74it/s]


Average F1: 0.2339
Average ROUGE-L: 0.2350





In [43]:
import nltk
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


True

#Max Aggregation

In [12]:
# === Training Loop ===
for epoch in range(5):
    model.train()
    total_loss = 0
    total_mrr = 0
    total_f1 = 0
    batches = 0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        batch = batch.to(DEVICE)

        if batch.edges_index is None:
            continue
        if batch.edges_index.dtype != torch.long:
            batch.edges_index = batch.edges_index.long()

        scores = model(batch.x, batch.edges_index)
        positive_ids = batch.positive_ids.to(DEVICE)

        loss = ranking_loss(scores, positive_ids)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_mrr += compute_mrr(scores, positive_ids)
        total_f1 += compute_f1_at_k(scores, positive_ids)
        batches += 1

    avg_loss = total_loss / batches if batches else 0
    avg_mrr = total_mrr / batches if batches else 0
    avg_f1 = total_f1 / batches if batches else 0

    print(f"Epoch {epoch+1} | Loss: {avg_loss:.4f} | MRR: {avg_mrr:.4f} | F1@5: {avg_f1:.4f}")

Epoch 1: 100%|██████████████████████████████████████████████████████████████████| 167454/167454 [19:11<00:00, 145.38it/s]


Epoch 1 | Loss: 0.2561 | MRR: 0.6873 | F1@5: 0.6334


IOPub message rate exceeded.██████████████████▏                                  | 80467/167454 [10:38<13:14, 109.45it/s]
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

Epoch 3: 100%|██████████████████████████████████████████████████████████████████| 167454/167454 [21:09<00:00, 131.88it/s]


Epoch 3 | Loss: 0.1600 | MRR: 0.6953 | F1@5: 0.6355


Epoch 4: 100%|██████████████████████████████████████████████████████████████████| 167454/167454 [23:41<00:00, 117.82it/s]


Epoch 4 | Loss: 0.1502 | MRR: 0.6966 | F1@5: 0.6359


IOPub message rate exceeded.█████████████████████████████▋                      | 110864/167454 [16:56<08:20, 113.08it/s]
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [14]:
model.eval()
total_mrr = 0.0
total_f1 = 0.0
batches = 0

with torch.no_grad():
    for batch in tqdm(dev_loader):
        batch = batch.to(DEVICE)

        if batch.edges_index.dtype != torch.long:
            batch.edges_index = batch.edges_index.long()

        scores = model(batch.x, batch.edges_index)
        positive_ids = batch.positive_ids.to(DEVICE)

        mrr = compute_mrr(scores, positive_ids)
        f1 = compute_f1_at_k(scores, positive_ids, k=5)

        total_mrr += mrr
        total_f1 += f1
        batches += 1

print(f"[Dev Eval] MRR: {total_mrr / batches:.4f} | F1@5: {total_f1 / batches:.4f}")


100%|█████████████████████████████████████████████████████████████████████████████| 12576/12576 [00:55<00:00, 226.93it/s]

[Dev Eval] MRR: 0.6658 | F1@5: 0.6319





In [15]:
import nltk
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [16]:
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from rouge_score import rouge_scorer
from nltk import word_tokenize
import numpy as np
from tqdm import tqdm

# Load LLM
tokenizer = T5Tokenizer.from_pretrained('t5-small')
t5_model = T5ForConditionalGeneration.from_pretrained('t5-small').to(DEVICE)

scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)

def compute_f1(prediction, ground_truth):
    pred_tokens = word_tokenize(prediction.lower())
    gold_tokens = word_tokenize(ground_truth.lower())
    common = set(pred_tokens) & set(gold_tokens)
    if len(common) == 0:
        return 0.0
    precision = len(common) / len(pred_tokens)
    recall = len(common) / len(gold_tokens)
    f1 = 2 * precision * recall / (precision + recall)
    return f1

# === Evaluation Loop ===
total_f1, total_rouge = 0.0, 0.0
count = 0
model.eval()

for i, (qid, sample) in enumerate(tqdm(dev_graphs.items())):
    graph = sample['graph'].to(DEVICE)
    query = sample['query']
    passage_titles = sample['passage_titles']
    positive_ids = sample['positive_ids']

    # Gold answer (concatenated titles of gold passages)
    gold_answer = " ".join([passage_titles[i] for i in positive_ids])

    # GNN Scoring + Ranking
    with torch.no_grad():
        scores = model(graph.x, graph.edges_index)
    topk = torch.argsort(scores, descending=True)[:5]
    top_passages = [passage_titles[i] for i in topk.cpu()]
    context = " ".join(top_passages)

    # LLM Input
    prompt = f"question: {query} context: {context}"
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(DEVICE)

    # Generate Answer
    output = t5_model.generate(input_ids, max_length=50)
    prediction = tokenizer.decode(output[0], skip_special_tokens=True)

    # Score
    f1 = compute_f1(prediction, gold_answer)
    rouge = scorer.score(prediction, gold_answer)['rougeL'].fmeasure

    total_f1 += f1
    total_rouge += rouge
    count += 1

# Final Results
print(f"\nAverage F1: {total_f1 / count:.4f}")
print(f"Average ROUGE-L: {total_rouge / count:.4f}")


You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
100%|██████████████████████████████████████████████████████████████████████████████| 12576/12576 [15:56<00:00, 13.15it/s]


Average F1: 0.2205
Average ROUGE-L: 0.2218





#Sum Aggregate

In [24]:
# === Training Loop ===
for epoch in range(5):
    model.train()
    total_loss = 0
    total_mrr = 0
    total_f1 = 0
    batches = 0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        batch = batch.to(DEVICE)

        if batch.edges_index is None:
            continue
        if batch.edges_index.dtype != torch.long:
            batch.edges_index = batch.edges_index.long()

        scores = model(batch.x, batch.edges_index)
        positive_ids = batch.positive_ids.to(DEVICE)

        loss = ranking_loss(scores, positive_ids)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_mrr += compute_mrr(scores, positive_ids)
        total_f1 += compute_f1_at_k(scores, positive_ids)
        batches += 1

    avg_loss = total_loss / batches if batches else 0
    avg_mrr = total_mrr / batches if batches else 0
    avg_f1 = total_f1 / batches if batches else 0

    print(f"Epoch {epoch+1} | Loss: {avg_loss:.4f} | MRR: {avg_mrr:.4f} | F1@5: {avg_f1:.4f}")

Epoch 1: 100%|██████████████████████████████████████████████████████████████████| 167454/167454 [25:41<00:00, 108.65it/s]


Epoch 1 | Loss: 0.4131 | MRR: 0.6840 | F1@5: 0.6319


IOPub message rate exceeded.                                                     | 30590/167454 [04:43<17:20, 131.48it/s]
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

Epoch 2: 100%|██████████████████████████████████████████████████████████████████| 167454/167454 [25:35<00:00, 109.05it/s]


Epoch 2 | Loss: 0.3813 | MRR: 0.6906 | F1@5: 0.6339


IOPub message rate exceeded.█████████████████████████████                        | 107554/167454 [16:17<10:16, 97.20it/s]
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

Epoch 3: 100%|██████████████████████████████████████████████████████████████████| 167454/167454 [26:27<00:00, 105.45it/s]


Epoch 3 | Loss: 0.4082 | MRR: 0.6927 | F1@5: 0.6345


IOPub message rate exceeded.█▉                                                    | 39121/167454 [09:15<39:20, 54.38it/s]
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

Epoch 4: 100%|█████████████████████████████████████████████████████████████████| 167454/167454 [1:05:38<00:00, 42.51it/s]


Epoch 4 | Loss: 0.4084 | MRR: 0.6941 | F1@5: 0.6348


IOPub message rate exceeded.███████████████████████▊                              | 93140/167454 [32:30<21:37, 57.27it/s]
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

Epoch 5: 100%|███████████████████████████████████████████████████████████████████| 167454/167454 [58:32<00:00, 47.68it/s]

Epoch 5 | Loss: 0.4646 | MRR: 0.6951 | F1@5: 0.6351





In [26]:
model.eval()
total_mrr = 0.0
total_f1 = 0.0
batches = 0

with torch.no_grad():
    for batch in tqdm(dev_loader):
        batch = batch.to(DEVICE)

        if batch.edges_index.dtype != torch.long:
            batch.edges_index = batch.edges_index.long()

        scores = model(batch.x, batch.edges_index)
        positive_ids = batch.positive_ids.to(DEVICE)

        mrr = compute_mrr(scores, positive_ids)
        f1 = compute_f1_at_k(scores, positive_ids, k=5)

        total_mrr += mrr
        total_f1 += f1
        batches += 1

print(f"[Dev Eval] MRR: {total_mrr / batches:.4f} | F1@5: {total_f1 / batches:.4f}")


100%|█████████████████████████████████████████████████████████████████████████████| 12576/12576 [01:53<00:00, 110.84it/s]

[Dev Eval] MRR: 0.6569 | F1@5: 0.6269





In [27]:
import nltk
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [28]:
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from rouge_score import rouge_scorer
from nltk import word_tokenize
import numpy as np
from tqdm import tqdm

# Load LLM
tokenizer = T5Tokenizer.from_pretrained('t5-small')
t5_model = T5ForConditionalGeneration.from_pretrained('t5-small').to(DEVICE)

scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)

def compute_f1(prediction, ground_truth):
    pred_tokens = word_tokenize(prediction.lower())
    gold_tokens = word_tokenize(ground_truth.lower())
    common = set(pred_tokens) & set(gold_tokens)
    if len(common) == 0:
        return 0.0
    precision = len(common) / len(pred_tokens)
    recall = len(common) / len(gold_tokens)
    f1 = 2 * precision * recall / (precision + recall)
    return f1

# === Evaluation Loop ===
total_f1, total_rouge = 0.0, 0.0
count = 0
model.eval()

for i, (qid, sample) in enumerate(tqdm(dev_graphs.items())):
    graph = sample['graph'].to(DEVICE)
    query = sample['query']
    passage_titles = sample['passage_titles']
    positive_ids = sample['positive_ids']

    # Gold answer (concatenated titles of gold passages)
    gold_answer = " ".join([passage_titles[i] for i in positive_ids])

    # GNN Scoring + Ranking
    with torch.no_grad():
        scores = model(graph.x, graph.edges_index)
    topk = torch.argsort(scores, descending=True)[:5]
    top_passages = [passage_titles[i] for i in topk.cpu()]
    context = " ".join(top_passages)

    # LLM Input
    prompt = f"question: {query} context: {context}"
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(DEVICE)

    # Generate Answer
    output = t5_model.generate(input_ids, max_length=50)
    prediction = tokenizer.decode(output[0], skip_special_tokens=True)

    # Score
    f1 = compute_f1(prediction, gold_answer)
    rouge = scorer.score(prediction, gold_answer)['rougeL'].fmeasure

    total_f1 += f1
    total_rouge += rouge
    count += 1

# Final Results
print(f"\nAverage F1: {total_f1 / count:.4f}")
print(f"Average ROUGE-L: {total_rouge / count:.4f}")


100%|██████████████████████████████████████████████████████████████████████████████| 12576/12576 [26:31<00:00,  7.90it/s]


Average F1: 0.2259
Average ROUGE-L: 0.2264



