# White-box Performances


In [18]:
import os
import glob
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Union, List, Dict, Tuple

In [19]:
def get_attack_config(attack_type):
    """Returns the specific configuration for different attack types."""

    if attack_type == "fgsm":
        return {
            "param_name": "eps",
            "param_display": "Epsilon",
            "param_format": "{:.3f}",
            "norm_type": "l2",
            "norm_display": "L2",
        }
    elif attack_type == "cw":
        return {
            "param_name": "c",
            "param_display": "C Parameter",
            "param_format": "{:.3f}",
            "norm_type": "l2",
            "norm_display": "L2",
        }
    elif attack_type == "jsma":
        return {
            "param_name": "theta",
            "param_display": "Theta",
            "param_format": "{:.3f}",
            "norm_type": "l2",
            "norm_display": "L2",
        }
    else:
        raise ValueError(f"Unsupported attack type: {attack_type}")


def extract_data(file_path, encoding_name, attack_config):
    """Extract data from CSV file based on attack configuration."""
    data_points = []

    try:
        df = pd.read_csv(file_path)
        param_name = attack_config["param_name"]
        norm_type = attack_config["norm_type"]

        for _, row in df.iterrows():
            # Extract parameter from attack_params
            try:
                param_dict = eval(row["attack_params"])
                param_value = param_dict.get(param_name, None)
            except:
                param_value = None

            # Create data point
            data_point = {
                "encoding": encoding_name,
                "param_value": param_value,
                "accuracy": row["accuracy"],
                f"pct_change_{norm_type}": row[f"pct_change_{norm_type}"],
            }

            data_points.append(data_point)

    except Exception as e:
        print(f"Error processing {file_path}: {e}")

    return data_points


def estimate_param_for_target_pct_change(
    df: pd.DataFrame,
    target_pct_change: float,
    param_col: str = "param_value",
    pct_change_col: str = "pct_change_linf",
    encoding_col: str = None,
    encoding_value: str = None,
    interpolation_method: str = "linear",
) -> tuple[float, bool]:
    """
    Estimate the parameter value that would produce a specific percentage change.
    If the target is outside the available range, it clamps to the nearest value
    and returns a flag.
    """

    if encoding_col is not None and encoding_value is not None:
        df_filtered = df[df[encoding_col] == encoding_value].copy()
    else:
        df_filtered = df.copy()

    df_filtered = df_filtered.sort_values(by=param_col)
    param_values = df_filtered[param_col].values
    pct_changes = df_filtered[pct_change_col].values

    was_clamped = False #

    if target_pct_change < pct_changes.min():
        print(
            f"Warning: Target {target_pct_change}% is below the available range "
            f"[{pct_changes.min():.4f}%, {pct_changes.max():.4f}%]. Clamping to min."
        )
        # Find the parameter that corresponds to the minimum pct_change
        estimated_param = param_values[np.argmin(pct_changes)]
        was_clamped = True
        return estimated_param, was_clamped

    elif target_pct_change > pct_changes.max():
        print(
            f"Warning: Target {target_pct_change}% is above the available range "
            f"[{pct_changes.min():.4f}%, {pct_changes.max():.4f}%]. Clamping to max."
        )
        # Find the parameter that corresponds to the maximum pct_change
        estimated_param = param_values[np.argmax(pct_changes)]
        was_clamped = True
        return estimated_param, was_clamped

    # If we are within the range, proceed with interpolation as before
    lower_idx = None
    for i in range(len(pct_changes) - 1):
        if pct_changes[i] <= target_pct_change <= pct_changes[i + 1]:
            lower_idx = i
            break

    if lower_idx is None: # Should not happen with the checks above, but as a fallback
        lower_idx = len(pct_changes) - 2


    lower_param = param_values[lower_idx]
    upper_param = param_values[lower_idx + 1]
    lower_pct = pct_changes[lower_idx]
    upper_pct = pct_changes[lower_idx + 1]

    if interpolation_method == "linear":
        estimated_param = lower_param + (target_pct_change - lower_pct) * (
            upper_param - lower_param
        ) / (upper_pct - lower_pct)
    elif interpolation_method == "log":
        if ( lower_param <= 0 or lower_pct <= 0 or upper_param <= 0 or upper_pct <= 0 or target_pct_change <= 0):
            raise ValueError("Log interpolation requires positive values")
        log_lower_param = np.log10(lower_param)
        log_upper_param = np.log10(upper_param)
        log_lower_pct = np.log10(lower_pct)
        log_upper_pct = np.log10(upper_pct)
        log_target = np.log10(target_pct_change)
        log_estimated_param = log_lower_param + (log_target - log_lower_pct) * (
            log_upper_param - log_lower_param
        ) / (log_upper_pct - log_lower_pct)
        estimated_param = 10**log_estimated_param
    else:
        raise ValueError("Method must be 'linear' or 'log'")

    return estimated_param, was_clamped


