In [None]:
"""
Assignment 1: Data Collection and Preprocessing for Foundation Model Pre-Training
Complete implementation with dataset collection, cleaning, tokenization, and custom data loader
"""

import os
import re
import hashlib
from collections import Counter
from typing import List, Dict, Iterator
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer
from tqdm import tqdm
import json
import gc

In [None]:
# ============================================================================
# 1. DATA COLLECTION
# ============================================================================
class DataCollector:
    """Handles collection of diverse text datasets from multiple sources"""

    def __init__(self, output_dir: str = "./raw_data"):
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)

    def collect_wikipedia(self, num_samples: int = 100000):
        """
        Collect Wikipedia articles
        Source: English Wikipedia dump
        Domain: Encyclopedic knowledge
        """
        print("Collecting Wikipedia data...")
        dataset = load_dataset("wikimedia/wikipedia", "20231101.en", split="train", streaming=True)

        texts = []
        for i, example in enumerate(tqdm(dataset, total=num_samples, desc="Wikipedia")):
            if i >= num_samples:
                break
            texts.append(example['text'])

        output_path = os.path.join(self.output_dir, "wikipedia.txt")
        with open(output_path, 'w', encoding='utf-8') as f:
            f.write('\n\n'.join(texts))

        return texts

    def collect_openwebtext(self, num_samples: int = 50000):
        """
        Collect OpenWebText data
        Source: Reddit links dataset
        Domain: General web text, discussions
        """
        print("Collecting OpenWebText data...")
        dataset = load_dataset("openwebtext", split="train", streaming=True)

        texts = []
        for i, example in enumerate(tqdm(dataset, total=num_samples, desc="OpenWebText")):
            if i >= num_samples:
                break
            texts.append(example['text'])

        output_path = os.path.join(self.output_dir, "openwebtext.txt")
        with open(output_path, 'w', encoding='utf-8') as f:
            f.write('\n\n'.join(texts))

        return texts

    def get_dataset_stats(self):
        """Calculate and display dataset statistics"""
        total_size = 0
        stats = {}

        for filename in os.listdir(self.output_dir):
            filepath = os.path.join(self.output_dir, filename)
            size = os.path.getsize(filepath)
            total_size += size
            stats[filename] = {
                'size_mb': size / (1024 * 1024),
                'size_gb': size / (1024 * 1024 * 1024)
            }

        print("\n" + "="*60)
        print("DATASET STATISTICS")
        print("="*60)
        for filename, stat in stats.items():
            print(f"{filename}: {stat['size_mb']:.2f} MB ({stat['size_gb']:.4f} GB)")
        print(f"\nTotal Size: {total_size / (1024**3):.4f} GB")
        print("="*60 + "\n")

        return stats

In [None]:
# ============================================================================
# 2. DATA CLEANING AND PREPROCESSING
# ============================================================================

