In [None]:
from sentence_transformers import SentenceTransformer

import torch

from tqdm.std import tqdm
import json
import pandas as pd
import numpy as np

In [None]:
if torch.cuda.is_available():
        device = torch.device("cuda")
        num_gpus = torch.cuda.device_count()
        print(f"Using {num_gpus} GPUs: {[torch.cuda.get_device_name(i) for i in range(num_gpus)]}")
else:
    device = torch.device("cpu")
    num_gpus = 0
    print("Using CPU")

In [None]:
model = SentenceTransformer("dunzhang/stella_en_1.5B_v5", trust_remote_code=True)
model.max_seq_length = 1024

In [None]:
if num_gpus > 1:
    model = torch.nn.DataParallel(model) 
    model.to(device)

In [None]:
data = pd.read_csv("../document_ranking_input_true_data/document_ranking_query.tsv",sep="\t")

In [None]:
data

In [None]:
def get_score(query, documents):
    """Calculate the matching score of a single query for multiple documents"""
    with torch.no_grad():
        query_embedding = model.encode([query], prompt_name="s2p_query", convert_to_tensor=True, normalize_embeddings=True)  # (1, dim)
        document_embeddings = model.encode(documents, convert_to_tensor=True, normalize_embeddings=True)  # (100, dim)
        # scores = (query_embedding @ document_embeddings.T) * 100  # (1, 100) -> (100,)
        scores = model.similarity(query_embedding, document_embeddings)
        
        del query_embedding, document_embeddings
        torch.cuda.empty_cache()
    return scores.cpu().numpy().flatten()

In [None]:
qids, pids, ranked_pids, ranked_scores = [], [], [], []

In [None]:
# Batch processing of data
batch_size = 100
num_samples = len(data)

for i in tqdm(range(0, num_samples, batch_size)):
    temp_df = data.iloc[i:i+batch_size]
    
    # Since the 100 queries are the same, only encode the first one
    query = temp_df["query"].iloc[0]
    passages = temp_df["passage"].tolist()
    query_id = temp_df["qid"].iloc[0]  # The 100 qid values are the same, take the first one
    passage_ids = temp_df["docid"].to_numpy()  # Convert to a NumPy array

    scores = get_score(query, passages)  # Compute scores (100,)

    # Use NumPy sorting for faster processing
    sorted_indices = np.argsort(-scores)  # Descending order sorting indices
    sorted_pids_batch = passage_ids[sorted_indices]  # Sorted PIDs
    sorted_scores_batch = scores[sorted_indices]  # Sorted scores

    # Store results
    qids.extend([query_id] * batch_size)  # 100 identical query_id values
    pids.extend(passage_ids)  # Original passage_id values
    ranked_pids.extend(sorted_pids_batch)  # Sorted PIDs
    ranked_scores.extend(sorted_scores_batch)  # Sorted scores


In [None]:
df = pd.DataFrame({"qid":qids, "docid":pids, "ranked_docid":ranked_pids, "scores":ranked_scores})

In [None]:
df.to_csv("../passage_output_result/stella_result.tsv",sep="\t",index=False)