In [None]:
import pandas as pd
import numpy as np
from typing import List, Optional

def dataset_sanity_check(
    df: pd.DataFrame,
    target_col: str,
    interval_cols: Optional[List[str]] = None,
    max_unique_for_cat: int = 50
) -> None:
    """
    Prints a summary of df:
     - shape, dtype counts
     - missing values per column
     - target distribution
     - suggested interval vs categorical splits
    Args:
      df: raw DataFrame
      target_col: name of the target column
      interval_cols: optional list of numeric feature names;
                     if None, they'll be inferred by dtype and unique count
      max_unique_for_cat: if dtype==object but unique<=this, treat as cat
    """
    print("─── DATASET SANITY CHECK ───")
    print(f"Shape: {df.shape[0]} rows × {df.shape[1]} cols")
    print("\nColumn dtypes:")
    print(df.dtypes.value_counts().to_string(), "\n")

    # Missing
    missing = df.isna().sum()
    if missing.any():
        print("Missing values:")
        print(missing[missing>0].sort_values(), "\n")
    else:
        print("No missing values.\n")

    # Target distribution
    if target_col not in df.columns:
        raise ValueError(f"Target column '{target_col}' not found in DataFrame!")
    print(f"Target distribution ({target_col}):")
    print(df[target_col].value_counts(normalize=True).mul(100).round(2).astype(str) + "%\n")

    # Feature type suggestions
    if interval_cols is None:
        # infer numeric by dtype
        num = df.select_dtypes(include=[np.number]).columns.tolist()
        # treat low‐card object cols as categorical too
        obj = [
            c for c in df.select_dtypes(include=["object"]).columns
            if df[c].nunique() <= max_unique_for_cat and c != target_col
        ]
        interval_cols = num
        cat_cols = [c for c in df.columns if c not in interval_cols + [target_col]]
    else:
        # user‐provided
        cat_cols = [c for c in df.columns if c not in interval_cols + [target_col]]

    print(f"Suggested numeric (interval) cols ({len(interval_cols)}): {interval_cols}")
    print(f"Suggested categorical cols ({len(cat_cols)}): {cat_cols}\n")

    # Warn about very small or very large datasets
    if df.shape[0] < 100:
        print("⚠️  Warning: fewer than 100 samples—GANs may overfit or collapse.")
    elif df.shape[0] > 200_000:
        print("⚠️  Warning: very large dataset—training may be slow.")

    print("──────────────────────────────\n")


# ─── Example usage ───
if __name__ == "__main__":
    # Load any dataset
    df = pd.read_csv("/content/drive/MyDrive/Katabatic/Data/Nursery/nursery.csv")

    # Run the check
    dataset_sanity_check(
      df,
      target_col="class",
      # you can also explicitly tell it which interval cols to use:
      interval_cols=["age","education-num","capital-gain","capital-loss","hours-per-week"]
    )


─── DATASET SANITY CHECK ───
Shape: 12960 rows × 9 cols

Column dtypes:
object    9 

No missing values.

Target distribution (class):
class
not_recom     33.33%\n
priority      32.92%\n
spec_prior     31.2%\n
very_recom     2.53%\n
recommend      0.02%\n
Name: proportion, dtype: object
Suggested numeric (interval) cols (5): ['age', 'education-num', 'capital-gain', 'capital-loss', 'hours-per-week']
Suggested categorical cols (8): ['parents', 'has_nurs', 'form', 'children', 'housing', 'finance', 'social', 'health']

──────────────────────────────

