# Synthetic PII Data Generation

Generate 11,000 rows of challenging PII examples targeting the six feature dimensions from Singh & Narayanan (2025)
"Unmasking the Reality of PII Masking Models".

Target dimensions: - basic: Straightforward, well-formatted entities - contextual: Entities requiring disambiguation -
noisy: Real-world imperfections (typos, OCR errors, abbreviations) - evolving: New/emerging PII formats (crypto, UPI,
modern handles) - multilingual: International PII formats in English text - adversarial: Intentionally confusing inputs
designed to fool NER models


## Imports and Environment Setup


In [1]:
%pip install xai-sdk openai pandas faker tqdm pydantic python-dotenv tenacity json-repair

Collecting xai-sdk
  Downloading xai_sdk-1.4.1-py3-none-any.whl.metadata (26 kB)
Collecting openai
  Downloading openai-2.8.1-py3-none-any.whl.metadata (29 kB)
Collecting faker
  Downloading faker-38.2.0-py3-none-any.whl.metadata (16 kB)
Collecting json-repair
  Downloading json_repair-0.54.2-py3-none-any.whl.metadata (12 kB)
Collecting grpcio<2,>=1.72.1 (from xai-sdk)
  Using cached grpcio-1.76.0-cp313-cp313-win_amd64.whl.metadata (3.8 kB)
Collecting opentelemetry-sdk<2,>=1.36.0 (from xai-sdk)
  Downloading opentelemetry_sdk-1.39.0-py3-none-any.whl.metadata (1.5 kB)
Collecting packaging<26,>=25.0 (from xai-sdk)
  Using cached packaging-25.0-py3-none-any.whl.metadata (3.3 kB)
Collecting protobuf<7,>=5.29.4 (from xai-sdk)
  Using cached protobuf-6.33.1-cp310-abi3-win_amd64.whl.metadata (593 bytes)
Collecting opentelemetry-api==1.39.0 (from opentelemetry-sdk<2,>=1.36.0->xai-sdk)
  Downloading opentelemetry_api-1.39.0-py3-none-any.whl.metadata (1.5 kB)
Collecting opentelemetry-semantic-co

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
streamlit 1.45.1 requires packaging<25,>=20, but you have packaging 25.0 which is incompatible.


In [2]:
import asyncio
import json
import os
import random
import re
import string
import sys
import time
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
from pathlib import Path
from typing import Any

import pandas as pd
from dotenv import load_dotenv
from faker import Faker
from pydantic import BaseModel, Field, field_validator, model_validator
from tenacity import (
    retry,
    stop_after_attempt,
    wait_exponential,
    retry_if_exception_type,
)
from tqdm.auto import tqdm

In [3]:
# Load environment variables from .env file
load_dotenv()

# Verify required API keys are present
XAI_API_KEY: str | None = os.getenv("XAI_API_KEY")
if not XAI_API_KEY:
    raise EnvironmentError(
        "XAI_API_KEY not found in environment. "
        "Create a .env file with your xAI API key."
    )

print("âœ“ Environment loaded successfully")
print(f"  XAI_API_KEY: {'*' * 8}...{XAI_API_KEY[-4:]}")

âœ“ Environment loaded successfully
  XAI_API_KEY: ********...MXdq


## Label Mapping Configuration

Label schema harmonization between the paper's 16 PII categories and ai4privacy dataset labels. All synthetic data uses
this unified mapping.


In [4]:
# Maps paper's categories â†’ ai4privacy labels (primary label is first)
PAPER_TO_AI4PRIVACY: dict[str, list[str]] = {
    "NAME": ["FIRSTNAME", "LASTNAME", "MIDDLENAME"],
    "EMAIL": ["EMAIL"],
    "PHONE": ["PHONENUMBER"],
    "DATE_OF_BIRTH": ["DOB"],
    "POSTAL_CODE": ["ZIPCODE"],
    "CREDIT_CARD": ["CREDITCARDNUMBER"],
    "BANK_ACCOUNT": ["ACCOUNTNUMBER", "IBAN", "BIC"],
    "DRIVER_LICENSE": ["DRIVERLICENSE"],
    "PASSPORT_NUMBER": ["PASSPORT"],
    "NATIONAL_IDENTITY_SSN_AADHAR": ["SSN"],
    "OTHER_NATIONAL_IDENTITY": ["NATIONALID"],
    "TAX_IDENTIFICATION": ["TAXID"],
    "VEHICLE_REGISTRATION": ["VEHICLEVRM", "VEHICLEVIN"],
    "INSURANCE_NUMBER": ["INSURANCENUMBER"],
    "BANK_UPI_ID": ["UPIID"],
    "NAMES_OF_PLACES_OR_NOUNS": ["CITY", "STATE", "COUNTY", "STREET"],
}

# Inverse mapping for evaluation (ai4privacy â†’ paper categories)
AI4PRIVACY_TO_PAPER: dict[str, str] = {
    v: k for k, vs in PAPER_TO_AI4PRIVACY.items() for v in vs
}

# All base labels for synthetic generation (we use paper categories)
ALL_PII_TYPES: list[str] = list(PAPER_TO_AI4PRIVACY.keys())

# Feature dimensions from Singh & Narayanan (2025)
FEATURE_DIMENSIONS: list[str] = [
    "basic",
    "contextual",
    "noisy",
    "evolving",
    "multilingual",
    "adversarial",
]

# Locales for international PII formats (all in English text context)
SUPPORTED_LOCALES: dict[str, str] = {
    "en_US": "United States",
    "en_GB": "United Kingdom",
    "en_IN": "India",
    "de_DE": "Germany",
    "fr_FR": "France",
    "en_AU": "Australia",
    "en_CA": "Canada",
    "it_IT": "Italy",
    "es_ES": "Spain",
    "nl_NL": "Netherlands",
}

print(f"âœ“ Configured {len(ALL_PII_TYPES)} PII types across {len(FEATURE_DIMENSIONS)} dimensions")
print(f"  PII Types: {ALL_PII_TYPES}")
print(f"  Dimensions: {FEATURE_DIMENSIONS}")
print(f"  Locales: {list(SUPPORTED_LOCALES.keys())}")

âœ“ Configured 16 PII types across 6 dimensions
  PII Types: ['NAME', 'EMAIL', 'PHONE', 'DATE_OF_BIRTH', 'POSTAL_CODE', 'CREDIT_CARD', 'BANK_ACCOUNT', 'DRIVER_LICENSE', 'PASSPORT_NUMBER', 'NATIONAL_IDENTITY_SSN_AADHAR', 'OTHER_NATIONAL_IDENTITY', 'TAX_IDENTIFICATION', 'VEHICLE_REGISTRATION', 'INSURANCE_NUMBER', 'BANK_UPI_ID', 'NAMES_OF_PLACES_OR_NOUNS']
  Dimensions: ['basic', 'contextual', 'noisy', 'evolving', 'multilingual', 'adversarial']
  Locales: ['en_US', 'en_GB', 'en_IN', 'de_DE', 'fr_FR', 'en_AU', 'en_CA', 'it_IT', 'es_ES', 'nl_NL']


## Pydantic Schemas for Synthetic Output

Pydantic schemas defining the structure of synthetic PII samples. All generated data must conform to these schemas for
validation.


In [5]:
class FeatureDimension(str, Enum):
    """
    The six NER failure mode dimensions from Singh & Narayanan (2025).
    
    Each dimension represents a specific type of challenge for PII detection:
        - basic: Standard, well-formatted entities with clear boundaries
        - contextual: Ambiguous entities requiring surrounding context
        - noisy: Real-world text imperfections and formatting variations
        - evolving: Modern/emerging PII formats not in traditional training data
        - multilingual: International formats embedded in English prose
        - adversarial: Intentionally deceptive patterns designed to evade detection
    """
    BASIC = "basic"
    CONTEXTUAL = "contextual"
    NOISY = "noisy"
    EVOLVING = "evolving"
    MULTILINGUAL = "multilingual"
    ADVERSARIAL = "adversarial"


class EntitySpan(BaseModel):
    """
    A single PII entity annotation with character-level span positions.
    
    Attributes:
        start: Starting character index (0-based, inclusive)
        end: Ending character index (exclusive, like Python slicing)
        label: PII type label from the unified taxonomy
        text: The actual text content of the entity (for verification)
    """
    start: int = Field(..., ge=0, description="Start character index (inclusive)")
    end: int = Field(..., gt=0, description="End character index (exclusive)")
    label: str = Field(..., description="PII type label")
    text: str = Field(..., min_length=1, description="Entity text content")
    
    @model_validator(mode="after")
    def validate_span_bounds(self) -> "EntitySpan":
        """Ensure start < end for valid span."""
        if self.start >= self.end:
            raise ValueError(f"Invalid span: start ({self.start}) must be < end ({self.end})")
        return self


class SyntheticSample(BaseModel):
    """
    A complete synthetic PII training sample with text and annotations.
    
    This schema captures everything needed for training and validation:
    the generated text, all entity annotations, metadata about the
    generation process, and the feature dimension being targeted.
    
    Attributes:
        text: The generated English text containing PII entities
        entities: List of all PII entity annotations with spans
        feature_dimension: Which NER challenge dimension this targets
        seed_pii_type: The primary PII type used to seed generation
        seed_pii_value: The actual PII value that was seeded
        seed_pii_locale: Locale/region for international formats
        scenario: Brief description of the text scenario/context
        type_variant: Specific variant or sub-type of the PII
        generation_id: Unique identifier for this generation attempt
        timestamp: When this sample was generated
    """
    text: str = Field(..., min_length=50, max_length=600, description="Generated text")
    entities: list[EntitySpan] = Field(..., min_length=1, description="Entity annotations")
    feature_dimension: FeatureDimension = Field(..., description="Target dimension")
    seed_pii_type: str = Field(..., description="Primary PII type")
    seed_pii_value: str = Field(..., description="Seeded PII value")
    seed_pii_locale: str | None = Field(None, description="Locale for international formats")
    scenario: str = Field(..., description="Text scenario description")
    type_variant: str = Field(..., description="PII format variant")
    generation_id: str = Field(..., description="Unique generation ID")
    timestamp: str = Field(default_factory=lambda: datetime.utcnow().isoformat())
    
    @model_validator(mode="after")
    def validate_entities_in_text(self) -> "SyntheticSample":
        """Verify all entity spans are valid within the text."""
        text_len = len(self.text)
        for entity in self.entities:
            if entity.end > text_len:
                raise ValueError(
                    f"Entity span [{entity.start}:{entity.end}] exceeds "
                    f"text length {text_len}"
                )
            actual_text = self.text[entity.start:entity.end]
            if actual_text != entity.text:
                raise ValueError(
                    f"Entity text mismatch at [{entity.start}:{entity.end}]: "
                    f"expected '{entity.text}', found '{actual_text}'"
                )
        return self
    
    @model_validator(mode="after")
    def validate_seed_pii_present(self) -> "SyntheticSample":
        """Ensure the seed PII value appears in the text."""
        if self.seed_pii_value not in self.text:
            raise ValueError(
                f"Seed PII value '{self.seed_pii_value}' not found in generated text"
            )
        return self


