# Module 2: Data Splitting Strategies

## Business Context: TechCorp HR Analytics (continued)

**Where We Are:**
In Module 1, we explored TechCorp's employee data and discovered patterns: salary depends on age (experience), country (market rates), and recruitment source. Now we need to split this data properly before training our Salary Prediction Model.

**Why This Matters for HR:**
If we test the model on data it has already "seen", it will appear accurate but fail on new hires. Proper splitting ensures:
1. **Fair evaluation:** Model is tested on truly unseen employees
2. **No cheating:** Prevents memorizing individual records
3. **Realistic performance:** Reflects how well the model will work on future hires

---

**Training Objective:** Master data splitting techniques to prevent overfitting and data leakage in ML projects.

**Scope:**
- Random Split: Standard Train/Validation/Test split
- Stratified Sampling: Handling imbalanced datasets
- Time-based Split: Correctly splitting time-series data
- Cross-Validation concepts

## Context and Requirements

- **Training day:** Day 1 - Data Preparation Fundamentals
- **Notebook type:** Demo
- **Technical requirements:**
  - Databricks Runtime 14.x LTS or newer
  - Unity Catalog enabled
  - Permissions: CREATE TABLE, SELECT, MODIFY
- **Dependencies:** `01_EDA_and_Validation.ipynb` (creates `customer_bronze` table)
- **Execution time:** ~20 minutes

> **Important:** This notebook saves split data for use in subsequent modules.

## Theoretical Introduction

**Why do we split data?**

If we test our model on the same data we trained it on, it will "cheat" (memorize the answers). This is called **Overfitting**.

**The Three Sets:**

| Set | Percentage | Purpose | When Used |
|-----|------------|---------|-----------|
| **Training** | 60-80% | Model learns patterns (weights) | During training |
| **Validation** | 10-20% | Tune hyperparameters | During development |
| **Test** | 10-20% | Final evaluation ("exam") | Only once at the end |

**Key Concepts:**

| Technique | When to Use | Problem it Solves |
|-----------|-------------|-------------------|
| **Random Split** | Default for most cases | Basic train/test separation |
| **Stratified Split** | Imbalanced classes (e.g., 1% fraud) | Preserves class distribution |
| **Time-based Split** | Time-series data | Prevents "time travel" data leakage |
| **Cross-Validation** | Limited data, need robust estimate | Uses all data for validation |

**Data Leakage Warning:**
> ️ Never use test data for any preprocessing decisions (imputation, scaling parameters). Always fit transformers on training data only!

## Per-User Isolation

Run the initialization script for per-user catalog and schema isolation:

In [0]:
%run ./00_Setup

**Load Data:**

In [0]:
# Load Data
df = spark.table("customer_bronze")

# Basic Split - Train/Test (80/20)

## Theory
**Basic Split** divides data into two parts: training (80%) and test (20%).

###  When to use:
- **Large datasets** (>10,000 samples)
- **Simple models** without hyperparameter tuning
- **Quick prototyping** and baseline models

###  Advantages:
- **Simplicity** - easy to implement
- **Speed** - minimal overhead
- **Clarity** - clear structure

###  Limitations:
- **No tuning** - no validation set
- **Variance** - results depend on random split

In [0]:
# Split into training and test sets (80/20)
train_data, test_data = df.randomSplit([0.8, 0.2], seed=42)

**What is `seed=42`?**

Setting `seed=42` (or any fixed number) ensures reproducibility in random operations (like data splitting or shuffling). Using a fixed seed means you'll get the same random results every time you run the code, making experiments consistent and results comparable.

In [0]:
display(train_data)

In [0]:
display(test_data)

## Section 1: Train / Validation / Test Split

**Why three sets?**
If we test our model on the same data we trained it on, it will cheat (memorize the answers). This is **Overfitting**.
To prevent this, we use a "Hold-out" strategy:

1.  **Training Set (60-80%)**: The model sees this data and learns patterns (weights).
2.  **Validation Set (10-20%)**: Used to tune "Hyperparameters" (e.g., tree depth, learning rate). We evaluate the model here *during* development.
3.  **Test Set (10-20%)**: The "Final Exam". Used **only once** at the very end to estimate how the model will perform in the real world. We *never* tune based on this set.

### Why do we use three sets?

1. **Training Set**: Model learns patterns in the data
2. **Validation Set**: We test different model configurations without "looking" at the test set
3. **Test Set**: Final, objective evaluation of model performance

### Process:
1. Train model on **Training Set**
2. Evaluate different hyperparameters on **Validation Set**
3. Select the best configuration
4. Final evaluation on **Test Set**

