# Data Split
This notebook partitions the merged dataset into distinct subsets for model development, test, validation, and production.

## Goals
1. **Implement Deterministic Splitting**: Ensure the split is fully reproducible across reruns by using a stable sort order rather than random sampling.
2. **Enforce Stratification**: Use stratified bucketing across both `source_dataset` and `label` to ensure the massive size of the TON_IoT dataset does not overwhelm smaller sources like UNSW-NB15 in any specific split.
3. **Isolate Holdout Data**: Maintain the integrity of validation, test, and production sets by ensuring they remain representative of the original data, while confining any future rebalancing or sampling strictly to the training set.

In [2]:
!pip -q install "PyAthena[SQLAlchemy]" sqlalchemy s3fs

In [3]:
import boto3
import sagemaker
import pandas as pd
import numpy as np
from sqlalchemy import create_engine, text

# Display settings
pd.set_option("display.max_rows", 100)
pd.set_option("display.max_columns", None)
pd.set_option("display.max_colwidth", None)
pd.set_option("display.width", None)

## Connect to Athena

In [4]:
sess = sagemaker.Session()
region = boto3.Session().region_name

results_bucket = sess.default_bucket()
athena_results_path = f"s3://{results_bucket}/athena/staging/"

database_name = "aai540_eda"

engine = create_engine(
    f"awsathena+rest://@athena.{region}.amazonaws.com:443/{database_name}",
    connect_args={"s3_staging_dir": athena_results_path, "region_name": region},
)
print("Region:", region)
print("Athena results:", athena_results_path)

Region: us-east-1
Athena results: s3://sagemaker-us-east-1-128131109986/athena/staging/


In [5]:
# Helper functions for queries
def exec_ddl(sql: str):
    with engine.begin() as conn:
        conn.execute(text(sql))

def read_sql(sql: str) -> pd.DataFrame:
    return pd.read_sql(sql, engine)

## Quick EDA before splitting merged dataset

### Explore a few merged dataset rows

In [14]:
read_sql(f"""
SELECT *
FROM {database_name}.feature_engineered_cleaned
pkt_total
Limit 5
""")

Unnamed: 0,duration,pkt_total,bytes_total,pkt_fwd,pkt_bwd,bytes_fwd,bytes_bwd,label,original_attack_type,attack_category,source_dataset,pkt_rate,byte_rate,bytes_per_pkt,pkt_ratio,byte_ratio
0,0.000313,2,0,1,1,0,0,1,scanning,Reconnaissance,TON_IoT,6389.776358,0.0,0.0,0.5,0.0
1,2e-06,2,0,1,1,0,0,1,scanning,Reconnaissance,TON_IoT,1000000.0,0.0,0.0,0.5,0.0
2,1.8e-05,2,0,1,1,0,0,1,scanning,Reconnaissance,TON_IoT,111111.111111,0.0,0.0,0.5,0.0
3,1.8e-05,2,0,1,1,0,0,1,scanning,Reconnaissance,TON_IoT,111111.111111,0.0,0.0,0.5,0.0
4,61.220886,5,0,3,2,0,0,1,ddos,DoS/DDoS,TON_IoT,0.081671,0.0,0.0,1.0,0.0


### Data source distribution in the merged dataset

In [7]:
rows_per_source = read_sql(f"""
SELECT
  source_dataset,
  COUNT(*) AS row_count
FROM {database_name}.merged_canonical_normalized
GROUP BY source_dataset
ORDER BY row_count DESC
""")

rows_per_source

Unnamed: 0,source_dataset,row_count
0,TON_IoT,21338152
1,CIC-IDS2017,2830743
2,UNSW-NB15,2540047


### Label distribution in the merged dataset

In [8]:
label_dist = read_sql(f"""
SELECT
  label,
  COUNT(*) AS row_count
FROM {database_name}.merged_canonical_normalized
GROUP BY label
ORDER BY label
""")

label_dist

Unnamed: 0,label,row_count
0,0,5273899
1,1,21435043


## Attack category distribution in the merged dataset

In [9]:
attack_category_dist = read_sql(f"""
SELECT
  COALESCE(NULLIF(trim(attack_category), ''), 'UNKNOWN') AS attack_category,
  COUNT(*) AS row_count
FROM {database_name}.merged_canonical_normalized
GROUP BY 1
ORDER BY row_count DESC
""")

attack_category_dist

Unnamed: 0,attack_category,row_count
0,DoS/DDoS,9937377
1,Reconnaissance,6329228
2,Normal,5273899
3,Web Attack,2111103
4,Brute Force,1732403
5,Backdoor,510445
6,Injection,452680
7,Generic Malware,288460
8,Exploits,46047
9,Fuzzing,24246


## Label distribution by source dataset

In [11]:
label_dist_by_source = read_sql(f"""
SELECT
  source_dataset,
  label,
  COUNT(*) AS row_count
FROM {database_name}.merged_canonical_normalized
GROUP BY source_dataset, label
ORDER BY source_dataset, label
""")

label_dist_by_source

Unnamed: 0,source_dataset,label,row_count
0,CIC-IDS2017,0,2273097
1,CIC-IDS2017,1,557646
2,TON_IoT,0,782038
3,TON_IoT,1,20556114
4,UNSW-NB15,0,2218764
5,UNSW-NB15,1,321283


## Attack category distribution by source dataset

In [13]:
attack_cat_by_source = read_sql(f"""
SELECT
  source_dataset,
  COALESCE(NULLIF(trim(attack_category), ''), 'UNKNOWN') AS attack_category,
  COUNT(*) AS row_count
FROM {database_name}.merged_canonical_normalized
GROUP BY source_dataset, 2
ORDER BY source_dataset, row_count DESC
""")

