In [1]:
import os
import torch
import tqdm
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    AutoTokenizer,
    pipeline,
)
import gc
from tqdm import tqdm
import json
import re

In [2]:
#name = "StructLM-7B-luis"
#name = "Mistral-7B-Instruct-v0.3-luis"
#name = "Llama-2-7b-hf-luis-QTSUMM"
name = "StructLM-7B-luis-QTSUMM"
#model_name = "./models/Llama-2-7b-hf-luis"
model_name = f"./qtsumm/{name}"
#model_name = f"./models/{name}"

#dataset_name = "yale-nlp/QTSumm"
dataset_name = "kokujin/prompts_1"

In [3]:
test = load_dataset(dataset_name, split="test")

if "qtsumm" in dataset_name.lower():
    test = test.add_column("prompt", [f"""Given following json that contains specifications of a product, generate a review of the key characteristics with json format. Follow the values on {{Keys}} to write the Output
    ### Product: {x["table"]}
    ### Keys: {x["query"]}
    ### Output:""" for x in test])

In [4]:
opt = model_name.split("/")[-1]
dic = {opt: []}

In [5]:
# Activate 8-bit precision base model loading
use_8bit = True

# Compute dtype for 4-bit base models
bnb_8bit_compute_dtype = "float16"

# Quantization type (fp4 or nf4)
bnb_4bit_quant_type = "nf4"

# Activate nested quantization for 4-bit base models (double quantization)
use_nested_quant = False

compute_dtype = getattr(torch, bnb_8bit_compute_dtype)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=use_8bit,
    bnb_4bit_quant_type=bnb_4bit_quant_type,
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=use_nested_quant,
)

In [None]:

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map={"": 0}
)
model.config.use_cache = False
model.config.pretraining_tp = 1

# Load LLaMA tokenizer
tokenizer = AutoTokenizer.from_pretrained("./tokenizers/" + name + "_tokenizer", trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training


if "qtsumm" in dataset_name.lower():
    for i, prompt in tqdm(enumerate(test)):
        tmp = {
        "Prompt": "",
        "Original": "",
        "Prediction": ""
        }
        resp = prompt["prompt"]
        prompt = resp
        #print(prompt)
        pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_new_tokens=3000)
        result = pipe(f"{prompt}")
        result = result[0]['generated_text']
        #print(result)
        tmp["Prompt"] = prompt
        tmp["Original"] = test["summary"][i]
        tmp["Prediction"] = result
        dic[opt].append(tmp)
else:
    for prompt in tqdm(test):
        tmp = {
            "Prompt": "",
            "Original": "",
            "Prediction": ""
        }
        resp = prompt["Text"].split('### Output: ')
        prompt = resp[0]
        pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_new_tokens=3000)
        result = pipe(f"{prompt}'### Output':")
        result = result[0]['generated_text']
        tmp["Prompt"] = prompt
        tmp["Original"] = resp[1]
        tmp["Prediction"] = result
        dic[opt].append(tmp)
    
del model
del tokenizer
gc.collect()
gc.collect()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

  1%|          | 15/2399 [14:15<37:20:11, 56.38s/it]

In [None]:
results = {
    "Prompt": [],
    "Original": [],
    "Prediction": []
}

In [None]:
def format(val):
    val = val.replace("':", "\":")
    val = val.replace("{'", "{\"")
    val = val.replace("',", "\",")
    val = val.replace("'}", "\"}")
    val = val.replace(": '", ": \"")
    val = val.replace(", '", ", \"")
    val = val.replace("}}", "}")
    val = val.replace("\n}", "")
    val = val.replace("\n\n", "\",\"")
    val = val.replace("\n", "")
    val = val.replace("\"t", "'t")
    val = val.replace("\"s", "'s")
    val = val.replace("\\", "/")
    val = re.sub("\d+\" ", "d'", val)
    return val

In [None]:
def formatting_2(val: str):
    out = val.split("### Output: ")[1].split("\n")[0].replace("\'", "").replace("{", "").replace("}", "").replace("[", "").replace("]", "").replace(":", "")
    k = val.split("### Output: ")[1].split("\n")[0].replace("{", "").replace("}", "").replace("[", "").replace("]", "").split(":")
    res = {}
    for i, k_ in enumerate(k):
        if len(res) == 0:
            res[k_] = out.replace("\n", "").replace("#", "").split(k_)[1][1:]
        else:
            res[k_] = res[k[i - 1]].split(k_)[1][1:]
            res[k[i - 1]] = res[k[i - 1]].split(k_)[0]
    return res

In [None]:
import re

