In [1]:
import pickle
from datasets import load_from_disk
import torch
from transformers import AutoTokenizer
import os

In [2]:
# Evaluate the model responses an accuracy
import hashlib
import re
import sys

import numpy as np

def get_label(expr, base):
    lhs, rhs = expr.split("+")
    lhs_base10 = int(lhs, base)
    rhs_base10 = int(rhs, base) 
    sum_base10 = lhs_base10 + rhs_base10
    return np.base_repr(sum_base10, base)

def unescape(str):
    placeholder = "<TMP>"
    assert placeholder not in str
    return str.replace("\\\\n", placeholder).replace("\\n", "\n").replace(placeholder, "\\n").replace("\\\\r", placeholder).replace("\\r", "\r").replace(placeholder, "\\r")


def parse_output(output):
    if len(output) == 0:
        return "FAILED"

    output_hash = hashlib.md5(output.encode("utf-8")).hexdigest()
    if output_hash in {"a7994fde4fba7d27500e6f03008abd7c"}:
        return "FAILED"

    output = output.replace(",", "").replace("С", "C")

    if (match := re.search("^[0-9A-Z]+$", output)) is not None:
        return output

    output = output.rstrip("\n$ `")
    if output.endswith("\n"):
        output = output[:-1]
    output = output.replace("\\text{", "")

    boxed_regex = r"boxed{(\\text{)?(result=)?([0-9A-Z]+(_{?[0-9]+}?)?\s*\+\s*[0-9A-Z]+(_{?[0-9]+}?)?\s*=\s*)?(0x)?([0-9A-Za-f \\.]+)(_ ?{?(base-)?([0-9]+|ten)}?)?}?(_{?([0-9]+|ten)}?)?}"
    get_result_from_boxed_regex = lambda match: match[-6].replace(" ", "").replace("\\", "")
    # match all \boxed{...} but also make sure there's only one match
    match = re.findall(boxed_regex, output)
    if len(match) >= 1 and all(get_result_from_boxed_regex(m) == get_result_from_boxed_regex(match[0]) for m in match):
        return get_result_from_boxed_regex(match[0])

    last_line = output.split("\n")[-1]
    match = re.findall(boxed_regex, last_line)
    if len(match) >= 1 and all(get_result_from_boxed_regex(m) == get_result_from_boxed_regex(match[0]) for m in match):
        return get_result_from_boxed_regex(match[0])

    last_line = output.rstrip(" .").split(".")[-1]
    match = re.findall(boxed_regex, last_line)
    if len(match) >= 1 and all(get_result_from_boxed_regex(m) == get_result_from_boxed_regex(match[0]) for m in match):
        return get_result_from_boxed_regex(match[0])

    if (match := re.search(r"\\boxed{[0-9A-Z]+}(_{?[0-9]+}?)?\s*\+\s*\\boxed{[0-9A-Z]+}(_{?[0-9]+}?)?\s*=\s*\\boxed{([0-9A-Z]+)}(_{?[0-9]+}?)?\.?$", last_line)) is not None:
        return match.groups()[-2]

    if (match := re.search(r"\\boxed{([0-9A-Z]+)_{?[0-9]+}?\s*=\s*[0-9A-Z]+_{?10}?}\$?\.?$", last_line)) is not None:
        return match.groups()[0]

    if (match := re.search(r"\$?[0-9A-Z]+(_{?[0-9]+}?)\s*\+\s*[0-9A-Z]+(_{?[0-9]+}?)\s*=\s*(0x)?([0-9A-Z]+)(_{?[0-9]+}?)\$?( in base-[0-9]+)?\.?$", output)) is not None:
        return match.groups()[-3]

    if (match := re.search(r"(=|is):?\s*\$?\\boxed{(0x)?([0-9A-Z]+)}\$? \(?in base-[0-9]+\)?,?( and| or| =) \$?\\boxed{(0x)?[0-9A-Z]+}\$? \(?in (base-10|decimal)\)?\.?$", output)) is not None:
        return match.groups()[2]
    if (match := re.search(r"(=|is):?\s*\$?\\boxed{(0x)?[0-9A-Z]+}\$? \(?in (base-10|decimal)\)?,?( and| or| =) \$?\\boxed{(0x)?([0-9A-Z]+)}\$? \(?in base-[0-9]+\)?\.?$", output)) is not None:
        return match.groups()[-1]
    # \boxed{207}_{10}$ which in base-11 is $\boxed{18A}$.
    if (match := re.search(r"\\boxed{[0-9A-Z]+}_\{10\}\$? which in base-[0-9]+ is \$?\\boxed{(0x)?([0-9A-Z]+)}\$?\.?$", output)) is not None:
        return match.groups()[-1]
    # 39 + 31 = 5A\boxed{}
    if (match := re.search(r"[0-9]+\s*\+\s*[0-9]+\s*=\s*([0-9A-Z]+)\\boxed\{\}\.?$", output)) is not None:
        return match.groups()[-1]

    # \boxed{result}\n62
    if (match := re.search(r"\\boxed{result}\s*(\n|=)?\s*([0-9A-Z+*^. ]+=\s*)?([0-9A-Z.]+)\$?\.?\**}?$", output)) is not None:
        return match.groups()[-1]

    # \boxed{result: 62}
    if (match := re.search(r"\\boxed{result: ([0-9A-Z]+)}$", output)) is not None:
        return match.groups()[0]

    match = re.findall(r"[0-9A-Z]+\s*\+\s*[0-9A-Z]+\s*=\s*(0x)?([0-9A-Z]+)", last_line)
    if len(match) == 1:
        return match[0][1]

    match_after_semicolon = r"\s+((\n|[ 0-9A-Z*^])+(\+(\n|[ 0-9A-Z*^])+)+(=|-+|_+)\s*)*([0-9A-Z]+)\s*(\(?(in )?base-[0-9]+\)?)?(, which [^,.]+)?(\s*\([^()]+\))?\.?$"
    if (match := re.search(r"\n([0-9A-Z]+)$", output)) is not None:
        return match.groups()[-1]
    if (match := re.search(r" in base-[0-9]+ is (equal to )?\"?(0x)?([0-9A-Z]+)\"?( base-[0-9]+)?(, (or|since) [^.]+)?( \([^()]+\))?\.$", output)) is not None:
        return match.groups()[-5]
    if (match := re.search(r" in base-[0-9]+: \$?([0-9A-Z]+)\$?\.$", output)) is not None:
        return match.groups()[-1]
    if (match := re.search(r" the base-[0-9]+ sum: ([0-9A-Z]+)\.$", output)) is not None:
        return match.groups()[-1]
    if (match := re.search(r"the result in base-[0-9]+ is ([0-9A-Z]+), which is equal to [0-9 *^+()]+\.$", output)) is not None:
        return match[1]
    if (match := re.search(r"the sum of [0-9A-Z]+ and [0-9A-Z]+ (in base-[0-9]+ )?(is|as):?" + match_after_semicolon, output)) is not None:
        return match.groups()[-5]
    if (match := re.search(r"the result of [0-9A-Z]+\s*\+\s*[0-9A-Z]+ (in base-[0-9]+ )?(is|as):?" + match_after_semicolon, output)) is not None:
        return match.groups()[-5]
    if (match := re.search(r"[0-9A-Z]+\s*\+\s*[0-9A-Z]+( in base-[0-9]+)?,? (which )?(equals|is equal to|as):? \$?([0-9A-Z]+)\$?(, written as [0-9A-Z]+)?\.?$", output)) is not None:
        return match.groups()[-2]
    if (match := re.search(r"in base-10 is \$?[0-9]+\$?,? (which )?(equals|is equal to|as):? \$?([0-9A-Z]+)\$?(, written as [0-9A-Z]+)?\.?$", output)) is not None:
        return match.groups()[-2]
    if (match := re.search(r"[0-9A-Z]+\s*\+\s*[0-9A-Z]+\s*=\s*([0-9A-Z]+)( in base-[0-9]+)?\.?$", output)) is not None:
        return match.groups()[-2]
    if (match := re.search(r"we can simply write the result as ([0-9A-Z]+)\.?$", output)) is not None:
        return match.groups()[-1]
    if (match := re.search(r"which can be written as ([0-9A-Z]+)\.?$", output)) is not None:
        return match.groups()[-1]
    if (match := re.search(r"(which gives|giving) us the( base-[0-9]+)? number ([0-9A-Z]+)\.?$", output)) is not None:
        return match.groups()[-1]
    if (match := re.search(r"the final result is simply the sum of the tens and ones places: ([0-9A-Z]+)\.?$", output)) is not None:
        return match.groups()[-1]
    if (match := re.search(r"the result is simply the combination of these two sums: ([0-9A-Z]+)\.?$", output)) is not None:
        return match.groups()[-1]
    if (match := re.search(r"we have ([0-9A-Z]+) in base-[0-9]+ as the (final answer for|result of|sum of) [0-9A-Z]+ (\+|and) [0-9A-Z]+\.$", output)) is not None:
        return match[1]
    if (match := re.search(r"we (have|get|end up with) ([0-9A-Z]+)( in base-[0-9]+)? as the( final)? (result|answer|sum)( in base-[0-9]+)?\.$", output)) is not None:
        return match.groups()[1]
    if (match := re.search(r"(=| is) \"?([0-9A-Z]+)\"?\s*(\s+\(?(in )?base-[0-9]+\)?)?\.?$", output)) is not None:
        return match.groups()[1]
    if (match := re.search(r"( final)?( base-[0-9]+)? (result|answer|sum)( in base-[0-9]+)?( is)?( simply)?( of)?( as)?:?" + match_after_semicolon, output)) is not None:
        return match.groups()[-5]
    if (match := re.search(r"we get:" + match_after_semicolon, output)) is not None:
        return match.groups()[-5]
    if (match := re.search(r"we can add the two numbers in base-[0-9]+:" + match_after_semicolon, output)) is not None:
        return match.groups()[-5]
    if (match := re.search(r"[tT]he combination of these sums:\s+([0-9A-Z]+)(\(in base-[0-9]+\))?\.?$", output)) is not None:
        return match.groups()[-2]
    if (match := re.search(r"(Result|Answer)( is)?:?\s+([0-9A-Z]+)\.?$", output)) is not None:
        return match.groups()[-1]
    if (match := re.search(r"The decimal equivalent of \$?([0-9A-Z]+)\$? is therefore \$?[0-9A-Z]+\$?\.?$", output)) is not None:
        return match.groups()[0]
    if (match := re.search(r"(T|t)he final (result|answer) is:?\s+([0-9A-Z ]+\s*\+\s*[0-9A-Z ]+\s*(=|-+)+\s*)?([0-9A-Z ]+)(\(in base-[0-9]+\))?\.?\**$", output)) is not None:
        return match.groups()[-2].replace(" ", "")
    if (match := re.search(r" in base-[0-9]+ is (equal to )?\"?(0x)?([0-9A-Z ]+)\"?(, or [^,.]+)?\.$", output)) is not None:
        return match.groups()[-2].replace(" ", "")
    if (match := re.search(r"( |(\n))([0-9A-Z]+) \(?in base-[0-9]+\)?\.$", output)) is not None:
        return match.groups()[-1]

    #print("Failed to parse output:", output)
    #print(output_hash)
    return "FAILED"

