In [497]:
from models import get_gpt_logprobs, tokenize_gpt
from data_utils import save_to_json, load_from_json
from collections import defaultdict
from pprint import pprint
import random
from tqdm import tqdm

In [498]:
# model = 'gpt-3.5-turbo-0125'
# model = 'gpt-4-turbo-2024-04-09'
model = "gpt-4o-2024-05-13"

model_name = (
    "gpt4o" if "gpt-4o" in model else "gpt4" if "gpt-4" in model else "gpt35"
)
input = """Your task: Produce a list that contains, in order, a one-word country, a one-word capital city of a country, a one-word US state, a one-word US state capital, and the surname of an American president."""

In [499]:
# from data_utils import update_solutions_json
# update_solutions_json()
results = load_from_json(f"results/{model_name}_results.json")
solutions = load_from_json("solutions.json")

In [500]:
def produce_results(model_name, number_of_results=100):
    results = []
    for i in range(number_of_results):
        logprobs, message = get_gpt_logprobs(model, input, temperature=1)
        results.append({'logprob_data': logprobs, 'message': message, 'tokens': tokenize_gpt(logprobs)})

    existing_results = load_from_json(f'results/{model_name}_results.json')
    existing_results.extend(results)
    save_to_json(existing_results, 'results/{model_name}_results.json')

In [501]:
prefix_map = {}
prefix_map[''] = results[0]['logprob_data'][0]
for result in results:
    prefix = ''
    for token, logprobs in zip(result['tokens'], result['logprob_data'][1:]):
        prefix += token
        prefix_map[prefix] = logprobs

In [502]:
def compute_prefix_map(results):
    prefix_map = {}
    prefix_map[''] = results[0]['logprob_data'][0]
    for result in results:
        prefix = ''
        for token, logprobs in zip(result['tokens'], result['logprob_data'][1:]):
            prefix += token
            if prefix == 'Canada  \nAmman  \nNevada  \nAustin  \nN':
                print(prefix, logprobs)
            prefix_map[prefix] = logprobs
    
    return prefix_map

In [503]:
def check_word_chain(words):
    """ Returns True if the words follow the last letter/first letter rule, else False """
    words = [word.lower() for word in words]
    for i in range(len(words) - 1):
        if words[i][-1] != words[i + 1][0]:
            return False
    return True

In [504]:
def first_letter_set(arr):
    return set([s[0] for s in arr])

def last_letter_set(arr):
    return set([s[-1] for s in arr])

In [505]:
""" Print total number of correct solution chains and the number of choosable items for each category """

def build_last_letter_dict(strings):
    prefix_dict = defaultdict(list)
    for s in strings:
        prefix_dict[s[-1]].append(s)
    return prefix_dict

last_letter_dict_countries = build_last_letter_dict(solutions['countries'])
last_letter_dict_world_capitals = build_last_letter_dict(solutions['world_capitals'])
last_letter_dict_states = build_last_letter_dict(solutions['states'])
last_letter_dict_state_capitals = build_last_letter_dict(solutions['state_capitals'])

correct_chains = []
for p in solutions['presidents']:
    for result in last_letter_dict_state_capitals[p[0]]:
        for s in last_letter_dict_states[result[0]]:
            for w in last_letter_dict_world_capitals[s[0]]:
                for x in last_letter_dict_countries[w[0]]:
                    correct_chains.append([x, w, s, result, p])

print(f'Total number of correct chains: {len(correct_chains)}')
print('Examples:')
pprint(correct_chains[10000:10010])

choosable_countries = set([c[0] for c in correct_chains])
choosable_world_capitals = set([c[1] for c in correct_chains])
choosable_states = set([c[2] for c in correct_chains])
choosable_state_capitals = set([c[3] for c in correct_chains])
choosable_presidents = set([c[4] for c in correct_chains])
choosable = choosable_countries | choosable_world_capitals | choosable_states | choosable_state_capitals | choosable_presidents

