In [1]:
!pip install -q bitsandbytes lxt

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/81.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m81.7/81.7 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m122.4/122.4 MB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m49.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for lxt (setup.py) ... [?25l[?25hdone


In [4]:
from transformers import AutoTokenizer, BitsAndBytesConfig
from lxt.models.llama import LlamaForCausalLM, attnlrp
from itertools import permutations
from constants import *
import torch
import json
import os
import re

In [5]:
def load_model(hf_tag, bnb_config):
    n_gpus = torch.cuda.device_count()
    if bnb_config is None:
        model = LlamaForCausalLM.from_pretrained(hf_tag, device_map="auto")
    else:
        model = LlamaForCausalLM.from_pretrained(
            hf_tag, quantization_config=bnb_config, device_map="auto"
        )
    tokenizer = AutoTokenizer.from_pretrained(hf_tag, token=True)
    return model, tokenizer

In [6]:
def create_bnb_config():
    return BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )

In [7]:
def get_prompt(system_prompt, user_prompt):
    prompt = []
    if system_prompt:
        prompt.append({"role": "system", "content": system_prompt})
    prompt.append({"role": "user", "content": user_prompt})
    return prompt

In [8]:
RESULTS_FOLDER = "results"
GENERATED_CODE_FOLDER = os.path.join(RESULTS_FOLDER, "accepted_code")
ATTRIBUTE_CODE_FOLDER = os.path.join(RESULTS_FOLDER, "layer_relevance_propagation")
REFERENCE_CODE_FOLDER = os.path.join(RESULTS_FOLDER, "reference_code")
MODEL = "yicoder9b"
TASK = "logic_ops"
HF_MODEL_TAG = HUGGINGFACE_TAGS[MODEL]
QUANTIZE = True
LANG2LANG = list(permutations(LANGS.keys(), 2))

In [10]:
if QUANTIZE:
    model, tokenizer = load_model(HF_MODEL_TAG, create_bnb_config())
else:
    model, tokenizer = load_model(HF_MODEL_TAG, None)
model.eval()
attnlrp.register(model)

KeyError: 'sdpa'

In [None]:
def hidden_relevance_hook(module, input, output):
    if isinstance(output, tuple):
        output = output[0]
    module.hidden_relevance = output.detach().cpu()

In [None]:
for layer in model.model.layers:
    layer.register_full_backward_hook(hidden_relevance_hook)

In [None]:
for FROM_LANG, TO_LANG in LANG2LANG:
    from_code = (
        open(os.path.join(REFERENCE_CODE_FOLDER, f"{TASK}.{FROM_LANG.lower()}"), "r")
        .read()
        .strip()
    )
    to_code_path = os.path.join(
        GENERATED_CODE_FOLDER,
        f"{MODEL}-{TASK}-{FROM_LANG.lower()}-{TO_LANG.lower()}.{TO_LANG.lower()}",
    )
    if not os.path.exists(to_code_path):
        continue
    to_code = open(to_code_path, "r").read().strip()
    if not to_code:
        continue
    attribute_path = os.path.join(
        ATTRIBUTE_CODE_FOLDER,
        f"{MODEL}-{TASK}-{FROM_LANG.lower()}-{TO_LANG.lower()}.json",
    )
    if os.path.exists(attribute_path):
        print(f"Skipping {MODEL} {TASK} {FROM_LANG.lower()} {TO_LANG.lower()}")
        continue
    prompt = f"Convert the following code from {LANGS[FROM_LANG]} to {LANGS[TO_LANG]}. "
    prompt += f"This is the requirement for the code - {TASK_DESCRIPTION[TASK]}\n"
    formatted_prompt = prompt + "```\n" + from_code + "\n```\n"
    if HUGGINGFACE_SYSTEM_PROMPT_SUPPORT[MODEL]:
        chat_prompt = get_prompt(
            "You are a helpful code conversion assistant.", formatted_prompt
        )
    else:
        chat_prompt = get_prompt(None, formatted_prompt)
    template = tokenizer.apply_chat_template(chat_prompt, tokenize=False)
    input_ids = tokenizer(template, return_tensors="pt", add_special_tokens=False).input_ids.to(model.device)
    input_embeds = model.get_input_embeddings()(input_ids)
    output_logits = model(inputs_embeds=input_embeds.requires_grad_(), use_cache=False).logits
    max_logits, max_indices = torch.max(output_logits[:, -1, :], dim=-1)
    max_logits.backward(max_logits)
    relevance_trace = []
    for layer in model.model.layers:
        relevance = layer.hidden_relevance[0].sum(-1)
        relevance = relevance / relevance.abs().max()
        relevance_trace.append(relevance)
    relevance_trace = torch.stack(relevance_trace)
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    break
