This notebook illustrates how we perform constraint reasoning on top of GPT3's outputs based on device + parts list


In [1]:
import json
import os
import csv
import ast
from csp import *

In [2]:
et2triplets_ann = {}
with open("enriched_mms/full-ET-dataset.tsv", "r") as dataset:
    lines = csv.reader(dataset, delimiter = "\t")
    for line_idx, line in enumerate(lines):
        # skip header
        if line_idx == 0:
            continue
            
        # per MM as in an everyday thing sketched by a turker
        et_turker = (line[0], line[1])
        if et_turker not in et2triplets_ann:
            et2triplets_ann[et_turker] = {"triplets": [], "parts-list": []}
           
        # collect list of (triplet_tuple, True_False_label)
        triplet = ast.literal_eval(line[2])
        annotated_relation = (triplet, line[3])
        assert annotated_relation not in et2triplets_ann[et_turker]["triplets"]
        et2triplets_ann[et_turker]["triplets"].append(annotated_relation)
        
        # also collect a list of unique parts
        for part in (triplet[0], triplet[2]):
            if part not in et2triplets_ann[et_turker]["parts-list"]:
                et2triplets_ann[et_turker]["parts-list"].append(part)

## Process predictions from GPT3 (text-davinci-003)

In [3]:
et_triplet_2_probTF = {}
gpt3_output_dir = "gpt3_query_api_output/"
for file in os.listdir(gpt3_output_dir):
    if "test" in file or file.startswith("."):
        continue
    with open(gpt3_output_dir + file, "r") as predfile:
        print(file)
        prediction_data = predfile.readlines()
        for prediction in prediction_data:
            json_pred = json.loads(prediction)
            #print(json_pred)

            et_triplet_str, original_negated_label = json_pred['id'].rsplit("-",1)
            
            assert et_triplet_str not in et_triplet_2_probTF
            et_triplet_2_probTF[et_triplet_str] = {"answer": None, "prob_True": 0, "prob_False": 0}

            # Get answer label
            et_triplet_2_probTF[et_triplet_str]["answer"] = str(json_pred["gpt3_answer"]["answer"]).capitalize()

            # Get prob scores
            et_triplet_2_probTF[et_triplet_str]["prob_True"] = json_pred["gpt3_answer"]["prob_True"]
            et_triplet_2_probTF[et_triplet_str]["prob_False"] = json_pred["gpt3_answer"]["prob_False"]


            assert  et_triplet_2_probTF[et_triplet_str]["answer"] == "True" or  et_triplet_2_probTF[et_triplet_str]["answer"] == "False"


gpt3_zero-shot_pred_on_full-ET-dataset_0-idx0-976.jsonl
gpt3_zero-shot_pred_on_full-ET-dataset_5-idx60976-75976.jsonl
gpt3_zero-shot_pred_on_full-ET-dataset_4-idx45976-60976.jsonl
gpt3_zero-shot_pred_on_full-ET-dataset_7-idx90976-105976.jsonl
gpt3_zero-shot_pred_on_full-ET-dataset_2-idx15976-30976.jsonl
gpt3_zero-shot_pred_on_full-ET-dataset_8-idx105976-end.jsonl
gpt3_zero-shot_pred_on_full-ET-dataset_3-idx30976-45976.jsonl
gpt3_zero-shot_pred_on_full-ET-dataset_1-idx976-15976.jsonl
gpt3_zero-shot_pred_on_full-ET-dataset_6-idx75976-90976.jsonl


In [4]:
true_cnt = 0
for triplet_ans in et_triplet_2_probTF:
    if et_triplet_2_probTF[triplet_ans]['answer'] == 'True':
        true_cnt += 1
print("% True tuples: {}/{} ({})".format(true_cnt, len(et_triplet_2_probTF), round((true_cnt/len(et_triplet_2_probTF)) * 100, 2)))

% True tuples: 13722/108528 (12.64)


The below code is a slight modification of that used for Macaw in Notebook 2a (so you may see some parts of that being reused here).

In [5]:
def query_gpt3_statements_batch_mode(device, perm):
    '''
    Input: everyday thing, list of tuples for permutation of list of parts
    Output: triplet_ans_conf_lst - contains [triplet,  ans, p_statement]
            neg_ans_conf_lst - list of p_neg_statement
    '''

    triplet_ans_conf_lst = [] # list of list
    neg_ans_conf_lst = [] # list
    for entry in perm:
        for rln in all_relations_lst:
            
            triplet = (entry[0], rln, entry[1])
            et_triplet_str = str((device, triplet))
            
            if et_triplet_str not in et_triplet_2_probTF:
                print("Need to query gpt3 online for", et_triplet_str)
#                 statement = triplet2statement(triplet)
#                 ans, p_statement, p_neg_statement = get_p_statement_and_p_neg_statement(device, statement)
            else:
                stored_data = et_triplet_2_probTF[et_triplet_str]
                ans = stored_data["answer"]
                p_statement = stored_data["prob_True"]
                p_neg_statement = stored_data["prob_False"]
            
            
            triplet_ans_conf_lst.append([triplet,  ans, p_statement])
            neg_ans_conf_lst.append(p_neg_statement)
            
    return triplet_ans_conf_lst, neg_ans_conf_lst

