In [None]:
import os
from dotenv import load_dotenv
import json
import random
import matplotlib.pyplot as plt
import numpy as np
import asyncio

from benchmark import run_benchmark_sample, Trials, mark_job_failure, create_compilation_failure_trial

In [None]:
# run this cell to load the job results for all trials that are waiting pending job results
trials = Trials()
await trials.load_results()

In [None]:
def get_probability_data(trials, function_success_threshold, use_hamming=False):
    num_vars_data = []
    complexity_data = []
    probability_data = []

    for num_vars in range(2, 33):
        for complexity in range(1, 22):
            print(f"(num_vars, complexity) = ({num_vars}, {complexity})")
    
            successful_function_count = 0
            trials_by_function = trials.get_per_statement(num_vars=num_vars, complexity=complexity)
            for function in trials_by_function.keys():
                s = 0.0
                for trial in trials_by_function[function]:
                    if use_hamming:
                        s += (1-trial.mean_hamming_distance)
                    else:
                        s += trial.exact_match_rate
                    
                if s / len(trials_by_function[function]) > function_success_threshold:
                    successful_function_count += 1

            if len(trials_by_function) == 0:
                 print(f"Warning: no results for {num_vars} variables, complexity {complexity}; skipping")
                 continue
            
            num_vars_data.append(num_vars)
            complexity_data.append(complexity)
            probability_data.append(successful_function_count / len(trials_by_function))
    
    return num_vars_data, complexity_data, probability_data

In [None]:
trials = Trials()
num_vars_data_exact, complexity_data_exact, probability_data_exact = get_probability_data(trials, 0.6, use_hamming=False)

In [None]:
trials = Trials()
num_vars_data_hamming, complexity_data_hamming, probability_data_hamming = get_probability_data(trials, 0.8, use_hamming=True)

In [None]:
def plot_probability_data(num_vars_data, complexity_data, probability_data, title, threshold=None, filepath=None):
    if threshold is not None:
        colors = ['green' if p > threshold else 'red' for p in probability_data]
    else:
        colors = probability_data

    plt.figure(figsize=(20, 10))
    plt.scatter(num_vars_data, complexity_data, c=probability_data, cmap='RdYlGn', edgecolors='black', alpha=0.75, s=450)
    plt.xticks(np.arange(min(num_vars_data), max(num_vars_data) + 1, 1))
    plt.yticks(np.arange(min(complexity_data), max(complexity_data) + 1, 1))

    plt.xlabel('Variables Count')
    plt.ylabel('Complexity')
    plt.title(title)
    cbar = plt.colorbar()
    if filepath is not None:
        plt.savefig(filepath, dpi=300, bbox_inches='tight')
    plt.show()

def plot_counts_data(num_vars_data, complexity_data, probability_data, title, threshold=None, filepath=None):
    if threshold is not None:
        colors = ['green' if p > threshold else 'red' for p in probability_data]
    else:
        colors = probability_data

    plt.figure(figsize=(10, 10))
    plt.scatter(num_vars_data, complexity_data, c=probability_data, cmap='RdYlGn', edgecolors='black', alpha=0.75, s=250)
    plt.xticks(np.arange(min(num_vars_data), max(num_vars_data) + 1, 1))
    plt.yticks(np.arange(min(complexity_data), max(complexity_data) + 1, 1))

    plt.xlabel('Variables Count')
    plt.ylabel('Complexity')
    plt.title(title)
    cbar = plt.colorbar()
    if filepath is not None:
        plt.savefig(filepath, dpi=300, bbox_inches='tight')
    plt.show()

In [None]:
plot_probability_data(num_vars_data_exact, complexity_data_exact, probability_data_exact, 'Proportion of functions with exact match rate > 0.6', filepath='exact_match_rate.png')

In [None]:
# plot the number of functions we have data for, for each number of variables and complexity
trials = Trials()
num_vars_data = []
complexity_data = []
function_count_data = []

for num_vars in range(2, 33):
        for complexity in range(1, 22):
            with trials._connect() as conn:
                cursor = conn.cursor()
                cursor.execute("SELECT COUNT(DISTINCT statement) FROM trials WHERE num_vars = ? AND complexity = ? AND NOT counts = ''", (num_vars, complexity))
                count = cursor.fetchone()[0]
                num_vars_data.append(num_vars)
                complexity_data.append(complexity)
                function_count_data.append(count)
            print(f"(num_vars, complexity) = ({num_vars}, {complexity})")

In [None]:
for i in range(len(function_count_data)):
    if function_count_data[i] > 30:
        function_count_data[i] = 30

In [None]:
plot_counts_data(num_vars_data, complexity_data, function_count_data, 'Number of functions with data', filepath='function_count.png')