def estimate_accuracy_for_param(
    df: pd.DataFrame,
    param_value: float,
    param_col: str = "param_value",
    accuracy_col: str = "accuracy",
    encoding_col: str = None,
    encoding_value: str = None,
    interpolation_method: str = "linear",
) -> float:
    """
    Estimate the accuracy for a specific parameter value using interpolation.

    Args:
        df: DataFrame containing parameter values and corresponding accuracy values
        param_value: The parameter value to find the accuracy for
        param_col: Column name for the parameter values
        accuracy_col: Column name for the accuracy values
        encoding_col: Optional column name for encoding/dataset type
        encoding_value: Optional specific encoding value to filter by
        method: Interpolation method ('linear' or 'log')

    Returns:
        Estimated accuracy at the specified parameter value
    """
    if encoding_col is not None and encoding_value is not None:
        df_filtered = df[df[encoding_col] == encoding_value].copy()
    else:
        df_filtered = df.copy()

    df_filtered = df_filtered.sort_values(by=param_col)

    param_values = df_filtered[param_col].values
    accuracy_values = df_filtered[accuracy_col].values

    if param_value < param_values.min() or param_value > param_values.max():
        print(
            f"Warning: Parameter value {param_value} is outside the available range "
            f"[{param_values.min():.4f}, {param_values.max():.4f}]. Extrapolating."
        )

    lower_idx = None
    for i in range(len(param_values) - 1):
        if param_values[i] <= param_value <= param_values[i + 1]:
            lower_idx = i
            break

    if lower_idx is None:
        if param_value < param_values[0]:
            lower_idx = 0
        else:
            lower_idx = len(param_values) - 2

    # Get the bracketing points
    lower_param = param_values[lower_idx]
    upper_param = param_values[lower_idx + 1]
    lower_acc = accuracy_values[lower_idx]
    upper_acc = accuracy_values[lower_idx + 1]

    if interpolation_method == "linear":
        estimated_accuracy = lower_acc + (param_value - lower_param) * (
            upper_acc - lower_acc
        ) / (upper_param - lower_param)
    elif interpolation_method == "log":
        if lower_param <= 0 or upper_param <= 0 or param_value <= 0:
            raise ValueError("Log interpolation requires positive parameter values")

        if lower_acc <= 0 or upper_acc <= 0:
            print("Warning: Accuracy values contain zeros. Using linear interpolation.")
            estimated_accuracy = lower_acc + (param_value - lower_param) * (
                upper_acc - lower_acc
            ) / (upper_param - lower_param)
        else:
            log_lower_param = np.log10(lower_param)
            log_upper_param = np.log10(upper_param)
            log_lower_acc = np.log10(lower_acc)
            log_upper_acc = np.log10(upper_acc)
            log_param = np.log10(param_value)

            log_estimated_acc = log_lower_acc + (log_param - log_lower_param) * (
                log_upper_acc - log_lower_acc
            ) / (log_upper_param - log_lower_param)

            estimated_accuracy = 10**log_estimated_acc

    else:
        raise ValueError("Method must be 'linear' or 'log'")

    return estimated_accuracy

