# Data Split
This notebook partitions the merged dataset into distinct subsets for model development, test, validation, and production.

## Goals
1. **Implement Deterministic Splitting**: Ensure the split is fully reproducible across reruns by using a stable sort order rather than random sampling.
2. **Enforce Stratification**: Use stratified bucketing across both `source_dataset` and `label` to ensure the massive size of the TON_IoT dataset does not overwhelm smaller sources like UNSW-NB15 in any specific split.
3. **Isolate Holdout Data**: Maintain the integrity of validation, test, and production sets by ensuring they remain representative of the original data, while confining any future rebalancing or sampling strictly to the training set.

In [22]:
!pip -q install "PyAthena[SQLAlchemy]" sqlalchemy s3fs

In [23]:
import boto3
import sagemaker
import pandas as pd
import numpy as np
from sqlalchemy import create_engine, text

# Display settings
pd.set_option("display.max_rows", 100)
pd.set_option("display.max_columns", None)
pd.set_option("display.max_colwidth", None)
pd.set_option("display.width", None)

## Connect to Athena

In [24]:
sess = sagemaker.Session()
region = boto3.Session().region_name

results_bucket = sess.default_bucket()
athena_results_path = f"s3://{results_bucket}/athena/staging/"

database_name = "aai540_eda"

engine = create_engine(
    f"awsathena+rest://@athena.{region}.amazonaws.com:443/{database_name}",
    connect_args={"s3_staging_dir": athena_results_path, "region_name": region},
)
print("Region:", region)
print("Athena results:", athena_results_path)

Region: us-east-1
Athena results: s3://sagemaker-us-east-1-776673915827/athena/staging/


In [25]:
# Helper functions for queries
def exec_ddl(sql: str):
    with engine.begin() as conn:
        conn.execute(text(sql))

def read_sql(sql: str) -> pd.DataFrame:
    return pd.read_sql(sql, engine)

## Quick EDA before splitting merged dataset

### Data source distribution in the merged dataset

In [26]:
rows_per_source = read_sql(f"""
SELECT
  source_dataset,
  row_count,
  ROUND(
      100.0 * row_count 
      / SUM(row_count) OVER (),
      2
  ) AS percentage
FROM (
    SELECT
        source_dataset,
        COUNT(*) AS row_count
    FROM {database_name}.feature_engineered_cleaned
    GROUP BY source_dataset
) t
ORDER BY row_count DESC
""")

rows_per_source

Unnamed: 0,source_dataset,row_count,percentage
0,TON_IoT,16023652,74.94
1,CIC-IDS2017,2827761,13.22
2,UNSW-NB15,2531675,11.84


### Label distribution in the merged dataset

In [27]:
label_dist = read_sql(f"""
SELECT
  label,
  row_count,
  ROUND(
      100.0 * row_count
      / SUM(row_count) OVER (),
      2
  ) AS percentage
FROM (
    SELECT
        label,
        COUNT(*) AS row_count
    FROM {database_name}.feature_engineered_cleaned
    GROUP BY label
) t
ORDER BY label
""")

label_dist


Unnamed: 0,label,row_count,percentage
0,0,4852041,22.69
1,1,16531047,77.31


## Attack category distribution in the merged dataset

In [28]:
attack_category_dist = read_sql(f"""
SELECT
  attack_category,
  row_count,
  ROUND(100.0 * row_count / SUM(row_count) OVER (), 2) AS percentage
FROM (
    SELECT
        COALESCE(NULLIF(trim(attack_category), ''), 'UNKNOWN') AS attack_category,
        COUNT(*) AS row_count
    FROM {database_name}.feature_engineered_cleaned
    GROUP BY 1
) t
ORDER BY row_count DESC
""")

attack_category_dist

Unnamed: 0,attack_category,row_count,percentage
0,DoS/DDoS,9678799,45.26
1,Normal,4852041,22.69
2,Web Attack,2095484,9.8
3,Reconnaissance,1862138,8.71
4,Brute Force,1643527,7.69
5,Backdoor,510444,2.39
6,Injection,451709,2.11
7,Generic Malware,215651,1.01
8,Exploits,46032,0.22
9,Fuzzing,24226,0.11


## Label distribution by source dataset

