In [None]:
from reverb.training.utils import DEFAULT_TRAINING_KWARGS, DEFAULT_MODEL_KWARGS, DEFAULT_DATA_KWARGS
import segmentation_models_pytorch as smp

experiments = {
    "supervised": {
        "run_name": "whale/supervised_cw0.1",
        "training_kwargs": {
            "max_epochs": 60,
            "lr": 0.0001,
            'class_weights': [0.1, 1.0]
        },
        "model_kwargs": DEFAULT_MODEL_KWARGS,
        "data_kwargs": {"feature_class": "whale_truncated"},
    },
    "semi_supervised": {
        "run_name": "whale/semi_supervised_cw0.1",
        "training_kwargs": {
            "max_epochs": 60,
            "alpha": 0.98,
            "lr": 0.0001,
            'class_weights': [0.1, 1.0],
            'consistency_lambda': 0.1

        },
        "model_kwargs": DEFAULT_MODEL_KWARGS,
        "data_kwargs": {"feature_class": "whale_truncated"},
    },
    "synthetic_pretrain": {
        "run_name": "whale/synthetic_pretrain",
        "training_kwargs": {
            "max_epochs": 25,
            "lr": 5e-5,
            'class_weights': [0.1, 1.0],
        },
        "finetuning_kwargs": {
            "max_epochs": 30,   
            "lr": 5e-6,
            'class_weights': [0.1, 1.0]
        },
        "model_kwargs": DEFAULT_MODEL_KWARGS,
        "data_kwargs": {"feature_class": "whale_truncated"},
    },
}

In [None]:
from reverb.training.utils import train, get_eval_dataloaders, compute_results_over_eval_sets, save_evaluation_results
eval_dataloaders = get_eval_dataloaders(feature_class="whale_truncated")


In [None]:
for experiment in experiments.keys():
    experiment_config = experiments[experiment]
    for i in range(3):
        run_name = f"{experiment_config['run_name']}_{i}"
        if experiment == "synthetic_pretrain":
            run_name += "_pre"

        training_kwargs = experiment_config['training_kwargs']
        model_kwargs = experiment_config['model_kwargs']
        data_kwargs = experiment_config['data_kwargs']
        # Train the model
        train(
            run_name=run_name,
            mode=experiment,
            model_kwargs=model_kwargs,
            data_kwargs=data_kwargs,  
            training_kwargs=training_kwargs,
            background_dir="../analysis/results/UP05_predictions/no_whales",

        )
        if experiment == "synthetic_pretrain":
            # Fine-tune the model
            finetuning_kwargs = experiment_config.get('finetuning_kwargs', {})
            finetune_run_name = f"{experiment_config['run_name']}_{i}"
            train(
                run_name=finetune_run_name,
                mode="supervised",
                model_kwargs=model_kwargs,
                data_kwargs=data_kwargs,  
                training_kwargs=finetuning_kwargs,
                pretrain_path=run_name,
            )
            run_name = finetune_run_name

        # Evaluate the model
        results = compute_results_over_eval_sets(run_name, eval_dataloaders, model_kwargs=model_kwargs, threshold=0.8)
        save_evaluation_results(run_name, results)


In [None]:
import os
import json
import pandas as pd
experiment_names = experiments.keys()

# Root directory containing experiment folders like 'baseline_model_0/', 'baseline_model_1/', etc.
experiments_root = './checkpoints/whale'

flattened_data = []

for exp_name in experiment_names:
    # Find folders starting with the experiment name and ending in a number (repeats)
    matching_folders = [
        d for d in os.listdir(experiments_root)
        if os.path.isdir(os.path.join(experiments_root, d)) and d.startswith(exp_name + '_')
    ]

    for folder in matching_folders:
        results_path = os.path.join(experiments_root, folder, 'eval_results.json')
        if os.path.isfile(results_path):
            with open(results_path, 'r') as f:
                datasets = json.load(f)
            for dataset, metrics in datasets.items():
                for metric, value in metrics.items():
                    if metric in ['miou', 'precision', 'recall']:
                        flattened_data.append({
                            'Experiment': exp_name,  # Group under common experiment name
                            'Repeat': folder,
                            'Dataset': dataset,
                            'Metric': metric,
                            'Value': value
                        })

# Convert to DataFrame
df = pd.DataFrame(flattened_data)

