In [1]:
!nvidia-smi

Sat Apr 12 21:17:18 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.54.03              Driver Version: 535.54.03    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100 80GB PCIe          Off | 00000000:17:00.0 Off |                    0 |
| N/A   58C    P0             190W / 300W |  23874MiB / 81920MiB |     29%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA A100 80GB PCIe          Off | 00000000:31:00.0 Off |  

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import DataLoader
import os

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [4]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# === GNN Ranker Model ===
class GNNRanker(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(GNNRanker, self).__init__()
        
        #Mean Aggregation
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, hidden_dim)
        self.conv4 = GCNConv(hidden_dim, hidden_dim)# third layer
        self.scorer = nn.Linear(hidden_dim, 1)

        #Max Aggrgation
        # self.conv1 = SAGEConv(input_dim, hidden_dim, aggr='max')
        # self.conv2 = SAGEConv(hidden_dim, hidden_dim, aggr='max')

        #Min Aggrgation
        # self.conv1 = SAGEConv(input_dim, hidden_dim, aggr='min')
        # self.conv2 = SAGEConv(hidden_dim, hidden_dim, aggr='min')

        #Sum Aggrgation
        # self.conv1 = SAGEConv(input_dim, hidden_dim, aggr='sum')
        # self.conv2 = SAGEConv(hidden_dim, hidden_dim, aggr='sum')

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = F.relu(self.conv3(x, edge_index))
        x = self.conv4(x, edge_index)
        return self.scorer(x).squeeze(-1)  # shape: [num_nodes]


In [5]:
# === Loss Function ===
def ranking_loss(scores, positive_ids):
    loss = 0.0
    pos_scores = scores[positive_ids]
    all_indices = torch.arange(len(scores), device=scores.device)
    neg_indices = list(set(all_indices.tolist()) - set(positive_ids.tolist()))
    
    if len(neg_indices) == 0 or len(pos_scores) == 0:
        return torch.tensor(0.0, device=scores.device, requires_grad=True)

    neg_scores = scores[neg_indices]
    for pos in pos_scores:
        margin = 1.0
        loss += torch.sum(F.relu(margin - pos + neg_scores))
    return loss / len(pos_scores)

In [6]:
# === MRR & F1 Evaluation ===
def compute_mrr(scores, positive_ids):
    sorted_indices = torch.argsort(scores, descending=True)
    ranks = [(sorted_indices == pid).nonzero(as_tuple=True)[0].item() + 1 for pid in positive_ids if pid < len(scores)]
    reciprocals = [1.0 / rank for rank in ranks]
    return sum(reciprocals) / len(ranks) if ranks else 0.0

def compute_f1_at_k(scores, positive_ids, k=5):
    top_k = torch.argsort(scores, descending=True)[:k]
    retrieved_set = set(top_k.tolist())
    gold_set = set(positive_ids.tolist())
    tp = len(retrieved_set & gold_set)
    precision = tp / k
    recall = tp / len(gold_set) if gold_set else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0
    return f1

In [7]:
from tqdm import tqdm

In [8]:
# === Load Data ===
train_graphs = torch.load('data/train.pt')
train_list = [v['graph'] for v in train_graphs.values()]
train_loader = DataLoader(train_list, batch_size=1, shuffle=True)

# === Initialize Model and Optimizer ===
model = GNNRanker(input_dim=768, hidden_dim=256).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)



# Mean Aggregation

In [12]:
# === Training Loop ===
for epoch in range(5):
    model.train()
    total_loss = 0
    total_mrr = 0
    total_f1 = 0
    batches = 0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        batch = batch.to(DEVICE)

        if batch.edges_index is None:
            continue
        if batch.edges_index.dtype != torch.long:
            batch.edges_index = batch.edges_index.long()

        scores = model(batch.x, batch.edges_index)
        positive_ids = batch.positive_ids.to(DEVICE)

        loss = ranking_loss(scores, positive_ids)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_mrr += compute_mrr(scores, positive_ids)
        total_f1 += compute_f1_at_k(scores, positive_ids)
        batches += 1

    avg_loss = total_loss / batches if batches else 0
    avg_mrr = total_mrr / batches if batches else 0
    avg_f1 = total_f1 / batches if batches else 0

    print(f"Epoch {epoch+1} | Loss: {avg_loss:.4f} | MRR: {avg_mrr:.4f} | F1@5: {avg_f1:.4f}")

Epoch 1: 100%|███████████████████████████████████████████████████████████████████| 167454/167454 [28:41<00:00, 97.26it/s]


Epoch 1 | Loss: 0.2005 | MRR: 0.6892 | F1@5: 0.6339


