In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats

# ========================================
# DATA QUALITY ASSESSMENT TOOLKIT
# ========================================

class DataQualityChecker:
    """
    Comprehensive data quality checker for time series datasets
    """
    
    def __init__(self, X_train, y_train, X_val, y_val, X_test, y_test, 
                 feature_names=None, target_name='PM2_5_log_scaled'):
        """
        Args:
            X_train, X_val, X_test: Shape (samples, timesteps, features)
            y_train, y_val, y_test: Shape (samples,)
            feature_names: List of feature names
            target_name: Name of target variable
        """
        self.X_train = X_train
        self.X_val = X_val
        self.X_test = X_test
        self.y_train = y_train
        self.y_val = y_val
        self.y_test = y_test
        self.feature_names = feature_names
        self.target_name = target_name
        
        if feature_names is None:
            self.feature_names = [f'Feature_{i}' for i in range(X_train.shape[2])]
    
    def check_all(self):
        """Run all quality checks"""
        print("="*80)
        print("DATA QUALITY ASSESSMENT REPORT")
        print("="*80)
        
        self.check_basic_info()
        self.check_missing_values()
        self.check_duplicates()
        self.check_distribution()
        self.check_outliers()
        self.check_data_balance()
        self.check_feature_ranges()
        self.check_temporal_consistency()
        self.check_correlations()
        
        print("\n" + "="*80)
        print("QUALITY ASSESSMENT COMPLETED")
        print("="*80)
    
    def check_basic_info(self):
        """Check basic dataset information"""
        print("\n[1] BASIC INFORMATION")
        print("-" * 80)
        print(f"Training samples:   {len(self.X_train):,}")
        print(f"Validation samples: {len(self.X_val):,}")
        print(f"Test samples:       {len(self.X_test):,}")
        print(f"Total samples:      {len(self.X_train) + len(self.X_val) + len(self.X_test):,}")
        print(f"\nSequence length:    {self.X_train.shape[1]}")
        print(f"Number of features: {self.X_train.shape[2]}")
        print(f"Features: {', '.join(self.feature_names)}")
        
        # Split ratio
        total = len(self.X_train) + len(self.X_val) + len(self.X_test)
        print(f"\nSplit ratios:")
        print(f"  Train: {len(self.X_train)/total*100:.1f}%")
        print(f"  Val:   {len(self.X_val)/total*100:.1f}%")
        print(f"  Test:  {len(self.X_test)/total*100:.1f}%")
    
    def check_missing_values(self):
        """Check for missing values"""
        print("\n[2] MISSING VALUES CHECK")
        print("-" * 80)
        
        datasets = {
            'Train': (self.X_train, self.y_train),
            'Val': (self.X_val, self.y_val),
            'Test': (self.X_test, self.y_test)
        }
        
        has_missing = False
        for name, (X, y) in datasets.items():
            nan_count_X = np.isnan(X).sum()
            inf_count_X = np.isinf(X).sum()
            nan_count_y = np.isnan(y).sum()
            inf_count_y = np.isinf(y).sum()
            
            if nan_count_X > 0 or inf_count_X > 0 or nan_count_y > 0 or inf_count_y > 0:
                has_missing = True
                print(f"‚ö†Ô∏è  {name} set:")
                if nan_count_X > 0:
                    print(f"   X: {nan_count_X:,} NaN values ({nan_count_X/X.size*100:.4f}%)")
                if inf_count_X > 0:
                    print(f"   X: {inf_count_X:,} Inf values ({inf_count_X/X.size*100:.4f}%)")
                if nan_count_y > 0:
                    print(f"   y: {nan_count_y:,} NaN values ({nan_count_y/y.size*100:.4f}%)")
                if inf_count_y > 0:
                    print(f"   y: {inf_count_y:,} Inf values ({inf_count_y/y.size*100:.4f}%)")
        
        if not has_missing:
            print("‚úÖ No missing values (NaN/Inf) detected in any dataset")
    
    def check_duplicates(self):
        """Check for duplicate sequences"""
        print("\n[3] DUPLICATE SEQUENCES CHECK")
        print("-" * 80)
        
        # Reshape to 2D for duplicate detection
        X_train_2d = self.X_train.reshape(len(self.X_train), -1)
        X_val_2d = self.X_val.reshape(len(self.X_val), -1)
        X_test_2d = self.X_test.reshape(len(self.X_test), -1)
        
        # Check duplicates within each set
        train_df = pd.DataFrame(X_train_2d)
        val_df = pd.DataFrame(X_val_2d)
        test_df = pd.DataFrame(X_test_2d)
        
        train_dups = train_df.duplicated().sum()
        val_dups = val_df.duplicated().sum()
        test_dups = test_df.duplicated().sum()
        
        print(f"Train duplicates: {train_dups:,} ({train_dups/len(self.X_train)*100:.2f}%)")
        print(f"Val duplicates:   {val_dups:,} ({val_dups/len(self.X_val)*100:.2f}%)")
        print(f"Test duplicates:  {test_dups:,} ({test_dups/len(self.X_test)*100:.2f}%)")
        
        if train_dups + val_dups + test_dups == 0:
            print("‚úÖ No duplicate sequences found")
        else:
            print("‚ö†Ô∏è  Duplicates detected - consider removing them")
    
    def check_distribution(self):
        """Check target distribution"""
        print("\n[4] TARGET DISTRIBUTION")
        print("-" * 80)
        
        datasets = {
            'Train': self.y_train,
            'Val': self.y_val,
            'Test': self.y_test
        }
        
        for name, y in datasets.items():
            print(f"\n{name} set ({self.target_name}):")
            print(f"  Mean:   {np.mean(y):.6f}")
            print(f"  Std:    {np.std(y):.6f}")
            print(f"  Min:    {np.min(y):.6f}")
            print(f"  25%:    {np.percentile(y, 25):.6f}")
            print(f"  Median: {np.median(y):.6f}")
            print(f"  75%:    {np.percentile(y, 75):.6f}")
            print(f"  Max:    {np.max(y):.6f}")
            
            # Skewness and Kurtosis
            skewness = stats.skew(y)
            kurtosis = stats.kurtosis(y)
            print(f"  Skewness: {skewness:.4f} {'(right-skewed)' if skewness > 0 else '(left-skewed)'}")
            print(f"  Kurtosis: {kurtosis:.4f} {'(heavy-tailed)' if kurtosis > 0 else '(light-tailed)'}")
        
        # Check distribution consistency
        print("\nüìä Distribution Consistency Check:")
        train_mean, val_mean, test_mean = np.mean(self.y_train), np.mean(self.y_val), np.mean(self.y_test)
        train_std, val_std, test_std = np.std(self.y_train), np.std(self.y_val), np.std(self.y_test)
        
        mean_diff_val = abs(train_mean - val_mean) / train_mean * 100
        mean_diff_test = abs(train_mean - test_mean) / train_mean * 100
        std_diff_val = abs(train_std - val_std) / train_std * 100
        std_diff_test = abs(train_std - test_std) / train_std * 100
        
        print(f"  Train vs Val mean difference:  {mean_diff_val:.2f}%")
        print(f"  Train vs Test mean difference: {mean_diff_test:.2f}%")
        print(f"  Train vs Val std difference:   {std_diff_val:.2f}%")
        print(f"  Train vs Test std difference:  {std_diff_test:.2f}%")
        
        if mean_diff_val > 10 or mean_diff_test > 10:
            print("  ‚ö†Ô∏è  Large mean difference detected - possible distribution shift")
        else:
            print("  ‚úÖ Distribution is consistent across splits")
    
    def check_outliers(self):
        """Check for outliers in target variable"""
        print("\n[5] OUTLIER DETECTION (Target Variable)")
        print("-" * 80)
        
        datasets = {
            'Train': self.y_train,
            'Val': self.y_val,
            'Test': self.y_test
        }
        
        for name, y in datasets.items():
            Q1 = np.percentile(y, 25)
            Q3 = np.percentile(y, 75)
            IQR = Q3 - Q1
            
            lower_bound = Q1 - 1.5 * IQR
            upper_bound = Q3 + 1.5 * IQR
            
            outliers_low = np.sum(y < lower_bound)
            outliers_high = np.sum(y > upper_bound)
            total_outliers = outliers_low + outliers_high
            
            print(f"\n{name} set:")
            print(f"  IQR: {IQR:.6f}")
            print(f"  Lower bound: {lower_bound:.6f}")
            print(f"  Upper bound: {upper_bound:.6f}")
            print(f"  Outliers (low):  {outliers_low:,} ({outliers_low/len(y)*100:.2f}%)")
            print(f"  Outliers (high): {outliers_high:,} ({outliers_high/len(y)*100:.2f}%)")
            print(f"  Total outliers:  {total_outliers:,} ({total_outliers/len(y)*100:.2f}%)")
            
            if total_outliers/len(y) > 0.05:
                print(f"  ‚ö†Ô∏è  High percentage of outliers (>{5}%)")
    
    def check_data_balance(self):
        """Check data balance across value ranges"""
        print("\n[6] DATA BALANCE (Target Variable Distribution)")
        print("-" * 80)
        
        # Create bins for target variable
        bins = np.percentile(self.y_train, [0, 25, 50, 75, 100])
        bin_labels = ['Q1 (Low)', 'Q2', 'Q3', 'Q4 (High)']
        
        datasets = {
            'Train': self.y_train,
            'Val': self.y_val,
            'Test': self.y_test
        }
        
        for name, y in datasets.items():
            digitized = np.digitize(y, bins[1:-1])
            counts = [np.sum(digitized == i) for i in range(4)]
            
            print(f"\n{name} set distribution:")
            for i, (label, count) in enumerate(zip(bin_labels, counts)):
                percentage = count / len(y) * 100
                bar = '‚ñà' * int(percentage / 2)
                print(f"  {label}: {count:6,} ({percentage:5.2f}%) {bar}")
            
            # Check balance
            min_count = min(counts)
            max_count = max(counts)
            imbalance_ratio = max_count / min_count if min_count > 0 else float('inf')
            
            if imbalance_ratio > 2:
                print(f"  ‚ö†Ô∏è  Imbalanced data detected (ratio: {imbalance_ratio:.2f}:1)")
            else:
                print(f"  ‚úÖ Data is reasonably balanced (ratio: {imbalance_ratio:.2f}:1)")
    
    def check_feature_ranges(self):
        """Check feature value ranges"""
        print("\n[7] FEATURE VALUE RANGES")
        print("-" * 80)
        
        # Get last timestep for each sequence (most recent)
        X_train_last = self.X_train[:, -1, :]
        
        print(f"\n{'Feature':<30} {'Min':>10} {'Max':>10} {'Mean':>10} {'Std':>10}")
        print("-" * 70)
        
        for i, name in enumerate(self.feature_names):
            feature_values = X_train_last[:, i]
            print(f"{name:<30} {np.min(feature_values):>10.4f} {np.max(feature_values):>10.4f} "
                  f"{np.mean(feature_values):>10.4f} {np.std(feature_values):>10.4f}")
        
        # Check if any features are constant
        print("\nüîç Constant Features Check:")
        constant_features = []
        for i, name in enumerate(self.feature_names):
            feature_values = X_train_last[:, i]
            if np.std(feature_values) < 1e-6:
                constant_features.append(name)
        
        if constant_features:
            print(f"  ‚ö†Ô∏è  Constant features detected: {', '.join(constant_features)}")
        else:
            print("  ‚úÖ No constant features detected")
    
    def check_temporal_consistency(self):
        """Check temporal consistency in sequences"""
        print("\n[8] TEMPORAL CONSISTENCY")
        print("-" * 80)
        
        # Check if sequences have temporal trends
        # Calculate average difference between consecutive timesteps
        diffs = np.diff(self.X_train, axis=1)
        avg_diff = np.mean(np.abs(diffs), axis=(0, 1))
        
        print(f"\nAverage temporal change per feature:")
        print(f"{'Feature':<30} {'Avg Change':>15}")
        print("-" * 50)
        
        for i, name in enumerate(self.feature_names):
            print(f"{name:<30} {avg_diff[i]:>15.6f}")
        
        # Check for sequences with no temporal variation
        zero_variance_count = 0
        for seq in self.X_train:
            if np.all(np.std(seq, axis=0) < 1e-6):
                zero_variance_count += 1
        
        print(f"\nüîç Zero-variance sequences: {zero_variance_count:,} ({zero_variance_count/len(self.X_train)*100:.2f}%)")
        
        if zero_variance_count > 0:
            print("  ‚ö†Ô∏è  Some sequences have no temporal variation")
        else:
            print("  ‚úÖ All sequences have temporal variation")
    
    def check_correlations(self):
        """Check feature correlations"""
        print("\n[9] FEATURE CORRELATIONS")
        print("-" * 80)
        
        # Use last timestep for correlation analysis
        X_train_last = self.X_train[:, -1, :]
        
        # Calculate correlation matrix
        corr_matrix = np.corrcoef(X_train_last.T)
        
        # Find highly correlated features (>0.9)
        high_corr_pairs = []
        for i in range(len(self.feature_names)):
            for j in range(i+1, len(self.feature_names)):
                if abs(corr_matrix[i, j]) > 0.9:
                    high_corr_pairs.append((
                        self.feature_names[i], 
                        self.feature_names[j], 
                        corr_matrix[i, j]
                    ))
        
        if high_corr_pairs:
            print("\n‚ö†Ô∏è  Highly correlated feature pairs (|r| > 0.9):")
            for feat1, feat2, corr in high_corr_pairs:
                print(f"  {feat1} <-> {feat2}: {corr:.4f}")
            print("\n  Consider removing one feature from each pair to reduce redundancy")
        else:
            print("\n‚úÖ No highly correlated feature pairs detected (threshold: |r| > 0.9)")
        
        # Correlation with target
        print("\nüìä Feature correlation with target:")
        y_train_expanded = np.repeat(self.y_train.reshape(-1, 1), X_train_last.shape[1], axis=1)
        target_corr = [np.corrcoef(X_train_last[:, i], self.y_train)[0, 1] 
                       for i in range(X_train_last.shape[1])]
        
        # Sort by absolute correlation
        sorted_indices = np.argsort(np.abs(target_corr))[::-1]
        
        print(f"\n{'Feature':<30} {'Correlation':>15}")
        print("-" * 50)
        for idx in sorted_indices[:10]:  # Top 10
            print(f"{self.feature_names[idx]:<30} {target_corr[idx]:>15.4f}")
    
    def plot_quality_report(self):
        """Generate visualization report"""
        print("\n[10] GENERATING VISUALIZATIONS...")
        print("-" * 80)
        
        fig = plt.figure(figsize=(16, 12))
        
        # 1. Target distribution
        ax1 = plt.subplot(3, 3, 1)
        ax1.hist([self.y_train, self.y_val, self.y_test], 
                 bins=50, label=['Train', 'Val', 'Test'], alpha=0.7)
        ax1.set_xlabel(self.target_name)
        ax1.set_ylabel('Frequency')
        ax1.set_title('Target Distribution')
        ax1.legend()
        ax1.grid(alpha=0.3)
        
        # 2. Box plots
        ax2 = plt.subplot(3, 3, 2)
        data_to_plot = [self.y_train, self.y_val, self.y_test]
        ax2.boxplot(data_to_plot, labels=['Train', 'Val', 'Test'])
        ax2.set_ylabel(self.target_name)
        ax2.set_title('Target Distribution (Box Plot)')
        ax2.grid(alpha=0.3)
        
        # 3. Q-Q plot for normality
        ax3 = plt.subplot(3, 3, 3)
        stats.probplot(self.y_train, dist="norm", plot=ax3)
        ax3.set_title('Q-Q Plot (Train Set)')
        ax3.grid(alpha=0.3)
        
        # 4. Feature distributions (first 6 features)
        for i in range(min(6, len(self.feature_names))):
            ax = plt.subplot(3, 3, i+4)
            feature_data = self.X_train[:, -1, i]  # Last timestep
            ax.hist(feature_data, bins=50, alpha=0.7, edgecolor='black')
            ax.set_xlabel(self.feature_names[i])
            ax.set_ylabel('Frequency')
            ax.set_title(f'{self.feature_names[i]} Distribution')
            ax.grid(alpha=0.3)
        
        plt.tight_layout()
        plt.savefig('data_quality_report.png', dpi=150, bbox_inches='tight')
        print("‚úÖ Visualization saved as 'data_quality_report.png'")
        plt.show()


