In [1]:
!pip install -q datasets transformers torch tqdm

In [2]:
import os
import sys
from datasets import load_dataset, concatenate_datasets
from transformers import AutoTokenizer
import torch
from collections import Counter
import json
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

NUM_WORKERS = 2
os.environ["TOKENIZERS_PARALLELISM"] = "true"

device = torch.device("cpu")
print(f"Using device: {device}")

print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CPU cores available: {os.cpu_count()}")

# Create working directories
os.makedirs('./data', exist_ok=True)
os.makedirs('./processed', exist_ok=True)
os.makedirs('./samples', exist_ok=True)

Using device: cpu
Python version: 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0]
PyTorch version: 2.9.0+cpu
CPU cores available: 2


# Load Datasets

In [3]:
# Target sample sizes to meet 1GB requirement
DATASET_CONFIG = {
    'wikitext_samples': 250000,
    'openwebtext_samples': 200000,
    'ag_news_samples': 50000
}

print("Loading WikiText-103...")
wikitext_dataset = load_dataset(
    "Salesforce/wikitext",
    "wikitext-103-raw-v1",
    split="train",
    streaming=True # Streaming mode for memory efficiency
)

print("Loading OpenWebText...")
openwebtext_dataset = load_dataset(
    "Skylion007/openwebtext",
    split="train",
    streaming=True # Streaming mode for memory efficiency
)

print("Loading AG News...")
ag_news_dataset = load_dataset(
    "fancyzhx/ag_news",
    split="train",
    streaming=False
)

print(f"\nDatasets loaded successfully")
print(f"Target samples: {sum(DATASET_CONFIG.values())} total documents")

Loading WikiText-103...


README.md: 0.00B [00:00, ?B/s]

Loading OpenWebText...


README.md: 0.00B [00:00, ?B/s]

Resolving data files:   0%|          | 0/80 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/80 [00:00<?, ?it/s]