class TextCleaner:
    """Comprehensive text cleaning and normalization"""

    def __init__(self, min_words: int = 50):
        self.min_words = min_words
        self.seen_hashes = set()

    def clean_text(self, text: str) -> str:
        """Apply all cleaning operations to text"""
        # Remove HTML tags
        text = re.sub(r'<[^>]+>', '', text)

        # Remove markdown artifacts
        text = re.sub(r'\[.*?\]\(.*?\)', '', text)  # Remove links
        text = re.sub(r'[#*_~`]', '', text)  # Remove markdown symbols

        # Remove reference markers like [1], [citation needed]
        text = re.sub(r'\[\d+\]', '', text)
        text = re.sub(r'\[citation needed\]', '', text, flags=re.IGNORECASE)

        # Normalize whitespace
        text = re.sub(r'\s+', ' ', text)
        text = text.strip()

        # Remove special characters but keep punctuation
        text = re.sub(r'[^\w\s.,!?;:\'\"-]', '', text)

        return text

    def is_low_quality(self, text: str) -> bool:
        """Check if text is low quality and should be filtered"""
        # Too short
        word_count = len(text.split())
        if word_count < self.min_words:
            return True

        # Too many non-alphabetic characters (spam indicators)
        alpha_ratio = sum(c.isalpha() or c.isspace() for c in text) / max(len(text), 1)
        if alpha_ratio < 0.7:
            return True

        # Too repetitive (same word repeated many times)
        words = text.lower().split()
        if len(words) > 0:
            most_common_ratio = Counter(words).most_common(1)[0][1] / len(words)
            if most_common_ratio > 0.3:
                return True

        return False

    def get_hash(self, text: str) -> str:
        """Generate hash for deduplication"""
        return hashlib.md5(text.encode('utf-8')).hexdigest()

    def is_duplicate(self, text: str) -> bool:
        """Check if text is duplicate using hash"""
        text_hash = self.get_hash(text)
        if text_hash in self.seen_hashes:
            return True
        self.seen_hashes.add(text_hash)
        return False

    def process_documents(self, texts: List[str]) -> List[str]:
        """Process multiple documents with cleaning and filtering"""
        cleaned_texts = []
        stats = {
            'total': len(texts),
            'duplicates': 0,
            'low_quality': 0,
            'kept': 0
        }

        for text in tqdm(texts, desc="Cleaning documents"):
            # Clean text
            cleaned = self.clean_text(text)

            # Check for duplicates
            if self.is_duplicate(cleaned):
                stats['duplicates'] += 1
                continue

            # Check quality
            if self.is_low_quality(cleaned):
                stats['low_quality'] += 1
                continue

            cleaned_texts.append(cleaned)
            stats['kept'] += 1

        print("\n" + "="*60)
        print("CLEANING STATISTICS")
        print("="*60)
        print(f"Total documents: {stats['total']}")
        print(f"Duplicates removed: {stats['duplicates']}")
        print(f"Low quality removed: {stats['low_quality']}")
        print(f"Documents kept: {stats['kept']}")
        print(f"Retention rate: {stats['kept']/stats['total']*100:.2f}%")
        print("="*60 + "\n")

        return cleaned_texts

In [None]:
# ============================================================================
# 3. TOKENIZATION
# ============================================================================

