# 01 · Data Prep and Splits

## Purpose

Prepare artifact splits without leakage before downstream training.

## Inputs

- `data/raw/controls.csv` — canonical NIST control catalog (control_id,family,title,summary).
- `data/raw/artifacts.csv` — raw artifacts with optional `partition`, labels, and rationale.

## Outputs

- `data/processed/artifacts_with_split.csv` with validated schema and enforced partitions.
- Inline partition and evidence-type summary tables for sanity checks.

## Steps

1. Load raw CSVs, assert required columns, and coerce types (IDs as strings, timestamps as datetime).
2. If `partition` is missing or blank, assign stratified 60/20/20 (train/dev/test) using a fixed RNG seed.
3. Normalize artifact text (lowercase + trim) to build a leakage hash and detect duplicates.
4. Resolve duplicates so each unique text appears in exactly one split (prefer train when conflicts).
5. Drop helper hash columns, persist the processed CSV, and surface partition/evidence counts.

## Acceptance Checks

- Only `train`, `dev`, and `test` appear in the partition column.
- Zero duplicate normalized texts across different partitions.
- `data/processed/artifacts_with_split.csv` exists on disk after execution.

In [1]:
import pandas as pd
import numpy as np
from pathlib import Path
import hashlib

# Set random seed for reproducibility
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)

## 1. Load and validate raw data

In [2]:
# Load controls
controls_path = Path("../data/raw/controls.csv")
controls = pd.read_csv(controls_path, dtype=str)

# Validate required columns
required_control_cols = ["control_id", "family", "title", "summary"]
assert all(col in controls.columns for col in required_control_cols), \
    f"Missing columns in controls.csv. Expected: {required_control_cols}"

print(f"✓ Loaded {len(controls)} controls")
print(f"  Columns: {list(controls.columns)}")
controls.head(3)

✓ Loaded 31 controls
  Columns: ['control_id', 'family', 'title', 'summary']


Unnamed: 0,control_id,family,title,summary
0,AC-2,AC,Account Management,"Provision, review, and remove accounts; enforc..."
1,AC-6,AC,Least Privilege,Restrict privileges to the minimum necessary; ...
2,AC-7,AC,Unsuccessful Logon Attempts,Enforce lockout thresholds and durations after...


In [3]:
# Load artifacts
artifacts_path = Path("../data/raw/artifacts.csv")
artifacts = pd.read_csv(artifacts_path, dtype={"artifact_id": str, "text": str, "evidence_type": str, "gold_controls": str, "gold_rationale": str})

# Parse timestamp
artifacts["timestamp"] = pd.to_datetime(artifacts["timestamp"], errors="coerce")

# Validate required columns
required_artifact_cols = ["artifact_id", "text", "evidence_type", "gold_controls"]
assert all(col in artifacts.columns for col in required_artifact_cols), \
    f"Missing columns in artifacts.csv. Expected: {required_artifact_cols}"

print(f"✓ Loaded {len(artifacts)} artifacts")
print(f"  Columns: {list(artifacts.columns)}")
artifacts.head(3)

✓ Loaded 1000 artifacts
  Columns: ['artifact_id', 'text', 'evidence_type', 'timestamp', 'gold_controls', 'gold_rationale']


Unnamed: 0,artifact_id,text,evidence_type,timestamp,gold_controls,gold_rationale
0,10522,Asset inventory shows 22 untagged cloud instan...,config,2025-09-18 05:20:00+00:00,CM-8,Information system component inventory incompl...
1,10738,API gateway rate limiting configured; DDoS pro...,config,2025-11-11 09:00:00+00:00,SC-5;SC-7,Traffic filtering implemented to prevent DoS a...
2,10741,SSH daemon configuration hardened; root login ...,config,2025-11-11 09:45:00+00:00,AC-17;SC-8,Remote access security enhanced with strong cr...


## 2. Assign train/dev/test partitions (60/20/20)

In [4]:
# Check if partition column exists
if "partition" not in artifacts.columns or artifacts["partition"].isna().all():
    print("⚠ No partition column found or all values are missing. Assigning 60/20/20 split...")
    
    # Shuffle and split
    n = len(artifacts)
    indices = np.random.permutation(n)
    
    train_end = int(0.6 * n)
    dev_end = int(0.8 * n)
    
    partitions = np.empty(n, dtype=object)
    partitions[indices[:train_end]] = "train"
    partitions[indices[train_end:dev_end]] = "dev"
    partitions[indices[dev_end:]] = "test"
    
    artifacts["partition"] = partitions
else:
    # Fill missing partitions with train/dev/test
    print("ℹ Partition column exists. Filling missing values...")
    artifacts["partition"] = artifacts["partition"].fillna("train")

print(f"\n✓ Partition distribution:")
print(artifacts["partition"].value_counts().sort_index())

⚠ No partition column found or all values are missing. Assigning 60/20/20 split...

✓ Partition distribution:
partition
dev      200
test     200
train    600
Name: count, dtype: int64


## 3. Detect and resolve duplicate texts across partitions (leakage guard)

