In [53]:
import json
import pandas as pd
from tools.project import OUTPUT_PATH, RAW_PATH
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import numpy as np

In [2]:
def load_metrics(json_path):
    with open(json_path, 'r') as f:
        data = json.load(f)
    df = pd.DataFrame.from_dict(data, orient='index')
    return df

In [5]:
def rename_columns(df):
    """
    Rename columns to use LaTeX-friendly names.
    For example, fad_ti -> FAD$_{ti}$, kncc_other -> kNCC$_{other}$, etc.
    """
    column_map = {
        "fad_ti":      r"FAD$_{ti}$",
        "fad_other":   r"FAD$_{other}$",
        "clap_ti":     r"CLAP$_{ti}$",
        "clap_other":  r"CLAP$_{other}$",
        "kncc_ti":     r"kNCC$_{ti}$",
        "knco_ti":     r"kNCO$_{ti}$",
        "kncc_other":  r"kNCC$_{other}$",
        "knco_other":  r"kNCO$_{other}$"
    }
    return df.rename(columns=column_map)

In [11]:
metrics_df = load_metrics(OUTPUT_PATH('comparison', 'ti-musicgen-clean.json'))[['fad_ti', 'clap_ti', 'kncc_ti', 'knco_ti']]
latex_df = rename_columns(metrics_df)
latex_df = latex_df.round(3)
latex_str = latex_df.to_latex(
        index=True,            # Keep the row labels (the category names)
        escape=False,          # Important to allow LaTeX syntax in the column names
        float_format="%.3f",
        caption="Metrics Table with Mean and Variance",  # Example caption
        label="tab:metrics"    # Example label
    )
    
with open(OUTPUT_PATH('visualization', "metrics_table.tex"), "w") as f:
    f.write(latex_str)

In [41]:
spider_df = load_metrics(OUTPUT_PATH('comparison', 'ti-musicgen-clean.json'))
metrics = ['fad','fad_clap', 'clap']
fig = make_subplots(
    rows=2, 
    cols=2, 
    specs=[
        [{'type': 'polar'}, {'type': 'polar'}],
        [{'type': 'polar'}, {'type': 'polar'}]
    ],
    vertical_spacing=0.15,
    subplot_titles=[f"{m.upper()}" for m in metrics] + ['']
)
def get_position(i):
    return (i // 2 + 1, i % 2 + 1)
for i, metric in enumerate(metrics):
    r, c = get_position(i)
    show_legend_flag = (i == 0)
    fig.add_trace(
        go.Scatterpolar(
            r=spider_df[f"{metric}_ti"],
            theta=spider_df.index,
            fill='toself',
            name=f"Inwersja Tekstowa + MusicGen",      # Legend name
            legendgroup="Inwersja Tekstowa + MusicGen",
            showlegend=show_legend_flag,
            marker_color='blue'

        ),
        row=r, col=c
    )
    fig.add_trace(
        go.Scatterpolar(
            r=spider_df[f"{metric}_other"],
            theta=spider_df.index,
            fill='toself',
            name=f"MusicGen",   # Legend name
            legendgroup="MusicGen",
            showlegend=show_legend_flag,
            marker_color='red'
        ),
        row=r, col=c
    )

fig.update_layout(
    # title="All Spider Plots in One Figure",
    showlegend=True,      # Show one combined legend
    height=1000, width=800,
    legend=dict(
        orientation="h",      # horizontal legend
        yanchor="bottom",     # anchor the legend to the bottom
        y=-0.2,               # move it below the plot area
        xanchor="center",     # anchor it to the center of the figure
        x=0.5,                # place it at the horizontal center
        font=dict(size=14),   # increase legend text size
        bgcolor="rgba(255,255,255,0.8)"  # optional: add a white, semi-transparent background
    ),
    # margin=dict(b=100) 
)
fig.write_image(OUTPUT_PATH('comparison', 'spider.png'))
fig.show()

In [59]:
with open(RAW_PATH('run_stats', 'stats.json'), 'r') as fh:
    data = json.load(fh)

def filter_by_config(query: dict[str, any]):
    stats = {}
    for k, v in data.items():
        if not all(v['params'][q_k] == q_v for q_k, q_v in query.items()):
            continue
        stats[k] = v
    return stats
filtered_data = filter_by_config({'tokens_num': 5})

In [68]:
def flatten_stats(data: dict):
    return {k: v['stats'] for k,v in data.items()}

def get_min_stats(data: dict):
    stats = {}
    for run_id, run_data in data.items():
        for stat_name, stat_val in run_data['stats'].items():
            stat_data = stats.get(stat_name, [])
            stat_data.append(np.min(stat_val))
            stats[stat_name] = stat_data
    res = {}
    for stat_name, stat_val in stats.items():
        res.update({
            f'{stat_name}_mean': np.mean(stat_val),
            f'{stat_name}_std': np.std(stat_val),
        })
    return res


Unnamed: 0,fad_choir_mean,fad_choir_std,fad_piano_mean,fad_piano_std,fad_two-steps_mean,fad_two-steps_std,fad_8bit_mean,fad_8bit_std,fad_b-minor-rock_mean,fad_b-minor-rock_std,...,fad_pirates_mean,fad_pirates_std,fad_metal-solos_mean,fad_metal-solos_std,fad_saxophone-chillout_mean,fad_saxophone-chillout_std,fad_8bit-slow_mean,fad_8bit-slow_std,fad_avg_mean,fad_avg_std
5,4.710213,0.36561,3.801444,0.508196,3.923909,0.736688,8.697683,0.841894,1.542337,0.452974,...,2.87146,0.536448,4.036265,1.222431,2.550864,0.519466,5.463827,0.909338,4.097182,0.310381
10,3.919329,0.265587,2.877155,0.489414,3.466273,0.619526,7.121143,1.511341,1.257988,0.133727,...,2.314139,0.210021,3.434294,0.554083,1.643668,0.52871,5.221861,0.55742,3.476175,0.263368
20,3.430175,0.353355,2.575861,0.419546,3.037685,0.436505,6.299074,1.225706,1.262737,0.412945,...,3.004607,0.367701,2.692576,0.590241,2.281065,0.973261,4.212825,0.750629,3.391981,0.298357


In [85]:
stats_df = pd.DataFrame([
    get_min_stats(filter_by_config({'tokens_num': x}))
    for x in [5, 10, 20]
])
stats_columns = [col.replace('_mean', '').replace('fad_', '') for col in stats_df.columns if '_mean' in col]
stats_latex_df = pd.DataFrame(index=stats_columns, columns=['\# tokenów=5', '\# tokenów=10', '\# tokenów=20'])

for i, tokens in enumerate([5, 10, 20]):
    for col in stats_columns:
        mean = stats_df.loc[i, f"fad_{col}_mean"]
        std = stats_df.loc[i, f"fad_{col}_std"]
        stats_latex_df.loc[col, f'\# tokenów={tokens}'] = f"{mean:.2f} $\pm$ {std:.2f}"
with open(OUTPUT_PATH('visualization', "sweep_runs_table.tex"), "w") as f:
    f.write(stats_latex_df.to_latex(index=True,            # Keep the row labels (the category names)
        escape=False,          # Important to allow LaTeX syntax in the column names
        float_format="%.3f",
        caption="FAD ze względu na ilość tokenów",  # Example caption
        label="tab:fad_by_token_num" )
    )