In [1]:
import datasets # type: ignore

input_dataset = 'lmsys/chatbot_arena_conversations'
dataset = datasets.load_dataset(input_dataset)['train']
n_rows = len(dataset)
print(str(n_rows))

33000


In [3]:
from typing import Generator

def fixed_chunker(dataset: datasets.Dataset, 
                  chunk_size: int) -> Generator:
    for i in range(0, len(dataset), chunk_size):
        yield dataset.select(range(i, min(i + chunk_size, len(dataset))))

def add_order_column(dataset: datasets.Dataset) -> datasets.Dataset:
    return dataset.map(
        lambda example, idx: {'order': idx}, 
        with_indices=True)
    
def add_question_column(dataset: datasets.Dataset) -> datasets.Dataset:
    return dataset.map(
        lambda example: {'question': example['conversation_a'][0]['content']})

# Test:
for big_chunk in fixed_chunker(dataset, 10000):
    print(f"Chunk len: {len(big_chunk)}")

for big_chunk in fixed_chunker(dataset, 100):
    big_chunk = add_order_column(big_chunk)
    print(f"First Order: {big_chunk[0]['order']}")

    big_chunk = add_question_column(big_chunk)
    print(f"First Question: {big_chunk[0]['question']}")

    # Sort the chunk.
    big_chunk = big_chunk.sort('question')
    print(f"First Question after sort: {big_chunk[0]['question']}")

    # Unsort the chunk.
    big_chunk = big_chunk.sort('order')
    print(f"First Question after unsort: {big_chunk[0]['question']}")

    break

Chunk len: 10000
Chunk len: 10000
Chunk len: 10000
Chunk len: 3000
First Order: 0
First Question: What is the difference between OpenCL and CUDA?
First Question after sort: 3,14 + 9855 + 0,000001 = ?
First Question after unsort: What is the difference between OpenCL and CUDA?


In [4]:
import torch
import utils.llm_utils as llm

model_name = 'microsoft/Phi-3-mini-4k-instruct'
tokenizer, model = llm.load_tokenizer_and_model(model_name)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [12]:
def token_len(tokenizer, text: str) -> torch.Tensor:
    message = [{"role": "user", "content": text}] 
    tokens = tokenizer.apply_chat_template([message], add_generation_prompt=True)[0]
    return len(tokens)

def add_token_len_column(tokenizer, dataset: datasets.Dataset) -> datasets.Dataset:
    return dataset.map(
        lambda example: {'token_len': token_len(tokenizer, example['question'])})

def token_chunker(dataset: datasets.Dataset, 
                  chunk_tokens: int,
                  generated_tokens:int) -> Generator:
    ''' Yield chunks of examples that when prompted, tokenized, padded and generated are 
    smaller than chunk_tokens size. 
    Since examples are sorted descending by token_len, all examples in a chunk will be
    padded to the length of the first example in the chunk. '''
    start = 0
    while start < len(dataset):
        example_len = dataset[start]['token_len'] + generated_tokens
        n_chunk = min(chunk_tokens//example_len, len(dataset)-start)
        yield dataset.select(range(start, start+n_chunk))
        start += n_chunk

print(f"Tokenized Hello World: {token_len(tokenizer, 'Hello World')}")

for big_chunk in fixed_chunker(dataset, 100):
    big_chunk = add_question_column(big_chunk)
    big_chunk = add_token_len_column(tokenizer, big_chunk)
    big_chunk = add_order_column(big_chunk)
    big_chunk = big_chunk.sort('token_len', reverse=True)
    print(f"First Question tokenization len: {big_chunk[0]['token_len']}")

    for small_chunk in token_chunker(big_chunk, 500, 50):
        print(f"Small chunk len: {len(small_chunk)} and tokenization len: {small_chunk[0]['token_len']}")

    break

Tokenized Hello World: 5
First Question tokenization len: 353
Small chunk len: 1 and tokenization len: 353
Small chunk len: 1 and tokenization len: 288
Small chunk len: 3 and tokenization len: 97
Small chunk len: 4 and tokenization len: 54
Small chunk len: 5 and tokenization len: 44
Small chunk len: 6 and tokenization len: 32
Small chunk len: 6 and tokenization len: 26
Small chunk len: 6 and tokenization len: 24
Small chunk len: 7 and tokenization len: 21
Small chunk len: 7 and tokenization len: 18
Small chunk len: 7 and tokenization len: 16
Small chunk len: 7 and tokenization len: 15
Small chunk len: 7 and tokenization len: 14
Small chunk len: 7 and tokenization len: 13
Small chunk len: 8 and tokenization len: 12
Small chunk len: 8 and tokenization len: 10
Small chunk len: 8 and tokenization len: 9
Small chunk len: 2 and tokenization len: 5
