In [1]:
import networkx as nx
from torch_geometric.utils import from_networkx
from community import community_louvain

import pandas as pd
pd.set_option('display.max_colwidth', None)

import torch
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from collections import defaultdict

from models.SpatialVectorizer import SpatialVectorizer

import concurrent.futures
import time

import torch.optim as optim

import pickle

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
msmarco_docs = pd.read_csv('data/doc_reranking/qids_1000/sampled_msmarco-docs.csv',
                            header=None, names=['docid', 'url', 'title', 'body'])

In [107]:
len(msmarco_docs)

183859

In [3]:
def get_cos_sim_tensor(df):
    texts = df['body'].fillna('')[:10000]
    vectorizer = TfidfVectorizer()
    tfidf_matrix = vectorizer.fit_transform(texts) 

    cosine_sim_matrix = cosine_similarity(tfidf_matrix)
    cosine_tensor = torch.tensor(cosine_sim_matrix, dtype=torch.float)
    return cosine_tensor

cos_tensor = get_cos_sim_tensor(msmarco_docs)
cos_tensor.shape

torch.Size([10000, 10000])

In [4]:
# from sklearn.neighbors import NearestNeighbors

# texts = msmarco_docs['body'].fillna('')[:10000]
# vectorizer = TfidfVectorizer()
# tfidf_matrix = vectorizer.fit_transform(texts)

# k = 10
# nn = NearestNeighbors(n_neighbors=k, metric='cosine', algorithm='brute')
# nn.fit(tfidf_matrix)

# # Get top-k cosine similarities (1 - distances)
# distances, indices = nn.kneighbors(tfidf_matrix)
# similarities = 1 - distances  # shape: (100000, k)

# cos_tensor = similarities

# edge_index = np.where((cos_tensor > 0.5))

Get non-diagonal edges where the cosine similarity is greater than the threshold value

In [5]:
edge_index = np.where((cos_tensor > 0.5) & (np.arange(cos_tensor.shape[0])[:, None] != np.arange(cos_tensor.shape[1])))


In [6]:
edges = list(zip(edge_index[0], edge_index[1]))

G = nx.Graph()
G.add_edges_from(edges)

Get node embeddings. Maps created as Node2Vec expects 0-indexed edge_index

In [7]:
components = list(nx.connected_components(G))


In [8]:
edge_index = torch.tensor(np.vstack(edge_index), dtype=torch.long)

src,dst = edge_index
unique_nodes = torch.unique(edge_index)
node_id_map = {old.item(): new for new, old in enumerate(unique_nodes)}
id_node_map = {new: old.item() for new, old in enumerate(unique_nodes)}

mapped_edge_index = torch.stack([
    torch.tensor([node_id_map[s.item()] for s in src]),
    torch.tensor([node_id_map[d.item()] for d in dst])
], dim=0)

sp = SpatialVectorizer(mapped_edge_index, "node2vec")
embeddings = sp.get_embeddings()

In [9]:
partition = community_louvain.best_partition(G) 

In [10]:
print(f"Total communities created: {len(set(partition.values()))}")

Total communities created: 793


Calculate central vectors for each of the clusters. These central vectors will be queried during QA

In [11]:
cluster_vectors = defaultdict(list)
for node, cluster_id in partition.items():
    cluster_vectors[cluster_id].append(embeddings[node_id_map[node]])

central_vectors = {
    cluster_id: torch.stack(vectors).mean(dim=0)
    for cluster_id, vectors in cluster_vectors.items()
}

In [12]:
# msmarco_doctrain_top100 = pd.read_csv('/home/nxz190009/phd/graph_reranking/msmarco_doc_reranking/msmarco-doctrain-top100', sep="\s+",header=None, names=["qid", "placeholder1", "docid", "placeholder2", "placeholder3", "placeholder4"])

Loading the dataset into RAM and filtering takes less time. Obviously we will run into an issue when the data does not fit in the RAM which is when we will need the chunking approach

In [13]:
# values_to_keep = ['D59221', 'D59220', 'D1555982']
# chunksize = 1000
# output_csv = 'demo.csv'

# def process_chunk(df):
#     return df[df['docid'].isin(values_to_keep)]

# def read_and_filter_in_parallel():
#     results = []

#     with pd.read_csv('msmarco_doc_reranking/msmarco-docs.tsv',
#                             sep="\t",header=None, names=['docid', 'url', 'title', 'body'], nrows=100000, 
#                               chunksize=chunksize) as reader:
#         with concurrent.futures.ProcessPoolExecutor() as executor:
#             futures = [executor.submit(process_chunk, chunk) for chunk in reader]
#             for future in concurrent.futures.as_completed(futures):
#                 result = future.result()
#                 if not result.empty:
#                     results.append(result)

