<a href="https://colab.research.google.com/github/YagyanshB/nhs-data-science/blob/main/data_quality_checks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# data quality checks

def run_nhs_data_quality_checks(filepath):
    """
    Comprehensive data quality check for NHS patient data including demographics and clinical notes

    Args:
        filepath: Path to the combined NHS patient data with notes

    Returns:
        Dictionary with data quality metrics and summary
    """
    print(f"Running data quality checks on {filepath}")

    # Load data
    try:
        data = pd.read_csv(filepath)
        print(f"Successfully loaded dataset with {len(data)} records and {len(data.columns)} features")
    except Exception as e:
        print(f"Error loading data: {e}")
        return None

    # Initialize results dictionary
    quality_results = {
        "record_count": len(data),
        "feature_count": len(data.columns),
        "missing_data": {},
        "outliers": {},
        "consistency_issues": [],
        "text_quality": {}
    }

    # 1. Check for missing values
    print("\nChecking for missing values...")
    missing_counts = data.isnull().sum()
    missing_percent = (missing_counts / len(data)) * 100

    missing_summary = pd.DataFrame({
        'missing_values': missing_counts,
        'percent_missing': missing_percent
    }).sort_values('percent_missing', ascending=False)

    quality_results["missing_data"]["summary"] = missing_summary[missing_summary['missing_values'] > 0].to_dict()
    quality_results["missing_data"]["total_missing_percentage"] = (data.isnull().sum().sum() / (data.shape[0] * data.shape[1])) * 100

    # Print missing values summary
    print(missing_summary[missing_summary['missing_values'] > 0])

    # 2. Check for demographic data quality
    print("\nChecking demographic data quality...")

    # Age range check
    age_min, age_max = data['age'].min(), data['age'].max()
    invalid_ages = ((data['age'] < 18) | (data['age'] > 110)).sum()
    quality_results["demographics"] = {
        "age_range": f"{age_min} - {age_max}",
        "invalid_ages": invalid_ages
    }
    print(f"Age range: {age_min} - {age_max}")
    print(f"Invalid ages (< 18 or > 110): {invalid_ages}")

    # Gender distribution
    gender_counts = data['gender'].value_counts()
    quality_results["demographics"]["gender_distribution"] = gender_counts.to_dict()
    print("Gender distribution:")
    print(gender_counts)

    # Ethnicity check
    ethnicity_counts = data['ethnicity'].value_counts()
    quality_results["demographics"]["ethnicity_distribution"] = ethnicity_counts.to_dict()
    print("\nEthnicity distribution:")
    print(ethnicity_counts)

    # IMD quintile check
    imd_counts = data['imd_quintile'].value_counts().sort_index()
    quality_results["demographics"]["imd_distribution"] = imd_counts.to_dict()
    print("\nIMD quintile distribution:")
    print(imd_counts)

    # Check for invalid IMD values
    invalid_imd = ((data['imd_quintile'] < 1) | (data['imd_quintile'] > 5)).sum()
    quality_results["demographics"]["invalid_imd"] = invalid_imd
    print(f"Invalid IMD values (not 1-5): {invalid_imd}")

    # 3. Check clinical variables
    print("\nChecking clinical variables...")

    # NEWS2 score range check
    news_min, news_max = data['news2_score'].min(), data['news2_score'].max()
    invalid_news = ((data['news2_score'] < 0) | (data['news2_score'] > 20)).sum()
    quality_results["clinical"] = {
        "news2_range": f"{news_min} - {news_max}",
        "invalid_news2": invalid_news
    }
    print(f"NEWS2 score range: {news_min} - {news_max}")
    print(f"Invalid NEWS2 scores (< 0 or > 20): {invalid_news}")

    # Charlson index check
    charlson_min, charlson_max = data['charlson_index'].min(), data['charlson_index'].max()
    invalid_charlson = ((data['charlson_index'] < 0) | (data['charlson_index'] > 25)).sum()
    quality_results["clinical"]["charlson_range"] = f"{charlson_min} - {charlson_max}"
    quality_results["clinical"]["invalid_charlson"] = invalid_charlson
    print(f"Charlson index range: {charlson_min} - {charlson_max}")
    print(f"Invalid Charlson values (< 0 or > 25): {invalid_charlson}")

    # Primary diagnosis check
    diagnosis_counts = data['primary_diagnosis'].value_counts()
    quality_results["clinical"]["diagnosis_distribution"] = diagnosis_counts.to_dict()
    print("\nTop 5 primary diagnoses:")
    print(diagnosis_counts.head(5))

    # 4. Check for logical consistency
    print("\nChecking for logical consistency...")

    # Check if high risk patients have reasonable risk factors
    high_risk = data[data['readmission_risk'] > 0.7]
    high_risk_without_factors = high_risk[
        (high_risk['age'] < 65) &
        (high_risk['charlson_index'] < 3) &
        (high_risk['previous_admissions_12mo'] == 0) &
        (~high_risk['primary_diagnosis'].isin(['Heart Failure', 'COPD', 'Frailty', 'Renal Failure']))
    ]

    consistency_issue = {
        "issue": "High risk patients without major risk factors",
        "count": len(high_risk_without_factors),
        "percentage": (len(high_risk_without_factors) / len(high_risk)) * 100
    }
    quality_results["consistency_issues"].append(consistency_issue)

    print(f"High risk patients without typical risk factors: {len(high_risk_without_factors)} " +
          f"({(len(high_risk_without_factors) / len(high_risk)) * 100:.1f}% of high risk patients)")

    # 5. Clinical notes quality
    print("\nAnalyzing clinical notes quality...")

    # Note length check
    if 'clinical_note' in data.columns:
        data['note_length'] = data['clinical_note'].str.len()
        note_min, note_mean, note_max = data['note_length'].min(), data['note_length'].mean(), data['note_length'].max()

        quality_results["text_quality"] = {
            "note_length_stats": {
                "min": note_min,
                "mean": note_mean,
                "max": note_max
            }
        }

        # Check for suspiciously short notes
        short_notes = (data['note_length'] < 50).sum()
        quality_results["text_quality"]["short_notes_count"] = short_notes
        quality_results["text_quality"]["short_notes_percent"] = (short_notes / len(data)) * 100

        print(f"Clinical note length - Min: {note_min}, Mean: {note_mean:.1f}, Max: {note_max}")
        print(f"Suspiciously short notes (<50 chars): {short_notes} ({(short_notes/len(data))*100:.1f}%)")

        # Check if notes contain expected clinical terms
        clinical_terms = ['patient', 'presented', 'symptoms', 'treatment', 'plan', 'prescribed', 'assessed']
        term_presence = {}

        for term in clinical_terms:
            term_count = data['clinical_note'].str.contains(term, case=False).sum()
            term_presence[term] = {
                "count": term_count,
                "percentage": (term_count / len(data)) * 100
            }

        quality_results["text_quality"]["clinical_terms_presence"] = term_presence

        print("\nPresence of expected clinical terms:")
        for term, stats in term_presence.items():
            print(f"'{term}': {stats['count']} records ({stats['percentage']:.1f}%)")

    # 6. Visualize key quality metrics
    print("\nGenerating data quality visualizations...")

    # Missing data heatmap
    plt.figure(figsize=(12, 8))
    sns.heatmap(data.isnull(), yticklabels=False, cbar=False, cmap='viridis')
    plt.title('Missing Value Heatmap')
    plt.tight_layout()
    plt.savefig('missing_data_heatmap.png')

    # Readmission risk distribution
    plt.figure(figsize=(10, 6))
    sns.histplot(data['readmission_risk'], bins=20, kde=True)
    plt.title('Readmission Risk Distribution')
    plt.xlabel('Readmission Risk')
    plt.ylabel('Count')
    plt.tight_layout()
    plt.savefig('readmission_risk_distribution.png')


    # Age vs Charlson Index scatter with readmission outcome
    plt.figure(figsize=(10, 6))
    # Sample at most the size of the dataset
    sample_size = min(1000, len(data))
    sns.scatterplot(x='age', y='charlson_index', hue='readmitted_within_30d', data=data.sample(sample_size))
    plt.title('Age vs Charlson Index by Readmission Status')
    plt.xlabel('Age')
    plt.ylabel('Charlson Comorbidity Index')
    plt.tight_layout()
    plt.savefig('age_charlson_readmission.png')


    # Note length distribution
    if 'note_length' in data.columns:
        plt.figure(figsize=(10, 6))
        sns.histplot(data['note_length'], bins=30, kde=True)
        plt.title('Clinical Note Length Distribution')
        plt.xlabel('Number of Characters')
        plt.ylabel('Count')
        plt.tight_layout()
        plt.savefig('note_length_distribution.png')

    # Summary
    print("\nData Quality Summary:")
    print(f"- Total records: {quality_results['record_count']}")
    print(f"- Missing data: {quality_results['missing_data']['total_missing_percentage']:.2f}% overall")
    print(f"- Demographics issues: {quality_results['demographics']['invalid_ages']} invalid ages, " +
          f"{quality_results['demographics']['invalid_imd']} invalid IMD values")
    print(f"- Clinical data issues: {quality_results['clinical']['invalid_news2']} invalid NEWS2 scores, " +
          f"{quality_results['clinical']['invalid_charlson']} invalid Charlson indices")
    print(f"- Logical consistency issues: {len(quality_results['consistency_issues'])} issues identified")
    if 'clinical_note' in data.columns:
        print(f"- Text quality: {quality_results['text_quality']['short_notes_count']} suspiciously short notes")

    return quality_results

# Run the quality checks and print results
if __name__ == "__main__":
    quality_results = run_nhs_data_quality_checks('nhs_patient_data_with_notes.csv')
    print("\n✅ Data quality assessment complete")