In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler, RobustScaler, LabelEncoder, OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_selection import SelectFromModel
from mpl_toolkits.mplot3d import Axes3D
from sklearn.model_selection import train_test_split
import glob
import os
import gc

In [None]:
# ==========================================
# CONFIGURATION
# ==========================================
file_pattern = "data/*.csv" 
files = glob.glob(file_pattern)
files.sort()

numeric_stats_list = []
categorical_stats_list = []

print(f"--- STARTING SANITIZED UNIVARIATE ANALYSIS ON {len(files)} FILES ---\n")

for filepath in files:
    filename = os.path.basename(filepath)
    print(f"Processing: {filename}...")
    
    try:
        df = pd.read_csv(filepath, low_memory=False)
        df.columns = df.columns.str.strip()

        print("\n--- HEALTH CHECK ---")
        print("\n--- CHECKING DATA TYPES ---")
        # We look for columns that are 'object' (strings) but should be numbers
        print(df.dtypes.value_counts())
        
        print("\n--- MISSING & INFINITE VALUES ---")
        # Standard NULL check
        null_counts = df.isnull().sum().sum()
        print(f"Total Null Values: {null_counts}")
        
        # CRITICAL FOR NETWORK DATA: Check for Infinity
        # Select only numeric columns to avoid errors with string columns
        numeric_cols = df.select_dtypes(include=[np.number]).columns
        inf_counts = np.isinf(df[numeric_cols]).sum().sum()
        print(f"Total Infinite Values (inf/-inf): {inf_counts}")

        print("\n--- DUPLICATES ---")
        duplicates = df.duplicated().sum()
        print(f"Duplicate Rows: {duplicates} ({duplicates/len(df)*100:.2f}%)")
        
        # Identify types
        cat_cols = ['Label', 'Protocol', 'Destination Port', 'Source Port']
        flag_cols = [c for c in df.columns if 'Flag' in c]
        cat_cols.extend(flag_cols)
        
        existing_cat = [c for c in cat_cols if c in df.columns]
        existing_num = [c for c in df.columns if c not in existing_cat and pd.api.types.is_numeric_dtype(df[c])]

        # ==========================================
        # A. NUMERICAL ANALYSIS (Sanitized)
        # ==========================================
        if existing_num:
            for col in existing_num:
                series = df[col]
                
                # 1. HARD STATISTICS (Counts of Dirty Data)
                total_rows = len(series)
                neg_count = (series < 0).sum()
                inf_count = np.isinf(series).sum()
                nan_count = series.isnull().sum()
                
                # 2. SANITIZATION FOR MOMENTS (Mean, Skew, etc.)
                # We replace Inf with NaN, then drop NaNs so the math works on valid numbers only
                clean_series = series.replace([np.inf, -np.inf], np.nan).dropna()
                
                if len(clean_series) > 0:
                    # Physics Check: For Flow Duration, we might want to exclude negatives from the Mean calc
                    # (Optional: depends on if you think -1 is a flag or noise. 
                    # For now, we include negatives in stats to show the error magnitude)
                    
                    # Calculated Moments
                    mean_val = clean_series.mean()
                    median_val = clean_series.median()
                    std_val = clean_series.std()
                    min_val = clean_series.min()
                    max_val = clean_series.max()
                    skew_val = clean_series.skew()
                    kurt_val = clean_series.kurt()
                    
                    # Outlier Check (on clean data)
                    Q1 = clean_series.quantile(0.25)
                    Q3 = clean_series.quantile(0.75)
                    IQR = Q3 - Q1
                    outliers = ((clean_series < (Q1 - 1.5 * IQR)) | (clean_series > (Q3 + 1.5 * IQR))).sum()
                else:
                    # If column was 100% Inf or NaN
                    mean_val = median_val = std_val = min_val = max_val = skew_val = kurt_val = outliers = 0

                numeric_stats_list.append({
                    'File': filename,
                    'Feature': col,
                    'Mean': mean_val,
                    'Median': median_val,
                    'Std_Dev': std_val,
                    'Min': min_val,
                    'Max': max_val,
                    'Skewness': skew_val,
                    'Kurtosis': kurt_val,
                    'Neg_Values': neg_count,
                    'Inf_Values': inf_count,
                    'NaN_Values': nan_count,
                    'Outlier_Count': outliers,
                    'Total_Rows': total_rows
                })

        # ==========================================
        # B. CATEGORICAL ANALYSIS
        # ==========================================
        for col in existing_cat:
            unique_count = df[col].nunique()
            top_val = df[col].mode()[0] if not df[col].mode().empty else "N/A"
            top_freq = df[col].value_counts().iloc[0] if unique_count > 0 else 0
            imbalance = top_freq / len(df)
            
            categorical_stats_list.append({
                'File': filename,
                'Feature': col,
                'Cardinality': unique_count,
                'Top_Class': top_val,
                'Imbalance_Ratio': imbalance,
                'Total_Rows': len(df)
            })
    
        # ==========================================
        # 1. PROTOCOL ANALYSIS
        # ==========================================
        if 'Protocol' in df.columns:
            protocol_map = {6: 'TCP', 17: 'UDP', 1: 'ICMP'}
            # Create a temporary column for plotting
            df['Protocol_Name'] = df['Protocol'].map(protocol_map).fillna('Other')
            
            plt.figure(figsize=(8, 5))
            sns.countplot(x=df['Protocol_Name'], order=df['Protocol_Name'].value_counts().index)
            plt.title(f"Distribution of Protocols - {filename}")
            plt.show()

        # ==========================================
        # 2. DESTINATION PORT ANALYSIS
        # ==========================================
        if 'Destination Port' in df.columns:
            top_ports = df['Destination Port'].value_counts().head(10)
            
            plt.figure(figsize=(10, 5))
            sns.barplot(x=top_ports.index, y=top_ports.values, order=top_ports.index)
            plt.title(f"Top 10 Destination Ports - {filename}")
            plt.xlabel("Port Number")
            plt.ylabel("Frequency")
            plt.show()

        # ==========================================
        # 3. FLAG ANALYSIS
        # ==========================================
        flag_cols = [col for col in df.columns if 'Flag' in col]
        if flag_cols:
            # Check if flags are numeric before summing
            if pd.api.types.is_numeric_dtype(df[flag_cols[0]]):
                flag_counts = df[flag_cols].sum().sort_values(ascending=False)
                
                plt.figure(figsize=(12, 6))
                sns.barplot(x=flag_counts.index, y=flag_counts.values)
                plt.title(f"Frequency of Network Flags - {filename}")
                plt.xticks(rotation=45, ha='right')
                plt.show()
            else:
                print("Skipping Flag plot: Flag columns appear to be non-numeric.")

        # ==========================================
        # 4. VARIANCE SCREENING
        # ==========================================
        print(f"\n--- DETECTING USELESS COLUMNS ({filename}) ---")
        numeric_df = df.select_dtypes(include=[np.number])
        std_devs = numeric_df.std()
        constant_cols = std_devs[std_devs == 0].index.tolist()
        
        print(f"Constant Columns Found: {len(constant_cols)}")
        if constant_cols:
            print(constant_cols)

        # ==========================================
        # 5. CORRELATION ANALYSIS
        # ==========================================
        print(f"\n--- CORRELATION MATRIX ({filename}) ---")
        
        # Prepare data for correlation
        # Exclude IDs, Labels, Context, and Zero Variance columns
        ignore_cols = ['Flow ID', 'Source IP', 'Destination IP', 'Timestamp', 'Label', 'Protocol_Name'] + constant_cols
        cols_to_drop = [c for c in ignore_cols if c in df.columns]
        
        # We only correlate the remaining NUMERIC columns
        analysis_df = df.drop(columns=cols_to_drop, errors='ignore').select_dtypes(include=[np.number])
        
        if not analysis_df.empty:
            corr_matrix = analysis_df.corr()

            plt.figure(figsize=(16, 12))
            sns.heatmap(corr_matrix, cmap='coolwarm', vmin=-1, vmax=1)
            plt.title(f"Correlation Matrix - {filename}")
            plt.show()

            # Identify High Correlations (> 0.95)
            upper_tri = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))
            to_drop_corr = [column for column in upper_tri.columns if any(upper_tri[column].abs() > 0.95)]

            print(f"Number of Highly Correlated (>0.95) Features: {len(to_drop_corr)}")
            print("Sample of redundant features:", to_drop_corr[:10])
        else:
            print("Not enough numeric columns left for correlation analysis.")

    except Exception as e:
        print(f"  [ERROR] {filename}: {e}")
        
    del df
    gc.collect()

