In [1]:
from reverb.training.utils import DEFAULT_TRAINING_KWARGS, DEFAULT_MODEL_KWARGS, DEFAULT_DATA_KWARGS
import segmentation_models_pytorch as smp

multiclass_experiments = {
    "baseline": {
        "run_name": "ablations/multiclass/baseline",
        "training_kwargs": {
            "max_epochs": 25,
            'class_weights': [0.5, 1.0, 1.0]
        },
        "model_kwargs": {"classes": 3},
        "data_kwargs": {"feature_class": "multiclass"},
    },


}

In [2]:
from reverb.training.utils import train, get_eval_dataloaders, compute_results_over_eval_sets, save_evaluation_results
eval_dataloaders = get_eval_dataloaders(feature_class="multiclass")


loading annotations into memory...
Done (t=0.24s)
creating index...
index created!
loading annotations into memory...
Done (t=0.05s)
creating index...
index created!
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
loading annotations into memory...
Done (t=0.01s)
creating index...
index created!
loading annotations into memory...
Done (t=0.01s)
creating index...
index created!


In [None]:
for experiment in multiclass_experiments.keys():
    experiment_config = multiclass_experiments[experiment]
    for i in range(1,3):
        run_name = f"{experiment_config['run_name']}_{i}"

        training_kwargs = experiment_config['training_kwargs']
        model_kwargs = experiment_config['model_kwargs']
        data_kwargs = experiment_config['data_kwargs']
        # train the model
        train(
            run_name=run_name,
            mode="supervised",
            model_kwargs=model_kwargs,
            data_kwargs=data_kwargs,  
            training_kwargs=training_kwargs,
        )
        # Evaluate the model
        results = compute_results_over_eval_sets(run_name, eval_dataloaders, model_kwargs=model_kwargs, training_kwargs=training_kwargs)
        save_evaluation_results(run_name, results)


In [3]:
import os
import json
import pandas as pd
experiment_names = multiclass_experiments.keys()

# Root directory containing experiment folders like 'baseline_model_0/', 'baseline_model_1/', etc.
experiments_root = './checkpoints/ablations/multiclass'

flattened_data = []

for exp_name in experiment_names:
    # Find folders starting with the experiment name and ending in a number (repeats)
    matching_folders = [
        d for d in os.listdir(experiments_root)
        if os.path.isdir(os.path.join(experiments_root, d)) and d.startswith(exp_name + '_')
    ]

    for folder in matching_folders:
        results_path = os.path.join(experiments_root, folder, 'eval_results.json')
        if os.path.isfile(results_path):
            with open(results_path, 'r') as f:
                datasets = json.load(f)
            for dataset, metrics in datasets.items():
                for metric, value in metrics.items():
                    if metric in ['miou', 'precision', 'recall']:
                        flattened_data.append({
                            'Experiment': exp_name,  # Group under common experiment name
                            'Repeat': folder,
                            'Dataset': dataset,
                            'Metric': metric,
                            'Value': value
                        })

# Convert to DataFrame
df = pd.DataFrame(flattened_data)

# Compute mean and SEM over repeats for each experiment
mean_df = (
    df.groupby(['Experiment', 'Dataset', 'Metric'])['Value']
    .mean()
    .reset_index()
    .rename(columns={'Value': 'Mean'})
)

sem_df = (
    df.groupby(['Experiment', 'Dataset', 'Metric'])['Value']
    .sem()
    .reset_index()
    .rename(columns={'Value': 'Std_Error'})
)

# Merge summaries
summary_df = pd.merge(mean_df, sem_df, on=['Experiment', 'Dataset', 'Metric'])

# Save outputs
df.to_csv('individual_repeat_results.csv', index=False)
summary_df.to_csv('multiclass_experiment_summary.csv', index=False)

print("Saved individual repeat results and summary statistics.")


Saved individual repeat results and summary statistics.


In [None]:
# Filter only for 'miou'
miou_df = summary_df[summary_df['Metric'] == 'miou']

# Print one table per dataset
for dataset in miou_df['Dataset'].unique():
    print(f"\n--- Dataset: {dataset} ---")
    display(miou_df[miou_df['Dataset'] == dataset].drop(columns=['Metric']))
