# Data Integrity & Customer Uniqueness Validation

## Overview
This notebook validates the fundamental data integrity constraint: **customer-level isolation across chunks**.

### Key Validation Rules:
1. **Within each chunk**: A customer_id should appear multiple times (one per effective_date)
2. **Across chunks**: A customer_id should appear in EXACTLY ONE chunk (all their samples in one chunk)
3. **Chunk 0-255**: Train customers (in-time data before 2024)
4. **Chunk 256-319**: Validation customers (in-time data before 2024)
5. **OOT data**: 2024 effective dates present across all chunks

### Outputs:
- Customer count per chunk
- Sample count per chunk by time period
- Cross-chunk customer_id duplication report (should be ZERO)
- Split confirmation (train/valid/OOT)

---


## 1. Environment Setup


In [None]:
# Import required libraries
from pyspark.sql import functions as F
from pyspark.sql import Window
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, FloatType
import pandas as pd
import numpy as np
import json
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from io import BytesIO, StringIO

# Set display options
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 100)
sns.set_style("whitegrid")

print("✓ Libraries imported successfully")


## 2. Helper Functions for ADLS File Operations

Functions to save CSV and PNG files directly to ADLS paths.


In [None]:
# Helper functions to save files directly to ADLS rather than DBFS
def save_pandas_to_csv_adls(df_pandas, adls_path):
    """Save pandas DataFrame to ADLS as CSV using dbutils"""
    # Convert to CSV string
    csv_string = df_pandas.to_csv(index=False)
    # Write to ADLS using dbutils
    dbutils.fs.put(adls_path, csv_string, overwrite=True)
    print(f"✓ Saved CSV to {adls_path}")

def save_plot_to_adls(fig, adls_path, dpi=150):
    """Save matplotlib figure to ADLS as PNG using dbutils"""
    # Save to bytes buffer
    buf = BytesIO()
    fig.savefig(buf, format='png', dpi=dpi, bbox_inches='tight')
    buf.seek(0)
    # Convert to base64 for dbutils (or save to temp local then copy)
    # Alternative: save locally to DBFS first, then copy to ADLS
    import tempfile
    import os
    with tempfile.NamedTemporaryFile(mode='wb', suffix='.png', delete=False) as tmp:
        tmp.write(buf.getvalue())
        tmp_path = tmp.name
    
    # Copy from local temp to ADLS
    dbutils.fs.cp(f"file:{tmp_path}", adls_path)
    os.remove(tmp_path)
    print(f"✓ Saved plot to {adls_path}")

print("✓ ADLS helper functions defined")


## 3. Configuration & Paths

Define data paths and key parameters for the validation.


In [None]:
# ==================== CONFIGURATION ====================

# Data paths on ADLS
DATA_PATH = "abfss://home@edaaaazepcalayelaye0001.dfs.core.windows.net/MD_Artifacts/money-out/data/"

# Target paths (contains customer_id and effective_date)
TARGET_TRAIN_VAL_PATH = DATA_PATH + "target/cust/all_products_chunk_320/train_val/"
TARGET_TEST_PATH = DATA_PATH + "target/cust/all_products_chunk_320/test/"

# Output paths for results
OUTPUT_PATH = "abfss://home@edaaaazepcalayelaye0001.dfs.core.windows.net/MD_Artifacts/money-out/mv/eda_validation/data_integrity/"
dbutils.fs.mkdirs(OUTPUT_PATH)

# Chunk configuration
TOTAL_CHUNKS = 320
TRAIN_CHUNKS = list(range(0, 256))  # 256 chunks: 0-255
VALID_CHUNKS = list(range(256, 320))  # 64 chunks: 256-319
ALL_CHUNKS = list(range(0, 320))

# Time split configuration
OOT_START_DATE = '2024-01-01'  # Out-of-time starts from 2024

# Sampling configuration (for quick testing)
SAMPLING_RATIO = 1.0  # 1.0 = 100% (use all data), 0.01 = 1% sample

print(f"✓ Configuration loaded")
print(f"  - Total chunks: {TOTAL_CHUNKS}")
print(f"  - Train chunks: {len(TRAIN_CHUNKS)} (0-255)")
print(f"  - Valid chunks: {len(VALID_CHUNKS)} (256-319)")
print(f"  - OOT start date: {OOT_START_DATE}")
print(f"  - Sampling ratio: {SAMPLING_RATIO*100}%")


## 4. Load Target Data

Load train_val and test target data to extract customer_id and effective_date information.


In [None]:
# Load train_val data (in-time: train + validation customers)
# CSV reader doesn't auto-discover partitions like Parquet, so we load all chunks and add chunk column manually
print("Loading train_val data from all chunks...")

