# Data Quality Exploration for CRISP Pipeline

This notebook helps you understand and assess your OMOP CDM data quality before running the CRISP pipeline.

**Purpose:**
- Examine data completeness and quality
- Identify potential issues before processing
- Generate quality metrics and visualizations

**Expected runtime:** ~5 minutes for 1000 patients

## Part 1: Setup and Data Loading

In [None]:
# Import required libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Set display options
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 100)

# Configure visualization
plt.style.use('default')
sns.set_palette('husl')

print("Libraries imported successfully")

In [None]:
# Load OMOP data tables
data_dir = Path('../data')

# Define required tables
required_tables = [
    'PERSON', 'VISIT_OCCURRENCE', 'CONDITION_OCCURRENCE',
    'PROCEDURE_OCCURRENCE', 'DRUG_EXPOSURE', 'MEASUREMENT',
    'OBSERVATION'
]

# Load available tables
tables = {}
missing_tables = []

for table_name in required_tables:
    file_path = data_dir / f"{table_name}.csv"
    if file_path.exists():
        print(f"Loading {table_name}...", end=' ')
        tables[table_name] = pd.read_csv(file_path)
        print(f"✓ ({len(tables[table_name]):,} rows)")
    else:
        missing_tables.append(table_name)
        print(f"⚠️  {table_name} not found")

if missing_tables:
    print(f"\n⚠️  Missing {len(missing_tables)} tables: {missing_tables}")
else:
    print(f"\n✅ All {len(required_tables)} required tables loaded successfully")

## Part 2: Basic Quality Checks

In [None]:
# Table sizes and basic statistics
table_stats = []

for name, df in tables.items():
    stats = {
        'Table': name,
        'Rows': len(df),
        'Columns': len(df.columns),
        'Memory (MB)': df.memory_usage(deep=True).sum() / 1024**2,
        'Null %': (df.isnull().sum().sum() / (len(df) * len(df.columns)) * 100) if len(df) > 0 else 0
    }
    table_stats.append(stats)

stats_df = pd.DataFrame(table_stats).round(2)
stats_df = stats_df.sort_values('Rows', ascending=False)

print("📊 Table Statistics Summary:")
print("="*60)
print(stats_df.to_string(index=False))
print(f"\nTotal rows across all tables: {stats_df['Rows'].sum():,}")
print(f"Total memory usage: {stats_df['Memory (MB)'].sum():.1f} MB")

In [None]:
# Check for missing values in critical columns
print("🔍 Missing Values in Critical Columns:")
print("="*60)

critical_columns = {
    'PERSON': ['person_id', 'gender_concept_id', 'year_of_birth'],
    'VISIT_OCCURRENCE': ['visit_occurrence_id', 'person_id', 'visit_start_date'],
    'MEASUREMENT': ['measurement_id', 'person_id', 'measurement_concept_id'],
    'CONDITION_OCCURRENCE': ['condition_occurrence_id', 'person_id', 'condition_concept_id']
}

for table_name, columns in critical_columns.items():
    if table_name in tables:
        df = tables[table_name]
        print(f"\n{table_name}:")
        for col in columns:
            if col in df.columns:
                null_count = df[col].isnull().sum()
                null_pct = (null_count / len(df) * 100) if len(df) > 0 else 0
                if null_count > 0:
                    print(f"  ⚠️  {col}: {null_count:,} nulls ({null_pct:.1f}%)")
                else:
                    print(f"  ✓ {col}: no nulls")

In [None]:
# Check for duplicate person_ids
if 'PERSON' in tables:
    person_df = tables['PERSON']
    total_persons = len(person_df)
    unique_persons = person_df['person_id'].nunique()
    duplicates = total_persons - unique_persons
    
    print("👥 Person ID Analysis:")
    print("="*60)
    print(f"Total person records: {total_persons:,}")
    print(f"Unique person IDs: {unique_persons:,}")
    
    if duplicates > 0:
        print(f"⚠️  Found {duplicates} duplicate person_id entries")
        # Show duplicate IDs
        dup_ids = person_df[person_df.duplicated('person_id', keep=False)]['person_id'].value_counts().head()
        print("\nTop duplicate person_ids:")
        print(dup_ids)
    else:
        print("✅ No duplicate person_ids found")

In [None]:
# Date range analysis
print("📅 Temporal Coverage Analysis:")
print("="*60)

