# Data Exploration

Testing the data processing utilities with base table and static_0.

In [None]:
import sys
sys.path.insert(0, "..")

import polars as pl
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from src.data_processing import (
    load_table_group,
    downcast_dtypes,
    drop_high_missing_cols,
    drop_high_cardinality_string_cols,
    preprocess_table,
    get_table_info,
)

sns.set_theme(style="whitegrid", palette="muted")
plt.rcParams.update({"figure.dpi": 120, "figure.facecolor": "white"})

In [None]:
DATA_PATH = "../data/"

## Load Base Table

In [None]:
# Load the base table
base = load_table_group(DATA_PATH, "base", split="train")
print(f"Base table shape: {base.shape}")
base.head()

In [None]:
# Check base table info
get_table_info(base)

In [None]:
# Preprocess base table
base_processed = preprocess_table(base)
print(f"\nAfter preprocessing: {base_processed.shape}")
get_table_info(base_processed)

## Load Static_0 Table

This table has multiple chunks (static_0_0, static_0_1, etc.) that need to be concatenated.

In [None]:
# Load static_0 - this will concatenate all chunks
static_0 = load_table_group(DATA_PATH, "static_0", split="train")
print(f"Static_0 table shape: {static_0.shape}")
static_0.head()

In [None]:
# Check static_0 info before preprocessing
info_before = get_table_info(static_0)
print(f"Shape: {info_before['shape']}")
print(f"Memory: {info_before['estimated_memory_mb']:.2f} MB")
print(f"Dtype counts: {info_before['dtype_counts']}")
print(f"Columns with >50% missing: {len(info_before['columns_with_high_missing'])}")

In [None]:
# Test downcast_dtypes
static_0_downcasted = downcast_dtypes(static_0)
info_downcasted = get_table_info(static_0_downcasted)
print(f"Memory before downcast: {info_before['estimated_memory_mb']:.2f} MB")
print(f"Memory after downcast: {info_downcasted['estimated_memory_mb']:.2f} MB")
print(f"Memory reduction: {(1 - info_downcasted['estimated_memory_mb']/info_before['estimated_memory_mb'])*100:.1f}%")

In [None]:
# Test drop_high_missing_cols
print(f"Columns before: {static_0.shape[1]}")
static_0_no_missing = drop_high_missing_cols(static_0, threshold=0.98)
print(f"Columns after (threshold=0.98): {static_0_no_missing.shape[1]}")

In [None]:
# Test drop_high_cardinality_string_cols
static_0_no_high_card = drop_high_cardinality_string_cols(static_0, max_unique=10_000)
print(f"Columns after dropping high-cardinality strings: {static_0_no_high_card.shape[1]}")

In [None]:
# Apply full preprocessing pipeline
static_0_processed = preprocess_table(static_0)
print(f"\nFinal shape after full preprocessing: {static_0_processed.shape}")
get_table_info(static_0_processed)

---

# Exploratory Data Analysis

## (a) Target Distribution & Temporal Drift

In [None]:
target_counts = base["target"].value_counts().sort("target").to_pandas()
total = target_counts["count"].sum()
default_rate = target_counts.loc[target_counts["target"] == 1, "count"].values[0] / total

fig, axes = plt.subplots(1, 2, figsize=(13, 4))

ax = axes[0]
bars = ax.bar(
    target_counts["target"].astype(str),
    target_counts["count"],
    color=["#4C72B0", "#DD8452"],
    edgecolor="black",
    linewidth=0.5,
)
for bar, count in zip(bars, target_counts["count"]):
    ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height(),
            f"{count:,}\n({count/total:.1%})", ha="center", va="bottom", fontsize=9)
ax.set_xlabel("Target")
ax.set_ylabel("Count")
ax.set_title(f"Target Distribution (default rate = {default_rate:.2%})")
ax.ticklabel_format(axis="y", style="plain")

weekly = (
    base.group_by("WEEK_NUM")
    .agg([
        pl.col("target").mean().alias("default_rate"),
        pl.col("target").count().alias("n_cases"),
    ])
    .sort("WEEK_NUM")
    .to_pandas()
)

ax = axes[1]
ax.plot(weekly["WEEK_NUM"], weekly["default_rate"], color="#4C72B0", linewidth=1.2)
z = np.polyfit(weekly["WEEK_NUM"], weekly["default_rate"], 1)
ax.plot(weekly["WEEK_NUM"], np.polyval(z, weekly["WEEK_NUM"]),
        "--", color="#DD8452", linewidth=1.5, label=f"trend (slope={z[0]:.5f})")
ax.set_xlabel("WEEK_NUM")
ax.set_ylabel("Default Rate")
ax.set_title("Default Rate by Week (temporal drift)")
ax.legend(fontsize=9)

fig.tight_layout()
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(13, 3))
ax.bar(weekly["WEEK_NUM"], weekly["n_cases"], color="#4C72B0", edgecolor="none", width=1.0)
ax.set_xlabel("WEEK_NUM")
ax.set_ylabel("Number of Cases")
ax.set_title("Case Volume by Week")
ax.ticklabel_format(axis="y", style="plain")
fig.tight_layout()
plt.show()