Epoch 2: 100%|███████████████████████████████████████████████████████████████████| 167454/167454 [28:55<00:00, 96.50it/s]


Epoch 2 | Loss: 0.1888 | MRR: 0.6900 | F1@5: 0.6341


Epoch 3: 100%|███████████████████████████████████████████████████████████████████| 167454/167454 [28:18<00:00, 98.59it/s]


Epoch 3 | Loss: 0.1806 | MRR: 0.6904 | F1@5: 0.6342


Epoch 4: 100%|██████████████████████████████████████████████████████████████████| 167454/167454 [25:38<00:00, 108.86it/s]


Epoch 4 | Loss: 0.1736 | MRR: 0.6909 | F1@5: 0.6342


Epoch 5: 100%|██████████████████████████████████████████████████████████████████| 167454/167454 [25:36<00:00, 108.97it/s]

Epoch 5 | Loss: 0.1693 | MRR: 0.6912 | F1@5: 0.6343





In [None]:
train_graphs

In [None]:
train_list['edge']

In [13]:
dev_graphs = torch.load('data/dev.pt')
dev_list = [v['graph'] for v in dev_graphs.values()]
dev_loader = DataLoader(dev_list, batch_size=1, shuffle=False)


In [14]:
model.eval()
total_mrr = 0.0
total_f1 = 0.0
batches = 0

with torch.no_grad():
    for batch in tqdm(dev_loader):
        batch = batch.to(DEVICE)

        if batch.edges_index.dtype != torch.long:
            batch.edges_index = batch.edges_index.long()

        scores = model(batch.x, batch.edges_index)
        positive_ids = batch.positive_ids.to(DEVICE)

        mrr = compute_mrr(scores, positive_ids)
        f1 = compute_f1_at_k(scores, positive_ids, k=5)

        total_mrr += mrr
        total_f1 += f1
        batches += 1

print(f"[Dev Eval] MRR: {total_mrr / batches:.4f} | F1@5: {total_f1 / batches:.4f}")


100%|█████████████████████████████████████████████████████████████████████████████| 12576/12576 [00:34<00:00, 364.18it/s]

[Dev Eval] MRR: 0.6518 | F1@5: 0.6251





In [18]:
import nltk
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [19]:
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from rouge_score import rouge_scorer
from nltk import word_tokenize
import numpy as np
from tqdm import tqdm

# Load LLM
tokenizer = T5Tokenizer.from_pretrained('t5-small')
t5_model = T5ForConditionalGeneration.from_pretrained('t5-small').to(DEVICE)

scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)

def compute_f1(prediction, ground_truth):
    pred_tokens = word_tokenize(prediction.lower())
    gold_tokens = word_tokenize(ground_truth.lower())
    common = set(pred_tokens) & set(gold_tokens)
    if len(common) == 0:
        return 0.0
    precision = len(common) / len(pred_tokens)
    recall = len(common) / len(gold_tokens)
    f1 = 2 * precision * recall / (precision + recall)
    return f1

# === Evaluation Loop ===
total_f1, total_rouge = 0.0, 0.0
count = 0
model.eval()

for i, (qid, sample) in enumerate(tqdm(dev_graphs.items())):
    graph = sample['graph'].to(DEVICE)
    query = sample['query']
    passage_titles = sample['passage_titles']
    positive_ids = sample['positive_ids']

    # Gold answer (concatenated titles of gold passages)
    gold_answer = " ".join([passage_titles[i] for i in positive_ids])

    # GNN Scoring + Ranking
    with torch.no_grad():
        scores = model(graph.x, graph.edges_index)
    topk = torch.argsort(scores, descending=True)[:5]
    top_passages = [passage_titles[i] for i in topk.cpu()]
    context = " ".join(top_passages)

    # LLM Input
    prompt = f"question: {query} context: {context}"
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(DEVICE)

    # Generate Answer
    output = t5_model.generate(input_ids, max_length=50)
    prediction = tokenizer.decode(output[0], skip_special_tokens=True)

    # Score
    f1 = compute_f1(prediction, gold_answer)
    rouge = scorer.score(prediction, gold_answer)['rougeL'].fmeasure

    total_f1 += f1
    total_rouge += rouge
    count += 1

# Final Results
print(f"\nAverage F1: {total_f1 / count:.4f}")
print(f"Average ROUGE-L: {total_rouge / count:.4f}")


100%|██████████████████████████████████████████████████████████████████████████████| 12576/12576 [19:26<00:00, 10.78it/s]


Average F1: 0.2347
Average ROUGE-L: 0.2353