# query gpt on "everyday thing"
def run_query_gpt_everyday_thing_batch_mode(device, parts):
    # get parts
    perm = get_parts_perm(device, parts)
    # get judgment
    triplet_ans_conf_lst, neg_ans_conf_lst = query_gpt3_statements_batch_mode(device, perm)
    triplet_ans_conf_lst_true = get_statements_that_macaw_believesT(triplet_ans_conf_lst) # reuse same function
    return triplet_ans_conf_lst, triplet_ans_conf_lst_true, neg_ans_conf_lst

In [6]:
# triplet_ans_conf_lst, triplet_ans_conf_lst_true, neg_ans_conf_lst =\
#     run_query_gpt_everyday_thing_batch_mode("egg", ['yolk', 'egg white', 'shell membrane', 'shell', 'air cell'])


In [7]:
# triplet_ans_conf_lst, triplet_ans_conf_lst_true, neg_ans_conf_lst

# Part 2: Constraint satisfaction

In [8]:
def imagine_a_device_with_csp(device, turker, outputs_dir, filter_threshold, parts=[]):

    device = device.lower()
    tag = "threshold" + str(filter_threshold)
    
    lm_query_dir = outputs_dir + "LMResponses/" # dir where you want to save output
    wcnf_dir = outputs_dir + "WCNF_format/" # dir where you want to save these wcnf for reference
    plots_dir = outputs_dir + "VizPlots/" # dir where you want to store output files
    statements_dir = outputs_dir + "Props/"# dir where you save data from this run
    all_results_filename = device.replace(" ", "-") + "_" + turker + "_" + tag
    for desired_dir in [outputs_dir, lm_query_dir, wcnf_dir, plots_dir, statements_dir]:
        make_sure_dir_exists(desired_dir)
    
    if all_results_filename + ".pkl" in os.listdir(statements_dir):
         # read
        with open(statements_dir + all_results_filename + ".pkl", 'rb') as f:
             all_result_dict = pickle.load(f)
        print("Read from file ...", len(all_result_dict["gpt3_predictions"]), "triplets ...")
    else:
        # lm response - do not want to query LM again if the same device has been asked
        if device.replace(" ", "-") + "-" + turker + ".pkl" in os.listdir(lm_query_dir):
            # read
            with open(lm_query_dir + device.replace(" ", "-") + "-" + turker + ".pkl", 'rb') as f:
                 triplet_ans_conf_lst, triplet_ans_conf_lst_true, neg_ans_conf_lst = pickle.load(f) 
        else:
            # query gpt
            triplet_ans_conf_lst, triplet_ans_conf_lst_true, neg_ans_conf_lst = run_query_gpt_everyday_thing_batch_mode(device, parts)
            # save
            with open(lm_query_dir + device.replace(" ", "-") + "-" + turker + ".pkl", 'wb') as f:
                pickle.dump([triplet_ans_conf_lst, triplet_ans_conf_lst_true, neg_ans_conf_lst], f)

        # use maxsat
        print("Running maxsat ...", len(triplet_ans_conf_lst), "triplets...")
        model_believe_true_props, maxsat_selected_props = run_maxsat(device, turker, wcnf_dir, triplet_ans_conf_lst, neg_ans_conf_lst, triplet_ans_conf_lst_true, use_only_model_true_props = False)

        print("Filtering ...", len(model_believe_true_props), "triplets...", len(maxsat_selected_props), "triplets...")
        # filter based on confidence
        model_believe_true_props_filtered = filter_props(model_believe_true_props, filter_threshold)
        maxsat_selected_props_filtered = filter_props(maxsat_selected_props, filter_threshold)

        # plot
        print("Generating visualization ...", len(model_believe_true_props_filtered), "believed...", len(maxsat_selected_props_filtered), "selected")
        generate_graph_png(device, turker, model_believe_true_props_filtered, plots_dir, "model_believe_true_" + tag)
        generate_graph_png(device, turker, maxsat_selected_props_filtered, plots_dir, "maxsat_selected_" + tag)
        believed_selected= [k for k,v in model_believe_true_props_filtered.items() if k in maxsat_selected_props_filtered]
        generate_graph_png(device, turker, believed_selected, plots_dir, "believed_selected_" + tag)

        # save result
        all_result_dict = {"gpt3_predictions": triplet_ans_conf_lst,\
                        "gpt3_predictions_believe_true": triplet_ans_conf_lst_true,\
                        "model_believe_true_props": model_believe_true_props,\
                        "maxsat_selected_props": maxsat_selected_props,\
                        "filter_threshold": filter_threshold,\
                        "model_believe_true_props_filtered": model_believe_true_props_filtered,\
                        "maxsat_selected_props_filtered": maxsat_selected_props_filtered}

        with open(statements_dir + all_results_filename + ".pkl", 'wb') as f:
            pickle.dump(all_result_dict, f)
        print()
    return all_result_dict

    

In [9]:
outputs_dir = "0_gpt3-text-davinci-003-ImagineADevice-CSP-Viz-full-ET-dataset/"
filter_threshold = 50 

In [10]:
sorted_et2triplets_ann = sorted(et2triplets_ann, key=lambda et_turker: len(et2triplets_ann[et_turker]['parts-list']))

In [5]:
for mm_idx, et_turker in enumerate(sorted_et2triplets_ann) :
    print(et_turker, "MM #", mm_idx + 1)
    
    et, turker = et_turker
    parts_list = et2triplets_ann[et_turker]['parts-list']
    print(len(parts_list))
    all_result_dict = imagine_a_device_with_csp(et, turker, outputs_dir, filter_threshold, parts_list)