def find_accuracy_at_target_pct_change(
    results_df: pd.DataFrame,
    target_pct_change: float,
    attack_config: dict,
    encoding: str,
    interpolation_method: str = "linear",
    verbose: bool = False,
) -> dict:
    """
    Finds the parameter for the target pct change, gets the accuracy,
    and now also tracks if the parameter was clamped.
    """
    norm_type = attack_config["norm_type"]
    param_name = attack_config["param_name"]
    param_format = attack_config.get("param_format", "{:.4f}")

    param_value, was_clamped = estimate_param_for_target_pct_change(
        results_df,
        target_pct_change=target_pct_change,
        param_col="param_value",
        pct_change_col=f"pct_change_{norm_type}",
        encoding_col="encoding",
        encoding_value=encoding,
        interpolation_method=interpolation_method,
    )

    accuracy = estimate_accuracy_for_param(
        results_df,
        param_value=param_value,
        param_col="param_value",
        accuracy_col="accuracy",
        encoding_col="encoding",
        encoding_value=encoding,
        interpolation_method=interpolation_method,
    )

    if verbose:
        print(
            f"\nFor {encoding} encoding at {target_pct_change}% change in {norm_type} norm:"
        )
        print(f"  {param_name} = {param_format.format(param_value)}")
        print(f"  Estimated accuracy: {accuracy:.4f}")

    return {
        "encoding": encoding,
        "target_pct_change": target_pct_change,
        "param_value": param_value,
        "param_name": param_name,
        "accuracy": accuracy,
        "norm_type": norm_type,
        "was_clamped": was_clamped
    }


def analyze_at_threshold(
    attack_type: str,
    results_df: pd.DataFrame,
    thresholds: list = [3.0, 10.0, 20.0],
    encodings: list = None,
    training_mode: str = "baseline",
    create_table: bool = True,
    baseline_model_metrics: dict = None,
) -> pd.DataFrame:
    """
    Analyzes model performance and now adds a 'Notes' column to the
    summary table if clamping occurred.
    """
    attack_config = get_attack_config(attack_type)
    if encodings is None:
        encodings = results_df["encoding"].unique()

    all_results = []
    for threshold in thresholds:
        for encoding in encodings:
            result = find_accuracy_at_target_pct_change(
                results_df,
                target_pct_change=threshold,
                attack_config=attack_config,
                encoding=encoding,
                interpolation_method="log",
            )
            all_results.append(result)

    results_table = pd.DataFrame(all_results)

    results_table["residual"] = results_table.apply(
        lambda row: baseline_model_metrics[row["encoding"]]["accuracy"]
        - row["accuracy"],
        axis=1,
    )
    results_table["residual_pct_change"] = results_table.apply(
        lambda row: (
            (baseline_model_metrics[row["encoding"]]["accuracy"] - row["accuracy"])
            / baseline_model_metrics[row["encoding"]]["accuracy"]
        )
        * 100,
        axis=1,
    )

    if create_table:
        print(f"\n{'='*80}")
        print(f"SUMMARY TABLE FOR {attack_type.upper()} ATTACK ({training_mode})")
        print(f"{'='*80}")

        # Format the table for display
        display_table = results_table.copy()

        display_table["param_value"] = display_table.apply(
            lambda row: attack_config["param_format"].format(row["param_value"]), axis=1
        )
        display_table["baseline_accuracy"] = display_table["encoding"].apply(
            lambda enc: f"{baseline_model_metrics[enc]['accuracy']:.4f}"
        )
        display_table["accuracy"] = display_table["accuracy"].apply(
            lambda x: f"{x:.4f}"
        )
        display_table["target_pct_change"] = display_table["target_pct_change"].apply(
            lambda x: f"{x:.1f}%"
        )
        display_table["residual"] = display_table["residual"].apply(
            lambda x: f"{x:.4f}"
        )
        display_table["residual_pct_change"] = display_table[
            "residual_pct_change"
        ].apply(lambda x: f"-{x:.2f}%")

        display_table['Notes'] = display_table['was_clamped'].apply(
            lambda x: '*Target not reached' if x else ''
        )

        column_renames = {
            "encoding": "Encoding",
            "target_pct_change": f'% Change in {attack_config["norm_display"]}',
            "param_value": attack_config["param_display"],
            "baseline_accuracy": "Baseline Accuracy",
            "accuracy": "Accuracy",
            "residual_pct_change": "Decrease in Accuracy (%)",
            "Notes": "Notes"
        }

        final_columns = [
            "Encoding",
            f'% Change in {attack_config["norm_display"]}',
            attack_config["param_display"],
            "Baseline Accuracy",
            "Accuracy",
            "Decrease in Accuracy (%)",
            "Notes"
        ]
        display_table = display_table.rename(columns=column_renames)[final_columns]
        
        pd.set_option("display.max_rows", None)
        pd.set_option("display.max_columns", None)
        pd.set_option("display.width", 1000)
        print(display_table)
        pd.reset_option("display.max_rows")
        pd.reset_option("display.max_columns")
        pd.reset_option("display.width")

    return results_table



