In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import numpy as np

In [2]:


# Define file paths
PREPROCESSED_PATH = '../Data/Preprocessed'
OUTPUT_PATH = '../Outputs'

# Ensure output directory exists
os.makedirs(OUTPUT_PATH, exist_ok=True)

# Set Seaborn style and color palette
sns.set_style("whitegrid")
sns.set_palette("Blues")

def load_data():
    """Load preprocessed training data."""
    try:
        train_data = pd.read_csv(os.path.join(PREPROCESSED_PATH, 'train_data.csv'))
        train_data['date'] = pd.to_datetime(train_data['date'])
        print("Preprocessed data loaded successfully.")
        return train_data
    except FileNotFoundError as e:
        raise FileNotFoundError(f"Error: {e}. Check if train_data.csv exists in {PREPROCESSED_PATH}")

def basic_understanding(data):
    """Provide a basic understanding of the data and check quality."""
    print("\n--- Basic Understanding of Data ---")
    print("Data Shape:", data.shape)
    print("\nData Types:\n", data.dtypes)
    print("\nSummary Statistics:\n", data.describe())
    print("\nMissing Values:\n", data.isnull().sum())
    
    # Data quality checks
    print("\n--- Data Quality Checks ---")
    # Negative rainfall
    negative_rainfall = data[data['rainfall_sum'] < 0]
    print(f"Negative rainfall values: {len(negative_rainfall)}")
    
    # Unrealistic rainfall (>1000 mm daily)
    high_rainfall = data[data['rainfall_sum'] > 1000]
    print(f"Unrealistic rainfall values (>1000 mm): {len(high_rainfall)}")
    
    # Duplicate records
    duplicates = data.duplicated().sum()
    print(f"Duplicate records: {duplicates}")
    
    # Invalid dates
    invalid_dates = data['date'].isnull().sum()
    print(f"Invalid dates: {invalid_dates}")
    
    # Save summary and quality checks
    with open(os.path.join(OUTPUT_PATH, 'data_summary.txt'), 'w') as f:
        f.write(f"Data Shape: {data.shape}\n\n")
        f.write("Data Types:\n")
        f.write(str(data.dtypes) + "\n\n")
        f.write("Summary Statistics:\n")
        f.write(str(data.describe()) + "\n\n")
        f.write("Missing Values:\n")
        f.write(str(data.isnull().sum()) + "\n\n")
        f.write("Data Quality Checks:\n")
        f.write(f"Negative rainfall values: {len(negative_rainfall)}\n")
        f.write(f"Unrealistic rainfall values (>1000 mm): {len(high_rainfall)}\n")
        f.write(f"Duplicate records: {duplicates}\n")
        f.write(f"Invalid dates: {invalid_dates}\n")

def target_variable_analysis(data):
    """Analyze the target variable: rainfall_sum."""
    print("\n--- Target Variable Analysis: rainfall_sum ---")
    plt.figure(figsize=(10, 6))
    sns.histplot(data['rainfall_sum'], bins=50, kde=True, color='steelblue')
    plt.title('Distribution of Daily Rainfall')
    plt.xlabel('Rainfall (mm)')
    plt.ylabel('Frequency')
    plt.savefig(os.path.join(OUTPUT_PATH, 'rainfall_distribution.png'))
    plt.close()
    
    plt.figure(figsize=(10, 6))
    sns.histplot(np.log1p(data['rainfall_sum']), bins=50, kde=True, color='steelblue')
    plt.title('Log-Transformed Distribution of Daily Rainfall')
    plt.xlabel('Log(Rainfall + 1)')
    plt.ylabel('Frequency')
    plt.savefig(os.path.join(OUTPUT_PATH, 'log_rainfall_distribution.png'))
    plt.close()
    
    print("Rainfall distribution plots saved.")

def univariate_analysis(data):
    """Perform univariate analysis on key variables."""
    print("\n--- Univariate Analysis ---")
    station_counts = data['station_name_x'].value_counts()  # Adjusted column name
    print("Number of records per station:\n", station_counts)
    
    plt.figure(figsize=(12, 6))
    station_counts.plot(kind='bar', color='steelblue')
    plt.title('Number of Records per Station')
    plt.xlabel('Station Name')
    plt.ylabel('Number of Records')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_PATH, 'station_distribution.png'))
    plt.close()
    
    monthly_mean = data.groupby('month')['rainfall_sum'].mean()
    plt.figure(figsize=(10, 6))
    monthly_mean.plot(kind='bar', color='steelblue')
    plt.title('Average Monthly Rainfall')
    plt.xlabel('Month')
    plt.ylabel('Average Rainfall (mm)')
    plt.xticks(ticks=range(12), labels=['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'])
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_PATH, 'monthly_rainfall.png'))
    plt.close()
    
    print("Univariate analysis plots saved.")