class GenerationBatch(BaseModel):
    """
    A batch of synthetic samples with generation metadata.
    
    Attributes:
        samples: List of generated samples in this batch
        dimension: The feature dimension for all samples in batch
        batch_id: Unique batch identifier
        total_requested: How many samples were requested
        successful: How many were successfully generated
        failed: How many failed generation/validation
    """
    samples: list[SyntheticSample] = Field(default_factory=list)
    dimension: FeatureDimension
    batch_id: str
    total_requested: int
    successful: int = 0
    failed: int = 0


print("âœ“ Pydantic schemas defined")
print(f"  SyntheticSample fields: {list(SyntheticSample.model_fields.keys())}")

âœ“ Pydantic schemas defined
  SyntheticSample fields: ['text', 'entities', 'feature_dimension', 'seed_pii_type', 'seed_pii_value', 'seed_pii_locale', 'scenario', 'type_variant', 'generation_id', 'timestamp']


## Faker-Based PII Value Generators

Realistic PII value generation using Faker with locale support. These generators create seed PII values that will be
embedded in synthetic text.


In [6]:
class PIIGenerator:
    """
    Generates realistic PII values across multiple locales using Faker.
    
    This class provides methods to generate each of the 16 PII types with
    proper formatting for different countries/regions. All generated PII
    is synthetic and safe for training data.
    
    Attributes:
        fakers: Dictionary mapping locale codes to Faker instances
        default_locale: Fallback locale when requested locale unavailable
    """
    
    def __init__(self, locales: list[str] | None = None):
        """
        Initialize PII generators for specified locales.
        
        Args:
            locales: List of locale codes (e.g., ['en_US', 'en_GB', 'de_DE']).
                     Defaults to all supported locales if not specified.
        """
        self.locales = locales or list(SUPPORTED_LOCALES.keys())
        self.fakers: dict[str, Faker] = {}
        self.default_locale = "en_US"
        
        for locale in self.locales:
            try:
                self.fakers[locale] = Faker(locale)
                # Seed for reproducibility within session
                self.fakers[locale].seed_instance(random.randint(0, 10000))
            except Exception as e:
                print(f"Warning: Could not initialize Faker for {locale}: {e}")
        
        if not self.fakers:
            raise RuntimeError("No Faker instances could be initialized")
    
    def _get_faker(self, locale: str | None = None) -> tuple[Faker, str]:
        """Get Faker instance for locale, falling back to default."""
        if locale and locale in self.fakers:
            return self.fakers[locale], locale
        return self.fakers[self.default_locale], self.default_locale
    
    def generate_name(
        self, 
        locale: str | None = None,
        name_type: str = "full",
    ) -> tuple[str, str, str]:
        """
        Generate a realistic person name.
        
        Args:
            locale: Target locale for name generation
            name_type: One of 'full', 'first', 'last', 'middle'
        
        Returns:
            Tuple of (name_value, actual_locale, label) where label is
            FIRSTNAME, LASTNAME, MIDDLENAME, or NAME for full names.
        """
        faker, actual_locale = self._get_faker(locale)
        
        if name_type == "first":
            return faker.first_name(), actual_locale, "FIRSTNAME"
        elif name_type == "last":
            return faker.last_name(), actual_locale, "LASTNAME"
        elif name_type == "middle":
            return faker.first_name(), actual_locale, "MIDDLENAME"
        else:
            return faker.name(), actual_locale, "NAME"
    
    def generate_email(self, locale: str | None = None) -> tuple[str, str, str]:
        """Generate a realistic email address."""
        faker, actual_locale = self._get_faker(locale)
        return faker.email(), actual_locale, "EMAIL"
    
    def generate_phone(self, locale: str | None = None) -> tuple[str, str, str]:
        """Generate a phone number in locale-appropriate format."""
        faker, actual_locale = self._get_faker(locale)
        return faker.phone_number(), actual_locale, "PHONE"
    
    def generate_dob(self, locale: str | None = None) -> tuple[str, str, str]:
        """Generate a date of birth with locale-appropriate formatting."""
        faker, actual_locale = self._get_faker(locale)
        dob = faker.date_of_birth(minimum_age=18, maximum_age=85)
        
        # Format varies by locale
        if actual_locale in ["en_US", "en_CA"]:
            formatted = dob.strftime("%m/%d/%Y")
        elif actual_locale in ["en_GB", "en_AU", "en_IN", "de_DE", "fr_FR", "it_IT", "es_ES", "nl_NL"]:
            formatted = dob.strftime("%d/%m/%Y")
        else:
            formatted = dob.strftime("%Y-%m-%d")
        
        return formatted, actual_locale, "DATE_OF_BIRTH"
    
    def generate_postal_code(self, locale: str | None = None) -> tuple[str, str, str]:
        """Generate a postal/ZIP code for the specified locale."""
        faker, actual_locale = self._get_faker(locale)
        return faker.postcode(), actual_locale, "POSTAL_CODE"
    
    def generate_credit_card(self, locale: str | None = None) -> tuple[str, str, str]:
        """Generate a credit card number (Luhn-valid but fake)."""
        faker, actual_locale = self._get_faker(locale)
        return faker.credit_card_number(), actual_locale, "CREDIT_CARD"
    
    def generate_bank_account(
        self, 
        locale: str | None = None,
        account_type: str = "iban",
    ) -> tuple[str, str, str]:
        """
        Generate a bank account identifier.
        
        Args:
            locale: Target locale
            account_type: 'iban' for European, 'account' for numeric, 'bic' for SWIFT
        """
        faker, actual_locale = self._get_faker(locale)
        
        if account_type == "iban" and hasattr(faker, "iban"):
            return faker.iban(), actual_locale, "BANK_ACCOUNT"
        elif account_type == "bic" and hasattr(faker, "swift"):
            return faker.swift(), actual_locale, "BANK_ACCOUNT"
        else:
            # Generate account-style number
            return faker.bban(), actual_locale, "BANK_ACCOUNT"
    
    def generate_driver_license(self, locale: str | None = None) -> tuple[str, str, str]:
        """Generate a driver's license number pattern for locale."""
        faker, actual_locale = self._get_faker(locale)
        
        # Different formats by country
        patterns = {
            "en_US": lambda: f"{faker.random_letter().upper()}{faker.random_number(digits=7, fix_len=True)}",
            "en_GB": lambda: f"{faker.random_uppercase_letter()}{faker.random_uppercase_letter()}{faker.random_number(digits=6, fix_len=True)}{faker.random_uppercase_letter()}{faker.random_uppercase_letter()}99{faker.random_uppercase_letter()}{faker.random_uppercase_letter()}",
            "en_IN": lambda: f"{faker.state_abbr() if hasattr(faker, 'state_abbr') else 'MH'}{faker.random_number(digits=13, fix_len=True)}",
            "de_DE": lambda: f"{faker.random_number(digits=11, fix_len=True)}",
        }
        
        generator = patterns.get(actual_locale, patterns["en_US"])
        return generator(), actual_locale, "DRIVER_LICENSE"
    
    def generate_passport(self, locale: str | None = None) -> tuple[str, str, str]:
        """Generate a passport number pattern."""
        faker, actual_locale = self._get_faker(locale)
        
        # Common patterns: letter(s) + digits
        patterns = {
            "en_US": lambda: f"{faker.random_number(digits=9, fix_len=True)}",
            "en_GB": lambda: f"{faker.random_number(digits=9, fix_len=True)}",
            "en_IN": lambda: f"{faker.random_uppercase_letter()}{faker.random_number(digits=7, fix_len=True)}",
            "de_DE": lambda: f"C{faker.random_number(digits=8, fix_len=True)}",
        }
        
        generator = patterns.get(actual_locale, lambda: f"{faker.random_uppercase_letter()}{faker.random_number(digits=8, fix_len=True)}")
        return generator(), actual_locale, "PASSPORT_NUMBER"
    
    def generate_ssn(self, locale: str | None = None) -> tuple[str, str, str]:
        """Generate a national identity number (SSN, Aadhaar, NI, etc.)."""
        faker, actual_locale = self._get_faker(locale)
        
        patterns = {
            "en_US": lambda: f"{faker.random_number(digits=3, fix_len=True)}-{faker.random_number(digits=2, fix_len=True)}-{faker.random_number(digits=4, fix_len=True)}",
            "en_GB": lambda: f"{faker.random_uppercase_letter()}{faker.random_uppercase_letter()}{faker.random_number(digits=6, fix_len=True)}{faker.random_uppercase_letter()}",
            "en_IN": lambda: f"{faker.random_number(digits=4, fix_len=True)} {faker.random_number(digits=4, fix_len=True)} {faker.random_number(digits=4, fix_len=True)}",
            "de_DE": lambda: f"{faker.random_number(digits=11, fix_len=True)}",
        }
        
        generator = patterns.get(actual_locale, patterns["en_US"])
        return generator(), actual_locale, "NATIONAL_IDENTITY_SSN_AADHAR"
    
    def generate_other_national_id(self, locale: str | None = None) -> tuple[str, str, str]:
        """Generate other national identity formats (PAN, TFN, etc.)."""
        faker, actual_locale = self._get_faker(locale)
        
        patterns = {
            "en_IN": lambda: f"{faker.random_uppercase_letter()}{faker.random_uppercase_letter()}{faker.random_uppercase_letter()}{faker.random_uppercase_letter()}{faker.random_uppercase_letter()}{faker.random_number(digits=4, fix_len=True)}{faker.random_uppercase_letter()}",  # PAN
            "en_AU": lambda: f"{faker.random_number(digits=9, fix_len=True)}",  # TFN
        }
        
        generator = patterns.get(actual_locale, lambda: f"{faker.random_uppercase_letter()}{faker.random_number(digits=8, fix_len=True)}")
        return generator(), actual_locale, "OTHER_NATIONAL_IDENTITY"
    
    def generate_tax_id(self, locale: str | None = None) -> tuple[str, str, str]:
        """Generate a tax identification number."""
        faker, actual_locale = self._get_faker(locale)
        
        patterns = {
            "en_US": lambda: f"{faker.random_number(digits=2, fix_len=True)}-{faker.random_number(digits=7, fix_len=True)}",  # EIN
            "de_DE": lambda: f"DE{faker.random_number(digits=9, fix_len=True)}",  # VAT
            "en_GB": lambda: f"GB{faker.random_number(digits=9, fix_len=True)}",  # VAT
        }
        
        generator = patterns.get(actual_locale, patterns["en_US"])
        return generator(), actual_locale, "TAX_IDENTIFICATION"
    
    def generate_vehicle_registration(
        self, 
        locale: str | None = None,
        reg_type: str = "plate",
    ) -> tuple[str, str, str]:
        """Generate vehicle registration (plate number or VIN)."""
        faker, actual_locale = self._get_faker(locale)
        
        if reg_type == "vin":
            # VIN is international 17-character format
            chars = string.ascii_uppercase.replace("I", "").replace("O", "").replace("Q", "") + string.digits
            vin = "".join(random.choices(chars, k=17))
            return vin, actual_locale, "VEHICLE_REGISTRATION"
        
        # License plate patterns
        patterns = {
            "en_US": lambda: f"{faker.random_uppercase_letter()}{faker.random_uppercase_letter()}{faker.random_uppercase_letter()}-{faker.random_number(digits=4, fix_len=True)}",
            "en_GB": lambda: f"{faker.random_uppercase_letter()}{faker.random_uppercase_letter()}{faker.random_number(digits=2, fix_len=True)} {faker.random_uppercase_letter()}{faker.random_uppercase_letter()}{faker.random_uppercase_letter()}",
            "de_DE": lambda: f"{faker.city()[:2].upper()}-{faker.random_uppercase_letter()}{faker.random_uppercase_letter()} {faker.random_number(digits=4, fix_len=True)}",
            "en_IN": lambda: f"{faker.state_abbr() if hasattr(faker, 'state_abbr') else 'MH'}{faker.random_number(digits=2, fix_len=True)} {faker.random_uppercase_letter()}{faker.random_uppercase_letter()} {faker.random_number(digits=4, fix_len=True)}",
        }
        
        generator = patterns.get(actual_locale, patterns["en_US"])
        return generator(), actual_locale, "VEHICLE_REGISTRATION"
    
    def generate_insurance_number(self, locale: str | None = None) -> tuple[str, str, str]:
        """Generate an insurance policy/member number."""
        faker, actual_locale = self._get_faker(locale)
        
        prefixes = ["INS", "POL", "MBR", "HLT", "AUT"]
        prefix = random.choice(prefixes)
        number = faker.random_number(digits=10, fix_len=True)
        
        return f"{prefix}-{number}", actual_locale, "INSURANCE_NUMBER"
    
    def generate_upi_id(self, locale: str | None = None) -> tuple[str, str, str]:
        """Generate an Indian UPI ID (user@provider format)."""
        faker, _ = self._get_faker("en_IN")  # UPI is India-specific
        
        providers = ["okicici", "okhdfcbank", "oksbi", "ybl", "paytm", "gpay", "phonepe"]
        username = faker.user_name().lower()
        provider = random.choice(providers)
        
        return f"{username}@{provider}", "en_IN", "BANK_UPI_ID"
    
    def generate_place_name(
        self, 
        locale: str | None = None,
        place_type: str = "city",
    ) -> tuple[str, str, str]:
        """Generate a place name (city, state, street, etc.)."""
        faker, actual_locale = self._get_faker(locale)
        
        if place_type == "city":
            return faker.city(), actual_locale, "NAMES_OF_PLACES_OR_NOUNS"
        elif place_type == "state":
            if hasattr(faker, "state"):
                return faker.state(), actual_locale, "NAMES_OF_PLACES_OR_NOUNS"
            return faker.city(), actual_locale, "NAMES_OF_PLACES_OR_NOUNS"
        elif place_type == "street":
            return faker.street_name(), actual_locale, "NAMES_OF_PLACES_OR_NOUNS"
        else:
            return faker.city(), actual_locale, "NAMES_OF_PLACES_OR_NOUNS"
    
    def generate_pii(
        self, 
        pii_type: str, 
        locale: str | None = None,
        **kwargs: Any,
    ) -> tuple[str, str, str]:
        """
        Generate PII of the specified type.
        
        Args:
            pii_type: One of the 16 PII types from ALL_PII_TYPES
            locale: Target locale for generation
            **kwargs: Additional arguments for specific generators
        
        Returns:
            Tuple of (pii_value, actual_locale, label)
        """
        generators = {
            "NAME": self.generate_name,
            "EMAIL": self.generate_email,
            "PHONE": self.generate_phone,
            "DATE_OF_BIRTH": self.generate_dob,
            "POSTAL_CODE": self.generate_postal_code,
            "CREDIT_CARD": self.generate_credit_card,
            "BANK_ACCOUNT": self.generate_bank_account,
            "DRIVER_LICENSE": self.generate_driver_license,
            "PASSPORT_NUMBER": self.generate_passport,
            "NATIONAL_IDENTITY_SSN_AADHAR": self.generate_ssn,
            "OTHER_NATIONAL_IDENTITY": self.generate_other_national_id,
            "TAX_IDENTIFICATION": self.generate_tax_id,
            "VEHICLE_REGISTRATION": self.generate_vehicle_registration,
            "INSURANCE_NUMBER": self.generate_insurance_number,
            "BANK_UPI_ID": self.generate_upi_id,
            "NAMES_OF_PLACES_OR_NOUNS": self.generate_place_name,
        }
        
        if pii_type not in generators:
            raise ValueError(f"Unknown PII type: {pii_type}. Valid types: {list(generators.keys())}")
        
        return generators[pii_type](locale=locale, **kwargs)