print(f'Choosable countries: {len(choosable_countries)} / {len(solutions["countries"])}')
print(f'Choosable world capitals: {len(choosable_world_capitals)} / {len(solutions["world_capitals"])}')
print(f'Choosable states: {len(choosable_states)} / {len(solutions["states"])}')
print(f'Choosable state capitals: {len(choosable_state_capitals)} / {len(solutions["state_capitals"])}')
print(f'Choosable presidents: {len(choosable_presidents)} / {len(solutions["presidents"])}')

Total number of correct chains: 24686
Examples:
[['russia', 'amsterdam', 'massachusetts', 'salem', 'monroe'],
 ['rwanda', 'amsterdam', 'massachusetts', 'salem', 'monroe'],
 ['samoa', 'amsterdam', 'massachusetts', 'salem', 'monroe'],
 ['serbia', 'amsterdam', 'massachusetts', 'salem', 'monroe'],
 ['slovakia', 'amsterdam', 'massachusetts', 'salem', 'monroe'],
 ['slovenia', 'amsterdam', 'massachusetts', 'salem', 'monroe'],
 ['somalia', 'amsterdam', 'massachusetts', 'salem', 'monroe'],
 ['syria', 'amsterdam', 'massachusetts', 'salem', 'monroe'],
 ['tanzania', 'amsterdam', 'massachusetts', 'salem', 'monroe'],
 ['tonga', 'amsterdam', 'massachusetts', 'salem', 'monroe']]
Choosable countries: 152 / 182
Choosable world capitals: 118 / 209
Choosable states: 30 / 40
Choosable state capitals: 14 / 41
Choosable presidents: 12 / 39


In [506]:
""" Prints the invalid submitted tokens for each category"""

countries = []
capitals = []
states = []
state_capitals = []
presidents = []
submitted_items = [countries, capitals, states, state_capitals, presidents]
for result in results:
    words = result['message'].split()
    for index in range(min(5, len(words))):
        if index == 4:
            presidents.append(words[index].replace('.', ''))
        else:
            submitted_items[index].append(words[index])
    
for name, l in [('countries', countries), ('world_capitals', capitals), ('states', states), ('state_capitals', state_capitals), ('presidents', presidents)]:
    print(len(l), len(set(l)))
    print([i for i in set(l) if i.lower() not in [j.lower() for j in solutions[name]]])
    print('---')

