### BabyLM data preprocessing for GPT-2 pretraining

In [2]:
import sys
import os
import glob

project_root = os.path.abspath('..')
if project_root not in sys.path:
    sys.path.insert(0, project_root)
    
training_data_folder_path = "../data/train_10M"

output_file = "../data/training_data_cleaned.txt"

file_paths = glob.glob(os.path.join(training_data_folder_path, "*"))
contents: list[str] = []
for file_path in file_paths:
        with open(file_path, encoding="utf-8") as f:
            contents.append(f.read())
            
print(f"number of documents: {len(contents)}")

size = 0
for doc in contents:
    print(f"document length (chars): {len(doc)}")
    size += sys.getsizeof(doc)
    
print(f"Memory usage: {size / (1024 * 1024):.2f} MB")

number of documents: 6
document length (chars): 8411630
document length (chars): 13910986
document length (chars): 4883879
document length (chars): 719322
document length (chars): 10806305
document length (chars): 15482927
Memory usage: 118.77 MB


#### Text Cleaning

In [3]:
import re

# Updated patterns including removal of lines with excessive spacing
patterns = [
    r'^= = = .* = = =$',        # Remove lines like "= = = Corcelles-près-Payerne = = ="
    r'^-\s*',                   # Remove "- " from the beginning of lines
    r'^[A-Z]:\s*',             # Remove speaker labels like "A:", "B:" (with space or tab)
    r'^\*[^*]+\*$',            # Remove chapter headers like "*CHAPTER XI*"
    r'^\*\w+:\s*',             # Remove speaker labels like "*CHI:" (with space or tab)
    r'\[.*?\]',                # Remove any content between square brackets [anything]
    r'^.*\s{4,}.*$',           # Remove lines with 4 or more consecutive spaces
]

# Function to remove short lines (less than 10 characters of actual content)
def remove_short_lines(text: str, min_length: int = 10) -> str:
    """Remove lines that have less than min_length characters of actual content."""
    lines = text.split('\n')
    filtered_lines = []
    
    for line in lines:
        # Strip whitespace and check actual content length
        content = line.strip()
        if len(content) >= min_length:
            filtered_lines.append(line)
    
    return '\n'.join(filtered_lines)

# Apply cleaning to each document and store the results
cleaned_contents = []
for i, doc in enumerate(contents):
    print(f"Cleaning document {i+1}/{len(contents)}...")
    cleaned_doc = doc
    
    # Apply regex patterns
    for pattern in patterns:
        # Count matches before cleaning
        matches_before = len(re.findall(pattern, cleaned_doc, flags=re.MULTILINE))
        cleaned_doc = re.sub(pattern, '', cleaned_doc, flags=re.MULTILINE)
        matches_after = len(re.findall(pattern, cleaned_doc, flags=re.MULTILINE))
        
        if matches_before > 0:
            print(f"  Pattern '{pattern}': removed {matches_before} matches")
    
    # Remove short lines
    lines_before = len(cleaned_doc.split('\n'))
    cleaned_doc = remove_short_lines(cleaned_doc, min_length=10)
    lines_after = len(cleaned_doc.split('\n'))
    short_lines_removed = lines_before - lines_after
    
    if short_lines_removed > 0:
        print(f"  Short lines: removed {short_lines_removed} lines shorter than 10 characters")
    
    cleaned_contents.append(cleaned_doc)
    print(f"Document {i+1} cleaned. Original: {len(doc)} chars, Cleaned: {len(cleaned_doc)} chars")

# Update the contents list with cleaned data
contents = cleaned_contents

Cleaning document 1/6...
  Pattern '^= = = .* = = =$': removed 12284 matches
  Pattern '^-\s*': removed 26 matches
  Pattern '\[.*?\]': removed 179 matches
  Pattern '^.*\s{4,}.*$': removed 945 matches
  Short lines: removed 26421 lines shorter than 10 characters
Document 1 cleaned. Original: 8411630 chars, Cleaned: 7582309 chars
Cleaning document 2/6...
  Pattern '^= = = .* = = =$': removed 50 matches
  Pattern '^-\s*': removed 102 matches
  Pattern '^\*[^*]+\*$': removed 55 matches
  Pattern '\[.*?\]': removed 2203 matches
  Pattern '^.*\s{4,}.*$': removed 1233 matches
  Short lines: removed 2638 lines shorter than 10 characters
Document 2 cleaned. Original: 13910986 chars, Cleaned: 13697692 chars
Cleaning document 3/6...
  Pattern '^.*\s{4,}.*$': removed 268 matches
  Short lines: removed 22086 lines shorter than 10 characters
Document 3 cleaned. Original: 4883879 chars, Cleaned: 4725138 chars
Cleaning document 4/6...
  Pattern '^[A-Z]:\s*': removed 18000 matches
  Short lines: remo

In [4]:
# Write the cleaned text to a new file
final_text = ""
for doc in contents:
    final_text += doc

with open(output_file, 'w', encoding='utf-8') as f:
    f.write(final_text)

In [14]:
def add_specialized_tokens_to_sentences(text: str) -> str:
    """
    Add specialized tokens for each line (sentence) in the text.
    Each newline is treated as a sentence boundary.
    Only adds <|endoftext|> token at the end of each sentence.
    
    Args:
        text: The input text to process
        
    Returns:
        Text with <|endoftext|> tokens added for each sentence
    """
    # Define specialized token (following GPT-2 style)
    END_OF_TEXT = "<|endoftext|>"
    
    # Split text into lines (each line is a sentence)
    lines = text.split('\n')
    processed_sentences = []
    
    for line in lines:
        line = line.strip()
        if line:  # Only process non-empty lines
            # Add only end of text token
            sentence_with_tokens = f"{line} {END_OF_TEXT}"
            processed_sentences.append(sentence_with_tokens)
    
    return '\n'.join(processed_sentences)

