In [None]:
import openai
import pickle
from nltk.tokenize import word_tokenize
from nltk.translate.bleu_score import sentence_bleu
import shap
import numpy as np

In [1]:
def llm(prompt, model="gpt4"):
    if model == "gpt4":
        client = openai.OpenAI(
            api_key= "",
        )
       
        chat_completion = client.chat.completions.create(
            messages=[
                {
                    "role": "user",
                    "content": prompt,
                }
            ],
            temperature=0,
            model="gpt-4-1106-preview"
        )
        return chat_completion.choices[0].message.content

In [2]:
dataset = pickle.load(open(r"dataset.pkl", "rb"))
original_results = pickle.load(open(r"original_results.pkl", "rb"))

In [14]:
def tokenizer(text):
    tokens = word_tokenize(text)
    return tokens
    
def generate_perturbed_texts(tokens, prefix="", postfix=""):
    perturbed_texts = []
    for idx in range(len(tokens)):
        perturbed_tokens = tokens.copy()
        perturbed_tokens[idx] = "[MASK]"
        perturbed_text = " ".join(perturbed_tokens)
        perturbed_texts.append(prefix + perturbed_text + postfix)
    return perturbed_texts

def compare(original_code, perturbated_code):
    return 1 - sentence_bleu([original_code.split()], perturbated_code.split())

def each_shap(original_text, original_output):
    def model_function(binary_vectors, perturbed_texts, original_output):
        explanations = []
        for binary_vector in binary_vectors:
            perturbed_text = perturbed_texts[np.argmax(np.all(binary_matrix == binary_vector, axis=1))]
            perturbed_output = llm_with_caching(perturbed_text)
            score = compare(original_output, perturbed_output)
            explanations.append(score)
        return np.array(explanations)

    cache = {}
    def llm_with_caching(perturbed_text):
        if perturbed_text in cache:
            return cache[perturbed_text]
        else:
            cache[perturbed_text] = llm(perturbed_text)
        return cache[perturbed_text]
    
    if original_text.find(">>>") != -1:
        tokens = tokenizer(original_text[original_text.find("def"):original_text.find(">>>")])
    else:
        tokens = tokenizer(original_text)
    
    if original_text.find(">>>") != -1:
        background_dataset = generate_perturbed_texts(tokens, original_text[:original_text.find("def")], original_text[original_text.find(">>>"):])
    else:
        background_dataset = generate_perturbed_texts(tokens)

    binary_matrix = []
    for perturbed_text in background_dataset:
        row = [1 if token not in perturbed_text.split() or token == "[MASK]" else 0 for token in tokens]
        binary_matrix.append(row)
    binary_matrix = np.array(binary_matrix)

    explainer = shap.KernelExplainer(lambda x: model_function(x, background_dataset, original_output), binary_matrix)
    shap_values = explainer.shap_values(np.ones((1, len(tokens))), nsamples=50)  # Explaining with all tokens "present"

    importances = []

    for token, shap_value in zip(tokens, shap_values[0]):
        importances.append((token, shap_value))
        
    return original_text, importances

In [None]:
import concurrent.futures
from tqdm import tqdm

def process_item(item):
    if item not in human_eval_shap_results:
        original_text = item
        original_output = original_results[item]
        return each_shap(original_text, original_output)
    else:
        return None

human_eval_shap_results = []

with concurrent.futures.ThreadPoolExecutor(max_workers=50) as executor:
    futures = [executor.submit(process_item, item) for item in dataset]

    for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
        if future.result():
            human_eval_shap_results[future.result()[0]] = future.result()[1]