In [None]:
# Settings
base = 10
type = "probe"
with_intervention = True
alpha = 0.0
layer = 9
model_name = 'meta-llama/Llama-3.1-8B'


if with_intervention:
    input_dir = f"../../../../results/arithmetic/{type}/base{base}/with_intervention/alpha_{alpha:0.2f}_layer_dofm_{layer}/combined"
    output_dir = f"{model_name.split('/')[-1]}/base{base}/with_intervention/alpha_{alpha:0.2f}_layer_dofm_{layer}"
    os.makedirs(output_dir, exist_ok=True)
else:
    input_dir = f"../../../../results/arithmetic/{type}/base{base}/base{base}"
    output_dir = f"{model_name.split('/')[-1]}/base{base}"
    os.makedirs(output_dir, exist_ok=True)
    alpha = 0.0
    layer = None



tokenizer = AutoTokenizer.from_pretrained(model_name)


dataset = load_from_disk(input_dir)
activations = dataset["residual_activations"]
# Save expressions and LLM responses to a file
expressions = dataset["expr"]
llm_responses = dataset["llm_response"]

print(len(expressions))





In [5]:
os.makedirs(output_dir, exist_ok=True)
# save  the expressions and llm responses to disk
with open(f"{output_dir}/expressions_base{base}.pkl", "wb") as f:
    pickle.dump(expressions, f)
