# Exploratory Data Analysis (EDA) - Criteo CTR Dataset

This notebook performs comprehensive EDA on the Criteo CTR dataset.

## Goals:
- Understand data distribution
- Analyze click rate
- Examine feature distributions
- Identify missing values
- Explore correlations

In [None]:
import sys
sys.path.append('..')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pyspark.sql import functions as F

from src.config import Config
from src.utils.logging_utils import setup_logging
from src.utils.spark_utils import create_spark_session
from src.data.loader import CriteoDataLoader

# Setup
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

%matplotlib inline

## 1. Load Configuration and Data

In [None]:
# Load configuration
config = Config('../config/config.yaml')

# Setup logging
logger = setup_logging(level='INFO')

# Create Spark session
spark = create_spark_session(
    app_name=config['spark']['app_name'],
    master=config['spark']['master'],
    executor_memory=config['spark']['executor_memory'],
    driver_memory=config['spark']['driver_memory']
)

print(f"Spark version: {spark.version}")

In [None]:
# Load data
loader = CriteoDataLoader(spark, config)

# Load sample data for EDA (use sample for faster analysis)
sample_path = config['data']['sample_path']
print(f"Loading sample data from: {sample_path}")

# If sample doesn't exist, create it from raw data
import os
if not os.path.exists(sample_path) or len(os.listdir(sample_path)) == 0:
    print("Sample not found. Loading raw data and creating sample...")
    raw_path = config['data']['raw_path']
    df = loader.load_raw_data(raw_path)
    df = loader.create_sample(df, config['data']['sample_size'], sample_path)
else:
    df = loader.load_parquet(sample_path)

print(f"Data loaded: {df.count():,} rows")

## 2. Basic Dataset Information

In [None]:
# Dataset shape
n_rows = df.count()
n_cols = len(df.columns)

print(f"Dataset shape: {n_rows:,} rows × {n_cols} columns")
print(f"\nColumns: {df.columns}")

# Show schema
df.printSchema()

In [None]:
# Show sample rows
df.show(5, truncate=True)

## 3. Target Variable Analysis (Click Rate)

In [None]:
# Click distribution
click_dist = df.groupBy('click').count().orderBy('click').toPandas()

print("Click Distribution:")
print(click_dist)

# Calculate percentages
total = click_dist['count'].sum()
click_dist['percentage'] = (click_dist['count'] / total * 100).round(2)

print("\nClick Distribution (with percentages):")
print(click_dist)

# Calculate click rate
click_rate = click_dist[click_dist['click'] == 1]['percentage'].values[0] / 100
print(f"\nClick-Through Rate (CTR): {click_rate:.4f} ({click_rate*100:.2f}%)")

In [None]:
# Visualize click distribution
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Bar chart
axes[0].bar(click_dist['click'].astype(str), click_dist['count'], color=['blue', 'red'])
axes[0].set_xlabel('Click')
axes[0].set_ylabel('Count')
axes[0].set_title('Click Distribution (Count)')
axes[0].set_xticks([0, 1])
axes[0].set_xticklabels(['No Click (0)', 'Click (1)'])

# Pie chart
axes[1].pie(
    click_dist['count'],
    labels=['No Click (0)', 'Click (1)'],
    autopct='%1.2f%%',
    colors=['blue', 'red']
)
axes[1].set_title('Click Distribution (Percentage)')

plt.tight_layout()
plt.show()

print(f"\n⚠️ CLASS IMBALANCE: Only {click_rate*100:.2f}% of samples are clicks!")
print("This will require special handling (scale_pos_weight in XGBoost)")

## 4. Missing Value Analysis

In [None]:
# Calculate missing values for all columns
missing_data = []

for col in df.columns:
    null_count = df.filter(F.col(col).isNull()).count()
    null_pct = (null_count / n_rows) * 100
    missing_data.append({
        'column': col,
        'missing_count': null_count,
        'missing_percentage': null_pct
    })

missing_df = pd.DataFrame(missing_data)
missing_df = missing_df[missing_df['missing_count'] > 0].sort_values(
    'missing_percentage', ascending=False
)

print(f"Columns with missing values: {len(missing_df)} / {n_cols}")
print("\nTop 10 columns with most missing values:")
print(missing_df.head(10))

In [None]:
# Visualize missing values
if len(missing_df) > 0:
    plt.figure(figsize=(12, 6))
    plt.barh(
        range(len(missing_df.head(20))),
        missing_df.head(20)['missing_percentage']
    )
    plt.yticks(
        range(len(missing_df.head(20))),
        missing_df.head(20)['column']
    )
    plt.xlabel('Missing Percentage (%)')
    plt.ylabel('Column')
    plt.title('Top 20 Columns with Missing Values')
    plt.gca().invert_yaxis()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
else:
    print("No missing values found in the dataset!")