print("\n--- SANITIZED ANALYSIS COMPLETE ---")

print("--- NUMERICAL STATS INSPECTION (First 5 entries per file) ---")
# We use a set to track which files we've already seen to just show a snippet of each
seen_files = {}

for entry in numeric_stats_list:
    fname = entry['File']
    
    # Initialize counter for this file
    if fname not in seen_files:
        seen_files[fname] = 0
    
    # Only show the first 3 features per file to keep output clean
    if seen_files[fname] < 3:
        print(f"[{fname}] Feature: {entry['Feature']}")
        print(f"   Mean: {entry['Mean']:.4f} | Median: {entry['Median']:.4f} | Std: {entry['Std_Dev']:.4f}")
        print(f"   Skew: {entry['Skewness']:.2f} | Kurt: {entry['Kurtosis']:.2f}")
        print(f"   Health: Neg={entry['Neg_Values']}, Inf={entry['Inf_Values']}, NaN={entry['NaN_Values']}")
        print("-" * 40)
        seen_files[fname] += 1

print("\n" + "="*60 + "\n")

print("--- CATEGORICAL STATS INSPECTION (All entries) ---")
for entry in categorical_stats_list:
    # We define a warning flag if the data is highly imbalanced (>95%)
    warning = " [IMBALANCED]" if entry['Imbalance_Ratio'] > 0.95 else ""
    
    print(f"[{entry['File']}] Feature: {entry['Feature']}")
    print(f"   Unique Values (Cardinality): {entry['Cardinality']}")
    print(f"   Most Frequent: '{entry['Top_Class']}' ({entry['Imbalance_Ratio']*100:.2f}% of data){warning}")
    print("-" * 40)
