In [1]:
import os
import torch
import tqdm
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    AutoTokenizer,
    pipeline,
)
import gc
from tqdm import tqdm
import json
import re

In [2]:
#name = "StructLM-7B-luis"
#name = "Mistral-7B-Instruct-v0.3-luis"
name = "Llama-2-7b-hf-luis"
#model_name = "./models/Llama-2-7b-hf-luis"
model_name = f"./models/{name}"

dataset_name = "yale-nlp/QTSumm"
#dataset_name = "kokujin/prompts_1"

In [3]:
test = load_dataset(dataset_name, split="test")

if "qtsumm" in dataset_name.lower():
    test = test.add_column("prompt", [f"""Given following json that contains specifications of a product, generate a review of the key characteristics with json format. Follow the values on {{Keys}} to write the Output
    ### Product: {x["table"]}
    ### Keys: {x["query"]}
    ### Output:""" for x in test])

In [4]:
opt = model_name.split("/")[-1]
dic = {opt: []}

In [5]:
# Activate 8-bit precision base model loading
use_8bit = True

# Compute dtype for 4-bit base models
bnb_8bit_compute_dtype = "float16"

# Quantization type (fp4 or nf4)
bnb_4bit_quant_type = "nf4"

# Activate nested quantization for 4-bit base models (double quantization)
use_nested_quant = False

compute_dtype = getattr(torch, bnb_8bit_compute_dtype)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=use_8bit,
    bnb_4bit_quant_type=bnb_4bit_quant_type,
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=use_nested_quant,
)

In [6]:

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map={"": 0}
)
model.config.use_cache = False
model.config.pretraining_tp = 1

# Load LLaMA tokenizer
tokenizer = AutoTokenizer.from_pretrained("./tokenizers/" + name + "_tokenizer", trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training


if "qtsumm" in dataset_name.lower():
    for i, prompt in tqdm(enumerate(test)):
        tmp = {
        "Prompt": "",
        "Original": "",
        "Prediction": ""
        }
        resp = prompt["prompt"]
        prompt = resp
        #print(prompt)
        pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_new_tokens=3000)
        result = pipe(f"{prompt}")
        result = result[0]['generated_text']
        #print(result)
        tmp["Prompt"] = prompt
        tmp["Original"] = test["summary"][i]
        tmp["Prediction"] = result
        dic[opt].append(tmp)
else:
    for prompt in tqdm(test):
        tmp = {
            "Prompt": "",
            "Original": "",
            "Prediction": ""
        }
        resp = prompt["Text"].split('### Output: ')
        prompt = resp[0]
        pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_new_tokens=5000)
        result = pipe(f"{prompt}'### Output':")
        result = result[0]['generated_text']
        tmp["Prompt"] = prompt
        tmp["Original"] = resp[1]
        tmp["Prediction"] = result
        dic[opt].append(tmp)
    
del model
del tokenizer
gc.collect()
gc.collect()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

34it [31:57, 57.71s/it]This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length (4096). Depending on the model, you may observe exceptions, performance degradation, or nothing at all.
1078it [16:44:34, 55.91s/it]


0

In [17]:
results = {
    "Prompt": [],
    "Original": [],
    "Prediction": []
}

In [8]:
def format(val):
    val = val.replace("':", "\":")
    val = val.replace("{'", "{\"")
    val = val.replace("',", "\",")
    val = val.replace("'}", "\"}")
    val = val.replace(": '", ": \"")
    val = val.replace(", '", ", \"")
    val = val.replace("}}", "}")
    val = val.replace("\n}", "")
    val = val.replace("\n\n", "\",\"")
    val = val.replace("\n", "")
    val = val.replace("\"t", "'t")
    val = val.replace("\"s", "'s")
    val = val.replace("\\", "/")
    val = re.sub("\d+\" ", "d'", val)
    return val

In [18]:
def formatting_2(val: str):
    out = val.split("### Output: ")[1].split("\n")[0].replace("\'", "").replace("{", "").replace("}", "").replace("[", "").replace("]", "").replace(":", "")
    k = val.split("### Output: ")[1].split("\n")[0].replace("{", "").replace("}", "").replace("[", "").replace("]", "").split(":")
    res = {}
    for i, k_ in enumerate(k):
        if len(res) == 0:
            res[k_] = out.replace("\n", "").replace("#", "").split(k_)[1][1:]
        else:
            res[k_] = res[k[i - 1]].split(k_)[1][1:]
            res[k[i - 1]] = res[k[i - 1]].split(k_)[0]
    return res

In [10]:
import re

