In [1]:
# # Process any remaining queries that did not fill a complete batch
# if queries:
#     out = pipe(queries, max_length=100, num_return_sequences=1, temperature=0.01)
#     for i, pred in enumerate(out):
#         # Since `pred` is a list of dictionaries, access the first element to get the generated text
#         generated_text = pred[0]['generated_text'] if isinstance(pred, list) and len(pred) > 0 else ""
        
#         # Strip out the prompt and process each prediction
#         stripped_pred = generated_text[len(queries[i]):].strip().lower()
#         stripped_pred = stripped_pred.split('\n')[0].strip()
#         stripped_pred = stripped_pred.rstrip(string.punctuation)

#         # Store the result
#         results[question_ids[i]][mode] = stripped_pred


In [None]:
from collections import defaultdict
from datasets import load_dataset
from tqdm import tqdm
import json
import string
import os
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
pipe = pipeline("text-generation", model=model_name, pad_token_id=tokenizer.eos_token_id, device_map='auto')

def get_query_for_mode(question, choices, mode):
    if mode == 'multiple_choice':
        return f"{question} Choose one of the following: {', '.join(choices)}. Don't include any other text in your response."
    else:
        return f"{question} Respond in as few words as possible."

ds = load_dataset("HuggingFaceM4/A-OKVQA")
modes = ['multiple_choice', 'direct_answer']
splits = ['validation']
results_dir = 'results'

for split in splits:
    print(f'Starting split: {split}')
    results_path = f'../{results_dir}/{split}.json'
    if os.path.exists(results_path):
        continue
    results = defaultdict(dict)
    for example in tqdm(ds[split]):
        input_question = example['question']
        question_id = example['question_id']
        
        for mode in modes:
            query = get_query_for_mode(input_question, example.get('choices', []), mode)
            out = pipe(query, max_length=100, num_return_sequences=1, temperature=0.01)  # Use a small positive temperature
            
            # Get generated text
            pred = out[0]['generated_text']
            
            # Strip out the prompt
            stripped_pred = pred[len(query):].strip().lower()
            stripped_pred = stripped_pred.split('\n')[0].strip()
            stripped_pred = stripped_pred.rstrip(string.punctuation)
            
            # print('-----------')
            # print(query)
            # print(stripped_pred)
            # Store the result
            results[question_id][mode] = stripped_pred

    print(f'Saving results for split: {split}')
    with open(results_path, 'w') as json_file:
        json.dump(results, json_file, indent=4)

        

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.96it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:05<00:00,  1.36s/it]


Starting split: validation


In [None]:
import re

def format_string(s):
    # Convert to lowercase
    s = s.lower()
    
    # Remove punctuation
    s = re.sub(r'[^\w\s]', '', s)
    
    # Remove articles (a, an, the)
    s = re.sub(r'\b(a|an|the)\b', '', s)
    
    # Convert words to digits (if they are numbers)
    s = re.sub(r'\bzero\b', '0', s)
    s = re.sub(r'\bone\b', '1', s)
    s = re.sub(r'\btwo\b', '2', s)
    s = re.sub(r'\bthree\b', '3', s)
    s = re.sub(r'\bfour\b', '4', s)
    s = re.sub(r'\bfive\b', '5', s)
    s = re.sub(r'\bsix\b', '6', s)
    s = re.sub(r'\bseven\b', '7', s)
    s = re.sub(r'\beight\b', '8', s)
    s = re.sub(r'\bnine\b', '9', s)
    
    # Remove extra spaces
    s = re.sub(r'\s+', ' ', s).strip()
    
    return s

for split in splits:
  results_path = f'../{results_dir}/{split}.json'
  with open(results_path, 'r') as file:
    results = json.load(file)

  mc_correct_list = []
  da_correct_list = []
  for example in tqdm(ds[split]):
    qid = example['question_id']
    target_mc = example['choices'][example['correct_choice_idx']]
    target_da_list = example['direct_answers']
    pred_mc = results[qid]['multiple_choice']
    pred_da = format_string(results[qid]['direct_answer'])
    

    mc_correct_list.append(pred_mc == target_mc)
    da_correct_list.append(min(target_da_list.count(pred_da)/10, 1))

  mc_acc = sum(mc_correct_list) / len(mc_correct_list)
  da_acc = sum(da_correct_list) / len(da_correct_list)
  print('Llama 3 8B Instruct', split)
  print('mc acc: ', round(mc_acc, 3))
  print('da acc: ', round(da_acc, 3))


100%|██████████| 1145/1145 [00:01<00:00, 965.15it/s]

Llama 3 8B Instruct validation
mc acc:  0.212
da acc:  0.059