attack_cat_by_source

Unnamed: 0,source_dataset,attack_category,row_count
0,CIC-IDS2017,Normal,2273097
1,CIC-IDS2017,DoS/DDoS,380688
2,CIC-IDS2017,Reconnaissance,158930
3,CIC-IDS2017,Brute Force,13835
4,CIC-IDS2017,Web Attack,2159
5,CIC-IDS2017,Botnet,1966
6,CIC-IDS2017,Infiltration,36
7,CIC-IDS2017,Injection,21
8,CIC-IDS2017,Exploits,11
9,TON_IoT,DoS/DDoS,9540336


## Stratified Dataset Split 
The dataset is partitioned into **train (40%)**, **validation (10%)**, **test (10%)**, and **production (40%)** categories. To handle the significant class and source imbalance (where the **TON_IoT** dataset is nearly 10x larger than others) we employ a **Stratified Splitting** strategy. Instead of a global split, we use the **`NTILE(100)`** window function partitioned by both `source_dataset` and `label`. This ensures that the 100 buckets are calculated independently for every unique combination of data source and class. As a result, each final split (e.g., the 'test' set) is guaranteed to contain exactly 10% of the samples from each specific dataset and each specific label.

### Determinism and Proportional Representation
By ordering the data within each partition by core flow features (such as `duration` and packet counts), the split remains fully deterministic and reproducible. This approach prevents any single dataset from disproportionately influencing a specific split and ensures that the model is validated and tested against a representative cross-section of all three telemetry sources.

In [37]:
# drop the split table if it already exists
exec_ddl(f"DROP TABLE IF EXISTS {database_name}.data_split")

# create a new table with stratified splits
exec_ddl(f"""
CREATE TABLE {database_name}.data_split
WITH (
  format = 'PARQUET',
  external_location = 's3://{results_bucket}/aai540/processed/data_split/',
  parquet_compression = 'SNAPPY'
) AS

-- Using PARTITION BY ensures the 100 buckets are created for EACH dataset/label combo
WITH numbered_data AS (
  SELECT 
    *,
    NTILE(100) OVER (
        PARTITION BY source_dataset, label 
        ORDER BY 
            duration, 
            pkt_total, 
            bytes_total, 
            pkt_fwd, 
            pkt_bwd, 
            bytes_fwd, 
            bytes_bwd
    ) AS split_bucket
  FROM {database_name}.merged_canonical_normalized
)

SELECT
  *,
  CASE
    WHEN split_bucket <= 40 THEN 'train'  -- 40% of each dataset/label
    WHEN split_bucket <= 50 THEN 'val'    -- 10% of each dataset/label
    WHEN split_bucket <= 60 THEN 'test'   -- 10% of each dataset/label
    ELSE 'prod'                           -- 40% of each dataset/label
  END AS data_split
FROM numbered_data
""")

### Verify splits

In [38]:
read_sql(f"""
SELECT data_split, COUNT(*) AS rows,
       ROUND(COUNT(*) * 100.0 / SUM(COUNT(*)) OVER (), 2) AS pct
FROM {database_name}.data_split
GROUP BY data_split
ORDER BY data_split
""")


Unnamed: 0,data_split,rows,pct
0,prod,10683504,40.0
1,test,2670890,10.0
2,train,10683652,40.0
3,val,2670896,10.0


### Verify data source distribution

In [39]:
read_sql(f"""
SELECT data_split, source_dataset, COUNT(*) rows,
       ROUND(COUNT(*) * 100.0 / SUM(COUNT(*)) OVER (PARTITION BY data_split), 2) pct_within_split
FROM {database_name}.data_split
GROUP BY data_split, source_dataset
ORDER BY data_split, source_dataset;
""")

Unnamed: 0,data_split,source_dataset,rows,pct_within_split
0,prod,CIC-IDS2017,1132277,10.6
1,prod,TON_IoT,8535240,79.89
2,prod,UNSW-NB15,1015987,9.51
3,test,CIC-IDS2017,283070,10.6
4,test,TON_IoT,2133810,79.89
5,test,UNSW-NB15,254010,9.51
6,train,CIC-IDS2017,1132320,10.6
7,train,TON_IoT,8535292,79.89
8,train,UNSW-NB15,1016040,9.51
9,val,CIC-IDS2017,283076,10.6


### Verify label distribution in the train split

In [46]:
read_sql(f"""
SELECT 
    label, 
    COUNT(*) as row_count,
    ROUND(COUNT(*) * 100.0 / SUM(COUNT(*)) OVER (), 2) as percent_of_train
FROM {database_name}.data_split
WHERE data_split = 'train'
GROUP BY label
ORDER BY label
""")

Unnamed: 0,label,row_count,percent_of_train
0,0,2109598,19.75
1,1,8574054,80.25


### Verify attack category distribution in the train split

In [47]:
read_sql(f"""
SELECT 
    attack_category, 
    COUNT(*) as row_count,
    ROUND(COUNT(*) * 100.0 / SUM(COUNT(*)) OVER (), 2) as percent_of_train
FROM {database_name}.data_split
WHERE data_split = 'train'
GROUP BY attack_category
ORDER BY row_count DESC
""")

Unnamed: 0,attack_category,row_count,percent_of_train
0,Reconnaissance,5194650,48.62
1,DoS/DDoS,2940024,27.52
2,Normal,2109598,19.75
3,Generic Malware,179424,1.68
4,Brute Force,152437,1.43
5,Web Attack,50316,0.47
6,Backdoor,40553,0.38
7,Exploits,7272,0.07
8,Injection,4393,0.04
9,Fuzzing,4219,0.04
