In [1]:
import sys
import json
from collections import defaultdict
import torch
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import LlamaTokenizer, LlamaForCausalLM
from tqdm import tqdm
from peft import PeftModel

sys.path.append("../")
from pp_utils import get_model_and_tokenizer, load_arithmetic_dataloader

seed = 0
torch.manual_seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model, tokenizer = get_model_and_tokenizer("llama")

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


Loading checkpoint shards: 100%|██████████| 2/2 [00:05<00:00,  2.71s/it]


## Creating Clean Data

In [9]:
raw_data = defaultdict(list)
answer_dist = defaultdict(int)
n_digits = 5
for x in torch.randint(0, pow(10, n_digits), (100,)).tolist():
    for y in torch.randint(0, pow(10, n_digits), (10000,)).tolist():
        out = x + y

        x_ = str(x).zfill(n_digits)
        y_ = str(y).zfill(n_digits)
        out = str(out).zfill(n_digits)
        answer = out[0]

        if answer in answer_dist:
            if answer_dist[answer] >= 50:
                continue
            else:
                answer_dist[answer] += 1
        else:
            answer_dist[answer] = 1

        expression = f"x={x_}; y={y_}; x+y="
        raw_data["expression"].append(expression)
        raw_data["label"].append(answer)

# Save raw_data as a json file
with open("../../data/arithmetic_data_clean.json", "w") as f:
    json.dump(raw_data, f)

## Creating Counterfactual Data

In [10]:
with open("../../data/arithmetic_data_clean.json", "r") as f:
    clean_data = json.load(f)

In [11]:
raw_data = defaultdict(list)
n_digits = 5
for idx, label in enumerate(clean_data["label"]):
    while True:
        x = torch.randint(0, pow(10, n_digits), (1,)).item()
        y = torch.randint(0, pow(10, n_digits), (1,)).item()
        out = x + y

        x_ = str(x).zfill(n_digits)
        y_ = str(y).zfill(n_digits)
        out = str(out).zfill(n_digits)
        answer = out[0]

        alphabets = [chr(i) for i in range(97, 123) if chr(i) != "x"]
        first_var = alphabets[torch.randint(0, len(alphabets), (1,)).item()]

        alphabets = [
            chr(i) for i in range(97, 123) if chr(i) != "y" and chr(i) != first_var
        ]
        second_var = alphabets[torch.randint(0, len(alphabets), (1,)).item()]

        if answer != label:
            expression = (
                f"{first_var}={x_}; {second_var}={y_}; {first_var}+{second_var}="
            )
            raw_data["expression"].append(expression)
            raw_data["label"].append(answer)
            break

with open("../../data/arithmetic_data_corrupt.json", "w") as f:
    json.dump(raw_data, f)

## Evaluating Model

In [3]:
dataloader = load_arithmetic_dataloader(
    tokenizer=tokenizer,
    clean_data_file="../../data/arithmetic_data_clean.json",
    corrupt_data_file="../../data/arithmetic_data_corrupt.json",
    num_samples=500,
    batch_size=64,
)

In [4]:
correct, total = 0, 0
for data in dataloader:
    input_ids = data["base_tokens"].to(device)
    labels = data["labels"].to(device)

    outputs = model(input_ids=input_ids)
    logits = outputs.logits[:, -1]
    predicted_index = torch.argmax(logits, dim=-1)

    correct += (predicted_index == labels).sum().item()
    total += labels.size(0)

    del input_ids, labels, outputs, logits, predicted_index
    torch.cuda.empty_cache()

print(f"Accuracy: {round(correct/total, 2)}")

Accuracy: 0.65
