### Calculating the average token length of Datasets

In [1]:
from transformers import AutoTokenizer
from datasets import load_dataset

# Configuration
# root_path = "selected_data/filtered-cured-50k_dataset.json"
root_path = "limo_dataset.json"

# Load dataset
data = load_dataset('json', data_files=root_path)
dialogs = data['train']

# sorted_labels = ['dolly', 'flan_v2', 'stanford_alpaca', 'wizardlm', 'oasst1']

# target_subset_name = 'wizardlm'
# dialogs = dialogs.filter(lambda sample: sample['dataset'] == target_subset_name)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B")

# Define data processing function
def process_dialog(batch):
    dialog_prompts = []
    dialog_completions = []
    for dialog in batch['messages']:  # Access the 'messages' field
        dialog_prompt = " ".join([msg['content'] for msg in dialog if msg['role'] == 'user'])
        dialog_completion = " ".join([msg['content'] for msg in dialog if msg['role'] == 'assistant'])
        dialog_prompts.append(dialog_prompt)
        dialog_completions.append(dialog_completion)
    return {"prompt": dialog_prompts, "completion": dialog_completions}

# Process dialogs using datasets.map
processed_data = dialogs.map(
    process_dialog,
    batched=True,
    batch_size=8,  # Adjust the batch size based on memory capacity
    num_proc=4
)

# Extract prompts and completions
prompts = processed_data['prompt']
completions = processed_data['completion']

# Batch tokenize data
prompts_tokenized = tokenizer(prompts, padding=False, truncation=False, return_length=True)
completions_tokenized = tokenizer(completions, padding=False, truncation=False, return_length=True)

# Compute average lengths
prompts_avg_tokens = sum(prompts_tokenized['length']) / len(prompts_tokenized['length'])
completions_avg_tokens = sum(completions_tokenized['length']) / len(completions_tokenized['length'])

overall_avg_tokens = prompts_avg_tokens + completions_avg_tokens

# Print results
print(f"Prompts Average Tokens: {prompts_avg_tokens}")
print(f"Completions Average Tokens: {completions_avg_tokens}")
print(f"Overall Tokens: {overall_avg_tokens}")



Map (num_proc=4):   0%|          | 0/817 [00:00<?, ? examples/s]

Prompts Average Tokens: 103.10771113831089
Completions Average Tokens: 6357.093023255814
Overall Tokens: 6460.200734394125


In [8]:
dialogs

Dataset({
    features: ['dataset', 'id', 'messages'],
    num_rows: 1567
})