In [1]:
import os
import torch
import tqdm
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    AutoTokenizer,
    pipeline,
)
import gc
from tqdm import tqdm
import json
import re

In [2]:
model_name = "meta-llama/Llama-2-7b-hf"

#dataset_name = "kokujin/prompts_1"

dataset_name = "yale-nlp/QTSumm"

In [3]:
test = load_dataset(dataset_name, split="test")

In [4]:
opt = model_name + "_raw"
dic = {opt: []}

In [5]:
# LoRA attention dimension
lora_r = 64

# Alpha parameter for LoRA scaling
lora_alpha = 16

# Dropout probability for LoRA layers
lora_dropout = 0.1

In [6]:
# Activate 8-bit precision base model loading
use_8bit = True

# Compute dtype for 4-bit base models
bnb_8bit_compute_dtype = "float16"

# Quantization type (fp4 or nf4)
bnb_4bit_quant_type = "nf4"

# Activate nested quantization for 4-bit base models (double quantization)
use_nested_quant = False

In [7]:
# Output directory where the model predictions and checkpoints will be stored
output_dir = f"./results/{model_name}/"

# Number of training epochs
num_train_epochs = 2

# Enable fp16/bf16 training (set bf16 to True with an A100)
fp16 = False
bf16 = False

# Batch size per GPU for training
per_device_train_batch_size = 1

# Batch size per GPU for evaluation
per_device_eval_batch_size = 1

# Number of update steps to accumulate the gradients for
gradient_accumulation_steps = 3

# Enable gradient checkpointing
gradient_checkpointing = True

# Maximum gradient normal (gradient clipping)
max_grad_norm = 0.3

# Initial learning rate (AdamW optimizer)
learning_rate = 2e-4

# Weight decay to apply to all layers except bias/LayerNorm weights
weight_decay = 0.001

# Optimizer to use
optim = "paged_adamw_32bit"

# Learning rate schedule
lr_scheduler_type = "cosine"

# Number of training steps (overrides num_train_epochs)
max_steps = -1

# Ratio of steps for a linear warmup (from 0 to learning rate)
warmup_ratio = 0.03

# Group sequences into batches with same length
# Saves memory and speeds up training considerably
group_by_length = True

# Save checkpoint every X updates steps
save_steps = 500

# Log every X updates steps
logging_steps = 25

In [8]:
# Maximum sequence length to use
max_seq_length = 1000

# Pack multiple short examples in the same input sequence to increase efficiency
packing = False

# Load the entire model on the GPU 0
device_map = {"": 0}

In [9]:
# Load tokenizer and model with QLoRA configuration
compute_dtype = getattr(torch, bnb_8bit_compute_dtype)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=use_8bit,
    bnb_4bit_quant_type=bnb_4bit_quant_type,
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=use_nested_quant,
)

In [10]:

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map={"": 0}
)
model.config.use_cache = False
model.config.pretraining_tp = 1

# Load LLaMA tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training

for prompt in tqdm(test):
    tmp = {
        "Prompt": "",
        "Original": "",
        "Prediction": ""
    }
    resp = prompt["Text"].split('### Output: ')
    prompt = resp[0]
    pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_new_tokens=3000)
    result = pipe(f"{prompt}'### Output':")
    result = result[0]['generated_text']
    tmp["Prompt"] = prompt
    tmp["Original"] = resp[1]
    tmp["Prediction"] = result
    dic[opt].append(tmp)
    
del model
del tokenizer
gc.collect()
gc.collect()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 1/2399 [00:54<36:17:25, 54.48s/it]This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length (4096). Depending on the model, you may observe exceptions, performance degradation, or nothing at all.
100%|██████████| 2399/2399 [32:53:49<00:00, 49.37s/it]   


0

In [11]:
results = {
    "Prompt": [],
    "Original": [],
    "Prediction": []
}

In [12]:
def format(val):
    val = val.replace("':", "\":")
    val = val.replace("{'", "{\"")
    val = val.replace("',", "\",")
    val = val.replace("'}", "\"}")
    val = val.replace(": '", ": \"")
    val = val.replace(", '", ", \"")
    val = val.replace("}}", "}")
    val = val.replace("\n}", "")
    val = val.replace("\n\n", "\",\"")
    val = val.replace("\n", "")
    val = val.replace("\"t", "'t")
    val = val.replace("\"s", "'s")
    val = val.replace("\\", "/")
    val = re.sub("\d+\" ", "d'", val)
    return val

In [13]:
def formatting_2(val):
    out = val.split("'### Output':")[1]
    k = val.split("### Keys: ")[1].split("'### Output':")[0].replace('[', '').replace(']','').replace('\n', '').replace('\'', '').split(", ")
    res = {}
    for i, k_ in enumerate(k):
        if len(res) == 0:
            res[k_] = out.replace("\n", "").replace("#", "").split(k_)[1][1:]
        else:
            res[k_] = res[k[i - 1]].split(k_)[1][1:]
            res[k[i - 1]] = res[k[i - 1]].split(k_)[0]
    return res