with open(f"{output_dir}/llm_responses_base{base}.pkl", "wb") as f:
    pickle.dump(llm_responses, f)

In [6]:
# Get the Accuracy of the LLM responses:
wrong_answer_mask = torch.zeros(len(llm_responses), dtype=torch.bool)
correct_answer_mask = torch.zeros(len(llm_responses), dtype=torch.bool)
real_world_answer_mask = torch.zeros(len(llm_responses), dtype=torch.bool)
count = 0
for i in range(len(llm_responses)):
    try:
        correct_response = get_label(expressions[i], base)
        real_world_response = None
        if base < 10:
            real_world_response = get_label(expressions[i], 10)
        
        llm_response = llm_responses[i].split("=")[1].strip()
        llm_response = llm_response.split("\n")[0].strip()
        pred = parse_output(llm_response).upper()


        if pred == correct_response:
            correct_answer_mask[i] = True
        elif real_world_response is not None and pred == real_world_response:
            real_world_answer_mask[i] = True
        else:
            wrong_answer_mask[i] = True
    except Exception as e:
        print(llm_responses[i])


print("Percentage of wrong answers:", wrong_answer_mask.sum()/len(llm_responses))
print("Percentage of correct answers:", correct_answer_mask.sum()/len(llm_responses))
if base < 10:
    print("Percentage of real world answers:", real_world_answer_mask.sum()/len(llm_responses))


