In [3]:
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.impute import SimpleImputer
from sklearn.impute import KNNImputer
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import confusion_matrix, roc_auc_score, mean_absolute_error, mean_squared_error
import xgboost as xgb
from sklearn.model_selection import GridSearchCV
from sklearn.preprocessing import StandardScaler
import scipy.stats as stats

In [4]:
#import tensorflow_data_validation as tfdv
#from tensorflow_data_validation.utils import display_util
import logging
import os

In [1]:
from dags.utils.logger_config import setup_logging
logger = setup_logging('.',"x.py")



.
.


In [4]:
logger.info("yy")

2024-06-07 00:02:08,758 - x.py - INFO - yy


In [10]:
df=pd.read_csv('Dataset.csv')

In [32]:
schema = {col: str(df[col].dtype) for col in df.columns}
schema

{'Unnamed: 0': 'int64',
 'Hour': 'int64',
 'HR': 'float64',
 'O2Sat': 'float64',
 'Temp': 'float64',
 'SBP': 'float64',
 'MAP': 'float64',
 'DBP': 'float64',
 'Resp': 'float64',
 'EtCO2': 'float64',
 'BaseExcess': 'float64',
 'HCO3': 'float64',
 'FiO2': 'float64',
 'pH': 'float64',
 'PaCO2': 'float64',
 'SaO2': 'float64',
 'AST': 'float64',
 'BUN': 'float64',
 'Alkalinephos': 'float64',
 'Calcium': 'float64',
 'Chloride': 'float64',
 'Creatinine': 'float64',
 'Bilirubin_direct': 'float64',
 'Glucose': 'float64',
 'Lactate': 'float64',
 'Magnesium': 'float64',
 'Phosphate': 'float64',
 'Potassium': 'float64',
 'Bilirubin_total': 'float64',
 'TroponinI': 'float64',
 'Hct': 'float64',
 'Hgb': 'float64',
 'PTT': 'float64',
 'WBC': 'float64',
 'Fibrinogen': 'float64',
 'Platelets': 'float64',
 'Age': 'float64',
 'Gender': 'int64',
 'Unit1': 'float64',
 'Unit2': 'float64',
 'HospAdmTime': 'float64',
 'ICULOS': 'int64',
 'SepsisLabel': 'int64',
 'Patient_ID': 'int64'}

In [11]:
unique_patient_ids = df['Patient_ID'].drop_duplicates().head(24000)

In [12]:
filtered_df = df[df['Patient_ID'].isin(unique_patient_ids)]

In [13]:
train_unique_patient_ids=filtered_df['Patient_ID'].drop_duplicates().head(18000)
test_unique_patient_ids=filtered_df['Patient_ID'].drop_duplicates().tail(6000)
train_df = df[df['Patient_ID'].isin(train_unique_patient_ids)]
test_df=df[df['Patient_ID'].isin(test_unique_patient_ids)]

In [62]:
from dags.utils.logger_config import setup_logging

In [66]:
import pandas as pd
import json
import logging
from dags.utils.logger_config import setup_logging
import numpy as np


setup_logging()

logger=logging.getLogger('Data_Validation.py')

STATS_SCHEMA_FILE = 'schema_and_stats.json'

def convert_to_serializable(value):
    """
    Convert a value to a JSON-serializable format.

    Args:
        value: The value to convert.

    Returns:
        The converted value.
    """
    if isinstance(value, (np.integer, np.floating)):
        return value.item()
    if isinstance(value, np.ndarray):
        return value.tolist()
    return value

