In [2]:
!pip install -q datasets transformers torch tqdm

In [3]:
import os
import sys
from datasets import load_dataset, concatenate_datasets
from transformers import AutoTokenizer
import torch
from collections import Counter
import json
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

NUM_WORKERS = 2
os.environ["TOKENIZERS_PARALLELISM"] = "true"

device = torch.device("cpu")
print(f"Using device: {device}")

print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CPU cores available: {os.cpu_count()}")

# Create working directories
os.makedirs('./data', exist_ok=True)
os.makedirs('./processed', exist_ok=True)
os.makedirs('./samples', exist_ok=True)

Using device: cpu
Python version: 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0]
PyTorch version: 2.9.0+cpu
CPU cores available: 8


# Load Datasets

In [13]:
# The sample sizes were tuned iteratively to exceed the 1GB raw text requirement while keeping a balanced distribution across domains.
DATASET_CONFIG = {
    'wikitext_samples': 500000,
    'openwebtext_samples': 150000,
    'cnn_samples': 150000
}

print("Loading WikiText-103...")
wikitext_dataset = load_dataset(
    "Salesforce/wikitext",
    "wikitext-103-raw-v1",
    split="train",
    streaming=True # Streaming mode iterate through samples one at a time without loading the entire corpus into RAM, preventing out-of-memory crashes in Colab.
)

print("Loading OpenWebText...")
openwebtext_dataset = load_dataset(
    "Skylion007/openwebtext",
    split="train",
    streaming=True
)

print("Loading CNN/DailyMail...")
cnn_dataset = load_dataset(
    "abisee/cnn_dailymail",
    "3.0.0",
    split="train",
    streaming=True,
    trust_remote_code=True # The dataset uses a custom loading script
)

print(f"\nDatasets loaded successfully")
print(f"Target samples: {sum(DATASET_CONFIG.values())} total documents")

Loading WikiText-103...
Loading OpenWebText...


