In [31]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy
import json
import requests
from pprint import pprint
from collections import Counter

In [7]:
data_dir = "/Users/ak/tte/data/sample-data/gsm"

basline_dir = data_dir + "/baseline_pass_20_ckpt_200" 
s3c_dir = data_dir + "/s3c_trace_pass_20_ckpt_1000" 


basline_json_file_path = basline_dir + "/predict_predictions.json"
basline_stat_log_path = basline_dir + "/predict_results.json"

s3c_json_file_path = s3c_dir + "/predict_predictions.json"
s3c_stat_log_path = s3c_dir + "/predict_results.json"

In [9]:
basline_stat_log_data = pd.read_json(basline_stat_log_path, typ='series')
basline_stat_log_dict = basline_stat_log_data.to_dict()

s3c_stat_log_data = pd.read_json(s3c_stat_log_path, typ='series')
s3c_stat_log_dict = s3c_stat_log_data.to_dict()

### Basline gsm pass@20 vs s3c round 0 pass@20

In [13]:
print(f" Basline: {basline_stat_log_dict['test_raw_pass@20']}% | s3c: {s3c_stat_log_dict['test_round_0_estimated_pass@20']}%")

 Basline: 43.5178% | s3c: 36.8461%


In [22]:
basline_json = pd.read_json(basline_json_file_path)
baseline_dict = basline_json.to_dict()

In [23]:
s3c_json = pd.read_json(s3c_json_file_path)
s3c_dict = s3c_json.to_dict()

In [18]:
def get_clusters(generated_programs_dict):

    clusters_dict = {}
    clusters_list = []
    
    for prompt in generated_programs_dict:
        i = 0

        programs_list = []
        while i < len(generated_programs_dict[prompt]):
            programs_list.append( generated_programs_dict[prompt][i]['generated_program'])
            i+=1

        counts = Counter(programs_list)
        
        for cluster in counts:
            indices = [i for i in range(len(generated_programs_dict[prompt])) if programs_list[i] == cluster]
            clusters_list.append(indices)
        clusters_dict.update({prompt:clusters_list})
        clusters_list = []

    return clusters_dict

In [19]:
basline_clusters_dict = get_clusters(baseline_dict)
s3c_clusters_dict = get_clusters(s3c_dict)

In [66]:
def filter_failed_prompts(baseline_dict):
    count = 0 
    
    failed_list = []
        
    for prompt in baseline_dict:
        fail_state = True
        for output in baseline_dict[prompt]:
            if(baseline_dict[prompt][output]['compiler_output'][0] == True):
                fail_state = False
                break
        if fail_state:
            failed_list.append(prompt)
        

    return(failed_list)

In [68]:
def filter_passing_prompts(s3c_dict):
    count = 0 
    
    passed_list = []
        
    for prompt in s3c_dict:
        success_state = False
        for output in s3c_dict[prompt]:
            if(s3c_dict[prompt][output]['compiler_output'][0] == True):
                success_state = True
                break
        if success_state:
            passed_list.append(prompt)
        

    return(passed_list)

In [20]:
def filter_clusters(generated_programs_dict, clusters_dict, correct_only = None, cluster_size_range = (1,5)):
    count = 0 
    
    filtered_dict = {}
    
    for cluster_id in clusters_dict:
        # if count == 23: break #Debug early stopping
        # count+=1

        filtered_clusters_list = []
        
        for cluster in clusters_dict[cluster_id]:
            if (cluster_size_range[0] <= (len(clusters_dict[cluster_id])) <= cluster_size_range[1]):
                if correct_only == True:
                    if (generated_programs_dict[cluster_id][cluster[0]]['compiler_output'][0] == True):
                        filtered_clusters_list.append(cluster)
                elif correct_only == False:
                    if (generated_programs_dict[cluster_id][cluster[0]]['compiler_output'][0] == False):
                        #print(cluster_id, cluster[0])
                        filtered_clusters_list.append(cluster)
                else:
                    filtered_clusters_list.append(cluster)
                    
        if(len(filtered_clusters_list) > 0): filtered_dict.update({cluster_id:filtered_clusters_list})

    return(filtered_dict)

In [47]:
incorrect_baseline_clusters = filter_clusters(baseline_dict, basline_clusters_dict, False, (1, 20))

In [122]:
def intersection(lst1, lst2):
    return list(set(lst1) & set(lst2))

In [125]:
#pprint(filter_failed_prompts(baseline_dict))
#pprint(filter_passing_prompts(s3c_dict))

tf_intersection = intersection(filter_failed_prompts(baseline_dict), filter_passing_prompts(s3c_dict))

print( len(tf_intersection) )
print( tf_intersection )

