In [1]:
import os
import torch
import tqdm
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    AutoTokenizer,
    pipeline,
)
import gc
from tqdm import tqdm
import json
import re

In [2]:
model_name = "./StructLM-7B-luis"

dataset_name = "kokujin/prompts_1"

In [3]:
test = load_dataset(dataset_name, split="test")

In [4]:
dic = {'StructLM-7B-luis': []}

In [5]:
bnb_config = BitsAndBytesConfig(
    load_in_8bit=False,
)

In [6]:
opt = model_name.split("./")[1]
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map={"": 0}
)
model.config.use_cache = False
model.config.pretraining_tp = 1

# Load LLaMA tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name + "_tokenizer", trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training

for prompt in tqdm(test):
    tmp = {
        "Prompt": "",
        "Original": "",
        "Prediction": ""
    }
    resp = prompt["Text"].split('### Output: ')
    prompt = resp[0]
    pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_new_tokens=5000)
    result = pipe(f"{prompt}'### Output':")
    result = result[0]['generated_text']
    tmp["Prompt"] = prompt
    tmp["Original"] = resp[1]
    tmp["Prediction"] = result
    dic[opt].append(tmp)
    
del model
del tokenizer
gc.collect()
gc.collect()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

100%|██████████| 2399/2399 [3:32:32<00:00,  5.32s/it]   


0

In [18]:
print(dic["StructLM-7B-luis"][0]["Prediction"])

Given following json that contains specifications of a product, generate a review of the key characteristics with json format. Follow the values on {Keys} to write the Output:
### Product: {'Model Name': 'LG G Flex 2 16GB', 'General': {'Operating System': 'Android 5.0 (Lollipop)', 'Custom UI': '', 'Dimensions': '149.1mm x 75.3mm x 7.1mm', 'Weight': '152g'}, 'Display & Design': {'Size': '5.5 inches (13.97 cm)', 'Resolution': '1080 x 1920 pixels', 'Pixel Density': '401ppi', 'Touch Screen': 'Yes, Capacitive Touchscreen, Multi-touch', 'Type': 'P-OLED', 'Screen To Body Ratio': '', 'Aspect Ratio': '', 'Refresh Rate': '', 'Screen Protection': 'Corning Gorilla Glass v3', 'Design': '', 'Colour Options': ['Red', 'Silver'], 'Water Resistance': ''}, 'Hardware': {'Chipset': 'Qualcomm Snapdragon 810 MSM8994', 'CPU': [' 4 x 2GHz Cortex A53', '4 x 1.55GHz'], 'GPU': 'Adreno 430', 'Architecture': '64-bit', 'RAM': '2 GB', 'Internal Storage': '16 GB', 'MicroSD Card Slot': 'Up to 2 TB'}, 'Main Camera': {'N

In [23]:
results = {
    "Prompt": [],
    "Original": [],
    "Prediction": []
}

In [11]:
def format(val):
    val = val.replace("':", "\":")
    val = val.replace("{'", "{\"")
    val = val.replace("',", "\",")
    val = val.replace("'}", "\"}")
    val = val.replace(": '", ": \"")
    val = val.replace(", '", ", \"")
    val = val.replace("}}", "}")
    val = val.replace("\n}", "")
    val = val.replace("\n\n", "\",\"")
    val = val.replace("\n", "")
    val = val.replace("\"t", "'t")
    val = val.replace("\"s", "'s")
    val = val.replace("\\", "/")
    val = re.sub("\d+\" ", "d'", val)
    return val

In [24]:
ommited = []
for i, t in enumerate(dic["StructLM-7B-luis"]):
    try:
        val_p = t["Prediction"].split("'### Output': ")[1].replace("\"", "'")
        val_p = "{\"Overview\": " + val_p if "{\"" not in val_p and "{'" not in val_p else val_p        
        val_p = format(val_p)
        val_p = val_p + "\"}" if "\"}" not in val_p else val_p

        val = t["Original"].replace("\"", "'")
        val = format(val)
        results["Prediction"].append(json.loads(val_p))
        results["Original"].append(json.loads(val))
        results["Prompt"].append(t["Prompt"])

    except:
        ommited.append(i)

print(len(ommited), ommited)

90 [12, 25, 49, 53, 65, 105, 171, 178, 196, 197, 199, 209, 217, 232, 261, 272, 289, 292, 296, 303, 329, 338, 339, 344, 353, 370, 393, 491, 501, 507, 514, 542, 574, 630, 674, 759, 782, 794, 808, 820, 841, 855, 937, 967, 983, 996, 1046, 1071, 1086, 1117, 1119, 1145, 1170, 1180, 1185, 1186, 1235, 1275, 1305, 1322, 1325, 1337, 1344, 1355, 1423, 1476, 1500, 1510, 1511, 1528, 1550, 1695, 1704, 1788, 1886, 1915, 1942, 1990, 2001, 2014, 2019, 2032, 2058, 2096, 2111, 2168, 2173, 2234, 2361, 2371]


In [25]:
print(ommited, len(ommited))

[12, 25, 49, 53, 65, 105, 171, 178, 196, 197, 199, 209, 217, 232, 261, 272, 289, 292, 296, 303, 329, 338, 339, 344, 353, 370, 393, 491, 501, 507, 514, 542, 574, 630, 674, 759, 782, 794, 808, 820, 841, 855, 937, 967, 983, 996, 1046, 1071, 1086, 1117, 1119, 1145, 1170, 1180, 1185, 1186, 1235, 1275, 1305, 1322, 1325, 1337, 1344, 1355, 1423, 1476, 1500, 1510, 1511, 1528, 1550, 1695, 1704, 1788, 1886, 1915, 1942, 1990, 2001, 2014, 2019, 2032, 2058, 2096, 2111, 2168, 2173, 2234, 2361, 2371] 90


In [26]:
print(len(results["Original"]),len(results["Prompt"]),len(results["Prediction"]))

2309 2309 2309


In [27]:
with open('Outputs/StructLm-7B/StructLm-7B.json', 'w') as f:
    json.dump(results, f, indent=4)