Resolving data files:   0%|          | 0/80 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/80 [00:00<?, ?it/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'abisee/cnn_dailymail' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'abisee/cnn_dailymail' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


Loading CNN/DailyMail...

Datasets loaded successfully
Target samples: 800000 total documents


# Calculate Total Dataset Size

In [14]:
import sys

# The function iterates through each dataset, encoding text to UTF-8 bytes to get an accurate byte count
def estimate_text_size(text):
    return sys.getsizeof(text.encode('utf-8'))

def calculate_dataset_size(dataset, num_samples, text_field='text'):
    total_bytes = 0
    count = 0

    for example in dataset:
        if text_field in example and example[text_field]:
            total_bytes += estimate_text_size(example[text_field])
            count += 1
            if count >= num_samples:
                break

    size_mb = total_bytes / (1024 * 1024)
    size_gb = size_mb / 1024
    return size_mb, size_gb, count

wikitext_mb, wikitext_gb, wikitext_count = calculate_dataset_size(
    wikitext_dataset,
    DATASET_CONFIG['wikitext_samples']
)

openwebtext_mb, openwebtext_gb, openwebtext_count = calculate_dataset_size(
    openwebtext_dataset,
    DATASET_CONFIG['openwebtext_samples']
)

cnn_mb, cnn_gb, cnn_count = calculate_dataset_size(
    cnn_dataset,
    DATASET_CONFIG['cnn_samples'],
    text_field='article'
)

total_mb = wikitext_mb + openwebtext_mb + cnn_mb
total_gb = total_mb / 1024

print(f"WikiText-103: {wikitext_mb:.2f} MB ({wikitext_count} docs)")
print(f"OpenWebText: {openwebtext_mb:.2f} MB ({openwebtext_count} docs)")
print(f"CNN/DailyMail: {cnn_mb:.2f} MB ({cnn_count} docs)")
print(f"Total: {total_mb:.2f} MB ({total_gb:.3f} GB)")
print(f"Status: {'PASS' if total_gb >= 1.0 else 'FAIL'}")

WikiText-103: 235.75 MB (500000 docs)
OpenWebText: 718.47 MB (150000 docs)
CNN/DailyMail: 584.16 MB (150000 docs)
Total: 1538.38 MB (1.502 GB)
Status: PASS


# Collect Raw Data into Memory

In [15]:
#  Collect all documents into a single dictionary keyed by source.
raw_data = {
    'wikitext': [],
    'openwebtext': [],
    'cnn_dailymail': []
}

print("Collecting WikiText-103...")
for i, example in enumerate(wikitext_dataset):
    if i >= DATASET_CONFIG['wikitext_samples']:
        break
    if example['text'] and example['text'].strip():
        raw_data['wikitext'].append(example['text'])

print("Collecting OpenWebText...")
for i, example in enumerate(openwebtext_dataset):
    if i >= DATASET_CONFIG['openwebtext_samples']:
        break
    if example['text'] and example['text'].strip():
        raw_data['openwebtext'].append(example['text'])

print("Collecting CNN/DailyMail...")
for i, example in enumerate(cnn_dataset):
    if i >= DATASET_CONFIG['cnn_samples']:
        break
    if example['article'] and example['article'].strip():
        raw_data['cnn_dailymail'].append(example['article'])

total_docs = sum(len(v) for v in raw_data.values())
print(f"\nCollected {total_docs} documents")
print(f"  WikiText: {len(raw_data['wikitext'])}")
print(f"  OpenWebText: {len(raw_data['openwebtext'])}")
print(f"  CNN/DailyMail: {len(raw_data['cnn_dailymail'])}")

Collecting WikiText-103...
Collecting OpenWebText...
Collecting CNN/DailyMail...

Collected 623818 documents
  WikiText: 323818
  OpenWebText: 150000
  CNN/DailyMail: 150000


# Text Cleaning - Remove Duplicates


In [16]:
# Use Python's built-in hash() for exact deduplication, it's fast and memory-efficient since only store hashes in the set
print("Starting deduplication...")

def deduplicate_texts(texts):
    seen = set()
    unique_texts = []
    duplicates = 0

    for text in texts:
        text_hash = hash(text)
        if text_hash not in seen:
            seen.add(text_hash)
            unique_texts.append(text)
        else:
            duplicates += 1

    return unique_texts, duplicates

cleaned_data = {}
total_duplicates = 0
# WikiText typically has the most duplicates (~12-14%)
for source, texts in raw_data.items():
    print(f"Processing {source}...")
    unique_texts, dupes = deduplicate_texts(texts)
    cleaned_data[source] = unique_texts
    total_duplicates += dupes
    print(f"  Original: {len(texts)}, Unique: {len(unique_texts)}, Duplicates: {dupes}")

total_after = sum(len(v) for v in cleaned_data.values())
print(f"\nTotal documents before: {sum(len(v) for v in raw_data.values())}")
print(f"Total documents after: {total_after}")
print(f"Total duplicates removed: {total_duplicates}")

Starting deduplication...
Processing wikitext...
  Original: 323818, Unique: 279733, Duplicates: 44085
Processing openwebtext...
  Original: 150000, Unique: 150000, Duplicates: 0
Processing cnn_dailymail...
  Original: 150000, Unique: 146912, Duplicates: 3088

Total documents before: 623818
Total documents after: 576645
Total duplicates removed: 47173


# Text Normalization

In [17]:
# Clean text: remove HTML/markdown, normalize case/whitespace, filter short documents

import re
# HTML/Markdown removal — Raw web text often contains <tags>, markdown links [text](url), and header markers (##, ==).
def clean_html_markdown(text):
    """Remove HTML tags and markdown formatting"""
    text = re.sub(r'<[^>]+>', '', text)
    text = re.sub(r'\[.*?\]\(.*?\)', '', text)
    text = re.sub(r'={2,}', '', text)
    text = re.sub(r'\*{2,}', '', text)
    text = re.sub(r'#{1,6}\s', '', text)
    return text

# Lowercasing + whitespace normalization — Lowercasing standardizes the input and effectively reduces the vocabulary the model needs to learn.
def normalize_text(text):
    """Normalize text: clean markup, lowercase, remove extra whitespace and symbols"""
    text = clean_html_markdown(text)
    text = text.lower()
    text = re.sub(r'\s+', ' ', text)
    text = re.sub(r'[^\w\s.,!?;:\-\'\"]', '', text)
    text = text.strip()
    return text

# Short document filtering (min 50 words) — Extremely short texts lack sufficient context for learning long-range dependencie
def filter_short_documents(texts, min_words=50):
    """Remove documents with fewer than min_words"""
    filtered = []
    removed = 0

    for text in texts:
        if len(text.split()) >= min_words:
            filtered.append(text)
        else:
            removed += 1

    return filtered, removed

print("Normalizing and filtering texts...")

normalized_data = {}
total_removed = 0

#This step removes ~35% of WikiText documents
for source, texts in cleaned_data.items():
    print(f"Processing {source}...")
    normalized_texts = [normalize_text(text) for text in texts]
    filtered_texts, removed = filter_short_documents(normalized_texts, min_words=50)
    normalized_data[source] = filtered_texts
    total_removed += removed
    print(f"  After normalization: {len(normalized_texts)}")
    print(f"  After filtering (<50 words): {len(filtered_texts)}, Removed: {removed}")

total_final = sum(len(v) for v in normalized_data.values())
print(f"\nTotal documents after cleaning: {total_final}")
print(f"Total short documents removed: {total_removed}")

Normalizing and filtering texts...
Processing wikitext...
  After normalization: 279733
  After filtering (<50 words): 184165, Removed: 95568
Processing openwebtext...
  After normalization: 150000
  After filtering (<50 words): 149994, Removed: 6
Processing cnn_dailymail...
  After normalization: 146912
  After filtering (<50 words): 146887, Removed: 25

Total documents after cleaning: 481046
Total short documents removed: 95599


In [18]:
import pickle

print("Saving cleaned data...")
with open('./processed/normalized_data.pkl', 'wb') as f:
    pickle.dump(normalized_data, f)

print(f"Saved {sum(len(v) for v in normalized_data.values())} documents")
print("File: ./processed/normalized_data.pkl")

Saving cleaned data...
Saved 481046 documents
File: ./processed/normalized_data.pkl


# Tokenization

In [19]:
# Tokenize text into fixed-size blocks using GPT-2 tokenizer

# Byte-Pair Encoding handles out-of-vocabulary words gracefully through subword units
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Balances context length with memory efficiency.
# Larger blocks (1024) give more context but require more GPU memory per batch.
# Smaller blocks (128) are efficient but limit the model's ability to learn long-range dependencies.
BLOCK_SIZE = 512

def tokenize_and_chunk(texts, block_size=512):
    all_input_ids = []

    for text in tqdm(texts, desc="Tokenizing"):
        tokens = tokenizer(text, truncation=False, add_special_tokens=True)
        input_ids = tokens['input_ids']

        for i in range(0, len(input_ids), block_size):
            chunk = input_ids[i:i + block_size]
            # Only keep full blocks — Partial chunks at the end of documents are discarded.
            if len(chunk) == block_size:
                all_input_ids.append(chunk)

    return all_input_ids

tokenized_data = {}

for source, texts in normalized_data.items():
    checkpoint_path = f'./processed/tokenized_{source}.pkl'

    if os.path.exists(checkpoint_path):
        print(f"\nLoading cached {source}...")
        with open(checkpoint_path, 'rb') as f:
            tokenized_data[source] = pickle.load(f)
        print(f"  Loaded {len(tokenized_data[source])} blocks")
    else:
        print(f"\nTokenizing {source}...")
        tokenized_data[source] = tokenize_and_chunk(texts, BLOCK_SIZE)

        with open(checkpoint_path, 'wb') as f:
            pickle.dump(tokenized_data[source], f)

        print(f"  Documents: {len(texts)}")
        print(f"  Tokenized blocks: {len(tokenized_data[source])}")
        print(f"  Saved to {checkpoint_path}")

total_blocks = sum(len(v) for v in tokenized_data.values())
print(f"\nTotal tokenized blocks ({BLOCK_SIZE} tokens each): {total_blocks}")

Loading tokenizer...


config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]