Loading AG News...


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/18.6M [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/1.23M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/120000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/7600 [00:00<?, ? examples/s]


Datasets loaded successfully
Target samples: 500000 total documents


# Calculate Total Dataset Size

In [4]:
import sys

def estimate_text_size(text):
    return sys.getsizeof(text.encode('utf-8'))

def calculate_dataset_size(dataset, num_samples, text_field='text'):
    total_bytes = 0
    count = 0

    for example in dataset:
        if text_field in example and example[text_field]:
            total_bytes += estimate_text_size(example[text_field])
            count += 1
            if count >= num_samples:
                break

    size_mb = total_bytes / (1024 * 1024)
    size_gb = size_mb / 1024
    return size_mb, size_gb, count

# Calculate size for each dataset
wikitext_mb, wikitext_gb, wikitext_count = calculate_dataset_size(
    wikitext_dataset,
    DATASET_CONFIG['wikitext_samples']
)

openwebtext_mb, openwebtext_gb, openwebtext_count = calculate_dataset_size(
    openwebtext_dataset,
    DATASET_CONFIG['openwebtext_samples']
)

ag_news_mb, ag_news_gb, ag_news_count = calculate_dataset_size(
    ag_news_dataset.select(range(DATASET_CONFIG['ag_news_samples'])),
    DATASET_CONFIG['ag_news_samples']
)

total_mb = wikitext_mb + openwebtext_mb + ag_news_mb
total_gb = total_mb / 1024

print(f"WikiText-103: {wikitext_mb:.2f} MB ({wikitext_count} docs)")
print(f"OpenWebText: {openwebtext_mb:.2f} MB ({openwebtext_count} docs)")
print(f"AG News: {ag_news_mb:.2f} MB ({ag_news_count} docs)")
print(f"Total: {total_mb:.2f} MB ({total_gb:.3f} GB)")
print(f"Status: {'PASS' if total_gb >= 1.0 else 'FAIL'}")

WikiText-103: 117.35 MB (250000 docs)
OpenWebText: 957.90 MB (200000 docs)
AG News: 12.94 MB (50000 docs)
Total: 1088.19 MB (1.063 GB)
Status: PASS


# Collect Raw Data into Memory

In [5]:
raw_data = {
    'wikitext': [],
    'openwebtext': [],
    'ag_news': []
}

# Collect WikiText documents (streaming)
print("Collecting WikiText-103...")
for i, example in enumerate(wikitext_dataset):
    if i >= DATASET_CONFIG['wikitext_samples']:
        break
    if example['text'] and example['text'].strip():
        raw_data['wikitext'].append(example['text'])

# Collect OpenWebText documents (streaming)
print("Collecting OpenWebText...")
for i, example in enumerate(openwebtext_dataset):
    if i >= DATASET_CONFIG['openwebtext_samples']:
        break
    if example['text'] and example['text'].strip():
        raw_data['openwebtext'].append(example['text'])

# Collect AG News documents (full dataset loaded)
print("Collecting AG News...")
for i in range(DATASET_CONFIG['ag_news_samples']):
    example = ag_news_dataset[i]
    if example['text'] and example['text'].strip():
        raw_data['ag_news'].append(example['text'])

total_docs = sum(len(v) for v in raw_data.values())
print(f"\nCollected {total_docs} documents")
print(f"  WikiText: {len(raw_data['wikitext'])}")
print(f"  OpenWebText: {len(raw_data['openwebtext'])}")
print(f"  AG News: {len(raw_data['ag_news'])}")

Collecting WikiText-103...
Collecting OpenWebText...
Collecting AG News...

Collected 411633 documents
  WikiText: 161633
  OpenWebText: 200000
  AG News: 50000


# Text Cleaning - Remove Duplicates


In [6]:
# Deduplicate documents using hash-based matching

print("Starting deduplication...")

def deduplicate_texts(texts):
    seen = set()
    unique_texts = []
    duplicates = 0

    for text in texts:
        text_hash = hash(text)
        if text_hash not in seen:
            seen.add(text_hash)
            unique_texts.append(text)
        else:
            duplicates += 1

    return unique_texts, duplicates

cleaned_data = {}
total_duplicates = 0

for source, texts in raw_data.items():
    print(f"Processing {source}...")
    unique_texts, dupes = deduplicate_texts(texts)
    cleaned_data[source] = unique_texts
    total_duplicates += dupes
    print(f"  Original: {len(texts)}, Unique: {len(unique_texts)}, Duplicates: {dupes}")

total_after = sum(len(v) for v in cleaned_data.values())
print(f"\nTotal documents before: {sum(len(v) for v in raw_data.values())}")
print(f"Total documents after: {total_after}")
print(f"Total duplicates removed: {total_duplicates}")

Starting deduplication...
Processing wikitext...
  Original: 161633, Unique: 141630, Duplicates: 20003
Processing openwebtext...
  Original: 200000, Unique: 200000, Duplicates: 0
Processing ag_news...
  Original: 50000, Unique: 50000, Duplicates: 0

Total documents before: 411633
Total documents after: 391630
Total duplicates removed: 20003


# Text Normalization

In [7]:
# Clean text: remove HTML/markdown, normalize case/whitespace, filter short documents

import re

def clean_html_markdown(text):
  """Remove HTML tags and markdown formatting"""
    text = re.sub(r'<[^>]+>', '', text)
    text = re.sub(r'\[.*?\]\(.*?\)', '', text)
    text = re.sub(r'={2,}', '', text)
    text = re.sub(r'\*{2,}', '', text)
    text = re.sub(r'#{1,6}\s', '', text)
    return text

def normalize_text(text):
  """Normalize text: clean markup, lowercase, remove extra whitespace and symbols"""
    text = clean_html_markdown(text)
    text = text.lower()
    text = re.sub(r'\s+', ' ', text)
    text = re.sub(r'[^\w\s.,!?;:\-\'\"]', '', text)
    text = text.strip()
    return text

def filter_short_documents(texts, min_words=50):
  """Remove documents with fewer than min_words"""
    filtered = []
    removed = 0

    for text in texts:
        if len(text.split()) >= min_words:
            filtered.append(text)
        else:
            removed += 1

    return filtered, removed

print("Normalizing and filtering texts...")

normalized_data = {}
total_removed = 0

for source, texts in cleaned_data.items():
    print(f"Processing {source}...")
    normalized_texts = [normalize_text(text) for text in texts]
    filtered_texts, removed = filter_short_documents(normalized_texts, min_words=50)
    normalized_data[source] = filtered_texts
    total_removed += removed
    print(f"  After normalization: {len(normalized_texts)}")
    print(f"  After filtering (<50 words): {len(filtered_texts)}, Removed: {removed}")

total_final = sum(len(v) for v in normalized_data.values())
print(f"\nTotal documents after cleaning: {total_final}")
print(f"Total short documents removed: {total_removed}")

Normalizing and filtering texts...
Processing wikitext...
  After normalization: 141630
  After filtering (<50 words): 91771, Removed: 49859
Processing openwebtext...
  After normalization: 200000
  After filtering (<50 words): 199994, Removed: 6
Processing ag_news...
  After normalization: 50000
  After filtering (<50 words): 4663, Removed: 45337

Total documents after cleaning: 296428
Total short documents removed: 95202


In [9]:
import pickle

print("Saving cleaned data...")
with open('./processed/normalized_data.pkl', 'wb') as f:
    pickle.dump(normalized_data, f)

print(f"Saved {sum(len(v) for v in normalized_data.values())} documents")
print("File: ./processed/normalized_data.pkl")

Saving cleaned data...
Saved 296428 documents
File: ./processed/normalized_data.pkl


# Tokenization

In [10]:
# Tokenize text into fixed-size blocks using GPT-2 tokenizer

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

BLOCK_SIZE = 512

def tokenize_and_chunk(texts, block_size=512):
    all_input_ids = []

    for text in tqdm(texts, desc="Tokenizing"):
        tokens = tokenizer(text, truncation=False, add_special_tokens=True)
        input_ids = tokens['input_ids']

        for i in range(0, len(input_ids), block_size):
            chunk = input_ids[i:i + block_size]
            if len(chunk) == block_size:
                all_input_ids.append(chunk)

    return all_input_ids

tokenized_data = {}

for source, texts in normalized_data.items():
    checkpoint_path = f'./processed/tokenized_{source}.pkl'

    if os.path.exists(checkpoint_path):
        print(f"\nLoading cached {source}...")
        with open(checkpoint_path, 'rb') as f:
            tokenized_data[source] = pickle.load(f)
        print(f"  Loaded {len(tokenized_data[source])} blocks")
    else:
        print(f"\nTokenizing {source}...")
        tokenized_data[source] = tokenize_and_chunk(texts, BLOCK_SIZE)

        with open(checkpoint_path, 'wb') as f:
            pickle.dump(tokenized_data[source], f)

        print(f"  Documents: {len(texts)}")
        print(f"  Tokenized blocks: {len(tokenized_data[source])}")
        print(f"  Saved to {checkpoint_path}")

total_blocks = sum(len(v) for v in tokenized_data.values())
print(f"\nTotal tokenized blocks ({BLOCK_SIZE} tokens each): {total_blocks}")

Loading tokenizer...

Tokenizing wikitext...


Tokenizing: 100%|██████████| 91771/91771 [01:37<00:00, 940.76it/s] 


  Documents: 91771
  Tokenized blocks: 189
  Saved to ./processed/tokenized_wikitext.pkl

Tokenizing openwebtext...


Tokenizing:   0%|          | 0/199994 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1207 > 1024). Running this sequence through the model will result in indexing errors
Tokenizing: 100%|██████████| 199994/199994 [30:25<00:00, 109.56it/s]


  Documents: 199994
  Tokenized blocks: 312366
  Saved to ./processed/tokenized_openwebtext.pkl