#     # Optionally concatenate
#     final_df = pd.concat(results, ignore_index=True)
#     final_df.to_csv(output_csv, index=False)

# s = time.time()
# read_and_filter_in_parallel()
# print(f"Chunking and parallelizing Took {time.time() - s}s")


# s = time.time()
# df = pd.read_csv('msmarco_doc_reranking/msmarco-docs.tsv',
#                             sep="\t",header=None, names=['docid', 'url', 'title', 'body'], nrows=100000)
# filtered_df = df[df['docid'].isin(values_to_keep)]
# print(f"Loading into RAM and filtering Took {time.time() - s}s")

In [14]:
# msmarco_doctrain_top100

In [None]:


with open('data/doc_reranking/qids_1000/qids.pkl', 'rb') as f:
    sampled_qids = pickle.load(f)



In [16]:
query_collection = pd.read_csv('msmarco_doc_reranking/msmarco-doctrain-queries.tsv', sep="\t",
                                header=None, names=['id', 'text'])
questions = query_collection[query_collection['id'].isin(sampled_qids)].text.tolist()

Map questions to the same vector space as the node2vec central vectors

In [17]:
from collections import Counter
cluster_sizes = Counter(partition.values())
cluster_sizes

Counter({4: 549,
         2: 306,
         10: 169,
         220: 144,
         1: 110,
         126: 92,
         35: 86,
         11: 83,
         0: 79,
         46: 71,
         12: 70,
         49: 69,
         90: 63,
         48: 62,
         187: 52,
         17: 51,
         216: 43,
         72: 41,
         30: 38,
         85: 38,
         60: 37,
         15: 36,
         139: 36,
         36: 35,
         47: 35,
         67: 35,
         125: 34,
         23: 33,
         78: 33,
         18: 32,
         98: 32,
         138: 32,
         75: 30,
         22: 29,
         172: 28,
         109: 25,
         5: 24,
         481: 24,
         323: 23,
         140: 22,
         37: 21,
         114: 21,
         6: 20,
         43: 20,
         155: 20,
         115: 19,
         119: 19,
         193: 19,
         148: 17,
         248: 17,
         3: 16,
         69: 16,
         111: 16,
         127: 16,
         79: 15,
         206: 15,
         19: 14,
         20

In [19]:
len(questions)

1000

In [97]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel

class TextEncoder(nn.Module):
    def __init__(self, model_name='bert-base-uncased', output_dim=128):
        super().__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        self.project = nn.Linear(self.bert.config.hidden_size, output_dim)

    def forward(self, input_ids, attention_mask):
        output = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0]
        return self.project(output)  # shape: [batch_size, output_dim]
    
model_t = TextEncoder(model_name='bert-base-uncased', output_dim=128).to('cuda')
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [95]:
questions

['pluvious definition',
 'how to sync computer to ps3',
 'things about gettysburg battlefield gettysburg pennsylvania haunted',
 'what chinese dynasty invented abacus',
 'what do episcopalians eat',
 'utm earth definition',
 'population in cincinnati ohio',
 'describe how an earthquake occurs',
 'what county for westland mi',
 'what education is needed to be a psychiatrist',
 'what county is campbell hall, ny',
 'how long to hip nerve blocks last',
 'is grinding your teeth bad',
 'meaning of osgood',
 'disability pensioner can i get age pension',
 'probable cause is derived from what amendment',
 'is iran fighting isis',
 'how far is pittsburgh pa from nyc',
 'postscript ps definition',
 'does wisconsin have state tax',
 'how far is danbury from nyc',
 'what is hardboard made from',
 'what are some organic compounds',
 'what does the term anti semitism mean',
 'what does a logistic manager do',
 'what do the initials hsw stand for',
 'std effects on the body',
 'what does a cre',
 'pop

In [56]:
from torch.utils.data import DataLoader

batch_size = 16  # or even smaller if needed
device = 'cuda' if torch.cuda.is_available() else 'cpu'

all_vecs = []

for i in range(0, len(questions), batch_size):
    batch_questions = questions[i:i+batch_size]
    encodings = tokenizer(batch_questions, return_tensors='pt', padding=True, truncation=True).to(device)
    
    with torch.no_grad():
        vecs = model(encodings['input_ids'], encodings['attention_mask'])
    
    all_vecs.append(vecs.cpu())  # move back to CPU to save GPU memory

text_vecs = torch.cat(all_vecs, dim=0).to('cuda')  # [total_questions, 128]


In [25]:
graph_vecs = torch.stack([v for i, v in central_vectors.items()]).to('cuda')

In [69]:
torch.arange(max(cluster_vectors.keys())+1)

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
        54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
        72, 73])

In [None]:

msmarco_doctrain_top100 = pd.read_csv('/home/nxz190009/phd/graph_reranking/msmarco_doc_reranking/msmarco-doctrain-top100', 
                                        sep="\s+",header=None, 
                                        names=["qid", "placeholder1", "docid", "placeholder2", "placeholder3", "placeholder4"])
sampled_qids = msmarco_doctrain_top100['qid'].sample(n=10, random_state=42).tolist()
msmarco_doctrain_top100_sampled = msmarco_doctrain_top100[msmarco_doctrain_top100['qid'].isin(sampled_qids)]

def get_cluster_labels(sampled_qids):
    labels = []
    for qid in sampled_qids:
        docids = msmarco_doctrain_top100[msmarco_doctrain_top100['qid'] == qid]['docid'].tolist()
        node_ids = msmarco_docs[msmarco_docs['docid'].isin(docids)].index
        clusters = set()
        for node_id in node_ids:
            if node_id in partition:
                clusters.add(partition[node_id])
        
        labels.append(list(clusters))
    return labels

def get_labels(sampled_qids):
    labels = torch.zeros((len(sampled_qids), max(cluster_vectors.keys())+1))
    targets = get_cluster_labels(sampled_qids)
    for i,r in enumerate(targets):
        cluster_ids = targets[i]
        for c_id in cluster_ids:
            labels[i][c_id] = 1
    return labels

In [70]:
labels = get_labels(sampled_qids).to('cuda')

In [78]:
class SimpleQueryEncoder(nn.Module):
    def __init__(self, embed_dim=300, output_dim=128):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(embed_dim, 256),
            nn.ReLU(),
            nn.Linear(256, output_dim)
        )

    def forward(self, embeddings):
        return self.encoder(embeddings)
    

def multi_label_contrastive_loss(text_embeds, graph_embeds, targets, temperature=0.07):
    text_embeds = nn.functional.normalize(text_embeds, dim=1)
    graph_embeds = nn.functional.normalize(graph_embeds, dim=1)

    logits = torch.matmul(text_embeds, graph_embeds.T) / temperature  # [B, N]
    loss_fn = nn.BCEWithLogitsLoss()
    return loss_fn(logits, targets)

In [87]:
train, val, test = [], [], []
batch_size = 50

text_vecs_batches = [text_vecs[i:i + batch_size] for i in range(0, len(text_vecs), batch_size)]
qid_batches = [sampled_qids[i:i + batch_size] for i in range(0, len(sampled_qids), batch_size)]
label_batches = [labels[i:i + batch_size] for i in range(0, len(labels), batch_size)]

In [92]:
text_vecs_batches