Tokenizing wikitext...


Tokenizing: 100%|██████████| 184165/184165 [02:06<00:00, 1453.63it/s]


  Documents: 184165
  Tokenized blocks: 395
  Saved to ./processed/tokenized_wikitext.pkl

Tokenizing openwebtext...


Tokenizing:   0%|          | 0/149994 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1207 > 1024). Running this sequence through the model will result in indexing errors
Tokenizing: 100%|██████████| 149994/149994 [13:07<00:00, 190.54it/s]


  Documents: 149994
  Tokenized blocks: 234187
  Saved to ./processed/tokenized_openwebtext.pkl

Tokenizing cnn_dailymail...


Tokenizing: 100%|██████████| 146887/146887 [10:16<00:00, 238.13it/s]


  Documents: 146887
  Tokenized blocks: 181676
  Saved to ./processed/tokenized_cnn_dailymail.pkl

Total tokenized blocks (512 tokens each): 416258


# Custom PyTorch Dataset and DataLoader

In [20]:
# Create PyTorch Dataset and DataLoader for batch training

from torch.utils.data import Dataset, DataLoader

class TextDataset(Dataset):
    def __init__(self, tokenized_data_dict):
       # Flatten all tokenized blocks from different sources into a single flat list
        self.data = []
        for source, blocks in tokenized_data_dict.items():
            self.data.extend(blocks)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Return input_ids and labels (for language modeling, labels = input_ids)
        input_ids = torch.tensor(self.data[idx], dtype=torch.long)
        return {
            'input_ids': input_ids,
            'labels': input_ids.clone()
        }