def generate_and_save_schema_and_stats(df, schema_file=STATS_SCHEMA_FILE):
    """
    Generate the schema and statistics from a DataFrame and save them to a JSON file.

    Args:
        df (pd.DataFrame): DataFrame from which to infer the schema and statistics.
        schema_file (str): Path to save the schema and statistics.
    """
    try:
        schema = {
        "HR": "float64",
        "O2Sat": "float64",
        "Temp": "float64",
        "SBP": "float64",
        "MAP": "float64",
        "DBP": "float64",
        "Resp": "float64",
        "EtCO2": "float64",
        "BaseExcess": "float64",
        "HCO3": "float64",
        "FiO2": "float64",
        "pH": "float64",
        "PaCO2": "float64",
        "SaO2": "float64",
        "AST": "float64",
        "BUN": "float64",
        "Alkalinephos": "float64",
        "Calcium": "float64",
        "Chloride": "float64",
        "Creatinine": "float64",
        "Bilirubin_direct": "float64",
        "Glucose": "float64",
        "Lactate": "float64",
        "Magnesium": "float64",
        "Phosphate": "float64",
        "Potassium": "float64",
        "Bilirubin_total": "float64",
        "TroponinI": "float64",
        "Hct": "float64",
        "Hgb": "float64",
        "PTT": "float64",
        "WBC": "float64",
        "Fibrinogen": "float64",
        "Platelets": "float64",
        "Age": "float64",
        "Gender": "int64",
        "Unit1": "float64",
        "Unit2": "float64",
        "HospAdmTime": "float64",
        "ICULOS": "int64",
        "SepsisLabel": "int64",
        "Patient_ID": "int64"  # Assuming Patient_ID is extracted as a string
        }
        stats = {}
        for col in df.columns:
            if pd.api.types.is_numeric_dtype(df[col]):
                stats[col] = {
                    'min': convert_to_serializable(df[col].min()),
                    'max': convert_to_serializable(df[col].max()),
                    'mean': convert_to_serializable(df[col].mean()) if not df[col].isnull().all() else None,
                    'median': convert_to_serializable(df[col].median()) if not df[col].isnull().all() else None,
                    'std': convert_to_serializable(df[col].std()) if not df[col].isnull().all() else None,
                    'null_count': convert_to_serializable(df[col].isnull().sum())
                }
            else:
                stats[col] = {
                    'unique_values': convert_to_serializable(df[col].unique()),
                    'null_count': convert_to_serializable(df[col].isnull().sum())
                }
        schema_and_stats = {'schema': schema, 'statistics': stats}
        with open(schema_file, 'w') as f:
            json.dump(schema_and_stats, f, indent=4)
        logger.info(f"Schema and statistics generated and saved to {schema_file}.")
    except Exception as e:
        logger.error(f"Error generating or saving schema and statistics: {e}")
        raise

def load_schema_and_stats(schema_file=STATS_SCHEMA_FILE):
    """
    Load the schema and statistics from a JSON file.

    Args:
        schema_file (str): Path to the schema and statistics file.

    Returns:
        dict: Loaded schema and statistics.
    """
    try:
        with open(schema_file, 'r') as f:
            schema_and_stats = json.load(f)
        logger.info(f"Schema and statistics loaded from {schema_file}.")
        return schema_and_stats
    except Exception as e:
        logger.error(f"Error loading schema and statistics from {schema_file}: {e}")
        raise

In [67]:


def validate_schema(df, schema):
    """
    Validate the schema of the DataFrame against the expected schema.

    Args:
        df (pd.DataFrame): DataFrame to validate.
        schema (dict): Expected schema.

    Returns:
        bool: True if schema is valid, False otherwise.
    """
    for column, dtype in schema.items():
        if column not in df.columns:
            logger.error(f"Missing column: {column}")
            return False
        if str(df[column].dtype) != dtype:
            logger.error(f"Invalid type for column {column}. Expected {dtype}, got {df[column].dtype}")
            return False
    logger.info("Schema validation passed.")
    return True

def validate_statistics(df, stats):
    """
    Validate statistics of the DataFrame against expected statistics.

    Args:
        df (pd.DataFrame): DataFrame to validate.
        stats (dict): Expected statistics.

    Returns:
        bool: True if statistics are valid, False otherwise.
    """
    try:
        for col, stat in stats.items():
            if col not in df.columns:
                logger.error(f"Missing column: {col}")
                return False
            
            if col == 'Patient_ID':
                if df[col].isnull().any():
                    logger.error("The 'patient_id' column cannot have null values.")
                    return False
                continue

            if 'min' in stat and 'max' in stat:
                if stat['min'] is not None and stat['max'] is not None:
                    if df[col].min() < stat['min']:
                        logger.warning(f"Column {col} min value anomaly: {df[col].min()} < {stat['min']}")
                    if df[col].max() > stat['max']:
                        logger.warning(f"Column {col} max value anomaly: {df[col].max()} > {stat['max']}")

            if 'mean' in stat and 'std' in stat:
                if stat['mean'] is not None and stat['std'] is not None:
                    if not df[col].isnull().all():  # Check if any non-null values exist
                        if abs(df[col].mean() - stat['mean']) > 3 * stat['std']:
                            logger.warning(f"Column {col} mean value anomaly: {df[col].mean()} != {stat['mean']}")

            if 'median' in stat and 'std' in stat:
                if stat['median'] is not None and stat['std'] is not None:
                    if not df[col].isnull().all():  # Check if any non-null values exist
                        if abs(df[col].median() - stat['median']) > 3 * stat['std']:
                            logger.warning(f"Column {col} median value anomaly: {df[col].median()} != {stat['median']}")

            if 'null_count' in stat:
                null_count = df[col].isnull().sum()
                if stat['null_count'] is not None:
                    if null_count > stat['null_count']:
                        logger.warning(f"Column {col} null value count anomaly: {null_count} > {stat['null_count']}")

            if 'unique_values' in stat:
                if stat['unique_values'] is not None:
                    unique_values = df[col].unique()
                    if set(unique_values) != set(stat['unique_values']):
                        logger.warning(f"Column {col} unique values anomaly: {unique_values} != {stat['unique_values']}")

        logger.info("Statistical validation passed.")
        return True
    except Exception as e:
        logger.error(f"Error during statistical validation: {e}")
        return False


