# Assignment 1: Data Collection and Preprocessing for Foundation Model Pre‑Training

This notebook demonstrates how to build a data preprocessing pipeline for pre‑training transformer‑based foundation models. It mirrors the original Python script, split into explanatory cells. You can run the parts you need and adapt them for larger corpora (streaming-friendly).

**What you'll see:**
- Collect raw text via Hugging Face `datasets` (streaming)
- Clean & normalize text (lowercasing, strip HTML/URLs, collapse whitespace, etc.)
- Deduplicate and filter too‑short documents
- Tokenize with a Hugging Face `AutoTokenizer`; chunk to fixed block size
- Build a PyTorch `Dataset` + `DataLoader` with padding & masks
- Save a few tokenized batches to `.pt` for inspection

> **Note**: This notebook expects `datasets`, `transformers`, and `torch` to be installed in your environment. Install them with the cell below if needed.

In [25]:
# Optional: install dependencies (uncomment to run in a fresh environment)
# !pip install -U datasets transformers torch


## Imports & Logger
We import lazily and keep placeholders so the notebook can be read without immediately having all libraries.

In [26]:
from __future__ import annotations
import argparse
import collections
import html
import logging
import re
import sys
from typing import Iterable, Iterator, List, Tuple

import datasets  # type: ignore
import transformers  # type: ignore
import torch  # type: ignore

try:
    from tqdm.auto import tqdm
except ImportError:  # pragma: no cover - fallback without tqdm
    def tqdm(iterable=None, **kwargs):
        return iterable if iterable is not None else []



logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


## 1) Stream a dataset from the Hugging Face Hub

In [27]:
def load_dataset_stream(dataset: str, dataset_name: str | None = None, split: str = "train") -> Iterable[str]:
    """Load text documents from a Hugging Face dataset using streaming."""
    if datasets is None:
        raise ImportError("The 'datasets' library is required to load datasets.")
    logger.info("Loading dataset %s%s with split=%s", dataset, f"[{dataset_name}]" if dataset_name else "", split)
    data = datasets.load_dataset(dataset, dataset_name, split=split, streaming=True)
    for example in data:
        text = example.get("text")
        if text is None:
            for value in example.values():
                if isinstance(value, str):
                    text = value
                    break
        if text is None:
            continue
        yield text


## 2) Cleaning utilities
Basic cleaning: unescape HTML, strip tags/URLs, lowercase, ASCII normalize, remove punctuation, collapse whitespace.

In [28]:
def unicodedata_normalize(text: str) -> str:
    """Normalize accented characters to ASCII equivalents (for primarily English corpora)."""
    import unicodedata
    normalized = unicodedata.normalize("NFKD", text)
    return normalized.encode("ascii", "ignore").decode("utf-8", "ignore")

def clean_text(text: str) -> str:
    text = html.unescape(text)
    text = re.sub(r"<[^>]+>", " ", text)  # remove HTML tags
    text = re.sub(r"https?://\S+|www\.\S+", " ", text)  # remove URLs
    text = text.lower()
    text = unicodedata_normalize(text)
    text = re.sub(r"[^\w\s]", " ", text)  # remove punctuation/symbols
    text = re.sub(r"\s+", " ", text).strip()
    return text


## 3) Deduplication & filtering short docs

In [29]:
def deduplicate_documents(docs: Iterable[str]) -> List[str]:
    seen: set[str] = set()
    unique_docs: List[str] = []
    for doc in tqdm(docs, desc="Deduplicating documents"):
        if doc not in seen:
            seen.add(doc)
            unique_docs.append(doc)
    return unique_docs

def filter_short_documents(docs: Iterable[str], min_words: int = 50) -> List[str]:
    filtered: List[str] = []
    for doc in tqdm(docs, desc="Deduplicating documents"):
        if len(doc.split()) >= min_words:
            filtered.append(doc)
    return filtered


## 4) Tokenize and chunk with Hugging Face `AutoTokenizer`

In [30]:
def tokenize_and_chunk(docs: Iterable[str], tokenizer_name: str, block_size: int) -> List[List[int]]:
    if transformers is None:
        raise ImportError("The 'transformers' library is required for tokenisation.")
    tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name)
    tokenised_blocks: List[List[int]] = []
    for doc in tqdm(docs, desc="Tokenizing documents"):
        tokens = tokenizer.encode(doc, add_special_tokens=False)
        for i in range(0, len(tokens), block_size):
            chunk = tokens[i : i + block_size]
            tokenised_blocks.append(chunk)
    return tokenised_blocks


## 5) PyTorch Dataset, collate function, and saving sample batches

In [31]:
class TextDataset:
    """PyTorch-like dataset for tokenised text blocks."""
    def __init__(self, token_blocks: List[List[int]], pad_token_id: int = 0) -> None:
        if torch is None:
            raise ImportError("The 'torch' library is required for creating the dataset.")
        self.token_blocks = token_blocks
        self.pad_token_id = pad_token_id
    def __len__(self) -> int:
        return len(self.token_blocks)
    def __getitem__(self, idx: int):
        ids = self.token_blocks[idx]
        return torch.tensor(ids, dtype=torch.long), torch.ones(len(ids), dtype=torch.long)