# Initialize global PII generator
pii_gen = PIIGenerator()

# Test generation
print("âœ“ PIIGenerator initialized")
print("\nSample generated PII values:")
for pii_type in random.sample(ALL_PII_TYPES, 5):
    value, locale, label = pii_gen.generate_pii(pii_type, locale="en_US")
    print(f"  {pii_type}: {value} ({locale})")

âœ“ PIIGenerator initialized

Sample generated PII values:
  NATIONAL_IDENTITY_SSN_AADHAR: 397-60-1868 (en_US)
  DRIVER_LICENSE: C7637049 (en_US)
  NAME: Eduardo Barnes (en_US)
  POSTAL_CODE: 77492 (en_US)
  EMAIL: harriscarolyn@example.net (en_US)


## Dimension-Specific Prompt Templates

Prompt templates for each feature dimension, designed to generate challenging PII examples that specifically target NER
model failure modes.


In [7]:
# Base system prompt for all dimensions
SYSTEM_PROMPT_BASE: str = """You are a synthetic data generator creating training examples for PII (Personally Identifiable Information) detection models.

Your task is to generate realistic English text containing the provided PII value, properly annotated with character-level spans.

CRITICAL REQUIREMENTS:
1. The text MUST be natural, coherent English
2. The provided PII value MUST appear EXACTLY as given (no modifications)
3. You MUST include additional contextually-relevant PII entities beyond the seed
4. All entity spans MUST be accurate character positions (0-indexed, exclusive end)
5. The scenario MUST feel realistic and plausible

OUTPUT FORMAT (strict JSON):
{
    "text": "The generated text containing PII...",
    "entities": [
        {"start": 0, "end": 10, "label": "LABEL", "text": "exact text"},
        ...
    ],
    "scenario": "Brief description of the scenario"
}

IMPORTANT: 
- The "text" field in each entity MUST exactly match text[start:end]
- Include 2-5 PII entities total (including the seed)
- Text length should be 100-500 characters
- Do NOT use markdown formatting in the text"""


DIMENSION_PROMPTS: dict[str, str] = {
    "basic": """DIMENSION: BASIC
Generate straightforward text where the PII is clearly formatted and easily identifiable.

CHARACTERISTICS:
- PII appears in standard, expected formats
- Clear contextual cues (e.g., "Email:", "Phone:", "SSN:")
- Well-structured sentences
- No ambiguity about entity boundaries

EXAMPLE SCENARIOS:
- Contact information in a directory entry
- Form data confirmation message
- Official document excerpt
- Registration confirmation

The goal is clean, well-formatted examples that establish baseline performance.""",

    "contextual": """DIMENSION: CONTEXTUAL
Generate text where PII requires context to disambiguate from similar-looking non-PII.

CHARACTERISTICS:
- Potential false positives nearby (e.g., product codes that look like IDs)
- Ambiguous strings that could be PII or not depending on context
- Names that could be company names, place names, or person names
- Numbers that could be IDs, prices, or dates

EXAMPLE SCENARIOS:
- Email discussing both a person named "Amazon" and the company Amazon
- Text containing both a date "March 15" and a person named "March"
- Discussion mixing product serial numbers with actual SSNs
- Street addresses where street names are also person names

The goal is examples requiring semantic understanding, not pattern matching.""",

    "noisy": """DIMENSION: NOISY
Generate text with real-world imperfections that challenge NER systems.

CHARACTERISTICS:
- Typos and misspellings in surrounding text (NOT in the PII itself)
- OCR-style errors (l/1, O/0 confusion in context)
- Inconsistent formatting and spacing
- Abbreviations and informal language
- Missing punctuation or extra whitespace
- SMS/chat-style shortened text

EXAMPLE SCENARIOS:
- Scanned document with OCR artifacts
- Hastily typed customer service chat
- Social media post with typos
- Informal email with abbreviations

The PII values themselves should remain accurate - the noise is in the surrounding text.""",

    "evolving": """DIMENSION: EVOLVING
Generate text containing modern/emerging PII formats not in traditional training data.

CHARACTERISTICS:
- Cryptocurrency wallet addresses (Bitcoin, Ethereum, etc.)
- UPI IDs (username@provider format)
- Modern usernames/handles (@mentions, Discord tags)
- Digital payment identifiers
- Cloud service identifiers
- API keys or tokens (realistic-looking fakes)
- Modern two-factor authentication codes

EXAMPLE SCENARIOS:
- Cryptocurrency transaction discussion
- Digital payment confirmation
- Tech support for modern apps
- Social media account setup
- Fintech application onboarding

The goal is PII types that have emerged in the last 5-10 years.""",

    "multilingual": """DIMENSION: MULTILINGUAL
Generate English text containing PII in international formats from various countries.

CHARACTERISTICS:
- International phone number formats (+44, +91, +49, etc.)
- Non-US ID formats (IBAN, UK NI numbers, Indian Aadhaar, German Personalausweis)
- International postal codes (UK postcodes, German PLZ, Indian PIN codes)
- Date formats from different regions (DD/MM/YYYY vs MM/DD/YYYY)
- International vehicle registration formats

EXAMPLE SCENARIOS:
- International business correspondence
- Immigration/visa documentation
- International banking transaction
- Multinational company HR records
- Travel booking confirmation

Text MUST be in English, but PII formats should be from non-US locales.""",

    "adversarial": """DIMENSION: ADVERSARIAL
Generate text with patterns designed to confuse or evade NER systems.

CHARACTERISTICS:
- Unusual spacing or formatting within PII
- PII split across sentence boundaries
- Obfuscated but recognizable PII (spaces in SSN: "123 45 6789")
- PII embedded in code snippets or technical text
- Edge cases with unusual but valid formats
- Deliberately misleading context

EXAMPLE SCENARIOS:
- PII hidden in debug logs or error messages
- Social engineering attempts with formatted PII
- Technical documentation with embedded real values
- PII in URLs, file paths, or JSON structures
- Creatively formatted attempts to evade filters

The goal is testing model robustness against evasion attempts.""",
}


