# Generate Imbalanced Dataset

Generates synthetic imbalanced classification datasets for XGBoost scaling experiments.

**Parameters** (via widgets or job params):
- `env`: Environment name (dev/prod)
- `run_mode`: `full` or `smoke` (smoke uses tiny data for quick validation)
- `json_params`: JSON string with additional config overrides

## Setup Widgets

In [None]:
# Widget definitions - these can be overridden by job parameters
dbutils.widgets.text("env", "dev", "Environment")
dbutils.widgets.dropdown("run_mode", "full", ["full", "smoke"], "Run Mode")
dbutils.widgets.text("json_params", "{}", "JSON Parameters")

# Catalog/schema widgets (can be set by job or bundle variables)
dbutils.widgets.text("catalog", "brian_gen_ai", "Catalog")
dbutils.widgets.text("schema", "xgb_scaling", "Schema")

## Import and Parse Parameters

In [None]:
import sys
import time

# Add src to path for local imports
# When deployed via DAB, the repo files are synced to workspace
import os
notebook_path = dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()
repo_root = "/".join(notebook_path.split("/")[:-2])  # Go up from /notebooks/notebook_name
sys.path.insert(0, f"/Workspace{repo_root}")

# Import core logic
from src.main import run, build_exit_result
from src.config import DatasetConfig

# Get widget values
env = dbutils.widgets.get("env")
run_mode = dbutils.widgets.get("run_mode")
json_params = dbutils.widgets.get("json_params")
catalog = dbutils.widgets.get("catalog")
schema = dbutils.widgets.get("schema")

# Parse parameters and get config
config = run(
    env=env,
    run_mode=run_mode,
    json_params=json_params,
    catalog=catalog,
    schema=schema,
)

## Generate Dataset

In [None]:
from pyspark.sql.functions import (
    rand, randn, when, col, lit, floor, abs as spark_abs,
    concat_ws, pandas_udf, array, struct
)
from pyspark.sql.types import (
    FloatType, IntegerType, StringType, StructType, StructField, ArrayType
)
import pyspark.sql.functions as F

# Categorical feature configurations — mix of different cardinalities
# to simulate real-world data (e.g., country, status, category, etc.)
CATEGORICAL_CONFIGS = [
    {"name": "binary",       "cardinality": 2,    "prefix": "cat_bin",  "weight": 0.15},
    {"name": "low_card",     "cardinality": 5,    "prefix": "cat_low",  "weight": 0.10},
    {"name": "medium_card",  "cardinality": 20,   "prefix": "cat_med",  "weight": 0.08},
    {"name": "high_card",    "cardinality": 100,  "prefix": "cat_hi",   "weight": 0.05},
    {"name": "very_high",    "cardinality": 500,  "prefix": "cat_vhi",  "weight": 0.03},
]