def convert_to_json_format(text):
    # Replace single quotes around keys and list elements with double quotes
    
    text = re.sub(r"(?<=[:{,\[])\s*'([^']+?)'\s*(?=[,}\]])", r'"\1"', text)
    text = re.sub(r"^{\s*'([^']+?)'\s*:", r'{"\1":', text)

    text = re.sub(r"(?<={|,)\s*'([^']+?)'\s*:", r'"\1":', text)
    
    return text

In [None]:
def formatting_2(val: str):
    if "### Output: " in val:
        out = val.split("### Output: ")[1].split("\n")[0]
    else:
        out = val.split("### Output': ")[1].split("\n")[0]

    #print("A")
    out = out.replace('"', "__")
    out = convert_to_json_format(out)
    out.replace("'", "\\'")
    out = out.replace("__", "'")
    out = out.replace("''", '""')
    if out.count("{") < out.count("}") and out.endswith("}}"):
        out = out[:-1]
    if "{" not in out and "[" not in out:
        out = '"' + out + '"'
        return out
    
    #out = fix_and_convert_to_json(out)
    return out

In [None]:
def formatqt(val):
    val = val.replace('"', "'")
    val = '"' + val + '"'
    return val

In [None]:
ommited = []
a = []
for i, t in enumerate(dic[opt]):
    try:
        val_p = t["Prediction"]
        if "llama" in model_name.lower():
            #print(val_p, "\n")
            val_p = formatting_2(val_p)
            #val_p = json.dumps(val_p)
        else:
            val_p = formatting_2(val_p) 
            #val_p = json.loads(val_p)
        

        val = t["Original"]
        val = formatqt(val)
        pred = json.loads(val_p) 
        org = json.loads(val)
        results["Prediction"].append(pred)
        results["Original"].append(org)
        results["Prompt"].append(t["Prompt"])

    except Exception as e:
        if "Extra data" in str(e):
            try:
                val_p = "[" + val_p + "]"

                pred = json.loads(val_p) 
                org = json.loads(val)
                results["Prediction"].append(pred)
                results["Original"].append(org)
                results["Prompt"].append(t["Prompt"])
            except Exception as e:
                print(e)
                ommited.append(i)
        else:
            #print(e, val_p)
            ommited.append(i)
            a.append(val_p)

print(len(ommited), ommited)

Expecting ',' delimiter: line 1 column 1238 (char 1237)
Expecting ',' delimiter: line 1 column 1180 (char 1179)
Expecting ',' delimiter: line 1 column 1056 (char 1055)
Expecting ',' delimiter: line 1 column 8874 (char 8873)
Expecting ',' delimiter: line 1 column 1377 (char 1376)
Expecting ',' delimiter: line 1 column 8949 (char 8948)
Expecting ',' delimiter: line 1 column 679 (char 678)
Expecting ',' delimiter: line 1 column 9541 (char 9540)
Expecting ',' delimiter: line 1 column 1049 (char 1048)
Expecting value: line 1 column 9517 (char 9516)
Expecting ',' delimiter: line 1 column 10055 (char 10054)
13 [37, 217, 239, 626, 943, 1202, 1261, 1347, 1348, 1430, 1628, 2102, 2170]


In [None]:
print(len(results["Original"]),len(results["Prompt"]),len(results["Prediction"]))

2386 2386 2386


In [None]:
print(a)

['\'The Samsung Galaxy A6 64GB comes with a 3000 mAh Li-ion battery, which is not compatible with wireless charging or fast charging. It has 2G, 3G, and 4G (compatible with India) network connectivity, and dual SIM slots. It supports Wi-Fi 802.11, b/g/n/n 5GHz, Direct, Mobile Hotspot, Bluetooth 4.2, and NFC. The device runs on Android 8.0 (Oreo) operating system with no mention of custom UI. Additionally, it has GPS with A-GPS, Glonass, FM Radio, and 3.5mm headphone jack. The Samsung Galaxy A6 64GB does not support Wi-Fi Calling or FM Radio. It has a 16 MP f/1.7 Wide Angle main camera with LED Flash, and a 16 MP f/1.9, Wide Angle main camera for the front camera. The device has a light sensor, proximity sensor, accelerometer, and gyroscope sensors. The Samsung Galaxy A6 64GB does not have a fingerprint sensor, but it has a MicroSD card slot that can accommodate up to 256 GB. The Samsung Galaxy A6 64GB is available in Black, Blue, and Gold colour options. The Samsung Galaxy A6 64GB weig

In [None]:
with open(f'Outputs/Llama-2-7b-hf/{name}.json', 'w') as f:
    json.dump(results, f, indent=4)