In [1]:
from vllm import LLM, SamplingParams
import torch
import gzip
import json
import jsonlines
import os

In [2]:
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=512)

In [3]:
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.1", dtype=torch.float16, gpu_memory_utilization=0.95, enforce_eager=True)

INFO 03-14 19:39:47 llm_engine.py:87] Initializing an LLM engine with config: model='mistralai/Mistral-7B-Instruct-v0.1', tokenizer='mistralai/Mistral-7B-Instruct-v0.1', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=32768, download_dir=None, load_format=auto, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, device_config=cuda, seed=0)
INFO 03-14 19:39:54 weight_utils.py:163] Using model weights format ['*.safetensors']
INFO 03-14 19:40:18 llm_engine.py:357] # GPU blocks: 5835, # CPU blocks: 2048


In [4]:
instruction = """Give a concise summary for the below description of the product in the form {"summary": ...}.\n\nProduct Info:\n"""

In [5]:
def parse(path):
    g = gzip.open(path, "rb")
    for l in g:
        yield json.loads(l)

In [6]:
test_dir = "data/raw_compressed/metadata/test/"
train_dir = "data/raw_compressed/metadata/train/"

In [7]:
def get_info_from_sample(sample:dict):
    description = sample.get("description", None)
    if isinstance(description, list):
        description = " ".join(description)
        
    features = sample.get("features", None)
    if isinstance(features, list):
        features = "\n- " + "\n- ".join(features)
    
    if description and features:
        return f"""Description:\n{description}\nFeatures:{features}"""
    elif description:
        return description
    elif features:
        return features
    else:
        return None

In [8]:
batch_size=16

In [9]:
# test = """<s>[INST] Give a concise summary for the below description of the product in the form {"summary": ...}.\n\nProduct Info:\n<div class="aplus"> <br>Amazon.com Gift Cards are the perfect way to give them exactly what they\'re hoping for--even if you don\'t know what it is. Amazon.com Gift Cards are redeemable for millions of items storewide, and never expire.<br /><br /> Box of 50 physical Amazon.com Gift Cards. Each card is attached to a folded greeting card and is packed in an individual 5 x 7 inch unsealed envelope. Also available to purchase as <a href="http://www.amazon.com/gp/product/B001H53QDK/ref=g_gc_asin_dp_sgl">individual gift cards</a>.<br/><br/> <a href="http://www.amazon.com/gp/product/B00067L6TQ/ref=g_gc_asin_dp_gclp">Check out our customized E-mail, Print at Home, and Mail gift card options</a>. Need a gift card in a hurry? Buy an Amazon.com Gift Card at <a href="http://www.amazon.com/gp/feature.html?docId=1000465651">a store near you</a>.<br /><br /> Amazon.com Gift Cards are a great way to motivate, reward, and appreciate your employees or customers. Order large quantities of bulk cards or codes, custom denominations, and custom Gift Card messaging through the <a href="http://www.amazon.com/gp/browse.html/ref=g_gc_asin_dp_corp?node=165034011">Amazon.com Corporate Gift Card Program</a>. Advertising the use of Amazon.com Gift Cards as an incentive or reward requires a Corporate Gift Card agreement. <a href="http://www.amazon.com/gp/browse.html/ref=g_gc_asin_dp_corp?node=165034011">Learn more</a>.<br /><br /> See Amazon.com Gift Card <a href="http://www.amazon.com/gp/browse.html/ref=g_gc_asin_dp_legal?node=3122091"> Terms and Conditions</a>.<br> </div> [/INST]"""

In [10]:
# print(test.split("Product Info:")[1].strip().removesuffix(" [/INST]"))

In [11]:
# batch = []
# for sample in parse("data/raw_compressed/metadata/test/meta_Gift_Cards.json.gz"):
#     info = get_info_from_sample(sample)
#     gct+=1
#     if info:
#         batch.append(info)
#         ct+=1
#     if len(batch)==4:
#         print(batch)
#         batch=[]
        

In [12]:
def get_labelled_data(split_dir, out_dir, skiplines=0, filename=None):
    
    # out_file = jsonlines.open(out_file_path, mode="a")
    batch = []
    for file in os.listdir(split_dir):
        # without_ext = file.split(".")[0]
        with jsonlines.open(os.path.join(out_dir, f'{file.split(".")[0]}.jsonl'), mode="a") as out_file:
            print(f"Processing file - {file}")
            ct = 0
            err_ct = 0
            pred_ct = 0
            input_file_path = os.path.join(split_dir, file)
            for sample in parse(input_file_path):
                ct+=1
                if filename and filename==file and ct<skiplines:
                    continue
                if ct and ct%1000 == 0: print(f"\n========\nProcessed {ct} products so far. Additional/Total labelled - {pred_ct} samples.\n========")

                info = get_info_from_sample(sample)
                if info and len(info.split())>50:
                    # batch.append("<s>[INST] "+instruction+info+" [/INST]")
                    batch.append(info)

                if len(batch) == batch_size:
                    try:
                        input_batch = ["<s>[INST] "+instruction+info_+" [/INST]" for info_ in batch]
                        outputs = llm.generate(input_batch, sampling_params, use_tqdm=False)

                        for output in outputs:
                            prompt_ = output.prompt.split("Product Info:")[1].strip().removesuffix(" [/INST]")
                            generated_text = output.outputs[0].text
                            try:
                                # print(f'Prompt: {prompt_!r}, Generated text: {json.loads(generated_text.strip())["summary"]}')
                                # to_write = {"product_info": prompt_.strip(), "summary": json.loads(generated_text.strip())["summary"]}
                                out_file.write({"product_info": prompt_.strip(), "summary": json.loads(generated_text.strip())["summary"]})
                                pred_ct+=1
                            except Exception as e:
                                err_ct+=1
                                # print(e)
                                # print(generated_text)   
                    except Exception as e:
                        print(e)
                        err_ct+=1
                    finally:
                        batch = []
    return pred_ct

In [13]:
# get_labelled_data(test_dir, "data/labelled/test.jsonl")

In [None]:
get_labelled_data(train_dir, "data/labelled/metadata/train", 153000, "meta_Cell_Phones_and_Accessories.json.gz")

Processing file - meta_Cell_Phones_and_Accessories.json.gz

Processed 153000 products so far. Additional/Total labelled - 0 samples.

Processed 154000 products so far. Additional/Total labelled - 410 samples.

Processed 155000 products so far. Additional/Total labelled - 784 samples.

Processed 156000 products so far. Additional/Total labelled - 1212 samples.

Processed 157000 products so far. Additional/Total labelled - 1640 samples.

Processed 158000 products so far. Additional/Total labelled - 2077 samples.

Processed 159000 products so far. Additional/Total labelled - 2546 samples.

Processed 160000 products so far. Additional/Total labelled - 3033 samples.

Processed 161000 products so far. Additional/Total labelled - 3479 samples.

Processed 162000 products so far. Additional/Total labelled - 3930 samples.

Processed 163000 products so far. Additional/Total labelled - 4255 samples.

Processed 164000 products so far. Additional/Total labelled - 4701 samples.

Processed 165000 prod

In [None]:
#SBATCH --job-name=create_labels_nlp_project
#SBATCH --partition=gpu
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=4
#SBATCH --mem=16G
#SBATCH --gres=gpu:v100-sxm2:1
#SBATCH --time=6:00:00
#SBATCH -o %J.log
#SBATCH -e %J.log