# Data Split
This notebook partitions the merged dataset into distinct subsets for model development, test, validation, and production.

## Goals
1. **Implement Deterministic Splitting**: Ensure the split is fully reproducible across reruns by using a stable sort order rather than random sampling.
2. **Enforce Stratification**: Use stratified bucketing across both `source_dataset` and `label` to ensure the massive size of the TON_IoT dataset does not overwhelm smaller sources like UNSW-NB15 in any specific split.
3. **Isolate Holdout Data**: Maintain the integrity of validation, test, and production sets by ensuring they remain representative of the original data, while confining any future rebalancing or sampling strictly to the training set.

In [6]:
!pip -q install "PyAthena[SQLAlchemy]" sqlalchemy s3fs

In [7]:
import boto3
import sagemaker
import pandas as pd
import numpy as np
from sqlalchemy import create_engine, text

# Display settings
pd.set_option("display.max_rows", 100)
pd.set_option("display.max_columns", None)
pd.set_option("display.max_colwidth", None)
pd.set_option("display.width", None)

## Connect to Athena

In [8]:
sess = sagemaker.Session()
region = boto3.Session().region_name

results_bucket = sess.default_bucket()
athena_results_path = f"s3://{results_bucket}/athena/staging/"

database_name = "aai540_eda"

engine = create_engine(
    f"awsathena+rest://@athena.{region}.amazonaws.com:443/{database_name}",
    connect_args={"s3_staging_dir": athena_results_path, "region_name": region},
)
print("Region:", region)
print("Athena results:", athena_results_path)

Region: us-east-1
Athena results: s3://sagemaker-us-east-1-933747558592/athena/staging/


In [9]:
# Helper functions for queries
def exec_ddl(sql: str):
    with engine.begin() as conn:
        conn.execute(text(sql))

def read_sql(sql: str) -> pd.DataFrame:
    return pd.read_sql(sql, engine)

## Quick EDA before splitting merged dataset

### Explore a few merged dataset rows

In [10]:
read_sql(f"""
SELECT *
FROM {database_name}.feature_engineered_cleaned
pkt_total
Limit 5
""")

Unnamed: 0,duration,pkt_total,bytes_total,pkt_fwd,pkt_bwd,bytes_fwd,bytes_bwd,label,original_attack_type,attack_category,source_dataset,pkt_rate,byte_rate,bytes_per_pkt,pkt_ratio,byte_ratio
0,4.999909,3530,0,3530,0,0,0,1,ddos,DoS/DDoS,TON_IoT,706.012849,0.0,0.0,3530.0,0.0
1,61.819354,5,0,3,2,0,0,1,ddos,DoS/DDoS,TON_IoT,0.080881,0.0,0.0,1.0,0.0
2,60.855999,5,0,3,2,0,0,1,ddos,DoS/DDoS,TON_IoT,0.082161,0.0,0.0,1.0,0.0
3,60.859135,5,0,3,2,0,0,1,ddos,DoS/DDoS,TON_IoT,0.082157,0.0,0.0,1.0,0.0
4,60.860039,5,0,3,2,0,0,1,ddos,DoS/DDoS,TON_IoT,0.082156,0.0,0.0,1.0,0.0


### Data source distribution in the merged dataset

In [11]:
rows_per_source = read_sql(f"""
SELECT
  source_dataset,
  COUNT(*) AS row_count
FROM {database_name}.merged_canonical_normalized
GROUP BY source_dataset
ORDER BY row_count DESC
""")

rows_per_source

Unnamed: 0,source_dataset,row_count
0,TON_IoT,21338152
1,CIC-IDS2017,2830743
2,UNSW-NB15,2540047


### Label distribution in the merged dataset

In [12]:
label_dist = read_sql(f"""
SELECT
  label,
  COUNT(*) AS row_count
FROM {database_name}.merged_canonical_normalized
GROUP BY label
ORDER BY label
""")

label_dist

Unnamed: 0,label,row_count
0,0,5273899
1,1,21435043


## Attack category distribution in the merged dataset

In [13]:
attack_category_dist = read_sql(f"""
SELECT
  COALESCE(NULLIF(trim(attack_category), ''), 'UNKNOWN') AS attack_category,
  COUNT(*) AS row_count
FROM {database_name}.merged_canonical_normalized
GROUP BY 1
ORDER BY row_count DESC
""")