class TextTokenizer:
    """Handle tokenization with chunking - MEMORY SAFE VERSION"""

    def __init__(self, model_name: str = "gpt2", max_length: int = 512):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
        self.max_length = max_length

        # Add padding token if not present
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        print(f"\nTokenizer Info:")
        print(f"  Model: {model_name}")
        print(f"  Vocab size: {self.tokenizer.vocab_size}")
        print(f"  Max length: {max_length}")
        print(f"  Fast tokenizer: {self.tokenizer.is_fast}")
        print(f"  Tokenization type: BPE (Byte-Pair Encoding)\n")

    def tokenize_and_chunk(self, text: str) -> List[List[int]]:
        """
        Tokenize text and chunk into sequences of max_length
        Returns list of token sequences
        """
        # Tokenize entire text
        tokens = self.tokenizer.encode(text, add_special_tokens=False)

        # Chunk into sequences of max_length
        chunks = []
        for i in range(0, len(tokens), self.max_length):
            chunk = tokens[i:i + self.max_length]
            if len(chunk) >= 50:  # Only keep chunks with at least 50 tokens
                chunks.append(chunk)

        return chunks

    def process_documents_segments(self,
                                   documents: List[str],
                                   segment_size: int = 50000,
                                   output_dir: str = "./tokenization_segments") -> str:
        """
        Process documents in segments to avoid memory issues
        Each segment is saved immediately to disk

        Returns: output_dir path
        """
        os.makedirs(output_dir, exist_ok=True)

        num_segments = (len(documents) + segment_size - 1) // segment_size

        print(f"\n{'='*60}")
        print("MEMORY-SAFE SEGMENTED TOKENIZATION")
        print('='*60)
        print(f"  Total documents: {len(documents):,}")
        print(f"  Segment size: {segment_size:,}")
        print(f"  Number of segments: {num_segments}")
        print(f"  Output directory: {output_dir}")
        print('='*60 + '\n')

        total_sequences = 0
        batch_size = 5000  # Larger batch size for efficiency

        for seg_idx in range(num_segments):
            start = seg_idx * segment_size
            end = min(start + segment_size, len(documents))
            segment_docs = documents[start:end]

            segment_file = os.path.join(output_dir, f"segment_{seg_idx:03d}.pt")

            # Check if already processed
            if os.path.exists(segment_file):
                print(f"‚úì Segment {seg_idx}/{num_segments-1} already exists, loading...")
                segment_data = torch.load(segment_file)
                total_sequences += len(segment_data)
                del segment_data
                continue

            print(f"\n{'‚îÄ'*60}")
            print(f"SEGMENT {seg_idx}/{num_segments-1}")
            print(f"Documents: {start:,} to {end:,}")
            print('‚îÄ'*60)

            segment_sequences = []

            # Filter short documents
            valid_docs = [doc for doc in segment_docs if len(doc.strip()) >= 100]
            print(f"  Valid documents: {len(valid_docs):,}")

            # Process in batches
            for i in tqdm(range(0, len(valid_docs), batch_size),
                     desc=f"  Processing",
                     leave=False):
              batch_docs = valid_docs[i:i + batch_size]

              try:
                  # batch tokenize
                  encoded = self.tokenizer(
                      batch_docs,
                      add_special_tokens=False,
                      truncation=False,
                      return_attention_mask=False,
                      padding=False
                  )

                  # Chunking
                  for token_ids in encoded['input_ids']:
                      if len(token_ids) > 50000:
                          continue

                      for j in range(0, len(token_ids), self.max_length):
                          chunk = token_ids[j:j + self.max_length]
                          if len(chunk) >= 50:
                              segment_sequences.append(chunk)

                  # ‚úÖ cleaning
                  del encoded
                  del batch_docs

              except Exception as e:
                  print(f"\n    ‚ö†Ô∏è Batch {i//batch_size} failed: {str(e)[:100]}")
                  continue

              # Clear memory periodically
              if i % (batch_size * 3) == 0 and i > 0:
                  gc.collect()

            # Save this segment immediately
            torch.save(segment_sequences, segment_file)
            total_sequences += len(segment_sequences)

            print(f"  ‚úÖ Segment complete: {len(segment_sequences):,} sequences")
            print(f"  üíæ Saved to: {segment_file}")
            print(f"  üìä Running total: {total_sequences:,} sequences")

            # Clear memory
            del segment_docs
            del valid_docs
            del segment_sequences
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        print(f"\n{'='*60}")
        print("‚úÖ ALL SEGMENTS COMPLETE")
        print('='*60)
        print(f"  Total sequences: {total_sequences:,}")
        print(f"  Segments saved in: {output_dir}")
        print('='*60 + '\n')

        return output_dir

    def load_segments_metadata(self, segment_dir: str) -> dict:
        """Load metadata about all segments"""
        segment_files = sorted([f for f in os.listdir(segment_dir) if f.endswith('.pt')])

        total_sequences = 0
        segment_info = []

        print("Loading segment metadata...")
        for seg_file in segment_files:
            seg_path = os.path.join(segment_dir, seg_file)
            seg_data = torch.load(seg_path)
            count = len(seg_data)
            total_sequences += count
            segment_info.append({
                'file': seg_file,
                'count': count
            })
            del seg_data

        return {
            'total_sequences': total_sequences,
            'num_segments': len(segment_files),
            'segments': segment_info
        }


In [None]:
# ============================================================================
# 4. CUSTOM DATASET AND DATALOADER
# ============================================================================