def monthly_rainfall_trend(data):
    """Plot monthly rainfall trend over years."""
    print("\n--- Monthly Rainfall Trend ---")
    monthly_trend = data.groupby(['year', 'month'])['rainfall_sum'].mean().unstack()
    
    plt.figure(figsize=(14, 8))
    monthly_trend.plot(kind='line', cmap='Blues', linewidth=2)
    plt.title('Monthly Rainfall Trend Over Years')
    plt.xlabel('Year')
    plt.ylabel('Average Rainfall (mm)')
    plt.legend(['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'], title='Month')
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_PATH, 'monthly_rainfall_trend.png'))
    plt.close()
    
    print("Monthly rainfall trend plot saved.")

def rainfall_distribution_by_station(data):
    """Plot rainfall distribution by station name."""
    print("\n--- Rainfall Distribution by Station ---")
    plt.figure(figsize=(14, 8))
    sns.boxplot(data=data, x='station_name_x', y='rainfall_sum', color='steelblue')  # Adjusted column name
    plt.title('Rainfall Distribution by Station')
    plt.xlabel('Station Name')
    plt.ylabel('Rainfall (mm)')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_PATH, 'rainfall_by_station.png'))
    plt.close()
    
    print("Rainfall distribution by station plot saved.")

def top_5_rainfall_stations(data):
    """Plot yearly rainfall for top 5 stations by total rainfall."""
    print("\n--- Top 5 Rainfall Stations Over Time ---")
    station_totals = data.groupby('station_name_x')['rainfall_sum'].sum().sort_values(ascending=False)  # Adjusted column name
    top_5_stations = station_totals.head(5).index
    
    top_5_data = data[data['station_name_x'].isin(top_5_stations)]  # Adjusted column name
    yearly_totals = top_5_data.groupby(['year', 'station_name_x'])['rainfall_sum'].sum().unstack()  # Adjusted column name
    
    plt.figure(figsize=(14, 8))
    yearly_totals.plot(kind='line', cmap='Blues', linewidth=2)
    plt.title('Yearly Rainfall for Top 5 Stations')
    plt.xlabel('Year')
    plt.ylabel('Total Rainfall (mm)')
    plt.legend(title='Station Name')
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_PATH, 'top_5_rainfall_stations.png'))
    plt.close()
    
    print("Top 5 rainfall stations plot saved.")

def scatter_month_vs_rainfall(data):
    """Scatter plot of month vs. rainfall_sum."""
    print("\n--- Scatter Plot: Month vs. Rainfall ---")
    plt.figure(figsize=(12, 6))
    sns.scatterplot(data=data, x='month', y='rainfall_sum', color='steelblue', alpha=0.5)
    plt.title('Month vs. Daily Rainfall')
    plt.xlabel('Month')
    plt.ylabel('Rainfall (mm)')
    plt.xticks(ticks=range(1, 13), labels=['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'])
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_PATH, 'month_vs_rainfall_scatter.png'))
    plt.close()
    
    print("Month vs. rainfall scatter plot saved.")

def bivariate_analysis(data):
    """Perform bivariate analysis between rainfall and other variables."""
    print("\n--- Bivariate Analysis ---")
    station_avg = data.groupby('station_name_x').agg({'rainfall_sum': 'mean', 'ele(meter)': 'first'}).reset_index()  # Adjusted column name
    plt.figure(figsize=(10, 6))
    sns.scatterplot(data=station_avg, x='ele(meter)', y='rainfall_sum', color='steelblue')
    plt.title('Average Rainfall vs. Elevation')
    plt.xlabel('Elevation (m)')
    plt.ylabel('Average Rainfall (mm)')
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_PATH, 'rainfall_vs_elevation.png'))
    plt.close()
    
    plt.figure(figsize=(12, 6))
    sns.boxplot(data=data, x='month', y='rainfall_sum', color='steelblue')
    plt.title('Rainfall Distribution by Month')
    plt.xlabel('Month')
    plt.ylabel('Rainfall (mm)')
    plt.xticks(ticks=range(12), labels=['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'])
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_PATH, 'rainfall_by_month.png'))
    plt.close()
    
    print("Bivariate analysis plots saved.")