In [5]:
# Create normalized text hash for duplicate detection
artifacts["text_hash"] = artifacts["text"].str.lower().str.strip().apply(
    lambda x: hashlib.md5(x.encode()).hexdigest() if pd.notna(x) else None
)

# Find duplicates across partitions
duplicates = artifacts.groupby("text_hash")["partition"].nunique()
cross_partition_dupes = duplicates[duplicates > 1]

print(f"Found {len(cross_partition_dupes)} unique texts appearing in multiple partitions")

if len(cross_partition_dupes) > 0:
    print("\n⚠ Resolving duplicates (keeping first occurrence, preferring train)...")
    
    # For each duplicate hash, keep only one partition (prefer train > dev > test)
    partition_priority = {"train": 0, "dev": 1, "test": 2}
    
    # Mark rows to keep
    artifacts["_priority"] = artifacts["partition"].map(partition_priority)
    artifacts["_keep"] = False
    
    for text_hash in cross_partition_dupes.index:
        mask = artifacts["text_hash"] == text_hash
        dupe_group = artifacts[mask].sort_values("_priority")
        # Keep only the first (highest priority) occurrence
        first_idx = dupe_group.index[0]
        artifacts.loc[first_idx, "_keep"] = True
    
    # Also keep all non-duplicates
    non_dupe_mask = ~artifacts["text_hash"].isin(cross_partition_dupes.index)
    artifacts.loc[non_dupe_mask, "_keep"] = True
    
    # Remove duplicates
    rows_before = len(artifacts)
    artifacts = artifacts[artifacts["_keep"]].copy()
    rows_after = len(artifacts)
    
    print(f"  Removed {rows_before - rows_after} duplicate rows")
    
    # Clean up helper columns
    artifacts.drop(columns=["_priority", "_keep"], inplace=True)
else:
    print("✓ No cross-partition duplicates found!")

# Drop the hash column
artifacts.drop(columns=["text_hash"], inplace=True)

Found 86 unique texts appearing in multiple partitions

⚠ Resolving duplicates (keeping first occurrence, preferring train)...
  Removed 95 duplicate rows


## 4. Save processed data and print summary statistics

In [6]:
# Create output directory if needed
output_dir = Path("../data/processed")
output_dir.mkdir(parents=True, exist_ok=True)

# Save processed artifacts
output_path = output_dir / "artifacts_with_split.csv"
artifacts.to_csv(output_path, index=False)

print(f"✓ Saved {len(artifacts)} artifacts to {output_path}")
print(f"\nFinal partition distribution:")
print(artifacts["partition"].value_counts().sort_index())

✓ Saved 905 artifacts to ../data/processed/artifacts_with_split.csv

Final partition distribution:
partition
dev      162
test     148
train    595
Name: count, dtype: int64


In [7]:
# Evidence type × partition crosstab
print("\n" + "="*60)
print("Evidence Type × Partition Distribution")
print("="*60)
crosstab = pd.crosstab(artifacts["evidence_type"], artifacts["partition"], margins=True)
print(crosstab)


Evidence Type × Partition Distribution
partition      dev  test  train  All
evidence_type                       
config          57    57    194  308
log             56    48    216  320
ticket          49    43    185  277
All            162   148    595  905


## 5. Acceptance checks

In [8]:
print("="*60)
print("ACCEPTANCE CHECKS")
print("="*60)

# Check 1: Only train/dev/test partitions
valid_partitions = {"train", "dev", "test"}
actual_partitions = set(artifacts["partition"].unique())
check1 = actual_partitions.issubset(valid_partitions)
print(f"\n✓ Check 1: Only train/dev/test partitions: {check1}")
print(f"  Found partitions: {sorted(actual_partitions)}")

# Check 2: Zero cross-partition duplicates
artifacts_recheck = pd.read_csv(output_path)
artifacts_recheck["text_hash"] = artifacts_recheck["text"].str.lower().str.strip().apply(
    lambda x: hashlib.md5(x.encode()).hexdigest() if pd.notna(x) else None
)
duplicates_final = artifacts_recheck.groupby("text_hash")["partition"].nunique()
cross_partition_final = (duplicates_final > 1).sum()
check2 = cross_partition_final == 0
print(f"\n✓ Check 2: Zero cross-partition duplicates: {check2}")
print(f"  Duplicate texts across partitions: {cross_partition_final}")

# Check 3: Output file exists
check3 = output_path.exists()
print(f"\n✓ Check 3: Output file exists: {check3}")
print(f"  Path: {output_path}")

# Overall
all_checks_passed = check1 and check2 and check3
print("\n" + "="*60)
if all_checks_passed:
    print("✅ ALL ACCEPTANCE CHECKS PASSED")
else:
    print("❌ SOME ACCEPTANCE CHECKS FAILED")
print("="*60)

ACCEPTANCE CHECKS

✓ Check 1: Only train/dev/test partitions: True
  Found partitions: ['dev', 'test', 'train']

✓ Check 2: Zero cross-partition duplicates: True
  Duplicate texts across partitions: 0

✓ Check 3: Output file exists: True
  Path: ../data/processed/artifacts_with_split.csv

✅ ALL ACCEPTANCE CHECKS PASSED