def filter_params_for_visibility(results_df, attack_type):
    """Filter parameters for visibility in the plot."""
    if attack_type == "fgsm":
        keep_params = [0.0005, 0.001, 0.01, 0.05, 0.1, 0.5, 1]
    elif attack_type == "cw":
        keep_params = [0.001, 0.01, 0.05, 0.1, 0.5, 1, 10, 100, 1000]
    elif attack_type == "jsma":
        keep_params = [0.0005, 0.001, 0.01, 0.05, 0.1, 0.5, 1]
    results_df = results_df[results_df["param_value"].isin(keep_params)]
    return results_df


def create_accuracy_plot(
    dataset,
    results_df,
    attack_config,
    attack_type,
    training_mode,
    use_all_values=True,
    show_plot=True,
):
    """Create a plot showing accuracy vs attack parameter."""
    plt.figure(figsize=(10, 6))

    if not use_all_values:
        results_df = filter_params_for_visibility(results_df, attack_type)

    param_display = attack_config["param_display"]
    param_format = attack_config["param_format"]

    sns.lineplot(
        data=results_df,
        x="param_value",
        y="accuracy",
        hue="encoding",
        marker="o",
        linewidth=2.5,
        markersize=8,
    )

    plt.xscale("log") 

    param_values = sorted(results_df["param_value"].unique())
    plt.xticks(param_values, [param_format.format(x) for x in param_values])

    plt.xlabel(f"{param_display} (log scale)", fontsize=12)
    plt.ylabel("Accuracy", fontsize=12)
    plt.grid(True, alpha=0.3)
    plt.legend(title="Encoding")

    plt.tight_layout()
    if show_plot:
        plt.show()

    # Save the plot
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    output_path = (
        f"{OUTPUT_DIR}/{dataset}_{attack_type}_baseline_accuracy_comparison.png"
    )
    plt.savefig(output_path)

    if not show_plot:
        plt.close()

    return plt


def create_norm_change_plot(
    dataset,
    results_df,
    attack_config,
    attack_type,
    training_mode,
    use_all_values=True,
    show_plot=True,
):
    """Create a plot showing percent change in norm vs attack parameter."""
    plt.figure(figsize=(10, 6))

    param_display = attack_config["param_display"]
    param_format = attack_config["param_format"]
    norm_type = attack_config["norm_type"]
    norm_display = attack_config["norm_display"]

    if not use_all_values:
        results_df = filter_params_for_visibility(results_df, attack_type)

    sns.lineplot(
        data=results_df,
        x="param_value",
        y=f"pct_change_{norm_type}",
        hue="encoding",
        marker="o",
        linewidth=2.5,
        markersize=8,
    )

    plt.xscale("log") 

    param_values = sorted(results_df["param_value"].unique())
    plt.xticks(param_values, [param_format.format(x) for x in param_values])

    plt.xlabel(f"{param_display} (log scale)", fontsize=12)
    plt.ylabel(f"% Change in {norm_display} Norm", fontsize=12)

    plt.axhline(y=3, color="r", linestyle="--", linewidth=1, label="3% Budget")
    plt.axhline(y=10, color="r", linestyle="--", linewidth=1, label="10% Budget")
    plt.axhline(y=20, color="r", linestyle="--", linewidth=1, label="20% Budget")

    plt.grid(True, alpha=0.3)
    plt.legend(title="Encoding")

    plt.tight_layout()
    if show_plot:
        plt.show()

    os.makedirs(OUTPUT_DIR, exist_ok=True)
    output_path = (
        f"{OUTPUT_DIR}/{dataset}_{attack_type}_baseline_{norm_type}norm_comparison.png"
    )
    plt.savefig(output_path)

    if not show_plot:
        plt.close()

    return plt