# Apply specialized tokens to your final_text
print("Adding <|endoftext|> tokens to final_text (treating each line as a sentence)...")
final_text_with_tokens = add_specialized_tokens_to_sentences(final_text)

# Save the processed text
output_file_with_tokens = "../data/training_data_with_special_tokens.txt"
with open(output_file_with_tokens, 'w', encoding='utf-8') as f:
    f.write(final_text_with_tokens)

print(f"Original text length: {len(final_text)} characters")
print(f"Text with special tokens length: {len(final_text_with_tokens)} characters")

# Count lines/sentences
original_lines = len([line for line in final_text.split('\n') if line.strip()])
processed_lines = len([line for line in final_text_with_tokens.split('\n') if line.strip()])
print(f"Number of sentences (lines): {original_lines}")

# Test tokenization with the new text
from src.data.tokenizer import get_tokenizer, text_to_token_ids

tokenizer = get_tokenizer()
tokens_with_special = text_to_token_ids(final_text_with_tokens, tokenizer)

print(f"Number of tokens with special tokens: {len(tokens_with_special[0])}")

# Show a sample of the processed text
sample_lines = final_text_with_tokens.split('\n')[:5]  # First 5 sentences
print("\nSample of first 5 sentences with <|endoftext|> tokens:")
for i, line in enumerate(sample_lines, 1):
    print(f"{i}. {line}")

# Check how the special tokens are tokenized
special_token_sample = "Hello world. <|endoftext|>"
special_tokens = text_to_token_ids(special_token_sample, tokenizer)
print(f"\nSpecial token sample: '{special_token_sample}'")
print(f"Tokenized as: {special_tokens[0].tolist()}")
print(f"Decoded back: '{tokenizer.decode(special_tokens[0].tolist())}'")

Adding <|endoftext|> tokens to final_text (treating each line as a sentence)...
Original text length: 45891376 characters
Text with special tokens length: 57133118 characters
Number of sentences (lines): 815911
Number of tokens with special tokens: 13784480

Sample of first 5 sentences with <|endoftext|> tokens:
1. The usage of .gov as a gTLD controlled only by the U.S. is controversial, as some people believe this to be an example of arrogance by the U.S. – such views hold that usage of .fed.us or a new second-level domain of .gov.us would be more suitable. Others believe that U.S. control of .gov is a natural result of the fact that the U.S. federal government was responsible for the initial creation of the Internet and its first user. <|endoftext|>
2. Corcelles-près-Payerne is a municipality in the Broye-Vully district in the canton of Vaud in Switzerland. <|endoftext|>
3. Cremin was a municipality of the district of Broye-Vully in the canton of Vaud in Switzerland. On 1 January 201

In [15]:
# tokenize the data for some analysis
from src.data.tokenizer import get_tokenizer, text_to_token_ids

tokenizer = get_tokenizer()

tokens = text_to_token_ids(final_text_with_tokens, tokenizer)

print(f"Number of words: {len(final_text_with_tokens.split())}")
print(f"Number of tokens: {len(tokens[0])}")

Number of words: 9297503
Number of tokens: 13784480


In [16]:
import numpy as np

token_array = tokens[0].numpy()

unique_tokens, counts = np.unique(token_array, return_counts=True)
print(f"Unique tokens: {len(unique_tokens)}")
mft = unique_tokens[np.argmax(counts)]
lft = unique_tokens[np.argmin(counts)]
print(f"Most frequent token: {mft}:{tokenizer.decode([mft])} (appears {counts.max()} times)")
print(f"Least frequent token: {lft}:{tokenizer.decode([lft])} (appears {counts.min()} times)")

Unique tokens: 45201
Most frequent token: 220:  (appears 852775 times)
Least frequent token: 98:� (appears 1 times)


In [17]:
from collections import Counter
from src.data.tokenizer import token_ids_to_text
import torch

# Get top 20 most frequent tokens
token_counter = Counter(token_array)
top_tokens = token_counter.most_common(20)

print("Top 20 most frequent tokens:")
for i, (token_id, count) in enumerate(top_tokens, 1):
    token_text = token_ids_to_text(torch.tensor([[token_id]]), tokenizer)
    percentage = (count / len(token_array)) * 100
    print(f"{i:2d}. Token ID: {token_id:5d} | Text: '{token_text}' | Count: {count:8d} | Percentage: {percentage:5.2f}%")

Top 20 most frequent tokens:
 1. Token ID:   220 | Text: ' ' | Count:   852775 | Percentage:  6.19%
 2. Token ID: 50256 | Text: '<|endoftext|>' | Count:   815911 | Percentage:  5.92%
 3. Token ID:   198 | Text: '
' | Count:   815910 | Percentage:  5.92%
 4. Token ID:    13 | Text: '.' | Count:   664473 | Percentage:  4.82%
 5. Token ID:    11 | Text: ',' | Count:   462790 | Percentage:  3.36%
 6. Token ID:   262 | Text: ' the' | Count:   354518 | Percentage:  2.57%
 7. Token ID:   284 | Text: ' to' | Count:   189168 | Percentage:  1.37%
 8. Token ID:   257 | Text: ' a' | Count:   178461 | Percentage:  1.29%
 9. Token ID:   345 | Text: ' you' | Count:   176208 | Percentage:  1.28%
10. Token ID:   290 | Text: ' and' | Count:   171039 | Percentage:  1.24%
11. Token ID:    30 | Text: '?' | Count:   162007 | Percentage:  1.18%
12. Token ID:   286 | Text: ' of' | Count:   156721 | Percentage:  1.14%
13. Token ID:   338 | Text: ''s' | Count:   139308 | Percentage:  1.01%
14. Token ID:   287 |