def validate_data(df):
    """
    Validate data against stored schema and statistics.

    Args:
        file_path (str): Path to the data file.

    Returns:
        bool: True if validation passes, False if validation fails.
    """
    try:
        # Load data
        #df = pd.read_csv(file_path)
        logger.info(f"Data loaded successfully.")

        # Load schema and statistics
        schema_and_stats = load_schema_and_stats()
        schema = schema_and_stats['schema']
        stats = schema_and_stats['statistics']

        # Validate schema
        if not validate_schema(df, schema):
            logger.error("Schema validation failed.")
            return False

        # Validate statistics
        if not validate_statistics(df, stats):
            logger.error("Statistical validation failed.")
            return False

        logger.info("Data validation passed.")
        return True
    except Exception as e:
        logger.error(f"Error during data validation: {e}")
        return False


In [68]:
generate_and_save_schema_and_stats(train_df)

2024-06-06 18:11:17,932 - Data_Validation.py - INFO - Schema and statistics generated and saved to schema_and_stats.json.


In [69]:
validation_result = validate_data(test_df)
if not validation_result:
    raise ValueError("Data validation failed. Stopping DAG execution.")

2024-06-06 18:11:38,040 - Data_Validation.py - INFO - Data loaded successfully.
2024-06-06 18:11:38,040 - Data_Validation.py - INFO - Schema and statistics loaded from schema_and_stats.json.
2024-06-06 18:11:38,041 - Data_Validation.py - INFO - Schema validation passed.
2024-06-06 18:11:38,458 - Data_Validation.py - INFO - Statistical validation passed.
2024-06-06 18:11:38,458 - Data_Validation.py - INFO - Data validation passed.


In [27]:
test_df

Unnamed: 0.1,Unnamed: 0,Hour,HR,O2Sat,Temp,SBP,MAP,DBP,Resp,EtCO2,...,Fibrinogen,Platelets,Age,Gender,Unit1,Unit2,HospAdmTime,ICULOS,SepsisLabel,Patient_ID
699611,0,0,,,,,,,,,...,,,42.51,1,,,-0.01,1,0,11905
699612,1,1,82.0,100.0,,110.0,76.0,,11.0,,...,,,42.51,1,,,-0.01,2,0,11905
699613,2,2,81.0,97.0,,110.0,77.0,,14.0,,...,,,42.51,1,,,-0.01,3,0,11905
699614,3,3,77.0,100.0,,111.0,73.0,,12.0,,...,,238.0,42.51,1,,,-0.01,4,0,11905
699615,4,4,76.0,100.0,37.33,110.0,74.0,,11.0,,...,,,42.51,1,,,-0.01,5,0,11905
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
927802,16,16,83.0,93.0,,126.0,95.0,86.0,15.0,,...,,,33.00,1,1.0,0.0,-4.88,17,0,114990
927803,17,17,78.0,100.0,,118.0,88.0,82.0,16.0,,...,,,33.00,1,1.0,0.0,-4.88,18,0,114990
927804,18,18,83.0,100.0,,117.0,86.0,78.0,17.0,,...,,,33.00,1,1.0,0.0,-4.88,19,0,114990
927805,19,19,79.0,100.0,,120.0,91.0,87.0,18.0,,...,,,33.00,1,1.0,0.0,-4.88,20,0,114990
