In [None]:
# Run this cell in Google Colab to set up the environment
import os
if 'COLAB_GPU' in os.environ or 'COLAB_RELEASE_TAG' in os.environ:
    !git clone https://github.com/ManuelVigelius/spatially_expanding_flow.git
    os.chdir('spatially_expanding_flow')
    !pip install -q diffusers transformers accelerate datasets plotly

In [None]:
from datasets import load_dataset
from huggingface_hub import login

import config
from experiments import run_all_experiments

import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np

login(token="")

dataset = load_dataset("detection-datasets/coco", split="val")

In [None]:
dataset_samples = list(dataset.select(range(config.num_samples)))
all_experiment_results = run_all_experiments(config.experiment_configs, dataset_samples, config)

In [None]:
experiment_positions = {
    "sd3_caption=False": (1, 1),
    "flux_caption=False": (1, 2),
    "auraflow_caption=False": (1, 3),
}

colors = [
    '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
    '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf'
]

for metric, metric_label in [("velocity", "Velocity MSE"), ("latent", "Latent MSE")]:
    fig = make_subplots(
        rows=1, cols=3,
        subplot_titles=["SD3", "FLUX", "AuraFlow"],
        horizontal_spacing=0.08,
    )

    for exp_key, (row, col) in experiment_positions.items():
        if exp_key not in all_experiment_results:
            continue

        results = all_experiment_results[exp_key][metric]
        avg_results = {t: {size: np.mean(losses) for size, losses in sizes.items()}
                       for t, sizes in results.items()}
        t_values = sorted(avg_results.keys())

        for i, key in enumerate(config.compression_sizes):
            loss_values = [avg_results[t][key] for t in t_values]
            label = f'Size {key}'
            show_legend = (col == 1)

            fig.add_trace(
                go.Scatter(
                    x=t_values, y=loss_values, mode='lines+markers',
                    name=label, line=dict(color=colors[i % len(colors)]),
                    marker=dict(size=4), showlegend=show_legend, legendgroup=label,
                ),
                row=row, col=col
            )

    fig.update_layout(
        title=f'{metric_label}: SD3 vs FLUX vs AuraFlow',
        height=450, width=1200,
        legend=dict(orientation="h", yanchor="bottom", y=-0.25, xanchor="center", x=0.5),
        hovermode='x unified'
    )
    for j in range(1, 4):
        fig.update_xaxes(title_text="Timestep Index", row=1, col=j)
        fig.update_yaxes(title_text=metric_label, row=1, col=j)
    fig.show()