# Display Results
df_stats_num = pd.DataFrame(numeric_stats_list)
df_stats_cat = pd.DataFrame(categorical_stats_list)

# CHECK 1: The Init_Win_bytes Mystery
print("\n=== NEGATIVE VALUE ANALYSIS ===")
# Show features with high negative counts
print(df_stats_num[df_stats_num['Neg_Values'] > 0][['File', 'Feature', 'Min', 'Neg_Values', 'Total_Rows']].sort_values(by='Neg_Values', ascending=False).head(10))

# CHECK 2: Distribution Shape (Valid Stats)
print("\n=== SKEWNESS REPORT (Cleaned) ===")
print(df_stats_num[['File', 'Feature', 'Skewness', 'Kurtosis']].sort_values(by='Skewness', ascending=False).head(10))

In [None]:
# ==========================================
# CONFIGURATION
# ==========================================
file_pattern = "data/*.csv"  # Update this!
files = glob.glob(file_pattern)
files.sort()

# Sampling Ratio: 10% of the data is usually enough for EDA
# If you have < 16GB RAM, lower this to 0.05 (5%)
SAMPLE_FRAC = 0.10 

global_samples = []

print(f"--- 1. BUILDING GLOBAL STRATIFIED SAMPLE ({SAMPLE_FRAC*100}%) ---")

for filepath in files:
    filename = os.path.basename(filepath)
    print(f"Sampling: {filename}...")
    
    try:
        # Load Raw File
        df = pd.read_csv(filepath, low_memory=False)
        
        # Basic Cleanup (Headers)
        df.columns = df.columns.str.strip()
        
        # STRATIFIED SAMPLING
        # We group by Label to ensure we keep a piece of every attack type
        # group_keys=False prevents Pandas from adding an extra index
        if 'Label' in df.columns:
            # We take 10% of EACH class. 
            # If a class has fewer than 10 rows, we take all of them to preserve rare attacks.
            sample = df.groupby('Label', group_keys=False).apply(
                lambda x: x.sample(frac=SAMPLE_FRAC) if len(x) > 10 else x
            )
        else:
            # Fallback if no label (should not happen in CIC-IDS2017)
            sample = df.sample(frac=SAMPLE_FRAC, random_state=42)
            
        print(f"   > Original: {len(df)} rows | Sampled: {len(sample)} rows")
        global_samples.append(sample)
        
    except Exception as e:
        print(f"   [ERROR] Could not sample {filename}: {e}")
    
    # Memory Cleanup
    del df
    gc.collect()