def generate_imbalanced_dataset(
    spark,
    total_rows: int,
    n_features: int,
    n_informative: int,
    n_categorical: int,
    minority_ratio: float,
    seed: int,
):
    """
    Generate large imbalanced classification dataset using Spark.

    OPTIMIZED for large datasets (100M+ rows):
    - Uses batched select() instead of iterative withColumn() to avoid
      massive Spark logical plans that cause OOM on the driver.
    - Adds columns in batches of ~50, with explicit repartitioning to
      keep partition sizes manageable.

    Features:
    - Numerical features: continuous floats (Gaussian distributed)
      - First n_informative are correlated with the target label
      - Remaining numerical features are random noise
    - Categorical features: string columns with varying cardinalities
      - Mix of binary (2), low (5), medium (20), high (100), very high (500)
      - Some are correlated with the target (informative)
    - Label: imbalanced binary (0/1) with minority_ratio as positive class proportion
    """
    n_numerical = n_features - n_categorical

    print(f"Generating: {total_rows:,} rows x {n_features} features")
    print(f"  Numerical features: {n_numerical}")
    print(f"  Categorical features: {n_categorical}")
    print(f"  Informative features: {n_informative}")
    print(f"  Minority ratio: {minority_ratio:.1%}")

    # Choose partition count based on data size
    # Target ~2M rows per partition for good parallelism
    n_partitions = max(8, total_rows // 2_000_000)
    print(f"  Partitions: {n_partitions}")
    print()

    # -------------------------------------------------------------------------
    # Step 1: Base dataframe with id + label
    # -------------------------------------------------------------------------
    df = spark.range(0, total_rows, numPartitions=n_partitions)

    # Imbalanced label (1 = minority class)
    df = df.withColumn(
        "label",
        when(rand(seed) < minority_ratio, lit(1)).otherwise(lit(0)).cast(IntegerType())
    )

    # -------------------------------------------------------------------------
    # Step 2: Generate ALL numerical features in one select() call
    #         This avoids building a massive logical plan from 500+ withColumn() calls.
    # -------------------------------------------------------------------------
    BATCH_SIZE = 50  # Add columns in batches to keep plan size manageable

    print(f"Generating {n_numerical} numerical features (batches of {BATCH_SIZE})...")
    existing_cols = df.columns  # ["id", "label"]

    for batch_start in range(0, n_numerical, BATCH_SIZE):
        batch_end = min(batch_start + BATCH_SIZE, n_numerical)
        new_cols = []

        for i in range(batch_start, batch_end):
            feature_seed = seed + i + 1
            if i < n_informative:
                # Informative: correlated with label
                weight = 0.5 + (i % 10) * 0.15
                new_cols.append(
                    (randn(feature_seed) + col("label") * lit(weight)).cast(FloatType()).alias(f"f{i}")
                )
            else:
                # Noise: pure random
                new_cols.append(
                    randn(feature_seed).cast(FloatType()).alias(f"f{i}")
                )

        # Select existing + new batch in one operation
        df = df.select("*", *new_cols)

        print(f"  Batch {batch_start}-{batch_end - 1} done ({batch_end}/{n_numerical})")

        # Checkpoint periodically to cut the lineage and prevent driver OOM
        if batch_end % 200 == 0 and batch_end < n_numerical:
            print(f"  [Checkpointing at {batch_end} features to cut lineage...]")
            df = df.localCheckpoint(eager=True)

    # -------------------------------------------------------------------------
    # Step 3: Generate ALL categorical features in one select() call
    # -------------------------------------------------------------------------
    if n_categorical > 0:
        print(f"\nGenerating {n_categorical} categorical features (batches of {BATCH_SIZE})...")

        n_informative_cat = max(1, n_categorical // 5)  # ~20% are informative

        for batch_start in range(0, n_categorical, BATCH_SIZE):
            batch_end = min(batch_start + BATCH_SIZE, n_categorical)
            new_cols = []

            for cat_offset in range(batch_start, batch_end):
                config_idx = cat_offset % len(CATEGORICAL_CONFIGS)
                cfg = CATEGORICAL_CONFIGS[config_idx]
                cardinality = cfg["cardinality"]
                prefix = cfg["prefix"]
                feature_seed = seed + n_numerical + cat_offset + 1000
                col_name = f"{prefix}_{cat_offset}"

                if cat_offset < n_informative_cat:
                    # Informative categorical: distribution depends on label
                    weight = cfg["weight"]
                    cat_int = (
                        floor(
                            spark_abs(randn(feature_seed) + col("label") * lit(weight))
                            * lit(cardinality / 3.0)
                        ) % lit(cardinality)
                    ).cast(IntegerType())
                else:
                    # Noise categorical: uniform random
                    cat_int = floor(rand(feature_seed) * lit(cardinality)).cast(IntegerType())

                # Convert to string label like "cat_bin_0", "cat_low_3"
                new_cols.append(
                    concat_ws("_", lit(prefix), cat_int.cast(StringType())).alias(col_name)
                )

            # Select existing + new batch
            df = df.select("*", *new_cols)

            print(f"  Cat batch {batch_start}-{batch_end - 1} done ({batch_end}/{n_categorical})")

        print(f"  Completed all {n_categorical} categorical features")

    # -------------------------------------------------------------------------
    # Step 4: Final column ordering and cleanup
    # -------------------------------------------------------------------------
    # Reorder: numerical features, then categorical features, then label (drop id)
    numerical_cols = [f"f{i}" for i in range(n_numerical)]
    categorical_cols = []
    for cat_idx in range(n_categorical):
        cfg = CATEGORICAL_CONFIGS[cat_idx % len(CATEGORICAL_CONFIGS)]
        categorical_cols.append(f"{cfg['prefix']}_{cat_idx}")

    all_cols = numerical_cols + categorical_cols + ["label"]
    df = df.select(all_cols)

    # Final repartition for optimal Delta write parallelism
    target_write_partitions = max(n_partitions, 200)
    if total_rows >= 50_000_000:
        target_write_partitions = max(n_partitions, 400)
    print(f"\nRepartitioning to {target_write_partitions} partitions for Delta write...")
    df = df.repartition(target_write_partitions)

    return df

In [None]:
# Generate the dataset
start_time = time.time()

df = generate_imbalanced_dataset(
    spark=spark,
    total_rows=config.total_rows,
    n_features=config.n_features,
    n_informative=config.n_informative,
    n_categorical=config.n_categorical,
    minority_ratio=config.minority_ratio,
    seed=config.seed,
)

generation_time = time.time() - start_time
print(f"\nDataFrame created in {generation_time:.1f}s (lazy - not materialized yet)")

## Write to Delta Table

In [None]:
print(f"Writing to: {config.output_table}")

write_start = time.time()

df.write \
    .format("delta") \
    .mode("overwrite") \
    .option("overwriteSchema", "true") \
    .saveAsTable(config.output_table)

write_time = time.time() - write_start
total_time = time.time() - start_time

print(f"Write completed in {write_time:.1f}s")
print(f"Total time: {total_time:.1f}s ({total_time / 60:.1f} minutes)")

## Validate Results

In [None]:
# Read back and validate
df_check = spark.table(config.output_table)

# Row count
row_count = df_check.count()
print(f"Rows written: {row_count:,}")
print(f"Expected:     {config.total_rows:,}")
print(f"Match: {row_count == config.total_rows}")

In [None]:
# Class distribution
label_counts = df_check.groupBy("label").count().orderBy("label").collect()

print("\nClass distribution:")
class_distribution = {}
for row in label_counts:
    label = row["label"]
    count = row["count"]
    pct = count / row_count * 100
    class_name = "Minority" if label == 1 else "Majority"
    print(f"  Label {label} ({class_name}): {count:,} ({pct:.2f}%)")
    class_distribution[label] = count

if len(label_counts) == 2:
    print(f"\nImbalance ratio: {label_counts[0]['count'] / label_counts[1]['count']:.1f}:1")

In [None]:
# Quick sample — show mix of numerical + categorical + label
n_numerical = config.n_features - config.n_categorical
sample_num = [f"f{i}" for i in range(min(3, n_numerical))]
sample_cat = []
cat_idx = 0
cfg_idx = 0
while len(sample_cat) < min(3, config.n_categorical):
    cfg = [
        {"prefix": "cat_bin"}, {"prefix": "cat_low"}, {"prefix": "cat_med"},
        {"prefix": "cat_hi"}, {"prefix": "cat_vhi"},
    ]
    sample_cat.append(f"{cfg[cfg_idx % 5]['prefix']}_{cat_idx}")
    cat_idx += 1
    cfg_idx += 1

sample_cols = sample_num + sample_cat + ["label"]
print(f"Sample data (first 5 rows, {len(sample_cols)} selected features + label):")
df_check.select(sample_cols).show(5, truncate=False)

## Exit with Result

In [None]:
# Build result for job output
result_json = build_exit_result(
    config=config,
    status="ok",
    row_count=row_count,
    duration_seconds=total_time,
    class_distribution=class_distribution,
)

print(f"\nNotebook result:")
print(result_json)

# Exit with JSON result (fetchable via Databricks API)
dbutils.notebook.exit(result_json)