# Load all chunks by iterating through chunk directories
df_train_val = None
for chunk_id in range(TOTAL_CHUNKS):
    chunk_path = f"{TARGET_TRAIN_VAL_PATH}chunk={chunk_id}/"
    try:
        df_chunk = spark.read.option("delimiter", ",") \
            .option("quoteMode", "NONE") \
            .option("header", "true") \
            .option("escape", "\\") \
            .csv(chunk_path)
        
        # Add chunk column
        df_chunk = df_chunk.withColumn("chunk", F.lit(chunk_id))
        
        # Union with existing data
        if df_train_val is None:
            df_train_val = df_chunk
        else:
            df_train_val = df_train_val.union(df_chunk)
    except:
        pass  # Skip if chunk doesn't exist

print(f"  Loaded {TOTAL_CHUNKS} chunks for train_val")

# Load test data (OOT: 2024 data)
print("Loading test data from all chunks...")
df_test = None
for chunk_id in range(TOTAL_CHUNKS):
    chunk_path = f"{TARGET_TEST_PATH}chunk={chunk_id}/"
    try:
        df_chunk = spark.read.option("delimiter", ",") \
            .option("quoteMode", "NONE") \
            .option("header", "true") \
            .option("escape", "\\") \
            .csv(chunk_path)
        
        # Add chunk column
        df_chunk = df_chunk.withColumn("chunk", F.lit(chunk_id))
        
        # Union with existing data
        if df_test is None:
            df_test = df_chunk
        else:
            df_test = df_test.union(df_chunk)
    except:
        pass  # Skip if chunk doesn't exist

print(f"  Loaded {TOTAL_CHUNKS} chunks for test")

# Rename columns for consistency
df_train_val = df_train_val.withColumnRenamed("pid", "cust_id") \
    .withColumnRenamed("pred_dt", "efectv_dt") \
    .withColumn("data_split", F.lit("in-time"))

df_test = df_test.withColumnRenamed("pid", "cust_id") \
    .withColumnRenamed("pred_dt", "efectv_dt") \
    .withColumn("data_split", F.lit("OOT"))

# Select only necessary columns and cache
cols_needed = ["cust_id", "efectv_dt", "chunk", "data_split"]
df_train_val = df_train_val.select(cols_needed)
df_test = df_test.select(cols_needed)

# Apply sampling if needed
if SAMPLING_RATIO < 1.0:
    print(f"Applying {SAMPLING_RATIO*100}% sampling...")
    df_train_val = df_train_val.sample(fraction=SAMPLING_RATIO, withReplacement=False, seed=42)
    df_test = df_test.sample(fraction=SAMPLING_RATIO, withReplacement=False, seed=42)

# Union all data
df_all = df_train_val.union(df_test).cache()

print(f"✓ Data loaded successfully")
print(f"  - Total rows: {df_all.count():,}")
print(f"  - Columns: {df_all.columns}")
df_all.printSchema()


## 5. Chunk-Level Summary Statistics

Calculate basic statistics for each chunk to understand data distribution.


In [None]:
# Calculate statistics per chunk
print("Calculating chunk-level statistics...")

chunk_stats = df_all.groupBy("chunk").agg(
    F.countDistinct("cust_id").alias("unique_customers"),
    F.count("*").alias("total_samples"),
    F.min("efectv_dt").alias("min_date"),
    F.max("efectv_dt").alias("max_date")
).orderBy("chunk")

# Convert to pandas for display and save
chunk_stats_pd = chunk_stats.toPandas()
chunk_stats_pd['chunk'] = chunk_stats_pd['chunk'].astype(int)
chunk_stats_pd['chunk_type'] = chunk_stats_pd['chunk'].apply(
    lambda x: 'train_OOT' if x < 256 else 'valid_OOT'
)

# Save to CSV on ADLS
save_pandas_to_csv_adls(chunk_stats_pd, OUTPUT_PATH + "chunk_summary_statistics.csv")

# Display summary
print("\n" + "="*80)
print("CHUNK SUMMARY STATISTICS")
print("="*80)
print(f"\nFirst 10 chunks:")
print(chunk_stats_pd.head(10))
print(f"\nLast 10 chunks:")
print(chunk_stats_pd.tail(10))

# Aggregate statistics
print("\n" + "="*80)
print("AGGREGATE STATISTICS BY CHUNK TYPE")
print("="*80)
summary_by_type = chunk_stats_pd.groupby('chunk_type').agg({
    'unique_customers': ['sum', 'mean', 'std'],
    'total_samples': ['sum', 'mean', 'std']
})
print(summary_by_type)


## 6. Critical Validation: Customer Uniqueness Across Chunks

**MOST IMPORTANT CHECK**: Verify that each customer_id appears in exactly ONE chunk folder.