# ========================================
# USAGE EXAMPLE
# ========================================

# Initialize checker
checker = DataQualityChecker(
    X_train=X_train,
    X_val=X_val,
    X_test=X_test,
    y_train=y_train,
    y_val=y_val,
    y_test=y_test,
    feature_names=[
        "PM10_scaled", "NO2_scaled", "SO2_scaled",
        "temperature_2m_scaled", "relative_humidity_2m_scaled",
        "wind_speed_10m_scaled", "surface_pressure_scaled",
        "precipitation_scaled", "hour_sin", "hour_cos",
        "month_sin", "month_cos", "day_of_week_sin",
        "day_of_week_cos", "wind_direction_sin",
        "wind_direction_cos", "is_weekend"
    ],
    target_name='PM2_5_log_scaled'
)

# Run all quality checks
checker.check_all()

# Generate visualization report
checker.plot_quality_report()

# ========================================
# ADDITIONAL CHECKS
# ========================================

print("\n" + "="*80)
print("ADDITIONAL SANITY CHECKS")
print("="*80)

# Check data types
print("\n[11] DATA TYPE CHECK")
print("-" * 80)
print(f"X_train dtype: {X_train.dtype}")
print(f"y_train dtype: {y_train.dtype}")
if X_train.dtype != np.float32 and X_train.dtype != np.float64:
    print("‚ö†Ô∏è  X_train is not float type - may cause issues")
if y_train.dtype != np.float32 and y_train.dtype != np.float64:
    print("‚ö†Ô∏è  y_train is not float type - may cause issues")

# Check for data leakage
print("\n[12] DATA LEAKAGE CHECK")
print("-" * 80)
X_train_2d = X_train.reshape(len(X_train), -1)
X_test_2d = X_test.reshape(len(X_test), -1)

# Check if any test samples appear in training
train_df = pd.DataFrame(X_train_2d)
test_df = pd.DataFrame(X_test_2d)
common = pd.merge(train_df, test_df, how='inner')

if len(common) > 0:
    print(f"‚ö†Ô∏è  POSSIBLE DATA LEAKAGE: {len(common)} test samples found in training set!")
else:
    print("‚úÖ No data leakage detected")

print("\n" + "="*80)
print("DATA QUALITY ASSESSMENT COMPLETE")
print("="*80)