Tokenizing ag_news...


Tokenizing: 100%|██████████| 4663/4663 [00:03<00:00, 1404.59it/s]

  Documents: 4663
  Tokenized blocks: 0
  Saved to ./processed/tokenized_ag_news.pkl

Total tokenized blocks (512 tokens each): 312555





# Custom PyTorch Dataset and DataLoader

In [11]:
# Create PyTorch Dataset and DataLoader for batch training

from torch.utils.data import Dataset, DataLoader

class TextDataset(Dataset):
    def __init__(self, tokenized_data_dict):
       # Flatten all tokenized blocks from different sources
        self.data = []
        for source, blocks in tokenized_data_dict.items():
            self.data.extend(blocks)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Return input_ids and labels (for language modeling, labels = input_ids)
        input_ids = torch.tensor(self.data[idx], dtype=torch.long)
        return {
            'input_ids': input_ids,
            'labels': input_ids.clone()
        }

dataset = TextDataset(tokenized_data)
print(f"Dataset size: {len(dataset)} blocks")

# Create DataLoader with batching and shuffling
BATCH_SIZE = 8
dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    pin_memory=False
)

print(f"DataLoader created with batch size: {BATCH_SIZE}")
print(f"Total batches: {len(dataloader)}")

print("\nTesting DataLoader...")
for batch in dataloader:
    print(f"Batch input_ids shape: {batch['input_ids'].shape}")
    print(f"Batch labels shape: {batch['labels'].shape}")
    break

