In [None]:
import pandas as pd
import numpy as np
from typing import List, Optional

def dataset_sanity_check(
    df: pd.DataFrame,
    target_col: str,
    interval_cols: Optional[List[str]] = None,
    max_unique_for_cat: int = 50
) -> None:
    """
    Prints a summary of df:
     - shape, dtype counts
     - missing values per column
     - target distribution
     - suggested interval vs categorical splits
    Args:
      df: raw DataFrame
      target_col: name of the target column
      interval_cols: optional list of numeric feature names;
                     if None, they'll be inferred by dtype and unique count
      max_unique_for_cat: if dtype==object but unique<=this, treat as cat
    """
    print("─── DATASET SANITY CHECK ───")
    print(f"Shape: {df.shape[0]} rows × {df.shape[1]} cols")
    print("\nColumn dtypes:")
    print(df.dtypes.value_counts().to_string(), "\n")

    # Missing
    missing = df.isna().sum()
    if missing.any():
        print("Missing values:")
        print(missing[missing>0].sort_values(), "\n")
    else:
        print("No missing values.\n")

    # Target distribution
    if target_col not in df.columns:
        raise ValueError(f"Target column '{target_col}' not found in DataFrame!")
    print(f"Target distribution ({target_col}):")
    print(df[target_col].value_counts(normalize=True).mul(100).round(2).astype(str) + "%\n")

    # Feature type suggestions
    if interval_cols is None:
        # infer numeric by dtype
        num = df.select_dtypes(include=[np.number]).columns.tolist()
        # treat low‐card object cols as categorical too
        obj = [
            c for c in df.select_dtypes(include=["object"]).columns
            if df[c].nunique() <= max_unique_for_cat and c != target_col
        ]
        interval_cols = num
        cat_cols = [c for c in df.columns if c not in interval_cols + [target_col]]
    else:
        # user‐provided
        cat_cols = [c for c in df.columns if c not in interval_cols + [target_col]]

    print(f"Suggested numeric (interval) cols ({len(interval_cols)}): {interval_cols}")
    print(f"Suggested categorical cols ({len(cat_cols)}): {cat_cols}\n")

    # Warn about very small or very large datasets
    if df.shape[0] < 100:
        print("  Warning: fewer than 100 samples—GANs may overfit or collapse.")
    elif df.shape[0] > 200_000:
        print(" Warning: very large dataset—training may be slow.")

    print("──────────────────────────────\n")


# ─── Example usage ───
if __name__ == "__main__":
    # Load any dataset
    df = pd.read_csv("/content/drive/MyDrive/Katabatic/Data/Adult/adult-official.csv")

    # Run the check
    dataset_sanity_check(
      df,
      target_col="class",
      # you can also explicitly tell it which interval cols to use:
      interval_cols=["age","education-num","capital-gain","capital-loss","hours-per-week"]
    )


─── DATASET SANITY CHECK ───
Shape: 48842 rows × 15 cols

Column dtypes:
object    15 

No missing values.

Target distribution (class):
class
<=50K    76.07%\n
>50K     23.93%\n
Name: proportion, dtype: object
Suggested numeric (interval) cols (5): ['age', 'education-num', 'capital-gain', 'capital-loss', 'hours-per-week']
Suggested categorical cols (9): ['workclass', 'fnlwgt', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country']

──────────────────────────────



In [None]:
import pandas as pd
import numpy as np
import re
from sklearn.preprocessing import LabelEncoder, StandardScaler
import os

def interval_to_mid(x):
    """Convert interval strings to midpoints, e.g., '20-30' => 25.0"""
    if pd.isna(x): return x
    parts = re.findall(r'-?[\d\.]+|inf', str(x))
    if len(parts) == 2:
        lo, hi = parts
        lo = float(lo) if lo not in ("-inf", "inf") else float(hi)
        hi = float(hi) if hi not in ("-inf", "inf") else float(lo)
        return (lo + hi) / 2
    try:
        return float(x)
    except:
        return pd.NA

def preprocess_dataset(file_path, target_column, interval_cols=None, output_path="processed_data.csv"):
    # Load dataset
    df = pd.read_csv(file_path)
    df.columns = df.columns.str.strip()


    # Auto-detect interval columns if not provided
    if interval_cols is None:
        interval_cols = df.select_dtypes(include=['object']).columns[
            df.select_dtypes(include=['object']).apply(
                lambda col: col.astype(str).str.contains(r'\d+\s*-\s*\d+').any()
            )
        ].tolist()

    # Convert intervals to midpoints
    for col in interval_cols:
        df[col] = df[col].apply(interval_to_mid)

    # Separate types
    numeric_cols = df.select_dtypes(include=['int64', 'float64']).columns.tolist()
    if target_column in numeric_cols:
        numeric_cols.remove(target_column)

    # Encode target
    le = LabelEncoder()
    df[target_column] = le.fit_transform(df[target_column])

    # Scale numeric columns if they exist
    if numeric_cols:
        scaler = StandardScaler()
        df[numeric_cols] = scaler.fit_transform(df[numeric_cols])
        df_numeric = df[numeric_cols].reset_index(drop=True)
    else:
        df_numeric = pd.DataFrame()

    # Identify categorical columns
    cat_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()

    # One-hot encode categorical columns if they exist
    if cat_cols:
        df_cat = pd.get_dummies(df[cat_cols], drop_first=False).astype(int).reset_index(drop=True)
    else:
        df_cat = pd.DataFrame()

    # Target column
    df_target = df[[target_column]].reset_index(drop=True)

    # Final dataset
    df_processed = pd.concat([df_numeric, df_cat, df_target], axis=1)

    # Hyperparameter adjustment
    num_rows = df_processed.shape[0]
    epochs = 100 if num_rows < 50000 else 150
    print(f"📊 Dataset size: {num_rows} rows → EPOCHS = {epochs}")

    return df_processed, numeric_cols, cat_cols, epochs, target_column,numeric_cols
# Example usage:
data_path = "/content/drive/MyDrive/Katabatic/Data/Adult/adult-official.csv"
target = "class"
interval_columns = []
output_file = "processed_data.csv"
#["age", "education-num", "capital-gain", "capital-loss", "hours-per-week"]

# TARGET_COL        = "class"
# TARGET_COL        = "class"
df_processed, numeric_cols, cat_cols, EPOCHS,TARGET_COL,NUMERIC_COLS  = preprocess_dataset(
    file_path=data_path,
    target_column=target,
    interval_cols=interval_columns,
    output_path=output_file
)


📊 Dataset size: 48842 rows → EPOCHS = 100
