# ELM Data Preparation Pipeline - Google Colab

This notebook implements a complete data preparation pipeline for Embedding Language Models (ELM) using:
- **Dataset**: WikiText-2 from HuggingFace
- **Model**: Qwen3-Embedding-4B (2560-dimensional embeddings)

## Pipeline Overview
1. Setup & Installation
2. Configuration
3. Download WikiText-2
4. Extract & Clean Paragraphs
5. Filter by Token Count
6. Split Dataset
7. Save Processed Data
8. Generate Embeddings
9. Load & Use Dataset

**Estimated Runtime**: 10-20 minutes (depending on GPU availability)

---
## 1. Setup & Installation

Install required packages for the pipeline.

In [None]:
# Install required packages
!pip install -q transformers>=4.51.0 datasets>=2.0.0 polars>=0.20.0 safetensors>=0.4.0 tqdm

# Install flash-attention (optional, improves performance)
# Note: This may take a few minutes to compile
!pip install -q flash-attn --no-build-isolation

In [None]:
# Import libraries
import torch
import torch.nn.functional as F
import numpy as np
import polars as pl
import re
import random
import time
from pathlib import Path
from dataclasses import dataclass, field
from typing import List, Dict, Tuple, Optional
from tqdm import tqdm
from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer, AutoModel
from safetensors.numpy import save_file, load_file
from torch import Tensor
from torch.utils.data import Dataset, DataLoader

print("✓ All packages imported successfully")

