# Module 2: Data Splitting Strategies

**Training Objective:** Master data splitting techniques to prevent overfitting and data leakage in ML projects.

**Scope:**
- Random Split: Standard Train/Validation/Test split
- Stratified Sampling: Handling imbalanced datasets
- Time-based Split: Correctly splitting time-series data
- Cross-Validation concepts

## Context and Requirements

- **Training day:** Day 1 - Data Preparation Fundamentals
- **Notebook type:** Demo
- **Technical requirements:**
  - Databricks Runtime 14.x LTS or newer
  - Unity Catalog enabled
  - Permissions: CREATE TABLE, SELECT, MODIFY
- **Dependencies:** `01_EDA_and_Validation.ipynb` (creates `customer_bronze` table)
- **Execution time:** ~20 minutes

> **Important:** This notebook saves split data for use in subsequent modules.

## Theoretical Introduction

**Why do we split data?**

If we test our model on the same data we trained it on, it will "cheat" (memorize the answers). This is called **Overfitting**.

**The Three Sets:**

| Set | Percentage | Purpose | When Used |
|-----|------------|---------|-----------|
| **Training** | 60-80% | Model learns patterns (weights) | During training |
| **Validation** | 10-20% | Tune hyperparameters | During development |
| **Test** | 10-20% | Final evaluation ("exam") | Only once at the end |

**Key Concepts:**

| Technique | When to Use | Problem it Solves |
|-----------|-------------|-------------------|
| **Random Split** | Default for most cases | Basic train/test separation |
| **Stratified Split** | Imbalanced classes (e.g., 1% fraud) | Preserves class distribution |
| **Time-based Split** | Time-series data | Prevents "time travel" data leakage |
| **Cross-Validation** | Limited data, need robust estimate | Uses all data for validation |

**Data Leakage Warning:**
> ‚ö†Ô∏è Never use test data for any preprocessing decisions (imputation, scaling parameters). Always fit transformers on training data only!

## Per-User Isolation

Run the initialization script for per-user catalog and schema isolation:

In [None]:
%run ./00_Setup

**Load Data:**

In [None]:
# Load Data
df = spark.table("customer_bronze")

## Section 1: Train / Validation / Test Split

**Why three sets?**
If we test our model on the same data we trained it on, it will cheat (memorize the answers). This is **Overfitting**.
To prevent this, we use a "Hold-out" strategy:

1.  **Training Set (60-80%)**: The model sees this data and learns patterns (weights).
2.  **Validation Set (10-20%)**: Used to tune "Hyperparameters" (e.g., tree depth, learning rate). We evaluate the model here *during* development.
3.  **Test Set (10-20%)**: The "Final Exam". Used **only once** at the very end to estimate how the model will perform in the real world. We *never* tune based on this set.

In [None]:
# Random Split
train_df, val_df, test_df = df.randomSplit([0.6, 0.2, 0.2], seed=42)

print(f"Total: {df.count()}")
print(f"Train: {train_df.count()}")
print(f"Val:   {val_df.count()}")
print(f"Test:  {test_df.count()}")

# Save for next modules
# This is crucial! We save the split data so subsequent notebooks (Imputation, Feature Eng) 
# work on the exact same Training set, preventing Data Leakage from Test data.
train_df.write.mode("overwrite").saveAsTable(f"{catalog_name}.{schema_name}.customer_train")
test_df.write.mode("overwrite").saveAsTable(f"{catalog_name}.{schema_name}.customer_test")
print("‚úÖ Saved 'customer_train' and 'customer_test' tables.")


## Section 2: Stratified Sampling

**The Imbalance Problem:**
Imagine a Fraud Detection dataset where only 1% of transactions are fraud.
If we do a random split, the Test set might end up with **zero** fraud cases just by bad luck. The model would look perfect (100% accuracy) but fail in production.

**Stratification** forces the split to respect the original distribution of classes (e.g., ensuring exactly 1% fraud in Train, Validation, and Test).

In [None]:
# Let's simulate a rare target variable 'is_vip' based on salary
from pyspark.sql.functions import when, col

df_strat = df.withColumn("is_vip", when(col("salary") > 150000, 1).otherwise(0))

# Check distribution
display(df_strat.groupBy("is_vip").count())

# Stratified Split using sampleBy
# We define the fraction of each class we want in the TRAINING set (e.g., 80%)
fractions = {0: 0.8, 1: 0.8}
train_strat = df_strat.stat.sampleBy("is_vip", fractions, seed=42)
test_strat = df_strat.subtract(train_strat)

print("Train Distribution:")
display(train_strat.groupBy("is_vip").count())

## Section 3: Cross-Validation (Concept)

