In [None]:
import os
import torch


In [3]:
!pip install torch_geometric rank_bm25

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
Collecting rank_bm25
  Downloading rank_bm25-0.2.2-py3-none-any.whl.metadata (3.2 kB)
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m38.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading rank_bm25-0.2.2-py3-none-any.whl (8.6 kB)
Installing collected packages: rank_bm25, torch_geometric
Successfully installed rank_bm25-0.2.2 torch_geometric-2.6.1
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.10 -m pip install --upgrade pip[0m


In [1]:
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import DataLoader
from torch_geometric.nn import GATConv, global_max_pool
from transformers import T5ForConditionalGeneration, T5Tokenizer
from rank_bm25 import BM25Okapi
import numpy as np
from tqdm import tqdm

In [2]:
import json
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import TfidfVectorizer
from collections import defaultdict
from torch_geometric.data import Data

# Data Preprocessing

In [11]:
def extract_passages(instance):
    passages = []
    for entity, facts in instance['context']:
        for idx, fact in enumerate(facts):
            passages.append({
                'text': fact,
                'entity': entity,
                'position': idx,
                'is_supporting': (entity, idx) in instance['supporting_facts']
            })
    return passages


In [12]:
class SplitProcessor:
    def __init__(self):
        self.tfidf = TfidfVectorizer(max_features=5000)
        self.embedder = SentenceTransformer('all-mpnet-base-v2')
        
    def fit(self, train_instances):
        """Train on training data only"""
        train_texts = [p['text'] for inst in train_instances 
                      for p in extract_passages(inst)]
        self.tfidf.fit(train_texts)
        # Warmup embedding model
        self.embedder.encode(train_texts[:1000]) 

    def process(self, instances):
        """Process any split"""
        processed = []
        for inst in instances:
            passages = extract_passages(inst)
            # Generate features
            tfidf = self.tfidf.transform([p['text'] for p in passages])
            embeds = self.embedder.encode([p['text'] for p in passages])
            # Build graph
            graph = build_graph(passages, tfidf, embeds, inst['evidences'])
            processed.append(graph)
        return processed


In [None]:
class SelectiveSampler:
    def __init__(self, num_neighbors=10, distance_metric='cosine'):
        self.num_neighbors = num_neighbors
        self.distance_metric = distance_metric

    def __call__(self, x, edge_index=None):
        num_nodes = x.size(0)

        # print(f"\n[SelectiveSampler] Number of nodes: {num_nodes}")

        if num_nodes < 2:
            # print("[SelectiveSampler] Too few nodes, returning empty edge set.")
            return torch.empty((2, 0), dtype=torch.long)

        num_neighbors = min(self.num_neighbors + 1, num_nodes)  # +1 to include self (we'll skip later)
        # print(f"[SelectiveSampler] num_neighbors used (including self): {num_neighbors}")

        if self.distance_metric == 'cosine':
            x_norm = torch.nn.functional.normalize(x, p=2, dim=1)
            similarity = x_norm @ x_norm.T
            _, topk = torch.topk(similarity, num_neighbors, dim=-1)
        elif self.distance_metric == 'euclidean':
            dists = torch.cdist(x, x, p=2)
            _, topk = torch.topk(-dists, num_neighbors, dim=-1)
        else:
            raise ValueError("Unsupported distance metric")

        # print(f"[SelectiveSampler] topk indices per node (including self):")
        # for i in range(num_nodes):
        #     print(f"  Node {i}: {topk[i].tolist()}")

        sampled_edges = []
        for i in range(num_nodes):
            for j in topk[i]:
                if i != j:
                    sampled_edges.append((i, j.item()))

        # if not sampled_edges:
            # print("[SelectiveSampler] No valid edges formed (only self-loops found).")

        edge_tensor = torch.tensor(sampled_edges, dtype=torch.long).t().contiguous()
        # print(f"[SelectiveSampler] Final sampled edge_index shape: {edge_tensor.shape}")

        return edge_tensor


In [16]:
def build_graph(passages, tfidf, embeds, evidences):
    edge_index = []
    
    # 1. Sequential connections
    entity_pos = defaultdict(list)
    for i, p in enumerate(passages):
        entity_pos[p['entity']].append(i)
    for ents in entity_pos.values():
        edge_index += [(ents[i], ents[i+1]) for i in range(len(ents)-1)]
    
    # 2. Semantic similarity (cosine > 0.7)
    cos_sim = cosine_similarity(embeds)
    rows, cols = np.where(cos_sim > 0.7)
    edge_index += list(zip(rows, cols))
    
    # 3. Keyword overlap (TF-IDF > 0.25)
    tfidf_sim = (tfidf * tfidf.T).toarray()
    rows, cols = np.where(tfidf_sim > 0.25)
    edge_index += list(zip(rows, cols))
    
    # 4. Evidence links
    entity_map = {p['entity']:i for i,p in enumerate(passages)}
    for subj, _, obj in evidences:
        if subj in entity_map and obj in entity_map:
            edge_index.append((entity_map[subj], entity_map[obj]))
    
    # Convert to PyG Data
    return Data(
        x=torch.tensor(embeds, dtype=torch.float32),
        edge_index=torch.tensor(edge_index).t().contiguous(),
        y=torch.tensor([p['is_supporting'] for p in passages], dtype=torch.float)
    )


In [None]:
processor = SplitProcessor()

In [None]:



with open('data/train.json') as f:
    train = json.load(f)
processor.fit(train) 

train_data = processor.process(train)
val_data = processor.process(json.load(open('data/dev.json')))
torch.save(val_data, f'data/dev.pt')
test_data = processor.process(json.load(open('data/test.json')))
torch.save(train_data, f'data/train.pt')