def get_generation_prompt(
    dimension: str,
    pii_type: str,
    pii_value: str,
    locale: str,
    type_variant: str = "standard",
) -> tuple[str, str]:
    """
    Construct the complete prompt for synthetic data generation.
    
    Args:
        dimension: Target feature dimension
        pii_type: Type of PII being seeded
        pii_value: The actual PII value to embed
        locale: Locale/region for the PII
        type_variant: Specific variant description
    
    Returns:
        Tuple of (system_prompt, user_prompt)
    """
    if dimension not in DIMENSION_PROMPTS:
        raise ValueError(f"Unknown dimension: {dimension}")
    
    system_prompt = f"{SYSTEM_PROMPT_BASE}\n\n{DIMENSION_PROMPTS[dimension]}"
    
    user_prompt = f"""Generate a {dimension.upper()} dimension training example.

SEED PII:
- Type: {pii_type}
- Value: {pii_value}
- Locale: {locale}
- Variant: {type_variant}

The provided PII value MUST appear exactly as shown in your generated text.
Include 2-4 additional relevant PII entities.
Ensure all entity spans are accurate character positions.

Generate the JSON output now:"""
    
    return system_prompt, user_prompt


print("âœ“ Prompt templates configured for all 6 dimensions")
for dim in FEATURE_DIMENSIONS:
    print(f"  {dim}: {len(DIMENSION_PROMPTS[dim])} chars")

âœ“ Prompt templates configured for all 6 dimensions
  basic: 527 chars
  contextual: 737 chars
  noisy: 629 chars
  evolving: 669 chars
  multilingual: 719 chars
  adversarial: 710 chars


## xAI API Client Wrapper

xAI API client wrapper using the official xai_sdk.AsyncClient.

This implementation follows the xAI async documentation: https://docs.x.ai/docs/guides/async

The SDK is gRPC-based and handles:

-   Connection pooling and management
-   Authentication via API key
-   Automatic retries for transient errors
-   Proper timeout handling

Rate Limit Strategy:

-   xAI allows 480 requests/minute (8 req/sec average)
-   We use asyncio.Semaphore to limit concurrent in-flight requests
-   Batch processing with semaphore ensures we stay under limits


In [8]:
import asyncio
import json
import os
from typing import Any

import json_repair
from xai_sdk import AsyncClient
from xai_sdk.chat import system, user, Response


class GrokClient:
    """
    Async client for xAI's Grok API using the official SDK.
    
    This client wraps xai_sdk.AsyncClient and provides:
        - Semaphore-controlled concurrency for rate limit compliance
        - Batch processing for high-throughput generation
        - JSON response parsing with repair for malformed LLM output
        - Consistent error handling across all requests
    
    The xAI SDK is gRPC-based, which provides better performance and
    reliability compared to raw HTTP requests.
    
    Attributes:
        client: The underlying xai_sdk.AsyncClient instance
        model: Model identifier (e.g., 'grok-4-1-fast-non-reasoning')
        max_concurrent: Maximum simultaneous in-flight requests
        min_batch_interval: Minimum seconds between batch starts (for rate limiting)
    """
    
    # Rate limit: 480 requests/minute = 8 req/sec
    # With 20 concurrent requests, we need ~2.5 seconds per batch minimum
    # Using 3.0 seconds for safety margin
    DEFAULT_BATCH_SIZE: int = 20
    DEFAULT_MIN_BATCH_INTERVAL: float = 3.0
    
    def __init__(
        self,
        api_key: str | None = None,
        model: str = "grok-4-1-fast-non-reasoning",
        max_concurrent: int = DEFAULT_BATCH_SIZE,
        min_batch_interval: float = DEFAULT_MIN_BATCH_INTERVAL,
        timeout: int = 900,
    ):
        """
        Initialize the Grok API client with the official xAI SDK.
        
        Args:
            api_key: xAI API key. If None, uses XAI_API_KEY environment variable.
            model: Model to use for generation (default: grok-4-1-fast-non-reasoning).
            max_concurrent: Maximum concurrent in-flight requests (default: 20).
            min_batch_interval: Minimum seconds between batch starts (default: 3.0).
            timeout: Request timeout in seconds (default: 900, the SDK default).
        """
        self.model = model
        self.max_concurrent = max_concurrent
        self.min_batch_interval = min_batch_interval
        
        # Initialize the official xAI AsyncClient
        # The SDK reads XAI_API_KEY from environment if api_key is None
        self.client = AsyncClient(
            api_key=api_key or os.getenv("XAI_API_KEY"),
            timeout=timeout,
        )
        
        # Semaphore controls maximum concurrent in-flight requests
        self._semaphore = asyncio.Semaphore(max_concurrent)
        
        # Track timing for rate limit compliance
        self._last_batch_start: float = 0.0
    
    async def close(self) -> None:
        """
        Close the client and release resources.
        
        The xAI SDK AsyncClient manages its own connection lifecycle,
        but we provide this method for explicit cleanup if needed.
        """
        # The AsyncClient doesn't have an explicit close method,
        # but we reset our state
        self._last_batch_start = 0.0
    
    async def __aenter__(self) -> "GrokClient":
        """Async context manager entry."""
        return self
    
    async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
        """Async context manager exit."""
        await self.close()
    
    async def _generate_single(
        self,
        system_prompt: str,
        user_prompt: str,
        temperature: float = 0.7,
        max_tokens: int = 1500,
    ) -> dict[str, Any] | None:
        """
        Generate a single response with semaphore-controlled concurrency.
        
        This method acquires the semaphore before making the request,
        ensuring we never exceed max_concurrent simultaneous requests.
        Uses the official xAI SDK chat interface.
        
        Args:
            system_prompt: System message with generation instructions.
            user_prompt: User message with specific request.
            temperature: Sampling temperature (0.0-1.0).
            max_tokens: Maximum tokens in the response.
        
        Returns:
            Parsed JSON response dict, or None if generation/parsing fails.
        """
        async with self._semaphore:
            try:
                # Create a new chat instance with system message
                chat = self.client.chat.create(
                    model=self.model,
                    messages=[system(system_prompt)],
                    temperature=temperature,
                    max_tokens=max_tokens,
                )
                
                # Append the user message
                chat.append(user(user_prompt))
                
                # Sample a response (this is the actual API call)
                response: Response = await chat.sample()
                
                # Extract the content from the response
                raw_content: str = response.content
                
                # Parse JSON from the response content
                return self._parse_json_response(raw_content)
                
            except Exception as e:
                # Log error but don't crash - return None to indicate failure
                print(f"Generation error: {type(e).__name__}: {e}")
                return None
    
    def _parse_json_response(self, raw_content: str) -> dict[str, Any] | None:
        """
        Parse JSON from LLM response content.
        
        LLMs often wrap JSON in markdown code blocks or produce slightly
        malformed JSON. This method handles common cases and uses json_repair
        as a fallback for malformed output.
        
        Args:
            raw_content: Raw string content from the LLM response.
        
        Returns:
            Parsed JSON dict, or None if parsing fails completely.
        """
        if not raw_content:
            return None
        
        json_str = raw_content.strip()
        
        # Remove markdown code blocks if present
        if "```json" in json_str:
            # Extract content between ```json and ```
            parts = json_str.split("```json")
            if len(parts) > 1:
                json_str = parts[1].split("```")[0]
        elif "```" in json_str:
            # Generic code block
            parts = json_str.split("```")
            if len(parts) >= 2:
                json_str = parts[1]
        
        json_str = json_str.strip()
        
        # Attempt standard JSON parsing first
        try:
            return json.loads(json_str)
        except json.JSONDecodeError:
            pass
        
        # Fall back to json_repair for malformed JSON
        try:
            return json_repair.loads(json_str)
        except Exception:
            return None
    
    async def generate(
        self,
        system_prompt: str,
        user_prompt: str,
        temperature: float = 0.7,
        max_tokens: int = 1500,
    ) -> dict[str, Any] | None:
        """
        Public interface for single generation.
        
        For high-throughput scenarios, use generate_batch() instead.
        
        Args:
            system_prompt: System message with generation instructions.
            user_prompt: User message with specific request.
            temperature: Sampling temperature.
            max_tokens: Maximum tokens in response.
        
        Returns:
            Parsed JSON response dict, or None on failure.
        """
        return await self._generate_single(
            system_prompt=system_prompt,
            user_prompt=user_prompt,
            temperature=temperature,
            max_tokens=max_tokens,
        )
    
    async def generate_batch(
        self,
        requests: list[tuple[str, str, float]],
        max_tokens: int = 1500,
    ) -> list[dict[str, Any] | None]:
        """
        Generate multiple responses concurrently with rate limit compliance.
        
        This method implements the pattern from the xAI async documentation:
        1. Create tasks for all requests
        2. Use semaphore to limit concurrent in-flight requests
        3. Use asyncio.gather() to execute all tasks
        4. Enforce minimum batch interval for rate limiting
        
        The semaphore ensures that even if you pass 100 requests, only
        max_concurrent will be in-flight at any given moment.
        
        Args:
            requests: List of (system_prompt, user_prompt, temperature) tuples.
            max_tokens: Maximum tokens per response.
        
        Returns:
            List of results in the same order as input requests.
            Each result is either a parsed JSON dict or None on failure.
        """
        # Enforce minimum time since last batch started
        now = time.time()
        elapsed_since_last_batch = now - self._last_batch_start
        if elapsed_since_last_batch < self.min_batch_interval:
            wait_time = self.min_batch_interval - elapsed_since_last_batch
            await asyncio.sleep(wait_time)
        
        # Record batch start time
        self._last_batch_start = time.time()
        
        # Create async task for each request
        # The semaphore inside _generate_single controls actual concurrency
        async def process_request(
            system_prompt: str,
            user_prompt: str,
            temperature: float,
        ) -> dict[str, Any] | None:
            return await self._generate_single(
                system_prompt=system_prompt,
                user_prompt=user_prompt,
                temperature=temperature,
                max_tokens=max_tokens,
            )
        
        # Build task list
        tasks = [
            process_request(system_prompt, user_prompt, temperature)
            for system_prompt, user_prompt, temperature in requests
        ]
        
        # Execute all tasks concurrently
        # The semaphore limits how many are actually in-flight simultaneously
        results = await asyncio.gather(*tasks, return_exceptions=True)
        
        # Convert exceptions to None for consistent return type
        processed_results: list[dict[str, Any] | None] = []
        for result in results:
            if isinstance(result, Exception):
                print(f"Task exception: {type(result).__name__}: {result}")
                processed_results.append(None)
            else:
                processed_results.append(result)
        
        return processed_results