def collate_fn(batch):
    if torch is None:
        raise ImportError("The 'torch' library is required for collation.")
    input_ids, attention_masks = zip(*batch)
    max_len = max(seq.size(0) for seq in input_ids)
    padded_ids, padded_masks = [], []
    for ids, mask in zip(input_ids, attention_masks):
        pad_length = max_len - ids.size(0)
        padded_ids.append(torch.cat([ids, torch.full((pad_length,), 0, dtype=torch.long)]))
        padded_masks.append(torch.cat([mask, torch.zeros(pad_length, dtype=torch.long)]))
    return {"input_ids": torch.stack(padded_ids), "attention_mask": torch.stack(padded_masks)}

def save_sample_batches(dataloader, num_batches: int | None, output_path: str) -> int:
    if torch is None:
        raise ImportError("The 'torch' library is required for saving sample batches.")
    saved_batches: List[dict] = []
    saved_count = 0
    for i, batch in enumerate(tqdm(dataloader, desc="Saving sample batches", total=num_batches)):
        if num_batches is not None and i >= num_batches:
            break
        saved_batches.append({k: v.clone().cpu() for k, v in batch.items()})
        saved_count += 1
    try:
        torch.save(saved_batches, output_path)
    except Exception:
        import pickle
        with open(output_path, "wb") as fh:
            pickle.dump(saved_batches, fh)
    return saved_count


## 6) Full preprocessing pipeline (callable from this notebook)

In [32]:
def preprocess_pipeline(
    dataset: str,
    dataset_name: str | None,
    tokenizer_name: str,
    block_size: int,
    min_words: int,
    sample_batches: int,
    output_path: str,
    max_documents: int | None = None,
) -> None:
    docs_iter: Iterator[str] = load_dataset_stream(dataset, dataset_name)
    cleaned_docs: List[str] = []
    for i, doc in enumerate(tqdm(docs_iter, desc="Cleaning documents", total=max_documents), start=1):
        cleaned = clean_text(doc)
        cleaned_docs.append(cleaned)
        if max_documents is not None and i >= max_documents:
            break
    logger.info("Loaded %d documents", len(cleaned_docs))
    unique_docs = deduplicate_documents(cleaned_docs)
    logger.info("After deduplication: %d documents", len(unique_docs))
    filtered_docs = filter_short_documents(unique_docs, min_words=min_words)
    logger.info("After filtering short docs: %d documents", len(filtered_docs))
    token_blocks = tokenize_and_chunk(filtered_docs, tokenizer_name, block_size)
    logger.info("Number of token blocks: %d", len(token_blocks))
    if transformers is None or torch is None:
        raise ImportError("Both 'transformers' and 'torch' are required to continue.")
    tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name)
    pad_token_id = tokenizer.pad_token_id or 0
    dataset_obj = TextDataset(token_blocks, pad_token_id=pad_token_id)
    dataloader = torch.utils.data.DataLoader(dataset_obj, batch_size=8, shuffle=True, collate_fn=collate_fn)
    saved = save_sample_batches(dataloader, sample_batches, output_path)
    logger.info("Saved %d batches to %s", saved, output_path)



## 7) Example configuration (edit and run)
You can tweak these parameters for your environment/dataset.

In [33]:
config = {
    "dataset": "wiki40b",          # e.g., 'openwebtext', 'c4', etc.
    "dataset_name": "en",           # or None if single-config dataset
    "tokenizer": "gpt2",            # or 'bert-base-uncased', etc.
    "block_size": 1024,
    "min_words": 50,
    "sample_batches": 10,            # or None to get all
    "output": "sample_dataset.pt",
    "max_docs": 1000,                # limit for quick trial; set None for full stream
}
config


{'dataset': 'wiki40b',
 'dataset_name': 'en',
 'tokenizer': 'gpt2',
 'block_size': 1024,
 'min_words': 50,
 'sample_batches': 10,
 'output': 'sample_dataset.pt',
 'max_docs': 1000}

## 8) Run the full pipeline (may take time depending on dataset/network)
Uncomment the cell below to execute.

In [34]:
_cfg = config
preprocess_pipeline(
    dataset=_cfg["dataset"],
    dataset_name=_cfg["dataset_name"],
    tokenizer_name=_cfg["tokenizer"],
    block_size=_cfg["block_size"],
    min_words=_cfg["min_words"],
    sample_batches=_cfg["sample_batches"],
    output_path=_cfg["output"],
    max_documents=_cfg["max_docs"],
)


2025-09-21 02:42:48,174 - INFO - Loading dataset wiki40b[en] with split=train


Resolving data files:   0%|          | 0/19 [00:00<?, ?it/s]

2025-09-21 02:42:52,023 - INFO - Loaded 1000 documents
2025-09-21 02:42:52,026 - INFO - After deduplication: 1000 documents
2025-09-21 02:42:52,055 - INFO - After filtering short docs: 876 documents
Token indices sequence length is longer than the specified maximum sequence length for this model (1420 > 1024). Running this sequence through the model will result in indexing errors
2025-09-21 02:42:56,054 - INFO - Number of token blocks: 1184
2025-09-21 02:42:56,588 - INFO - Saved 10 batches to sample_dataset.pt


### Notes & Tips
- For multilingual corpora, consider **not** converting to ASCII in `unicodedata_normalize`.
- If your tokenizer's `pad_token_id` is not 0, you can modify `collate_fn` to use it.
- For large-scale jobs, shard tokenized blocks to disk instead of keeping all in memory.