# Lab 08 — Why Schema Validation in ML/LLM Pipelines

**Focus Area:** Why schema validation — catching upstream drift; protecting training/eval

> This lab is the *why* and *show‑me* for validation. You'll simulate upstream changes (types, ranges, unexpected categories) and see how a light schema gate prevents bad data from reaching LLM‑adjacent stages.

---

## Outcomes

By the end of this lab, you will be able to:

1. Explain the difference between **structural drift** (columns/types) and **semantic drift** (values/ranges/categories), and why each harms LLM workflows.
2. Add a **pre‑flight validation gate** that fails fast with actionable messages.
3. Use a **minimal Pandera schema** (or Pydantic model per row) to enforce types, ranges, and categorical sets.
4. Capture **human‑readable failure reports** for debugging, CI, and incident triage.

## Prerequisites & Setup

- Python 3.13 with `pandas`, `numpy`, `pandera`, `pydantic`, `pyarrow` installed.  
- JupyterLab or VS Code with Jupyter extension.
- Artifacts from previous labs (optional but recommended): `artifacts/clean/per_customer.parquet` or `users_clean.parquet`  

**Start a notebook:** `week02_lab08.ipynb`

If you don't have prior artifacts, synthesize a small frame now:

In [None]:
import numpy as np, pandas as pd
rng = np.random.default_rng(42)
users2 = pd.DataFrame({
    'CustomerID': [f'C{i:05d}' for i in range(300)],
    'country_norm': rng.choice(['USA','DE','SG','BR'], size=300, p=[.55,.2,.15,.1]),
    'age': rng.integers(16, 80, size=300).astype('int64'),
    'ltv_usd': np.round(np.clip(rng.lognormal(3.0, 0.7, size=300), 0, 5e4), 2),
    'is_adult': (rng.integers(16, 80, size=300) >= 18),
    'is_high_value': rng.random(300) > 0.85,
})
users2.head()

## Part A — What can go wrong, concretely?

In LLM/ML pipelines, silent data drift can:

- **Break transforms** (e.g., `to_datetime` fails after a type flip from string→int).
- **Bias metrics** (e.g., new country labels split a cohort: `U.S.A.` appears again).
- **Explode tokens/costs** (e.g., unexpectedly long text fields; numeric → string inflation).
- **Poison eval/train** (e.g., negative prices; out‑of‑range ages; missing required keys).

**Exercise:** Create 3 synthetic drifts.

In [None]:
broken = users2.copy()
# 1) Structural drift: age becomes string for some rows
broken.loc[broken.index[:20], 'age'] = broken.loc[broken.index[:20], 'age'].astype(str)
# 2) Semantic drift: country label out of policy
broken.loc[10:15, 'country_norm'] = ['U.S.A.','United States','usa','US','USA','USA']
# 3) Range drift: negative ltv sneaks in
broken.loc[50:55, 'ltv_usd'] = [-10, -5, -1, 0, 1, 2]
broken.head()

## Part B — Minimal Pandera schema as a gate

We'll define a small DataFrame schema to catch the above.

In [None]:
%pip install pandera

import pandera.pandas as pa
from pandera import Column, Check

Schema = pa.DataFrameSchema({
    'CustomerID': Column(object, nullable=False),
    'country_norm': Column(object, Check.isin(['USA','DE','SG','BR']), nullable=False),
    'age': Column(pa.Int64, Check.in_range(0, 120), nullable=False),
    'ltv_usd': Column(float, Check.ge(0), nullable=False),
    'is_adult': Column(bool, nullable=False),
    'is_high_value': Column(bool, nullable=False),
})

### B1. Validate clean vs broken

In [None]:
# Clean should pass
ok = Schema.validate(users2, lazy=True)
print('clean rows:', len(ok))

# Broken should fail with a report
try:
    Schema.validate(broken, lazy=True)
except pa.errors.SchemaErrors as err:
    report = err.failure_cases
    print(report.head(10))

