In [1]:
import os
import torch
import tqdm
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    AutoTokenizer,
    pipeline,
)
import gc
from tqdm import tqdm
import json
import re

In [2]:
model_name = "./Llama-2-7b-hf-luis"

dataset_name = "kokujin/prompts_1"

In [3]:
test = load_dataset(dataset_name, split="test")

In [4]:
opt = model_name.split("./")[1]
dic = {opt: []}

In [5]:
bnb_config = BitsAndBytesConfig(
    load_in_8bit=False,
)

In [None]:

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map={"": 0}
)
model.config.use_cache = False
model.config.pretraining_tp = 1

# Load LLaMA tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name + "_tokenizer", trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training

for prompt in tqdm(test):
    tmp = {
        "Prompt": "",
        "Original": "",
        "Prediction": ""
    }
    resp = prompt["Text"].split('### Output: ')
    prompt = resp[0]
    pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_new_tokens=5000)
    result = pipe(f"{prompt}'### Output':")
    result = result[0]['generated_text']
    tmp["Prompt"] = prompt
    tmp["Original"] = resp[1]
    tmp["Prediction"] = result
    dic[opt].append(tmp)
    
del model
del tokenizer
gc.collect()
gc.collect()

In [59]:
results = {
    "Prompt": [],
    "Original": [],
    "Prediction": []
}

In [58]:
def format(val):
    val = val.replace("':", "\":")
    val = val.replace("{'", "{\"")
    val = val.replace("',", "\",")
    val = val.replace("'}", "\"}")
    val = val.replace(": '", ": \"")
    val = val.replace(", '", ", \"")
    val = val.replace("}}", "}")
    val = val.replace("\n}", "")
    val = val.replace("\n\n", "\",\"")
    val = val.replace("\n", "")
    val = val.replace("\"t", "'t")
    val = val.replace("\"s", "'s")
    val = val.replace("\\", "/")
    val = re.sub("\d+\" ", "d'", val)
    return val

In [60]:
ommited = []
a = []
for i, t in enumerate(dic[opt]):
    try:
        val_p = t["Prediction"].split("'### Output': ")[1].replace("\"", "'").split("### Keys:")[0]
        if "llama" in model_name.lower():
            if len(val_p.split("### Output: ")) > 1:
                val_p = val_p.split("### Output: ")[1]
        else:
             val_p = val_p.split("### Output: ")
        val_p = "{\"Overview\": " + val_p if "{\"" not in val_p and "{'" not in val_p else val_p        
        val_p = format(val_p)
        val_p = val_p + "\"}" if "\"}" not in val_p else val_p

        val = t["Original"].replace("\"", "'")
        val = format(val)
        print(val_p)
        results["Prediction"].append(json.loads(val_p))
        results["Original"].append(json.loads(val))
        results["Prompt"].append(t["Prompt"])

    except:
        ommited.append(i)
        a.append(val_p)

print(len(ommited), ommited)

In [62]:
print(ommited, len(ommited))

[2, 3, 5, 6, 10, 11, 13, 18, 21, 22, 25, 26, 27, 32, 33, 36, 37, 38, 39, 40, 44, 47, 48, 54, 55, 56, 57, 58, 61, 63, 64, 65, 66, 67, 68, 69, 80, 82, 84, 86, 89, 90, 91, 98, 101, 103, 105, 106, 107, 108, 109, 111, 112, 116, 118, 119, 120, 121, 123, 126, 127, 129, 133, 134, 139, 140, 143, 144, 145, 146, 149, 150, 151, 152, 154, 155, 157, 158, 159, 161, 162, 163, 164, 167, 168, 171, 173, 174, 181, 183, 187, 188, 189, 190, 192, 193, 195, 198, 201, 202, 203, 204, 206, 207, 208, 210, 215, 216, 217, 218, 219, 221, 224, 225, 228, 229, 231, 232, 233, 235, 236, 238, 239, 240, 243, 245, 246, 248, 250, 251, 254, 256, 259, 260, 261, 263, 265, 267, 268, 270, 271, 273, 274, 275, 276, 278, 280, 281, 282, 284, 285, 287, 288, 289, 290, 292, 294, 295, 296, 298, 299, 301, 302, 304, 306, 307, 309, 310, 311, 312, 314, 315, 316, 317, 321, 322, 323, 329, 330, 331, 332, 335, 336, 337, 339, 342, 343, 345, 347, 350, 351, 353, 355, 356, 357, 358, 360, 361, 362, 363, 365, 366, 367, 369, 370, 371, 372, 374, 376, 37

In [64]:
print(len(results["Original"]),len(results["Prompt"]),len(results["Prediction"]))

1022 1022 1022


In [63]:
print(len(a))

1377


In [61]:
with open('Outputs/Llama-2-7b-hf/a.json', 'w') as f:
    json.dump(a, f, indent=4)