In [None]:
import os
import csv
import statistics

# Directory containing experiment folders
experiments_dir = "experiments"

# Collect rows for each table
parameters_rows = []  # For the parameters table
summary_rows = []     # For the summary table

# Keep track of all parameter keys seen across experiments
all_param_keys = set()

# New summary fields (except ASR) are expected in the main summary CSV.
summary_keys = [
    "Average Best Loss",
    "Std Best Loss",
    "Average Gradient Time",
    "Std Gradient Time",
    "Average Sampling Time",
    "Std Sampling Time",
    "Average PGD Time",
    "Std PGD Time",
    "Average Loss Time",
    "Std Loss Time",
    "Average Total Time",
    "Std Total Time",
]

# We map them to shorter column names and combine avg±std.
summary_column_map = {
    "Average Best Loss": "Loss",
    "Std Best Loss": "Loss",
    "Average Gradient Time": "Grad (s)",
    "Std Gradient Time": "Grad (s)",
    "Average Sampling Time": "Sampling (s)",
    "Std Sampling Time": "Sampling (s)",
    "Average PGD Time": "PGD (s)",
    "Std PGD Time": "PGD (s)",
    "Average Loss Time": "LossTime (s)",
    "Std Loss Time": "LossTime (s)",
    "Average Total Time": "Total (s)",
    "Std Total Time": "Total (s)",
}

# Final column order for the results table (no "k" column here).
final_result_keys = [
    "Experiment",
    "Loss",
    "Grad (s)",
    "Sampling (s)",
    "PGD (s)",
    "LossTime (s)",
    "Total (s)",
    "ASR (%)",    # Combined average ± std ASR
]
# Display headers for the summary table
final_result_display = [
    "Experiment",
    "Loss",
    "Grad (s)",
    "Sampling (s)",
    "PGD (s)",
    "LossTime (s)",
    "Total (s)",
    "ASR (\\%)",
]

def latex_escape(text: str) -> str:
    """Escape underscores and percent signs for LaTeX."""
    return text.replace("_", "\\_").replace("%", "\\%")

def format_numeric(value: str, param_name: str = "") -> str:
    """
    Parse and format numeric parameters:
      - For integer parameters (seed, iter, search_width, etc.), display as integer.
      - Otherwise, format as float with 4 decimals.
      - On failure, return the text with special characters escaped.
    """
    if value in ("True", "False"):
        return value
    if param_name.lower() in ("seed", "iter", "search_width", "min_search_width", "num_steps", "num_prompts", "k"):
        try:
            return str(int(float(value)))
        except ValueError:
            return latex_escape(value)
    try:
        float_val = float(value)
        return f"{float_val:.4f}"
    except ValueError:
        return latex_escape(value)

def combine_avg_std(avg_val: str, std_val: str) -> str:
    """
    Combine average and standard deviation values into a single string with ±.
    If either is missing or 'nan', return "0.0000±0.0000".
    """
    if (not avg_val) or (avg_val.lower() == "nan") or (not std_val) or (std_val.lower() == "nan"):
        return "0.0000±0.0000"
    try:
        avg_float = float(avg_val)
        std_float = float(std_val)
        return f"{avg_float:.4f}±{std_float:.4f}"
    except ValueError:
        return "0.0000±0.0000"

# Gather experiment folders in sorted order
experiment_folders = sorted(
    folder for folder in os.listdir(experiments_dir)
    if os.path.isdir(os.path.join(experiments_dir, folder))
)

# Process each experiment folder and assign IDs (exp1, exp2, etc.)
for i, folder in enumerate(experiment_folders, start=1):
    exp_id = f"exp{i}"
    folder_path = os.path.join(experiments_dir, folder)

    # ------------------ PARAMETERS TABLE ------------------
    parameters = {}
    params_file = os.path.join(folder_path, "parameters.csv")
    if os.path.exists(params_file):
        with open(params_file, newline="") as csvfile:
            reader = csv.DictReader(csvfile)
            for row in reader:
                param = row["Parameter"]
                if param == "debug_output":
                    continue
                value = row["Value"]
                parameters[param] = value
                all_param_keys.add(param)
    # ------------------ ADD "k" FROM EVALUATION (if available) ------------------
    eval_summary_file = os.path.join(folder_path, "evaluation", "summary.csv")
    k_value = "N/A"
    asr_values = []
    k_values = []
    if os.path.exists(eval_summary_file):
        with open(eval_summary_file, newline="") as csvfile:
            reader = csv.DictReader(csvfile)
            for row in reader:
                try:
                    asr_values.append(float(row["ASR (%)"]))
                    k_values.append(int(row["total"]))
                except Exception:
                    continue
        if k_values:
            avg_k = round(statistics.mean(k_values))
            k_value = f"{avg_k}"
        parameters["k"] = k_value
        all_param_keys.add("k")
    else:
        parameters["k"] = "N/A"
        all_param_keys.add("k")
    
    param_row = {"Experiment": exp_id}
    param_row.update(parameters)
    parameters_rows.append(param_row)

    # ------------------ RESULTS (SUMMARY) TABLE ------------------
    summary = {}
    summary_file = os.path.join(folder_path, "summary.csv")
    if os.path.exists(summary_file):
        with open(summary_file, newline="") as csvfile:
            reader = csv.DictReader(csvfile)
            for row in reader:
                metric = row["Metric"]
                value = row["Value"]
                summary[metric] = value

    # Use evaluation/summary.csv to extract ASR values.
    if os.path.exists(eval_summary_file):
        with open(eval_summary_file, newline="") as csvfile:
            reader = csv.DictReader(csvfile)
            for row in reader:
                try:
                    asr_values.append(float(row["ASR (%)"]))
                except Exception:
                    continue
        if asr_values:
            avg_asr = statistics.mean(asr_values)
            std_asr = statistics.stdev(asr_values) if len(asr_values) > 1 else 0.0
            summary["ASR (%)"] = f"{avg_asr:.4f}±{std_asr:.4f}"
        else:
            summary["ASR (%)"] = "N/A"
    else:
        summary["ASR (%)"] = "N/A"

    # Build a new row for the summary table.
    summary_row = {"Experiment": exp_id}
    summary_row["Loss"] = combine_avg_std(
        summary.get("Average Best Loss", ""), summary.get("Std Best Loss", "")
    )
    summary_row["Grad (s)"] = combine_avg_std(
        summary.get("Average Gradient Time", ""), summary.get("Std Gradient Time", "")
    )
    summary_row["Sampling (s)"] = combine_avg_std(
        summary.get("Average Sampling Time", ""), summary.get("Std Sampling Time", "")
    )
    summary_row["PGD (s)"] = combine_avg_std(
        summary.get("Average PGD Time", ""), summary.get("Std PGD Time", "")
    )
    summary_row["LossTime (s)"] = combine_avg_std(
        summary.get("Average Loss Time", ""), summary.get("Std Loss Time", "")
    )
    summary_row["Total (s)"] = combine_avg_std(
        summary.get("Average Total Time", ""), summary.get("Std Total Time", "")
    )
    summary_row["ASR (%)"] = summary.get("ASR (%)", "N/A")
    summary_rows.append(summary_row)