156
['test_743', 'test_1133', 'test_710', 'test_636', 'test_978', 'test_129', 'test_1050', 'test_583', 'test_348', 'test_637', 'test_729', 'test_668', 'test_48', 'test_53', 'test_1315', 'test_1013', 'test_411', 'test_673', 'test_349', 'test_726', 'test_634', 'test_1026', 'test_25', 'test_920', 'test_1141', 'test_859', 'test_255', 'test_311', 'test_294', 'test_18', 'test_508', 'test_146', 'test_994', 'test_272', 'test_27', 'test_304', 'test_745', 'test_891', 'test_682', 'test_913', 'test_867', 'test_1301', 'test_289', 'test_704', 'test_707', 'test_969', 'test_597', 'test_893', 'test_1121', 'test_585', 'test_674', 'test_562', 'test_829', 'test_523', 'test_879', 'test_760', 'test_344', 'test_615', 'test_438', 'test_318', 'test_771', 'test_762', 'test_442', 'test_65', 'test_7', 'test_916', 'test_960', 'test_1305', 'test_1272', 'test_158', 'test_1237', 'test_852', 'test_1014', 'test_95', 'test_356', 'test_1', 'test_1114', 'test_1036', 'test_1041', 'test_1066', 'test_194', 'test_922', 'test_

In [129]:
reversed_tf_intersection = intersection(filter_failed_prompts(s3c_dict), filter_passing_prompts(baseline_dict))
print(len(reversed_tf_intersection))
print(reversed_tf_intersection)

186
['test_495', 'test_837', 'test_456', 'test_406', 'test_724', 'test_441', 'test_291', 'test_838', 'test_556', 'test_959', 'test_1284', 'test_807', 'test_170', 'test_397', 'test_863', 'test_1058', 'test_309', 'test_1170', 'test_653', 'test_499', 'test_288', 'test_460', 'test_416', 'test_350', 'test_452', 'test_803', 'test_506', 'test_16', 'test_839', 'test_750', 'test_1042', 'test_107', 'test_471', 'test_665', 'test_1208', 'test_489', 'test_1152', 'test_530', 'test_1081', 'test_930', 'test_997', 'test_697', 'test_1030', 'test_679', 'test_742', 'test_1246', 'test_1264', 'test_631', 'test_910', 'test_1167', 'test_3', 'test_247', 'test_684', 'test_628', 'test_136', 'test_1009', 'test_809', 'test_1182', 'test_1200', 'test_270', 'test_983', 'test_50', 'test_279', 'test_1187', 'test_885', 'test_892', 'test_399', 'test_0', 'test_788', 'test_1134', 'test_1285', 'test_511', 'test_450', 'test_156', 'test_187', 'test_470', 'test_89', 'test_206', 'test_519', 'test_394', 'test_446', 'test_577', '

In [139]:
len(filter_failed_prompts(baseline_dict)), len(filter_failed_prompts(s3c_dict))

(745, 775)

In [143]:
len(filter_passing_prompts(baseline_dict)), len(filter_passing_prompts(s3c_dict))

(574, 544)

In [142]:
len( intersection(filter_failed_prompts(s3c_dict), filter_failed_prompts(baseline_dict)) )

589

In [104]:
print(30/1300)

0.023076923076923078


In [106]:
len(baseline_dict), len(s3c_dict)

(1319, 1319)

In [75]:
basline_clusters_dict['test_743']

[[0, 7, 13, 17, 18],
 [1, 2, 4, 5, 6, 8, 11, 14, 15],
 [3],
 [9, 10],
 [12],
 [16],
 [19]]

In [99]:
def get_longest(lst):
    return max(lst, key=len)

def find_overlap( tf_intersection, basline_clusters_dict, s3c_clusters_dict):
    for prompt in tf_intersection:
        #print(s3c_clusters_dict[prompt])
        basline_cluster = (get_longest(basline_clusters_dict[prompt])) #longest wrong baseline cluster
        s3c_cluster = (get_longest(s3c_clusters_dict[prompt])) #longest correct s3c cluster
        
        print( f"Prompt: {prompt} \n")
        print("Incorrect basline generation: ")
        print( baseline_dict[prompt][basline_cluster[0]]['generated_program'] )
        
        gt_ans = baseline_dict[prompt][basline_cluster[0]]['expected_answer']
        gen_ans = baseline_dict[prompt][basline_cluster[0]]['compiler_output'][2]
        print(f"Ground Truth Answer: {gt_ans} | Generated Answer: {gen_ans}") 
        
        print("----------------------------------------------------------------")
        
        print("S3C Correct Generation: ")
        
find_overlap(tf_intersection, basline_clusters_dict, s3c_clusters_dict)

Prompt: test_743 

Incorrect basline generation: 
n0=3
n1=2
n2=5
t0=n0-n1
answer=t0-n2

Ground Truth Answer: 8.0 | Generated Answer: -4
----------------------------------------------------------------
Prompt: test_1133 

Incorrect basline generation: 
n0=2
n1=1
n2=3
n3=40
t0=n0*n1
t1=n2*t0
t2=t1+n3
answer=n2+t2

Ground Truth Answer: 50.0 | Generated Answer: 49
----------------------------------------------------------------
Prompt: test_710 

Incorrect basline generation: 
n0=5
n1=3.0
n2=4
n3=2.5
n4=13
n5=0.1
t0=n0*n1
t1=n2*n3
t2=n4*n5
answer=t0+t1+t2

Ground Truth Answer: 45.0 | Generated Answer: 26.3
----------------------------------------------------------------
Prompt: test_636 

Incorrect basline generation: 
n0=50
n1=2
t0=n0*n1
answer=t0*n1

Ground Truth Answer: 1050.0 | Generated Answer: 200
----------------------------------------------------------------
Prompt: test_978 

Incorrect basline generation: 
n0=100
n1=40
n2=0.01
t0=n0*n1*n2
answer=n0+t0

Ground Truth Answer: 160.0 