attack_category_dist

Unnamed: 0,attack_category,row_count
0,DoS/DDoS,9937377
1,Reconnaissance,6329228
2,Normal,5273899
3,Web Attack,2111103
4,Brute Force,1732403
5,Backdoor,510445
6,Injection,452680
7,Generic Malware,288460
8,Exploits,46047
9,Fuzzing,24246


## Label distribution by source dataset

In [14]:
label_dist_by_source = read_sql(f"""
SELECT
  source_dataset,
  label,
  COUNT(*) AS row_count
FROM {database_name}.merged_canonical_normalized
GROUP BY source_dataset, label
ORDER BY source_dataset, label
""")

label_dist_by_source

Unnamed: 0,source_dataset,label,row_count
0,CIC-IDS2017,0,2273097
1,CIC-IDS2017,1,557646
2,TON_IoT,0,782038
3,TON_IoT,1,20556114
4,UNSW-NB15,0,2218764
5,UNSW-NB15,1,321283


## Attack category distribution by source dataset

In [15]:
attack_cat_by_source = read_sql(f"""
SELECT
  source_dataset,
  COALESCE(NULLIF(trim(attack_category), ''), 'UNKNOWN') AS attack_category,
  COUNT(*) AS row_count
FROM {database_name}.merged_canonical_normalized
GROUP BY source_dataset, 2
ORDER BY source_dataset, row_count DESC
""")

attack_cat_by_source

Unnamed: 0,source_dataset,attack_category,row_count
0,CIC-IDS2017,Normal,2273097
1,CIC-IDS2017,DoS/DDoS,380688
2,CIC-IDS2017,Reconnaissance,158930
3,CIC-IDS2017,Brute Force,13835
4,CIC-IDS2017,Web Attack,2159
5,CIC-IDS2017,Botnet,1966
6,CIC-IDS2017,Infiltration,36
7,CIC-IDS2017,Injection,21
8,CIC-IDS2017,Exploits,11
9,TON_IoT,DoS/DDoS,9540336


## Sampling Before Stratified Dataset Split

The merged dataset contains tens of millions of network flow records, which might be unnecessary and inefficient for training a tree-based classification model in the context of this assignment. To reduce computational cost while preserving the statistical properties of the data, we perform a controlled sampling step before applying the stratified train/validation/test/production split.

The sampling is performed prior to splitting and is stratified by `(source_dataset, label)`, ensuring that:
- The relative contribution of each source dataset (TON_IoT, CIC-IDS2017, UNSW-NB15) is preserved.
- The benign/malicious label distribution within each dataset is maintained.
- Minority attack classes are not inadvertently removed.

In [16]:
# target total number of rows
TARGET_TOTAL_ROWS = 500_000

In [17]:
# drop sampled table if it already exists
exec_ddl(f"DROP TABLE IF EXISTS {database_name}.merged_sampled_v2")

# create a sampled version of the merged dataset before splitting
exec_ddl(f"""
CREATE TABLE {database_name}.merged_sampled_v2
WITH (
  format = 'PARQUET',
  external_location = 's3://{results_bucket}/aai540/processed/merged_sampled_v2/',
  parquet_compression = 'SNAPPY'
) AS
WITH base AS (
  SELECT *
  FROM {database_name}.merged_canonical_normalized
),

-- total rows per dataset
ds_counts AS (
  SELECT source_dataset, COUNT(*) AS ds_rows
  FROM base
  GROUP BY source_dataset
),

-- total rows per dataset + label
group_counts AS (
  SELECT source_dataset, label, COUNT(*) AS group_rows
  FROM base
  GROUP BY source_dataset, label
),

-- allocate target rows per dataset proportional to original size
ds_targets AS (
  SELECT
    source_dataset,
    CAST({TARGET_TOTAL_ROWS} AS DOUBLE)
      * (CAST(ds_rows AS DOUBLE) / SUM(CAST(ds_rows AS DOUBLE)) OVER ())
      AS ds_target_rows
  FROM ds_counts
),

-- allocate dataset targets down to label groups proportionally
group_targets AS (
  SELECT
    g.source_dataset,
    g.label,
    g.group_rows,
    d.ds_rows,
    t.ds_target_rows
      * (CAST(g.group_rows AS DOUBLE) / CAST(d.ds_rows AS DOUBLE))
      AS group_target_rows
  FROM group_counts g
  JOIN ds_counts d
    ON g.source_dataset = d.source_dataset
  JOIN ds_targets t
    ON g.source_dataset = t.source_dataset
),

-- compute sampling rate per (source_dataset, label)
rates AS (
  SELECT
    source_dataset,
    label,
    LEAST(1.0, group_target_rows / CAST(group_rows AS DOUBLE)) AS sample_rate
  FROM group_targets
)

SELECT b.*
FROM base b
JOIN rates r
  ON b.source_dataset = r.source_dataset
 AND b.label = r.label
WHERE rand() < r.sample_rate
""")