class SegmentedPretrainingDataset(Dataset):
    """
    Dataset that loads segments on-demand
    Prevents loading all sequences into memory at once
    """

    def __init__(self, segment_dir: str, max_length: int = 512, pad_token_id: int = 0):
        self.segment_dir = segment_dir
        self.max_length = max_length
        self.pad_token_id = pad_token_id

        # Find all segment files
        self.segment_files = sorted([
            f for f in os.listdir(segment_dir) if f.endswith('.pt')
        ])

        if not self.segment_files:
            raise ValueError(f"No segment files found in {segment_dir}")

        # Count sequences in each segment
        print("Counting sequences in segments...")
        self.sequence_counts = []
        total = 0

        for seg_file in tqdm(self.segment_files, desc="Loading metadata"):
            seg_path = os.path.join(segment_dir, seg_file)
            seg_data = torch.load(seg_path)
            count = len(seg_data)
            self.sequence_counts.append(count)
            total += count
            del seg_data

        self.total_sequences = total
        print(f"‚úÖ Total sequences: {self.total_sequences:,}\n")

        # Currently loaded segment
        self.current_segment_idx = None
        self.current_segment_data = None

    def __len__(self) -> int:
        return self.total_sequences

    def _load_segment(self, segment_idx: int):
        """Load a specific segment into memory"""
        if self.current_segment_idx != segment_idx:
            # Clear old segment
            if self.current_segment_data is not None:
                del self.current_segment_data
                gc.collect()

            # Load new segment
            seg_file = self.segment_files[segment_idx]
            seg_path = os.path.join(self.segment_dir, seg_file)
            self.current_segment_data = torch.load(seg_path)
            self.current_segment_idx = segment_idx

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        # Find which segment contains this index
        cumsum = 0
        segment_idx = 0
        local_idx = idx

        for i, count in enumerate(self.sequence_counts):
            if idx < cumsum + count:
                segment_idx = i
                local_idx = idx - cumsum
                break
            cumsum += count

        # Load the appropriate segment
        self._load_segment(segment_idx)

        # Get the sequence
        sequence = self.current_segment_data[local_idx]

        # Pad or truncate
        if len(sequence) < self.max_length:
            padding_length = self.max_length - len(sequence)
            input_ids = sequence + [self.pad_token_id] * padding_length
            attention_mask = [1] * len(sequence) + [0] * padding_length
        else:
            input_ids = sequence[:self.max_length]
            attention_mask = [1] * self.max_length

        labels = input_ids.copy()

        return {
            'input_ids': torch.tensor(input_ids, dtype=torch.long),
            'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
            'labels': torch.tensor(labels, dtype=torch.long)
        }

In [None]:
# ============================================================================
# 5. MAIN PIPELINE
# ============================================================================