def visualize_attack_with_thresholds(
    attack_type="fgsm",
    base_dir="attack_sweep",
    dataset="mirai",
    encodings=["DM", "Stats", "Raw"],
    baseline_model_metrics=None,
    training_mode="baseline",
    interpolation_method="linear",
    thresholds=[3.0, 10.0, 20.0],
    use_all_values=True,
    show_plots=True,
):
    """
    Generate visualizations and analysis for the specified attack type with threshold analysis.
    """
    attack_config = get_attack_config(attack_type)

    file_paths = {}
    for encoding in encodings:
        file_paths[encoding] = (
            f"{base_dir}/{dataset}/{training_mode}/{encoding}_{attack_type}.csv"
        )

    all_data = []
    for encoding, file_path in file_paths.items():
        encoding_data = extract_data(file_path, encoding, attack_config)
        all_data.extend(encoding_data)

    results_df = pd.DataFrame(all_data)

    if results_df.empty:
        print("No data to plot. Exiting.")
        return None, None, None, None

    accuracy_plot = create_accuracy_plot(
        dataset,
        results_df,
        attack_config,
        attack_type,
        training_mode,
        use_all_values,
        show_plots,
    )

    norm_plot = create_norm_change_plot(
        dataset,
        results_df,
        attack_config,
        attack_type,
        training_mode,
        use_all_values,
        show_plots,
    )

    if attack_type == "fgsm":
        pct_change_col = "pct_change_l2"
    elif attack_type == "cw" or attack_type == "deepfool":
        pct_change_col = "pct_change_l2"
    elif attack_type == "jsma":
        pct_change_col = "pct_change_l2"
    
    for encoding in encodings:
        for target_pct_change in [3.0, 10.0, 20.0]:
            suggested_param, _ = estimate_param_for_target_pct_change(
                results_df,
                target_pct_change=target_pct_change,
                param_col="param_value",
                encoding_col="encoding",
                encoding_value=encoding,
                pct_change_col=pct_change_col,
                interpolation_method=interpolation_method,
            )
            print(
                f"Breakpoint {target_pct_change}, Param for {encoding} encoding: {suggested_param:.4f}"
            )

    results_table = analyze_at_threshold(
        attack_type=attack_type,
        results_df=results_df,
        thresholds=thresholds,
        encodings=encodings,
        training_mode=training_mode,
        baseline_model_metrics=baseline_model_metrics,
    )

    return accuracy_plot, norm_plot, results_table, results_df

## BASELINE


In [None]:
TRAINING_MODE = "baseline"

INTERPOLATION_METHOD = "log"

OUTPUT_DIR = "results/attack_sweep/plots"
ENCODINGS = ["DM", "Stats", "Raw"]
BASE_DIR = "results/attack_sweep"

DATASETS = ["mirai", "unsw-nb15"] 
ATTACK_TYPES = ["fgsm", "cw", "jsma"]

for dataset in DATASETS:
    if dataset == "mirai":
        baseline_model_metrics = {
            "DM": {"accuracy": 0.9916},
            "Stats": {"accuracy": 0.9918},
            "Raw": {"accuracy": 0.6817},
            "Flows": {"accuracy": 0.6714}
        }
    elif dataset == "unsw-nb15":
        baseline_model_metrics = {
            "DM": {"accuracy": 0.9682},
            "Stats": {"accuracy": 0.9703},
            "Raw": {"accuracy": 0.8641},  
            "Flows": {"accuracy": 0.8542} 
        }
    else:
        raise ValueError(f"Unsupported dataset: {dataset}")

    for attack_type in ATTACK_TYPES:
        print(f"Running analysis for {dataset} dataset with {attack_type} attack...")
        accuracy_plot, norm_plot, results_table, _ = visualize_attack_with_thresholds(
            attack_type=attack_type,
            base_dir=BASE_DIR,
            dataset=dataset,
            encodings=ENCODINGS,
            training_mode=TRAINING_MODE,
            interpolation_method=INTERPOLATION_METHOD,
            use_all_values=True,
            baseline_model_metrics=baseline_model_metrics,
            show_plots=False,
        )

Running analysis for mirai dataset with fgsm attack...
Breakpoint 3.0, Param for DM encoding: 0.0024
Breakpoint 10.0, Param for DM encoding: 0.0095
Breakpoint 20.0, Param for DM encoding: 0.0196
Breakpoint 3.0, Param for Stats encoding: 0.0129
Breakpoint 10.0, Param for Stats encoding: 0.0438
Breakpoint 20.0, Param for Stats encoding: 0.0882
Breakpoint 3.0, Param for Raw encoding: 0.0146
Breakpoint 10.0, Param for Raw encoding: 0.0491
Breakpoint 20.0, Param for Raw encoding: 0.1002