# Compute mean and SEM over repeats for each experiment
mean_df = (
    df.groupby(['Experiment', 'Dataset', 'Metric'])['Value']
    .mean()
    .reset_index()
    .rename(columns={'Value': 'Mean'})
)

sem_df = (
    df.groupby(['Experiment', 'Dataset', 'Metric'])['Value']
    .sem()
    .reset_index()
    .rename(columns={'Value': 'Std_Error'})
)

# Merge summaries
summary_df = pd.merge(mean_df, sem_df, on=['Experiment', 'Dataset', 'Metric'])

# Save outputs
df.to_csv('individual_repeat_results.csv', index=False)
summary_df.to_csv('supervised_experiment_summary.csv', index=False)

print("Saved individual repeat results and summary statistics.")


In [None]:
# Filter only for 'miou'
miou_df = summary_df[summary_df['Metric'] == 'miou']

# Print one table per dataset
for dataset in miou_df['Dataset'].unique():
    print(f"\n--- Dataset: {dataset} ---")
    display(miou_df[miou_df['Dataset'] == dataset].drop(columns=['Metric']))


In [None]:
from tabulate import tabulate

# Filter for the validation dataset and metrics of interest
filtered = summary_df[
    (summary_df['Dataset'] == 'valid') &
    (summary_df['Metric'].isin(['miou', 'precision', 'recall']))
]

# Pivot so that rows are experiments and columns are metrics
pivot_df = filtered.pivot(index='Experiment', columns='Metric', values=['Mean', 'Std_Error'])

# Desired order of metrics
metrics_order = ['miou', 'precision', 'recall']
rows = []

for experiment in pivot_df.index:
    row = [experiment]  # Start with experiment name
    for metric in metrics_order:
        mean = pivot_df.loc[experiment, ('Mean', metric)]
        sem = pivot_df.loc[experiment, ('Std_Error', metric)]
        formatted = f"{mean:.3f} \pm {sem:.3f}"
        row.append(formatted)
    rows.append(row)

# Column headers
headers = ['Model Type', 'IoU', 'Precision', 'Recall']

# Generate LaTeX table
latex_table = tabulate(rows, headers=headers, tablefmt='latex_booktabs')

# Full LaTeX table
latex_full = f"""
\\begin{{table}}
\\centering
\\small
\\setlength{{\\tabcolsep}}{{4pt}}
\\caption{{Blue whale call segmentation performance on the UP05 validation dataset. 
The mean performance and standard error over three independent training runs are reported for each metric.}}
\\label{{tab:results_whales}}
{latex_table}
\\end{{table}}
"""

print(latex_full)


In [None]:
from reverb.training.utils import load_best_model_from_run
models = {}
for exp_name, exp_config in experiments.items():
    model = load_best_model_from_run(
        run_name=exp_config['run_name'] + '_2',  # Assuming we want the first repeat
        model_kwargs=exp_config['model_kwargs'],
        training_kwargs=exp_config['training_kwargs'],
    )
    models[exp_name] = model


In [None]:
import torch
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
valid_dataloader = eval_dataloaders['valid']
valid_dataset_batch = next(iter(valid_dataloader))
all_pred_masks = {} 

for name, model in models.items():
    model.to(device)
    model.eval()
    threshold = 0.8 if 'synthetic' in name else 0.8
    print(threshold)
    with torch.no_grad():
        eval_dataloader = valid_dataloader
        for batch in eval_dataloader:
            images, masks = batch
            images, masks = images.to(device), masks.to(device)
            y_pred = model(images)
            probs = F.softmax(y_pred, dim=1)  # shape: (B, 2, H, W)

            # Get class 1 probability and apply threshold
            class1_prob = probs[:, 1, :, :]  # shape: (B, H, W)
            masks_pred = (class1_prob > threshold).long().cpu().numpy()
    all_pred_masks[name] = masks_pred

In [None]:
# find indices where true masks sum to non-zero
non_zero_indices = valid_dataset_batch[1].sum(dim=(1, 2)) > 0
# print the actual index values
print("Indices with non-zero true masks:", non_zero_indices.nonzero(as_tuple=True)[0].tolist())


In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.pyplot as plt

# set font size for all plots
plt.rcParams.update({'font.size': 36})

