In [None]:
import json
import os

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import pandas as pd
import wandb

from typing import Union

In [None]:
def wandb2pd(exp_runs):
    df = pd.DataFrame(data=None, index=None, columns=None, dtype=None, copy=False)
    summary_df = pd.DataFrame(data=None, index=None, columns=None, dtype=None, copy=False)
    config_df = pd.DataFrame(data=None, index=None, columns=None, dtype=None, copy=False)
    name_df = pd.DataFrame(data=None, index=None, columns=None, dtype=None, copy=False)

    summary = [] 
    config = [] 
    name = [] 
    for exp in exp_runs: 
        summary.append(exp.summary._json_dict) 
        config.append({k:v for k,v in exp.config.items() if not k.startswith('_')}) 
        name.append(exp.name)       

    summary_df = pd.DataFrame.from_records(summary) 
    config_df = pd.DataFrame.from_records(config) 
    name_df = pd.DataFrame({'name': name}) 
    df = pd.concat([name_df, config_df, summary_df], axis=1)
    return df


In [None]:
def get_gpu_count(exp_runs):
    
    for exp in exp_runs:
        try:
            metedata = json.load(exp.file("wandb-metadata.json").download())
            os.remove("wandb-metadata.json")    
            return metedata["gpu_count"]
        except:
            print(f"{exp.id}:failed to fetch data")
            continue
            

In [None]:
def describe_exp_runs(path:str, is_filter:bool=True) -> (int, float):
    
    print(path)
    api = wandb.Api()
    if is_filter:
        exp_runs = api.runs(
            path=path,
            filters={"state":"finished"}
        )
    else:
        exp_runs = api.runs(
            path=path
        )
        
    df = wandb2pd(exp_runs)
    gpu_count = get_gpu_count(exp_runs)
    
    model_num = len(df)
    print(f"Num of Models: {model_num}")

    calc_time = (df["_runtime"].sum() / 3600) * gpu_count
    print(f"Calculation time: {calc_time} hour")
    
    print()
    return (model_num, calc_time)


In [None]:
all_model_num = 0
all_calc_time = 0

### BGC

In [None]:
path_list = [
    "entity_name/project_momentum",
    "entity_name/project_adam",
    "entity_name/project_momentum_sgd",
    "entity_name/project_adam"
]
for path in path_list:
    if (path == "entity_name/project_momentum") or (path == "entity_name/project_adam"):
        model_num, calc_time = describe_exp_runs(path, is_filter=False)
        calc_time *= 8
    else:
        model_num, calc_time = describe_exp_runs(path)
        all_model_num += model_num
        
    all_calc_time += calc_time

### Domainbed

In [None]:
dataset_list = [
    "ColoredMNIST",
    "PACS",
    "VLCS",
    "OfficeHome",
    "TerraIncognita",
    "DomainNet",
    "RotatedMNIST"
]
algorithm_list = [
    "ERM",
    "IRM"
]
for dataset in dataset_list:
    for algorithm in algorithm_list:
        
        path_list = [
            f"entity_name/{algorithm}_{dataset}_momentum_sgd",
            f"entity_name/{algorithm}_{dataset}_adam"
        ]
        
        for path in path_list:
            model_num, calc_time = describe_exp_runs(path)
            all_model_num += model_num
            all_calc_time += calc_time

### WILDS

In [None]:
dataset_list = [
    "WILDS_civilcomments",
    "WILDS_Amazon"
]
algorithm_list = [
    "ERM",
    "IRM"
]
for dataset in dataset_list:
    for algorithm in algorithm_list:
        
        path_list = [
            f"entity_name/{dataset}_{algorithm}_momentum_sgd",
            f"entity_name/{dataset}_{algorithm}_adam"
        ]
        for path in path_list:
            model_num, calc_time = describe_exp_runs(path)
            all_model_num += model_num
            all_calc_time += calc_time

In [None]:
print(f"all model num: {all_model_num}")
print(f"all calculation time: {all_calc_time} hour")