## 5. Numerical Features Analysis

In [None]:
# Get numerical columns
numerical_cols = config['features']['numerical_cols']

# Convert to Pandas for easier analysis (sample only)
sample_size = min(100000, n_rows)
df_sample = df.sample(False, sample_size / n_rows, seed=42).toPandas()

print(f"Analyzing {len(df_sample):,} sampled rows for numerical features...")

In [None]:
# Statistical summary
print("Numerical Features Summary:")
print(df_sample[numerical_cols].describe())

In [None]:
# Distribution plots for numerical features
fig, axes = plt.subplots(5, 3, figsize=(15, 20))
axes = axes.flatten()

for idx, col in enumerate(numerical_cols):
    # Remove nulls for visualization
    data = df_sample[col].dropna()
    
    if len(data) > 0:
        axes[idx].hist(data, bins=50, edgecolor='black', alpha=0.7)
        axes[idx].set_title(f'{col} Distribution')
        axes[idx].set_xlabel('Value')
        axes[idx].set_ylabel('Frequency')
        axes[idx].grid(True, alpha=0.3)
    else:
        axes[idx].text(0.5, 0.5, 'No data', ha='center', va='center')
        axes[idx].set_title(f'{col} Distribution (No Data)')

# Remove extra subplots
for idx in range(len(numerical_cols), len(axes)):
    fig.delaxes(axes[idx])

plt.tight_layout()
plt.show()

## 6. Categorical Features Analysis

In [None]:
# Get categorical columns
categorical_cols = config['features']['categorical_cols']

# Analyze cardinality (number of unique values)
cardinality_data = []

for col in categorical_cols:
    unique_count = df.select(col).distinct().count()
    cardinality_data.append({
        'column': col,
        'unique_values': unique_count
    })

cardinality_df = pd.DataFrame(cardinality_data).sort_values(
    'unique_values', ascending=False
)

print("Categorical Features Cardinality:")
print(cardinality_df)

In [None]:
# Visualize cardinality
plt.figure(figsize=(12, 6))
plt.barh(range(len(cardinality_df)), cardinality_df['unique_values'])
plt.yticks(range(len(cardinality_df)), cardinality_df['column'])
plt.xlabel('Number of Unique Values')
plt.ylabel('Column')
plt.title('Categorical Features Cardinality')
plt.gca().invert_yaxis()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 7. Correlation Analysis (Numerical Features)

In [None]:
# Calculate correlation matrix
correlation_cols = ['click'] + numerical_cols
correlation_matrix = df_sample[correlation_cols].corr()

# Plot correlation heatmap
plt.figure(figsize=(14, 12))
sns.heatmap(
    correlation_matrix,
    annot=False,
    cmap='coolwarm',
    center=0,
    square=True,
    linewidths=0.5
)
plt.title('Correlation Matrix (Numerical Features + Target)')
plt.tight_layout()
plt.show()

# Show correlations with target
print("\nCorrelation with target (click):")
target_corr = correlation_matrix['click'].sort_values(ascending=False)
print(target_corr)

## 8. Summary and Key Findings

In [None]:
print("="*60)
print("EDA SUMMARY")
print("="*60)
print(f"\n1. Dataset Size: {n_rows:,} rows × {n_cols} columns")
print(f"\n2. Click-Through Rate: {click_rate:.4f} ({click_rate*100:.2f}%)")
print(f"   - Class Imbalance: {(1-click_rate)*100:.2f}% No Click, {click_rate*100:.2f}% Click")
print(f"   - Imbalance Ratio: {(1-click_rate)/click_rate:.1f}:1")
print(f"\n3. Missing Values: {len(missing_df)} columns have missing values")
if len(missing_df) > 0:
    print(f"   - Highest missing: {missing_df.iloc[0]['column']} ({missing_df.iloc[0]['missing_percentage']:.2f}%)")
print(f"\n4. Numerical Features: {len(numerical_cols)} features (I1-I13)")
print(f"\n5. Categorical Features: {len(categorical_cols)} features (C1-C26)")
print(f"   - Cardinality ranges from {cardinality_df['unique_values'].min():,} to {cardinality_df['unique_values'].max():,}")
print(f"\n6. Strongest correlations with target:")
top_corr = target_corr[target_corr.index != 'click'].head(3)
for feat, corr_val in top_corr.items():
    print(f"   - {feat}: {corr_val:.4f}")
print("\n" + "="*60)
print("\nKey Takeaways:")
print("- Severe class imbalance requires special handling (scale_pos_weight)")
print("- Missing values need imputation")
print("- High cardinality categoricals need encoding (count/target encoding)")
print("- Weak correlations suggest non-linear relationships (good for tree models)")
print("="*60)

In [None]:
# Stop Spark session
spark.stop()
print("Spark session stopped.")