# ----------------------------------------------------------------------------
# 1) TABLE FOR PARAMETERS
# ----------------------------------------------------------------------------
param_columns = ["Experiment"] + sorted(all_param_keys, key=lambda x: x.lower())

latex_params = (
    "\\begin{table}[ht]\n"
    "\\centering\n"
    "\\resizebox{\\textwidth}{!}{%\n"  # Wrap in \\resizebox
    "  \\begin{tabular}{" + "l" * len(param_columns) + "}\n"
    "  \\hline\n"
)

escaped_headers = [latex_escape(col) for col in param_columns]
latex_params += "  " + " & ".join(escaped_headers) + " \\\\\n"
latex_params += "  \\hline\n"

for row in parameters_rows:
    row_values = []
    for col in param_columns:
        val = row.get(col, "")
        val_formatted = format_numeric(val, param_name=col)
        row_values.append(val_formatted)
    latex_params += "  " + " & ".join(row_values) + " \\\\\n"

latex_params += (
    "  \\hline\n"
    "  \\end{tabular}\n"
    "} % end of resizebox\n"
    "\\caption{Experiment parameters.}\n"
    "\\label{tab:experiment_parameters}\n"
    "\\end{table}\n"
)

# ----------------------------------------------------------------------------
# 2) TABLE FOR RESULTS (SUMMARY)
# ----------------------------------------------------------------------------
latex_summary = (
    "\\begin{table}[ht]\n"
    "\\centering\n"
    "\\resizebox{\\textwidth}{!}{%\n"  # Wrap in \\resizebox
    "  \\begin{tabular}{" + "l" * len(final_result_display) + "}\n"
    "  \\hline\n"
)

latex_summary += "  " + " & ".join(final_result_display) + " \\\\\n"
latex_summary += "  \\hline\n"

for row in summary_rows:
    row_values = []
    for key in final_result_keys:
        val = row.get(key, "")
        if key in ["Grad (s)", "Sampling (s)", "PGD (s)", "LossTime (s)", "Total (s)", "Loss", "ASR (%)"]:
            val_formatted = latex_escape(val)
        else:
            val_formatted = format_numeric(val, param_name=key)
        row_values.append(val_formatted)
    latex_summary += "  " + " & ".join(row_values) + " \\\\\n"

latex_summary += (
    "  \\hline\n"
    "  \\end{tabular}\n"
    "} % end of resizebox\n"
    "\\caption{Summary of experiment results.}\n"
    "\\label{tab:experiment_summary}\n"
    "\\end{table}\n"
)

print("%% ========== TABLE: PARAMETERS ========== %%")
print(latex_params)
print("\n%% ========== TABLE: SUMMARY ========== %%")
print(latex_summary)


\begin{table}[ht]
\centering
\resizebox{\textwidth}{!}{%
  \begin{tabular}{llllllllllllllll}
  \hline
  Experiment & alpha & dynamic\_search & eps & gcg\_attack & joint\_eval & k & min\_search\_width & model & name & num\_prompts & num\_steps & pgd\_after\_gcg & pgd\_attack & search\_width & seed \\
  \hline
  exp1 & 4/255 & False & 64/255 & False & False & 25.0000 & 0 & llava & Llava - PGD Only & 20 & 600 & False & True & 0 & 1 \\
  exp2 & 4/255 & False & 96/255 & True & True & 25.0000 & 512 & llava & Gemma - Joint Eval Long & 5 & 400 & False & True & 512 & 1 \\
  exp3 & 0/255 & False & 0/255 & True & False & 25.0000 & 512 & llava & Llava - GCG Only & 20 & 250 & False & False & 512 & 1 \\
  exp4 & 4/255 & False & 64/255 & False & False & 25.0000 & 0 & gemma & Gemma - PGD Only & 10 & 600 & False & True & 0 & 1 \\
  exp5 & 0/255 & False & 0/255 & True & False & 25.0000 & 512 & gemma & Gemma - GCG Only & 10 & 250 & False & False & 512 & 1 \\
  exp6 & 4/255 & False & 64/255 & True & False