# EDA Quick Reference (seaborn)

> 目标：在 30–90 分钟内把数据结构、质量、分布、关系、时间结构讲清楚，并产出可复用图表与统计量。

**约定**
- 主表：`df`
- 目标列：`TARGET`
- 时间列：`TIME_COL`（若存在）
- 资产/分组列：`GROUP_COL`（若存在）

本 notebook 只包含可直接复制的 EDA 常用片段；按需运行与删改。


In [None]:
# --- imports ---
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import TimeSeriesSplit

sns.set_theme(style="whitegrid")
pd.set_option("display.max_columns", 200)
pd.set_option("display.width", 160)


## 0. 载入数据与基础检查

- 读入后立刻检查：行列数、列类型、缺失、重复、时间范围（若存在）


In [None]:
# df = pd.read_csv("path/to/data.csv")
# df = pd.read_parquet("path/to/data.parquet")

df.shape, df.head(3)


In [None]:
# --- basic schema ---
display(df.info())
display(df.describe(include="all").T.head(30))

# duplicates
dup_rows = df.duplicated().sum()
dup_rows


In [None]:
# --- missingness ---
na = df.isna().mean().sort_values(ascending=False)
na.head(30)

# quick missingness plot (top-k)
topk = 30
plt.figure()
sns.barplot(x=na.head(topk).values, y=na.head(topk).index)
plt.title(f"Missing rate (top {topk})")
plt.xlabel("missing rate")
plt.ylabel("")
plt.show()


## 1. 类型与切片

- 数值列 / 类别列 / 时间列（若存在）
- 仅在子集上画图（如抽样、过滤时间窗口）


In [None]:
# --- column buckets ---
num_cols = df.select_dtypes(include=[np.number]).columns.tolist()
cat_cols = df.select_dtypes(include=["object", "category", "bool"]).columns.tolist()

num_cols[:10], cat_cols[:10], len(num_cols), len(cat_cols)


In [None]:
# --- optional sampling for speed ---
# df_plot = df.sample(min(len(df), 200_000), random_state=0) if len(df) > 200_000 else df
df_plot = df
df_plot.shape


## 2. 单变量分布（数值）

- `histplot`：形状、偏态、厚尾
- `boxplot`：离群点与稳健范围
- `kdeplot`：更平滑的形状对比


In [None]:
# --- choose a feature to inspect ---
COL = num_cols[0] if len(num_cols) else None
COL


In [None]:
if COL is not None:
    plt.figure()
    sns.histplot(data=df_plot, x=COL, bins=60, kde=True)
    plt.title(f"hist + kde: {COL}")
    plt.show()

    plt.figure()
    sns.boxplot(data=df_plot, x=COL)
    plt.title(f"box: {COL}")
    plt.show()


## 3. 目标列（若存在）

- 分布与离群点
- 与关键特征的关系（散点/回归线）


In [None]:
# --- set target column name ---
TARGET = "y"  # TODO: rename
TARGET in df.columns


In [None]:
if TARGET in df.columns:
    plt.figure()
    sns.histplot(data=df_plot, x=TARGET, bins=60, kde=True)
    plt.title(f"target hist: {TARGET}")
    plt.show()

    plt.figure()
    sns.boxplot(data=df_plot, x=TARGET)
    plt.title(f"target box: {TARGET}")
    plt.show()


## 4. 类别变量分布与分组差异

- `countplot`：类别频率（含不平衡）
- `boxplot/violinplot`：类别 vs 数值（含 target）


In [None]:
# --- countplot for a categorical column ---
CAT = cat_cols[0] if len(cat_cols) else None
CAT


In [None]:
if CAT is not None:
    topn = 30
    vc = df_plot[CAT].astype("object").value_counts().head(topn)
    plt.figure(figsize=(8, max(3, 0.25*len(vc))))
    sns.barplot(x=vc.values, y=vc.index)
    plt.title(f"count: {CAT} (top {topn})")
    plt.xlabel("count")
    plt.ylabel("")
    plt.show()


In [None]:
# --- category vs numeric ---
if (CAT is not None) and (len(num_cols) > 0):
    YCOL = TARGET if TARGET in df.columns else num_cols[0]
    topn = 15
    topcats = df_plot[CAT].astype("object").value_counts().head(topn).index
    d = df_plot[df_plot[CAT].astype("object").isin(topcats)].copy()

    plt.figure(figsize=(10, 4))
    sns.boxplot(data=d, x=CAT, y=YCOL)
    plt.title(f"box: {YCOL} by {CAT} (top {topn})")
    plt.xticks(rotation=45, ha="right")
    plt.tight_layout()
    plt.show()

    plt.figure(figsize=(10, 4))
    sns.violinplot(data=d, x=CAT, y=YCOL, inner="quartile", cut=0)
    plt.title(f"violin: {YCOL} by {CAT} (top {topn})")
    plt.xticks(rotation=45, ha="right")
    plt.tight_layout()
    plt.show()


## 5. 双变量关系（数值 vs 数值）

- `scatterplot`：线性/非线性、簇结构、异方差
- `regplot`：趋势线（`lowess=True` 适合非线性）


In [None]:
# --- pick x feature for bivariate plots ---
XCOL = num_cols[0] if len(num_cols) else None
YCOL = TARGET if TARGET in df.columns else (num_cols[1] if len(num_cols) > 1 else None)
XCOL, YCOL


