# Master Quality Control Notebook

This notebook performs comprehensive quality checks on all pipeline outputs.

## QC Checks:
1. Row count consistency across datasets
2. Missingness analysis
3. Date consistency and logical ordering
4. Duplicate ID checks
5. Foreign key integrity
6. Summary dashboard with PASS/FAIL status

In [None]:
# Import libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')

# Initialize results tracking
qc_results = {
    'timestamp': datetime.now().isoformat(),
    'checks': {},
    'overall_status': 'PASS'
}

## 1. Load All Datasets and Check Row Counts

In [None]:
# Define expected datasets
datasets = {
    'cohort': '../data_derived/cohort.parquet',
    'exposure': '../data_derived/exposure.parquet',
    'mediator': '../data_derived/mediator.parquet',
    'outcomes': '../data_derived/outcomes.parquet',
    'confounders': '../data_derived/confounders.parquet',
    'patient_master': '../data_derived/patient_master.parquet',
    'ps_weighted': '../data_derived/ps_weighted.parquet'
}

# Load datasets and check row counts
loaded_data = {}
row_counts = {}
expected_rows = 250025  # From blueprint

for name, path in datasets.items():
    try:
        df = pd.read_parquet(path)
        loaded_data[name] = df
        row_counts[name] = len(df)
        print(f"{name}: {len(df):,} rows")
    except FileNotFoundError:
        print(f"WARNING: {name} not found at {path}")
        row_counts[name] = 0

# Check consistency
row_count_check = all(count == expected_rows for count in row_counts.values() if count > 0)
qc_results['checks']['row_count_consistency'] = {
    'status': 'PASS' if row_count_check else 'FAIL',
    'expected': expected_rows,
    'actual': row_counts,
    'consistent': row_count_check
}

if not row_count_check:
    qc_results['overall_status'] = 'FAIL'
    print("\n❌ FAIL: Row counts are inconsistent!")
else:
    print("\n✅ PASS: All datasets have consistent row counts")

## 2. Missingness Analysis

In [None]:
# Analyze missingness in patient master table
if 'patient_master' in loaded_data:
    master_df = loaded_data['patient_master']
    
    # Calculate missing percentages
    missing_pct = (master_df.isnull().sum() / len(master_df) * 100).sort_values(ascending=False)
    missing_pct = missing_pct[missing_pct > 0]
    
    # Create missingness heatmap
    fig, ax = plt.subplots(figsize=(12, 8))
    
    # Select columns with any missing data
    cols_with_missing = missing_pct.index[:20]  # Top 20
    if len(cols_with_missing) > 0:
        missing_matrix = master_df[cols_with_missing].isnull()
        sns.heatmap(missing_matrix.T, cbar=True, yticklabels=True, 
                   cmap='RdYlBu', ax=ax)
        ax.set_title('Missingness Heatmap (Top 20 Variables)')
        ax.set_xlabel('Patient Index')
        plt.tight_layout()
        plt.show()
    
    # Check if missingness is acceptable (< 10% for key variables)
    key_vars = ['Patient_ID', 'ssd_flag', 'age', 'sex_M']
    key_missing = {var: missing_pct.get(var, 0) for var in key_vars}
    
    missingness_check = all(pct < 10 for pct in key_missing.values())
    qc_results['checks']['missingness'] = {
        'status': 'PASS' if missingness_check else 'FAIL',
        'key_variables': key_missing,
        'total_columns_with_missing': len(missing_pct),
        'max_missing_pct': float(missing_pct.max()) if len(missing_pct) > 0 else 0
    }
    
    if not missingness_check:
        qc_results['overall_status'] = 'FAIL'
    
    print(f"\nMissingness Summary:")
    print(f"Columns with missing data: {len(missing_pct)}")
    print(f"Max missing percentage: {missing_pct.max():.1f}%" if len(missing_pct) > 0 else "No missing data")
    print(f"\nKey variable missingness:")
    for var, pct in key_missing.items():
        print(f"  {var}: {pct:.1f}%")

## 3. Date Consistency Checks

In [None]:
# Check date consistency
date_issues = []

