### Setup

In [4]:
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
import os
from datasets import load_dataset
import time
import csv


In [2]:
model_id = "nvidia/Llama3-ChatQA-1.5-8B"

In [5]:
tokenizer = AutoTokenizer.from_pretrained(model_id)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


### Loading dataset

- All data have been downloaded, we are just loading it from disk to memory

In [8]:
dataset_name = dataset_names[0] #choose
dataset = load_dataset("nvidia/ChatRAG-Bench", dataset_name)

In [9]:
def reformat_question(turn_list, dataset_name):

    ## only take the lastest 7 turns
    turn_list = turn_list[-7:]
    assert turn_list[-1]['role'] == 'user'

    long_answer_dataset_list = ["doc2dial", "quac", "qrecc", "inscit", "doqa_movies", "doqa_travel", "doqa_cooking", "hybridial", "convfinqa"]
    long_and_short_dataset_list = ["topiocqa"]
    entity_dataset_list = ["sqa"]
    short_dataset_list = ["coqa"]

    if dataset_name in long_answer_dataset_list:
        for item in turn_list:
            if item['role'] == 'user':
                ## only needs to add it on the first user turn
                item['content'] = 'Please give a full and complete answer for the question. ' + item['content']
                break
    
    elif dataset_name in long_and_short_dataset_list:
        turn_list[-1]['content'] = "Answer the following question with a short span, or a full and complete answer. " + turn_list[-1]['content']

    elif dataset_name in entity_dataset_list:
        turn_list[-1]['content'] = "Answer the following question with one or a list of items. " + turn_list[-1]['content']

    elif dataset_name in short_dataset_list:
        turn_list[-1]['content'] = "Answer the following question with a short span. The answer needs to be just in a few words. " + turn_list[-1]['content']

    else:
        raise Exception("please input a correct dataset name!")
    
    question = ""
    for item in turn_list:
        if item["role"] == "user":
            question += "User: " + item["content"] + "\n\n"
        else:
            assert item["role"] == "assistant"
            question += "Assistant: " + item["content"] + "\n\n"
    
    question += "Assistant:"
    
    return question


def get_inputs(example, dataset_name, tokenizer, num_ctx, max_output_len, max_seq_length=4096):
    system = "System: This is a chat between a user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions based on the context. The assistant should also indicate when the answer cannot be found in the context."

    turn_list = example['messages']
    question_formatted = reformat_question(turn_list, dataset_name)

    ctx_list = ["title: " + ctx["title"] + ", source: " + ctx["text"] for ctx in example['ctxs'][:num_ctx]]
    context = "\n\n".join(ctx_list)

    context_tokens = tokenizer.encode(context)
    question_tokens = tokenizer.encode(question_formatted)
    system_tokens = tokenizer.encode(system)

    if len(context_tokens) + len(question_tokens) + len(system_tokens) + max_output_len >= max_seq_length:
        context_tokens = context_tokens[:max_seq_length - max_output_len - len(question_tokens) - len(system_tokens)]
        context = tokenizer.decode(context_tokens, skip_special_tokens=True)

    model_input = system + "\n\n" + context + "\n\n" + question_formatted

    return {"model_input": model_input}

def process_dataset(dataset, dataset_name, tokenizer, num_ctx, max_output_len, max_seq_length=4096):
    processed_dataset = dataset.map(
        lambda example: get_inputs(example, dataset_name, tokenizer, num_ctx, max_output_len, max_seq_length),
        batched=False,
        remove_columns=dataset.column_names,
    )
    return processed_dataset

In [10]:

num_ctx = 5  # Specify the number of contexts to use
max_output_len = 64  # Specify the maximum output length
max_seq_length = 64  # Specify the maximum sequence length

processed_dataset = process_dataset(dataset['dev'], dataset_name, tokenizer, num_ctx, max_output_len, max_seq_length)

Map: 100%|██████████| 7983/7983 [00:13<00:00, 612.36 examples/s]


### Sampling Inputs

