In [1]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from glob import glob

In [11]:
def running_mean(x, N=50):
    """Compute running mean with window size N."""
    return np.convolve(x, np.ones(N)/N, mode='valid')

def plot_dataset(base_path, dataset_name, window=10, step_stride=10):
    dataset_path = os.path.join(base_path, dataset_name)
    reg_folders = sorted([f for f in os.listdir(dataset_path) if f.startswith("reg_")])

    iso_plot = plt.figure(figsize=(10, 6))
    norm_plot = plt.figure(figsize=(10, 6))

    iso_ax = iso_plot.add_subplot(111)
    norm_ax = norm_plot.add_subplot(111)

    for reg in reg_folders:
        reg_path = os.path.join(dataset_path, reg)
        csv_files = glob(os.path.join(reg_path, "*.csv"))
        if not csv_files:
            continue

        # take the first csv
        df = pd.read_csv(csv_files[0])

        # stride (take every step_stride rows)
        df = df.iloc[::step_stride, :]

        steps = df["Step"].values
        norm = df["Predicted_Noise_L2Norm"].values
        iso = df["IsoValue"].values

        # smooth with running mean
        steps_smooth = steps[window-1:]
        norm_smooth = running_mean(norm, N=window)
        iso_smooth = running_mean(iso, N=window)

        # plot separately
        norm_ax.plot(steps_smooth, norm_smooth, label=f"{reg}")
        iso_ax.plot(steps_smooth, iso_smooth, label=f"{reg}")

    # Norm plot
    norm_ax.set_title(f"Dataset: {dataset_name} - Norm")
    norm_ax.set_xlabel("Steps")
    norm_ax.set_ylabel("Predicted_Noise_L2Norm")
    norm_ax.legend()
    norm_ax.grid(True)

    norm_save = os.path.join(base_path, f"{dataset_name}_norm.png")
    norm_plot.savefig(norm_save, dpi=300)
    plt.close(norm_plot)

    # Iso plot
    iso_ax.set_title(f"Dataset: {dataset_name} - Iso")
    iso_ax.set_xlabel("Steps")
    iso_ax.set_ylabel("IsoValue")
    iso_ax.legend()
    iso_ax.grid(True)

    iso_save = os.path.join(base_path, f"{dataset_name}_iso.png")
    iso_plot.savefig(iso_save, dpi=300)
    plt.close(iso_plot)

In [13]:
base_dir = "logs/norms_iso_2025-08-23_12-58-03"
datasets = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))]

for ds in datasets:
    plot_dataset(base_dir, ds, window=50, step_stride=50)