**Checkpoint:** Inspect `report` to see: wrong dtype (`age`), out‑of‑set categories (`country_norm`), and negative values (`ltv_usd`).

### B2. Actionable messages for CI / logs

In [None]:
# Summarize by column + failure type
summary = (report
           .groupby(['column', 'check'])
           .size()
           .reset_index(name='failures')
           .sort_values('failures', ascending=False))
summary

> **Interpretation:** This summary is what you'd attach to a CI artifact or Slack alert.

## Part C — Row‑level validation with Pydantic (optional)

Use Pydantic models when you're validating **per‑row payloads** (e.g., API messages) or writing contracts across services.

In [None]:
%pip install pydantic

from pydantic import BaseModel, Field, ValidationError
from typing import Literal

class CustomerRow(BaseModel):
    CustomerID: str
    country_norm: Literal['USA','DE','SG','BR']
    age: int = Field(ge=0, le=120)
    ltv_usd: float = Field(ge=0)
    is_adult: bool
    is_high_value: bool

row = users2.iloc[0].to_dict()
CustomerRow(**row)

In [None]:
try:
    CustomerRow(**broken.iloc[12].to_dict())
except ValidationError as e:
    print(e)

**When to prefer Pydantic:** API boundaries, message queues, microservices. **When to prefer Pandera:** bulk DataFrame validation in ETL/ELT.

## Part D — Pre‑flight gate function + fail‑fast

Wrap the schema check in a reusable function that raises a concise, friendly error and writes a CSV report for triage.

In [None]:
from pathlib import Path

def validate_or_raise(df: pd.DataFrame, schema: pa.DataFrameSchema, name: str, out_dir: str = 'artifacts/validation') -> pd.DataFrame:
    Path(out_dir).mkdir(parents=True, exist_ok=True)
    try:
        return schema.validate(df, lazy=True)
    except pa.errors.SchemaErrors as err:
        rep = err.failure_cases
        dest = Path(out_dir) / f'{name}_schema_failures.csv'
        rep.to_csv(dest, index=False)
        # compact message for logs/CI
        top = (rep.groupby(['column','check']).size().reset_index(name='n')
                 .sort_values('n', ascending=False).head(5).to_dict(orient='records'))
        raise RuntimeError(f"Validation failed for {name}. Top issues: {top}. See {dest}")

In [None]:
# Example usage
_ = validate_or_raise(users2, Schema, name='users2_clean')
try:
    _ = validate_or_raise(broken, Schema, name='users2_broken')
except RuntimeError as e:
    print('\nGATE BLOCKED ->', e)

## Part E — Wrap‑Up

Add a markdown cell and answer:

1. Name one structural and one semantic drift you simulated. How would each impact an LLM component downstream?  
2. Paste the top 3 failure types from your summary and propose a remediation (fix in source vs transform rule).  
3. Where would you place this validation gate in your Day‑1/Day‑2 pipeline, and why?

### Your answers here:

**1. Structural and Semantic Drift:**

- **Structural drift:** Age column changed from Int64 to string for some rows. This would break downstream operations like numerical aggregations, comparisons, or ML model features that expect numeric input.
- **Semantic drift:** Country labels changed to non-standard values ('U.S.A.', 'United States', 'usa', 'US'). This would cause category splits in LLM prompts, create duplicate embeddings for the same entity, and inflate token usage unnecessarily.

**2. Top 3 Failure Types:**

*(Fill in based on your actual output)*

Example:
- `age` dtype mismatch: Fix at source by enforcing integer type constraints in upstream system
- `country_norm` out of allowed set: Add transform rule to normalize variants before validation
- `ltv_usd` negative values: Fix at source by adding database constraint or API validation

**3. Validation Gate Placement:**

I would place this validation gate:
- **Immediately after data ingestion** (Day-1) to fail fast before any transformation or enrichment
- **Before feature engineering** (Day-2) to protect ML/LLM components from corrupted inputs
- **As a CI/CD check** before promoting data to production storage