In [29]:
label_dist_by_source = read_sql(f"""
SELECT
  source_dataset,
  label,
  row_count,
  ROUND(
      100.0 * row_count 
      / SUM(row_count) OVER (PARTITION BY source_dataset),
      2
  ) AS percentage
FROM (
    SELECT
        source_dataset,
        label,
        COUNT(*) AS row_count
    FROM {database_name}.feature_engineered_cleaned
    GROUP BY source_dataset, label
) t
ORDER BY source_dataset, label
""")

label_dist_by_source


Unnamed: 0,source_dataset,label,row_count,percentage
0,CIC-IDS2017,0,2271205,80.32
1,CIC-IDS2017,1,556556,19.68
2,TON_IoT,0,370348,2.31
3,TON_IoT,1,15653304,97.69
4,UNSW-NB15,0,2210488,87.31
5,UNSW-NB15,1,321187,12.69


## Attack category distribution by source dataset

In [30]:
attack_cat_by_source = read_sql(f"""
SELECT
  source_dataset,
  attack_category,
  row_count,
  ROUND(
      100.0 * row_count
      / SUM(row_count) OVER (PARTITION BY source_dataset),
      2
  ) AS percentage
FROM (
    SELECT
        source_dataset,
        COALESCE(NULLIF(trim(attack_category), ''), 'UNKNOWN') AS attack_category,
        COUNT(*) AS row_count
    FROM {database_name}.feature_engineered_cleaned
    GROUP BY source_dataset, 2
) t
ORDER BY source_dataset, row_count DESC
""")

attack_cat_by_source


Unnamed: 0,source_dataset,attack_category,row_count,percentage
0,CIC-IDS2017,Normal,2271205,80.32
1,CIC-IDS2017,DoS/DDoS,379737,13.43
2,CIC-IDS2017,Reconnaissance,158804,5.62
3,CIC-IDS2017,Brute Force,13832,0.49
4,CIC-IDS2017,Web Attack,2159,0.08
5,CIC-IDS2017,Botnet,1956,0.07
6,CIC-IDS2017,Infiltration,36,0.0
7,CIC-IDS2017,Injection,21,0.0
8,CIC-IDS2017,Exploits,11,0.0
9,TON_IoT,DoS/DDoS,9282720,57.93


## Sampling Before Stratified Dataset Split

The merged dataset contains tens of millions of network flow records, which might be unnecessary and inefficient for training a tree-based classification model in the context of this assignment. To reduce computational cost while preserving the statistical properties of the data, we perform a controlled sampling step before applying the stratified train/validation/test/production split.

The sampling is performed prior to splitting and is stratified by `(source_dataset, label)`, ensuring that:
- The relative contribution of each source dataset (TON_IoT, CIC-IDS2017, UNSW-NB15) is preserved.
- The benign/malicious label distribution within each dataset is maintained.
- Minority attack classes are not inadvertently removed.

In [31]:
# target total number of rows
TARGET_TOTAL_ROWS = 500_000

In [32]:
# drop sampled table if it already exists
exec_ddl(f"DROP TABLE IF EXISTS {database_name}.merged_sampled")

# delete S3 data directory to avoid HIVE_PATH_ALREADY_EXISTS error
import subprocess
subprocess.run(["aws", "s3", "rm", f"s3://{results_bucket}/aai540/processed/merged_sampled/", "--recursive"], check=False)