date_columns = {
    'VISIT_OCCURRENCE': ['visit_start_date', 'visit_end_date'],
    'CONDITION_OCCURRENCE': ['condition_start_date'],
    'MEASUREMENT': ['measurement_date'],
    'DRUG_EXPOSURE': ['drug_exposure_start_date']
}

for table_name, cols in date_columns.items():
    if table_name in tables:
        df = tables[table_name]
        print(f"\n{table_name}:")
        for col in cols:
            if col in df.columns:
                # Convert to datetime
                dates = pd.to_datetime(df[col], errors='coerce')
                valid_dates = dates.dropna()
                
                if len(valid_dates) > 0:
                    min_date = valid_dates.min()
                    max_date = valid_dates.max()
                    span_years = (max_date - min_date).days / 365.25
                    
                    print(f"  {col}:")
                    print(f"    Range: {min_date.date()} to {max_date.date()} ({span_years:.1f} years)")
                    print(f"    Invalid dates: {len(dates) - len(valid_dates):,}")

In [None]:
# Foreign key validation - Check if all persons in other tables exist in PERSON
if 'PERSON' in tables:
    person_ids = set(tables['PERSON']['person_id'].unique())
    
    print("🔗 Foreign Key Validation (person_id):")
    print("="*60)
    
    for table_name, df in tables.items():
        if table_name != 'PERSON' and 'person_id' in df.columns:
            table_person_ids = set(df['person_id'].unique())
            orphaned = table_person_ids - person_ids
            
            if orphaned:
                print(f"⚠️  {table_name}: {len(orphaned)} person_ids not found in PERSON table")
            else:
                print(f"✓ {table_name}: All person_ids valid")

## Part 3: Data Quality Visualizations

In [None]:
# Missing data heatmap
fig, ax = plt.subplots(figsize=(12, 6))

# Calculate missing percentages for each table
missing_data = {}
for name, df in tables.items():
    if len(df) > 0:
        missing_pct = (df.isnull().sum() / len(df) * 100).round(1)
        # Only show columns with any missing data
        missing_cols = missing_pct[missing_pct > 0]
        if len(missing_cols) > 0:
            missing_data[name] = missing_cols

if missing_data:
    # Create a matrix for heatmap
    all_cols = set()
    for cols in missing_data.values():
        all_cols.update(cols.index)
    
    matrix_data = []
    table_names = []
    
    for table_name, missing_cols in missing_data.items():
        row = [missing_cols.get(col, 0) for col in sorted(all_cols)]
        matrix_data.append(row)
        table_names.append(table_name)
    
    if matrix_data:
        sns.heatmap(matrix_data, 
                   xticklabels=sorted(all_cols),
                   yticklabels=table_names,
                   annot=True, fmt='.1f',
                   cmap='YlOrRd',
                   cbar_kws={'label': 'Missing %'},
                   ax=ax)
        plt.title('Missing Data Heatmap (% missing by column)', fontsize=14, fontweight='bold')
        plt.xticks(rotation=45, ha='right')
        plt.tight_layout()
        plt.show()
    else:
        print("✅ No missing data found in any table!")
else:
    print("✅ No missing data found in any table!")

In [None]:
# Record count bar chart
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Sort tables by row count
sorted_stats = stats_df.sort_values('Rows', ascending=True)

# Bar chart of row counts
colors = plt.cm.viridis(np.linspace(0.3, 0.9, len(sorted_stats)))
bars = ax1.barh(sorted_stats['Table'], sorted_stats['Rows'], color=colors)
ax1.set_xlabel('Number of Records', fontsize=12)
ax1.set_title('Record Count by Table', fontsize=14, fontweight='bold')
ax1.grid(axis='x', alpha=0.3)

# Add value labels
for bar, value in zip(bars, sorted_stats['Rows']):
    ax1.text(value, bar.get_y() + bar.get_height()/2, f'{value:,}', 
            ha='left', va='center', fontsize=10)