In [11]:
print(processed_dataset[1]['model_input'])

System: This is a chat between a user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions based on the context. The assistant should also indicate when the answer cannot be found in the context.

title:, source: Once upon a time, in a barn near a farm house, there lived a little white kitten named Cotton. Cotton lived high up in a nice warm place above the barn where all of the farmer's horses slept. But Cotton wasn't alone in her little home above the barn, oh no. She shared her hay bed with her mommy and 5 other sisters. All of her sisters were cute and fluffy, like Cotton. But she was the only white one in the bunch. The rest of her sisters were all orange with beautiful white tiger stripes like Cotton's mommy. Being different made Cotton quite sad. She often wished she looked like the rest of her family. So one day, when Cotton found a can of the old farmer's orange paint, she used it to paint herself like them

In [12]:
print(processed_dataset[-1]['model_input'])

System: This is a chat between a user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions based on the context. The assistant should also indicate when the answer cannot be found in the context.

title:, source: Las Vegas (, Spanish for "The Meadows"), officially the City of Las Vegas and often known simply as Vegas, is the 28th-most populated city in the United States, the most populated city in the state of Nevada, and the county seat of Clark County. The city anchors the Las Vegas Valley metropolitan area and is the largest city within the greater Mojave Desert. Las Vegas is an internationally renowned major resort city, known primarily for its gambling, shopping, fine dining, entertainment, and nightlife. The Las Vegas Valley as a whole serves as the leading financial, commercial, and cultural center

User: Which state is it in?

Assistant: Nevada

User: Is it located in a desert?

Assistant: Yes

User: what is

In [13]:
#processed_dataset is python dict

In [14]:
processed_dataset

Dataset({
    features: ['model_input'],
    num_rows: 7983
})

In [15]:
#model is already downloaded as we have added it the image creation process, 
#the time taken to run the cell is due to loading of model from disk to memory

In [16]:
model_vllm = LLM(model_id)
sampling_params = SamplingParams(temperature=0, top_k=1, max_tokens= 64)

INFO 06-05 20:02:04 llm_engine.py:73] Initializing an LLM engine with config: model='nvidia/Llama3-ChatQA-1.5-8B', tokenizer='nvidia/Llama3-ChatQA-1.5-8B', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=8192, download_dir=None, load_format=auto, tensor_parallel_size=1, quantization=None, seed=0)
INFO 06-05 20:02:04 tokenizer.py:32] For some LLaMA V1 models, initializing the fast tokenizer may take a long time. To reduce the initialization time, consider using 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


INFO 06-05 20:02:31 llm_engine.py:222] # GPU blocks: 27144, # CPU blocks: 2048


In [18]:
## This changes the GPU support to 8
tik = time.time()
## bos token for llama-3
bos_token = "<|begin_of_text|>"
output_list = []
prompts = []
for example in processed_dataset['model_input']:
    prompt = bos_token + example
    prompts.append(prompt)

outputs = model_vllm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    #print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

    # print("generated_text:", generated_text)
    output_list.append(generated_text)
tok = time.time()
print(tok - tik)

Processed prompts: 100%|██████████| 7983/7983 [01:33<00:00, 85.22it/s] 

98.32302021980286





In [19]:
output_datapath = os.path.join(f"{dataset_name}_output.csv")

In [20]:
print("writing to %s" % output_datapath)
with open(output_datapath, "w", newline="") as csvfile:
    csv_writer = csv.writer(csvfile)
    csv_writer.writerow(["Generated Text"])  # Write the header row
    for output in output_list:
        csv_writer.writerow([output])  # Write each generated text as a row in the CSV

writing to coqa_output.csv


In [21]:
!ls

__pycache__  chatrag.ipynb  coqa_output.csv  rag_bench.py


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [22]:
!nvidia-smi

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Wed Jun  5 20:05:56 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA H100 80GB HBM3          On  | 00000000:8B:00.0 Off |                    0 |
| N/A   39C    P0             113W / 700W |  72973MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    