print("\nDataLoader test successful")

Dataset size: 312555 blocks
DataLoader created with batch size: 8
Total batches: 39070

Testing DataLoader...
Batch input_ids shape: torch.Size([8, 512])
Batch labels shape: torch.Size([8, 512])

DataLoader test successful


# Save Sample Processed Data

In [12]:
# Save first 10 batches as sample output for assignment submission
print("Saving sample batches...")

sample_batches = []
for i, batch in enumerate(dataloader):
    sample_batches.append(batch)
    if i >= 9:
        break

torch.save(sample_batches, './samples/sample_dataset.pt')
print(f"Saved 10 sample batches to ./samples/sample_dataset.pt")

print("\nVerifying saved samples...")
loaded_samples = torch.load('./samples/sample_dataset.pt')
print(f"Loaded {len(loaded_samples)} batches")
print(f"First batch shape: {loaded_samples[0]['input_ids'].shape}")

print("\nSample statistics:")
print(f"Total dataset blocks: {len(dataset)}")
print(f"Block size: 512 tokens")
print(f"Batch size: {BATCH_SIZE}")
print(f"Total batches available: {len(dataloader)}")

Saving sample batches...
Saved 10 sample batches to ./samples/sample_dataset.pt

Verifying saved samples...
Loaded 10 batches
First batch shape: torch.Size([8, 512])

Sample statistics:
Total dataset blocks: 312555
Block size: 512 tokens
Batch size: 8
Total batches available: 39070


# Display Sample Data

In [14]:
# Show examples of tokenized data for report documentation

print("Sample Tokenized Data")
print("="*50)

# Display first 3 blocks from sample batches
for i in range(3):
    sample_block = loaded_samples[0]['input_ids'][i]
    print(f"\nBlock {i+1}:")
    print(f"Shape: {sample_block.shape}")
    print(f"Token IDs (first 20): {sample_block[:20].tolist()}")

    # Decode to show actual text
    decoded_text = tokenizer.decode(sample_block)
    print(f"Decoded text (first 200 chars): {decoded_text[:200]}...")

Sample Tokenized Data

Block 1:
Shape: torch.Size([512])
Token IDs (first 20): [19138, 286, 12858, 11, 1312, 481, 2298, 8591, 8591, 1956, 284, 1592, 764, 21024, 26456, 25, 2853, 271, 4489, 1734]
Decoded text (first 200 chars): otive of momentum, i will pick la la land to win . directing nominees: denis villeneuve, arrival  mel gibson, hacksaw ridge  damien chazelle, la la land  kenneth lonergan, manchester by the sea  barry...

Block 2:
Shape: torch.Size([512])
Token IDs (first 20): [11, 428, 290, 635, 428, 220, 3651, 13, 1058, 1201, 345, 423, 284, 3551, 606, 19617, 11, 345, 765, 284]
Decoded text (first 200 chars): , this and also this  comments. : since you have to write them coding, you want to be just enough specific; a commit comment becomes a , just like a bdd method name. each preemptive comment triggers a...

Block 3:
Shape: torch.Size([512])
Token IDs (first 20): [611, 484, 1657, 510, 287, 1171, 13, 1909, 11, 2828, 422, 1111, 883, 4736, 1302, 416, 262, 2458, 11, 996]
Decoded t