### Three-way Split (60/20/20)
```
Dataset → Train (60%) + Validation (20%) + Test (20%)
    ↓        ↓              ↓              ↓
Training   Model      Hyperparameter    Final
 Data      Fitting      Tuning        Evaluation
```
- **Use case**: Model selection and hyperparameter optimization
- **Pros**: Unbiased final evaluation, enables tuning
- **Cons**: Reduces training data size

In [0]:
# Random Split
train_df, val_df, test_df = df.randomSplit([0.6, 0.2, 0.2], seed=42)

print(f"Total: {df.count()}")
print(f"Train: {train_df.count()}")
print(f"Val:   {val_df.count()}")
print(f"Test:  {test_df.count()}")

# Save for next modules
# This is crucial! We save the split data so subsequent notebooks (Imputation, Feature Eng) 
# work on the exact same Training set, preventing Data Leakage from Test data.
train_df.write.mode("overwrite").saveAsTable(f"{catalog_name}.{schema_name}.customer_train")
test_df.write.mode("overwrite").saveAsTable(f"{catalog_name}.{schema_name}.customer_test")
print(" Saved 'customer_train' and 'customer_test' tables.")


## Section 2: Stratified Sampling

### Why Stratified Sampling?

In imbalanced datasets (e.g. 1% fraud), a pure random split can cause:

- Train/Validation/Test sets with too few or zero minority-class samples  
- Over-optimistic metrics (e.g. high accuracy by always predicting the majority class)  
- Unstable and non-comparable results between runs

**Stratified sampling** ensures each split (Train, Validation, Test) preserves the original class proportions (e.g. ~1% fraud in every set).

### Why do we use stratification?

- Maintains class distribution in all splits  
- Ensures the model sees minority cases during training  
- Provides more reliable and comparable evaluation metrics  
- Reduces risk of “degenerate” splits without the minority class

### Process

1. **Select target for stratification**  
   - Use the classification label (e.g. `is_fraud`, `churn_flag`).

2. **Inspect global class distribution**  
   - Compute counts and percentages for each class.

3. **Define split ratios**  
   - Common: Train / Val / Test = 70% / 15% / 15% (or similar).

4. **Apply stratified split**  
   - Use a method that supports stratification:
     - Single split: stratified train/val/test  
     - Cross-validation: stratified k-fold

5. **Validate distributions per split**  
   - Check that class percentages in Train/Val/Test are close to the original dataset.

6. **Handle special cases**  
   - Time-series: respect temporal order, avoid random shuffling across time  
   - Grouped data (e.g. customers): combine stratification with grouping to avoid leakage

In [0]:
from pyspark.sql.functions import when, col

df_vip_train = train_data.withColumn("is_vip", when(col("salary") > 150000, 1).otherwise(0))

In [0]:
display(df_vip_train.groupBy("is_vip").count())

In [0]:
# Let's simulate a rare target variable 'is_vip' based on salary
from pyspark.sql.functions import when, col

df_strat = df.withColumn("is_vip", when(col("salary") > 150000, 1).otherwise(0))

In [0]:
# Check distribution
display(df_strat.groupBy("is_vip").count())

In [0]:
# Stratified Split using sampleBy
# We define the fraction of each class we want in the TRAINING set (e.g., 80%)
fractions = {0: 0.8, 1: 0.8}
train_strat = df_strat.stat.sampleBy("is_vip", fractions, seed=42)
test_strat = df_strat.subtract(train_strat)

In [0]:


print("Train Distribution:")
display(train_strat.groupBy("is_vip").count())

In [0]:
print("Train Distribution:")
display(test_strat.groupBy("is_vip").count())

## Section 3: Cross-Validation (Concept)

In Cross-Validation (k-fold), we split the data into $k$ folds. We train $k$ times, each time using $k-1$ folds for training and 1 fold for validation.

*Note: In Spark ML, `CrossValidator` is an object that wraps the model and handles this automatically during training. We don't manually split the DataFrame into folds usually.*

```python
# Concept Code (We will use this in the Pipeline notebook - Module 6)
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

# cv = CrossValidator(estimator=pipeline, estimatorParamMaps=paramGrid, evaluator=evaluator, numFolds=3)
```

### Advantages:
-  **Better data utilization** - every record is used for both training and testing
-  **More stable evaluation** - averaging reduces the impact of randomness
-  **Overfitting detection** - high variance between folds may indicate a problem

### Disadvantages:
-  **Computationally expensive** - training is performed K times
-  **Longer duration** - especially for large models