if 'patient_master' in loaded_data:
    df = loaded_data['patient_master']
    
    # Identify date columns
    date_cols = [col for col in df.columns if 'date' in col.lower()]
    
    print("Date columns found:", date_cols)
    
    # Convert to datetime if needed
    for col in date_cols:
        if df[col].dtype != 'datetime64[ns]':
            try:
                df[col] = pd.to_datetime(df[col])
            except:
                print(f"Warning: Could not convert {col} to datetime")
    
    # Check for future dates
    today = pd.Timestamp.now()
    for col in date_cols:
        if df[col].dtype == 'datetime64[ns]':
            future_dates = df[df[col] > today]
            if len(future_dates) > 0:
                date_issues.append(f"{col} has {len(future_dates)} future dates")
    
    # Check logical ordering
    if 'index_date' in df.columns and 'death_date' in df.columns:
        illogical = df[df['death_date'] < df['index_date']]
        if len(illogical) > 0:
            date_issues.append(f"Death before index date: {len(illogical)} cases")
    
    # Check study period consistency (2018-2020)
    if 'index_date' in df.columns:
        df['index_year'] = pd.to_datetime(df['index_date']).dt.year
        year_counts = df['index_year'].value_counts().sort_index()
        
        plt.figure(figsize=(8, 5))
        year_counts.plot(kind='bar')
        plt.title('Patient Distribution by Index Year')
        plt.xlabel('Year')
        plt.ylabel('Number of Patients')
        plt.axhline(y=expected_rows/3, color='r', linestyle='--', alpha=0.5, label='Expected if uniform')
        plt.legend()
        plt.tight_layout()
        plt.show()
        
        # Check if majority are in 2018-2020
        study_period_pct = year_counts[year_counts.index.isin([2018, 2019, 2020])].sum() / len(df) * 100
        if study_period_pct < 95:
            date_issues.append(f"Only {study_period_pct:.1f}% of patients in 2018-2020")
    
    date_check = len(date_issues) == 0
    qc_results['checks']['date_consistency'] = {
        'status': 'PASS' if date_check else 'FAIL',
        'issues': date_issues,
        'study_period_coverage': study_period_pct if 'study_period_pct' in locals() else None
    }
    
    if not date_check:
        qc_results['overall_status'] = 'FAIL'
        print("\n❌ Date issues found:")
        for issue in date_issues:
            print(f"  - {issue}")
    else:
        print("\n✅ PASS: All date checks passed")

## 4. Duplicate ID Checks

In [None]:
# Check for duplicate Patient_IDs
duplicate_issues = []

for name, df in loaded_data.items():
    if 'Patient_ID' in df.columns:
        duplicates = df['Patient_ID'].duplicated().sum()
        if duplicates > 0:
            duplicate_issues.append(f"{name}: {duplicates} duplicate Patient_IDs")
            print(f"WARNING: {name} has {duplicates} duplicate Patient_IDs")

duplicate_check = len(duplicate_issues) == 0
qc_results['checks']['duplicate_ids'] = {
    'status': 'PASS' if duplicate_check else 'FAIL',
    'issues': duplicate_issues
}

if not duplicate_check:
    qc_results['overall_status'] = 'FAIL'
    print("\n❌ FAIL: Duplicate IDs found")
else:
    print("\n✅ PASS: No duplicate Patient_IDs found")

## 5. Foreign Key Integrity

In [None]:
# Check that all Patient_IDs exist in master cohort
integrity_issues = []

if 'cohort' in loaded_data:
    master_ids = set(loaded_data['cohort']['Patient_ID'])
    
    for name, df in loaded_data.items():
        if name != 'cohort' and 'Patient_ID' in df.columns:
            dataset_ids = set(df['Patient_ID'])
            
            # Check for IDs not in master
            orphan_ids = dataset_ids - master_ids
            if len(orphan_ids) > 0:
                integrity_issues.append(f"{name}: {len(orphan_ids)} Patient_IDs not in cohort")
            
            # Check for missing IDs
            missing_ids = master_ids - dataset_ids
            if len(missing_ids) > 0 and name != 'ps_weighted':  # ps_weighted might be filtered
                print(f"INFO: {name} missing {len(missing_ids)} patients from cohort")

