In [1]:
import os
import pandas as pd

cwd = os.getcwd()  # get directory for storage

# This file automates the entire pipeline for assertion generation with chatgpt

## Step 1) Get Asserted Code From Github

### Step 1.1) Clean and process the code
### Step 1.2) Extract Ground-Truth Assertions & Relevant Statistics

In [23]:
from google.cloud import bigquery as bq

def get_asserted_code(num=100000, ext="%.py", verbose=True):
    query_string = """SELECT f.repo_name, c.content
FROM `bigquery-public-data.github_repos.files` AS f
JOIN `bigquery-public-data.github_repos.contents` AS c
ON f.id = c.id
WHERE
NOT c.binary
AND f.path LIKE '%.py'
AND REGEXP_CONTAINS(c.content, r'(?m)^\s*assert ')
LIMIT """ + str(num)
    
    if isinstance(num, int):
        secret_dir = "Data/secret/"
        api_key = cwd + "/" + secret_dir + os.listdir(secret_dir)[0]
        assert api_key[-5:] == ".json"  # confirm that it was found
        os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = api_key
        query_string = query_string.replace("%.py", ext)

        if verbose:
            print("*Running Query:")
            print(query_string)
            print()
        client = bq.Client()
        df = (
            client.query(query_string)
            .result()
            .to_dataframe(
                create_bqstorage_client=True,
            )
        )
    elif isinstance(num, str):
        # load data from file
        df = pd.read_csv(num)
        print("Found data at", num)
    else:
        print("first param type undefined, must be string signifying directory of csv or\
               int signifying number of records to scrib from bigquery...")
        assert False
    
    if verbose:
        print("*Handling Duplicates...")
    init_len = len(df)
    df.drop_duplicates(subset=["content"], keep="first", inplace=True)
    if verbose:
        print("#Non-duplicates / #Total Retrieved =", (len(df)/init_len))
    return df

# small test
verilog_dir = cwd+"/Data/BigQuery/VerilogAssertions-ALL.csv"
python_dir = cwd+"/Data/BigQuery/PythonAssertions100k.csv"
df = get_asserted_code(python_dir)  # 10
df

Found data at /Users/korahughes/Documents/GitHub/LLMCodeGen/Data/BigQuery/PythonAssertions100k.csv
*Handling Duplicates...
Duplicate Ratio =  1.0


Unnamed: 0,repo_name,content
0,tqchen/tvm,# Licensed to the Apache Software Foundation (...
1,Lujeni/ansible,# (c) 2017 Red Hat Inc.\n#\n# This file is par...
2,lukas-hetzenecker/home-assistant,"""""""The tests for the Pilight sensor platform.""..."
3,schnoebe/fedora-mock,import fcntl\nimport glob\nimport grp\nimport ...
4,samstav/fastfood,# -*- coding: utf-8 -*-\n# Copyright 2015 Rack...
...,...,...
33788,raphaelm/django-i18nfield,from i18nfield.admin import I18nModelAdmin\nfr...
33789,fniephaus/alfred-rworkflow,# The MIT License (MIT)\n#\n# Copyright (c) 20...
33790,bgris/ODL_bgris,# -*- coding: utf-8 -*-\r\n#\r\n# Copyright © ...
33791,chrsrds/scikit-learn,"""""""\nTesting for the base module (sklearn.ense..."


In [42]:
tester_df = df.sample(200)
tester_df