In [None]:
if (XCOL is not None) and (YCOL is not None) and (XCOL != YCOL):
    d = df_plot[[XCOL, YCOL]].dropna()
    # downsample if too large
    if len(d) > 200_000:
        d = d.sample(200_000, random_state=0)

    plt.figure()
    sns.scatterplot(data=d, x=XCOL, y=YCOL, s=10, alpha=0.4)
    plt.title(f"scatter: {YCOL} vs {XCOL}")
    plt.show()

    plt.figure()
    sns.regplot(data=d, x=XCOL, y=YCOL, scatter_kws=dict(s=10, alpha=0.3), lowess=True)
    plt.title(f"regplot lowess: {YCOL} vs {XCOL}")
    plt.show()


## 6. 相关性与共线性（数值列）

- `corr()` + `heatmap`
- 只对一部分列画热力图（避免太大）


In [None]:
# --- correlation heatmap ---
k = min(30, len(num_cols))
sel = num_cols[:k]
corr = df_plot[sel].corr()

plt.figure(figsize=(10, 8))
sns.heatmap(corr, center=0, square=True)
plt.title(f"corr heatmap (first {k} numeric cols)")
plt.tight_layout()
plt.show()


## 7. Pairplot（小样本/小列数时）

- 小数据快速扫：分布 + 两两关系
- 大数据时先抽样


In [None]:
# --- pairplot ---
cols = [c for c in [TARGET] if c in df.columns] + num_cols[:4]
cols = list(dict.fromkeys(cols))  # unique keep order

if len(cols) >= 2:
    d = df_plot[cols].dropna()
    if len(d) > 20_000:
        d = d.sample(20_000, random_state=0)
    sns.pairplot(d, corner=True, diag_kind="hist")
    plt.show()


## 8. 时间序列（若存在）

- 先确保时间列为 `datetime`
- 画总体走势、分组走势
- 检查缺口、重复时间戳、频率是否稳定


In [None]:
# --- set time column name (optional) ---
TIME_COL = "date"  # TODO: rename if exists
TIME_COL in df.columns


In [None]:
if TIME_COL in df.columns:
    df_plot[TIME_COL] = pd.to_datetime(df_plot[TIME_COL], errors="coerce")
    display(df_plot[[TIME_COL]].describe())
    df_plot = df_plot.sort_values(TIME_COL)


In [None]:
if (TIME_COL in df_plot.columns) and (TARGET in df_plot.columns):
    d = df_plot[[TIME_COL, TARGET]].dropna()
    plt.figure(figsize=(12, 4))
    sns.lineplot(data=d, x=TIME_COL, y=TARGET, errorbar=None)
    plt.title(f"time line: {TARGET}")
    plt.tight_layout()
    plt.show()

    # rolling mean
    d2 = d.set_index(TIME_COL).sort_index()
    roll = d2[TARGET].rolling("7D").mean() if d2.index.inferred_type == "datetime64" else d2[TARGET].rolling(50).mean()
    plt.figure(figsize=(12, 4))
    plt.plot(roll.index, roll.values)
    plt.title(f"rolling mean: {TARGET}")
    plt.tight_layout()
    plt.show()


In [None]:
# --- optional grouped time series ---
GROUP_COL = "asset"  # TODO: rename if exists
if (TIME_COL in df_plot.columns) and (GROUP_COL in df_plot.columns) and (TARGET in df_plot.columns):
    d = df_plot[[TIME_COL, GROUP_COL, TARGET]].dropna()
    # keep top groups for readability
    topn = 8
    topg = d[GROUP_COL].astype("object").value_counts().head(topn).index
    d = d[d[GROUP_COL].astype("object").isin(topg)]

    plt.figure(figsize=(12, 4))
    sns.lineplot(data=d, x=TIME_COL, y=TARGET, hue=GROUP_COL, errorbar=None)
    plt.title(f"grouped time line: {TARGET} by {GROUP_COL} (top {topn})")
    plt.tight_layout()
    plt.show()


## 9. 切分与泄露检查（时间序列场景）

- 时间切分：训练集严格早于测试集
- `TimeSeriesSplit` 用于交叉验证（仅示例）


In [None]:
# --- TimeSeriesSplit template ---
# Assumes df is sorted by TIME_COL; if no TIME_COL, skip
if TIME_COL in df.columns:
    df2 = df.sort_values(TIME_COL).reset_index(drop=True)
    tscv = TimeSeriesSplit(n_splits=5)

    splits = []
    for fold, (tr, te) in enumerate(tscv.split(df2), 1):
        splits.append((fold, tr.min(), tr.max(), te.min(), te.max(), len(tr), len(te)))
    splits[:3]


## 10. 离群点与裁剪（可选）

- 分位数裁剪（winsorize 风格）
- IQR 规则（更稳健）


In [None]:
# --- quantile clipping ---
COL = TARGET if TARGET in df.columns else (num_cols[0] if len(num_cols) else None)

if COL is not None:
    lo, hi = df[COL].quantile([0.001, 0.999])
    df[f"{COL}_clip"] = df[COL].clip(lo, hi)
    lo, hi


## 11. EDA 快速结论清单（写 presentation 用）

- 数据规模：行列数、时间跨度、分组数量
- 质量：缺失、重复、异常值、类型问题
- 目标：分布（偏态/厚尾/异常），是否需要变换
- 特征：与目标的关系、强相关/冗余特征
- 时间结构：趋势/季节性/突变、是否存在泄露风险