This is the core isolation constraint - if violated, the modeling setup is broken.


In [None]:
print("="*80)
print("CRITICAL VALIDATION: CUSTOMER UNIQUENESS ACROSS CHUNKS")
print("="*80)
print("\nChecking if any customer_id appears in multiple chunk folders...")
print("(This should return ZERO violations)\n")

# Find customers that appear in multiple chunks
customer_chunk_mapping = df_all.select("cust_id", "chunk").distinct()

customer_chunk_count = customer_chunk_mapping.groupBy("cust_id").agg(
    F.countDistinct("chunk").alias("num_chunks"),
    F.collect_set("chunk").alias("chunks")
)

# Find violations (customers in multiple chunks)
violations = customer_chunk_count.filter(F.col("num_chunks") > 1)
violation_count = violations.count()

if violation_count == 0:
    print("✓✓✓ VALIDATION PASSED ✓✓✓")
    print("   No customer appears in multiple chunks.")
    print("   Customer-level isolation is maintained correctly!")
else:
    print("✗✗✗ VALIDATION FAILED ✗✗✗")
    print(f"   Found {violation_count:,} customers appearing in multiple chunks!")
    print("   This violates the per-customer isolation constraint.")
    print("\nFirst 20 violations:")
    violations_pd = violations.limit(20).toPandas()
    print(violations_pd)
    
    # Save violations to file
    violations_full = violations.toPandas()
    save_pandas_to_csv_adls(violations_full, OUTPUT_PATH + "CRITICAL_customer_chunk_violations.csv")
    print(f"\n   Full violation list saved")

# Summary statistics
print("\n" + "="*80)
print("CUSTOMER DISTRIBUTION SUMMARY")
print("="*80)
total_unique_customers = customer_chunk_count.count()
print(f"Total unique customers across all chunks: {total_unique_customers:,}")
print(f"Customers with violations (in multiple chunks): {violation_count:,}")
print(f"Violation rate: {(violation_count/total_unique_customers*100):.4f}%")


## 7. Split Confirmation: Train / Valid / OOT

Verify the split configuration:
- **Train**: Chunks 0-255, dates before 2024
- **Valid**: Chunks 256-319, dates before 2024  
- **OOT**: All chunks, dates in 2024+


In [None]:
print("="*80)
print("SPLIT CONFIRMATION: TRAIN / VALID / OOT")
print("="*80)

# Add split labels based on chunk and date
df_with_split = df_all.withColumn(
    "split_label",
    F.when(
        (F.col("efectv_dt") >= OOT_START_DATE), "OOT"
    ).when(
        (F.col("chunk") < 256) & (F.col("efectv_dt") < OOT_START_DATE), "train"
    ).when(
        (F.col("chunk") >= 256) & (F.col("efectv_dt") < OOT_START_DATE), "valid"
    ).otherwise("unknown")
)

# Calculate split statistics
split_stats = df_with_split.groupBy("split_label").agg(
    F.countDistinct("cust_id").alias("unique_customers"),
    F.count("*").alias("total_samples"),
    F.countDistinct("chunk").alias("num_chunks"),
    F.min("efectv_dt").alias("min_date"),
    F.max("efectv_dt").alias("max_date")
).orderBy("split_label")

split_stats_pd = split_stats.toPandas()
save_pandas_to_csv_adls(split_stats_pd, OUTPUT_PATH + "split_confirmation_statistics.csv")

print("\nSplit Statistics:")
print(split_stats_pd.to_string(index=False))

# Check for unknown splits (should be zero)
unknown_count = split_stats_pd[split_stats_pd['split_label'] == 'unknown']['total_samples'].sum()
if unknown_count > 0:
    print(f"\n✗ WARNING: Found {unknown_count:,} samples with 'unknown' split label!")
else:
    print("\n✓ All samples correctly assigned to train/valid/OOT splits")

# Calculate customer overlap between in-time and OOT
print("\n" + "="*80)
print("CUSTOMER OVERLAP: IN-TIME vs OOT")
print("="*80)

intime_customers = df_with_split.filter(
    F.col("split_label").isin(["train", "valid"])
).select("cust_id").distinct()

oot_customers = df_with_split.filter(
    F.col("split_label") == "OOT"
).select("cust_id").distinct()

# Count overlap
overlap = intime_customers.join(oot_customers, on="cust_id", how="inner").count()
total_intime = intime_customers.count()
total_oot = oot_customers.count()

print(f"In-time unique customers: {total_intime:,}")
print(f"OOT unique customers: {total_oot:,}")
print(f"Customers appearing in both: {overlap:,}")
print(f"Overlap percentage: {(overlap/total_intime*100):.2f}% of in-time customers")
print("\nNote: Overlap is expected since same customers appear in both time periods.")


