In [45]:
from tira.third_party_integrations import ir_datasets, get_output_directory
from transformers import XLMRobertaForSequenceClassification, AutoTokenizer
from tqdm import tqdm
from pathlib import Path
import pandas as pd

In [44]:
dataset = ir_datasets.load('workshop-on-open-web-search/query-processing-20231027-training')

# Query processors persist their results in a file queries.jsonl in the output directory.
output_file = Path(get_output_directory('.')) / 'queries.jsonl'

In [4]:
model = XLMRobertaForSequenceClassification.from_pretrained('OnnoLander/XLMRoberta-comparative-questions')
tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-large')

  return self.fget.__get__(instance, owner)()


In [33]:
def predict(query_text):
    logits = model(**tokenizer(query_text, return_tensors="pt")).logits
    return {'is_comparative': int(logits.argmax()) == 1, 'is_comparative_logits': [float(i) for i in logits[0]]}


In [37]:
print('Predict some example: "what is better, playstation or xbox?"')
print(predict('What is better, playstation or xbox?'))

Predict some example: "what is better, playstation or xbox?"
{'is_comparative': True, 'is_comparative_logits': [-5.1698899269104, 5.360167980194092]}


In [38]:
print('Predict some example: "playstation vs xbox"')
print(predict('playstation vs xbox'))

Predict some example: "playstation vs xbox"
{'is_comparative': True, 'is_comparative_logits': [-5.046789169311523, 5.182837963104248]}


In [39]:
print('Predict some example: "hubble telescope achievements"')
print(predict('hubble telescope achievements'))

Predict some example: "hubble telescope achievements"
{'is_comparative': False, 'is_comparative_logits': [4.803865909576416, -4.987368583679199]}


In [42]:
print('Predict some example: "barack obama family tree"')
print(predict('barack obama family tree'))

Predict some example: "barack obama family tree"
{'is_comparative': False, 'is_comparative_logits': [4.693194389343262, -5.171028137207031]}


In [46]:
processed_queries = []

for query in tqdm(dataset.queries_iter()):
    prediction = predict(query.text)
    processed_queries += [{'query_id': query.query_id, 'is_comparative': prediction['is_comparative'], 'is_comparative_logits': prediction['is_comparative_logits']}]

pd.DataFrame(processed_queries).to_json(output_file, lines=True, orient='records')

3it [00:00,  5.25it/s]