def main(save_to_drive: bool = True):
    """
    Complete preprocessing pipeline

    Args:
        save_to_drive: If True, saves all outputs to Google Drive (recommended)
    """

    print("="*60)
    print("FOUNDATION MODEL DATA PREPROCESSING PIPELINE")
    if save_to_drive:
        print("(AUTO-SAVE TO GOOGLE DRIVE)")
    print("="*60 + "\n")

    # Mount Google Drive if needed
    if save_to_drive:
        from google.colab import drive
        print("Mounting Google Drive...")
        drive.mount('/content/drive')

        drive_base = '/content/drive/MyDrive/ML_Assignment_Data'
        os.makedirs(drive_base, exist_ok=True)
        print(f"‚úÖ Save location: {drive_base}\n")

    # Step 1: Data Collection
    print("STEP 1: Data Collection")
    print("-" * 60)
    collector = DataCollector(output_dir="./raw_data")

    # Collect from multiple sources for diversity
    wiki_texts = collector.collect_wikipedia(num_samples=200000)
    web_texts = collector.collect_openwebtext(num_samples=200000)

    all_texts = wiki_texts + web_texts
    collector.get_dataset_stats()

    # Step 2: Cleaning and Preprocessing
    print("\nSTEP 2: Text Cleaning and Preprocessing")
    print("-" * 60)
    cleaner = TextCleaner(min_words=50)
    cleaned_texts = cleaner.process_documents(all_texts)

    # Clear memory
    del wiki_texts, web_texts, all_texts
    gc.collect()

    # Step 3: Tokenization (MEMORY-SAFE SEGMENTED VERSION)
    print("\nSTEP 3: Tokenization")
    print("-" * 60)
    tokenizer = TextTokenizer(model_name="gpt2", max_length=512)

    # Set output directory
    if save_to_drive:
        segment_output = os.path.join(drive_base, 'tokenization_segments')
    else:
        segment_output = './tokenization_segments'

    # Process in segments (prevents memory crash)
    segment_dir = tokenizer.process_documents_segments(
        documents=cleaned_texts,
        segment_size=50000,
        output_dir=segment_output
    )

    # Get metadata about segments
    metadata_info = tokenizer.load_segments_metadata(segment_dir)

    # Clear memory
    del cleaned_texts
    gc.collect()

    # Step 4: Create Dataset and DataLoader
    print("\nSTEP 4: Creating Dataset and DataLoader")
    print("-" * 60)

    # Use SegmentedDataset (memory-efficient)
    dataset = SegmentedPretrainingDataset(
        segment_dir=segment_dir,
        max_length=512,
        pad_token_id=tokenizer.tokenizer.pad_token_id
    )

    # Create DataLoader
    dataloader = DataLoader(
        dataset,
        batch_size=8,
        shuffle=True,
        num_workers=0,
        pin_memory=False
    )

    print(f"‚úÖ Dataset created: {len(dataset):,} sequences")
    print(f"‚úÖ DataLoader created")

    # Step 5: Save sample batches
    print("\nSTEP 5: Saving Sample Batches")
    print("-" * 60)

    # Set save directory
    if save_to_drive:
        processed_dir = os.path.join(drive_base, 'processed_data')
    else:
        processed_dir = './processed_data'

    os.makedirs(processed_dir, exist_ok=True)

    sample_batches = []
    print("Generating sample batches...")
    for i, batch in enumerate(dataloader):
        if i >= 5:  # Save first 5 batches
            break
        sample_batches.append(batch)
        print(f"  Batch {i+1} shape: {batch['input_ids'].shape}")

    batch_path = os.path.join(processed_dir, 'sample_batches.pt')
    torch.save(sample_batches, batch_path)
    print(f"‚úÖ Saved to: {batch_path}")

    # Save metadata
    print("\nSTEP 6: Saving Metadata")
    print("-" * 60)

    metadata = {
        'total_sequences': len(dataset),
        'max_sequence_length': 512,
        'vocab_size': tokenizer.tokenizer.vocab_size,
        'tokenizer_type': 'GPT2-BPE',
        'num_segments': metadata_info['num_segments'],
        'segment_dir': segment_dir,
        'saved_to_drive': save_to_drive
    }

    metadata_path = os.path.join(processed_dir, 'metadata.json')
    with open(metadata_path, 'w') as f:
        json.dump(metadata, f, indent=2)

    print(f"‚úÖ Saved to: {metadata_path}")

    print("\n" + "="*60)
    print("‚úÖ PIPELINE COMPLETE!")
    print("="*60)

    if save_to_drive:
        print(f"\nüíæ All files saved to Google Drive:")
        print(f"   {drive_base}")
        print(f"\n‚úÖ Files will persist after restart!")
    else:
        print(f"\n‚ö†Ô∏è  Files saved to temporary storage")
        print(f"   Will be deleted after restart!")

    print("\nOutputs:")
    print(f"  - Raw data: ./raw_data/")
    print(f"  - Tokenization segments: {segment_dir}/")
    print(f"  - Sample batches: {batch_path}")
    print(f"  - Metadata: {metadata_path}")
    print(f"\nDataset Info:")
    print(f"  - Total sequences: {len(dataset):,}")
    print(f"  - Number of segments: {metadata_info['num_segments']}")
    print("\n" + "="*60)

    return dataset, dataloader