SUMMARY TABLE FOR FGSM ATTACK (baseline)
  Encoding % Change in L2 Epsilon Baseline Accuracy Accuracy Decrease in Accuracy (%) Notes
0       DM           3.0%   0.002            0.9916   0.9916                   -0.00%      
1    Stats           3.0%   0.013            0.9918   0.9916                   -0.02%      
2      Raw           3.0%   0.015            0.6817   0.4670                  -31.49%      
3       DM          10.0%   0.009            0.9916   0.9916                   -0.00%      
4    Stats 

## DEFENSES


In [None]:
def format_results_for_table(
    results_collection,
    attack_types,
    training_modes,
    eps_values,
    encodings,
    budgets=[3, 10, 20],
):
    """
    Format results into LaTeX tables and CSV files for easy copying.

    Parameters:
    - results_collection: Dictionary where results_collection[budget][encoding][attack][training][eps] = accuracy
    - attack_types: List of attack types (e.g., ['fgsm', 'cw'])
    - training_modes: List of training modes (e.g., ['pgd', 'trades', 'mart'])
    - eps_values: List of epsilon values (e.g., [0.001, 0.01, 0.1])
    - encodings: List of encoding types (e.g., ['DM', 'Stats', 'Raw'])
    - budgets: List of budget percentages (default [3, 10, 20])

    Returns:
    - Dictionary with LaTeX tables and CSV strings for each training mode
    """
    results = {}

    for training_mode in training_modes:
        # Generate LaTeX table
        latex = f"""\\begin{{table*}}[t]
  \\centering
  \\begin{{subtable}}[t]{{0.49\\textwidth}}
    \\caption{{UNSW: Accuracy of defenses}}
    \\label{{tab:unsw_{training_mode}_defense}}
    \\scriptsize
    \\setlength{{\\tabcolsep}}{{2pt}}
      \\begin{{tabular}}{{lc|{('c' * len(eps_values) + '|') * len(attack_types)}}}
    \\toprule
    \\multicolumn{{11}}{{c}}{{\\textbf{{{training_mode.upper()} ($\\epsilon$)}}}} \\\\
    \\cmidrule(lr){{1-{2 + len(attack_types) * len(eps_values)}}}
    \\multirow{{2}}{{*}}{{Budget}} & \\multirow{{2}}{{*}}{{Enc.}} """

        # Add epsilon values headers
        for _ in attack_types:
            for eps in eps_values:
                latex += f"& {eps} "
            latex += "\n       "

        latex += """\\\\
    """

        # Add cmidrules
        start_col = 3
        for _ in range(len(attack_types)):
            end_col = start_col + len(eps_values) - 1
            latex += f"\\cmidrule(lr){{{start_col}-{end_col}}}"
            start_col = end_col + 1

        # Add attack type headers
        latex += "\n      & "
        for attack in attack_types:
            latex += f"& \\multicolumn{{{len(eps_values)}}}{{c|}}{{{attack.upper()}}} "

        latex += """\\\\
    \\midrule\n"""

        # Add data rows
        for budget in budgets:
            latex += f"    \\multirow{{{len(encodings)}}}{{*}}{{{budget}\\%}} \n"

            for encoding in encodings:
                latex += f"     & {encoding}   "

                for attack in attack_types:
                    for eps in eps_values:
                        value = (
                            results_collection.get(budget, {})
                            .get(encoding, {})
                            .get(attack, {})
                            .get(training_mode, {})
                            .get(eps)
                        )
                        latex += f" & {value:.4f}" if value is not None else " & —"

                latex += " \\\\\n"

            if budget != budgets[-1]:
                latex += "    \\midrule\n"

        latex += """    \\bottomrule
  \\end{tabular}
  \\end{subtable}
\\end{table*}"""

        # Generate CSV
        csv = f"Budget,Encoding,"
        for attack in attack_types:
            for eps in eps_values:
                csv += f"{attack.upper()}/{eps},"
        csv = csv[:-1] + "\n"  # Remove trailing comma and add newline

        for budget in budgets:
            for encoding in encodings:
                csv += f"{budget}%,{encoding},"

                for attack in attack_types:
                    for eps in eps_values:
                        value = (
                            results_collection.get(budget, {})
                            .get(encoding, {})
                            .get(attack, {})
                            .get(training_mode, {})
                            .get(eps)
                        )
                        csv += f"{value:.4f}," if value is not None else ","

                csv = csv[:-1] + "\n"  # Remove trailing comma and add newline

        results[training_mode] = {"latex": latex, "csv": csv}

    return results