# ==========================================
# 2. AGGREGATING THE GLOBAL DATASET
# ==========================================
print("\n--- MERGING SAMPLES ---")
global_df = pd.concat(global_samples, axis=0)
print(f"Global Dataset Shape: {global_df.shape}")
print("\nGlobal Label Distribution:")
print(global_df['Label'].value_counts())

# ==========================================
# 3. GLOBAL DATA CLEANING (For Analysis Only)
# ==========================================
# We must sanitize the global sample to run Correlation/Variance checks
# (We are NOT saving this cleaned version yet, just using it to find bad columns)

print("\n--- SANITIZING GLOBAL SAMPLE ---")

# Replace Inf with NaN
global_df.replace([np.inf, -np.inf], np.nan, inplace=True)

# Separate numeric
numeric_df = global_df.select_dtypes(include=[np.number])

# ==========================================
# 4. GLOBAL VARIANCE ANALYSIS (The "Drop List")
# ==========================================
print("\n--- DETECTING GLOBALLY CONSTANT COLUMNS ---")
# If std() is 0 (or NaN because of all NaNs), the column provides no info
std_devs = numeric_df.std()
# We look for columns that are effectively zero variance
constant_cols = std_devs[std_devs == 0].index.tolist()

# Also check for columns that are 100% Null (Empty)
null_cols = global_df.columns[global_df.isnull().all()].tolist()
cols_to_drop = list(set(constant_cols + null_cols))

print(f"Features with ZERO variance across ALL files: {len(cols_to_drop)}")
print(cols_to_drop)

# ==========================================
# 5. GLOBAL CORRELATION ANALYSIS
# ==========================================
print("\n--- GLOBAL CORRELATION ANALYSIS ---")

# Drop the constant columns and Identifiers before correlation
ignore_cols = ['Flow ID', 'Source IP', 'Destination IP', 'Timestamp', 'Label'] + cols_to_drop
corr_data = global_df.drop(columns=[c for c in ignore_cols if c in global_df.columns])
corr_data = corr_data.select_dtypes(include=[np.number])

# Compute Matrix
corr_matrix = corr_data.corr()

# Plot Heatmap
plt.figure(figsize=(20, 16))
sns.heatmap(corr_matrix, cmap='coolwarm', vmin=-1, vmax=1)
plt.title("Global Correlation Matrix (Representative Sample)")
plt.show()

# List Pairs > 0.95
print("\n--- HIGHLY CORRELATED PAIRS (> 0.95) ---")
upper_tri = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))
# Find index of columns with correlation > 0.95
to_drop_corr = [column for column in upper_tri.columns if any(upper_tri[column].abs() > 0.95)]

print(f"Number of Redundant Features: {len(to_drop_corr)}")
print("List of Redundant Features (Consider dropping):")
print(to_drop_corr)

# ==========================================
# 6. GLOBAL DISTRIBUTION CHECK (SKEWNESS)
# ==========================================
print("\n--- GLOBAL SKEWNESS CHECK ---")
# We check skewness one last time on the global set to confirm Log-Scaling needs
skew_series = numeric_df.skew().sort_values(ascending=False)
print("Top 10 Most Skewed Features (Global):")
print(skew_series.head(10))

bad_features = {
    "constant_columns": cols_to_drop,
    "redundant_correlation": to_drop_corr
}

print("\n--- SAVING FEATURE LISTS ---")

In [None]:


class CICIDS_Preprocessor:
    def __init__(self, bad_feature_dict, scheme='A', correlation_threshold=0.95):
        self.scheme = scheme
        self.corr_thresh = correlation_threshold
        
        # 1. PARSE STATIC DROP LIST
        # Combine constant columns and redundant correlation columns into one unique list
        self.static_drop_cols = list(set(
            bad_feature_dict.get("constant_columns", []) + 
            bad_feature_dict.get("redundant_correlation", [])
        ))
        
        # Storage for fitted objects
        self.scaler = None
        self.imputer = None
        self.encoders = {}
        self.ohe = None
        self.feature_selector = None
        self.selected_features = None
        self.dynamic_drop_cols = [] # Columns found to be correlated during fit()

    def fit(self, X, y=None):
        print(f"--- FITTING PREPROCESSOR (SCHEME {self.scheme}) ---")
        # Working copy
        X_clean = X.copy()
        
        # 1. APPLY STATIC DROPS (From your Dictionary)
        # Only drop columns that actually exist in this dataframe
        existing_drop = [c for c in self.static_drop_cols if c in X_clean.columns]
        X_clean.drop(columns=existing_drop, inplace=True)
        print(f"   > Dropped {len(existing_drop)} static bad features.")
        
        # 2. CLEAN INFINITY
        X_clean.replace([np.inf, -np.inf], np.nan, inplace=True)

        # 3. DYNAMIC CORRELATION CHECK (Safety Net)
        # Even though we dropped global redundant cols, the sample might have new correlations
        numeric_df = X_clean.select_dtypes(include=[np.number])
        corr_matrix = numeric_df.corr().abs()
        upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))
        self.dynamic_drop_cols = [column for column in upper.columns if any(upper[column] > self.corr_thresh)]
        
        if self.dynamic_drop_cols:
            print(f"   > Dropped {len(self.dynamic_drop_cols)} additional correlated features.")
            X_clean.drop(columns=self.dynamic_drop_cols, inplace=True)
        
        # Identify columns types
        cat_cols = ['Protocol'] 
        if 'Destination Port' in X_clean.columns: cat_cols.append('Destination Port')
        # All remaining are numeric
        num_cols = [c for c in X_clean.columns if c not in cat_cols]

        # 4. IMPUTATION FIT
        self.imputer = SimpleImputer(strategy='median')
        self.imputer.fit(X_clean[num_cols])
        
        # 5. SCALING & ENCODING FIT
        if self.scheme == 'A':
            # RobustScaler
            self.scaler = RobustScaler()
            self.scaler.fit(X_clean[num_cols])
            
            # Label Encoding
            for col in cat_cols:
                le = LabelEncoder()
                le.fit(X_clean[col].astype(str))
                self.encoders[col] = le
                
        elif self.scheme == 'B':
            # Log1p + StandardScaler
            # We transform first to fit the scaler correctly
            X_log = np.log1p(X_clean[num_cols].clip(lower=0))
            self.scaler = StandardScaler()
            self.scaler.fit(X_log)
            
            # One-Hot Encoding
            self.ohe = OneHotEncoder(sparse_output=False, handle_unknown='ignore')
            if 'Protocol' in X_clean.columns:
                self.ohe.fit(X_clean[['Protocol']])

        # 6. FEATURE SELECTION (Embedded Random Forest)
        print("   > Running Embedded Feature Selection...")
        # Prepare data for Selector (Basic Imputation + Freq Encoding)
        X_temp = X_clean.copy()
        X_temp[num_cols] = self.imputer.transform(X_temp[num_cols])
        
        for col in cat_cols:
             freqs = X_temp[col].value_counts()
             X_temp[col] = X_temp[col].map(freqs)
        
        # Train Selector
        rf = RandomForestClassifier(n_estimators=50, max_depth=10, random_state=42, n_jobs=-1)
        rf.fit(X_temp, y)
        
        self.feature_selector = SelectFromModel(rf, prefit=True, threshold="mean")
        self.selected_features = X_temp.columns[self.feature_selector.get_support()]
        print(f"   > Selected {len(self.selected_features)} features out of {X_temp.shape[1]}")
        
        return self

    def transform(self, X):
        X_processed = X.copy()
        
        # 1. STATIC & DYNAMIC DROPS
        # Combine lists and drop
        all_drops = list(set(self.static_drop_cols + self.dynamic_drop_cols))
        existing_drops = [c for c in all_drops if c in X_processed.columns]
        X_processed.drop(columns=existing_drops, inplace=True)
        X_processed.replace([np.inf, -np.inf], np.nan, inplace=True)
        
        # Re-identify columns
        cat_cols = ['Protocol']
        if 'Destination Port' in X_processed.columns: cat_cols.append('Destination Port')
        num_cols = [c for c in X_processed.columns if c not in cat_cols]
        
        # 2. IMPUTATION
        X_processed[num_cols] = self.imputer.transform(X_processed[num_cols])
        
        # 3. SCHEME TRANSFORMS
        if self.scheme == 'A':
            X_processed[num_cols] = self.scaler.transform(X_processed[num_cols])
            for col in cat_cols:
                le = self.encoders.get(col)
                # Map unknown labels to the first class (usually 0) to prevent crashing
                X_processed[col] = X_processed[col].astype(str).apply(lambda x: x if x in le.classes_ else le.classes_[0])
                X_processed[col] = le.transform(X_processed[col])
                
        elif self.scheme == 'B':
            X_processed[num_cols] = np.log1p(X_processed[num_cols].clip(lower=0))
            X_processed[num_cols] = self.scaler.transform(X_processed[num_cols])
            
            if 'Protocol' in X_processed.columns:
                proto_ohe = self.ohe.transform(X_processed[['Protocol']])
                ohe_cols = [f"Proto_{i}" for i in range(proto_ohe.shape[1])]
                # Concat and Drop
                X_processed = pd.concat([X_processed.reset_index(drop=True), 
                                         pd.DataFrame(proto_ohe, columns=ohe_cols)], axis=1)
                X_processed.drop(columns=['Protocol'], inplace=True)
                
                # Note: Reset index above is crucial when concatenating with OHE array

        # 4. FILTER SELECTED FEATURES
        # Only keep columns that the Selector liked
        # (We handle the case where OHE might have changed column names slightly or kept them)
        # For Scheme A (LabelEnc), names match. For Scheme B, we might lose Protocol, but that's fine.
        final_cols = [c for c in self.selected_features if c in X_processed.columns]
        
        # If Scheme B generated new OHE columns, we usually keep them all or apply selection logic.
        # For simplicity in this implementation, we force keep OHE columns if Scheme B
        if self.scheme == 'B':
            ohe_cols = [c for c in X_processed.columns if 'Proto_' in c]
            final_cols.extend(ohe_cols)
            final_cols = list(set(final_cols))
            
        return X_processed[final_cols]