### Validate data source distribution

In [18]:
read_sql(f"""
SELECT source_dataset, COUNT(*) AS rows
FROM {database_name}.merged_sampled_v2
GROUP BY source_dataset
ORDER BY rows DESC
""")

Unnamed: 0,source_dataset,rows
0,TON_IoT,398775
1,CIC-IDS2017,53063
2,UNSW-NB15,47801


### Validate label distribution per dataset

In [19]:
read_sql(f"""
SELECT source_dataset, label, COUNT(*) AS rows
FROM {database_name}.merged_sampled_v2
GROUP BY source_dataset, label
ORDER BY source_dataset, label
""")

Unnamed: 0,source_dataset,label,rows
0,CIC-IDS2017,0,42606
1,CIC-IDS2017,1,10457
2,TON_IoT,0,14469
3,TON_IoT,1,384306
4,UNSW-NB15,0,41777
5,UNSW-NB15,1,6024


### Validate attack category distribution per dataset

In [20]:
read_sql(f"""
SELECT
  source_dataset,
  COALESCE(NULLIF(trim(attack_category), ''), 'UNKNOWN') AS attack_category,
  COUNT(*) AS row_count
FROM {database_name}.merged_sampled_v2
GROUP BY source_dataset, 2
ORDER BY source_dataset, row_count DESC
""")

Unnamed: 0,source_dataset,attack_category,row_count
0,CIC-IDS2017,Normal,42606
1,CIC-IDS2017,DoS/DDoS,7234
2,CIC-IDS2017,Reconnaissance,2902
3,CIC-IDS2017,Brute Force,242
4,CIC-IDS2017,Web Attack,46
5,CIC-IDS2017,Botnet,33
6,TON_IoT,DoS/DDoS,178628
7,TON_IoT,Reconnaissance,114870
8,TON_IoT,Web Attack,39368
9,TON_IoT,Brute Force,32061


## Stratified Dataset Split 
The dataset is partitioned into **train (40%)**, **validation (10%)**, **test (10%)**, and **production (40%)** categories. To handle the significant class and source imbalance (where the **TON_IoT** dataset is nearly 10x larger than others) we employ a **Stratified Splitting** strategy. Instead of a global split, we use the **`NTILE(100)`** window function partitioned by both `source_dataset` and `label`. This ensures that the 100 buckets are calculated independently for every unique combination of data source and class. As a result, each final split (e.g., the 'test' set) is guaranteed to contain exactly 10% of the samples from each specific dataset and each specific label.

### Determinism and Proportional Representation
By ordering the data within each partition by core flow features (such as `duration` and packet counts), the split remains fully deterministic and reproducible. This approach prevents any single dataset from disproportionately influencing a specific split and ensures that the model is validated and tested against a representative cross-section of all three telemetry sources.

In [21]:
# drop the split table if it already exists
exec_ddl(f"DROP TABLE IF EXISTS {database_name}.split_v2")