if __name__ == "__main__":

    dataset, dataloader = main(save_to_drive=True)

FOUNDATION MODEL DATA PREPROCESSING PIPELINE
(AUTO-SAVE TO GOOGLE DRIVE)

Mounting Google Drive...
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
‚úÖ Save location: /content/drive/MyDrive/ML_Assignment_Data

STEP 1: Data Collection
------------------------------------------------------------
Collecting Wikipedia data...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

Wikipedia: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 200000/200000 [01:21<00:00, 2464.22it/s]


Collecting OpenWebText data...


Resolving data files:   0%|          | 0/80 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/80 [00:00<?, ?it/s]

OpenWebText: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 200000/200000 [01:01<00:00, 3251.76it/s]



DATASET STATISTICS
wikipedia.txt: 835.40 MB (0.8158 GB)
openwebtext.txt: 951.99 MB (0.9297 GB)

Total Size: 1.7455 GB


STEP 2: Text Cleaning and Preprocessing
------------------------------------------------------------


Cleaning documents: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 400000/400000 [08:57<00:00, 743.84it/s]



CLEANING STATISTICS
Total documents: 400000
Duplicates removed: 70
Low quality removed: 10695
Documents kept: 389235
Retention rate: 97.31%


STEP 3: Tokenization
------------------------------------------------------------

Tokenizer Info:
  Model: gpt2
  Vocab size: 50257
  Max length: 512
  Fast tokenizer: True
  Tokenization type: BPE (Byte-Pair Encoding)


MEMORY-SAFE SEGMENTED TOKENIZATION
  Total documents: 389,235
  Segment size: 50,000
  Number of segments: 8
  Output directory: /content/drive/MyDrive/ML_Assignment_Data/tokenization_segments

‚úì Segment 0/7 already exists, loading...
‚úì Segment 1/7 already exists, loading...
‚úì Segment 2/7 already exists, loading...
‚úì Segment 3/7 already exists, loading...
‚úì Segment 4/7 already exists, loading...
‚úì Segment 5/7 already exists, loading...
‚úì Segment 6/7 already exists, loading...
‚úì Segment 7/7 already exists, loading...

‚úÖ ALL SEGMENTS COMPLETE
  Total sequences: 930,229
  Segments saved in: /content/drive/MyDrive

Loading metadata: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 8/8 [11:16<00:00, 84.53s/it]


‚úÖ Total sequences: 930,229

‚úÖ Dataset created: 930,229 sequences
‚úÖ DataLoader created

STEP 5: Saving Sample Batches
------------------------------------------------------------
Generating sample batches...
  Batch 1 shape: torch.Size([8, 512])
  Batch 2 shape: torch.Size([8, 512])
  Batch 3 shape: torch.Size([8, 512])
  Batch 4 shape: torch.Size([8, 512])
  Batch 5 shape: torch.Size([8, 512])
‚úÖ Saved to: /content/drive/MyDrive/ML_Assignment_Data/processed_data/sample_batches.pt

STEP 6: Saving Metadata
------------------------------------------------------------
‚úÖ Saved to: /content/drive/MyDrive/ML_Assignment_Data/processed_data/metadata.json

‚úÖ PIPELINE COMPLETE!

üíæ All files saved to Google Drive:
   /content/drive/MyDrive/ML_Assignment_Data

‚úÖ Files will persist after restart!

Outputs:
  - Raw data: ./raw_data/
  - Tokenization segments: /content/drive/MyDrive/ML_Assignment_Data/tokenization_segments/
  - Sample batches: /content/drive/MyDrive/ML_Assignment_Data/