In [1]:
import os
import torch
import tqdm
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    AutoTokenizer,
    pipeline,
)
import gc
from tqdm import tqdm

In [2]:
model_name = "./StructLM-7B-luis"

dataset_name = "kokujin/prompts_1"

In [3]:
test = load_dataset(dataset_name, split="test")

In [4]:
dic = {'StructLM-7B-luis': []}

In [5]:
bnb_config = BitsAndBytesConfig(
    load_in_8bit=False,
)

In [6]:
opt = model_name.split("./")[1]
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map={"": 0}
)
model.config.use_cache = False
model.config.pretraining_tp = 1

# Load LLaMA tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name + "_tokenizer", trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training

for prompt in tqdm(test):
    prompt = prompt["Text"].split('### Output')[0]
    pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_new_tokens=4000)
    result = pipe(f"{prompt}'### Output':")
    result = result[0]['generated_text']
    dic[opt].append(result)
    
del model
del tokenizer
gc.collect()
gc.collect()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

100%|██████████| 5087/5087 [7:31:07<00:00,  5.32s/it]   


0

In [7]:
print(dic["StructLM-7B-luis"][:10])

["Given following json that contains specifications of a product, generate a review of the key characteristics with json format. Follow the values on {Keys} to write the Output:\n### Product: {'Model Name': 'Honor 9S', 'General': {'Operating System': 'Android 10 (Q)', 'Custom UI': 'Magic UI', 'Dimensions': '146.5mm x 70.9mm x 8.3mm', 'Weight': '144g'}, 'Display & Design': {'Size': '5.45 inches (13.84 cm)', 'Resolution': '720 x 1440 pixels', 'Pixel Density': '295ppi', 'Touch Screen': 'Yes, Capacitive Touchscreen, Multi-touch', 'Type': 'TFT LCD', 'Screen To Body Ratio': '', 'Aspect Ratio': '18 9', 'Refresh Rate': '', 'Screen Protection': '', 'Design': '', 'Colour Options': ['Black', 'Blue'], 'Water Resistance': ''}, 'Hardware': {'Chipset': 'MediaTek Helio P22', 'CPU': [' 8 x 2GHz Cortex A53'], 'GPU': 'PowerVR GE8320', 'Architecture': '64-bit', 'RAM': '2 GB', 'Internal Storage': '32 GB', 'MicroSD Card Slot': 'Up to 512 GB'}, 'Main Camera': {'Number of Cameras': 'Single', 'Resolution': [' 

In [40]:
import json
import re
results = {
    "Original": [],
    "Prediction": []
}

In [41]:
ommited = []
for i, t in enumerate(dic["StructLM-7B-luis"]):
    try:
        val = t.split("'### Output': ")[1].replace("\"", "'")
        val = "{\"Overview\": " + val if "{\"" not in val and "{'" not in val else val
        val = val.replace("':", "\":")
        val = val.replace("{'", "{\"")
        val = val.replace("',", "\",")
        val = val.replace("'}", "\"}")
        val = val.replace(": '", ": \"")
        val = val.replace(", '", ", \"")
        val = val.replace("}}", "}")
        val = val.replace("\n}", "")
        val = val.replace("\n\n", "\",\"")
        val = val.replace("\n", "")
        val = val.replace("\"t", "'t")
        val = val.replace("\"s", "'s")
        val = val.replace("\\", "/")
        val = re.sub("\d+\" ", "d'", val)
        val = val + "\"}" if "\"}" not in val else val
        
        results["Prediction"].append(json.loads(val))
    except:
        print(val)
        ommited.append(i)
for i, t in enumerate(test):
    if i not in ommited:
        val = t["Text"].split("### Output: ")[1].replace("\"", "'")
        val = val.replace("':", "\":")
        val = val.replace("{'", "{\"")
        val = val.replace("',", "\",")
        val = val.replace("'}", "\"}")
        val = val.replace(": '", ": \"")
        val = val.replace(", '", ", \"")
        val = val.replace("}}", "}")
        val = val.replace("\n}", "")
        val = val.replace("\n\n", "\",\"")
        val = val.replace("\n", "")
        val = val.replace("\"t", "'t")
        val = val.replace("\"s", "'s")
        val = val.replace("\\", "/")
        val = re.sub("\d+\" ", "d'", val)
        results["Original"].append(json.loads(val))


{"Design and Display": {"Size": "6.8 inches (17.27 cm)", "Resolution": "1080 x 2460 pixels", "Pixel Density": "395ppi", "Touch Screen": "Yes, Capacitive Touchscreen, Multi-touch", "Type": "IPS LCD, Dinorex Glass", "Screen To Body Ratio": "90.52 %", "Aspect Ratio": "20.5 9", "Refresh Rate": "", "Screen Protection": "Yes", "Design": "Punch-hole display", "Colour Options": ['Interstellar Black", "Komodo Island", "Turquoise Cyan", "Winsor Violet'], "Water Resistance": "IPX2, Splash proof"}, "Cameras": {"Number of Cameras": "Dual", "Resolution": [' 48 MP f/1.79 main camera", "10x Digital Zoom", "2 MP f/2.4 depth sensor'], "Flash": "Dual LED Flash", "Video": ['1920x1080@30fps'], "Features": ['Artificial Intelligence", "Bokeh Effect']}, "Software and Connectivity": {"Software and Connectivity": "Tecno Spark 8 Pro comes with Android 11 out of the box. The smartphone gets some basic connectivity options like dual sim VoLTE, Wi-Fi 802.11, ac/b/g/n, mobile hotspot, Bluetooth, USB OTG and USB Type

In [42]:
print(ommited, len(ommited))

[22, 82, 119, 153, 184, 328, 397, 401, 449, 453, 472, 493, 756, 779, 872, 900, 981, 999, 1002, 1010, 1025, 1088, 1097, 1170, 1241, 1269, 1371, 1384, 1391, 1398, 1421, 1427, 1430, 1529, 1532, 1538, 1624, 1757, 1769, 1785, 1838, 1850, 1932, 1957, 1993, 2014, 2051, 2083, 2084, 2124, 2140, 2151, 2154, 2182, 2240, 2242, 2248, 2395, 2404, 2416, 2429, 2444, 2458, 2540, 2545, 2657, 2701, 2771, 2911, 2916, 2951, 3001, 3054, 3076, 3189, 3202, 3203, 3232, 3270, 3304, 3401, 3513, 3611, 3670, 3718, 3737, 3772, 3810, 3812, 3968, 4052, 4075, 4092, 4217, 4337, 4488, 4499, 4637, 4638, 4643, 4753, 4898, 4909, 4939, 4999, 5001, 5020, 5040, 5056] 109


In [43]:
with open('Outputs/StructLm-7B.json', 'w') as f:
    json.dump(results, f, indent=4)