integrity_check = len(integrity_issues) == 0
qc_results['checks']['foreign_key_integrity'] = {
    'status': 'PASS' if integrity_check else 'FAIL',
    'issues': integrity_issues
}

if not integrity_check:
    qc_results['overall_status'] = 'FAIL'
    print("\n❌ FAIL: Foreign key integrity issues")
    for issue in integrity_issues:
        print(f"  - {issue}")
else:
    print("\n✅ PASS: All Patient_IDs properly linked")

## 6. Summary Dashboard

In [None]:
# Create summary dashboard
print("\n" + "="*50)
print("QUALITY CONTROL SUMMARY DASHBOARD")
print("="*50)
print(f"\nRun Date: {qc_results['timestamp']}")
print(f"\nOVERALL STATUS: {qc_results['overall_status']}")
print("\nIndividual Checks:")

# Summary table
check_summary = []
for check_name, check_result in qc_results['checks'].items():
    status = check_result['status']
    icon = "✅" if status == "PASS" else "❌"
    check_summary.append({
        'Check': check_name.replace('_', ' ').title(),
        'Status': f"{icon} {status}",
        'Details': len(check_result.get('issues', [])) if 'issues' in check_result else ''
    })

summary_df = pd.DataFrame(check_summary)
print(summary_df.to_string(index=False))

# Save results to JSON
output_path = Path('../results/qc_results.json')
output_path.parent.mkdir(exist_ok=True)
with open(output_path, 'w') as f:
    json.dump(qc_results, f, indent=2)
print(f"\nQC results saved to: {output_path}")

# Final status
if qc_results['overall_status'] == 'PASS':
    print("\n🎉 All quality checks PASSED! The pipeline outputs are consistent and valid.")
else:
    print("\n⚠️  Some quality checks FAILED. Please review the issues above.")

## Additional Checks: Data Distributions

In [None]:
# Check key variable distributions
if 'patient_master' in loaded_data:
    df = loaded_data['patient_master']
    
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    # Age distribution
    if 'age' in df.columns:
        df['age'].hist(bins=30, ax=axes[0,0], edgecolor='black')
        axes[0,0].set_title('Age Distribution')
        axes[0,0].set_xlabel('Age')
        axes[0,0].axvline(df['age'].mean(), color='red', linestyle='--', label=f'Mean: {df["age"].mean():.1f}')
        axes[0,0].legend()
    
    # Sex distribution
    if 'sex_M' in df.columns:
        sex_counts = df['sex_M'].value_counts()
        sex_counts.plot(kind='bar', ax=axes[0,1])
        axes[0,1].set_title('Sex Distribution')
        axes[0,1].set_xticklabels(['Female', 'Male'], rotation=0)
        axes[0,1].set_ylabel('Count')
    
    # SSD flag distribution
    if 'ssd_flag' in df.columns:
        ssd_counts = df['ssd_flag'].value_counts()
        ssd_counts.plot(kind='bar', ax=axes[1,0])
        axes[1,0].set_title('SSD Flag Distribution')
        axes[1,0].set_xticklabels(['Control', 'SSD'], rotation=0)
        axes[1,0].set_ylabel('Count')
        
        # Add percentage labels
        total = len(df)
        for i, v in enumerate(ssd_counts):
            axes[1,0].text(i, v + total*0.01, f'{v/total*100:.1f}%', ha='center')
    
    # Charlson score distribution
    if 'charlson_score' in df.columns:
        df['charlson_score'].value_counts().sort_index().plot(kind='bar', ax=axes[1,1])
        axes[1,1].set_title('Charlson Comorbidity Score Distribution')
        axes[1,1].set_xlabel('Charlson Score')
        axes[1,1].set_ylabel('Count')
    
    plt.tight_layout()
    plt.show()

In [None]:
# Export notebook results for pipeline integration
print(f"\nFinal QC Status for pipeline: {qc_results['overall_status']}")

# This output can be captured by papermill for automation
qc_status = qc_results['overall_status']