Percentage of wrong answers: tensor(0.)
Percentage of correct answers: tensor(1.)


In [5]:
layer_activation_dict={i: torch.zeros(len(activations), 4096) for i in range(32)}
for i in range(len(activations)):
    for j in range(32):
        layer_activation_dict[j][i] = torch.tensor(activations[i][j])

with open(f"{output_dir}/layer_activations_base{base}.pkl", "wb") as f:
    pickle.dump(layer_activation_dict, f)

In [6]:
# Function to get the probabilities over the next token given the activations.
def logit_lens(unembed_weights, final_layer_norm, activations):
    """
    Get the probabilities of the next token given the activations.
    """

    normed_h = final_layer_norm(activations)
    logits = torch.matmul(normed_h, unembed_weights.T)
    
    probs = torch.softmax(logits, dim=-1)
    
    return probs

In [7]:
# Process the activations for each layer and all instances and save -- takes around 6-7 minutes
layer_probs = {_ : torch.zeros(1000, 1000) for _ in range(32)}
correct_token_ids = []

layer_activation_dict = pickle.load(open(f"{output_dir}/layer_activations_base{base}.pkl", "rb"))

# Load the unembedding weights and final layer normalization for logit lens
unembed_weights = torch.load(f"{model_name.split('/')[-1]}/unembed_weights.pt", weights_only=True)
final_layer_norm = torch.load(f"{model_name.split('/')[-1]}/final_layer_norm.pt", weights_only=False)


for layer in range(32):
    print(f"Processing layer {layer}...")
    layer_activations = layer_activation_dict[layer]
    with torch.no_grad():
        probs = logit_lens(unembed_weights, final_layer_norm, layer_activations)
    layer_probs[layer] = probs.cpu()
    


Processing layer 0...
Processing layer 1...
Processing layer 2...
Processing layer 3...
Processing layer 4...
Processing layer 5...
Processing layer 6...
Processing layer 7...
Processing layer 8...
Processing layer 9...
Processing layer 10...
Processing layer 11...
Processing layer 12...
Processing layer 13...
Processing layer 14...
Processing layer 15...
Processing layer 16...
Processing layer 17...
Processing layer 18...
Processing layer 19...
Processing layer 20...
Processing layer 21...
Processing layer 22...
Processing layer 23...
Processing layer 24...
Processing layer 25...
Processing layer 26...
Processing layer 27...
Processing layer 28...
Processing layer 29...
Processing layer 30...
Processing layer 31...


: 

In [None]:
with open(f"{output_dir}/logit_lens_results_base{base}.pkl", "wb") as f:
    pickle.dump(layer_probs, f)