# Average Tokens per Shard by Dataset

Estimates the average number of tokens per parquet shard for each of the four dataset types:
FineWeb-Edu, SmolTalk, UltraChat-Gen, and UltraChat-SFT.

Samples N shards per dataset, tokenizes a random subset of rows from each shard, scales by
the full shard row count to estimate total tokens, then averages across shards.

In [1]:
import sys, os
sys.path.insert(0, os.path.join(os.getcwd(), '..'))

import random
import numpy as np
import pyarrow.parquet as pq
from transformers import AutoTokenizer

random.seed(42)

tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.add_special_tokens({
    'bos_token': '<|beginoftext|>',
    'pad_token': '<|pad|>',
    'additional_special_tokens': ['<|user|>', '<|assistant|>', '<|system|>'],
})
print(f'Vocab: {len(tokenizer)}')
print(f'BOS={tokenizer.bos_token_id}, EOS={tokenizer.eos_token_id}, PAD={tokenizer.pad_token_id}')

  from .autonotebook import tqdm as notebook_tqdm


Vocab: 50262
BOS=50257, EOS=50256, PAD=50258


## Config

In [2]:
N_SHARDS = 10   # shards to sample per dataset

BASE_DIR      = os.path.join(os.getcwd(), '..', 'data')
fineweb_dir   = os.path.join(BASE_DIR, 'base_data')
conv_dir      = os.path.join(BASE_DIR, 'conversation_data')

fineweb_files = sorted(f for f in os.listdir(fineweb_dir) if f.endswith('.parquet'))
conv_files    = sorted(f for f in os.listdir(conv_dir)    if f.endswith('.parquet'))

print(f'FineWeb-Edu shards : {len(fineweb_files)} total, sampling {N_SHARDS}')
print(f'Conversation shards: {len(conv_files)} total, sampling {N_SHARDS}')

FineWeb-Edu shards : 1706 total, sampling 10
Conversation shards: 31 total, sampling 10


## Helpers

In [3]:
SAMPLE_PER_SHARD = 4000  # rows to tokenize per shard (trades speed vs accuracy)

role_tok = {
    'user':      tokenizer.convert_tokens_to_ids('<|user|>'),
    'assistant': tokenizer.convert_tokens_to_ids('<|assistant|>'),
    'system':    tokenizer.convert_tokens_to_ids('<|system|>'),
}

def tokens_for_text(text):
    return len(tokenizer.encode(text, add_special_tokens=False))

def tokens_for_conv(messages):
    """BOS + (role_token + content tokens per turn) + EOS."""
    n = 2  # BOS + EOS
    for msg in messages:
        if msg['role'] in role_tok:
            n += 1
        n += len(tokenizer.encode(msg['content'], add_special_tokens=False))
    return n

def estimate_shard_tokens(rows, tok_fn, sample_size=SAMPLE_PER_SHARD):
    """Tokenize a random sample of rows, scale mean to full shard row count."""
    sample = random.sample(rows, min(sample_size, len(rows)))
    mean_tok = np.mean([tok_fn(r) for r in sample])
    return mean_tok * len(rows), mean_tok, len(rows)

## FineWeb-Edu: tokens per shard

In [4]:
fw_sample = random.sample(fineweb_files, N_SHARDS)
fw_shard_tokens = []

for fname in fw_sample:
    path = os.path.join(fineweb_dir, fname)
    texts = pq.read_table(path)['text'].to_pylist()
    total, mean, n = estimate_shard_tokens(texts, tokens_for_text)
    fw_shard_tokens.append(total)
    print(f'  {fname}: {n:,} docs, ~{mean:.0f} tok/doc → {total/1e6:.1f}M tokens')

fw_avg = np.mean(fw_shard_tokens)
fw_std = np.std(fw_shard_tokens)
print(f'\nFineWeb-Edu  avg tokens/shard: {fw_avg/1e6:.1f}M ± {fw_std/1e6:.1f}M  (N={N_SHARDS})')

Token indices sequence length is longer than the specified maximum sequence length for this model (1764 > 1024). Running this sequence through the model will result in indexing errors


  shard_01309.parquet: 54,272 docs, ~1014 tok/doc → 55.0M tokens
  shard_00228.parquet: 54,272 docs, ~1019 tok/doc → 55.3M tokens
  shard_00051.parquet: 53,248 docs, ~1000 tok/doc → 53.3M tokens
  shard_01518.parquet: 53,248 docs, ~1040 tok/doc → 55.4M tokens
  shard_00563.parquet: 53,248 docs, ~1055 tok/doc → 56.2M tokens
  shard_00501.parquet: 53,248 docs, ~1053 tok/doc → 56.1M tokens
  shard_00457.parquet: 53,248 docs, ~1019 tok/doc → 54.2M tokens
  shard_00285.parquet: 53,248 docs, ~991 tok/doc → 52.8M tokens
  shard_01508.parquet: 53,248 docs, ~1010 tok/doc → 53.8M tokens
  shard_00209.parquet: 53,248 docs, ~1064 tok/doc → 56.7M tokens