## 8. Temporal Distribution by Split

Analyze sample distribution over time for each split.


In [None]:
print("Analyzing temporal distribution by split...")

# Count samples by month and split
temporal_dist = df_with_split.groupBy("efectv_dt", "split_label").agg(
    F.countDistinct("cust_id").alias("unique_customers"),
    F.count("*").alias("total_samples")
).orderBy("efectv_dt", "split_label")

temporal_dist_pd = temporal_dist.toPandas().drop("unique_customers", axis=1)
save_pandas_to_csv_adls(temporal_dist_pd, OUTPUT_PATH + "temporal_distribution_by_split.csv")

# Plot temporal distribution
fig, axes = plt.subplots(2, 1, figsize=(16, 10))

# Plot 1: Total samples over time by split (no axes, just one plot)
plt.figure(figsize=(16, 6))
for split in ['train', 'valid', 'OOT']:
    split_data = temporal_dist_pd[temporal_dist_pd['split_label'] == split]
    if len(split_data) > 0:
        plt.plot(split_data['efectv_dt'], split_data['total_samples'], 
                 marker='o', label=split, linewidth=2)

plt.xlabel('Effective Date', fontsize=12)
plt.ylabel('Total Samples', fontsize=12)
plt.title('Sample Count Over Time by Split', fontsize=14, fontweight='bold')
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.xticks(rotation=45)


plt.tight_layout()
save_plot_to_adls(fig, OUTPUT_PATH + "temporal_distribution_plots.png", dpi=150)
plt.show()


## 9. Summary Report

Generate final summary of all validation checks.


In [None]:
print("\n" + "="*80)
print("FINAL VALIDATION SUMMARY REPORT")
print("="*80)

# Collect all key metrics
summary_report = {
    "validation_date": pd.Timestamp.now().strftime("%Y-%m-%d %H:%M:%S"),
    "sampling_ratio": SAMPLING_RATIO,
    "total_chunks": TOTAL_CHUNKS,
    "train_chunks": len(TRAIN_CHUNKS),
    "valid_chunks": len(VALID_CHUNKS),
    "total_samples": int(df_all.count()),
    "total_unique_customers": int(customer_chunk_count.count()),
    "customer_chunk_violations": int(violation_count),
    "validation_passed": violation_count == 0,
}

# Add split statistics
for _, row in split_stats_pd.iterrows():
    split = row['split_label']
    summary_report[f"{split}_unique_customers"] = int(row['unique_customers'])
    summary_report[f"{split}_total_samples"] = int(row['total_samples'])
    summary_report[f"{split}_num_chunks"] = int(row['num_chunks'])

# Save summary
summary_df = pd.DataFrame([summary_report])
save_pandas_to_csv_adls(summary_df, OUTPUT_PATH + "validation_summary_report.csv")

# Print formatted summary
print("\nValidation Results:")
print(f"  Timestamp: {summary_report['validation_date']}")
print(f"  Sampling Ratio: {summary_report['sampling_ratio']*100}%")
print(f"\nData Overview:")
print(f"  Total Samples: {summary_report['total_samples']:,}")
print(f"  Total Unique Customers: {summary_report['total_unique_customers']:,}")
print(f"\nChunk Configuration:")
print(f"  Total Chunks: {summary_report['total_chunks']}")
print(f"  Train Chunks: {summary_report['train_chunks']} (0-255)")
print(f"  Valid Chunks: {summary_report['valid_chunks']} (256-319)")
print(f"\nCritical Validation:")
print(f"  Customer Chunk Violations: {summary_report['customer_chunk_violations']:,}")
print(f"  Validation Status: {'✓ PASSED' if summary_report['validation_passed'] else '✗ FAILED'}")

if 'train_unique_customers' in summary_report:
    print(f"\nSplit Statistics:")
    print(f"  Train - Customers: {summary_report.get('train_unique_customers', 0):,}, "
          f"Samples: {summary_report.get('train_total_samples', 0):,}")
    print(f"  Valid - Customers: {summary_report.get('valid_unique_customers', 0):,}, "
          f"Samples: {summary_report.get('valid_total_samples', 0):,}")
    print(f"  OOT   - Customers: {summary_report.get('OOT_unique_customers', 0):,}, "
          f"Samples: {summary_report.get('OOT_total_samples', 0):,}")

print(f"\n✓ Summary report saved to {OUTPUT_PATH}validation_summary_report.csv")
print("\n" + "="*80)
print("VALIDATION COMPLETE")
print("="*80)


## 10. Cleanup

Unpersist cached dataframes to free up memory.


In [None]:
# Unpersist cached data
df_all.unpersist()
print("✓ Memory cleaned up")