dataset = TextDataset(tokenized_data)
print(f"Dataset size: {len(dataset)} blocks")

# Create DataLoader with batching and shuffling
BATCH_SIZE = 8 #  Chosen for Colab's CPU/memory constraints
dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True, # this ensures each training batch contains a mix of domains rather than seeing all WikiText, then all OpenWebText
    num_workers=2,
    pin_memory=False
)

print(f"DataLoader created with batch size: {BATCH_SIZE}")
print(f"Total batches: {len(dataloader)}")

print("\nTesting DataLoader...")
for batch in dataloader:
    print(f"Batch input_ids shape: {batch['input_ids'].shape}")
    print(f"Batch labels shape: {batch['labels'].shape}")
    break

print("\nDataLoader test successful")

Dataset size: 416258 blocks
DataLoader created with batch size: 8
Total batches: 52033

Testing DataLoader...
Batch input_ids shape: torch.Size([8, 512])
Batch labels shape: torch.Size([8, 512])

DataLoader test successful


# Save Sample Processed Data

In [21]:
# Save first 10 batches as sample output for assignment submission
print("Saving sample batches...")

sample_batches = []
for i, batch in enumerate(dataloader):
    sample_batches.append(batch)
    if i >= 9:
        break

torch.save(sample_batches, './samples/sample_dataset.pt')
print(f"Saved 10 sample batches to ./samples/sample_dataset.pt")

print("\nVerifying saved samples...")
loaded_samples = torch.load('./samples/sample_dataset.pt')
print(f"Loaded {len(loaded_samples)} batches")
print(f"First batch shape: {loaded_samples[0]['input_ids'].shape}")

print("\nSample statistics:")
print(f"Total dataset blocks: {len(dataset)}")
print(f"Block size: 512 tokens")
print(f"Batch size: {BATCH_SIZE}")
print(f"Total batches available: {len(dataloader)}")

Saving sample batches...
Saved 10 sample batches to ./samples/sample_dataset.pt

Verifying saved samples...
Loaded 10 batches
First batch shape: torch.Size([8, 512])

Sample statistics:
Total dataset blocks: 416258
Block size: 512 tokens
Batch size: 8
Total batches available: 52033


# Display Sample Data

In [22]:
# Show examples of tokenized data for report documentation

print("Sample Tokenized Data")
print("="*50)

# Display first 3 blocks from sample batches
for i in range(3):
    sample_block = loaded_samples[0]['input_ids'][i]
    print(f"\nBlock {i+1}:")
    print(f"Shape: {sample_block.shape}")
    print(f"Token IDs (first 20): {sample_block[:20].tolist()}")

    # Decode to show actual text
    decoded_text = tokenizer.decode(sample_block)
    print(f"Decoded text (first 200 chars): {decoded_text[:200]}...")

Sample Tokenized Data

Block 1:
Shape: torch.Size([512])
Token IDs (first 20): [13, 262, 976, 2925, 329, 13720, 281, 31506, 25, 345, 297, 307, 1654, 13, 611, 345, 423, 284, 1265, 11]
Decoded text (first 200 chars): . the same goes for identifying an ace: youll be sure. if you have to ask, a pitcher isnt an ace. across mlb in the 2016 regular season, when a starting pitcher completed at least 7.0 innings, his tea...

Block 2:
Shape: torch.Size([512])
Token IDs (first 20): [65, 16296, 1204, 11, 815, 1312, 307, 5213, 5633, 644, 1324, 2945, 3544, 308, 862, 1752, 284, 5004, 534, 4238]
Decoded text (first 200 chars): battery life, should i be concerned ? whatappened uses gps once to determine your initial location, after that it uses wifi hotspots and cell towers to determine your location. one of the beauties of ...

Block 3:
Shape: torch.Size([512])
Token IDs (first 20): [318, 6159, 20109, 4249, 4336, 39056, 13, 340, 4206, 3446, 644, 340, 318, 1804, 290, 340, 1724, 284, 466, 340]
Decoded t

In [23]:
torch.save(tokenized_data, './processed/full_tokenized_dataset.pt')