In [3]:
import pandas as pd
import numpy as np

def sample_dataset(input_path, output_path, label_col, sample_fraction=0.1):
    """
    Load a dataset, sample a subset while keeping class balance, and save it to a new file.

    Args:
        input_path (str): Path to the input dataset in .parquet format.
        output_path (str): Path to save the sampled dataset in .parquet format.
        label_col (str): Column name containing class labels.
        sample_fraction (float): Fraction of the dataset to keep (default 10%).

    Returns:
        None: Displays sampling stats and saves the dataset.
    """
    print("Loading dataset...")
    # Load the dataset
    df = pd.read_parquet(input_path)
    print(f"Original Dataset Size: {df.shape}")

    # Ensure the label column exists
    if label_col not in df.columns:
        raise ValueError(f"'{label_col}' not found in the dataset columns: {df.columns}")

    # Convert unhashable types in the label column to strings
    df[label_col] = df[label_col].apply(lambda x: ','.join(map(str, x)) if isinstance(x, (list, np.ndarray)) else str(x))

    # Perform class-balanced sampling
    print(f"Sampling {sample_fraction * 100}% of the dataset while maintaining class balance...")
    sampled_df = df.groupby(label_col, group_keys=False).apply(
        lambda x: x.sample(frac=sample_fraction, random_state=42)
    )

    # Save the sampled dataset
    print(f"Saving sampled dataset to {output_path}...")
    sampled_df.to_parquet(output_path, index=False)

    print(f"Sampled Dataset Size: {sampled_df.shape}")
    print("Sampling complete!")

# If this script is called directly
if __name__ == "__main__":
    # Define file paths and parameters
    INPUT_FILE = "../data/processed/ssh_attacks_decoded.parquet"  # Input dataset path
    OUTPUT_FILE = "../data/ssh_attacks_sampled_decoded.parquet"  # Output sampled dataset path
    LABEL_COLUMN = "Set_Fingerprint"  # Column containing the class labels
    SAMPLE_FRACTION = 0.1  # Fraction of the dataset to sample (10%)

    # Run the sampling process
    sample_dataset(INPUT_FILE, OUTPUT_FILE, LABEL_COLUMN, SAMPLE_FRACTION)


Loading dataset...
Original Dataset Size: (233035, 4)
Sampling 10.0% of the dataset while maintaining class balance...
Saving sampled dataset to ../data/ssh_attacks_sampled_decoded.parquet...
Sampled Dataset Size: (23297, 4)
Sampling complete!