def multivariate_analysis(data):
    """Perform multivariate analysis, including correlations."""
    print("\n--- Multivariate Analysis ---")
    numerical_cols = data.select_dtypes(include=['float64', 'int64']).columns
    correlation_matrix = data[numerical_cols].corr()
    
    plt.figure(figsize=(12, 10))
    sns.heatmap(correlation_matrix, annot=True, cmap='rocket', fmt='.2f')
    plt.title('Correlation Matrix')
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_PATH, 'correlation_matrix.png'))
    plt.close()
    
    selected_cols = ['rainfall_sum', 'ele(meter)', 'lat(deg)', 'lon(deg)']
    available_cols = [col for col in selected_cols if col in data.columns]
    if len(available_cols) > 1:
        sns.pairplot(data[available_cols], palette='Blues')
        plt.savefig(os.path.join(OUTPUT_PATH, 'pairplot.png'))
        plt.close()
    else:
        print("Insufficient columns for pairplot.")
    
    print("Multivariate analysis plots saved.")

def outlier_detection(data):
    """Detect outliers in rainfall data."""
    print("\n--- Outlier Detection ---")
    plt.figure(figsize=(10, 6))
    sns.boxplot(data['rainfall_sum'], color='steelblue')
    plt.title('Boxplot of Daily Rainfall')
    plt.xlabel('Rainfall (mm)')
    plt.savefig(os.path.join(OUTPUT_PATH, 'rainfall_boxplot.png'))
    plt.close()
    
    Q1 = data['rainfall_sum'].quantile(0.25)
    Q3 = data['rainfall_sum'].quantile(0.75)
    IQR = Q3 - Q1
    lower_bound = Q1 - 1.5 * IQR
    upper_bound = Q3 + 1.5 * IQR
    outliers = data[(data['rainfall_sum'] < lower_bound) | (data['rainfall_sum'] > upper_bound)]
    print(f"Number of outliers in rainfall_sum: {len(outliers)}")
    
    outliers.to_csv(os.path.join(OUTPUT_PATH, 'rainfall_outliers.csv'), index=False)
    print("Outlier detection completed.")

def missing_data_visualization(data):
    """Visualize missing data patterns."""
    print("\n--- Missing Data Visualization ---")
    plt.figure(figsize=(12, 6))
    sns.heatmap(data.isnull(), cbar=False, cmap='rocket')
    plt.title('Missing Data Heatmap')
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_PATH, 'missing_data_heatmap.png'))
    plt.close()
    print("Missing data heatmap saved.")

def main():
    """Main function to execute EDA steps."""
    data = load_data()
    basic_understanding(data)
    target_variable_analysis(data)
    univariate_analysis(data)
    monthly_rainfall_trend(data)
    rainfall_distribution_by_station(data)
    top_5_rainfall_stations(data)
    scatter_month_vs_rainfall(data)
    bivariate_analysis(data)
    multivariate_analysis(data)
    outlier_detection(data)
    missing_data_visualization(data)
    print("EDA completed successfully.")

if __name__ == "__main__":
    main()

Preprocessed data loaded successfully.

--- Basic Understanding of Data ---
Data Shape: (215504, 21)

Data Types:
 gsid                         int64
station_id                   int64
station_name_x              object
district_x                  object
year                         int64
month                        int64
days                         int64
rainfall_sum               float64
unnamed:_8                 float64
unnamed:_9_x               float64
unnamed:_10                float64
s_n_                       float64
station_name_y              object
basin_office                object
types_of_station            object
district_y                  object
lat(deg)                   float64
lon(deg)                   float64
ele(meter)                 float64
unnamed:_9_y                object
date                datetime64[ns]
dtype: object

Summary Statistics:
                 gsid     station_id           year          month  \
count  215504.000000  215504.000000  215504.0

  func(x=vector, **plot_kwargs)
  func(x=vector, **plot_kwargs)
  func(x=vector, **plot_kwargs)
  func(x=vector, **plot_kwargs)
  func(x=x, y=y, **kwargs)
  func(x=x, y=y, **kwargs)
  func(x=x, y=y, **kwargs)
  func(x=x, y=y, **kwargs)
  func(x=x, y=y, **kwargs)
  func(x=x, y=y, **kwargs)
  func(x=x, y=y, **kwargs)
  func(x=x, y=y, **kwargs)
  func(x=x, y=y, **kwargs)
  func(x=x, y=y, **kwargs)
  func(x=x, y=y, **kwargs)
  func(x=x, y=y, **kwargs)


Multivariate analysis plots saved.

--- Outlier Detection ---
Number of outliers in rainfall_sum: 43018
Outlier detection completed.

--- Missing Data Visualization ---
Missing data heatmap saved.
EDA completed successfully.


<Figure size 1400x800 with 0 Axes>

<Figure size 1400x800 with 0 Axes>