In [None]:
import pandas as pd
from tensorboard.backend.event_processing import event_accumulator
import os, re, shutil
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
def load_nested_tb_logs(root_dir):
    all_data = []

    # Walk through the directory tree
    for root, dirs, files in os.walk(root_dir):
        # Check if there are any tfevents files in this specific folder
        if any(f.startswith("events.out.tfevents") for f in files):
            # Extract the folder name to use as a category/run label
            folder_name = os.path.basename(root)

            # Initialize accumulator for this specific subdirectory
            acc = event_accumulator.EventAccumulator(root)
            acc.Reload()

            for tag in acc.Tags()['scalars']:
                events = acc.Scalars(tag)
                df_temp = pd.DataFrame(events)

                # We add 'metric' (e.g., value) and 'sub_dir' (e.g., test_noise_avs...)
                df_temp['metric_tag'] = tag
                df_temp['run_group'] = folder_name

                all_data.append(df_temp)

    if not all_data:
        print("No event files found in the specified path.")
        return pd.DataFrame()

    # Combine all found data
    master_df = pd.concat(all_data, ignore_index=True)

    # Cleanup: Convert time and reorder columns
    master_df['wall_time'] = pd.to_datetime(master_df['wall_time'], unit='s')

    return master_df

In [None]:
path = "../train_outputs/2059323/Test_record/ACL_ViT16_aclifa_2gpu/tensorboard/epoch0"

df = load_nested_tb_logs(path)

# Example output
print(f"Loaded {len(df)} data points.")
print(df.head())

In [None]:
def get_thr(ss: str):
    match = re.search(r'\(([\d\.]+)\)', ss)
    if match != None:
        return float(match.group(1))

def get_metric(ss: str):
    match = re.search(r'\([\d\.]+\)_(.*)$', ss)
    if match != None:
        return match.group(1)

def get_audio_type(ss: str):
    match = re.search(r'(std|silence|noise)', ss)
    if match != None:
        return match.group(1)

def get_dataset(ss: str):
    match = re.search(r'(avs_ms3|avs_s4|vggss|exvggss|vggsound|flickr|exflickr|avatar)', ss)
    if match != None:
        return match.group(1)

In [None]:
df['threshold'] = df['run_group'].apply(lambda x: get_thr(str(x)))
df['metric'] = df['run_group'].apply(lambda x: get_metric(str(x)))
df['audio_type'] = df['metric_tag'].apply(lambda x: get_audio_type(str(x)))
df['dataset'] = df['run_group'].apply(lambda x: get_dataset(str(x)))
df.drop(['wall_time', 'metric_tag', 'run_group'],axis=1, inplace=True)

In [None]:
df.info()
# df[df['dataset'].isnull()]

In [None]:
# wrong_list = list(filter(lambda x: re.search(r'\(s4\)', x), os.listdir(path)))
# corrected_list = [x.replace('(s4)', '_s4') for x in wrong_list]

# for wr, corr in zip(wrong_list, corrected_list):
#     # print(os.path.join(path, wr), '-->', os.path.join(path, corr))
#     shutil.move(os.path.join(path, wr), os.path.join(path, corr))

In [None]:
def plot_all_metrics(df):
    # 1. Setup the style
    sns.set_theme(style="whitegrid")

    # 2. Define the strict line style mapping you requested
    # (5, 5) = dashed, (1, 2) = dotted, "" = solid
    style_map = {
        'std': (None, None),  # Solid
        'noise': (5, 5),      # Dashed
        'silence': (1, 2)     # Dotted
    }

    # 3. Get list of unique metrics
    metrics = df['metric'].unique()

    for m in metrics:
        # Filter data for just this metric
        subset = df[df['metric'] == m].copy()

        # Sort by threshold to ensure lines don't "zig-zag"
        subset = subset.sort_values('threshold')

        # Create a new figure
        plt.figure(figsize=(10, 6))

        # 4. Create the lineplot
        ax = sns.lineplot(
            data=subset,
            x='threshold',
            y='value',
            hue='dataset',        # Color by dataset
            style='audio_type',   # Line style by audio type
            dashes=style_map,     # Apply our specific styles
            markers=True,         # Keep dots on the data points
            linewidth=2
        )

        # 5. Formatting
        plt.title(f"Metric: {m}", fontsize=15, fontweight='bold')
        plt.xlabel("Threshold")
        plt.ylabel("Value")
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left') # Legend outside the plot
        plt.tight_layout()

        # Optional: Save each plot
        # plt.savefig(f"plot_{m}.png")

        plt.show()


In [None]:
plot_all_metrics(df)