FineWeb-Edu  avg tokens/shard: 54.9M ± 1.2M  (N=10)


## Conversation datasets: tokens per shard

Each conversation shard contains rows from multiple sources. For each sampled shard,
rows are split by source, sampled independently, and scaled to estimate per-source token totals.

In [5]:
SOURCES = ['smoltalk', 'ultrachat_gen', 'ultrachat_sft']

cv_sample = random.sample(conv_files, N_SHARDS)
shard_totals = {src: [] for src in SOURCES}

for fname in cv_sample:
    path = os.path.join(conv_dir, fname)
    rows = pq.read_table(path).to_pylist()

    by_source = {src: [] for src in SOURCES}
    for row in rows:
        src = row['source']
        if src in by_source:
            by_source[src].append(row['messages'])

    print(f'{fname}:')
    for src in SOURCES:
        msgs = by_source[src]
        if not msgs:
            shard_totals[src].append(0.0)
            print(f'  {src:<22}: 0 rows')
            continue
        total, mean, n = estimate_shard_tokens(msgs, tokens_for_conv, sample_size=1000)
        shard_totals[src].append(total)
        print(f'  {src:<22}: {n:,} convs, ~{mean:.0f} tok/conv → {total/1e6:.1f}M tokens')
    print()

shard_00023.parquet:
  smoltalk              : 34,513 convs, ~866 tok/conv → 29.9M tokens
  ultrachat_gen         : 8,523 convs, ~901 tok/conv → 7.7M tokens
  ultrachat_sft         : 6,964 convs, ~1260 tok/conv → 8.8M tokens

shard_00008.parquet:
  smoltalk              : 34,734 convs, ~893 tok/conv → 31.0M tokens
  ultrachat_gen         : 8,457 convs, ~955 tok/conv → 8.1M tokens
  ultrachat_sft         : 6,809 convs, ~1223 tok/conv → 8.3M tokens

shard_00027.parquet:
  smoltalk              : 34,494 convs, ~962 tok/conv → 33.2M tokens
  ultrachat_gen         : 8,610 convs, ~934 tok/conv → 8.0M tokens
  ultrachat_sft         : 6,896 convs, ~1219 tok/conv → 8.4M tokens

shard_00016.parquet:
  smoltalk              : 34,619 convs, ~929 tok/conv → 32.1M tokens
  ultrachat_gen         : 8,424 convs, ~893 tok/conv → 7.5M tokens
  ultrachat_sft         : 6,957 convs, ~1212 tok/conv → 8.4M tokens

shard_00018.parquet:
  smoltalk              : 34,459 convs, ~886 tok/conv → 30.5M tokens


KeyboardInterrupt: 

In [12]:
msgs[0]
# tokens_for_conv

[{'content': 'In what ways do contemporary street art and public installation art challenge traditional notions of art and aesthetics, and how do they engage with and reflect the cultural and social climates of the communities where they are created? Additionally, how do factors such as accessibility, public space, and the use of non-traditional materials play a role in the creation and reception of these art forms?',
  'role': 'user'},
 {'content': 'Contemporary street art and public installation art challenge traditional notions of art and aesthetics by rejecting the idea that art is only meant for the elite and should be displayed in galleries and museums. Instead, these art forms are accessible to anyone who passes by them on the street or in public spaces. Additionally, they often incorporate non-traditional materials, such as spray paint, stencils, and found objects.\n\nThese art forms engage with and reflect the cultural and social climates of the communities where they are crea

In [None]:
sample = random.sample(rows, min(5, len(rows)))
mean_tok = np.mean([tok_fn(r) for r in sample])
mean_tok * len(rows), mean_tok, len(rows)

## Summary

In [None]:
print(f'{"Dataset":<22}  {"Avg tokens/shard":>18}  {"Std":>8}  {"N shards":>8}')
print('-' * 62)
print(f'{"fineweb_edu":<22}  {fw_avg/1e6:>15.1f}M  {fw_std/1e6:>5.1f}M  {N_SHARDS:>8}')
for src in SOURCES:
    vals = np.array(shard_totals[src])
    avg  = vals.mean()
    std  = vals.std()
    n    = int((vals > 0).sum())
    print(f'{src:<22}  {avg/1e6:>15.1f}M  {std/1e6:>5.1f}M  {n:>8}')

Dataset                   Avg tokens/shard       Std  N shards
--------------------------------------------------------------
fineweb_edu                        54.9M    1.2M        10
smoltalk                           31.7M    1.1M        10
ultrachat_gen                       7.8M    0.2M        10
ultrachat_sft                       8.5M    0.2M        10