In [14]:
def formatting_3(val):
    out = val.split("'### Output':")[1].replace("'", "\"")
    k = val.split("### Keys: ")[1].split("'### Output':")[0].replace('[', '').replace(']','').replace('\n', '').replace('\'', '').split(", ")
    res = {}
    for i, k_ in enumerate(k):
        if len(res) == 0:
            res[k_] = out.replace("\n", "").replace("#", "").split(k_)[1][1:]
        else:
            res[k_] = res[k[i - 1]].split(k_)[1][1:]
            res[k[i - 1]] = res[k[i - 1]].split(k_)[0]
    return res

In [15]:
ommited = []
a = []
for i, t in enumerate(dic[opt]):
    try:
        val_p = t["Prediction"]
        if "llama" in model_name.lower():
            val_p = formatting_2(val_p)
            val_p = json.dumps(val_p)
        elif "Struct" in model_name:
            val_p = formatting_2(val_p) 
            val_p = json.dumps(val_p)
        else:
            val_p = val_p.split("'### Output':")[1].replace("'", "\"")
            val_p = "{\"Overview\": " + val_p if "{\"" not in val_p and "{'" not in val_p and "{" not in val_p else val_p        
        
        print(val_p, '\n aca termina la review')
        #val_p = format(val_p)
        #val_p = val_p + "\"}" if "\"}" not in val_p else val_p

        val = t["Original"].replace("\"", "'")
        val = format(val)
        results["Prediction"].append(json.loads(val_p))
        results["Original"].append(json.loads(val))
        results["Prompt"].append(t["Prompt"])

    
    except Exception as e:
        ommited.append(i)
        #print(e)
        a.append(t["Prediction"])

print(len(ommited), ommited)


{"Design and Display": ", '", "Battery and Connectivity": "]"} 
 aca termina la review
{"Design and Display": ": {    \"Screen to Body Ratio\": 92.2,    \"Screen Protection\": \"Gorilla Glass 6\",    \"Water Resistance\": \"IP68, Water resistant (up to 30 minutes in a depth of 1.5 meter)\"  },  \"Hardware\": {    \"Chipset\": \"Qualcomm Snapdragon 888 Plus\",    \"CPU\": [      \" 1 x 3GHz Kryo 680\",      \"3 x 2.42GHz Kryo 680\",      \"4 x 1.8GHz Kryo 680\"    ],    \"GPU\": \"Adreno 660\",    \"Architecture\": \"64-bit\",    \"RAM\": \"12 GB\",    \"Internal Storage\": \"256 GB\",    \"MicroSD Card Slot\": \"No\"  },  \"Main Camera\": {    \"Number of Cameras\": \"Quad\",    \"Resolution\": [      \" 50 MP f/1.57 Wide Angle main camera\",      \"PDAF, OIS, Digital Zoom\",      \"48 MP f/2.2 ultra-wide camera\",      \"12 MP f/1.6 camera\",      \"Optical Zoom\",      \"8 MP f/3.4 camera\",      \"60x Digital Zoom, 5x Optical Zoom, Periscope\"    ],    \"Flash\": \"Dual-color LED Fl

In [16]:
print(ommited, len(ommited))

[4, 5, 6, 9, 10, 12, 16, 18, 35, 38, 42, 44, 71, 75, 80, 81, 86, 94, 101, 103, 105, 106, 115, 126, 129, 130, 132, 134, 136, 138, 145, 151, 154, 167, 170, 179, 182, 183, 185, 189, 193, 196, 199, 201, 205, 217, 220, 221, 229, 234, 242, 247, 249, 250, 251, 256, 259, 262, 263, 267, 280, 282, 283, 286, 291, 295, 299, 300, 306, 312, 315, 316, 317, 319, 328, 334, 339, 355, 356, 357, 360, 362, 364, 365, 371, 372, 377, 380, 382, 383, 389, 390, 391, 396, 400, 402, 405, 409, 412, 423, 426, 429, 431, 438, 442, 445, 451, 457, 461, 465, 466, 477, 478, 486, 492, 493, 497, 500, 503, 510, 518, 519, 524, 525, 528, 530, 532, 534, 538, 543, 545, 546, 547, 551, 562, 567, 569, 583, 585, 599, 605, 611, 613, 617, 620, 622, 625, 628, 629, 630, 634, 636, 639, 652, 653, 659, 666, 670, 676, 683, 687, 688, 693, 694, 695, 697, 702, 704, 713, 714, 716, 722, 724, 745, 747, 754, 755, 757, 764, 769, 770, 773, 774, 779, 783, 786, 791, 792, 802, 807, 813, 814, 815, 816, 825, 827, 839, 840, 841, 843, 845, 848, 860, 864, 8

In [17]:
print(len(results["Original"]),len(results["Prompt"]),len(results["Prediction"]))

1844 1844 1844


In [18]:
print(len(a))

555


In [19]:
with open('Outputs/Llama-2-7b-hf/Llama-2-7b-hf_raw.json', 'w') as f:
    json.dump(results, f, indent=4)