In [None]:
DATASET = "mirai" 

ATTACK_TYPES = ["fgsm", 'jsma', "cw"]
TRAINING_MODES = ["pgd", "trades", "mart"]
EPS_VALUES = [0.001, 0.01, 0.1]


INTERPOLATION_METHOD = "log"

OUTPUT_DIR = "results/attack_sweep/plots"
ENCODINGS = ["DM", "Stats", "Raw"]
BASE_DIR = "results/attack_sweep"

if DATASET == "mirai":
    baseline_model_metrics = {
        "DM": {"accuracy": 0.9916},
        "Stats": {"accuracy": 0.9918},
        "Raw": {"accuracy": 0.6817},
    }
elif DATASET == "unsw-nb15":
    baseline_model_metrics = {
        "DM": {"accuracy": 0.9682},
        "Stats": {"accuracy": 0.9703},
        "Raw": {"accuracy": 0.8641},
    }
else:
    raise ValueError(f"Unsupported dataset: {DATASET}")

results_collection = {}


for attack_type in ATTACK_TYPES:
    for training_mode in TRAINING_MODES:
        for eps in EPS_VALUES:
            accuracy_plot, norm_plot, results_table, results_df = (
                visualize_attack_with_thresholds(
                    attack_type=attack_type,
                    base_dir=BASE_DIR,
                    dataset=DATASET,
                    encodings=ENCODINGS,
                    training_mode=f"{training_mode}/eps{eps}",
                    interpolation_method=INTERPOLATION_METHOD,
                    use_all_values=True,
                    baseline_model_metrics=baseline_model_metrics,
                    show_plots=False,
                )
            )
            print(
                f"Results for {attack_type} attack with {training_mode} training and eps={eps}:"
            )

            print(results_df)

            for i, row in results_df.iterrows():
                encoding = row['encoding']
                accuracy = row['accuracy']
                pct_change = row['pct_change_l2']
                
                row_in_encoding_group = i % 3
                
                if row_in_encoding_group == 0:
                    budget = 3
                elif row_in_encoding_group == 1:
                    budget = 10
                else:  
                    budget = 20
                            
                if budget not in results_collection:
                    results_collection[budget] = {}
                if encoding not in results_collection[budget]:
                    results_collection[budget][encoding] = {}
                if attack_type not in results_collection[budget][encoding]:
                    results_collection[budget][encoding][attack_type] = {}
                if training_mode not in results_collection[budget][encoding][attack_type]:
                    results_collection[budget][encoding][attack_type][training_mode] = {}
                
                results_collection[budget][encoding][attack_type][training_mode][eps] = pct_change

Breakpoint 3.0, Param for DM encoding: 0.0024
Breakpoint 10.0, Param for DM encoding: 0.0094
Breakpoint 20.0, Param for DM encoding: 0.0194
Breakpoint 3.0, Param for Stats encoding: 0.0132
Breakpoint 10.0, Param for Stats encoding: 0.0448
Breakpoint 20.0, Param for Stats encoding: 0.0882
Breakpoint 3.0, Param for Raw encoding: 0.0147
Breakpoint 10.0, Param for Raw encoding: 0.0492
Breakpoint 20.0, Param for Raw encoding: 0.1000

SUMMARY TABLE FOR FGSM ATTACK (pgd/eps0.001)
  Encoding % Change in L2 Epsilon Baseline Accuracy Accuracy Decrease in Accuracy (%)                Notes
0       DM           3.0%   0.002            0.9916   0.9916                   -0.00%                     
1    Stats           3.0%   0.013            0.9918   0.9914                   -0.04%                     
2      Raw           3.0%   0.015            0.6817   0.5589                  -18.01%  *Target not reached
3       DM          10.0%   0.009            0.9916   0.9916                   -0.00%         