In [16]:
import os
import random
import time
import random
import re
from collections import Counter

from functools import wraps

def timed(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        start = time.time()
        result = func(*args, **kwargs)
        elapsed = time.time() - start
        print(f"[{func.__name__}] took {elapsed:.2f} seconds.")
        return result
    return wrapper


# TODO: move to utils

# File paths explicitly defined
DATA_DIR = "../data"
ORIGINAL_FILE = os.path.join(DATA_DIR, "TinyStoriesV2-GPT4-train.txt")
SAMPLED_FILE = os.path.join(DATA_DIR, "sampled_50k.txt")

@timed
def load_dataset(filepath):
    start_time = time.time()
    with open(filepath, "r", encoding="utf-8", errors="replace") as f:
        docs = f.read().split("<|endoftext|>")
    docs = [doc.strip() for doc in docs if doc.strip()]
    elapsed_time = time.time() - start_time
    print(f"Loaded {len(docs)} documents in {elapsed_time:0.2f} seconds.")
    return docs

# Saves memory when loading large datasets
@timed
def load_dataset_stream(filepath, split_token="<|endoftext|>"):
    docs = []
    current_doc = []
    split_token = split_token.strip()

    with open(filepath, "r", encoding="utf-8", errors="replace") as f:
        for line in f:
            line = line.strip()
            if line == split_token:
                if current_doc:
                    docs.append("\n".join(current_doc))
                    current_doc = []
            else:
                current_doc.append(line)
    if current_doc:
        docs.append("\n".join(current_doc))

    return docs

@timed
def explore_dataset(docs, num_samples=5):
    print(f"Total number of documents: {len(docs)}\n")

    # Basic stats (length distribution explicitly)
    lengths = [len(doc.split()) for doc in docs]
    avg_len = sum(lengths) / len(lengths)
    max_len = max(lengths)
    min_len = min(lengths)

    print("Explicit Document Length Stats (in tokens):")
    print(f"  Avg Length: {avg_len:.2f}")
    print(f"  Max Length: {max_len}")
    print(f"  Min Length: {min_len}\n")

    # Show some random samples explicitly
    samples = random.sample(docs, min(num_samples, len(docs)))
    for i, doc in enumerate(samples, 1):
        print(f"--- Sample Document {i} ---")
        print(doc[:500])  # Print first 500 chars explicitly for clarity
        print("--------------------------\n")

    # Count wrong (non-decodable UTF-8) symbols (replacement character "�")
    wrong_char_pattern = re.compile(r'�')
    total_wrong_chars = sum(len(wrong_char_pattern.findall(doc)) for doc in docs)
    print(f"Total number of non-decodable UTF-8 symbols (�): {total_wrong_chars}\n")

    # Explicitly show context around problematic characters
    docs_with_wrong_chars = [doc for doc in docs if '�' in doc]
    if docs_with_wrong_chars:
        num_context_samples = min(3, len(docs_with_wrong_chars))
        context_samples = random.sample(docs_with_wrong_chars, num_context_samples)

        print(f"Showing context around problematic characters in {num_context_samples} random documents:\n")
        for idx, doc in enumerate(context_samples, 1):
            print(f"--- Context Sample {idx} ---")
            for match in wrong_char_pattern.finditer(doc):
                start, end = match.start(), match.end()
                context = doc[max(0, start - 30):min(len(doc), end + 30)]
                print(f"...{context}...")
            print("--------------------------\n")
    else:
        print("No problematic characters found in documents.\n")


def sample_documents(docs, sample_size=50000, seed=42):
    random.seed(seed)
    if sample_size > len(docs):
        raise ValueError("Sample size greater than available documents")
    sampled_docs = random.sample(docs, sample_size)
    return sampled_docs

def save_sampled_data(docs, filepath):
    with open(filepath, "w", encoding="utf-8") as f:
        for doc in docs:
            f.write(doc.strip() + "\n<|endoftext|>\n")

In [17]:
print("Loading dataset...")
# docs = load_dataset(ORIGINAL_FILE)
docs = load_dataset_stream(ORIGINAL_FILE)
print("Loading finished")

Loading dataset...
[load_dataset_stream] took 3.40 seconds.
Loading finished


In [18]:
print("Exploring dataset...")
explore_dataset(docs, num_samples=2)

Exploring dataset...
Total number of documents: 2717659

Explicit Document Length Stats (in tokens):
  Avg Length: 160.62
  Max Length: 963
  Min Length: 0

--- Sample Document 1 ---
Once upon a time, there was a little boy named Tim. Tim loved playing with his toy car all day. One day, while playing, his mom called him for dinner. Tim did not want to stop playing, but he was very hungry.
Mom made Tim's favorite dinner, hot soup and yummy bread. But Tim did not come to eat right away. He wanted to play just a little more. When he finally came to eat, his soup was cold. Tim did not like cold soup.
Tim learned that when mom calls for dinner, he should stop playing and come eat
--------------------------

--- Sample Document 2 ---
One day, a pale boy named Tim went to the park. He saw a tap near the sandbox. He wanted to play with the water, so he turned the tap on.
A girl named Sue saw Tim and said, "No, Tim! We must prevent the water from going everywhere!" Sue ran to the tap and turned

In [3]:
print(f"\nSampling {50000} documents explicitly...")
sampled_docs = sample_documents(docs, sample_size=50000)

print(f"Saving sampled data to '{SAMPLED_FILE}'")
save_sampled_data(sampled_docs, SAMPLED_FILE)
print("Sampled data saved clearly!")


Sampling 50000 documents explicitly...
Saving sampled data to '../data/sampled_50k.txt'
Sampled data saved clearly!