# Initialize client using the official SDK
grok_client = GrokClient(
    api_key=XAI_API_KEY,
    model="grok-4-1-fast-non-reasoning",
    max_concurrent=40,          # Max 40 in-flight requests at once
    min_batch_interval=1.0,     # Wait at least 3 seconds between batches
    timeout=900,                # 15 minute timeout (SDK default)
)

print("âœ“ Grok API client initialized using official xai_sdk.AsyncClient")
print(f"  Model: {grok_client.model}")
print(f"  Max concurrent requests: {grok_client.max_concurrent}")
print(f"  Min batch interval: {grok_client.min_batch_interval}s")
print(f"  Theoretical max throughput: {60 / grok_client.min_batch_interval * grok_client.max_concurrent:.0f} req/min")

âœ“ Grok API client initialized using official xai_sdk.AsyncClient
  Model: grok-4-1-fast-non-reasoning
  Max concurrent requests: 40
  Min batch interval: 1.0s
  Theoretical max throughput: 2400 req/min


## Sample Generation and Validation Functions

Core generation logic: creates individual samples, validates them, and handles retry logic for failed generations.


In [9]:
import uuid


def find_entity_span(text: str, entity_text: str, label: str) -> EntitySpan | None:
    """
    Find the character span of an entity in text.
    
    Handles cases where the LLM might provide incorrect spans by
    searching for the actual text position.
    
    Args:
        text: The full text to search in
        entity_text: The entity string to find
        label: The entity label
    
    Returns:
        EntitySpan if found, None otherwise
    """
    start_idx = text.find(entity_text)
    if start_idx == -1:
        return None
    
    return EntitySpan(
        start=start_idx,
        end=start_idx + len(entity_text),
        label=label,
        text=entity_text,
    )


def repair_entity_spans(
    text: str,
    raw_entities: list[dict[str, Any]],
) -> list[EntitySpan]:
    """
    Repair and validate entity spans from LLM output.
    
    LLMs frequently produce incorrect character positions. This function:
    1. Validates each span against the actual text
    2. Attempts to find correct positions for misaligned entities
    3. Filters out entities that cannot be located
    
    Args:
        text: The generated text
        raw_entities: List of entity dicts from LLM response
    
    Returns:
        List of validated EntitySpan objects
    """
    repaired: list[EntitySpan] = []
    seen_spans: set[tuple[int, int]] = set()  # Avoid duplicates
    
    for entity_dict in raw_entities:
        try:
            start = entity_dict.get("start", 0)
            end = entity_dict.get("end", 0)
            label = entity_dict.get("label", "UNKNOWN")
            entity_text = entity_dict.get("text", "")
            
            # First, check if provided span is correct
            if 0 <= start < end <= len(text):
                actual_text = text[start:end]
                if actual_text == entity_text:
                    # Span is correct
                    span_key = (start, end)
                    if span_key not in seen_spans:
                        repaired.append(EntitySpan(
                            start=start,
                            end=end,
                            label=label,
                            text=entity_text,
                        ))
                        seen_spans.add(span_key)
                    continue
            
            # Span is incorrect, try to find the text
            if entity_text:
                found_span = find_entity_span(text, entity_text, label)
                if found_span:
                    span_key = (found_span.start, found_span.end)
                    if span_key not in seen_spans:
                        repaired.append(found_span)
                        seen_spans.add(span_key)
        
        except Exception as e:
            # Skip malformed entities
            continue
    
    return repaired


async def generate_single_sample(
    client: GrokClient,
    dimension: str,
    pii_type: str,
    locale: str,
    generation_id: str,
    max_attempts: int = 3,
) -> SyntheticSample | None:
    """
    Generate a single synthetic sample with retry logic.
    
    Args:
        client: Initialized GrokClient
        dimension: Target feature dimension
        pii_type: Type of PII to generate
        locale: Locale for PII formatting
        generation_id: Unique ID for this generation
        max_attempts: Max generation attempts before giving up
    
    Returns:
        Validated SyntheticSample, or None if generation fails
    """
    # Generate seed PII value
    pii_value, actual_locale, label = pii_gen.generate_pii(pii_type, locale=locale)
    
    # Determine type variant based on PII type
    type_variants = {
        "NAME": ["full name", "first name only", "formal with title"],
        "EMAIL": ["personal", "work", "academic"],
        "PHONE": ["mobile", "landline", "with extension"],
        "CREDIT_CARD": ["Visa", "Mastercard", "Amex"],
        "BANK_ACCOUNT": ["IBAN", "domestic", "SWIFT/BIC"],
    }
    type_variant = random.choice(type_variants.get(pii_type, ["standard"]))
    
    # Get prompts
    system_prompt, user_prompt = get_generation_prompt(
        dimension=dimension,
        pii_type=pii_type,
        pii_value=pii_value,
        locale=actual_locale,
        type_variant=type_variant,
    )
    
    for attempt in range(max_attempts):
        try:
            # Generate from LLM
            response = await client.generate(
                system_prompt=system_prompt,
                user_prompt=user_prompt,
                temperature=0.7 + (attempt * 0.1),  # Increase temp on retries
            )
            
            if response is None:
                continue
            
            # Extract fields
            text = response.get("text", "")
            raw_entities = response.get("entities", [])
            scenario = response.get("scenario", "Unspecified scenario")
            
            if not text or not raw_entities:
                continue
            
            # Verify seed PII is in text
            if pii_value not in text:
                # Try to find a close match (case-insensitive)
                if pii_value.lower() not in text.lower():
                    continue
            
            # Repair entity spans
            entities = repair_entity_spans(text, raw_entities)
            
            if not entities:
                continue
            
            # Create and validate sample
            sample = SyntheticSample(
                text=text,
                entities=entities,
                feature_dimension=FeatureDimension(dimension),
                seed_pii_type=pii_type,
                seed_pii_value=pii_value,
                seed_pii_locale=actual_locale,
                scenario=scenario,
                type_variant=type_variant,
                generation_id=generation_id,
            )
            
            return sample
            
        except Exception as e:
            if attempt < max_attempts - 1:
                await asyncio.sleep(1)  # Brief pause before retry
            continue
    
    return None


print("âœ“ Generation functions defined")

âœ“ Generation functions defined


## Batch Generation with Checkpointing

Batch generation orchestration with concurrent processing, progress tracking, checkpointing, and balanced sampling
across dimensions and PII types.

This processes samples in batches of 20 concurrent requests, dramatically improving throughput compared to sequential
processing.


In [10]:
@dataclass
class GenerationConfig:
    """
    Configuration for synthetic data generation run.
    
    Attributes:
        total_samples: Total number of samples to generate
        samples_per_dimension: Samples per feature dimension (auto-calculated if 0)
        batch_size: Number of concurrent requests per batch
        samples_per_checkpoint: How often to save checkpoints
        output_dir: Directory for output files and checkpoints
        checkpoint_prefix: Prefix for checkpoint filenames
    """
    total_samples: int = 11000
    samples_per_dimension: int = 0  # 0 = auto-calculate
    batch_size: int = 20  # Concurrent requests per batch
    samples_per_checkpoint: int = 100
    output_dir: str = "./data/synthetic"
    checkpoint_prefix: str = "synthetic_checkpoint"
    
    def __post_init__(self):
        if self.samples_per_dimension == 0:
            self.samples_per_dimension = self.total_samples // len(FEATURE_DIMENSIONS)
        
        # Ensure output directory exists
        Path(self.output_dir).mkdir(parents=True, exist_ok=True)


@dataclass 
class GenerationTask:
    """
    A single generation task with all parameters needed.
    
    Attributes:
        dimension: Target feature dimension
        pii_type: Type of PII to generate
        locale: Locale for PII formatting
        generation_id: Unique identifier for this task
        pii_value: Pre-generated seed PII value
        actual_locale: Actual locale used (may differ if fallback)
        label: PII label for the seed value
        type_variant: Specific variant description
        system_prompt: Complete system prompt
        user_prompt: Complete user prompt
    """
    dimension: str
    pii_type: str
    locale: str
    generation_id: str
    pii_value: str
    actual_locale: str
    label: str
    type_variant: str
    system_prompt: str
    user_prompt: str