# create a sampled version of the merged dataset before splitting
exec_ddl(f"""
CREATE TABLE {database_name}.merged_sampled
WITH (
  format = 'PARQUET',
  external_location = 's3://{results_bucket}/aai540/processed/merged_sampled/',
  parquet_compression = 'SNAPPY'
) AS
WITH base AS (
  SELECT *
  FROM {database_name}.feature_engineered_cleaned
),

-- total rows per dataset
ds_counts AS (
  SELECT source_dataset, COUNT(*) AS ds_rows
  FROM base
  GROUP BY source_dataset
),

-- total rows per dataset + label
group_counts AS (
  SELECT source_dataset, label, COUNT(*) AS group_rows
  FROM base
  GROUP BY source_dataset, label
),

-- allocate target rows per dataset proportional to original size
ds_targets AS (
  SELECT
    source_dataset,
    CAST({TARGET_TOTAL_ROWS} AS DOUBLE)
      * (CAST(ds_rows AS DOUBLE) / SUM(CAST(ds_rows AS DOUBLE)) OVER ())
      AS ds_target_rows
  FROM ds_counts
),

-- allocate dataset targets down to label groups proportionally
group_targets AS (
  SELECT
    g.source_dataset,
    g.label,
    g.group_rows,
    d.ds_rows,
    t.ds_target_rows
      * (CAST(g.group_rows AS DOUBLE) / CAST(d.ds_rows AS DOUBLE))
      AS group_target_rows
  FROM group_counts g
  JOIN ds_counts d
    ON g.source_dataset = d.source_dataset
  JOIN ds_targets t
    ON g.source_dataset = t.source_dataset
),

-- compute sampling rate per (source_dataset, label)
rates AS (
  SELECT
    source_dataset,
    label,
    LEAST(1.0, group_target_rows / CAST(group_rows AS DOUBLE)) AS sample_rate
  FROM group_targets
)

SELECT b.*
FROM base b
JOIN rates r
  ON b.source_dataset = r.source_dataset
 AND b.label = r.label
WHERE rand() < r.sample_rate
""")

delete: s3://sagemaker-us-east-1-776673915827/aai540/processed/merged_sampled/20260216_162623_00070_arpas_6293976a-d77f-4e8d-9fff-32ad4559b896
delete: s3://sagemaker-us-east-1-776673915827/aai540/processed/merged_sampled/20260216_162623_00070_arpas_6c8ef49f-6183-4298-b3be-a3bc4896d74e
delete: s3://sagemaker-us-east-1-776673915827/aai540/processed/merged_sampled/20260216_162623_00070_arpas_f2276255-ba9f-49b5-87be-5580a43d6ca4
delete: s3://sagemaker-us-east-1-776673915827/aai540/processed/merged_sampled/20260216_162623_00070_arpas_6f4f21fc-a618-467f-b1dc-cab0372a7144
delete: s3://sagemaker-us-east-1-776673915827/aai540/processed/merged_sampled/20260216_162623_00070_arpas_929eb017-c773-4237-9498-7aa8b15259fa
delete: s3://sagemaker-us-east-1-776673915827/aai540/processed/merged_sampled/20260216_162623_00070_arpas_ab9ab76c-eb1d-4dd1-9b57-c336298bb72e
delete: s3://sagemaker-us-east-1-776673915827/aai540/processed/merged_sampled/20260216_162623_00070_arpas_b0daef2d-39e3-44b4-a348-9167b809865d

### Validate data source distribution

In [33]:
rows_per_source_sampled = read_sql(f"""
SELECT
  source_dataset,
  rows,
  ROUND(
      100.0 * rows
      / SUM(rows) OVER (),
      2
  ) AS percentage
FROM (
    SELECT
        source_dataset,
        COUNT(*) AS rows
    FROM {database_name}.merged_sampled
    GROUP BY source_dataset
) t
ORDER BY rows DESC
""")

rows_per_source_sampled

Unnamed: 0,source_dataset,rows,percentage
0,TON_IoT,374255,74.97
1,CIC-IDS2017,65996,13.22
2,UNSW-NB15,58965,11.81


### Validate label distribution per dataset

In [34]:
label_dist_by_source_sampled = read_sql(f"""
SELECT
  source_dataset,
  label,
  rows,
  ROUND(
      100.0 * rows
      / SUM(rows) OVER (PARTITION BY source_dataset),
      2
  ) AS percentage
FROM (
    SELECT
        source_dataset,
        label,
        COUNT(*) AS rows
    FROM {database_name}.merged_sampled
    GROUP BY source_dataset, label
) t
ORDER BY source_dataset, label
""")

label_dist_by_source_sampled

Unnamed: 0,source_dataset,label,rows,percentage
0,CIC-IDS2017,0,53004,80.31
1,CIC-IDS2017,1,12992,19.69
2,TON_IoT,0,8587,2.29
3,TON_IoT,1,365668,97.71
4,UNSW-NB15,0,51415,87.2
5,UNSW-NB15,1,7550,12.8


### Validate attack category distribution per dataset