def convert_to_json_format(text):
    # Replace single quotes around keys and list elements with double quotes
    
    text = re.sub(r"(?<=[:{,\[])\s*'([^']+?)'\s*(?=[,}\]])", r'"\1"', text)
    text = re.sub(r"^{\s*'([^']+?)'\s*:", r'{"\1":', text)

    text = re.sub(r"(?<={|,)\s*'([^']+?)'\s*:", r'"\1":', text)
    
    return text

In [19]:
def formatting_2(val: str):
    out = val.split("### Output: ")[1].split("\n")[0]
    out = out.replace('"', "__")
    out = convert_to_json_format(out)
    out.replace("'", "\\'")
    out = out.replace("__", "'")
    out = out.replace("''", '""')
    if out.count("{") < out.count("}") and out.endswith("}}"):
        out = out[:-1]
    if "{" not in out and "[" not in out:
        out = '"' + out + '"'
        return out
    
    #out = fix_and_convert_to_json(out)
    return out

In [12]:
def formatqt(val):
    val = val.replace('"', "'")
    val = '"' + val + '"'
    return val

In [20]:
ommited = []
a = []
for i, t in enumerate(dic[opt]):
    try:
        val_p = t["Prediction"]
        if "llama" in model_name.lower():
            val_p = formatting_2(val_p)
            #val_p = json.dumps(val_p)
        else:
            val_p = formatting_2(val_p) 
            #print(val_p)
            #val_p = json.loads(val_p)
        

        val = t["Original"]
        val = formatqt(val)
        pred = json.loads(val_p) 
        org = json.loads(val)
        results["Prediction"].append(pred)
        results["Original"].append(org)
        results["Prompt"].append(t["Prompt"])

    except Exception as e:
        if "Extra data" in str(e):
            try:
                val_p = "[" + val_p + "]"

                pred = json.loads(val_p) 
                org = json.loads(val)
                results["Prediction"].append(pred)
                results["Original"].append(org)
                results["Prompt"].append(t["Prompt"])
            except Exception as e:
                print(e)
                ommited.append(i)
        else:
            print(e, val_p)
            ommited.append(i)
            a.append(val_p)

print(len(ommited), ommited)

Expecting value: line 1 column 7054 (char 7053) {"4200":"The 4200 is the same as the 4250 except for the memory that is 4 MB instead of 64 MB.","4250":"The HP LaserJet 4250 is a monochrome laser printer that has a print resolution of 1200 dpi. It has a print speed of 17 pages per minute and a scan resolution of 600 x 600 dpi. The printer has a 52 MB RAM and a 80 MB hard disk. The memory is expandable to 256 MB. It has a 2.4' LCD display. It has a 200-sheet paper tray and 100-sheet multipurysize tray. The printer is 11.5' x 15.2' x 12.5' and weighs 48.5 pounds. It has a 2-year limited warranty.","4240":"The HP LaserJet 4240 is a monochrome laser printer that has a print resolution of 1200 dpi. It has a print speed of 17 pages per minute and a scan resolution of 600 x 600 dpi. The printer has a 1200 MB hard disk and a 48 MB RAM. The memory is expandable to 64 MB. It has a 2.4' LCD display. It has a 250-sheet paper tray and a 100-sheet multipurysize tray. The printer is 11.5' x 15.2' x 12

In [14]:
print(len(results["Original"]),len(results["Prompt"]),len(results["Prediction"]))

1075 1075 1075


In [15]:
print(a)

['"{\\"Quantity\\":\\"2\\",\\"Build Year\\":\\"1985\\",\\"First in\\":\\"Bellanca Super Decathlon\\",\\"Last out\\":\\"Dornier Do 228\\",\\"Ref(s)\\":\\"Bellanca Super Decathlon is a single-engined, low-wing monoplane with a retractable landing gear and tricycle landing gear. The aircraft is manufactured by Bellanca Aircraft of the USA. It is available in the USA and Europe.\\"}"', '"[\\"Iona Presentation College\\",\\"Penrhos College\\",\\"Santa Maria College\\"]"', '"{\\"No In Series\\":\\"58\\",\\"No In Season\\":\\"1\\",\\"Title\\":\\"Devil May Care\\",\\"Directed By\\":\\"Allan Arkush\\",\\"Written By\\":\\"Jim Praytor, Andi Bushell\\", \'Us Viewers (Millions):\\"12.79\\",\\"Original Air Date\\":\\"March 7, 2004\\"}"']


In [21]:
with open(f'Outputs/Llama-2-7b-hf/{name}.json', 'w') as f:
    json.dump(results, f, indent=4)