class SyntheticDataGenerator:
    """
    Orchestrates batch generation of synthetic PII data with concurrent processing.
    
    This generator processes samples in batches, firing multiple concurrent
    requests to maximize throughput while respecting API rate limits.
    
    Features:
        - Concurrent batch processing (default: 20 requests at a time)
        - Balanced sampling across dimensions, PII types, and locales
        - Periodic checkpointing to prevent data loss
        - Progress tracking with accurate ETA estimation
        - Automatic retry for failed generations within batches
    
    Attributes:
        client: GrokClient for LLM generation
        config: GenerationConfig with settings
        generated_samples: List of all successfully generated samples
        stats: Dictionary tracking generation statistics
    """
    
    def __init__(
        self,
        client: GrokClient,
        config: GenerationConfig | None = None,
    ):
        """
        Initialize the synthetic data generator.
        
        Args:
            client: Initialized GrokClient with concurrency settings
            config: Generation configuration (uses defaults if None)
        """
        self.client = client
        self.config = config or GenerationConfig()
        self.generated_samples: list[SyntheticSample] = []
        self.failed_tasks: list[GenerationTask] = []  # Track failures for potential retry
        self.stats: dict[str, Any] = {
            "total_attempts": 0,
            "successful": 0,
            "failed": 0,
            "by_dimension": defaultdict(int),
            "by_pii_type": defaultdict(int),
            "by_locale": defaultdict(int),
            "batches_processed": 0,
            "start_time": None,
        }
    
    def _create_generation_task(
        self,
        dimension: str,
        pii_type: str,
        locale: str,
    ) -> GenerationTask:
        """
        Create a complete generation task with pre-generated PII and prompts.
        
        This front-loads all the work that doesn't require API calls,
        so batch processing only involves the actual LLM requests.
        
        Args:
            dimension: Target feature dimension
            pii_type: Type of PII to generate
            locale: Target locale for PII formatting
        
        Returns:
            GenerationTask with all fields populated
        """
        generation_id = f"{dimension}_{pii_type}_{uuid.uuid4().hex[:8]}"
        
        # Generate seed PII value
        pii_value, actual_locale, label = pii_gen.generate_pii(pii_type, locale=locale)
        
        # Determine type variant
        type_variants = {
            "NAME": ["full name", "first name only", "formal with title"],
            "EMAIL": ["personal", "work", "academic"],
            "PHONE": ["mobile", "landline", "with extension"],
            "CREDIT_CARD": ["Visa", "Mastercard", "Amex"],
            "BANK_ACCOUNT": ["IBAN", "domestic", "SWIFT/BIC"],
        }
        type_variant = random.choice(type_variants.get(pii_type, ["standard"]))
        
        # Build prompts
        system_prompt, user_prompt = get_generation_prompt(
            dimension=dimension,
            pii_type=pii_type,
            pii_value=pii_value,
            locale=actual_locale,
            type_variant=type_variant,
        )
        
        return GenerationTask(
            dimension=dimension,
            pii_type=pii_type,
            locale=locale,
            generation_id=generation_id,
            pii_value=pii_value,
            actual_locale=actual_locale,
            label=label,
            type_variant=type_variant,
            system_prompt=system_prompt,
            user_prompt=user_prompt,
        )
    
    def _get_generation_plan(self) -> list[tuple[str, str, str]]:
        """
        Create a balanced generation plan across dimensions, PII types, and locales.
        
        Returns:
            List of (dimension, pii_type, locale) tuples representing generation tasks
        """
        plan: list[tuple[str, str, str]] = []
        locales = list(SUPPORTED_LOCALES.keys())
        
        samples_per_dim = self.config.samples_per_dimension
        samples_per_type_per_dim = max(1, samples_per_dim // len(ALL_PII_TYPES))
        
        for dimension in FEATURE_DIMENSIONS:
            for pii_type in ALL_PII_TYPES:
                for _ in range(samples_per_type_per_dim):
                    locale = random.choice(locales)
                    plan.append((dimension, pii_type, locale))
        
        # Shuffle to avoid sequential patterns and distribute load
        random.shuffle(plan)
        
        return plan[:self.config.total_samples]
    
    def _process_response(
        self,
        task: GenerationTask,
        response: dict[str, Any] | None,
    ) -> SyntheticSample | None:
        """
        Process an LLM response into a validated SyntheticSample.
        
        Args:
            task: The generation task that produced this response
            response: Parsed JSON response from the LLM, or None on failure
        
        Returns:
            Validated SyntheticSample, or None if validation fails
        """
        if response is None:
            return None
        
        try:
            text = response.get("text", "")
            raw_entities = response.get("entities", [])
            scenario = response.get("scenario", "Unspecified scenario")
            
            if not text or not raw_entities:
                return None
            
            # Verify seed PII is in text
            if task.pii_value not in text:
                if task.pii_value.lower() not in text.lower():
                    return None
            
            # Repair entity spans (LLMs often get positions wrong)
            entities = repair_entity_spans(text, raw_entities)
            
            if not entities:
                return None
            
            # Create and validate sample using Pydantic
            sample = SyntheticSample(
                text=text,
                entities=entities,
                feature_dimension=FeatureDimension(task.dimension),
                seed_pii_type=task.pii_type,
                seed_pii_value=task.pii_value,
                seed_pii_locale=task.actual_locale,
                scenario=scenario,
                type_variant=task.type_variant,
                generation_id=task.generation_id,
            )
            
            return sample
            
        except Exception as e:
            return None
    
    async def _process_batch(
        self,
        tasks: list[GenerationTask],
    ) -> list[tuple[GenerationTask, SyntheticSample | None]]:
        """
        Process a batch of generation tasks concurrently.
        
        This fires all requests in the batch simultaneously (respecting
        the client's concurrency limit) and waits for all to complete.
        
        Args:
            tasks: List of GenerationTask objects to process
        
        Returns:
            List of (task, sample_or_none) tuples in the same order
        """
        # Prepare request tuples for the client
        requests = [
            (task.system_prompt, task.user_prompt, 0.7)
            for task in tasks
        ]
        
        # Fire all requests concurrently
        responses = await self.client.generate_batch(requests)
        
        # Process responses into samples
        results: list[tuple[GenerationTask, SyntheticSample | None]] = []
        for task, response in zip(tasks, responses):
            sample = self._process_response(task, response)
            results.append((task, sample))
        
        return results
    
    def _save_checkpoint(self, checkpoint_num: int) -> str:
        """
        Save current progress to a checkpoint file.
        
        Args:
            checkpoint_num: Checkpoint sequence number
        
        Returns:
            Path to saved checkpoint file
        """
        checkpoint_path = (
            Path(self.config.output_dir) / 
            f"{self.config.checkpoint_prefix}_{checkpoint_num:04d}.json"
        )
        
        # Convert defaultdicts to regular dicts for JSON serialization
        stats_copy = {
            k: (dict(v) if isinstance(v, defaultdict) else v)
            for k, v in self.stats.items()
        }
        
        checkpoint_data = {
            "checkpoint_num": checkpoint_num,
            "timestamp": datetime.utcnow().isoformat(),
            "stats": stats_copy,
            "samples": [s.model_dump() for s in self.generated_samples],
            "failed_task_count": len(self.failed_tasks),
        }
        
        with open(checkpoint_path, "w", encoding="utf-8") as f:
            json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
        
        return str(checkpoint_path)
    
    def load_checkpoint(self, checkpoint_path: str) -> int:
        """
        Load progress from a checkpoint file.
        
        Args:
            checkpoint_path: Path to checkpoint file
        
        Returns:
            Number of samples loaded
        """
        with open(checkpoint_path, "r", encoding="utf-8") as f:
            data = json.load(f)
        
        # Restore stats with defaultdict behavior
        loaded_stats = data.get("stats", {})
        self.stats = {
            "total_attempts": loaded_stats.get("total_attempts", 0),
            "successful": loaded_stats.get("successful", 0),
            "failed": loaded_stats.get("failed", 0),
            "by_dimension": defaultdict(int, loaded_stats.get("by_dimension", {})),
            "by_pii_type": defaultdict(int, loaded_stats.get("by_pii_type", {})),
            "by_locale": defaultdict(int, loaded_stats.get("by_locale", {})),
            "batches_processed": loaded_stats.get("batches_processed", 0),
            "start_time": loaded_stats.get("start_time"),
        }
        
        self.generated_samples = [
            SyntheticSample.model_validate(s) 
            for s in data.get("samples", [])
        ]
        
        print(f"âœ“ Loaded {len(self.generated_samples)} samples from checkpoint")
        return len(self.generated_samples)
    
    async def generate_batch(
        self,
        start_from: int = 0,
        progress_bar: bool = True,
    ) -> list[SyntheticSample]:
        """
        Generate a full batch of synthetic samples using concurrent processing.
        
        This method processes samples in batches of config.batch_size,
        firing concurrent requests and waiting for each batch to complete
        before starting the next.
        
        Args:
            start_from: Index to start from (for resuming from checkpoint)
            progress_bar: Whether to display a progress bar
        
        Returns:
            List of all successfully generated samples
        """
        # Get the full generation plan
        plan = self._get_generation_plan()
        
        if start_from > 0:
            plan = plan[start_from:]
            print(f"Resuming from sample {start_from}")
        
        self.stats["start_time"] = time.time()
        
        # Create all tasks upfront (this is fast, no API calls)
        print("Preparing generation tasks...")
        all_tasks = [
            self._create_generation_task(dimension, pii_type, locale)
            for dimension, pii_type, locale in plan
        ]
        
        # Split into batches
        batch_size = self.config.batch_size
        batches = [
            all_tasks[i:i + batch_size] 
            for i in range(0, len(all_tasks), batch_size)
        ]
        
        total_batches = len(batches)
        samples_since_checkpoint = len(self.generated_samples) % self.config.samples_per_checkpoint
        
        print(f"Processing {len(all_tasks)} tasks in {total_batches} batches of {batch_size}")
        
        # Process batches with progress tracking
        batch_iterator = tqdm(
            enumerate(batches),
            total=total_batches,
            desc="Processing batches",
            disable=not progress_bar,
        )
        
        for batch_idx, batch_tasks in batch_iterator:
            batch_start_time = time.time()
            
            # Process this batch concurrently
            results = await self._process_batch(batch_tasks)
            
            # Update stats and collect samples
            for task, sample in results:
                self.stats["total_attempts"] += 1
                
                if sample is not None:
                    self.generated_samples.append(sample)
                    self.stats["successful"] += 1
                    self.stats["by_dimension"][task.dimension] += 1
                    self.stats["by_pii_type"][task.pii_type] += 1
                    self.stats["by_locale"][task.actual_locale] += 1
                    samples_since_checkpoint += 1
                else:
                    self.stats["failed"] += 1
                    self.failed_tasks.append(task)
            
            self.stats["batches_processed"] += 1
            
            # Calculate metrics for progress display
            batch_elapsed = time.time() - batch_start_time
            success_rate = self.stats["successful"] / self.stats["total_attempts"]
            samples_per_second = len(batch_tasks) / batch_elapsed if batch_elapsed > 0 else 0
            
            batch_iterator.set_postfix({
                "success": f"{self.stats['successful']}/{self.stats['total_attempts']}",
                "rate": f"{success_rate:.1%}",
                "speed": f"{samples_per_second:.1f}/s",
            })
            
            # Checkpoint if needed
            if samples_since_checkpoint >= self.config.samples_per_checkpoint:
                cp_num = len(self.generated_samples) // self.config.samples_per_checkpoint
                cp_path = self._save_checkpoint(cp_num)
                batch_iterator.write(f"  ðŸ’¾ Checkpoint {cp_num}: {cp_path}")
                samples_since_checkpoint = 0
        
        # Final checkpoint
        self._save_checkpoint(9999)
        
        return self.generated_samples
    
    def get_statistics(self) -> dict[str, Any]:
        """
        Get comprehensive generation statistics.
        
        Returns:
            Dictionary with generation metrics and breakdowns
        """
        elapsed = time.time() - self.stats["start_time"] if self.stats["start_time"] else 0
        
        return {
            "total_generated": len(self.generated_samples),
            "total_attempts": self.stats["total_attempts"],
            "total_failed": self.stats["failed"],
            "success_rate": (
                self.stats["successful"] / self.stats["total_attempts"]
                if self.stats["total_attempts"] > 0 else 0
            ),
            "batches_processed": self.stats["batches_processed"],
            "elapsed_seconds": elapsed,
            "samples_per_second": len(self.generated_samples) / elapsed if elapsed > 0 else 0,
            "by_dimension": dict(self.stats["by_dimension"]),
            "by_pii_type": dict(self.stats["by_pii_type"]),
            "by_locale": dict(self.stats["by_locale"]),
        }
    
    def save_final_dataset(self, filename: str = "synthetic_pii_data.json") -> str:
        """
        Save the complete dataset to a JSON file.
        
        Args:
            filename: Output filename
        
        Returns:
            Path to saved file
        """
        output_path = Path(self.config.output_dir) / filename
        
        dataset = {
            "metadata": {
                "generated_at": datetime.utcnow().isoformat(),
                "total_samples": len(self.generated_samples),
                "dimensions": FEATURE_DIMENSIONS,
                "pii_types": ALL_PII_TYPES,
                "locales": list(SUPPORTED_LOCALES.keys()),
            },
            "statistics": self.get_statistics(),
            "samples": [s.model_dump() for s in self.generated_samples],
        }
        
        with open(output_path, "w", encoding="utf-8") as f:
            json.dump(dataset, f, indent=2, ensure_ascii=False)
        
        print(f"âœ“ Dataset saved to: {output_path}")
        return str(output_path)


# Initialize generator with concurrent processing config
config = GenerationConfig(
    total_samples=11000,
    batch_size=40,  # 40 concurrent requests per batch
    samples_per_checkpoint=100,
    output_dir="./data/synthetic",
)

generator = SyntheticDataGenerator(
    client=grok_client,
    config=config,
)

print("âœ“ SyntheticDataGenerator initialized with concurrent batch processing")
print(f"  Target: {config.total_samples} samples")
print(f"  Batch size: {config.batch_size} concurrent requests")
print(f"  Per dimension: {config.samples_per_dimension} samples")
print(f"  Checkpoint every: {config.samples_per_checkpoint} samples")
print(f"  Output directory: {config.output_dir}")

# Estimate completion time
# With 20 concurrent requests, ~3 second batches, that's ~400 requests/min
# 11,000 samples / 400 per min â‰ˆ 27.5 minutes (plus overhead)
estimated_batches = config.total_samples // config.batch_size
estimated_time_minutes = estimated_batches * 3.5 / 60  # 3.5 sec per batch average
print(f"  Estimated completion time: ~{estimated_time_minutes:.0f} minutes")

âœ“ SyntheticDataGenerator initialized with concurrent batch processing
  Target: 11000 samples
  Batch size: 40 concurrent requests
  Per dimension: 1833 samples
  Checkpoint every: 100 samples
  Output directory: ./data/synthetic
  Estimated completion time: ~16 minutes


## Quick Test Generation

Quick test: Generate a small batch to verify everything works before running the full 11,000 sample generation.


In [11]:
async def test_generation(num_samples: int = 10) -> list[SyntheticSample]:
    """
    Run a quick test generation with a small number of samples.
    
    Args:
        num_samples: Number of test samples to generate
    
    Returns:
        List of generated test samples
    """
    print(f"Running test generation with {num_samples} samples...")
    
    test_config = GenerationConfig(
        total_samples=num_samples,
        samples_per_checkpoint=num_samples + 1,  # No checkpoints for test
        output_dir="./data/synthetic_test",
    )
    
    test_generator = SyntheticDataGenerator(
        client=grok_client,
        config=test_config,
    )
    
    test_samples = await test_generator.generate_batch(progress_bar=True)
    
    print(f"\nTest complete: {len(test_samples)} samples generated")
    
    if test_samples:
        print("\nFirst sample:")
        s = test_samples[0]
        print(f"  Dimension: {s.feature_dimension.value}")
        print(f"  Text: {s.text[:150]}...")
        print(f"  Entities: {len(s.entities)}")
        for e in s.entities:
            print(f"    - {e.label}: '{e.text}' [{e.start}:{e.end}]")
    
    return test_samples


# Run quick test (comment out for full generation)
test_samples = await test_generation(10)

Running test generation with 10 samples...
Preparing generation tasks...
Processing 10 tasks in 1 batches of 20


Processing batches:   0%|          | 0/1 [00:00<?, ?it/s]


Test complete: 10 samples generated

First sample:
  Dimension: basic
  Text: Dear Mr. Rossi,

Thank you for registering for the conference. Here is your confirmation:

Name: Marco Rossi
Email: silvestro28@example.com (work)
Pho...
  Entities: 4
    - EMAIL: 'silvestro28@example.com' [116:139]
    - PERSON: 'Marco Rossi' [97:108]
    - PHONE: '+39 333 1234567' [154:169]
    - ADDRESS: 'Via Roma 45, 20121 Milano' [179:204]


  timestamp: str = Field(default_factory=lambda: datetime.utcnow().isoformat())
  "timestamp": datetime.utcnow().isoformat(),


In [12]:
print(test_samples)

[SyntheticSample(text='Dear Mr. Rossi,\n\nThank you for registering for the conference. Here is your confirmation:\n\nName: Marco Rossi\nEmail: silvestro28@example.com (work)\nPhone: +39 333 1234567\nAddress: Via Roma 45, 20121 Milano\n\nWe look forward to seeing you!\n\nBest regards,\nConference Team', entities=[EntitySpan(start=116, end=139, label='EMAIL', text='silvestro28@example.com'), EntitySpan(start=97, end=108, label='PERSON', text='Marco Rossi'), EntitySpan(start=154, end=169, label='PHONE', text='+39 333 1234567'), EntitySpan(start=179, end=204, label='ADDRESS', text='Via Roma 45, 20121 Milano')], feature_dimension=<FeatureDimension.BASIC: 'basic'>, seed_pii_type='EMAIL', seed_pii_value='silvestro28@example.com', seed_pii_locale='it_IT', scenario='Email confirmation for conference registration with contact details', type_variant='work', generation_id='basic_EMAIL_e94f71d6', timestamp='2025-12-03T23:09:38.691184'), SyntheticSample(text='DEBUG: User auth failed for IT citizen.

## Run Generation

Main execution cell - runs the full generation process.

WARNING: This will make ~11,000+ API calls to xAI. Ensure you have:

1. Sufficient API credits
2. Stable internet connection
3. Time for completion

To resume from a checkpoint, uncomment the load_checkpoint line.


In [13]:
async def run_generation():
    """Execute the full synthetic data generation pipeline."""
    
    print("=" * 60)
    print("SYNTHETIC PII DATA GENERATION")
    print("=" * 60)
    print(f"Start time: {datetime.now().isoformat()}")
    print(f"Target samples: {generator.config.total_samples}")
    print()
    
    # Uncomment to resume from checkpoint:
    # generator.load_checkpoint("./data/synthetic/synthetic_checkpoint_0050.json")
    # start_idx = len(generator.generated_samples)
    start_idx = 0
    
    try:
        samples = await generator.generate_batch(
            start_from=start_idx,
            progress_bar=True,
        )
        
        print()
        print("=" * 60)
        print("GENERATION COMPLETE")
        print("=" * 60)
        
        # Print statistics
        stats = generator.get_statistics()
        print(f"\nTotal generated: {stats['total_generated']}")
        print(f"Success rate: {stats['success_rate']:.1%}")
        
        print("\nBy dimension:")
        for dim, count in stats["by_dimension"].items():
            print(f"  {dim}: {count}")
        
        print("\nBy PII type:")
        for pii_type, count in sorted(stats["by_pii_type"].items(), key=lambda x: -x[1])[:10]:
            print(f"  {pii_type}: {count}")
        
        # Save final dataset
        output_path = generator.save_final_dataset("synthetic_pii_data.json")
        
        return samples
        
    except KeyboardInterrupt:
        print("\n\nGeneration interrupted! Saving checkpoint...")
        generator._save_checkpoint(9998)
        print("Checkpoint saved. Run again to resume.")
        raise
    
    finally:
        await grok_client.close()


# Run the generation
# In Jupyter, use: await run_generation()
# In script, use: asyncio.run(run_generation())

# For Jupyter notebooks:
samples = await run_generation()

SYNTHETIC PII DATA GENERATION
Start time: 2025-12-03T15:09:58.386096
Target samples: 11000

Preparing generation tasks...
Processing 10944 tasks in 274 batches of 40


Processing batches:   0%|          | 0/274 [00:00<?, ?it/s]

  timestamp: str = Field(default_factory=lambda: datetime.utcnow().isoformat())
  "timestamp": datetime.utcnow().isoformat(),


  ðŸ’¾ Checkpoint 1: data\synthetic\synthetic_checkpoint_0001.json
  ðŸ’¾ Checkpoint 2: data\synthetic\synthetic_checkpoint_0002.json
  ðŸ’¾ Checkpoint 3: data\synthetic\synthetic_checkpoint_0003.json
  ðŸ’¾ Checkpoint 4: data\synthetic\synthetic_checkpoint_0004.json
  ðŸ’¾ Checkpoint 5: data\synthetic\synthetic_checkpoint_0005.json
  ðŸ’¾ Checkpoint 7: data\synthetic\synthetic_checkpoint_0007.json
  ðŸ’¾ Checkpoint 8: data\synthetic\synthetic_checkpoint_0008.json
  ðŸ’¾ Checkpoint 9: data\synthetic\synthetic_checkpoint_0009.json
  ðŸ’¾ Checkpoint 10: data\synthetic\synthetic_checkpoint_0010.json
  ðŸ’¾ Checkpoint 11: data\synthetic\synthetic_checkpoint_0011.json
  ðŸ’¾ Checkpoint 13: data\synthetic\synthetic_checkpoint_0013.json
  ðŸ’¾ Checkpoint 14: data\synthetic\synthetic_checkpoint_0014.json
  ðŸ’¾ Checkpoint 15: data\synthetic\synthetic_checkpoint_0015.json
  ðŸ’¾ Checkpoint 16: data\synthetic\synthetic_checkpoint_0016.json
  ðŸ’¾ Checkpoint 17: data\synthetic\synthetic_checkpoin

  "generated_at": datetime.utcnow().isoformat(),


âœ“ Dataset saved to: data\synthetic\synthetic_pii_data.json


## Post-Generation Analysis and Export

Analyze the generated dataset and prepare for the next notebook (validation).


In [14]:
def analyze_dataset(samples: list[SyntheticSample]) -> None:
    """Print comprehensive analysis of generated dataset."""
    
    print("=" * 60)
    print("DATASET ANALYSIS")
    print("=" * 60)
    
    # Basic counts
    print(f"\nTotal samples: {len(samples)}")
    
    # Dimension distribution
    dim_counts = defaultdict(int)
    for s in samples:
        dim_counts[s.feature_dimension.value] += 1
    
    print("\nDistribution by dimension:")
    for dim in FEATURE_DIMENSIONS:
        count = dim_counts.get(dim, 0)
        pct = count / len(samples) * 100 if samples else 0
        print(f"  {dim:15s}: {count:5d} ({pct:5.1f}%)")
    
    # PII type distribution
    pii_counts = defaultdict(int)
    for s in samples:
        pii_counts[s.seed_pii_type] += 1
    
    print("\nDistribution by PII type:")
    for pii_type in sorted(pii_counts.keys()):
        count = pii_counts[pii_type]
        pct = count / len(samples) * 100 if samples else 0
        print(f"  {pii_type:30s}: {count:4d} ({pct:5.1f}%)")
    
    # Locale distribution
    locale_counts = defaultdict(int)
    for s in samples:
        if s.seed_pii_locale:
            locale_counts[s.seed_pii_locale] += 1
    
    print("\nDistribution by locale:")
    for locale in sorted(locale_counts.keys()):
        count = locale_counts[locale]
        pct = count / len(samples) * 100 if samples else 0
        print(f"  {locale:10s}: {count:4d} ({pct:5.1f}%)")
    
    # Entity statistics
    total_entities = sum(len(s.entities) for s in samples)
    avg_entities = total_entities / len(samples) if samples else 0
    
    print(f"\nEntity statistics:")
    print(f"  Total entities: {total_entities}")
    print(f"  Average per sample: {avg_entities:.2f}")
    
    # Text length statistics
    text_lengths = [len(s.text) for s in samples]
    print(f"\nText length statistics:")
    print(f"  Min: {min(text_lengths) if text_lengths else 0}")
    print(f"  Max: {max(text_lengths) if text_lengths else 0}")
    print(f"  Mean: {sum(text_lengths) / len(text_lengths) if text_lengths else 0:.1f}")
    
    # Sample examples
    print("\n" + "=" * 60)
    print("SAMPLE EXAMPLES (one per dimension)")
    print("=" * 60)
    
    shown_dims = set()
    for sample in samples:
        if sample.feature_dimension.value not in shown_dims:
            shown_dims.add(sample.feature_dimension.value)
            print(f"\n[{sample.feature_dimension.value.upper()}]")
            print(f"Text: {sample.text[:200]}...")
            print(f"Entities: {len(sample.entities)}")
            for ent in sample.entities[:3]:
                print(f"  - [{ent.start}:{ent.end}] {ent.label}: '{ent.text}'")
            
            if len(shown_dims) >= 6:
                break


# Run analysis
if 'samples' in dir() and samples:
    analyze_dataset(samples)
else:
    print("No samples generated yet. Run the generation cell first.")

DATASET ANALYSIS

Total samples: 10836

Distribution by dimension:
  basic          :  1818 ( 16.8%)
  contextual     :  1819 ( 16.8%)
  noisy          :  1801 ( 16.6%)
  evolving       :  1813 ( 16.7%)
  multilingual   :  1819 ( 16.8%)
  adversarial    :  1766 ( 16.3%)

Distribution by PII type:
  BANK_ACCOUNT                  :  671 (  6.2%)
  BANK_UPI_ID                   :  680 (  6.3%)
  CREDIT_CARD                   :  670 (  6.2%)
  DATE_OF_BIRTH                 :  684 (  6.3%)
  DRIVER_LICENSE                :  677 (  6.2%)
  EMAIL                         :  672 (  6.2%)
  INSURANCE_NUMBER              :  673 (  6.2%)
  NAME                          :  662 (  6.1%)
  NAMES_OF_PLACES_OR_NOUNS      :  682 (  6.3%)
  NATIONAL_IDENTITY_SSN_AADHAR  :  681 (  6.3%)
  OTHER_NATIONAL_IDENTITY       :  679 (  6.3%)
  PASSPORT_NUMBER               :  684 (  6.3%)
  PHONE                         :  682 (  6.3%)
  POSTAL_CODE                   :  682 (  6.3%)
  TAX_IDENTIFICATION          

## Export to CSV/Parquet

Export the generated data in formats suitable for later validation.


In [15]:
def export_for_validation(
    samples: list[SyntheticSample],
    output_dir: str = "./data/synthetic",
) -> dict[str, str]:
    """
    Export samples in multiple formats for downstream processing.
    
    Args:
        samples: List of generated samples
        output_dir: Output directory path
    
    Returns:
        Dictionary mapping format names to file paths
    """
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    exported_files = {}
    
    # 1. JSON Lines format (one sample per line, for streaming)
    jsonl_path = output_dir / "synthetic_samples.jsonl"
    with open(jsonl_path, "w", encoding="utf-8") as f:
        for sample in samples:
            f.write(json.dumps(sample.model_dump(), ensure_ascii=False) + "\n")
    exported_files["jsonl"] = str(jsonl_path)
    print(f"âœ“ Exported JSONL: {jsonl_path}")
    
    # 2. CSV format (flattened, for quick inspection)
    csv_data = []
    for sample in samples:
        csv_data.append({
            "generation_id": sample.generation_id,
            "text": sample.text,
            "feature_dimension": sample.feature_dimension.value,
            "seed_pii_type": sample.seed_pii_type,
            "seed_pii_value": sample.seed_pii_value,
            "seed_pii_locale": sample.seed_pii_locale,
            "scenario": sample.scenario,
            "type_variant": sample.type_variant,
            "num_entities": len(sample.entities),
            "entities_json": json.dumps([e.model_dump() for e in sample.entities]),
            "timestamp": sample.timestamp,
        })
    
    df = pd.DataFrame(csv_data)
    csv_path = output_dir / "synthetic_samples.csv"
    df.to_csv(csv_path, index=False, encoding="utf-8")
    exported_files["csv"] = str(csv_path)
    print(f"âœ“ Exported CSV: {csv_path}")
    
    # 3. Parquet format (efficient for large datasets)
    try:
        parquet_path = output_dir / "synthetic_samples.parquet"
        df.to_parquet(parquet_path, index=False)
        exported_files["parquet"] = str(parquet_path)
        print(f"âœ“ Exported Parquet: {parquet_path}")
    except Exception as e:
        print(f"Note: Parquet export skipped ({e})")
    
    # 4. Dimension-specific JSON files (for targeted validation)
    for dimension in FEATURE_DIMENSIONS:
        dim_samples = [s for s in samples if s.feature_dimension.value == dimension]
        if dim_samples:
            dim_path = output_dir / f"synthetic_{dimension}.json"
            with open(dim_path, "w", encoding="utf-8") as f:
                json.dump(
                    [s.model_dump() for s in dim_samples],
                    f,
                    indent=2,
                    ensure_ascii=False,
                )
            exported_files[f"json_{dimension}"] = str(dim_path)
    print(f"âœ“ Exported dimension-specific JSON files")
    
    print(f"\nTotal files exported: {len(exported_files)}")
    return exported_files


# Export
if 'samples' in dir() and samples:
    exported = export_for_validation(samples)
    print("\nExported files:")
    for fmt, path in exported.items():
        print(f"  {fmt}: {path}")
else:
    print("No samples to export. Run generation first.")

âœ“ Exported JSONL: data\synthetic\synthetic_samples.jsonl
âœ“ Exported CSV: data\synthetic\synthetic_samples.csv
âœ“ Exported Parquet: data\synthetic\synthetic_samples.parquet
âœ“ Exported dimension-specific JSON files

Total files exported: 9

Exported files:
  jsonl: data\synthetic\synthetic_samples.jsonl
  csv: data\synthetic\synthetic_samples.csv
  parquet: data\synthetic\synthetic_samples.parquet
  json_basic: data\synthetic\synthetic_basic.json
  json_contextual: data\synthetic\synthetic_contextual.json
  json_noisy: data\synthetic\synthetic_noisy.json
  json_evolving: data\synthetic\synthetic_evolving.json
  json_multilingual: data\synthetic\synthetic_multilingual.json
  json_adversarial: data\synthetic\synthetic_adversarial.json