###  Cross-Validation (K-fold)
```
Fold 1: [TEST ] [TRAIN] [TRAIN] [TRAIN] [TRAIN]
Fold 2: [TRAIN] [TEST ] [TRAIN] [TRAIN] [TRAIN]
Fold 3: [TRAIN] [TRAIN] [TEST ] [TRAIN] [TRAIN]
...
Final Score = Average(Fold1, Fold2, ..., FoldK)
```
- **Use case**: Robust model evaluation with limited data
- **Pros**: Maximum data utilization, robust estimates
- **Cons**: Computationally expensive (K times training)

## Section 4: Time-based Split

**The Time Travel Problem (Data Leakage):**
In time-series problems (e.g., Stock Price, Sales Forecasting), the order of data matters.
If we do a random split, the model might learn from "future" data (e.g., sales in December) to predict "past" data (sales in January). This is impossible in real life.

**Solution:**
We must split by time.
- **Train:** Oldest data (e.g., Jan - Oct).
- **Test:** Newest data (e.g., Nov - Dec).

In [0]:
from pyspark.sql.functions import to_date, col, expr

# Ensure date format
df_time = df.withColumn("registration_date", to_date("registration_date"))

In [0]:
# Dynamic Split Date: Let's take the date that separates the oldest 80% from the newest 20%
# We calculate the 80th percentile of the date
split_date_row = df_time.selectExpr("percentile_approx(to_unix_timestamp(registration_date), 0.8)").collect()[0][0]
split_date = spark.sql(f"select to_date(from_unixtime({split_date_row}))").collect()[0][0]

print(f"Dynamic Split Date (80% cutoff): {split_date}")

In [0]:

train_time = df_time.filter(col("registration_date") < split_date)
test_time = df_time.filter(col("registration_date") >= split_date)

print(f"Train (Historical): {train_time.count()}")
print(f"Test (Recent): {test_time.count()}")

## Best Practices

###  Splitting Strategy Guide:

| Scenario | Recommended Approach | Rationale |
|----------|---------------------|-----------|
| General classification/regression | 60/20/20 random split | Standard approach |
| Imbalanced classes (<5% minority) | Stratified split | Preserves class distribution |
| Time-series / forecasting | Time-based split | Prevents data leakage |
| Small dataset (<1000 rows) | Cross-validation (5-10 folds) | Uses all data efficiently |
| Large dataset (>1M rows) | Simple random split | Stratification less critical |

### ️ Common Mistakes to Avoid:

1. **Using test data for feature engineering** → Data leakage
2. **Shuffling time-series before split** → Invalid evaluation
3. **Not setting random seed** → Non-reproducible results
4. **Using validation set for final evaluation** → Overly optimistic results
5. **Forgetting to stratify imbalanced datasets** → Unrepresentative splits

###  Pro Tips:

- Always set `seed=42` (or any fixed number) for reproducibility
- Save split data as tables for pipeline consistency
- Document your split ratios and strategy
- For production: consider rolling window validation for time-series

## Summary

### What we achieved:

- **Random Split**: Created train/validation/test sets with `randomSplit()`
- **Stratified Sampling**: Used `sampleBy()` to preserve class distribution
- **Cross-Validation**: Understood k-fold concept for robust evaluation
- **Time-based Split**: Applied temporal split to prevent data leakage

### Key Takeaways:

| # | Principle |
|---|-----------|
| 1 | **Always split before preprocessing** - fit transformers on train only |
| 2 | **Use stratification for imbalanced data** - ensures representative splits |
| 3 | **Respect temporal order for time-series** - no "time travel" allowed |
| 4 | **Test set is sacred** - use it only once at the very end |
| 5 | **Save splits as tables** - ensures consistency across pipeline |

### Data Pipeline Status:

| Table | Created | Used By |
|-------|---------|---------|
| `customer_bronze` | Module 1 | This module |
| `customer_train` |  This module | Modules 3-7 |
| `customer_test` |  This module | Module 6 (evaluation) |

### Next Steps:

 **Next Module:** Module 3 - Data Imputing (handling missing values)

## Cleanup

Optionally remove demo tables created during exercises:

In [0]:
# Cleanup - remove demo tables created in this notebook

# ️ WARNING: Do NOT delete customer_train and customer_test - they are needed for subsequent modules!

# Uncomment the lines below to remove demo tables:

# spark.sql(f"DROP TABLE IF EXISTS {catalog_name}.{schema_name}.customer_train")
# spark.sql(f"DROP TABLE IF EXISTS {catalog_name}.{schema_name}.customer_test")

# print(" All demo tables removed")

print("ℹ️ Cleanup disabled (uncomment code to remove demo tables)")