## (b) Missing Rates Across Table Groups

In [None]:
TABLE_GROUPS = [
    "base", "static_0", "static_cb_0",
    "person_1", "person_2",
    "applprev_1", "applprev_2",
    "credit_bureau_a_1", "credit_bureau_a_2",
    "credit_bureau_b_1", "credit_bureau_b_2",
    "debitcard_1", "deposit_1", "other_1",
    "tax_registry_a_1", "tax_registry_b_1", "tax_registry_c_1",
]

missing_summary = []
for tg in TABLE_GROUPS:
    try:
        df = load_table_group(DATA_PATH, tg, split="train")
    except FileNotFoundError:
        continue
    n = df.height
    nc = df.null_count()
    for col in df.columns:
        if col == "case_id":
            continue
        rate = nc[col][0] / n
        missing_summary.append({"table_group": tg, "column": col, "missing_rate": rate})

missing_df = pl.DataFrame(missing_summary)
print(f"Total feature columns across all tables: {missing_df.height}")
print(f"Columns with >50% missing: {missing_df.filter(pl.col('missing_rate') > 0.5).height}")
print(f"Columns with >90% missing: {missing_df.filter(pl.col('missing_rate') > 0.9).height}")
print(f"Columns with >98% missing: {missing_df.filter(pl.col('missing_rate') > 0.98).height}")

In [None]:
table_miss = (
    missing_df.group_by("table_group")
    .agg([
        pl.col("missing_rate").mean().alias("avg_missing"),
        pl.col("missing_rate").max().alias("max_missing"),
        (pl.col("missing_rate") > 0.98).sum().alias("cols_gt_98pct"),
        pl.len().alias("n_cols"),
    ])
    .sort("avg_missing", descending=True)
    .to_pandas()
)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

ax = axes[0]
ax.barh(table_miss["table_group"], table_miss["avg_missing"], color="#4C72B0", edgecolor="none")
ax.set_xlabel("Average Missing Rate")
ax.set_title("Average Missing Rate per Table Group")
ax.invert_yaxis()
ax.axvline(0.5, color="grey", linestyle="--", linewidth=0.8, alpha=0.7)

ax = axes[1]
colors = ["#DD8452" if v > 0 else "#4C72B0" for v in table_miss["cols_gt_98pct"]]
ax.barh(table_miss["table_group"], table_miss["cols_gt_98pct"], color=colors, edgecolor="none")
ax.set_xlabel("Number of Columns")
ax.set_title("Columns with >98% Missing per Table Group")
ax.invert_yaxis()

fig.tight_layout()
plt.show()

In [None]:
top_missing = (
    missing_df.filter(pl.col("missing_rate") > 0.90)
    .sort("missing_rate", descending=True)
)
print(f"Columns with >90% missing ({top_missing.height} total):")
print(top_missing.head(30))

## (c) Feature Drift Detection

Compare numeric feature distributions in **early weeks** (WEEK_NUM 0–30) vs **late weeks** (WEEK_NUM 61–91).
For each feature we measure:
- **Relative shift in mean**: `|mean_late - mean_early| / (std_overall + ε)`
- **Relative shift in std**: `|std_late - std_early| / (std_overall + ε)`
- **Shift in missing rate**: `|miss_late - miss_early|`

Features with large shifts are candidates for dropping to improve model stability.

In [None]:
static_with_week = static_0.join(
    base.select("case_id", "WEEK_NUM"), on="case_id", how="left"
)

numeric_cols = [
    c for c in static_0.columns
    if c != "case_id" and static_0[c].dtype in (pl.Float64, pl.Float32, pl.Int64, pl.Int32)
]
print(f"Numeric columns to analyze: {len(numeric_cols)}")

EARLY_MAX = 30
LATE_MIN = 61

early = static_with_week.filter(pl.col("WEEK_NUM") <= EARLY_MAX)
late = static_with_week.filter(pl.col("WEEK_NUM") >= LATE_MIN)
print(f"Early weeks (0-{EARLY_MAX}): {early.height:,} rows")
print(f"Late weeks ({LATE_MIN}-91): {late.height:,} rows")

In [None]:
EPS = 1e-9
drift_records = []

overall_stats = static_0.select([
    pl.col(c).cast(pl.Float64).std().alias(f"{c}__std") for c in numeric_cols
])

for col in numeric_cols:
    std_all = overall_stats[f"{col}__std"][0]
    if std_all is None:
        continue
    std_all = float(std_all)

    mean_e = early[col].cast(pl.Float64).mean()
    mean_l = late[col].cast(pl.Float64).mean()
    std_e = early[col].cast(pl.Float64).std()
    std_l = late[col].cast(pl.Float64).std()
    miss_e = early[col].null_count() / early.height
    miss_l = late[col].null_count() / late.height

    if mean_e is None or mean_l is None:
        continue

    mean_shift = abs(mean_l - mean_e) / (std_all + EPS)
    std_shift = abs((std_l or 0) - (std_e or 0)) / (std_all + EPS)
    miss_shift = abs(miss_l - miss_e)

    drift_records.append({
        "column": col,
        "mean_early": round(mean_e, 4),
        "mean_late": round(mean_l, 4),
        "mean_shift": round(mean_shift, 4),
        "std_shift": round(std_shift, 4),
        "miss_early": round(miss_e, 4),
        "miss_late": round(miss_l, 4),
        "miss_shift": round(miss_shift, 4),
    })

