In [4]:
import wandb
import numpy as np

# Log in to WandB
wandb.login()

# Replace with your entity, project, and run ID
entity = "nlp_ls"
projects = ["dspro2-predicting-country", "dspro2-predicting-region", "dspro2-predicting-coordinates"]
run_ids = ["*:10"]  # or [...]

# Value to add to the summary
summary_key = None  # "test_data_run_id"
summary_value = None  # "w1098m89"

# Push best Validation Accuracy * and Validation Distance and Validation Loss as Logs
push_best = True
# Whether to recalculate the best values, or only push them if they are not already in the summary (recommended to set to False)
force_recalculate = False

# Access the run
api = wandb.Api()
all_non_special_used_run_ids = []
for project in projects:
    all_project_runs = [run.id for run in api.runs(f"{entity}/{project}")]
    project_run_ids = []
    for run_id in run_ids:
        if run_id.startswith("*"):
            # Get limit, if limit is set get the most recent runs
            if len(run_id) > 2 and run_id[1] == ":":
                limit = int(run_id.split(":")[1])
            # Get all runs of the project
            project_run_ids = []
            print(f"Getting run id's for project {project}")
            for run in all_project_runs if not limit else all_project_runs[:limit]:
                project_run_ids.append(run)
            print(f"Found {len(project_run_ids)} runs")
        elif run_id in all_project_runs:
            project_run_ids.append(run_id)
            all_non_special_used_run_ids.append(run_id)

    for run_id in project_run_ids:
        print(f"Checking run {run_id}")
        if summary_key is not None and summary_value is not None:
            run = api.run(f"{entity}/{project}/{run_id}")

            # Update the summary
            run.summary[summary_key] = summary_value
            run.summary.update()

            print(f"Successfully updated summary of run {run_id} with {summary_key}: {summary_value}")

        # Push best Validation Accuracy * and Validation Distance and Validation Loss to summary
        # Technically, this is different from pushing it after training, because there it will use the values of the best epoch, however, this is a good approximation
        if push_best:
            run = api.run(f"{entity}/{project}/{run_id}")
            # Skip if the run is running
            if run.state == "running":
                print(f"Run {run_id} is still running, skipping")
                continue
            # Get all metrics in summary that start with "Validation"
            metrics = run.summary.keys()
            validation_metrics = [k for k in metrics if k.lower().startswith("validation")]
            best_validation_metrics = {}
            # For every validation metric, push the best value from the history
            for metric in validation_metrics:
                best_key = f"Best {metric}"
                if best_key in metrics and not force_recalculate:
                    continue
                all_values = run.history()[metric]  # numpy array
                all_values_without_nan = all_values[~np.isnan(all_values)]
                if len(all_values_without_nan) > 0:
                    best_value = np.max(all_values_without_nan) if "accuracy" in metric.lower() or "correct" in metric.lower() else np.min(all_values_without_nan)
                else:
                    best_value = np.nan
                best_validation_metrics[best_key] = best_value
                # Update the summary
                run.summary[best_key] = best_value
            if len(best_validation_metrics) > 0:
                run.summary.update()
                print(f"Successfully updated summary of run {run_id} with {best_validation_metrics}")
            else:
                print(f"Skipped updating summary of run {run_id}")

non_special_run_ids = set([run_id for run_id in run_ids if not run_id.startswith("*")])
all_non_special_used_run_ids = set(all_non_special_used_run_ids)
if len(non_special_run_ids) > 0 and len(all_non_special_used_run_ids) > 0:
    print(f"Run ids that were not found: {non_special_run_ids - all_non_special_used_run_ids}")

Getting run id's for project dspro2-predicting-country
Found 10 runs
Checking run elpusqol
Run elpusqol is still running, skipping
Checking run kth52fnv
Successfully updated summary of run kth52fnv with {'Best Validation Accuracy Top 3': 0.8207401174932765, 'Best Validation Accuracy Top 5': 0.8790510389590876, 'Best Validation Loss': 1.4306674492495821, 'Best Validation Accuracy Top 1': 0.6403684060279159}
Checking run 8pqe6jmh
Successfully updated summary of run 8pqe6jmh with {'Best Validation Accuracy Top 3': 0.8052646603663026, 'Best Validation Accuracy Top 1': 0.6072839821506378, 'Best Validation Loss': 1.4597630012418126, 'Best Validation Accuracy Top 5': 0.8728608561082981}
Checking run ewnspxxj
Successfully updated summary of run ewnspxxj with {'Best Validation Accuracy Top 1': 0.5754616343885692, 'Best Validation Accuracy Top 3': 0.7823369442733297, 'Best Validation Loss': 1.5119558592918878, 'Best Validation Accuracy Top 5': 0.8547710984569618}
Checking run bc69qzqh
Successful