In [35]:
attack_cat_by_source_sampled = read_sql(f"""
SELECT
  source_dataset,
  attack_category,
  row_count,
  ROUND(
      100.0 * row_count
      / SUM(row_count) OVER (PARTITION BY source_dataset),
      2
  ) AS percentage
FROM (
    SELECT
        source_dataset,
        COALESCE(NULLIF(trim(attack_category), ''), 'UNKNOWN') AS attack_category,
        COUNT(*) AS row_count
    FROM {database_name}.merged_sampled
    GROUP BY source_dataset, 2
) t
ORDER BY source_dataset, row_count DESC
""")

attack_cat_by_source_sampled

Unnamed: 0,source_dataset,attack_category,row_count,percentage
0,CIC-IDS2017,Normal,53004,80.31
1,CIC-IDS2017,DoS/DDoS,8823,13.37
2,CIC-IDS2017,Reconnaissance,3740,5.67
3,CIC-IDS2017,Brute Force,327,0.5
4,CIC-IDS2017,Botnet,50,0.08
5,CIC-IDS2017,Web Attack,50,0.08
6,CIC-IDS2017,Infiltration,1,0.0
7,CIC-IDS2017,Exploits,1,0.0
8,TON_IoT,DoS/DDoS,216649,57.89
9,TON_IoT,Web Attack,48926,13.07


## Stratified Dataset Split 
The dataset is partitioned into **train (40%)**, **validation (10%)**, **test (10%)**, and **production (40%)** categories. To handle the significant class and source imbalance (where the **TON_IoT** dataset is nearly 10x larger than others) we employ a **Stratified Splitting** strategy. Instead of a global split, we use the **`NTILE(100)`** window function partitioned by both `source_dataset` and `label`. This ensures that the 100 buckets are calculated independently for every unique combination of data source and class. As a result, each final split (e.g., the 'test' set) is guaranteed to contain exactly 10% of the samples from each specific dataset and each specific label.

### Determinism and Proportional Representation
By ordering the data within each partition by core flow features (such as `duration` and packet counts), the split remains fully deterministic and reproducible. This approach prevents any single dataset from disproportionately influencing a specific split and ensures that the model is validated and tested against a representative cross-section of all three telemetry sources.

In [36]:
# drop the split table if it already exists
exec_ddl(f"DROP TABLE IF EXISTS {database_name}.dataset_split")

# delete S3 data directory to avoid HIVE_PATH_ALREADY_EXISTS error
subprocess.run(["aws", "s3", "rm", f"s3://{results_bucket}/aai540/processed/dataset_split/", "--recursive"], check=False)

# create a new table with stratified splits
exec_ddl(f"""
CREATE TABLE {database_name}.dataset_split
WITH (
  format = 'PARQUET',
  external_location = 's3://{results_bucket}/aai540/processed/dataset_split/',
  parquet_compression = 'SNAPPY'
) AS

-- Using PARTITION BY ensures the 100 buckets are created for EACH dataset/label combo
WITH numbered_data AS (
  SELECT 
    *,
    NTILE(100) OVER (
        PARTITION BY source_dataset, label 
        ORDER BY 
            duration, 
            pkt_total, 
            bytes_total, 
            pkt_fwd, 
            pkt_bwd, 
            bytes_fwd, 
            bytes_bwd
    ) AS split_bucket
  FROM {database_name}.merged_sampled
)

SELECT
  *,
  CASE
    WHEN split_bucket <= 40 THEN 'train'  -- 40% of each dataset/label
    WHEN split_bucket <= 50 THEN 'val'    -- 10% of each dataset/label
    WHEN split_bucket <= 60 THEN 'test'   -- 10% of each dataset/label
    ELSE 'prod'                           -- 40% of each dataset/label
  END AS data_split
FROM numbered_data
""")

delete: s3://sagemaker-us-east-1-776673915827/aai540/processed/dataset_split/20260216_162634_00052_deyw3_2608cdbd-d835-496d-a8de-b7838fc0d196


### Verify splits

### Verify split_v2 has all engineered features

In [37]:
# Check columns in split_v2 - should have all 12 features
sample = read_sql(f"""
SELECT *
FROM {database_name}.dataset_split
LIMIT 1
""")

print("Columns in split_v2:")
print(list(sample.columns))
print(f"\nTotal columns: {len(sample.columns)}")

# Verify engineered features are present
engineered_features = ['pkt_rate', 'byte_rate', 'bytes_per_pkt', 'pkt_ratio', 'byte_ratio']
missing = [f for f in engineered_features if f not in sample.columns]
if missing:
    print(f"\nMISSING engineered features: {missing}")