1000 64
[]
---
1000 102
['York', 'Denmark', 'Mexico', 'Yaoundé', 'Egypt', 'Nicaragua', 'Norway', 'Ecuador', 'Australia', 'Montreal', 'Ethiopia', 'Tangier', 'Yakarta', 'Kuwait', 'Yemen', 'November', 'Estonia', 'Addis', 'Yakima', 'England', 'ElKuwait', 'Uganda', 'Uruguay', 'Uzbekistan', 'Naples', 'Atlanta', 'Utah', 'Namaqua', 'Yakutsk', 'Lagos', 'Yak', 'Denver', 'Namibia', 'Youngstown', 'Edmonton', 'Kansas', 'Nigeria', 'Laos', 'La', 'Andorra', 'Nashville', 'Nepal', 'Elgin', 'Tallahassee', 'Austin', 'Lahore', 'New', 'Yangon']
---
1000 54
['North', 'York', 'Seol', 'Southhold', 'South', 'Southhampton', 'Yemen', 'Harrisburg', 'Estonia', 'x-ray', 'Paz', 'Rhode', 'Delhi', 'Yale', 'X', 'Sofia', 'Everglades', 'Yellowstone', 'Xenia', 'New']
---
1000 63
['Newmont', 'Exeter', 'Emporia', 'York', 'Santa', 'Omaha', 'Aspen', 'Eugene', 'Seattle', 'Edmondson', 'Nebraska', 'Illinois', 'Esmeralda', 'Kansas', 'Samoa', 'Texas', 'Nevada', 'Idaho', 'Yuma', 'Yakima', 'Orlando', 'Anchorage', 'Carolina', 'Evansto

In [507]:
score = 0
valid_results = []
for result in results:
    try:
        a, b, c, d, e = result['message'].split()[:5]
        if a.lower() not in solutions['countries']:
            continue
        if b.lower() not in solutions['world_capitals']:
            continue
        if c.lower() not in solutions['states']:
            continue
        if d.lower() not in solutions['state_capitals']:
            continue
        if e.lower() not in solutions['presidents']:
            continue
        score += 1
        valid_results.append(result)
    except:
        continue

print(score, len(results), score / len(results))

409 1000 0.409


In [508]:
passing_results = []
for result in valid_results:
    words = result['message'].split()[:5]
    if check_word_chain(words):
        passing_results.append(result)

len(passing_results), len(valid_results), len(passing_results) / len(valid_results)

(341, 409, 0.8337408312958435)

In [509]:
prefix_map = compute_prefix_map(results)

total_prob_map = []
for result in tqdm(passing_results):
    # print(result['message'])
    # print(result['tokens'])
    response = ''
    try:
        total_prob = [i[1] for i in prefix_map[response] if i[0] == result['tokens'][0]][0]
        # print(total_prob)
    except:
        total_prob = prefix_map[''][-1][1]
    # print(result['tokens'][0], total_prob)
    for token, next_token in zip(result['tokens'], result['tokens'][1:]):
        response += token
        try:
            prob = [i[1] for i in prefix_map[response] if i[0] == next_token][0]
        except:
            try:
                prob = prefix_map[response][-1][1]
            except:
                prob = 1
        total_prob *= prob
        # print(next_token, prob)
    result['total_prob']= total_prob
    # print('\n', total_prob)

save_to_json(results, f"results/{model_name}_results.json")

Canada  
Amman  
Nevada  
Austin  
N [['ixon', 0.9999993295729247], ['ix', 3.466323946336287e-07], ['ash', 6.82560337633487e-08], ['orris', 2.8453348089834e-08], ['olan', 2.8453348089834e-08], ['ixen', 2.510999155743982e-08], ['ixo', 2.215948977336598e-08], ['icol', 1.0467401794744658e-08], ['<|end|>', 1.0467401794744658e-08], ['aylor', 9.237449661970594e-09], ['icolas', 8.152020714470167e-09], ['orton', 5.602796437537268e-09], ['ielsen', 5.602796437537268e-09], ['oble', 4.363462252943702e-09], ['ox', 3.850741922767617e-09], ['issan', 3.850741922767617e-09], ['IX', 3.850741922767617e-09], ['ikon', 2.335593038799337e-09], ['ye', 1.8189616875530459e-09], ['ison', 1.8189616875530459e-09]]


100%|██████████| 341/341 [00:00<00:00, 17948.67it/s]


In [510]:
save_to_json([result['total_prob'] for result in passing_results], f'{model_name}_total_probs.json')

In [511]:
prefix_map['Canada  \nAmman  \nNevada  \nAustin  \nN']

[['ixon', 0.9999993295729247],
 ['ix', 3.466323946336287e-07],
 ['ash', 6.82560337633487e-08],
 ['orris', 2.8453348089834e-08],
 ['olan', 2.8453348089834e-08],
 ['ixen', 2.510999155743982e-08],
 ['ixo', 2.215948977336598e-08],
 ['icol', 1.0467401794744658e-08],
 ['<|end|>', 1.0467401794744658e-08],
 ['aylor', 9.237449661970594e-09],
 ['icolas', 8.152020714470167e-09],
 ['orton', 5.602796437537268e-09],
 ['ielsen', 5.602796437537268e-09],
 ['oble', 4.363462252943702e-09],
 ['ox', 3.850741922767617e-09],
 ['issan', 3.850741922767617e-09],
 ['IX', 3.850741922767617e-09],
 ['ikon', 2.335593038799337e-09],
 ['ye', 1.8189616875530459e-09],
 ['ison', 1.8189616875530459e-09]]