Why: Early detection minimizes wasted processing, prevents cascading errors, and provides clear diagnostic information at the point of failure.

---

**Common pitfalls:** Using `object` dtypes everywhere; not distinguishing `structural` vs `semantic` drift; over‑fitting schemas (too strict for expected evolution).

## Solution Snippets (reference)

**Quick failure roll‑up:**

In [None]:
summary = (report.groupby(['column','check'])
           .size().reset_index(name='failures')
           .sort_values('failures', ascending=False))
summary.head()

**CI‑style assert:**

In [None]:
assert Schema.validate(users2, lazy=True) is not None

**Lightweight allow‑list for categories:**

In [None]:
allowed = {'USA','DE','SG','BR'}
viol = set(broken['country_norm']) - allowed
viol

## Bonus: Real-World Example with Lab Artifacts

Let's apply the validation concepts to the actual `per_customer.parquet` file from previous labs to see how schema validation works with real data.

In [None]:
# Load the real customer data from previous labs
import os
from pathlib import Path

# Check if the per_customer.parquet file exists
artifact_path = Path('artifacts/clean/per_customer.parquet')
if artifact_path.exists():
    print(f"✓ Found artifact: {artifact_path}")
    real_customers = pd.read_parquet(artifact_path)
    print(f"Shape: {real_customers.shape}")
    print(f"Columns: {list(real_customers.columns)}")
    print("\nFirst few rows:")
    print(real_customers.head())
    print("\nData types:")
    print(real_customers.dtypes)
else:
    print(f"✗ Artifact not found at {artifact_path}")
    print("Available files in artifacts/clean/:")
    clean_dir = Path('artifacts/clean')
    if clean_dir.exists():
        for f in clean_dir.iterdir():
            print(f"  - {f.name}")
    else:
        print("  artifacts/clean directory not found")

In [None]:
# Create a dynamic schema based on the actual data structure
if 'real_customers' in locals() and not real_customers.empty:
    print("Creating schema based on actual data structure...")
    
    # Inspect unique values in categorical columns
    categorical_cols = real_customers.select_dtypes(include=['object']).columns
    print(f"\nCategorical columns: {list(categorical_cols)}")
    
    for col in categorical_cols:
        unique_vals = real_customers[col].unique()
        print(f"{col}: {len(unique_vals)} unique values")
        if len(unique_vals) <= 10:
            print(f"  Values: {list(unique_vals)}")
        else:
            print(f"  Sample: {list(unique_vals[:5])}...")
    
    # Check for potential issues
    print("\n" + "="*50)
    print("POTENTIAL DATA QUALITY ISSUES:")
    print("="*50)
    
    # Check for null values
    nulls = real_customers.isnull().sum()
    if nulls.any():
        print("\n🚨 Null values detected:")
        for col, count in nulls[nulls > 0].items():
            print(f"  {col}: {count} nulls ({count/len(real_customers)*100:.1f}%)")
    else:
        print("\n✅ No null values found")
    
    # Check for negative values in numeric columns
    numeric_cols = real_customers.select_dtypes(include=[np.number]).columns
    print(f"\n📊 Numeric columns: {list(numeric_cols)}")
    
    for col in numeric_cols:
        if (real_customers[col] < 0).any():
            neg_count = (real_customers[col] < 0).sum()
            print(f"🚨 {col}: {neg_count} negative values")
        else:
            print(f"✅ {col}: No negative values")
    
    print(f"\n📈 Data summary:")
    print(real_customers.describe())