else:
    print(f"\n✓ All 5 engineered features present: {engineered_features}")

Columns in split_v2:
['duration', 'pkt_total', 'bytes_total', 'pkt_fwd', 'pkt_bwd', 'bytes_fwd', 'bytes_bwd', 'label', 'original_attack_type', 'attack_category', 'source_dataset', 'pkt_rate', 'byte_rate', 'bytes_per_pkt', 'pkt_ratio', 'byte_ratio', 'split_bucket', 'data_split']

Total columns: 18

✓ All 5 engineered features present: ['pkt_rate', 'byte_rate', 'bytes_per_pkt', 'pkt_ratio', 'byte_ratio']


In [38]:
read_sql(f"""
SELECT data_split, COUNT(*) AS rows,
       ROUND(COUNT(*) * 100.0 / SUM(COUNT(*)) OVER (), 2) AS pct
FROM {database_name}.dataset_split
GROUP BY data_split
ORDER BY data_split
""")

Unnamed: 0,data_split,rows,pct
0,prod,199627,39.99
1,test,49920,10.0
2,train,199739,40.01
3,val,49930,10.0


### Verify data source distribution

In [39]:
read_sql(f"""
SELECT data_split, source_dataset, COUNT(*) rows,
       ROUND(COUNT(*) * 100.0 / SUM(COUNT(*)) OVER (PARTITION BY data_split), 2) pct_within_split
FROM {database_name}.dataset_split
GROUP BY data_split, source_dataset
ORDER BY data_split, source_dataset;
""")

Unnamed: 0,data_split,source_dataset,rows,pct_within_split
0,prod,CIC-IDS2017,26392,13.22
1,prod,TON_IoT,149675,74.98
2,prod,UNSW-NB15,23560,11.8
3,test,CIC-IDS2017,6600,13.22
4,test,TON_IoT,37430,74.98
5,test,UNSW-NB15,5890,11.8
6,train,CIC-IDS2017,26404,13.22
7,train,TON_IoT,149720,74.96
8,train,UNSW-NB15,23615,11.82
9,val,CIC-IDS2017,6600,13.22


### Verify label distribution in the train split

In [40]:
read_sql(f"""
SELECT 
    label, 
    COUNT(*) as row_count,
    ROUND(COUNT(*) * 100.0 / SUM(COUNT(*)) OVER (), 2) as percent_of_train
FROM {database_name}.dataset_split
WHERE data_split = 'train'
GROUP BY label
ORDER BY label
""")

Unnamed: 0,label,row_count,percent_of_train
0,0,45219,22.64
1,1,154520,77.36


### Verify attack category distribution in the train split

In [41]:
read_sql(f"""
SELECT 
    attack_category, 
    COUNT(*) as row_count,
    ROUND(COUNT(*) * 100.0 / SUM(COUNT(*)) OVER (), 2) as percent_of_train
FROM {database_name}.dataset_split
WHERE data_split = 'train'
GROUP BY attack_category
ORDER BY row_count DESC
""")

Unnamed: 0,attack_category,row_count,percent_of_train
0,DoS/DDoS,109842,54.99
1,Normal,45219,22.64
2,Reconnaissance,26750,13.39
3,Backdoor,11186,5.6
4,Generic Malware,2522,1.26
5,Brute Force,2499,1.25
6,Web Attack,1300,0.65
7,Injection,151,0.08
8,Exploits,146,0.07
9,Fuzzing,104,0.05


In [42]:
read_sql(f"""
SELECT 
    attack_category, 
    COUNT(*) as row_count,
    ROUND(COUNT(*) * 100.0 / SUM(COUNT(*)) OVER (), 2) as percent_of_train
FROM {database_name}.dataset_split
WHERE data_split = 'train'
GROUP BY attack_category
ORDER BY row_count DESC
""")

Unnamed: 0,attack_category,row_count,percent_of_train
0,DoS/DDoS,109842,54.99
1,Normal,45219,22.64
2,Reconnaissance,26750,13.39
3,Backdoor,11186,5.6
4,Generic Malware,2522,1.26
5,Brute Force,2499,1.25
6,Web Attack,1300,0.65
7,Injection,151,0.08
8,Exploits,146,0.07
9,Fuzzing,104,0.05