[tensor([[ 0.3008,  0.4821,  0.1552,  ..., -0.0583,  0.5484, -0.6296],
         [ 0.3944,  0.2107, -0.0426,  ..., -0.0277,  0.7066, -0.4537],
         [ 0.5214,  0.1793,  0.1345,  ..., -0.1644,  0.5514, -0.6884],
         ...,
         [ 0.4535,  0.1241, -0.1475,  ...,  0.0621,  0.6702, -0.4480],
         [ 0.4033,  0.2349, -0.0981,  ..., -0.0807,  0.5587, -0.6130],
         [ 0.3880,  0.1366, -0.0565,  ..., -0.0511,  0.7330, -0.4973]],
        device='cuda:0'),
 tensor([[ 0.5508,  0.2573, -0.1967,  ..., -0.0920,  0.6005, -0.6206],
         [ 0.5026,  0.0291, -0.1788,  ...,  0.4305,  0.6738, -0.7355],
         [ 0.2462,  0.1301,  0.0780,  ..., -0.1218,  0.3703, -0.4273],
         ...,
         [ 0.4339,  0.3588,  0.0083,  ..., -0.0427,  0.6582, -0.4777],
         [ 0.5079,  0.1932,  0.0895,  ...,  0.0072,  0.5474, -0.5166],
         [ 0.4230,  0.1270, -0.3164,  ...,  0.0769,  0.6449, -0.4551]],
        device='cuda:0'),
 tensor([[ 0.5694,  0.1892, -0.1474,  ..., -0.0073,  0.5662, -0.56

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

all_vecs = []

for i in range(0, len(questions), batch_size):
    batch_questions = questions[i:i+batch_size]
    encodings = tokenizer(batch_questions, return_tensors='pt', padding=True, truncation=True).to(device)
    
    with torch.no_grad():
        vecs = model(encodings['input_ids'], encodings['attention_mask'])
    
    all_vecs.append(vecs.cpu())  # move back to CPU to save GPU memory

text_vecs = torch.cat(all_vecs, dim=0).to('cuda')  # [total_questions, 128]

In [101]:
def train_text_encoder(model, graph_vecs, qid_batches, epochs):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    model.train()
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)
    graph_vecs = graph_vecs.to(device)
    total_loss = 0

    for epoch in range(epochs):
        total_loss = 0
        for i,qid_batch in enumerate(qid_batches):
            batch_questions = questions[i:i+batch_size]
            encodings = tokenizer(batch_questions, return_tensors='pt', padding=True, truncation=True).to(device)
            text = model_t(encodings['input_ids'], encodings['attention_mask'])

            labels = label_batches[i]
                
            loss = multi_label_contrastive_loss(text, graph_vecs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1} Loss: {total_loss:.4f}")
    
model = SimpleQueryEncoder()
train_text_encoder(model, graph_vecs, qid_batches, 10)


Epoch 1 Loss: 19.1866
Epoch 2 Loss: 19.1866
Epoch 3 Loss: 19.1866
Epoch 4 Loss: 19.1866
Epoch 5 Loss: 19.1866
Epoch 6 Loss: 19.1866


KeyboardInterrupt: 

In [None]:
num_epochs = 1

optimizer = optim.Adam(model.parameters(), lr=2e-5)

loss = multi_label_contrastive_loss(text_vecs[:10], graph_vecs, labels[:10])

for epoch in range(num_epochs):
    for qid_batch in qid_batches:

        labels = get_labels(qid_batch).to('cuda')

        loss = multi_label_contrastive_loss(text_vecs, graph_vecs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [69]:
loss

tensor(0.7468, device='cuda:0')

In [111]:
enc = tokenizer(questions, return_tensors='pt', padding=True, truncation=True).to('cuda')


In [None]:
def multi_label_contrastive_loss(text_embeds, graph_embeds, targets, temperature=0.07):
    text_embeds = nn.functional.normalize(text_embeds, dim=1)
    graph_embeds = nn.functional.normalize(graph_embeds, dim=1)

    logits = torch.matmul(text_embeds, graph_embeds.T) / temperature  # [B, N]
    loss_fn = nn.BCEWithLogitsLoss()
    return loss_fn(logits, targets) 
    
optimizer = optim.Adam(model.parameters(), lr=2e-5)    
# Tokenize text
enc = tokenizer(questions, return_tensors='pt', padding=True, truncation=True).to('cuda')
text_vecs = model(enc['input_ids'], enc['attention_mask'])  # BERT → projection

# Lookup graph vectors
graph_vecs = torch.stack([v for k,v in central_vectors.items()]).to('cuda')
sampled_qids
labels = get_labels().to('cuda')

# Compute loss
loss = multi_label_contrastive_loss(text_vecs, graph_vecs, labels)
loss
# Backprop and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()

In [None]:
node_mapping = {i: n for i, n in enumerate(G.nodes())}


{0: 0,
 6: 0,
 63: 0,
 90: 0,
 160: 0,
 178: 0,
 198: 0,
 274: 0,
 330: 0,
 331: 0,
 390: 0,
 396: 0,
 412: 0,
 424: 0,
 445: 0,
 461: 0,
 481: 0,
 492: 0,
 508: 0,
 538: 0,
 623: 0,
 672: 0,
 693: 0,
 725: 0,
 747: 0,
 805: 0,
 818: 0,
 829: 0,
 849: 0,
 889: 0,
 909: 0,
 938: 0,
 1491: 0,
 1790: 0,
 1: 1,
 86: 1,
 155: 1,
 268: 1,
 563: 1,
 616: 1,
 2: 2,
 190: 2,
 3: 3,
 599: 3,
 4: 4,
 10: 4,
 15: 4,
 47: 4,
 176: 4,
 200: 4,
 203: 4,
 217: 4,
 245: 4,
 314: 4,
 323: 4,
 324: 4,
 341: 4,
 348: 4,
 350: 4,
 374: 4,
 376: 4,
 422: 4,
 435: 4,
 507: 4,
 545: 4,
 564: 4,
 620: 4,
 626: 4,
 659: 4,
 663: 4,
 697: 4,
 723: 4,
 750: 4,
 759: 4,
 844: 4,
 891: 4,
 927: 4,
 950: 4,
 959: 4,
 963: 4,
 5: 5,
 296: 5,
 329: 5,
 357: 5,
 463: 5,
 577: 5,
 762: 5,
 918: 5,
 924: 5,
 7: 6,
 179: 6,
 181: 6,
 202: 6,
 289: 6,
 304: 6,
 385: 6,
 454: 6,
 729: 6,
 755: 6,
 783: 6,
 830: 6,
 831: 6,
 861: 6,
 872: 6,
 895: 6,
 903: 6,
 970: 6,
 971: 6,
 991: 6,
 8: 5,
 29: 5,
 223: 5,
 434: 5,
 622: 