# Pie chart of memory usage
memory_data = sorted_stats[sorted_stats['Memory (MB)'] > 0]
if len(memory_data) > 0:
    ax2.pie(memory_data['Memory (MB)'], 
           labels=memory_data['Table'],
           autopct='%1.1f%%',
           startangle=90)
    ax2.set_title('Memory Usage Distribution', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.show()

In [None]:
# Timeline visualization - data coverage over time
if 'VISIT_OCCURRENCE' in tables and len(tables['VISIT_OCCURRENCE']) > 0:
    visit_df = tables['VISIT_OCCURRENCE'].copy()
    
    # Convert dates
    visit_df['visit_start_date'] = pd.to_datetime(visit_df['visit_start_date'], errors='coerce')
    visit_df = visit_df.dropna(subset=['visit_start_date'])
    
    if len(visit_df) > 0:
        # Group by month
        visit_df['year_month'] = visit_df['visit_start_date'].dt.to_period('M')
        monthly_visits = visit_df.groupby('year_month').size()
        
        # Plot
        fig, ax = plt.subplots(figsize=(14, 5))
        monthly_visits.plot(kind='line', ax=ax, linewidth=2, marker='o', markersize=4)
        ax.set_xlabel('Time Period', fontsize=12)
        ax.set_ylabel('Number of Visits', fontsize=12)
        ax.set_title('Temporal Data Coverage (Visits Over Time)', fontsize=14, fontweight='bold')
        ax.grid(True, alpha=0.3)
        
        # Add statistics
        avg_visits = monthly_visits.mean()
        ax.axhline(y=avg_visits, color='r', linestyle='--', alpha=0.5, label=f'Average: {avg_visits:.0f}')
        ax.legend()
        
        plt.xticks(rotation=45)
        plt.tight_layout()
        plt.show()
        
        # Print summary
        print(f"📊 Temporal Coverage Summary:")
        print(f"  • Date range: {monthly_visits.index.min()} to {monthly_visits.index.max()}")
        print(f"  • Total months: {len(monthly_visits)}")
        print(f"  • Average visits/month: {avg_visits:.0f}")
        print(f"  • Peak month: {monthly_visits.idxmax()} ({monthly_visits.max():,} visits)")

## Part 4: Quality Report Summary

In [None]:
# Generate comprehensive quality report
print("📋 DATA QUALITY REPORT SUMMARY")
print("="*60)

# Collect all quality metrics
quality_metrics = {
    'Total Tables Loaded': len(tables),
    'Missing Tables': len(missing_tables),
    'Total Records': stats_df['Rows'].sum(),
    'Total Unique Patients': tables['PERSON']['person_id'].nunique() if 'PERSON' in tables else 0,
    'Data Size (MB)': stats_df['Memory (MB)'].sum(),
    'Tables with Missing Data': sum(1 for name, df in tables.items() if df.isnull().any().any()),
}

# Display metrics
for metric, value in quality_metrics.items():
    if isinstance(value, float):
        print(f"{metric:.<30} {value:,.1f}")
    else:
        print(f"{metric:.<30} {value:,}")

print("\n" + "="*60)

In [None]:
# Identify critical issues and provide recommendations
print("⚠️  CRITICAL ISSUES & RECOMMENDATIONS")
print("="*60)

issues = []
recommendations = []

# Check for critical issues
if missing_tables:
    issues.append(f"Missing {len(missing_tables)} required tables: {', '.join(missing_tables)}")
    recommendations.append("Ensure all required OMOP tables are present before running pipeline")

if 'PERSON' in tables:
    person_df = tables['PERSON']
    if len(person_df) == 0:
        issues.append("PERSON table is empty")
        recommendations.append("Load patient data before proceeding")
    elif person_df['person_id'].duplicated().any():
        issues.append("Duplicate person_ids found")
        recommendations.append("Review and deduplicate PERSON table")

# Check for excessive missing data
for name, df in tables.items():
    null_pct = (df.isnull().sum().sum() / (len(df) * len(df.columns)) * 100) if len(df) > 0 else 0
    if null_pct > 50:
        issues.append(f"{name} has {null_pct:.1f}% missing data")
        recommendations.append(f"Review data quality for {name} table")

# Display results
if issues:
    print("\n🔴 Issues Found:")
    for i, issue in enumerate(issues, 1):
        print(f"  {i}. {issue}")
    
    print("\n💡 Recommendations:")
    for i, rec in enumerate(recommendations, 1):
        print(f"  {i}. {rec}")
else:
    print("\n✅ No critical issues found!")
    print("\nYour data appears ready for the CRISP pipeline.")
    print("\nNext steps:")
    print("  1. Run data validation: python data_preparation/validate_data.py")
    print("  2. Execute pipeline: python pipeline_modules/run_all_module.py")

print("\n" + "="*60)
print("Report generated successfully")