In [1]:
import tensorboard as tb
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
import numpy as np
import os
import re

In [2]:
class TensorBoardAnalyzer:
    def __init__(self, log_dir):
        self.log_dir = log_dir
        self.experiment_data = {"baseline": []}

    def extract_test_acc(self):
        """Extracts the final test_acc values from all relevant event files in the log directory."""
        for experiment_dir in os.listdir(self.log_dir):
            full_experiment_path = os.path.join(self.log_dir, experiment_dir)
            if os.path.isdir(full_experiment_path):
                # 处理baseline实验
                if 'baseline' in experiment_dir:
                    summary_dir = os.path.join(full_experiment_path, 'summary')
                    if os.path.exists(summary_dir):
                        event_files = [os.path.join(summary_dir, f) for f in os.listdir(summary_dir) if f.startswith("events.out.tfevents")]
                        if event_files:
                            latest_event_file = max(event_files, key=os.path.getctime)
                            event_acc = EventAccumulator(latest_event_file)
                            event_acc.Reload()
                            if 'iter_0/test/acc' in event_acc.Tags()['scalars']:
                                test_acc = event_acc.Scalars('iter_0/test/acc')[-1].value
                                self.experiment_data["baseline"].append(test_acc)
                elif 'swin' in experiment_dir and ('98' not in experiment_dir):
                    base_name = re.findall(r'cifar10_swin_vit_(.*?)_prune_ratio', experiment_dir)
                    if base_name:
                        base_name = base_name[0]
                        if base_name not in self.experiment_data:
                            self.experiment_data[base_name] = {60: [], 80: [], 90: [], 95: []}

                        prune_ratio = re.findall(r'prune_ratio_(\d+)', experiment_dir)
                        if prune_ratio:
                            prune_ratio = int(prune_ratio[0])
                            summary_dir = os.path.join(full_experiment_path, 'summary')
                            if os.path.exists(summary_dir):
                                event_files = [os.path.join(summary_dir, f) for f in os.listdir(summary_dir) if f.startswith("events.out.tfevents")]
                                if event_files:
                                    latest_event_file = max(event_files, key=os.path.getctime)
                                    event_acc = EventAccumulator(latest_event_file)
                                    event_acc.Reload()
                                    if 'iter_0/test/acc' in event_acc.Tags()['scalars']:
                                        test_acc = event_acc.Scalars('iter_0/test/acc')[-1].value
                                        self.experiment_data[base_name][prune_ratio].append(test_acc)
        return self.experiment_data


    def compute_statistics(self, values):
        """Computes the mean and standard deviation of a list of values."""
        mean_val = np.mean(values)
        std_dev_val = np.std(values)
        return mean_val, std_dev_val

    def generate_latex_table(self, stats):
        """Generates a LaTeX table with the given statistics."""
        table_header = r"""
        \begin{table}[!b]
        \centering
        \caption{Test Accuracy (\%) on CIFAR-10 with Swin.}
        \resizebox{\linewidth}{!}{
        \begin{tabular}{|l|c|c|c|c|}
        \hline
        """

        table_rows = ""

        # Baseline row
        if "baseline" in stats:
            baseline_mean, baseline_std_dev = self.compute_statistics(stats["baseline"])
            table_rows += "Baseline & {:.2f} & - & - & - \\\\\n".format(baseline_mean)
            table_rows += "\\hline\n"

        # Header for prune ratios
        table_rows += "Sparsity & 60 & 80 & 90 & 95 \\\\\n"
        table_rows += "\\hline\n"

        # Data rows
        methods = ["SNIP", "SNIP_DD", "GraSP", "GraSP_SD"]
        for method in methods:
            row_data = [method]
            for pr in [60, 80, 90, 95]:
                if method in stats and pr in stats[method] and stats[method][pr]:
                    mean, std_dev = self.compute_statistics(stats[method][pr])
                    row_data.append("{:.2f} $\\pm$ {:.2f}".format(mean, std_dev))
                else:
                    row_data.append("-")
            table_rows += " & ".join(row_data) + " \\\\\n\\hline\n"

        table_footer = r"""
        \end{tabular}}
        \label{swin}
        \end{table}
        """
        
        return table_header + table_rows + table_footer

    def analyze(self):
        """Main method to perform the analysis and generate the LaTeX table."""
        experiment_data = self.extract_test_acc()
        if not experiment_data:
            raise ValueError("No relevant test_acc values found in the TensorBoard logs.")
        
        latex_table = self.generate_latex_table(experiment_data)
        return latex_table




# Example usage
log_dir = '/mnt/ssd_3/DDOPaI/runs/pruning/cifar10/swin'
analyzer = TensorBoardAnalyzer(log_dir)
latex_table = analyzer.analyze()
print(latex_table)



        \begin{table}[!b]
        \centering
        \caption{Test Accuracy (\%) on CIFAR-10 with Swin.}
        \resizebox{\linewidth}{!}{
        \begin{tabular}{|l|c|c|c|c|}
        \hline
        Baseline & 84.22 & - & - & - \\
\hline
Sparsity & 60 & 80 & 90 & 95 \\
\hline
SNIP & 83.75 $\pm$ 0.80 & 82.85 $\pm$ 0.47 & 80.11 $\pm$ 1.07 & 78.60 $\pm$ 0.55 \\
\hline
SNIP_DD & 84.37 $\pm$ 0.25 & 83.16 $\pm$ 0.66 & 80.60 $\pm$ 0.75 & 79.16 $\pm$ 0.51 \\
\hline
GraSP & 83.72 $\pm$ 0.70 & 82.56 $\pm$ 0.49 & 81.76 $\pm$ 0.18 & 80.55 $\pm$ 0.34 \\
\hline
GraSP_SD & 83.68 $\pm$ 0.64 & 83.49 $\pm$ 0.75 & 82.87 $\pm$ 1.13 & 82.58 $\pm$ 1.53 \\
\hline

        \end{tabular}}
        \label{swin}
        \end{table}
        