drift_df = pl.DataFrame(drift_records).sort("mean_shift", descending=True)
print(f"Analyzed {drift_df.height} numeric features for drift")
drift_df.head(20)

In [None]:
top_n = 25
top_drift = drift_df.head(top_n).to_pandas()

fig, axes = plt.subplots(1, 3, figsize=(17, 6))

ax = axes[0]
ax.barh(top_drift["column"], top_drift["mean_shift"], color="#DD8452", edgecolor="none")
ax.set_xlabel("Normalised Mean Shift")
ax.set_title(f"Top {top_n} Features by Mean Drift")
ax.invert_yaxis()

top_std = drift_df.sort("std_shift", descending=True).head(top_n).to_pandas()
ax = axes[1]
ax.barh(top_std["column"], top_std["std_shift"], color="#55A868", edgecolor="none")
ax.set_xlabel("Normalised Std Shift")
ax.set_title(f"Top {top_n} Features by Std Drift")
ax.invert_yaxis()

top_miss = drift_df.sort("miss_shift", descending=True).head(top_n).to_pandas()
ax = axes[2]
ax.barh(top_miss["column"], top_miss["miss_shift"], color="#8172B2", edgecolor="none")
ax.set_xlabel("Δ Missing Rate")
ax.set_title(f"Top {top_n} Features by Missing-Rate Shift")
ax.invert_yaxis()

fig.tight_layout()
plt.show()

In [None]:
MEAN_SHIFT_THRESHOLD = 0.3
STD_SHIFT_THRESHOLD = 0.3
MISS_SHIFT_THRESHOLD = 0.1

drift_flagged = drift_df.filter(
    (pl.col("mean_shift") > MEAN_SHIFT_THRESHOLD)
    | (pl.col("std_shift") > STD_SHIFT_THRESHOLD)
    | (pl.col("miss_shift") > MISS_SHIFT_THRESHOLD)
).sort("mean_shift", descending=True)

print(f"Features flagged for drift (any criterion): {drift_flagged.height}")

top_6 = drift_flagged.head(6)["column"].to_list()
if top_6:
    n_plot = len(top_6)
    fig, axes = plt.subplots(2, 3, figsize=(15, 7))
    axes = axes.flatten()
    for i, col in enumerate(top_6):
        ax = axes[i]
        vals_e = early[col].drop_nulls().cast(pl.Float64).to_numpy()
        vals_l = late[col].drop_nulls().cast(pl.Float64).to_numpy()
        lo = np.nanpercentile(np.concatenate([vals_e, vals_l]), 1)
        hi = np.nanpercentile(np.concatenate([vals_e, vals_l]), 99)
        bins = np.linspace(lo, hi, 50)
        ax.hist(vals_e, bins=bins, alpha=0.5, density=True, label="early", color="#4C72B0")
        ax.hist(vals_l, bins=bins, alpha=0.5, density=True, label="late", color="#DD8452")
        ax.set_title(col, fontsize=9)
        ax.legend(fontsize=7)
        ax.tick_params(labelsize=7)
    for j in range(n_plot, len(axes)):
        axes[j].set_visible(False)
    fig.suptitle("Distribution Comparison: Early vs Late Weeks (top drifted features)", fontsize=11)
    fig.tight_layout()
    plt.show()

## (d) Candidate Drift-Prone Features to Drop

Features are flagged if **any** of these hold:
- Normalised mean shift > 0.3
- Normalised std shift > 0.3
- Missing rate shift > 0.1

In [None]:
high_missing_cols = (
    missing_df.filter(
        (pl.col("table_group") == "static_0") & (pl.col("missing_rate") > 0.98)
    )["column"].to_list()
)

drift_prone_cols = drift_flagged["column"].to_list()

candidates_to_drop = sorted(set(drift_prone_cols + high_missing_cols))

print(f"Drift-prone features (static_0): {len(drift_prone_cols)}")
print(f"High-missing features (>98%, static_0): {len(high_missing_cols)}")
print(f"Combined unique candidates to drop: {len(candidates_to_drop)}")
print()
print("Candidate features to drop:")
for col in candidates_to_drop:
    reasons = []
    if col in drift_prone_cols:
        reasons.append("drift")
    if col in high_missing_cols:
        reasons.append(">98% missing")
    print(f"  {col:45s} [{', '.join(reasons)}]")

In [None]:
print("DRIFT_PRONE_FEATURES = [")
for col in candidates_to_drop:
    print(f'    "{col}",')
print("]")