In [1]:
import os
import torch
import tqdm
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    AutoTokenizer,
    pipeline,
)
import gc
from tqdm import tqdm
import json
import re

In [2]:
model_name = "./Mistral-7B-Instruct-v0.3-luis"

dataset_name = "kokujin/prompts_1"

In [3]:
test = load_dataset(dataset_name, split="test")

In [4]:
opt = model_name.split("./")[1]
dic = {opt: []}

In [5]:
bnb_config = BitsAndBytesConfig(
    load_in_8bit=False,
)

In [6]:

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map={"": 0}
)
model.config.use_cache = False
model.config.pretraining_tp = 1

# Load LLaMA tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name + "_tokenizer", trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training

for prompt in tqdm(test):
    tmp = {
        "Prompt": "",
        "Original": "",
        "Prediction": ""
    }
    resp = prompt["Text"].split('### Output: ')
    prompt = resp[0]
    pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_new_tokens=5000)
    result = pipe(f"{prompt}'### Output':")
    result = result[0]['generated_text']
    tmp["Prompt"] = prompt
    tmp["Original"] = resp[1]
    tmp["Prediction"] = result
    dic[opt].append(tmp)
    
del model
del tokenizer
gc.collect()
gc.collect()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

100%|██████████| 2399/2399 [12:08:03<00:00, 18.21s/it]  


0

In [7]:
print(dic["StructLM-7B-luis"][0]["Prediction"])

KeyError: 'StructLM-7B-luis'

In [20]:
results = {
    "Prompt": [],
    "Original": [],
    "Prediction": []
}

In [9]:
def format(val):
    val = val.replace("':", "\":")
    val = val.replace("{'", "{\"")
    val = val.replace("',", "\",")
    val = val.replace("'}", "\"}")
    val = val.replace(": '", ": \"")
    val = val.replace(", '", ", \"")
    val = val.replace("}}", "}")
    val = val.replace("\n}", "")
    val = val.replace("\n\n", "\",\"")
    val = val.replace("\n", "")
    val = val.replace("\"t", "'t")
    val = val.replace("\"s", "'s")
    val = val.replace("\\", "/")
    val = re.sub("\d+\" ", "d'", val)
    return val

In [21]:
ommited = []
for i, t in enumerate(dic[opt]):
    try:
        val_p = t["Prediction"].split("'### Output': ")[1].replace("\"", "'").split("### Output:")[0].split("### Keys:")[0]
        val_p = "{\"Overview\": " + val_p if "{\"" not in val_p and "{'" not in val_p else val_p        
        val_p = format(val_p)
        val_p = val_p + "\"}" if "\"}" not in val_p else val_p

        val = t["Original"].replace("\"", "'")
        val = format(val)
        print(val_p)
        results["Prediction"].append(json.loads(val_p))
        results["Original"].append(json.loads(val))
        results["Prompt"].append(t["Prompt"])

    except:
        ommited.append(i)

print(len(ommited), ommited)

{"Design and Display": "The LG G Flex 2 16GB has a 5.5 inch display with a resolution of 1080 x 1920 pixels and a pixel density of 401PPI. The display is a P-OLED display and has a screen to body ratio of 77.3%. The phone has a plastic body and is available in colors like Platinum Silver, Armor Gold, and Philosopher Black. The phone has a 3.5mm jack for headphones and a 2.5D curved display.", "Battery and Connectivity": "The battery on the LG G Flex 2 is a 3000mAh unit that is user replaceable. The phone has Bluetooth 4.1 and NFC built in. The phone has a microUSB 2.0 port for charging and syncing with a computer. The phone has WiFi a/b/g/n/ac and can be a mobile hotspot for other devices to connect to the internet through the phone. The phone has a 3.5mm headphone jack for connecting to headphones. The phone has a single SIM slot and does not have dual SIM support."}
{"Design and Display": "The Vivo X70 Pro Plus boasts a 6.78-inch AMOLED display with a resolution of 1440 x 3200 pixels

In [22]:
print(ommited, len(ommited))

[29, 32, 38, 44, 93, 101, 166, 265, 329, 352, 356, 370, 386, 407, 420, 431, 450, 460, 509, 533, 566, 576, 596, 629, 660, 691, 711, 713, 728, 729, 748, 758, 772, 851, 884, 918, 988, 1004, 1016, 1048, 1088, 1103, 1167, 1185, 1214, 1231, 1232, 1248, 1283, 1284, 1301, 1328, 1345, 1346, 1412, 1438, 1447, 1452, 1457, 1479, 1495, 1541, 1559, 1614, 1623, 1629, 1630, 1646, 1700, 1703, 1715, 1733, 1735, 1755, 1759, 1776, 1779, 1854, 1855, 1891, 1896, 1910, 1920, 1962, 1966, 1981, 2006, 2030, 2034, 2098, 2113, 2119, 2160, 2166, 2186, 2202, 2237, 2267, 2272, 2299, 2336, 2339, 2345, 2351, 2369] 105


In [23]:
print(len(results["Original"]),len(results["Prompt"]),len(results["Prediction"]))

2294 2294 2294


In [24]:
with open('Outputs/Mistral-7B-Instruct-v0.3/Mistral-7B-Instruct-v0.3.json', 'w') as f:
    json.dump(results, f, indent=4)