# create a new table with stratified splits
exec_ddl(f"""
CREATE TABLE {database_name}.split_v2
WITH (
  format = 'PARQUET',
  external_location = 's3://{results_bucket}/aai540/processed/split_v2/',
  parquet_compression = 'SNAPPY'
) AS

-- Using PARTITION BY ensures the 100 buckets are created for EACH dataset/label combo
WITH numbered_data AS (
  SELECT 
    *,
    NTILE(100) OVER (
        PARTITION BY source_dataset, label 
        ORDER BY 
            duration, 
            pkt_total, 
            bytes_total, 
            pkt_fwd, 
            pkt_bwd, 
            bytes_fwd, 
            bytes_bwd
    ) AS split_bucket
  FROM {database_name}.merged_sampled_v2
)

SELECT
  *,
  CASE
    WHEN split_bucket <= 40 THEN 'train'  -- 40% of each dataset/label
    WHEN split_bucket <= 50 THEN 'val'    -- 10% of each dataset/label
    WHEN split_bucket <= 60 THEN 'test'   -- 10% of each dataset/label
    ELSE 'prod'                           -- 40% of each dataset/label
  END AS data_split
FROM numbered_data
""")

### Verify splits

In [22]:
read_sql(f"""
SELECT data_split, COUNT(*) AS rows,
       ROUND(COUNT(*) * 100.0 / SUM(COUNT(*)) OVER (), 2) AS pct
FROM {database_name}.split_v2
GROUP BY data_split
ORDER BY data_split
""")


Unnamed: 0,data_split,rows,pct
0,prod,199786,39.99
1,test,49967,10.0
2,train,199916,40.01
3,val,49970,10.0


### Verify data source distribution

In [23]:
read_sql(f"""
SELECT data_split, source_dataset, COUNT(*) rows,
       ROUND(COUNT(*) * 100.0 / SUM(COUNT(*)) OVER (PARTITION BY data_split), 2) pct_within_split
FROM {database_name}.split_v2
GROUP BY data_split, source_dataset
ORDER BY data_split, source_dataset;
""")

Unnamed: 0,data_split,source_dataset,rows,pct_within_split
0,prod,CIC-IDS2017,21200,10.61
1,prod,TON_IoT,159489,79.83
2,prod,UNSW-NB15,19097,9.56
3,test,CIC-IDS2017,5307,10.62
4,test,TON_IoT,39880,79.81
5,test,UNSW-NB15,4780,9.57
6,train,CIC-IDS2017,21246,10.63
7,train,TON_IoT,159526,79.8
8,train,UNSW-NB15,19144,9.58
9,val,CIC-IDS2017,5310,10.63


### Verify label distribution in the train split

In [24]:
read_sql(f"""
SELECT 
    label, 
    COUNT(*) as row_count,
    ROUND(COUNT(*) * 100.0 / SUM(COUNT(*)) OVER (), 2) as percent_of_train
FROM {database_name}.split_v2
WHERE data_split = 'train'
GROUP BY label
ORDER BY label
""")

Unnamed: 0,label,row_count,percent_of_train
0,0,39566,19.79
1,1,160350,80.21


### Verify attack category distribution in the train split

In [25]:
read_sql(f"""
SELECT 
    attack_category, 
    COUNT(*) as row_count,
    ROUND(COUNT(*) * 100.0 / SUM(COUNT(*)) OVER (), 2) as percent_of_train
FROM {database_name}.split_v2
WHERE data_split = 'train'
GROUP BY attack_category
ORDER BY row_count DESC
""")

Unnamed: 0,attack_category,row_count,percent_of_train
0,Reconnaissance,96932,48.49
1,DoS/DDoS,55244,27.63
2,Normal,39566,19.79
3,Generic Malware,3401,1.7
4,Brute Force,2812,1.41
5,Web Attack,912,0.46
6,Backdoor,754,0.38
7,Exploits,127,0.06
8,Injection,83,0.04
9,Fuzzing,73,0.04


In [26]:
read_sql(f"""
SELECT 
    attack_category, 
    COUNT(*) as row_count,
    ROUND(COUNT(*) * 100.0 / SUM(COUNT(*)) OVER (), 2) as percent_of_train
FROM {database_name}.split_v2
WHERE data_split = 'train'
GROUP BY attack_category
ORDER BY row_count DESC
""")

Unnamed: 0,attack_category,row_count,percent_of_train
0,Reconnaissance,96932,48.49
1,DoS/DDoS,55244,27.63
2,Normal,39566,19.79
3,Generic Malware,3401,1.7
4,Brute Force,2812,1.41
5,Web Attack,912,0.46
6,Backdoor,754,0.38
7,Exploits,127,0.06
8,Injection,83,0.04
9,Fuzzing,73,0.04