In [None]:
# Build a Pandera schema for the real customer data
if 'real_customers' in locals() and not real_customers.empty:
    print("Building Pandera schema for real customer data...\n")
    
    # Create schema constraints based on discovered data patterns
    real_schema_constraints = {}
    
    for col in real_customers.columns:
        dtype = real_customers[col].dtype
        print(f"Processing {col} ({dtype})...")
        
        if pd.api.types.is_object_dtype(dtype):
            # For categorical columns, create allow-list from unique values
            unique_vals = real_customers[col].dropna().unique()
            if len(unique_vals) <= 20:  # Only create strict validation for small sets
                real_schema_constraints[col] = Column(object, Check.isin(unique_vals), nullable=real_customers[col].isnull().any())
                print(f"  → Created allow-list with {len(unique_vals)} values")
            else:
                real_schema_constraints[col] = Column(object, nullable=real_customers[col].isnull().any())
                print(f"  → Too many values ({len(unique_vals)}), using basic object validation")
        
        elif pd.api.types.is_numeric_dtype(dtype):
            # For numeric columns, set reasonable range constraints
            min_val = real_customers[col].min()
            max_val = real_customers[col].max()
            
            # Build checks based on data characteristics
            checks = []
            if min_val >= 0:  # If all values are non-negative, enforce that
                checks.append(Check.ge(0))
                print(f"  → Added non-negative constraint (min: {min_val})")
            
            if max_val < 1000:  # Add upper bound for reasonable ranges
                checks.append(Check.le(max_val * 1.1))  # 10% buffer
                print(f"  → Added upper bound: {max_val * 1.1}")
            
            real_schema_constraints[col] = Column(dtype, checks, nullable=real_customers[col].isnull().any())
        
        elif pd.api.types.is_bool_dtype(dtype):
            real_schema_constraints[col] = Column(bool, nullable=real_customers[col].isnull().any())
            print(f"  → Boolean validation")
        
        else:
            real_schema_constraints[col] = Column(dtype, nullable=real_customers[col].isnull().any())
            print(f"  → Generic validation for {dtype}")
    
    # Create the schema
    RealCustomerSchema = pa.DataFrameSchema(real_schema_constraints)
    print(f"\n✅ Created schema with {len(real_schema_constraints)} column validations")

In [None]:
# Test the schema on the real data
if 'RealCustomerSchema' in locals() and 'real_customers' in locals():
    print("Testing schema validation on real customer data...\n")
    
    try:
        validated_data = RealCustomerSchema.validate(real_customers, lazy=True)
        print("✅ VALIDATION PASSED!")
        print(f"Successfully validated {len(validated_data)} rows")
    except pa.errors.SchemaErrors as err:
        print("🚨 VALIDATION FAILED!")
        failures = err.failure_cases
        print(f"Found {len(failures)} validation failures:")
        print(failures.groupby(['column', 'check']).size().reset_index(name='count').head(10))
    
    # Now let's simulate some data corruption and see how our schema catches it
    print("\n" + "="*60)
    print("SIMULATING DATA CORRUPTION SCENARIOS:")
    print("="*60)
    
    # Create a corrupted version of real data
    corrupted_real = real_customers.copy()
    
    # Scenario 1: Introduce negative values in a numeric column
    if len(real_customers.select_dtypes(include=[np.number]).columns) > 0:
        numeric_col = real_customers.select_dtypes(include=[np.number]).columns[0]
        corrupted_real.loc[:5, numeric_col] = -999
        print(f"Scenario 1: Added negative values to {numeric_col}")
    
    # Scenario 2: Add invalid categories to a categorical column
    if len(real_customers.select_dtypes(include=['object']).columns) > 0:
        cat_col = real_customers.select_dtypes(include=['object']).columns[0]
        corrupted_real.loc[:3, cat_col] = 'INVALID_CATEGORY'
        print(f"Scenario 2: Added invalid category to {cat_col}")
    
    # Test the corrupted data
    try:
        RealCustomerSchema.validate(corrupted_real, lazy=True)
        print("❌ Schema failed to catch corruption!")
    except pa.errors.SchemaErrors as err:
        print(f"\n✅ Schema successfully caught {len(err.failure_cases)} corruption issues:")
        failure_summary = (err.failure_cases
                          .groupby(['column', 'check'])
                          .size()
                          .reset_index(name='failures')
                          .sort_values('failures', ascending=False))
        print(failure_summary)