In [None]:
pip install transformers

In [3]:
from transformers import  AutoTokenizer, OPTForCausalLM, AutoModel, GPT2Tokenizer, GPT2LMHeadModel
import torch
import numpy as np
from sklearn.utils import shuffle

class QAModel():

  def __init__(self, model_name="facebook/opt-1.3b", device='cuda'):
    self.model = OPTForCausalLM.from_pretrained(model_name).to(device)
    self.tokenizer = AutoTokenizer.from_pretrained(model_name)
    #self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    #self.model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
    self.device = device

  def get_answer(self, q, options):
    scores = []
    for o in options:
      input = self.tokenizer(q+' '+o, return_tensors="pt").input_ids.to(self.device)
      o_input = self.tokenizer(o, return_tensors="pt").to(self.device)
      o_len = o_input.input_ids.size(1)
      target_ids = input.clone()
      target_ids[:, :-o_len] = -100
      with torch.no_grad():
          outputs = self.model(input, labels=target_ids)
          neg_log_likelihood = outputs[0] 

      scores.append((-1*neg_log_likelihood.cpu()))
    scores, options = shuffle(scores, options, random_state=0)
    args = np.argsort(scores)
    return options[args[-1]]

In [4]:
import json

all_path = "../data/500QA.json"
with open(all_path) as f:
  all_data = json.load(f)


In [None]:
qa_model = QAModel()

In [6]:
def ZeroShot_QA(data, name):
  correct = 0
  
  type_correct = {
      "Specific": 0,
      "Subjective": 0,
      "Indirect": 0,
      "Compound": 0,
      "Negated": 0,
      "Analogical": 0,
      "Temporal": 0}

  for d in data:
    options_list = [val for val in d['options'].values()]
    query = d['query']
    correct_answer = d['options'][d['answer']]
    answer = qa_model.get_answer(query, options_list)
    #print('Question:', query)
    #print('Expected Answer: ', correct_answer)
    #print('Answer:', answer)
    if answer == correct_answer:
      correct += 1
      for key in d['query_type']:
        if d['query_type'][key] == 1:
          type_correct[key] += 1
    #else:
    #  print(query)
    #  print(d["query_type"])
  
  print("Results for {}:".format(name))
  print("Total correct: ", correct)
  print(type_correct)

In [None]:
ZeroShot_QA(all_data, "All")