In Cross-Validation (k-fold), we split the data into $k$ folds. We train $k$ times, each time using $k-1$ folds for training and 1 fold for validation.

*Note: In Spark ML, `CrossValidator` is an object that wraps the model and handles this automatically during training. We don't manually split the DataFrame into folds usually.*

```python
# Concept Code (We will use this in the Pipeline notebook - Module 6)
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

# cv = CrossValidator(estimator=pipeline, estimatorParamMaps=paramGrid, evaluator=evaluator, numFolds=3)
```

## Section 4: Time-based Split

**The Time Travel Problem (Data Leakage):**
In time-series problems (e.g., Stock Price, Sales Forecasting), the order of data matters.
If we do a random split, the model might learn from "future" data (e.g., sales in December) to predict "past" data (sales in January). This is impossible in real life.

**Solution:**
We must split by time.
- **Train:** Oldest data (e.g., Jan - Oct).
- **Test:** Newest data (e.g., Nov - Dec).

In [None]:
from pyspark.sql.functions import to_date, col, expr

# Ensure date format
df_time = df.withColumn("registration_date", to_date("registration_date"))

# Dynamic Split Date: Let's take the date that separates the oldest 80% from the newest 20%
# We calculate the 80th percentile of the date
split_date_row = df_time.selectExpr("percentile_approx(to_unix_timestamp(registration_date), 0.8)").collect()[0][0]
split_date = spark.sql(f"select to_date(from_unixtime({split_date_row}))").collect()[0][0]

print(f"Dynamic Split Date (80% cutoff): {split_date}")

train_time = df_time.filter(col("registration_date") < split_date)
test_time = df_time.filter(col("registration_date") >= split_date)

print(f"Train (Historical): {train_time.count()}")
print(f"Test (Recent): {test_time.count()}")


## Best Practices

### üéØ Splitting Strategy Guide:

| Scenario | Recommended Approach | Rationale |
|----------|---------------------|-----------|
| General classification/regression | 60/20/20 random split | Standard approach |
| Imbalanced classes (<5% minority) | Stratified split | Preserves class distribution |
| Time-series / forecasting | Time-based split | Prevents data leakage |
| Small dataset (<1000 rows) | Cross-validation (5-10 folds) | Uses all data efficiently |
| Large dataset (>1M rows) | Simple random split | Stratification less critical |

### ‚ö†Ô∏è Common Mistakes to Avoid:

1. **Using test data for feature engineering** ‚Üí Data leakage
2. **Shuffling time-series before split** ‚Üí Invalid evaluation
3. **Not setting random seed** ‚Üí Non-reproducible results
4. **Using validation set for final evaluation** ‚Üí Overly optimistic results
5. **Forgetting to stratify imbalanced datasets** ‚Üí Unrepresentative splits

### üí° Pro Tips:

- Always set `seed=42` (or any fixed number) for reproducibility
- Save split data as tables for pipeline consistency
- Document your split ratios and strategy
- For production: consider rolling window validation for time-series

## Summary

### What we achieved:

- **Random Split**: Created train/validation/test sets with `randomSplit()`
- **Stratified Sampling**: Used `sampleBy()` to preserve class distribution
- **Cross-Validation**: Understood k-fold concept for robust evaluation
- **Time-based Split**: Applied temporal split to prevent data leakage

### Key Takeaways:

| # | Principle |
|---|-----------|
| 1 | **Always split before preprocessing** - fit transformers on train only |
| 2 | **Use stratification for imbalanced data** - ensures representative splits |
| 3 | **Respect temporal order for time-series** - no "time travel" allowed |
| 4 | **Test set is sacred** - use it only once at the very end |
| 5 | **Save splits as tables** - ensures consistency across pipeline |

### Data Pipeline Status:

| Table | Created | Used By |
|-------|---------|---------|
| `customer_bronze` | Module 1 | This module |
| `customer_train` | ‚úÖ This module | Modules 3-7 |
| `customer_test` | ‚úÖ This module | Module 6 (evaluation) |

### Next Steps:

üìö **Next Module:** Module 3 - Data Imputing (handling missing values)

## Cleanup

Optionally remove demo tables created during exercises:

In [None]:
# Cleanup - remove demo tables created in this notebook

# ‚ö†Ô∏è WARNING: Do NOT delete customer_train and customer_test - they are needed for subsequent modules!

# Uncomment the lines below to remove demo tables:

# spark.sql(f"DROP TABLE IF EXISTS {catalog_name}.{schema_name}.customer_train")
# spark.sql(f"DROP TABLE IF EXISTS {catalog_name}.{schema_name}.customer_test")

# print("‚úÖ All demo tables removed")

print("‚ÑπÔ∏è Cleanup disabled (uncomment code to remove demo tables)")