In [None]:
# ==========================================
# 0. SETUP BAD FEATURES LIST
# ==========================================
# Ensure your dictionary is defined here (from previous steps)
# bad_features = { "constant_columns": [...], "redundant_correlation": [...] }
# If you don't have the variable in memory, stick empty lists to avoid crash:
if 'bad_features' not in locals():
    print("Warning: bad_features dictionary not found. Creating empty placeholder.")
    bad_features = {"constant_columns": [], "redundant_correlation": []}

# ==========================================
# 1. PREPARE X AND y
# ==========================================
print("--- 1. SEPARATING X AND y ---")
identifiers = ['Flow ID', 'Source IP', 'Destination IP', 'Timestamp']
X = global_df.drop(columns=['Label'] + [c for c in identifiers if c in global_df.columns])
y = global_df['Label']

# ==========================================
# 2. SAFE STRATIFIED SPLIT
# ==========================================
print("\n--- 2. HANDLING RARE CLASSES & SPLITTING ---")

# Identify classes with only 1 sample
class_counts = y.value_counts()
rare_classes = class_counts[class_counts < 2].index.tolist()

if len(rare_classes) > 0:
    print(f"   > Detected rare classes (1 sample): {rare_classes}")
    print("   > Removing them from split and forcing them into Train set...")

# Create mask
mask_rare = y.isin(rare_classes)

# Separate the singletons
X_rare = X[mask_rare]
y_rare = y[mask_rare]

# The rest of the data (Safe to split)
X_common = X[~mask_rare]
y_common = y[~mask_rare]

# Split the common data
X_train_c, X_test_c, y_train_c, y_test_c = train_test_split(
    X_common, y_common, 
    test_size=0.2, 
    random_state=42, 
    stratify=y_common
)

# Merge Rare data back into Training Set ONLY
X_train = pd.concat([X_train_c, X_rare])
y_train = pd.concat([y_train_c, y_rare])

# Test set is just the common test split
X_test = X_test_c
y_test = y_test_c

print(f"   > Training Data: {X_train.shape}")
print(f"   > Testing Data:  {X_test.shape}")

# ==========================================
# 3. APPLY PREPROCESSOR
# ==========================================
print("\n--- 3. FITTING PREPROCESSOR ---")

# Initialize with the Bad Features Dictionary
preprocessor = CICIDS_Preprocessor(bad_feature_dict=bad_features, scheme='A')

# Fit on Train
preprocessor.fit(X_train, y_train)

# Transform
print("\n--- 4. TRANSFORMING ---")
X_train_processed = preprocessor.transform(X_train)
X_test_processed = preprocessor.transform(X_test)

print(f"Final Train Shape: {X_train_processed.shape}")
print(f"Final Test Shape:  {X_test_processed.shape}")