In [12]:
import json
import itertools
from typing import Dict, List

import numpy as np

import torch
from torch import Tensor
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModelForCausalLM

from tqdm import tqdm

device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')

### reranker

In [13]:
tokenizer = AutoTokenizer.from_pretrained("/Users/timmiakov/Dev/models/Qwen3-Reranker-4B", padding_side='left')
model = AutoModelForCausalLM.from_pretrained("/Users/timmiakov/Dev/models/Qwen3-Reranker-4B").to(device).eval()

token_false_id = tokenizer.convert_tokens_to_ids("no")
token_true_id = tokenizer.convert_tokens_to_ids("yes")
max_length = 8192

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [14]:
def format_instruction(instructions: List[str]) -> str:
    output = "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format(
        instruction=instructions[0], query=instructions[1], doc=instructions[2]
    )
    return output

def process_inputs(pairs: List, prefix_tokens: List[int], suffix_tokens: List[int]) -> Dict:
    
    tokenizer_max_length = max_length - len(prefix_tokens) - len(suffix_tokens)
    
    inputs = tokenizer(
        pairs, padding=True, truncation='longest_first',
        return_attention_mask=False, max_length=tokenizer_max_length
    )
    for i, ele in enumerate(inputs['input_ids']):
        inputs['input_ids'][i] = prefix_tokens + ele + suffix_tokens
    inputs = tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=max_length)
    for key in inputs:
        inputs[key] = inputs[key].to(model.device)

    return inputs

@torch.no_grad()
def compute_logits(inputs: Dict, **kwargs) -> List[float]:
    batch_scores = model(**inputs).logits[:, -1, :]
    true_vector = batch_scores[:, token_true_id]
    false_vector = batch_scores[:, token_false_id]
    batch_scores = torch.stack([false_vector, true_vector], dim=1)
    batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
    scores = batch_scores[:, 1].exp().tolist()
    return scores

In [15]:
def invoke(prefix: str, suffix: str, insturction: str, query: str, documents: List[str]):
    
    assert isinstance(insturction, str) == True
    assert isinstance(query, str) == True
    
    prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False)
    suffix_tokens = tokenizer.encode(suffix, add_special_tokens=False)

    pairs = list(map(format_instruction, list(itertools.product([insturction], [query], documents))))

    inputs = process_inputs(pairs, prefix_tokens, suffix_tokens)
    
    return compute_logits(inputs)

### test queires

In [20]:
with open("results/processed/webqsp_test_results_resolved.json", 'r') as f:
    test_data = json.load(f)

with open("results/processed/webqsp_test_results_PROMPTS.json", 'r') as f:
    test_prompts = json.load(f)

In [119]:
idx = 47
test_q = test_data[idx]
test_q
nodes = [n[1] for n in test_q['retrieved']]

In [120]:
test_q

{'id': 'WebQTest-63',
 'retrieved': [['m.0rmg', 'andrew johnson administration', 0.3487710654735565],
  ['m.03mpk', 'hannibal hamlin', 0.20501472055912018],
  ['m.04gc2', 'defense attourney', 0.15570005774497986],
  ['m.0fj9f', 'politition', 0.15130020678043365],
  ['m.016fc2', 'statesman"@e', 0.13650622963905334]],
 'answers': [['m.03mpk', 'hannibal hamlin'],
  ['m.0rmg', 'andrew johnson administration']],
 'question': 'who was vp for lincoln'}

In [121]:
test_graph = test_prompts[test_q['id']]

In [122]:
test_graph

['abraham lincon --> government.us_vice_president.to_president --> andrew johnson administration',
 'andrew johnson administration --> government.us_president.vice_president --> abraham lincon',
 'abraham lincon --> government.us_vice_president.to_president --> hannibal hamlin',
 'hannibal hamlin --> government.us_president.vice_president --> abraham lincon',
 'defense attourney --> people.person.profession --> abraham lincon',
 'abraham lincon --> media_common.dedication.dedicated_to --> m.04tl_wn --> media_common.dedicator.dedications --> us legislative branch --> government.government_office_or_title.governmental_body_if_any --> united states congressperson --> base.onephylogeny.type_of_thing.things_of_this_type --> politition',
 'politition --> people.person.profession --> abraham lincon',
 'abraham lincon --> media_common.dedication.dedicated_to --> m.04tl_wn --> media_common.dedicator.dedications --> us legislative branch --> government.government_office_or_title.governmental_bod

In [123]:
reverse_test_graph = []
for st in test_graph:
    ents = st.split(' --> ')
    new_str = ents[2] + ' --> ' + ents[1] + ' --> ' + ents[0] 
    reverse_test_graph.append(new_str)

In [124]:
reverse_test_graph

['andrew johnson administration --> government.us_vice_president.to_president --> abraham lincon',
 'abraham lincon --> government.us_president.vice_president --> andrew johnson administration',
 'hannibal hamlin --> government.us_vice_president.to_president --> abraham lincon',
 'abraham lincon --> government.us_president.vice_president --> hannibal hamlin',
 'abraham lincon --> people.person.profession --> defense attourney',
 'm.04tl_wn --> media_common.dedication.dedicated_to --> abraham lincon',
 'abraham lincon --> people.person.profession --> politition',
 'm.04tl_wn --> media_common.dedication.dedicated_to --> abraham lincon',
 'abraham lincon --> people.person.profession --> statesman"@e']

In [125]:
prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"

instruction = """The graph of knowledge is given in the format:
Entity 1 (vertex) -> entity relation (edge) -> Entity 2 (vertex)
Having this graph and a query for it, choose the Entities that are relevant answers to that query.
Graph:\n"""
instruction += '\n'.join(reverse_test_graph)

queries = test_q['question']

documents = [r[1] for r in test_q['retrieved']]

In [126]:
scores = invoke(
    prefix=prefix,
    suffix=suffix,
    insturction=instruction,
    query=queries,
    documents=documents
)



In [127]:
documents

['andrew johnson administration',
 'hannibal hamlin',
 'defense attourney',
 'politition',
 'statesman"@e']

In [128]:
scores

[0.09347117692232132,
 0.9863868355751038,
 0.003919102717190981,
 0.0025081755593419075,
 0.0012620736379176378]