def apply_crop(input, indices):
    cropped_spectrogram = np.flipud(np.flipud(input)[indices, :])
    return cropped_spectrogram

colors = ['darkblue','cyan']
cmap_masks = plt.cm.colors.ListedColormap(colors)
exp_patterns = ['supervised', 'test_finetune']
def plot_valid_predictions(indices,                       # three sample indices
                           images,                        # shape (N, 1, H, W) or (N, H, W)
                           masks_true,                    # shape (N, H, W)
                           all_pred_masks,                # dict: name -> np.ndarray (N, H, W)
                           save_dir="results",
                           trim_rows=17,
                           vmin=0.25, vmax=0.75,
                           cmap_masks="tab20"):
    """
    Draws one row per index: raw image | GT mask | one column per model prediction.
    Saves each figure as results/whale_predictions_<idx>.png and shows it.
    """
    os.makedirs(save_dir, exist_ok=True)

    # Ensure images are 3‑D (N, H, W) for imshow.
    if images.ndim == 4:          # (N, C, H, W)
        images = images.squeeze(1)  # assume single‑channel
    images = images[ :, :-trim_rows, :-4]
    masks_true = masks_true[:, :-trim_rows, :-4]
    model_names = list(all_pred_masks.keys())
    n_cols      = 2 + len(model_names)
    n_rows      = len(indices)
    spec_frequencies = np.linspace(1.01, 50.0, 399)
    freq_mask = (spec_frequencies >= 15) & (spec_frequencies <= 19)
    freq_indices = np.where(freq_mask)[0]
    cropped_frequencies = spec_frequencies[freq_indices][::-1]

    for idx in indices:
        fig, axs = plt.subplots(1, n_cols,
                                figsize=(3 * n_cols, 2),
                                constrained_layout=True)

        # --- column 0: raw spectrogram ----------------------------------------------------------
        image = images[idx]
        image = apply_crop(image, freq_indices)

        axs[0].imshow(image, aspect="auto",
                      cmap="jet", vmin=vmin, vmax=vmax)

        # --- column 1: ground‑truth mask --------------------------------------------------------
        mask_true = apply_crop(masks_true[idx], freq_indices)
        axs[1].imshow(mask_true, aspect="auto",
                      cmap=cmap_masks)

        # --- remaining columns: one per model ---------------------------------------------------
        for j, name in enumerate(model_names, start=2):
            pred_mask = all_pred_masks[name][idx][:-trim_rows, :-4]
            pred_mask = apply_crop(pred_mask, freq_indices)
            axs[j].imshow(pred_mask,
                          aspect="auto", cmap=cmap_masks)

        # cosmetic: hide spines / ticks except first column’s y‑ticks
        for j, ax in enumerate(axs):
            if j != 0:  # hide everything for masks/preds
                ax.set_xticks([]); ax.set_xticklabels([])
                ax.set_yticks([]); ax.set_yticklabels([])
                ax.axis("off")
            else:       # first column keeps y‑ticks for context
                major_tick_positions = []
                major_tick_freqs = [16, 18]
                for freq in major_tick_freqs:
                    # This finds the index in the array where the value is nearest to `freq`
                    closest_index = np.abs(cropped_frequencies - freq).argmin()
                    major_tick_positions.append(closest_index)

                # 3. Set the major ticks at the found positions with the desired labels
                ax.set_yticks(major_tick_positions)
                ax.set_yticklabels(major_tick_freqs)

                # 4. Turn on automatic minor ticks
                # Matplotlib will now add unlabeled ticks between your major ones.
                ax.minorticks_on()

                # You can keep your existing lines to hide the x-ticks
                ax.set_xticks([])
                ax.set_xticklabels([])

        # save & show ---------------------------------------------------------------------------
        out_path = os.path.join(save_dir, f"whale_predictions_{idx}.png")
        plt.savefig(out_path, dpi=300, transparent=True)
        plt.show()
        plt.close(fig)

images_cpu = valid_dataset_batch[0].cpu().numpy()
masks_cpu  = valid_dataset_batch[1].cpu().numpy()
import scienceplots

with plt.style.context(['science']):
    plot_valid_predictions(indices=[4, 8, 25, 36, 101],
                        images=images_cpu,
                        masks_true=masks_cpu,
                        all_pred_masks=all_pred_masks,
                        cmap_masks=cmap_masks,)