Unnamed: 0,repo_name,content
13113,taroplus/spark,# -*- encoding: utf-8 -*-\n#\n# Licensed to th...
8658,spadejac/web-python,"'''\nCreated on Nov 10, 2013\n\n@author: Manoj..."
30406,tchellomello/home-assistant,"""""""Test the devolo_home_control config flow.""""..."
16410,sauloal/pycluster,from __future__ import absolute_import\nfrom l...
32522,Answeror/torabot,from functools import partial\nfrom datetime i...
...,...,...
13343,google-research/google-research,# coding=utf-8\n# Copyright 2022 The Google Re...
22028,WiseDoge/plume,from plume.knn import KNeighborClassifier\nfro...
7176,cvxopt/chompack,"from cvxopt import matrix, lapack, spmatrix\nf..."
10586,radarsat1/siconos,"""""""Tests for python interface of friction cont..."


In [43]:
from tqdm import tqdm

conditionals = dict([[cond, i] for i, cond in enumerate(["==", "!=", "<=", ">=", "<", ">"])])
compounding_statements = ["and"]
bad_statements = [" or ", " in ", "isinstance"]  # TODO: properly account for OR
def get_assertions(func, is_split=True, verbose=True):
    """
    Format: "assert [expression], [return_string]"
    
    Exceptions to Handle:
    - 'in'/'not in' keyword
    - boolean functions - ex. isinstance(var, type)
    - separation of attributes - ex. len(var), var[i]
    """
#     if verbose:
#         print("*Extracting Assertions...")
    out = []
    asserted_lines = 0
    lines = []
    for temp in func.split('\n'):  # find lines with assert in them
        if "assert" in temp:
            asserted_lines += 1
            bad_flag = False
            for bad in bad_statements:
                if bad in temp:
                    bad_flag = True
            if not bad_flag:
                lines.append(temp.strip())
    # TODO: experiment with smaller content window for assertions
    ind = 0
    while ind < len(lines):
        data = lines[ind].strip()
        start = data.find('assert')
        if start == -1:  # double checking that the assertion exists in this line
            ind += 1
            continue
        # account for combination statements
        for statement in compounding_statements:
            add_statement = data.find(statement)
            if add_statement != -1:
                extra_line = data[add_statement+len(statement):]
                lines.insert(ind+1, "assert "+extra_line)
                data = data[:add_statement].strip()

        com = data.find(',')   # parsing out return_string
        if com != -1:
            data = data[:com]
        com = data.find('#')
        if com != -1:   # parsing out comments
            if com < start:  # if the assertion itself is a comment
                ind += 1
                continue
            else:
                data = data[:com]

        if is_split:  # splitting the assertion into components for analysis
            data = [var.strip() for var in data.split(' ') if len(var.strip()) > 0]
            
            if len(data) < 1:  # edge case: nothing after 'assert' (likely typo)
                if verbose:
                    print("empty assertion found?: ", data, '\n', lines[ind])
                ind += 1
                continue
                
            if data[0] != "assert":  # edge case: something before the 'assert' statement
                ind += 1
#                 if verbose:
#                     print("something was found before the assertion on this line:\n", data)
                continue
    
            data = data[1:]  # from here on we only care about the content after the 'assert' keyword
            if len(data) < 1:  # edge case: nothing after 'assert' (likely typo)
                if verbose:
                    print("empty assertion found?: ", data, '\n', lines[ind])
                ind += 1
                continue

            condition = True  # assertion [variable] == condition by default
            if data[0] == "not":  # accounting for 'not' keyword
                condition = False
                data = data[1:]
            
            if len(data) == 1:  # adding == to simlify
                data = data + ["==", str(condition)]

            for i in range(len(data)):
                if data[i] == "is":  # simplifying is to ==
                    data[i] = "=="
                if data[i] in conditionals.keys():  # parsing common conditionals
                    data = [' '.join(data[:i]), data[i], ' '.join(data[i+1:])]  # conditionals[data[i]]
                    break

        if verbose and len(data) != 3:
            print("Weird assertion found:\n", data, '\n', lines[ind])
            print()
#             assert len(data) == 3, "found conditional-less assertion:\n" + str(data) + '\n' + str(lines[ind-1:ind+2])
        else:
            out.append(data)
        ind += 1
    return out, asserted_lines

unassert = lambda code: [line for line in code.split('\n') if "assert" not in line]

# small test
# tester_df["assertions"] = tester_df["content"].apply(lambda code: get_assertions(code))
assertions = []  # list of parsed assertions
asserted_lines = []  # number of lines with 'assert' in them
parsed_lines = []  # number of assertions easily parsed
arr = []  # assertion recovery ratio
atl = []  # assertions to size
for i, row in tqdm(tester_df.iterrows()):
    parsed, lines = get_assertions(row["content"])
    assertions.append(parsed)
    asserted_lines.append(lines)
    parsed_lines.append(len(parsed))
    arr.append(len(parsed)/lines)
    atl.append(len())
tester_df["unasserted"] = tester_df["content"].apply(unassert)

tester_df["assertions"] = assertions
tester_df["asserted_lines"] = asserted_lines
tester_df["parsed_lines"] = parsed_lines
tester_df["arr"] = arr
tester_df.describe()

200it [00:00, 2040.44it/s]

Weird assertion found:
 ['assert', 'that', 'it', 'fails'] 
 assert  assert that it fails

Weird assertion found:
 ['"ERROR', 'could', 'not', 'open', 'HDF', 'file:\\n', '-->', 'raw:', '"+self.raw.filename+"\\n', '-->', 'hdf:', '"+self.HDFObj.filename+"\\n"'] 
 assert "ERROR could not open HDF file:\n --> raw: "+self.raw.filename+"\n --> hdf: "+self.HDFObj.filename+"\n"

Weird assertion found:
 ['await', 'async_setup_component(hass'] 
 assert await async_setup_component(hass, const.DOMAIN, hass_config)

Weird assertion found:
 ['db._is_valid_field_name("foo', 'bar")'] 
 assert db._is_valid_field_name("foo bar")

Weird assertion found:
 ['db._is_valid_table_name("foo', 'bar")'] 
 assert db._is_valid_table_name("foo bar")

Weird assertion found:
 ['remove', "assert.'"] 
 assert  remove assert.'

Weird assertion found:
 ['loc', 'should', 'be', 'length', '6', '(', 'scan'] 
 assert loc should be length 6 ( scan, segment, candint, dmind, dtind, beamnum ).'

Weird assertion found:
 ['re.search(




Unnamed: 0,asserted_lines,parsed_lines,arr
count,200.0,200.0,200.0
mean,24.08,12.37,0.704143
std,57.765281,44.210018,0.425701
min,1.0,0.0,0.0
25%,2.0,1.0,0.414189
50%,5.0,3.0,0.877885
75%,19.0,10.0,1.0
max,578.0,578.0,2.5


In [35]:
print("Dropping Data with No Parsed Lines:")
no_parsed = len(tester_df[tester_df["parsed_lines"]==0])
print("#No-Parsed / Total =", (no_parsed/len(tester_df)))
tester_df = tester_df[~tester_df["parsed_lines"]=="0"]

Dropping Data with No Parsed Lines:
#No-Parsed/Total = 0.14


In [44]:
import plotly.express as px
tester_df

Unnamed: 0,repo_name,content,unasserted,assertions,asserted_lines,parsed_lines,arr
13113,taroplus/spark,# -*- encoding: utf-8 -*-\n#\n# Licensed to th...,"[# -*- encoding: utf-8 -*-, #, # Licensed to t...","[[datatype, ==, pickled], [datatype, ==, pytho...",157,2,0.012739
8658,spadejac/web-python,"'''\nCreated on Nov 10, 2013\n\n@author: Manoj...","[''', Created on Nov 10, 2013, , @author: Mano...","[[start, <=, end]]",3,1,0.333333
30406,tchellomello/home-assistant,"""""""Test the devolo_home_control config flow.""""...","[""""""Test the devolo_home_control config flow.""...","[[result[""type""], ==, ""form""], [result[""errors...",19,19,1.000000
16410,sauloal/pycluster,from __future__ import absolute_import\nfrom l...,"[from __future__ import absolute_import, from ...","[[list(d), ==, [2], [repr(d), ==, ""deque([2], ...",8,8,1.000000
32522,Answeror/torabot,from functools import partial\nfrom datetime i...,"[from functools import partial, from datetime ...","[[regular, ==, True]]",2,1,0.500000
...,...,...,...,...,...,...,...
13343,google-research/google-research,# coding=utf-8\n# Copyright 2022 The Google Re...,"[# coding=utf-8, # Copyright 2022 The Google R...","[[transformer, ==, not None]]",2,1,0.500000
22028,WiseDoge/plume,from plume.knn import KNeighborClassifier\nfro...,"[from plume.knn import KNeighborClassifier, fr...","[[np.all(y_pred, ==, self._train_y[:-1])]]",1,1,1.000000
7176,cvxopt/chompack,"from cvxopt import matrix, lapack, spmatrix\nf...","[from cvxopt import matrix, lapack, spmatrix, ...",[],1,0,0.000000
10586,radarsat1/siconos,"""""""Tests for python interface of friction cont...","[""""""Tests for python interface of friction con...","[[options.dparam[1], <, options.dparam[0]], [r...",2,2,1.000000


In [45]:
fig = px.box(tester_df, y="arr") # showing distribution of assertion recovery rate
fig.show()

## Step 2) Generate LLM Prompt & Query a GPT

In [None]:
def generate_prompt(asserted_code, verbose=True):
    ...
    
    
banned_vars = ['', '*', 'self']
def get_variables(func, verbose=False):
    out = []
    for line in func.split('\n'):
        line = line.strip()
        if "def " in line:  # add params if its a function
            start = line.find('(')
            end = line.find(')')
            for new_param in line[start+1:end].split(','):
                default = new_param.find("=")
                if default != -1:
                    new_param = new_param[:default]
                new_param = new_param.strip()
                if new_param not in out and new_param not in banned_vars:
                    if verbose:
                        print("*Found  {", new_param, "}  at:\n", line, '\n')
                    out.append(new_param)
        else: # add variables if equals operation
            find_var = line.find(' = ')
            if find_var != -1:
                new_var = line[:find_var].strip()
                
                if ',' in new_var: # handle tuple equalities edge case (ex: a, b, c = fn_output())
                    var_list = [tuple_var.strip() for tuple_var in new_var.split(',')]
                else:
                    var_list = [new_var]
                for new_var in var_list:
                    if new_var not in out and new_var not in banned_vars:
                        if verbose:
                            print("**Found  {", new_var, "}  at:\n", line, '\n')
                        out.append(new_var)
            # TODO: handle indexing
    return out

# out = get_variables(df.sample()["content"].iloc[0])
get_vars = lambda code: get_variables(code)
df["variables"] = df["content"].apply(get_vars)
df

In [None]:
# querying
import openai
import altair as alt
import json
from vega_datasets import data

def run_gpt4(messages):
    OPENAI_API_KEY = "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" # os.environ['OPENAI_API_KEY']
    openai.api_key = OPENAI_API_KEY
    response = openai.ChatCompletion.create(
        model="gpt-4",
        messages=messages
    )
    return response["choices"][0]["message"]["content"]


# TODO: add coding language versatility
def gpt_oneshot(input_prompt, directive="You are a helpful bot that adds assertions to pieces of Python code.", verbose=False):
    message_hist = [{"role": "system", "content": directive},
                    {"role": "user", "content": input_prompt}]  # init
    response = run_gpt4(message_hist)
    if verbose:
        print("chat_gpt: ", response, '\n')
#     message_hist.append({"role": "system", "content": response})
    return response

print("\n\n", gpt_oneshot("what do you do?"))

## Step 3) Parse & Evaluate GPT's Response

### Step 3.1) Restore the assertion(s) generated to code and evaluate
> Metrics of evaluation, does it run? does it add to the code? is it ground-truth-like? human evaluator rank? gpt evaluator rank?

## Step 4) ...