In [4]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from itertools import permutations
from captum.attr import (
    FeatureAblation,
    LLMAttribution,
    TextTokenInput,
    TextTemplateInput,
)
from constants import *
import torch
import json
import os
import re

In [5]:
def load_model(hf_tag, bnb_config):
    n_gpus = torch.cuda.device_count()
    if bnb_config is None:
        model = AutoModelForCausalLM.from_pretrained(hf_tag, device_map="auto")
    else:
        model = AutoModelForCausalLM.from_pretrained(
            hf_tag, quantization_config=bnb_config, device_map="auto"
        )
    tokenizer = AutoTokenizer.from_pretrained(hf_tag, token=True)
    return model, tokenizer

In [6]:
def create_bnb_config():
    return BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )

In [7]:
def get_prompt(system_prompt, user_prompt):
    prompt = []
    if system_prompt:
        prompt.append({"role": "system", "content": system_prompt})
    prompt.append({"role": "user", "content": user_prompt})
    return prompt

In [8]:
RESULTS_FOLDER = "results"
GENERATED_CODE_FOLDER = os.path.join(RESULTS_FOLDER, "accepted_code")
ATTRIBUTE_CODE_FOLDER = os.path.join(RESULTS_FOLDER, "feature_ablation")
REFERENCE_CODE_FOLDER = os.path.join(RESULTS_FOLDER, "reference_code")
MODEL = "yicoder9b"
TASK = "logic_ops"
HF_MODEL_TAG = HUGGINGFACE_TAGS[MODEL]
QUANTIZE = True
LANG2LANG = list(permutations(LANGS.keys(), 2))

In [None]:
if QUANTIZE:
    model, tokenizer = load_model(HF_MODEL_TAG, create_bnb_config())
else:
    model, tokenizer = load_model(HF_MODEL_TAG, None)
model.eval()

In [10]:
fa = FeatureAblation(model)
llm_attr = LLMAttribution(fa, tokenizer)

In [11]:
# reference from https://stackoverflow.com/a/61305389
matcher = re.compile(r'((?:(\"+)[\s\S]+?\2|[^"\n]+)+)')

In [None]:
for FROM_LANG, TO_LANG in LANG2LANG:
    # if FROM_LANG in ["C", "CPP", "GO"] or (FROM_LANG in ["JAVA"] and TO_LANG in ["C", "CPP", "GO", "JS"]):
    #     continue
    from_code = (
        open(os.path.join(REFERENCE_CODE_FOLDER, f"{TASK}.{FROM_LANG.lower()}"), "r")
        .read()
        .strip()
    )
    to_code_path = os.path.join(
        GENERATED_CODE_FOLDER,
        f"{MODEL}-{TASK}-{FROM_LANG.lower()}-{TO_LANG.lower()}.{TO_LANG.lower()}",
    )
    if not os.path.exists(to_code_path):
        continue
    to_code = open(to_code_path, "r").read().strip()
    if not to_code:
        continue
    attribute_path = os.path.join(
        ATTRIBUTE_CODE_FOLDER,
        f"{MODEL}-{TASK}-{FROM_LANG.lower()}-{TO_LANG.lower()}.json",
    )
    if os.path.exists(attribute_path):
        print(f"Skipping {MODEL} {TASK} {FROM_LANG.lower()} {TO_LANG.lower()}")
        continue
    values = [x[0] for x in matcher.findall(from_code)]
    code_lines = "\n".join(["{}" for _ in range(len(values))])
    prompt = f"Convert the following code from {LANGS[FROM_LANG]} to {LANGS[TO_LANG]}. "
    prompt += f"This is the requirement for the code - {TASK_DESCRIPTION[TASK]}\n"
    formatted_prompt = prompt + "```\n" + code_lines + "\n```\n"
    if HUGGINGFACE_SYSTEM_PROMPT_SUPPORT[MODEL]:
        chat_prompt = get_prompt(
            "You are a helpful code conversion assistant.", formatted_prompt
        )
    else:
        chat_prompt = get_prompt(None, formatted_prompt)
    template = tokenizer.apply_chat_template(chat_prompt, tokenize=False)
    inp = TextTemplateInput(template=template, values=values)
    attr_res = llm_attr.attribute(inp, target=to_code)
    d = {
        "input_tokens": [str(x) for x in attr_res.input_tokens],
        "output_tokens": [str(x) for x in attr_res.output_tokens],
        "token_attr": attr_res.token_attr.squeeze().tolist(),
        "seq_attr": attr_res.seq_attr.tolist(),
    }
    json.dump(d, open(attribute_path, "w"))
    print(f"Done for {MODEL} {TASK} {FROM_LANG.lower()} {TO_LANG.lower()}")