In [1]:
import torch
import torch.nn.functional as F
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from json import dumps as to_json

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
tokenizer = AutoTokenizer.from_pretrained('LiYuan/amazon-query-product-ranking')
model = AutoModelForSequenceClassification.from_pretrained('LiYuan/amazon-query-product-ranking').to(DEVICE)

In [4]:
passages = ["15x7 universal wheel","17 inch wleel","4x100 15x7 wheel","wheel","Washington D.C.","mouse trap","model plane","cheeseburgers might be unhealthy, but they sure are tasty!"]
query = ["4x100 wheel 15in"]*len(passages)
inputs = tokenizer(query, passages, padding=True, truncation=True, return_tensors='pt', max_length=512).to(DEVICE)

In [5]:
with torch.no_grad():
    logits = model(**inputs).logits
    probs = F.softmax(logits, dim=-1)
    probs_is_okay = torch.sum(probs[:,:2],1)
    probs_is_perfect = probs[:,0]
    rankings_thresh_perfect = torch.argsort(probs_is_perfect,dim=0,descending=True).tolist()
    rankings_thresh_okay = torch.argsort(probs_is_okay,dim=0,descending=True).tolist()    

In [6]:
print(f"""RANKINGS PASSAGE INDEX
All Probs:
  {probs}
Threshold perfect:
  {rankings_thresh_perfect}
Threshold okay:
  {rankings_thresh_okay}
""")

RANKINGS PASSAGE INDEX
All Probs:
  tensor([[0.2226, 0.6718, 0.0069, 0.0988],
        [0.1434, 0.5626, 0.0203, 0.2736],
        [0.8167, 0.1502, 0.0045, 0.0286],
        [0.1641, 0.6660, 0.0073, 0.1626],
        [0.1847, 0.2299, 0.0244, 0.5610],
        [0.1862, 0.2462, 0.0400, 0.5276],
        [0.1571, 0.2369, 0.0295, 0.5765],
        [0.1146, 0.1555, 0.0214, 0.7085]], device='cuda:0')
Threshold perfect:
  [2, 0, 5, 4, 3, 6, 1, 7]
Threshold okay:
  [2, 0, 3, 1, 5, 4, 6, 7]



In [7]:
ranked_metadata_perfect = [{"score":float(probs_is_perfect[i]),"text":passages[i]} for i in rankings_thresh_perfect]
ranked_metadata_okay = [{"score":float(probs_is_okay[i]),"text":passages[i]} for i in rankings_thresh_okay]

In [8]:
print(f"""RANKINGS FULL METADATA

{"###"*27}

Threshold perfect:

{to_json(ranked_metadata_perfect,indent=2)}

{"###"*27}

Threshold okay:

{to_json(ranked_metadata_okay,indent=2)}
""")

RANKINGS FULL METADATA

#################################################################################

Threshold perfect:

[
  {
    "score": 0.8166823983192444,
    "text": "4x100 15x7 wheel"
  },
  {
    "score": 0.22256851196289062,
    "text": "15x7 universal wheel"
  },
  {
    "score": 0.18616242706775665,
    "text": "mouse trap"
  },
  {
    "score": 0.18468773365020752,
    "text": "Washington D.C."
  },
  {
    "score": 0.16410930454730988,
    "text": "wheel"
  },
  {
    "score": 0.1570894718170166,
    "text": "model plane"
  },
  {
    "score": 0.1434362232685089,
    "text": "17 inch wleel"
  },
  {
    "score": 0.11463627219200134,
    "text": "cheeseburgers might be unhealthy, but they sure are tasty!"
  }
]

#################################################################################

Threshold okay:

[
  {
    "score": 0.9668986201286316,
    "text": "4x100 15x7 wheel"
  },
  {
    "score": 0.8943306803703308,
    "text": "15x7 universal wheel"
  },
  {
    