In [None]:
# Check GPU availability
if torch.cuda.is_available():
    print(f"✓ GPU Available: {torch.cuda.get_device_name(0)}")
    print(f"  VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f"  CUDA Version: {torch.version.cuda}")
else:
    print("⚠ WARNING: No GPU detected. This will be very slow on CPU.")
    print("  Go to Runtime > Change runtime type > Hardware accelerator > GPU")

### Optional: Mount Google Drive for Persistent Storage

Uncomment and run this cell if you want to save data to Google Drive (recommended for large datasets).

In [None]:
# # Mount Google Drive
# from google.colab import drive
# drive.mount('/content/drive')
# 
# # Use Drive for data storage
# BASE_DIR = Path('/content/drive/MyDrive/elm_data')
# BASE_DIR.mkdir(parents=True, exist_ok=True)
# print(f"Using Google Drive: {BASE_DIR}")

---
## 2. Configuration

Define configuration settings for the pipeline.

In [None]:
@dataclass
class Config:
    """Configuration for the ELM data preparation pipeline."""
    
    # Paths (Colab-friendly)
    base_dir: Path = field(default_factory=lambda: Path('/content/elm_data'))
    data_dir: Path = field(init=False)
    processed_dir: Path = field(init=False)
    embeddings_dir: Path = field(init=False)
    
    # Dataset parameters
    hf_dataset_name: str = "Salesforce/wikitext"
    hf_dataset_config: str = "wikitext-2-v1"
    
    # Text filtering parameters
    min_tokens: int = 100
    max_tokens: int = 2000
    
    # Dataset split ratios
    train_ratio: float = 0.8
    val_ratio: float = 0.1
    test_ratio: float = 0.1
    random_seed: int = 42
    
    # Embedding model parameters
    model_name: str = "Qwen/Qwen3-Embedding-4B"
    embedding_dim: int = 2560
    max_length: int = 8192
    batch_size: int = 4  # Reduced for Colab T4 GPU (16GB VRAM)
    
    # Model optimization parameters
    use_flash_attention: bool = True
    use_fp16: bool = True
    
    def __post_init__(self):
        """Initialize derived paths."""
        self.data_dir = self.base_dir / "data"
        self.processed_dir = self.data_dir / "wikitext2_processed"
        self.embeddings_dir = self.data_dir / "embeddings"
        
        # Validate ratios
        if not abs(self.train_ratio + self.val_ratio + self.test_ratio - 1.0) < 1e-6:
            raise ValueError("Train, val, and test ratios must sum to 1.0")
    
    def create_directories(self):
        """Create necessary directories if they don't exist."""
        self.data_dir.mkdir(parents=True, exist_ok=True)
        self.processed_dir.mkdir(parents=True, exist_ok=True)
        self.embeddings_dir.mkdir(parents=True, exist_ok=True)
    
    def get_processed_path(self, split: str) -> Path:
        """Get path to processed data file for a given split."""
        return self.processed_dir / f"{split}.parquet"
    
    def get_embeddings_path(self, split: str) -> Path:
        """Get path to embeddings file for a given split."""
        return self.embeddings_dir / f"{split}_embeddings.safetensors"

# Initialize configuration
config = Config()
config.create_directories()

print("Configuration:")
print(f"  Base directory: {config.base_dir}")
print(f"  Model: {config.model_name}")
print(f"  Token range: {config.min_tokens}-{config.max_tokens}")
print(f"  Batch size: {config.batch_size}")
print(f"  Split: {config.train_ratio}/{config.val_ratio}/{config.test_ratio}")

---
## 3. Utility Functions

Helper functions for formatting and logging.

In [None]:
def format_time(seconds: float) -> str:
    """Format seconds as human-readable string."""
    if seconds < 60:
        return f"{seconds:.1f}s"
    elif seconds < 3600:
        minutes = int(seconds // 60)
        secs = seconds % 60
        return f"{minutes}m {secs:.0f}s"
    else:
        hours = int(seconds // 3600)
        minutes = int((seconds % 3600) // 60)
        secs = seconds % 60
        return f"{hours}h {minutes}m {secs:.0f}s"

def format_size(num_bytes: int) -> str:
    """Format bytes as human-readable string."""
    for unit in ["B", "KB", "MB", "GB", "TB"]:
        if num_bytes < 1024.0:
            return f"{num_bytes:.2f} {unit}"
        num_bytes /= 1024.0
    return f"{num_bytes:.2f} PB"

def count_parameters(model) -> int:
    """Count the number of trainable parameters in a model."""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print("✓ Utility functions defined")

---
## 4. STEP 1: Download WikiText-2 Dataset

Download the WikiText-2 dataset from HuggingFace.

In [None]:
print("=" * 80)
print("STEP 1: Loading WikiText-2 dataset")
print("=" * 80)

start_time = time.time()

# Load dataset from HuggingFace
dataset = load_dataset(
    config.hf_dataset_name,
    config.hf_dataset_config,
    trust_remote_code=False
)

print(f"\n✓ Dataset loaded successfully in {format_time(time.time() - start_time)}")
print(f"  Train samples: {len(dataset['train'])}")
print(f"  Validation samples: {len(dataset['validation'])}")
print(f"  Test samples: {len(dataset['test'])}")

---
## 5. STEP 2: Extract & Clean Paragraphs

Extract paragraphs from the dataset and clean Wikipedia formatting artifacts.

In [None]:
def clean_text(text: str) -> str:
    """Clean Wikipedia formatting artifacts from text."""
    # Remove Wikipedia section headers (e.g., " = = Section = = ")
    text = re.sub(r'\s*=\s+=\s+.*?\s+=\s+=\s*', ' ', text)
    text = re.sub(r'\s*=\s+.*?\s+=\s*', ' ', text)
    
    # Remove Wikipedia formatting artifacts
    text = text.replace("@-@", "-")
    text = text.replace("@.@", ".")
    text = text.replace("@,@", ",")
    
    # Remove <unk> tokens
    text = text.replace("<unk>", "")
    
    # Collapse multiple spaces
    text = re.sub(r'\s+', ' ', text)
    
    # Strip leading/trailing whitespace
    text = text.strip()
    
    return text

def is_low_quality(text: str) -> bool:
    """Check if text is low quality and should be filtered out."""
    if not text:
        return True
    
    total_chars = len(text)
    if total_chars == 0:
        return True
    
    # Count special characters and digits
    special_chars = sum(1 for c in text if not c.isalnum() and not c.isspace())
    digits = sum(1 for c in text if c.isdigit())
    
    # Filter if mostly special characters (>50%)
    if special_chars / total_chars > 0.5:
        return True
    
    # Filter if mostly digits (>50%)
    if digits / total_chars > 0.5:
        return True
    
    # Filter very short texts (< 20 characters)
    if total_chars < 20:
        return True
    
    return False

def extract_paragraphs(dataset: DatasetDict) -> List[str]:
    """Extract paragraphs from HuggingFace dataset."""
    all_paragraphs = []
    current_paragraph = []
    
    # Process all splits together
    for split in ["train", "validation", "test"]:
        print(f"  Processing {split} split...")
        
        for item in tqdm(dataset[split], desc=f"  Extracting from {split}"):
            text = item.get("text", "").strip()
            
            # Empty line indicates paragraph break
            if not text:
                if current_paragraph:
                    paragraph = " ".join(current_paragraph)
                    all_paragraphs.append(paragraph)
                    current_paragraph = []
            else:
                current_paragraph.append(text)
        
        # Don't forget the last paragraph
        if current_paragraph:
            paragraph = " ".join(current_paragraph)
            all_paragraphs.append(paragraph)
            current_paragraph = []
    
    return all_paragraphs

print("=" * 80)
print("STEP 2: Extracting and preprocessing paragraphs")
print("=" * 80)

start_time = time.time()
paragraphs = extract_paragraphs(dataset)

print(f"\n✓ Extracted {len(paragraphs):,} raw paragraphs in {format_time(time.time() - start_time)}")

---
## 6. STEP 3: Filter Paragraphs by Token Count

Filter paragraphs to keep only those with 100-2000 tokens.

In [None]:
print("=" * 80)
print("STEP 3: Filtering paragraphs")
print("=" * 80)

start_time = time.time()

# Load tokenizer for filtering
print(f"Loading tokenizer: {config.model_name}")
tokenizer = AutoTokenizer.from_pretrained(config.model_name)

# Filter paragraphs
filtered_data = []

print(f"\nFiltering paragraphs (token range: {config.min_tokens}-{config.max_tokens})...")
for text in tqdm(paragraphs, desc="Filtering"):
    # Clean the text
    cleaned_text = clean_text(text)
    
    # Skip low-quality text
    if is_low_quality(cleaned_text):
        continue
    
    # Tokenize to count tokens
    tokens = tokenizer.encode(cleaned_text, add_special_tokens=False)
    num_tokens = len(tokens)
    
    # Filter by token count
    if config.min_tokens <= num_tokens <= config.max_tokens:
        filtered_data.append({
            "text": cleaned_text,
            "token_count": num_tokens,
            "char_count": len(cleaned_text),
        })

print(f"\n✓ Kept {len(filtered_data):,}/{len(paragraphs):,} paragraphs "
      f"({100 * len(filtered_data) / len(paragraphs):.1f}%)")
print(f"  Completed in {format_time(time.time() - start_time)}")

---
## 7. STEP 4: Split Dataset

Split the filtered data into train/validation/test sets (80/10/10).

In [None]:
print("=" * 80)
print("STEP 4: Splitting dataset")
print("=" * 80)

start_time = time.time()

# Shuffle data with fixed seed
random.seed(config.random_seed)
shuffled_data = filtered_data.copy()
random.shuffle(shuffled_data)

# Calculate split indices
n_total = len(shuffled_data)
n_train = int(n_total * config.train_ratio)
n_val = int(n_total * config.val_ratio)

# Split the data
train_data = shuffled_data[:n_train]
val_data = shuffled_data[n_train:n_train + n_val]
test_data = shuffled_data[n_train + n_val:]

print(f"\n✓ Dataset split ({config.train_ratio}/{config.val_ratio}/{config.test_ratio}):")
print(f"  Train: {len(train_data):,} samples")
print(f"  Val: {len(val_data):,} samples")
print(f"  Test: {len(test_data):,} samples")
print(f"  Total: {len(train_data) + len(val_data) + len(test_data):,} samples")
print(f"  Completed in {format_time(time.time() - start_time)}")

---
## 8. STEP 5: Save Processed Data

Save the processed text data as Parquet files.

In [None]:
print("=" * 80)
print("STEP 5: Saving processed data")
print("=" * 80)

start_time = time.time()

# Add text IDs to each split
for idx, item in enumerate(train_data):
    item["text_id"] = f"train_{idx}"

for idx, item in enumerate(val_data):
    item["text_id"] = f"val_{idx}"

for idx, item in enumerate(test_data):
    item["text_id"] = f"test_{idx}"

# Save each split
splits = {
    "train": train_data,
    "val": val_data,
    "test": test_data,
}

for split_name, split_data in splits.items():
    output_path = config.get_processed_path(split_name)
    
    # Convert to Polars DataFrame and save
    df = pl.DataFrame(split_data)
    df.write_parquet(output_path)
    
    file_size = output_path.stat().st_size
    print(f"  ✓ Saved {split_name}: {output_path.name} ({len(split_data):,} samples, {format_size(file_size)})")

print(f"\n✓ All data saved successfully in {format_time(time.time() - start_time)}")
print(f"  Output directory: {config.processed_dir}")

---
## 9. STEP 6: Generate Embeddings

Load the Qwen3-Embedding-4B model and generate embeddings for all splits.

**Note**: This step takes the longest (1-2 hours depending on GPU).

In [None]:
def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
    """Extract embeddings from the last token position."""
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[
            torch.arange(batch_size, device=last_hidden_states.device),
            sequence_lengths
        ]

@torch.no_grad()
def generate_embeddings_batch(
    model: AutoModel,
    tokenizer: AutoTokenizer,
    texts: List[str],
    config: Config,
    desc: str = "Generating embeddings"
) -> np.ndarray:
    """Generate embeddings for a list of texts."""
    all_embeddings = []
    num_batches = (len(texts) + config.batch_size - 1) // config.batch_size
    
    for i in tqdm(range(0, len(texts), config.batch_size), total=num_batches, desc=desc):
        batch_texts = texts[i:i + config.batch_size]
        
        # Tokenize batch
        batch_dict = tokenizer(
            batch_texts,
            padding=True,
            truncation=True,
            max_length=config.max_length,
            return_tensors="pt",
        )
        
        # Move to device
        batch_dict = {k: v.to(model.device) for k, v in batch_dict.items()}
        
        # Forward pass
        outputs = model(**batch_dict)
        
        # Extract embeddings using last token pooling
        embeddings = last_token_pool(
            outputs.last_hidden_state,
            batch_dict['attention_mask']
        )
        
        # Normalize embeddings (L2 norm)
        embeddings = F.normalize(embeddings, p=2, dim=1)
        
        # Move to CPU and convert to numpy
        embeddings_np = embeddings.cpu().float().numpy()
        all_embeddings.append(embeddings_np)
    
    # Concatenate all batches
    all_embeddings = np.vstack(all_embeddings)
    return all_embeddings

print("✓ Embedding functions defined")

In [None]:
print("=" * 80)
print("STEP 6: Generating embeddings")
print("=" * 80)

start_time = time.time()

# Load embedding model
print(f"\nLoading embedding model: {config.model_name}")

# Load tokenizer with left padding
embedding_tokenizer = AutoTokenizer.from_pretrained(
    config.model_name,
    padding_side='left'
)

# Load model with optimizations
try:
    if config.use_flash_attention and torch.cuda.is_available():
        print("  Loading with flash_attention_2...")
        model = AutoModel.from_pretrained(
            config.model_name,
            attn_implementation="flash_attention_2",
            torch_dtype=torch.float16 if config.use_fp16 else torch.float32
        )
    else:
        print("  Loading without flash attention...")
        model = AutoModel.from_pretrained(
            config.model_name,
            torch_dtype=torch.float16 if config.use_fp16 else torch.float32
        )
except Exception as e:
    print(f"  Warning: {e}")
    print("  Falling back to standard attention...")
    model = AutoModel.from_pretrained(
        config.model_name,
        torch_dtype=torch.float16 if config.use_fp16 else torch.float32
    )

# Move to GPU if available
if torch.cuda.is_available():
    model = model.cuda()
    print(f"  ✓ Model loaded on GPU: {torch.cuda.get_device_name(0)}")
else:
    print("  ⚠ Using CPU (this will be VERY slow)")

model.eval()

num_params = count_parameters(model)
print(f"  Model parameters: {num_params:,}")
print(f"  Batch size: {config.batch_size}")
print(f"  Max length: {config.max_length}")

In [None]:
# Generate embeddings for each split
print("\nGenerating embeddings for all splits...\n")

for split_name in ["train", "val", "test"]:
    print(f"Processing {split_name} split...")
    
    # Load texts from parquet
    parquet_path = config.get_processed_path(split_name)
    df = pl.read_parquet(parquet_path)
    texts = df["text"].to_list()
    
    print(f"  Loaded {len(texts):,} texts")
    
    # Generate embeddings
    embeddings = generate_embeddings_batch(
        model=model,
        tokenizer=embedding_tokenizer,
        texts=texts,
        config=config,
        desc=f"  Embedding {split_name}"
    )
    
    print(f"  Generated embeddings shape: {embeddings.shape}")
    
    # Save embeddings
    output_path = config.get_embeddings_path(split_name)
    
    tensors = {"embeddings": embeddings}
    metadata = {
        "split": split_name,
        "model": config.model_name,
        "num_texts": str(len(texts)),
        "shape": str(embeddings.shape),
        "dtype": str(embeddings.dtype),
    }
    
    save_file(tensors, str(output_path), metadata=metadata)
    
    file_size = output_path.stat().st_size
    print(f"  ✓ Saved to: {output_path.name} ({format_size(file_size)})\n")

print(f"✓ All embeddings completed in {format_time(time.time() - start_time)}")
print(f"  Output directory: {config.embeddings_dir}")

# Free up GPU memory
del model
if torch.cuda.is_available():
    torch.cuda.empty_cache()
print("  GPU memory cleared")

---
## 10. STEP 7: Dataset Class

Define PyTorch Dataset class for loading and using the processed data.

In [None]:
class ELMDataset(Dataset):
    """Dataset class for ELM that loads text, embeddings, and metadata."""
    
    def __init__(
        self,
        data_dir: Path,
        split: str = "train",
        load_embeddings: bool = True,
        config: Optional[Config] = None
    ):
        """Initialize ELMDataset."""
        self.split = split
        self.load_embeddings = load_embeddings
        
        if config is None:
            config = Config()
            config.base_dir = data_dir.parent if data_dir.name == "data" else data_dir
        
        self.config = config
        
        # Load text data
        self.parquet_path = config.get_processed_path(split)
        if not self.parquet_path.exists():
            raise FileNotFoundError(f"{split} data file not found: {self.parquet_path}")
        
        self.df = pl.read_parquet(self.parquet_path)
        
        # Extract columns
        self.texts = self.df["text"].to_list()
        self.text_ids = self.df["text_id"].to_list()
        self.token_counts = self.df["token_count"].to_list()
        self.char_counts = self.df["char_count"].to_list()
        
        # Load embeddings if requested
        self.embeddings = None
        if load_embeddings:
            self.embeddings_path = config.get_embeddings_path(split)
            if not self.embeddings_path.exists():
                raise FileNotFoundError(f"{split} embeddings file not found: {self.embeddings_path}")
            
            tensors = load_file(str(self.embeddings_path))
            self.embeddings = tensors["embeddings"]
            
            # Validate dimensions match
            if len(self.embeddings) != len(self.texts):
                raise ValueError(
                    f"Mismatch between texts ({len(self.texts)}) "
                    f"and embeddings ({len(self.embeddings)})"
                )
    
    def __len__(self) -> int:
        return len(self.texts)
    
    def __getitem__(self, idx: int) -> Dict:
        """Get a single sample."""
        item = {
            "text": self.texts[idx],
            "metadata": {
                "text_id": self.text_ids[idx],
                "token_count": self.token_counts[idx],
                "char_count": self.char_counts[idx],
            }
        }
        
        if self.embeddings is not None:
            item["embedding"] = self.embeddings[idx]
        
        return item
    
    def interpolate_embeddings(self, idx1: int, idx2: int, alpha: float) -> np.ndarray:
        """Linear interpolation between two embeddings."""
        if self.embeddings is None:
            raise ValueError("Embeddings not loaded. Set load_embeddings=True")
        
        emb1 = self.embeddings[idx1]
        emb2 = self.embeddings[idx2]
        
        # Linear interpolation
        interpolated = (1 - alpha) * emb1 + alpha * emb2
        
        # Re-normalize to unit length
        norm = np.linalg.norm(interpolated)
        if norm > 0:
            interpolated = interpolated / norm
        
        return interpolated
    
    def get_statistics(self) -> Dict:
        """Get dataset statistics."""
        stats = {
            "split": self.split,
            "num_samples": len(self),
            "avg_token_count": np.mean(self.token_counts),
            "avg_char_count": np.mean(self.char_counts),
            "min_token_count": np.min(self.token_counts),
            "max_token_count": np.max(self.token_counts),
        }
        
        if self.embeddings is not None:
            stats["embedding_dim"] = self.embeddings.shape[1]
            stats["embedding_dtype"] = str(self.embeddings.dtype)
        
        return stats
    
    def get_dataloader(self, batch_size: int, shuffle: bool = False) -> DataLoader:
        """Create a PyTorch DataLoader."""
        return DataLoader(
            self,
            batch_size=batch_size,
            shuffle=shuffle,
            collate_fn=ELMCollator(),
        )

class ELMCollator:
    """Collate function for batching variable-length texts."""
    
    def __call__(self, batch: List[Dict]) -> Dict:
        texts = [item["text"] for item in batch]
        metadata = [item["metadata"] for item in batch]
        
        collated = {
            "text": texts,
            "metadata": metadata,
        }
        
        if "embedding" in batch[0]:
            embeddings = np.stack([item["embedding"] for item in batch])
            collated["embedding"] = embeddings
        
        return collated

print("✓ Dataset classes defined")

---
## 11. STEP 8: Example Usage & Verification

Load the processed data and demonstrate usage examples.

In [None]:
print("=" * 80)
print("STEP 8: Example Usage & Verification")
print("=" * 80)

# Load training dataset
print("\n1. Loading dataset with embeddings...")
train_dataset = ELMDataset(
    data_dir=config.data_dir,
    split="train",
    load_embeddings=True,
    config=config
)

print(f"  ✓ Loaded {len(train_dataset):,} samples")

# Get a sample
sample = train_dataset[0]
print(f"\n  Sample 0:")
print(f"    Text: {sample['text'][:100]}...")
print(f"    Embedding shape: {sample['embedding'].shape}")
print(f"    Token count: {sample['metadata']['token_count']}")

In [None]:
# Dataset statistics
print("\n2. Dataset statistics...")
stats = train_dataset.get_statistics()
print("\n  Train dataset:")
for key, value in stats.items():
    if isinstance(value, float):
        print(f"    {key}: {value:.2f}")
    else:
        print(f"    {key}: {value}")

In [None]:
# Using DataLoader
print("\n3. Creating PyTorch DataLoader...")
dataloader = train_dataset.get_dataloader(batch_size=8, shuffle=True)
print(f"  ✓ DataLoader created with {len(dataloader)} batches")

# Get first batch
first_batch = next(iter(dataloader))
print(f"\n  First batch:")
print(f"    Number of texts: {len(first_batch['text'])}")
print(f"    Embeddings shape: {first_batch['embedding'].shape}")
print(f"    Sample text: {first_batch['text'][0][:80]}...")

In [None]:
# Embedding interpolation
print("\n4. Embedding interpolation...")
idx1, idx2 = 0, 100
print(f"\n  Interpolating between samples {idx1} and {idx2}:")
print(f"    Text 1: {train_dataset[idx1]['text'][:60]}...")
print(f"    Text 2: {train_dataset[idx2]['text'][:60]}...")

print("\n  Interpolated embeddings:")
for alpha in [0.0, 0.25, 0.5, 0.75, 1.0]:
    interpolated = train_dataset.interpolate_embeddings(idx1, idx2, alpha)
    print(f"    alpha={alpha:.2f}: shape={interpolated.shape}, "
          f"norm={np.linalg.norm(interpolated):.4f}")

In [None]:
# Embedding similarity
print("\n5. Computing embedding similarity...")
emb1 = train_dataset[0]['embedding']
emb2 = train_dataset[1]['embedding']
emb3 = train_dataset[100]['embedding']

# Compute cosine similarity (embeddings are already normalized)
sim_01 = np.dot(emb1, emb2)
sim_02 = np.dot(emb1, emb3)
sim_12 = np.dot(emb2, emb3)

print("\n  Cosine similarities:")
print(f"    Sample 0 vs 1: {sim_01:.4f}")
print(f"    Sample 0 vs 100: {sim_02:.4f}")
print(f"    Sample 1 vs 100: {sim_12:.4f}")

print(f"\n  Sample 0: {train_dataset[0]['text'][:80]}...")
print(f"  Sample 1: {train_dataset[1]['text'][:80]}...")
print(f"  Sample 100: {train_dataset[100]['text'][:80]}...")

In [None]:
# Load all splits
print("\n6. Loading all splits...")
splits = {}
for split_name in ["train", "val", "test"]:
    splits[split_name] = ELMDataset(
        data_dir=config.data_dir,
        split=split_name,
        load_embeddings=True,
        config=config
    )

print("\n  Dataset sizes:")
for split_name, dataset in splits.items():
    print(f"    {split_name}: {len(dataset):,} samples")

total_samples = sum(len(ds) for ds in splits.values())
print(f"    Total: {total_samples:,} samples")

print("\n" + "=" * 80)
print("✓ PIPELINE COMPLETED SUCCESSFULLY")
print("=" * 80)
print(f"\nOutput files:")
print(f"  Processed data: {config.processed_dir}")
print(f"  Embeddings: {config.embeddings_dir}")
print(f"\nYou can now use the ELMDataset class to load and use this data!")

---
## Summary

This notebook has successfully:
1. ✓ Downloaded WikiText-2 dataset
2. ✓ Extracted and cleaned paragraphs
3. ✓ Filtered by token count (100-2000 tokens)
4. ✓ Split into train/val/test (80/10/10)
5. ✓ Saved processed data as Parquet files
6. ✓ Generated 2560-dimensional embeddings using Qwen3-Embedding-4B
7. ✓ Created PyTorch Dataset class for easy loading

### Next Steps

You can now use the processed data for:
- Training language models
- Embedding-based retrieval systems
- Semantic similarity tasks
- Text generation experiments

### Downloading Results

To download the processed data:
```python
# Zip the output directory
!zip -r elm_data.zip /content/elm_data/

# Download (in Colab)
from google.colab import files
files.download('elm_data.zip')
```

Or use the Google Drive mount option at the beginning to save directly to your Drive.