In [215]:
import os
import time
from typing import List, Optional

import pandas as pd
import rich
import wandb
from rich import print
from rich.console import Console
from rich.table import Table
from tqdm.auto import tqdm

def display_pretty(df):
    # visualize using tabulate for ipython noteboooks
    from IPython.display import display

    # Assuming df is your DataFrame
    with pd.option_context("display.max_rows", None, "display.max_columns", None):
        display(
            df
        )



def login(project: Optional[str] = None, token: Optional[str] = None):
    if project is None:
        project = os.getenv("WANDB_PROJECT")
    if token is None:
        token = os.getenv("WANDB_API_KEY")
    # Set your wandb API key
    wandb.login(key=token)


def fetch_runs(
    project_list,
    exp_name_list: Optional[List[str]] = None,
    exp_term_list: Optional[List[str]] = None,
    total_runs_limit: Optional[int] = None,
):
    # Fetch runs from wandb
    api = wandb.Api()
    runs = (
        api.runs(path=project_list)
        if isinstance(project_list, str)
        else [api.runs(path=p) for p in project_list]
    )
    # runs = [run for run_list in runs for run in run_list]

    exp_name_to_summary_dict = {}
    exp_name_to_state_dict = {}
    idx = 0
    for run in tqdm(runs[0]):
        if "exp_name" not in run.config:
            continue
        exp_name = run.config["exp_name"].lower()

        if exp_name_list is not None and exp_name not in exp_name_list:
            continue

        if exp_term_list is not None:
            if not any(term in exp_name for term in exp_term_list):
                continue

        if exp_name not in exp_name_to_summary_dict:
            exp_name_to_summary_dict[exp_name] = [run.summaryMetrics]
            exp_name_to_state_dict[exp_name] = [run.state]
        else:
            exp_name_to_summary_dict[exp_name].append(run.summaryMetrics)
            exp_name_to_state_dict[exp_name].append(run.state)
        if total_runs_limit is not None and idx >= total_runs_limit:
            break
        idx += 1
    # sort the dicts by exp_name
    exp_name_to_summary_dict = dict(sorted(exp_name_to_summary_dict.items()))
    exp_name_to_state_dict = dict(sorted(exp_name_to_state_dict.items()))
    return exp_name_to_summary_dict, exp_name_to_state_dict


def fetch_run_status(exp_name_to_summary_dict, exp_name_to_state_dict):
    exp_data = {}
    exp_to_command = {
        key.lower(): value for key, value in exp_name_to_summary_dict.items()
    }
    if isinstance(exp_name_to_summary_dict, dict):
        experiments = list(exp_name_to_summary_dict.keys())
    experiments = [exp.lower() for exp in experiments]
    # Iterate through the given experiments
    with tqdm(total=len(experiments), desc="Checking experiments") as pbar:
        for exp_name in experiments:
            # Check if the experiment exists in wandb and if it has completed the testing stage
            if exp_name in exp_name_to_summary_dict:
                keys = [
                    k
                    for summary_keys in exp_name_to_summary_dict[exp_name]
                    for k in summary_keys.keys()
                ]

                testing_completed = any("testing/ensemble" in k for k in keys)
                model_compiled = any("model/num_parameters" in k for k in keys)

                currently_running = any(
                    "running" == state.lower()
                    for state in exp_name_to_state_dict[exp_name]
                )
                if "global_step" in keys:
                    current_iter = max(
                        [
                            summary_stats["global_step"]
                            for summary_stats in exp_name_to_summary_dict[
                                exp_name
                            ]
                            if "global_step" in summary_stats.keys()
                        ]
                    )
                else:
                    current_iter = 0

            else:
                testing_completed = False
                currently_running = False
                current_iter = 0

            # Append the data to the list
            exp_data[exp_name] = {
                "testing_completed": testing_completed,
                "currently_running": currently_running,
                "current_iter": current_iter,
                "model_compiled": model_compiled,
                "command": exp_to_command[exp_name],
            }
            pbar.update(1)
    return exp_data


def pretty_print_runs(exp_data):

    # Create a pandas DataFrame
    df = pd.DataFrame(
        exp_data
    ).T  # Transpose the DataFrame so that each experiment is a row

    # Create a console for rich print
    console = Console()

    # Create a table
    table = Table(show_header=True, header_style="bold magenta", style="dim")
    table.add_column("idx", justify="right")
    table.add_column("Experiment Name", width=50)
    table.add_column("Currently Running", justify="right")
    table.add_column("Testing Completed", justify="right")
    table.add_column("Current Iteration", justify="right")
    table.add_column("Model Compiled", justify="right")

    # Add rows to the table
    for idx, (exp_name, row) in enumerate(df.iterrows()):
        table.add_row(
            str(idx),
            exp_name,
            str(row["currently_running"]),
            str(row["testing_completed"]),
            str(row["current_iter"]),
            str(row["model_compiled"]),
        )

    # Print the table
    console.print(table)



In [216]:
project = "machinelearningbrewery/gate-0-9-1"
token = None


exp_name_to_summary_dict, exp_name_to_state_dict = fetch_runs(
                project_list=[project],
                exp_name_list=None,
                exp_term_list=["9032024", "10032024", "31032024", "30032024"],
                total_runs_limit=None,
            )
exp_data = fetch_run_status(
exp_name_to_summary_dict, exp_name_to_state_dict
)
            
pretty_print_runs(exp_data)

# store the dicts in json
import json

with open("exp_data.json", "w") as f:
    json.dump(exp_data, f)

# store the dicts in json
with open("exp_name_to_summary_dict.json", "w") as f:
    json.dump(exp_name_to_summary_dict, f)

# store the dicts in json
with open("exp_name_to_state_dict.json", "w") as f:
    json.dump(exp_name_to_state_dict, f)

  0%|          | 0/4811 [00:00<?, ?it/s]

In [None]:
import json
import pandas as pd

model_names = ["ar-vit-b16", "effformer-s0", "bart", "bert", "clip-b16", "whisper", "svit-b16", "siglip-p16", "rnx50-32x4a1", "mpnet", "laion-b16", "flex-b-1200ep", "effv2-rw-s", "dino-b16", "deit3-b16", "convnextv2-base"]

model_names = list(set(model_names))
# load the json
with open("exp_data.json", "r") as f:
    exp_data = json.load(f)

# load the json
with open("exp_name_to_summary_dict.json", "r") as f:
    exp_name_to_summary_dict = json.load(f)

# load the json
with open("exp_name_to_state_dict.json", "r") as f:
    exp_name_to_state_dict = json.load(f)

In [None]:
keys = []
for model_name, metrics in exp_name_to_summary_dict.items():
    cur_keys = [k for item in metrics for k in item.keys() if "testing" in k and "std" not in k and "episode" not in k and "background" not in k and "complete" not in k and "ensemble_3" in k and "similarities" not in k and "logits" not in k]
    
    keys.extend(cur_keys)

keys = sorted(set(keys))
# print(keys)

results = {}

for model_name in exp_name_to_summary_dict.keys():
    results[model_name] = {}
    for key in keys:
        for item in exp_name_to_summary_dict[model_name]:
            if key in item:
                dataset_name = model_name.split("-")[1] if "math" not in model_name else model_name.split("-")[1] + "-" + model_name.split("-")[2]
                results[model_name][f"{dataset_name}.{key}"] = item[key]
# delete all keys that have no value
for model_name in results.keys():
    for key in list(results[model_name].keys()):
        if results[model_name][key] is None:
            del results[model_name][key]


original_df = pd.DataFrame(results).T
# display_pretty(original_df)
df = original_df.copy()


In [None]:
# label the first column as exp_name
df.index.name = "exp_name"
df["exp_name"] = df.index
df["experiment_series"] = df.index.str.split("-").str[0]
df["dataset_name"] = df.index.str.split("-").str[1]
# break the exp_name column by - and take the first part to be experiment series, and then check if the second part matches one of the model names in the list, and assign that as the model name
# check the exp_name column and if one of the model names matches part of the exp_name, assign it to the model_name column
# for item in df["exp_name"]:
#     for model_name in model_names:
#         print(model_name, item)
#         if model_name in item:
#             print(f"Model name: {model_name}, exp_name: {item}")
#             continue
#     print(f"Model name: None, exp_name: {item}")
df["model_name"] = df["exp_name"].apply(lambda x: next((model_name for model_name in model_names if model_name in x), None))
# set model_name to be in the second column
df = df.set_index("dataset_name", append=True)
df = df.set_index("model_name", append=True)
df = df.set_index("experiment_series", append=True)
df = df.groupby('model_name').first().reset_index()
# if a column has only None and NaNs it should be removed
df = df.dropna(axis=0, how='all')
# drop column with name winoground.testing/ensemble_3-/text_to_image_accuracy_top_5-epoch-mean
df = df.drop(columns=['winoground.testing/ensemble_3-/text_to_image_accuracy_top_5-epoch-mean'])
# drop column with name winoground.testing/ensemble_3-/image_to_text_accuracy_top_5-epoch-mean
df = df.drop(columns=['winoground.testing/ensemble_3-/image_to_text_accuracy_top_5-epoch-mean'])


# remove dataset_name column
# df = df.drop(columns=['dataset_name'])
# unify the rows by model_name


# display_pretty(df)
# pretty_display first 20 columns

In [None]:
df = df.dropna(axis=1, how='all')

cleaned_df = df.drop(columns=[col for col in df.columns if df[col].isnull().sum() >= 14])
# remove columns with "logits" in the name
cleaned_df = cleaned_df[[col for col in cleaned_df.columns if "logits" not in col]]
cleaned_df = cleaned_df[[col for col in cleaned_df.columns if "ensemble_3--" not in col]]
cleaned_df = cleaned_df[[col for col in cleaned_df.columns if "global_step" not in col]]


# remove dataset_name column
# df = df.drop(columns=['dataset_name'])
# unify the rows by model_name

# remove columns where all values are 1.0
cleaned_df = cleaned_df[[col for col in cleaned_df.columns if cleaned_df[col].nunique() > 1]]
# replace testing/ensemble_3-/ in column names with .
cleaned_df.columns = cleaned_df.columns.str.replace("testing/ensemble_3-/", "")
cleaned_df.columns = cleaned_df.columns.str.replace("-epoch-mean", "")

# count columns
print(f"Number of columns: {len(cleaned_df.columns)}")
# display_pretty(cleaned_df)

# Count NaNs vs non-NaNs
nans = cleaned_df.isna().sum()
non_nans = cleaned_df.count()
print(f"Number of NaNs: {nans.sum()}")
print(f"Number of non-NaNs: {non_nans.sum()}")

# NaNs per row with row key being the column 1
nans_per_row = cleaned_df.isna().sum(axis=1)
print(f"Number of NaNs per row: {nans_per_row}")

In [None]:
# replace NaNs in each column by sampling from N(mu, sigma) where mu and sigma are the mean and std of the column
import numpy as np

synthetic_df = cleaned_df.copy()
for col in synthetic_df.columns[1:-1]:
    synthetic_df[col] = synthetic_df[col].astype(float)
    mu = synthetic_df[col].mean()
    sigma = synthetic_df[col].std()
    synthetic_df[col] = synthetic_df[col].fillna(synthetic_df[col].apply(lambda x: np.random.normal(mu, sigma) if pd.isnull(x) else x))

In [None]:
display_pretty(synthetic_df)

Unnamed: 0,model_name,acdc.dice_loss,acdc.loss,acdc.mIoU,acdc.mean_accuracy,acdc.overall_accuracy,ade20k.ce_loss,ade20k.dice_loss,ade20k.focal_loss,ade20k.loss,ade20k.mIoU,ade20k.mean_accuracy,ade20k.overall_accuracy,aircraft.accuracy_top_1,aircraft.loss,chexpert.testing/ensemble_3/0-aps,chexpert.testing/ensemble_3/0-auc,chexpert.testing/ensemble_3/0-bs,chexpert.testing/ensemble_3/1-aps,chexpert.testing/ensemble_3/1-auc,chexpert.testing/ensemble_3/1-bs,chexpert.testing/ensemble_3/2-aps,chexpert.testing/ensemble_3/2-auc,chexpert.testing/ensemble_3/2-bs,chexpert.testing/ensemble_3/3-aps,chexpert.testing/ensemble_3/3-auc,chexpert.testing/ensemble_3/3-bs,chexpert.testing/ensemble_3/4-aps,chexpert.testing/ensemble_3/4-auc,chexpert.testing/ensemble_3/4-bs,chexpert.testing/ensemble_3/aps-macro,chexpert.testing/ensemble_3/auc-macro,chexpert.testing/ensemble_3/bs-macro,chexpert.testing/ensemble_3/loss,cifar100.accuracy_top_1,cifar100.accuracy_top_5,cifar100.loss,cityscapes.ce_loss,cityscapes.dice_loss,cityscapes.focal_loss,cityscapes.loss,cityscapes.mIoU,cityscapes.mean_accuracy,cityscapes.overall_accuracy,clevr.accuracy_top_1,clevr.accuracy_top_1_colour,clevr.accuracy_top_1_count,clevr.accuracy_top_1_material,clevr.accuracy_top_1_shape,clevr.accuracy_top_1_size,clevr.accuracy_top_1_yes_no,clevr.loss,clevr.loss_colour,clevr.loss_count,clevr.loss_material,clevr.loss_shape,clevr.loss_size,clevr.loss_yes_no,clevr-math.accuracy_top_1,clevr-math.accuracy_top_5,clevr-math.loss,coco.ce_loss,coco.dice_loss,coco.focal_loss,coco.loss,coco.mIoU,coco.mean_accuracy,coco.overall_accuracy,cubirds.accuracy_top_1,cubirds.loss,diabetic.testing/ensemble_3/0-aps,diabetic.testing/ensemble_3/0-auc,diabetic.testing/ensemble_3/0-bs,diabetic.testing/ensemble_3/1-aps,diabetic.testing/ensemble_3/1-auc,diabetic.testing/ensemble_3/1-bs,diabetic.testing/ensemble_3/2-aps,diabetic.testing/ensemble_3/2-auc,diabetic.testing/ensemble_3/2-bs,diabetic.testing/ensemble_3/3-aps,diabetic.testing/ensemble_3/3-auc,diabetic.testing/ensemble_3/3-bs,diabetic.testing/ensemble_3/4-aps,diabetic.testing/ensemble_3/4-auc,diabetic.testing/ensemble_3/4-bs,diabetic.testing/ensemble_3/aps-macro,diabetic.testing/ensemble_3/auc-macro,diabetic.testing/ensemble_3/bs-macro,diabetic.testing/ensemble_3/loss,dtextures.accuracy_top_1,dtextures.loss,flickr30k.image_to_text_accuracy,flickr30k.image_to_text_accuracy_top_5,flickr30k.image_to_text_loss,flickr30k.loss,flickr30k.text_to_image_accuracy,flickr30k.text_to_image_accuracy_top_5,flickr30k.text_to_image_loss,food101.accuracy_top_1,food101.accuracy_top_5,food101.loss,fungi.accuracy_top_1,fungi.loss,ham10k.testing/ensemble_3/0-aps,ham10k.testing/ensemble_3/0-auc,ham10k.testing/ensemble_3/0-bs,ham10k.testing/ensemble_3/1-aps,ham10k.testing/ensemble_3/1-auc,ham10k.testing/ensemble_3/1-bs,ham10k.testing/ensemble_3/2-aps,ham10k.testing/ensemble_3/2-auc,ham10k.testing/ensemble_3/2-bs,ham10k.testing/ensemble_3/3-aps,ham10k.testing/ensemble_3/3-auc,ham10k.testing/ensemble_3/3-bs,ham10k.testing/ensemble_3/4-aps,ham10k.testing/ensemble_3/4-auc,ham10k.testing/ensemble_3/4-bs,ham10k.testing/ensemble_3/5-aps,ham10k.testing/ensemble_3/5-auc,ham10k.testing/ensemble_3/5-bs,ham10k.testing/ensemble_3/6-aps,ham10k.testing/ensemble_3/6-auc,ham10k.testing/ensemble_3/6-bs,ham10k.testing/ensemble_3/aps-macro,ham10k.testing/ensemble_3/auc-macro,ham10k.testing/ensemble_3/bs-macro,ham10k.testing/ensemble_3/loss,happy.accuracy_top_1,happy.accuracy_top_1_individual,happy.accuracy_top_1_species,happy.accuracy_top_5,happy.accuracy_top_5_individual,happy.accuracy_top_5_species,happy.loss,happy.loss_individual,happy.loss_species,hmdb51.accuracy_top_1,hmdb51.accuracy_top_5,hmdb51.loss,imagenet1k.accuracy_top_1,imagenet1k.accuracy_top_5,imagenet1k.loss,iwildcam.loss,iwildcam.mae_loss,iwildcam.mse_loss,kinetics.accuracy_top_1,kinetics.accuracy_top_5,kinetics.loss,mini.accuracy_top_1,mini.loss,newyorkercaptioncontest.image_to_text_accuracy,newyorkercaptioncontest.image_to_text_accuracy_top_5,newyorkercaptioncontest.image_to_text_loss,newyorkercaptioncontest.loss,newyorkercaptioncontest.text_to_image_accuracy,newyorkercaptioncontest.text_to_image_accuracy_top_5,newyorkercaptioncontest.text_to_image_loss,nyu.ce_loss,nyu.dice_loss,nyu.focal_loss,nyu.loss,nyu.mIoU,nyu.mean_accuracy,nyu.overall_accuracy,omniglot.accuracy_top_1,omniglot.loss,pascal.ce_loss,pascal.dice_loss,pascal.focal_loss,pascal.loss,pascal.mIoU,pascal.mean_accuracy,pascal.overall_accuracy,places365.accuracy_top_1,places365.accuracy_top_5,places365.loss,pokemonblipcaptions.image_to_text_accuracy,pokemonblipcaptions.image_to_text_accuracy_top_5,pokemonblipcaptions.image_to_text_loss,pokemonblipcaptions.loss,pokemonblipcaptions.text_to_image_accuracy,pokemonblipcaptions.text_to_image_accuracy_top_5,pokemonblipcaptions.text_to_image_loss,ucf.accuracy_top_1,ucf.accuracy_top_5,ucf.loss,vgg.accuracy_top_1,vgg.loss,winoground.image_to_text_accuracy,winoground.image_to_text_loss,winoground.loss,winoground.text_to_image_accuracy,winoground.text_to_image_loss,exp_name
0,ar-vit-b16,0.559243,0.562368,54.338646,79.443932,78.977887,1.439675,0.623715,0.291468,0.969018,31.810069,42.73518,62.651232,0.941239,0.262639,0.744916,0.920518,0.075839,0.518032,0.734637,0.19314,0.42149,0.708754,0.193698,0.777515,0.844941,0.176546,0.50848,0.866368,0.100756,0.594087,0.815044,0.147996,0.449167,40.814091,71.516716,2.278613,0.208765,0.616431,0.031152,0.148707,61.390024,75.50766,93.116434,52.169037,35.75423,45.621979,59.972759,51.118793,60.342369,60.226028,0.871289,1.508475,1.156839,0.659476,0.968794,0.650669,0.636307,59.341446,98.925003,0.935315,1.555699,0.570834,0.33662,1.073815,30.990615,43.468811,63.096144,0.940776,0.354046,0.895456,0.802728,0.137832,0.107936,0.636813,0.064423,0.511298,0.825096,0.100858,0.354947,0.933615,0.017826,0.722953,0.983035,0.009418,0.518518,0.836257,0.066072,0.180801,0.772806,1.478162,0.05225,0.184427,3.882412,3.889763,0.050999,0.195006,3.897114,84.751724,96.998093,0.537537,0.766097,0.887478,0.818958,0.971394,0.043687,0.985305,0.973967,0.058963,0.952532,0.999191,0.00348,0.756843,0.935443,0.052584,0.9647,0.999539,0.002152,0.861583,0.990035,0.016942,0.731135,0.972677,0.01902,0.867294,0.977464,0.028118,0.361344,52.962318,7.375529,98.54911,59.551193,19.19805,99.904335,3.098909,6.139311,0.058507,28.385416,58.658855,3.175461,77.671432,92.912804,0.965971,1.587222,1.587222,5.39161,23.338408,45.657619,4.014843,0.97935,0.077078,0.049986,0.185464,3.902347,3.907792,0.050875,0.190297,3.913237,1.709063,0.234403,0.316457,1.114036,12.077721,23.015949,32.480359,0.985692,0.07505,1.112825,0.490195,0.223187,0.769829,26.895896,35.571377,71.831284,47.252628,79.488289,2.015198,0.535616,0.722426,1.908139,1.936036,0.554228,0.741958,1.963933,65.259743,86.329659,1.414807,0.944496,0.302574,0.380417,0.693004,0.692399,0.541528,0.691794,debug-acdc-ar-vit-b16-512-9032024
1,bart,0.737925,0.742864,21.72679,41.123699,40.4233,3.738331,0.673217,0.874897,2.493215,0.494601,1.675647,13.021367,0.607579,1.220632,0.231186,0.664104,0.124336,0.298783,0.566562,0.185531,0.310067,0.599408,0.178448,0.487589,0.641896,0.26397,0.358567,0.794473,0.109397,0.337238,0.653289,0.172336,0.532958,2.318869,11.076831,4.521831,0.042907,0.62489,0.047516,0.242374,1.946948,5.263158,36.992007,42.365562,12.892879,44.629971,50.508053,33.378548,53.698559,59.094204,0.957619,2.079662,1.213407,0.693644,1.099306,0.689525,0.645299,42.060528,96.925003,1.342714,3.807623,0.926801,0.893084,2.593539,1.020467,2.102062,20.372541,0.498045,2.027175,0.738117,0.510582,0.195352,0.069083,0.500102,0.064319,0.149236,0.502309,0.126553,0.02753,0.500293,0.026796,0.020873,0.5,0.020453,0.200968,0.502657,0.086695,0.222981,0.498092,1.908545,0.017732,0.084908,4.152012,4.155223,0.017214,0.088209,4.158434,10.172654,29.689827,3.971809,0.338282,2.287416,0.234877,0.763432,0.086201,0.932007,0.875827,0.130516,0.023492,0.634929,0.013141,0.297102,0.794204,0.089906,0.169623,0.706891,0.012355,0.313963,0.889467,0.043419,0.188845,0.88362,0.032783,0.308559,0.792624,0.058331,0.254716,30.496447,1.740585,59.252312,47.53735,5.266884,89.807808,4.406802,7.463479,1.350124,4.947917,20.768229,3.776443,0.487532,2.04204,6.707815,1.837848,1.837848,6.41076,0.248242,1.30844,6.130118,0.371851,1.901379,0.015378,0.080413,4.162674,4.166707,0.017732,0.082002,4.170739,4.327356,0.197133,0.724424,2.175923,0.323242,10.0,3.23242,0.805374,0.745138,2.498958,0.508702,0.521702,1.670969,1.438254,3.285745,34.903458,2.657618,10.188266,5.275402,0.518153,0.62477,3.740649,3.588059,0.548483,0.647289,3.435468,7.003171,20.19405,4.236597,0.577748,1.631277,0.50125,0.690207,0.691926,0.487778,0.693645,debug-acdc-bart-9032024
2,bert,0.767772,0.773026,32.129586,56.076954,55.467281,1.06542,0.630138,-0.017199,-0.278009,20.084024,78.377641,109.454734,0.612453,2.501446,0.315078,0.722873,0.120661,0.294241,0.574516,0.1894,0.318631,0.598586,0.179187,0.538825,0.671968,0.248399,0.378656,0.813468,0.111149,0.369086,0.676282,0.169759,0.536359,9.265526,27.935907,4.021411,0.452299,0.575627,0.020173,-0.090542,0.811853,27.765484,75.470982,42.528553,12.821336,44.374405,49.844261,34.649845,54.034348,59.468624,0.954478,2.07966,1.200574,0.693457,1.098016,0.688531,0.645494,44.757233,97.462502,1.291237,2.41501,1.095427,0.239328,0.854072,9.519236,35.55395,63.334344,0.448636,4.651384,0.735423,0.49977,0.194606,0.06926,0.5,0.064484,0.148644,0.488873,0.128573,0.026536,0.521204,0.024792,0.018406,0.49971,0.018068,0.199654,0.501911,0.086105,0.319376,0.389783,4.101505,0.016786,0.090259,4.138282,4.142645,0.01807,0.094844,4.147008,17.389938,42.398548,3.537323,0.293462,2.409081,0.320978,0.791878,0.083413,0.935442,0.882398,0.127037,0.024942,0.648399,0.013233,0.33174,0.804688,0.086741,0.145963,0.829001,0.011961,0.347485,0.900174,0.041341,0.220755,0.897139,0.031867,0.332472,0.821954,0.056514,0.23933,35.818462,1.949207,69.687721,50.442513,7.344993,93.540031,4.101583,7.200733,1.002434,5.46875,24.414062,3.773712,1.086957,4.223945,6.529994,2.174762,2.174762,8.593822,0.250827,1.285168,6.139731,0.304308,1.998082,0.019444,0.08946,4.14245,4.145315,0.018712,0.089369,4.14818,3.460258,0.22904,0.13708,0.921306,4.65503,17.400046,16.7844,0.870129,0.53097,3.795896,0.280748,-0.124146,6.335429,-7.292712,14.571794,51.051065,4.898752,16.048599,5.012216,0.519072,0.647289,3.602581,3.352214,0.54159,0.674632,3.101847,6.478405,23.127453,4.151929,0.566426,2.808384,0.484583,0.698199,0.693588,0.517917,0.688977,debug-acdc-bert-9032024
3,clip-b16,0.583313,0.586698,29.563564,52.64336,52.723345,1.039815,0.629292,0.199019,0.699743,45.14831,58.369999,72.696039,0.961861,0.212422,0.765198,0.920979,0.073965,0.551915,0.754613,0.196419,0.428896,0.711965,0.261576,0.807658,0.868467,0.163832,0.49489,0.866615,0.102052,0.609711,0.824528,0.159569,0.370855,71.516716,92.167595,1.026104,0.190074,0.563548,0.03069,0.134229,68.788779,81.364349,94.013561,52.544109,35.589828,45.708652,60.555397,52.042732,60.848839,60.533722,0.865286,1.486372,1.153551,0.65211,0.945489,0.647927,0.636439,61.461845,99.112503,0.881165,1.318337,0.543554,0.28277,0.913485,38.089237,49.46423,68.02459,0.977226,0.164157,0.914735,0.83858,0.123853,0.133944,0.664063,0.067604,0.572381,0.857944,0.089783,0.466375,0.942948,0.020832,0.690638,0.979751,0.011338,0.555615,0.856657,0.062682,0.128799,0.852414,0.653928,0.063447,0.212557,3.786345,3.80002,0.059223,0.220848,3.813695,91.298401,98.639702,0.310249,0.850268,0.635735,0.860287,0.974267,0.037059,0.989463,0.980555,0.056214,0.841906,0.990037,0.004641,0.794845,0.942675,0.047178,0.997368,0.999965,0.000702,0.912438,0.99308,0.014604,0.767574,0.981193,0.015277,0.880554,0.980253,0.025097,0.343366,64.846588,29.980165,99.713013,74.044312,48.088627,100.0,2.312976,4.610023,0.01593,37.760418,67.96875,2.760987,81.593674,95.770058,0.754218,1.456479,1.456479,4.384621,21.912863,44.366909,4.08327,0.959975,0.148252,0.066286,0.214269,3.76708,3.778309,0.059256,0.214788,3.789538,2.011645,0.228527,0.396605,1.324538,7.596339,16.139887,23.913012,0.98769,0.076278,0.746756,0.506998,0.136958,0.516096,32.732588,40.355453,78.716546,53.465958,84.134743,1.742838,0.565028,0.753676,1.762938,1.731197,0.625689,0.806526,1.699456,73.829659,92.540024,0.97616,0.987282,0.066546,0.425278,0.692561,0.692498,0.551806,0.692435,debug-acdc-clip-b16-512-31032024
4,convnextv2-base,0.41436,0.41601,33.452904,61.323517,60.739414,0.402027,0.625735,0.245454,0.883061,0.905346,20.288492,65.366348,0.959871,0.21719,0.747682,0.908532,0.079875,0.553183,0.754638,0.20635,0.438143,0.717511,0.211536,0.805775,0.867016,0.174546,0.534408,0.875168,0.113653,0.615838,0.824573,0.157192,0.565546,82.404457,97.173569,0.674328,-0.069855,0.72239,0.056311,0.363216,69.889688,108.802656,82.651205,52.532558,35.41935,45.800877,60.15786,52.083954,61.043514,60.706806,0.859377,1.46419,1.143953,0.650135,0.936152,0.645701,0.635208,79.319077,99.800003,0.465694,2.378376,0.900627,0.4378,1.552991,35.865328,50.359252,74.035147,1.360745,39.523504,0.922949,0.863,0.113302,0.128984,0.681692,0.068313,0.654535,0.881703,0.086248,0.415841,0.947274,0.020049,0.672167,0.973885,0.010908,0.558895,0.869511,0.059764,0.202634,0.506648,-7.079307,0.058243,0.203995,3.83791,3.849165,0.052925,0.202193,3.86042,92.663116,98.975471,0.283863,0.839228,96.458665,0.903905,0.981648,0.03372,0.991025,0.986166,0.038251,0.929646,0.998955,0.00426,0.840736,0.955604,0.042114,0.994987,0.999929,0.000563,0.873177,0.989729,0.015825,0.773781,0.979206,0.017671,0.901037,0.984462,0.021772,0.548383,87.320435,74.880013,99.760841,91.933617,83.899117,99.968109,0.827558,1.643827,0.011288,51.757812,81.445312,2.053579,85.052353,96.733139,0.646129,1.577138,1.577138,4.913095,48.116394,74.939789,2.385665,0.620759,-212.457994,0.069102,0.214146,3.763579,3.775449,0.060507,0.219102,3.787318,2.557237,0.22391,0.000821,2.038372,0.323242,10.0,3.23242,0.93195,0.347292,20.361689,0.155569,5.075783,13.144416,18.651888,23.841524,64.291122,54.378284,84.750984,1.692317,0.568934,0.805607,1.698529,1.726562,0.613971,0.806526,1.754595,84.249474,95.401688,0.616269,0.726614,10.605822,0.510417,0.691239,0.691657,0.462778,0.692075,debug-acdc-convnextv2-base-9032024
5,deit3-b16,0.562037,0.565171,50.162196,75.539291,75.143587,1.385229,0.636853,0.265982,0.914186,35.081536,46.845039,63.900622,0.948787,0.275883,0.760484,0.921575,0.072582,0.53772,0.753387,0.195167,0.455444,0.720776,0.193854,0.815632,0.868783,0.158535,0.544855,0.880372,0.091234,0.622827,0.828979,0.142274,0.375662,61.703823,86.65406,1.380363,0.207643,0.626092,0.030988,0.146506,63.926107,76.805351,93.096244,52.554836,35.53352,45.774403,60.464371,52.206165,60.860577,60.50684,0.868381,1.527259,1.150773,0.654099,0.944456,0.648676,0.635717,73.696053,99.737503,0.637889,1.460949,0.650212,0.310589,1.005667,32.558817,43.375015,64.245912,0.95799,0.280003,0.90887,0.819149,0.125971,0.114768,0.639484,0.060179,0.521323,0.831294,0.097416,0.37465,0.93476,0.016775,0.675209,0.979598,0.011056,0.518964,0.840857,0.062279,0.39017,0.804581,1.2067,0.05439,0.193441,3.862041,3.872587,0.050661,0.199591,3.883134,87.252182,97.754097,0.44436,0.794328,0.820706,0.861763,0.977199,0.035715,0.989487,0.980184,0.053412,0.886808,0.996763,0.004748,0.758586,0.934775,0.055448,0.994598,0.999929,0.000986,0.878274,0.990244,0.014839,0.75695,0.979761,0.016559,0.875209,0.979836,0.025958,0.328734,58.385941,17.457464,99.314415,66.81633,33.664551,99.968109,2.687502,5.344537,0.030465,34.114582,68.619789,2.655795,81.779495,94.4813,0.807097,1.378354,1.378354,4.135864,28.670059,53.417755,3.570246,0.985332,0.056905,0.065306,0.211758,3.805508,3.815632,0.060417,0.213965,3.825756,1.517677,0.222096,0.269498,0.982946,11.62466,22.610285,42.770438,0.986616,0.072521,0.898437,0.5526,0.167201,0.625892,28.876969,36.208042,77.461112,48.337345,79.934875,1.979457,0.542509,0.745864,1.856413,1.902985,0.527803,0.776195,1.949558,75.037758,91.618851,0.975303,0.959775,0.189089,0.5,0.691406,0.691406,0.5,0.691406,debug-acdc-deit3-b16-512-9032024
6,dino-b16,0.560648,0.563765,27.994537,50.658131,51.255843,1.341776,0.638767,0.259287,0.889526,33.174095,44.23954,64.168709,0.963227,0.182379,0.758957,0.921371,0.075205,0.542244,0.752738,0.209443,0.438833,0.716822,0.210768,0.792471,0.855844,0.176927,0.526088,0.873399,0.106371,0.611719,0.824035,0.155743,0.456763,61.196259,86.236069,1.408972,0.198762,0.636738,0.029259,0.140554,63.871944,75.661316,93.391345,52.526684,35.310368,45.746029,60.657906,52.477352,60.679329,60.30722,0.868729,1.52149,1.154731,0.653266,0.941955,0.647975,0.636708,60.241444,98.862503,0.90671,1.526463,0.655372,0.31977,1.043477,29.845586,41.188881,61.742191,0.955455,0.263008,0.900429,0.815154,0.136868,0.145276,0.664927,0.065153,0.507019,0.829809,0.101821,0.465422,0.94445,0.018043,0.701961,0.979458,0.008796,0.544022,0.84676,0.066136,0.295474,0.800249,1.110317,0.056811,0.196414,3.818525,3.823214,0.057296,0.209132,3.827904,86.480362,97.449043,0.473914,0.798328,0.849084,0.843227,0.969427,0.043269,0.987924,0.976791,0.059655,0.826146,0.993358,0.004652,0.795615,0.945071,0.048861,0.989724,0.999858,0.001315,0.888333,0.991102,0.014408,0.750876,0.974313,0.01873,0.868835,0.97856,0.02727,0.408693,59.447559,19.532871,99.362244,68.318436,36.652813,99.984055,2.626684,5.22336,0.030008,26.432291,54.101562,3.342566,68.330406,89.080482,1.330247,1.435504,1.435504,4.336665,18.455582,40.203323,4.243761,0.896685,0.395391,0.058547,0.198825,3.830781,3.834272,0.055156,0.207082,3.837764,3.59455,0.227783,0.819929,2.427963,5.989678,10.475529,25.268564,0.987855,0.074817,0.968016,0.522662,0.180554,0.661319,27.321649,34.652905,68.636712,47.250439,79.217384,2.011414,0.519991,0.726333,2.079312,2.029545,0.56204,0.783088,1.979779,58.800213,81.716248,1.729808,0.959417,0.20126,0.450417,0.692785,0.692889,0.424444,0.692992,debug-acdc-dino-b16-512-9032024
7,effformer-s0,0.539577,0.716573,28.253764,62.672143,37.187413,3.786354,0.632552,0.423205,1.787743,28.240358,22.420348,60.357365,0.735181,311.535767,0.649992,0.891769,0.08444,0.40248,0.67359,0.174388,0.405115,0.682954,0.165812,0.752828,0.8322,0.175415,0.448669,0.849232,0.095979,0.531817,0.785949,0.139207,0.361917,14.42078,39.848724,3.635036,0.302926,0.45198,-0.01563,0.417855,52.670607,78.518658,46.565338,45.213261,14.183392,44.793575,53.588108,44.800533,54.358669,59.569378,0.951993,2.076338,1.20389,0.689646,1.074091,0.686052,0.644838,56.594734,98.75,1.018385,2.366133,0.868089,0.459146,2.399551,27.165891,14.095085,51.018565,0.724367,179.395859,0.871587,0.732111,0.157192,0.068866,0.574034,0.053108,0.292004,0.711891,0.116917,0.210072,0.867923,0.02509,0.237263,0.93637,0.017806,0.335958,0.764466,0.074023,0.32684,0.68517,3.573644,0.038708,0.154528,4.017169,4.024096,0.038708,0.151284,4.031023,60.177311,85.432098,1.584574,0.59217,1340.480591,0.579945,0.916613,0.062671,0.965203,0.934708,0.094052,0.049584,0.756288,0.012972,0.467295,0.848766,0.075528,0.219378,0.945436,0.011889,0.670395,0.966323,0.027899,0.334014,0.936192,0.029112,0.469402,0.900618,0.044875,0.230441,47.809746,3.992725,91.626762,55.324039,11.749557,98.898514,3.461311,6.641551,0.281072,2.148438,9.635417,4.05603,41.839832,69.749039,3.072035,2.74767,2.74767,13.900888,0.237898,1.248966,6.129591,0.835908,1495.147827,0.041885,0.155532,4.020712,4.026254,0.037177,0.152445,4.031796,4.288333,0.211576,0.070679,1.38355,-0.295494,10.184449,34.348615,0.957604,0.157702,0.504879,0.368006,-0.493287,3.958249,21.225499,11.496292,38.701943,24.689142,54.302212,3.354646,0.519991,0.672564,2.226351,2.214739,0.516085,0.721507,2.203126,0.792812,4.968287,4.749167,0.867747,30.777681,0.473056,0.692109,0.693594,0.4875,0.695079,debug-aircraft-fs-classification-effformer-s0-...
8,effv2-rw-s,0.40563,0.407194,61.47151,87.221474,87.253705,2.218599,0.694129,0.465928,1.444493,14.218676,20.438848,42.741177,0.905912,1.201142,0.736924,0.900437,0.077381,0.533335,0.751606,0.212149,0.430741,0.710683,0.221275,0.805519,0.870372,0.162695,0.499439,0.858335,0.116572,0.601192,0.818286,0.158014,0.382158,65.266716,89.77906,1.313418,0.202399,0.57462,0.032058,0.142042,64.238483,77.438736,93.1606,39.637341,11.754904,38.795815,50.374279,33.523605,51.114349,52.27631,1.009231,2.079872,1.352569,0.692819,1.099139,0.691973,0.691156,42.085526,96.412498,1.420667,2.593563,0.944034,0.567035,1.741864,10.237104,15.753405,39.177912,0.916984,0.494343,0.912593,0.842967,0.120161,0.125692,0.664431,0.06427,0.602877,0.863237,0.091198,0.344796,0.950563,0.018053,0.701306,0.974989,0.01131,0.537453,0.859237,0.060998,0.207131,0.596898,14.336714,0.025866,0.118693,4.049098,4.051786,0.024795,0.121285,4.054474,85.862808,97.076744,0.597491,0.737152,5.838009,0.807705,0.965943,0.043065,0.980309,0.969181,0.066337,0.890409,0.997876,0.004213,0.70157,0.909157,0.05751,0.949925,0.998901,0.002267,0.864451,0.991342,0.017774,0.769353,0.975707,0.016452,0.85196,0.972587,0.02966,0.668098,87.752937,75.872589,99.633293,91.717018,83.449989,99.984055,0.801419,1.580514,0.022324,4.947917,18.815104,4.361852,73.517426,91.386269,1.219578,1.925611,1.925611,6.455574,0.408564,2.161771,6.063885,0.628709,27.530499,0.032929,0.149876,3.982876,3.986747,0.038461,0.156207,3.990618,2.01298,0.227779,0.399157,1.326675,8.232061,15.147624,28.677225,0.986207,0.094821,1.566818,0.693605,0.300327,1.025733,9.28407,13.209156,56.804983,50.31086,81.924255,1.881371,0.495634,0.586627,10.6997,10.896326,0.462316,0.586627,11.092952,3.363787,15.255964,4.968294,0.884897,0.608705,0.499583,0.691508,0.691589,0.500278,0.691669,debug-acdc-effv2-rw-s-9032024
9,flex-b-1200ep,0.541337,0.544269,53.079033,77.973427,77.711524,1.12193,0.636601,0.209301,0.746021,42.138797,54.628418,70.487888,0.944313,0.311596,0.761663,0.916881,0.074291,0.557579,0.770244,0.203239,0.438248,0.715043,0.222159,0.815037,0.875956,0.158398,0.529572,0.880561,0.102428,0.62042,0.831737,0.152103,0.376406,53.224522,80.632965,1.761222,0.193869,0.617445,0.028996,0.136098,67.524648,79.642441,93.640814,51.83083,34.895889,45.917717,59.302315,50.384392,59.906704,60.578979,0.869842,1.487887,1.153855,0.665756,0.982806,0.650563,0.634901,59.863155,98.875,0.925399,1.380164,0.621037,0.293294,0.951781,35.139361,45.798512,65.967574,0.952227,0.329608,0.913311,0.832195,0.123623,0.118343,0.646564,0.060532,0.572244,0.854787,0.092918,0.409283,0.939713,0.016247,0.752629,0.984683,0.009379,0.553162,0.851589,0.06054,0.401484,0.783887,1.357674,0.050109,0.188551,3.846513,3.851919,0.050447,0.201574,3.857324,89.095078,98.074043,0.386901,0.816184,0.80634,0.875722,0.974775,0.038177,0.986084,0.975213,0.062957,0.813408,0.99415,0.006338,0.771325,0.933202,0.053791,0.960272,0.999468,0.002326,0.914955,0.994609,0.015137,0.767361,0.945151,0.016291,0.869875,0.973796,0.02786,0.44848,70.991737,42.286407,99.697067,79.88855,59.793053,99.984055,1.987586,3.957114,0.018057,22.200521,56.770832,3.426509,82.047234,95.236572,0.767253,1.526898,1.526898,4.69312,9.108842,23.494297,4.906845,0.971718,0.100641,0.051608,0.190478,3.859116,3.861938,0.049895,0.200019,3.864762,2.372985,0.234913,0.507031,1.584009,6.8884,13.225321,30.002304,0.987893,0.06881,0.954676,0.63163,0.175795,0.646925,28.666773,37.213219,76.762781,50.224388,82.161781,1.86101,0.56204,0.726333,1.919998,1.801394,0.645221,0.795726,1.68279,58.11311,82.267441,1.723406,0.94918,0.273503,0.496806,0.692402,0.691165,0.531389,0.689927,debug-acdc-flex-b-1200ep-9032024


In [None]:
model_df = cleaned_df.dropna()
# reset the index
model_df = model_df.reset_index()
# make every row into a sample in torch
model_df = model_df.drop(columns=["exp_name", "model_name", "index"])
# cast all columns to float
model_df = model_df.astype(float)
column_means = model_df.mean()
column_sd = model_df.std()
model_df = model_df.dropna()

display_pretty(model_df)
display_pretty(column_means)
display_pretty(column_sd)

Unnamed: 0,acdc.dice_loss,acdc.loss,acdc.mIoU,acdc.mean_accuracy,acdc.overall_accuracy,ade20k.ce_loss,ade20k.dice_loss,ade20k.focal_loss,ade20k.loss,ade20k.mIoU,ade20k.mean_accuracy,ade20k.overall_accuracy,aircraft.accuracy_top_1,aircraft.loss,chexpert.testing/ensemble_3/0-aps,chexpert.testing/ensemble_3/0-auc,chexpert.testing/ensemble_3/0-bs,chexpert.testing/ensemble_3/1-aps,chexpert.testing/ensemble_3/1-auc,chexpert.testing/ensemble_3/1-bs,chexpert.testing/ensemble_3/2-aps,chexpert.testing/ensemble_3/2-auc,chexpert.testing/ensemble_3/2-bs,chexpert.testing/ensemble_3/3-aps,chexpert.testing/ensemble_3/3-auc,chexpert.testing/ensemble_3/3-bs,chexpert.testing/ensemble_3/4-aps,chexpert.testing/ensemble_3/4-auc,chexpert.testing/ensemble_3/4-bs,chexpert.testing/ensemble_3/aps-macro,chexpert.testing/ensemble_3/auc-macro,chexpert.testing/ensemble_3/bs-macro,chexpert.testing/ensemble_3/loss,cifar100.accuracy_top_1,cifar100.accuracy_top_5,cifar100.loss,cityscapes.ce_loss,cityscapes.dice_loss,cityscapes.focal_loss,cityscapes.loss,cityscapes.mIoU,cityscapes.mean_accuracy,cityscapes.overall_accuracy,clevr.accuracy_top_1,clevr.accuracy_top_1_colour,clevr.accuracy_top_1_count,clevr.accuracy_top_1_material,clevr.accuracy_top_1_shape,clevr.accuracy_top_1_size,clevr.accuracy_top_1_yes_no,clevr.loss,clevr.loss_colour,clevr.loss_count,clevr.loss_material,clevr.loss_shape,clevr.loss_size,clevr.loss_yes_no,clevr-math.accuracy_top_1,clevr-math.accuracy_top_5,clevr-math.loss,coco.ce_loss,coco.dice_loss,coco.focal_loss,coco.loss,coco.mIoU,coco.mean_accuracy,coco.overall_accuracy,cubirds.accuracy_top_1,cubirds.loss,diabetic.testing/ensemble_3/0-aps,diabetic.testing/ensemble_3/0-auc,diabetic.testing/ensemble_3/0-bs,diabetic.testing/ensemble_3/1-aps,diabetic.testing/ensemble_3/1-auc,diabetic.testing/ensemble_3/1-bs,diabetic.testing/ensemble_3/2-aps,diabetic.testing/ensemble_3/2-auc,diabetic.testing/ensemble_3/2-bs,diabetic.testing/ensemble_3/3-aps,diabetic.testing/ensemble_3/3-auc,diabetic.testing/ensemble_3/3-bs,diabetic.testing/ensemble_3/4-aps,diabetic.testing/ensemble_3/4-auc,diabetic.testing/ensemble_3/4-bs,diabetic.testing/ensemble_3/aps-macro,diabetic.testing/ensemble_3/auc-macro,diabetic.testing/ensemble_3/bs-macro,diabetic.testing/ensemble_3/loss,dtextures.accuracy_top_1,dtextures.loss,flickr30k.image_to_text_accuracy,flickr30k.image_to_text_accuracy_top_5,flickr30k.image_to_text_loss,flickr30k.loss,flickr30k.text_to_image_accuracy,flickr30k.text_to_image_accuracy_top_5,flickr30k.text_to_image_loss,food101.accuracy_top_1,food101.accuracy_top_5,food101.loss,fungi.accuracy_top_1,fungi.loss,ham10k.testing/ensemble_3/0-aps,ham10k.testing/ensemble_3/0-auc,ham10k.testing/ensemble_3/0-bs,ham10k.testing/ensemble_3/1-aps,ham10k.testing/ensemble_3/1-auc,ham10k.testing/ensemble_3/1-bs,ham10k.testing/ensemble_3/2-aps,ham10k.testing/ensemble_3/2-auc,ham10k.testing/ensemble_3/2-bs,ham10k.testing/ensemble_3/3-aps,ham10k.testing/ensemble_3/3-auc,ham10k.testing/ensemble_3/3-bs,ham10k.testing/ensemble_3/4-aps,ham10k.testing/ensemble_3/4-auc,ham10k.testing/ensemble_3/4-bs,ham10k.testing/ensemble_3/5-aps,ham10k.testing/ensemble_3/5-auc,ham10k.testing/ensemble_3/5-bs,ham10k.testing/ensemble_3/6-aps,ham10k.testing/ensemble_3/6-auc,ham10k.testing/ensemble_3/6-bs,ham10k.testing/ensemble_3/aps-macro,ham10k.testing/ensemble_3/auc-macro,ham10k.testing/ensemble_3/bs-macro,ham10k.testing/ensemble_3/loss,happy.accuracy_top_1,happy.accuracy_top_1_individual,happy.accuracy_top_1_species,happy.accuracy_top_5,happy.accuracy_top_5_individual,happy.accuracy_top_5_species,happy.loss,happy.loss_individual,happy.loss_species,hmdb51.accuracy_top_1,hmdb51.accuracy_top_5,hmdb51.loss,imagenet1k.accuracy_top_1,imagenet1k.accuracy_top_5,imagenet1k.loss,iwildcam.loss,iwildcam.mae_loss,iwildcam.mse_loss,kinetics.accuracy_top_1,kinetics.accuracy_top_5,kinetics.loss,mini.accuracy_top_1,mini.loss,newyorkercaptioncontest.image_to_text_accuracy,newyorkercaptioncontest.image_to_text_accuracy_top_5,newyorkercaptioncontest.image_to_text_loss,newyorkercaptioncontest.loss,newyorkercaptioncontest.text_to_image_accuracy,newyorkercaptioncontest.text_to_image_accuracy_top_5,newyorkercaptioncontest.text_to_image_loss,nyu.ce_loss,nyu.dice_loss,nyu.focal_loss,nyu.loss,nyu.mIoU,nyu.mean_accuracy,nyu.overall_accuracy,omniglot.accuracy_top_1,omniglot.loss,pascal.ce_loss,pascal.dice_loss,pascal.focal_loss,pascal.loss,pascal.mIoU,pascal.mean_accuracy,pascal.overall_accuracy,places365.accuracy_top_1,places365.accuracy_top_5,places365.loss,pokemonblipcaptions.image_to_text_accuracy,pokemonblipcaptions.image_to_text_accuracy_top_5,pokemonblipcaptions.image_to_text_loss,pokemonblipcaptions.loss,pokemonblipcaptions.text_to_image_accuracy,pokemonblipcaptions.text_to_image_accuracy_top_5,pokemonblipcaptions.text_to_image_loss,ucf.accuracy_top_1,ucf.accuracy_top_5,ucf.loss,vgg.accuracy_top_1,vgg.loss,winoground.image_to_text_accuracy,winoground.image_to_text_loss,winoground.loss,winoground.text_to_image_accuracy,winoground.text_to_image_loss
0,0.559243,0.562368,54.338646,79.443932,78.977887,1.439675,0.623715,0.291468,0.969018,31.810069,42.73518,62.651232,0.941239,0.262639,0.744916,0.920518,0.075839,0.518032,0.734637,0.19314,0.42149,0.708754,0.193698,0.777515,0.844941,0.176546,0.50848,0.866368,0.100756,0.594087,0.815044,0.147996,0.449167,40.814091,71.516716,2.278613,0.208765,0.616431,0.031152,0.148707,61.390024,75.50766,93.116434,52.169037,35.75423,45.621979,59.972759,51.118793,60.342369,60.226028,0.871289,1.508475,1.156839,0.659476,0.968794,0.650669,0.636307,59.341446,98.925003,0.935315,1.555699,0.570834,0.33662,1.073815,30.990615,43.468811,63.096144,0.940776,0.354046,0.895456,0.802728,0.137832,0.107936,0.636813,0.064423,0.511298,0.825096,0.100858,0.354947,0.933615,0.017826,0.722953,0.983035,0.009418,0.518518,0.836257,0.066072,0.180801,0.772806,1.478162,0.05225,0.184427,3.882412,3.889763,0.050999,0.195006,3.897114,84.751724,96.998093,0.537537,0.766097,0.887478,0.818958,0.971394,0.043687,0.985305,0.973967,0.058963,0.952532,0.999191,0.00348,0.756843,0.935443,0.052584,0.9647,0.999539,0.002152,0.861583,0.990035,0.016942,0.731135,0.972677,0.01902,0.867294,0.977464,0.028118,0.361344,52.962318,7.375529,98.54911,59.551193,19.19805,99.904335,3.098909,6.139311,0.058507,28.385416,58.658855,3.175461,77.671432,92.912804,0.965971,1.587222,1.587222,5.39161,23.338408,45.657619,4.014843,0.97935,0.077078,0.049986,0.185464,3.902347,3.907792,0.050875,0.190297,3.913237,1.709063,0.234403,0.316457,1.114036,12.077721,23.015949,32.480359,0.985692,0.07505,1.112825,0.490195,0.223187,0.769829,26.895896,35.571377,71.831284,47.252628,79.488289,2.015198,0.535616,0.722426,1.908139,1.936036,0.554228,0.741958,1.963933,65.259743,86.329659,1.414807,0.944496,0.302574,0.380417,0.693004,0.692399,0.541528,0.691794
2,0.583313,0.586698,29.563564,52.64336,52.723345,1.039815,0.629292,0.199019,0.699743,45.14831,58.369999,72.696039,0.961861,0.212422,0.765198,0.920979,0.073965,0.551915,0.754613,0.196419,0.428896,0.711965,0.261576,0.807658,0.868467,0.163832,0.49489,0.866615,0.102052,0.609711,0.824528,0.159569,0.370855,71.516716,92.167595,1.026104,0.190074,0.563548,0.03069,0.134229,68.788779,81.364349,94.013561,52.544109,35.589828,45.708652,60.555397,52.042732,60.848839,60.533722,0.865286,1.486372,1.153551,0.65211,0.945489,0.647927,0.636439,61.461845,99.112503,0.881165,1.318337,0.543554,0.28277,0.913485,38.089237,49.46423,68.02459,0.977226,0.164157,0.914735,0.83858,0.123853,0.133944,0.664063,0.067604,0.572381,0.857944,0.089783,0.466375,0.942948,0.020832,0.690638,0.979751,0.011338,0.555615,0.856657,0.062682,0.128799,0.852414,0.653928,0.063447,0.212557,3.786345,3.80002,0.059223,0.220848,3.813695,91.298401,98.639702,0.310249,0.850268,0.635735,0.860287,0.974267,0.037059,0.989463,0.980555,0.056214,0.841906,0.990037,0.004641,0.794845,0.942675,0.047178,0.997368,0.999965,0.000702,0.912438,0.99308,0.014604,0.767574,0.981193,0.015277,0.880554,0.980253,0.025097,0.343366,64.846588,29.980165,99.713013,74.044312,48.088627,100.0,2.312976,4.610023,0.01593,37.760418,67.96875,2.760987,81.593674,95.770058,0.754218,1.456479,1.456479,4.384621,21.912863,44.366909,4.08327,0.959975,0.148252,0.066286,0.214269,3.76708,3.778309,0.059256,0.214788,3.789538,2.011645,0.228527,0.396605,1.324538,7.596339,16.139887,23.913012,0.98769,0.076278,0.746756,0.506998,0.136958,0.516096,32.732588,40.355453,78.716546,53.465958,84.134743,1.742838,0.565028,0.753676,1.762938,1.731197,0.625689,0.806526,1.699456,73.829659,92.540024,0.97616,0.987282,0.066546,0.425278,0.692561,0.692498,0.551806,0.692435
3,0.562037,0.565171,50.162196,75.539291,75.143587,1.385229,0.636853,0.265982,0.914186,35.081536,46.845039,63.900622,0.948787,0.275883,0.760484,0.921575,0.072582,0.53772,0.753387,0.195167,0.455444,0.720776,0.193854,0.815632,0.868783,0.158535,0.544855,0.880372,0.091234,0.622827,0.828979,0.142274,0.375662,61.703823,86.65406,1.380363,0.207643,0.626092,0.030988,0.146506,63.926107,76.805351,93.096244,52.554836,35.53352,45.774403,60.464371,52.206165,60.860577,60.50684,0.868381,1.527259,1.150773,0.654099,0.944456,0.648676,0.635717,73.696053,99.737503,0.637889,1.460949,0.650212,0.310589,1.005667,32.558817,43.375015,64.245912,0.95799,0.280003,0.90887,0.819149,0.125971,0.114768,0.639484,0.060179,0.521323,0.831294,0.097416,0.37465,0.93476,0.016775,0.675209,0.979598,0.011056,0.518964,0.840857,0.062279,0.39017,0.804581,1.2067,0.05439,0.193441,3.862041,3.872587,0.050661,0.199591,3.883134,87.252182,97.754097,0.44436,0.794328,0.820706,0.861763,0.977199,0.035715,0.989487,0.980184,0.053412,0.886808,0.996763,0.004748,0.758586,0.934775,0.055448,0.994598,0.999929,0.000986,0.878274,0.990244,0.014839,0.75695,0.979761,0.016559,0.875209,0.979836,0.025958,0.328734,58.385941,17.457464,99.314415,66.81633,33.664551,99.968109,2.687502,5.344537,0.030465,34.114582,68.619789,2.655795,81.779495,94.4813,0.807097,1.378354,1.378354,4.135864,28.670059,53.417755,3.570246,0.985332,0.056905,0.065306,0.211758,3.805508,3.815632,0.060417,0.213965,3.825756,1.517677,0.222096,0.269498,0.982946,11.62466,22.610285,42.770438,0.986616,0.072521,0.898437,0.5526,0.167201,0.625892,28.876969,36.208042,77.461112,48.337345,79.934875,1.979457,0.542509,0.745864,1.856413,1.902985,0.527803,0.776195,1.949558,75.037758,91.618851,0.975303,0.959775,0.189089,0.5,0.691406,0.691406,0.5,0.691406
4,0.560648,0.563765,27.994537,50.658131,51.255843,1.341776,0.638767,0.259287,0.889526,33.174095,44.23954,64.168709,0.963227,0.182379,0.758957,0.921371,0.075205,0.542244,0.752738,0.209443,0.438833,0.716822,0.210768,0.792471,0.855844,0.176927,0.526088,0.873399,0.106371,0.611719,0.824035,0.155743,0.456763,61.196259,86.236069,1.408972,0.198762,0.636738,0.029259,0.140554,63.871944,75.661316,93.391345,52.526684,35.310368,45.746029,60.657906,52.477352,60.679329,60.30722,0.868729,1.52149,1.154731,0.653266,0.941955,0.647975,0.636708,60.241444,98.862503,0.90671,1.526463,0.655372,0.31977,1.043477,29.845586,41.188881,61.742191,0.955455,0.263008,0.900429,0.815154,0.136868,0.145276,0.664927,0.065153,0.507019,0.829809,0.101821,0.465422,0.94445,0.018043,0.701961,0.979458,0.008796,0.544022,0.84676,0.066136,0.295474,0.800249,1.110317,0.056811,0.196414,3.818525,3.823214,0.057296,0.209132,3.827904,86.480362,97.449043,0.473914,0.798328,0.849084,0.843227,0.969427,0.043269,0.987924,0.976791,0.059655,0.826146,0.993358,0.004652,0.795615,0.945071,0.048861,0.989724,0.999858,0.001315,0.888333,0.991102,0.014408,0.750876,0.974313,0.01873,0.868835,0.97856,0.02727,0.408693,59.447559,19.532871,99.362244,68.318436,36.652813,99.984055,2.626684,5.22336,0.030008,26.432291,54.101562,3.342566,68.330406,89.080482,1.330247,1.435504,1.435504,4.336665,18.455582,40.203323,4.243761,0.896685,0.395391,0.058547,0.198825,3.830781,3.834272,0.055156,0.207082,3.837764,3.59455,0.227783,0.819929,2.427963,5.989678,10.475529,25.268564,0.987855,0.074817,0.968016,0.522662,0.180554,0.661319,27.321649,34.652905,68.636712,47.250439,79.217384,2.011414,0.519991,0.726333,2.079312,2.029545,0.56204,0.783088,1.979779,58.800213,81.716248,1.729808,0.959417,0.20126,0.450417,0.692785,0.692889,0.424444,0.692992
5,0.40563,0.407194,61.47151,87.221474,87.253705,2.218599,0.694129,0.465928,1.444493,14.218676,20.438848,42.741177,0.905912,1.201142,0.736924,0.900437,0.077381,0.533335,0.751606,0.212149,0.430741,0.710683,0.221275,0.805519,0.870372,0.162695,0.499439,0.858335,0.116572,0.601192,0.818286,0.158014,0.382158,65.266716,89.77906,1.313418,0.202399,0.57462,0.032058,0.142042,64.238483,77.438736,93.1606,39.637341,11.754904,38.795815,50.374279,33.523605,51.114349,52.27631,1.009231,2.079872,1.352569,0.692819,1.099139,0.691973,0.691156,42.085526,96.412498,1.420667,2.593563,0.944034,0.567035,1.741864,10.237104,15.753405,39.177912,0.916984,0.494343,0.912593,0.842967,0.120161,0.125692,0.664431,0.06427,0.602877,0.863237,0.091198,0.344796,0.950563,0.018053,0.701306,0.974989,0.01131,0.537453,0.859237,0.060998,0.207131,0.596898,14.336714,0.025866,0.118693,4.049098,4.051786,0.024795,0.121285,4.054474,85.862808,97.076744,0.597491,0.737152,5.838009,0.807705,0.965943,0.043065,0.980309,0.969181,0.066337,0.890409,0.997876,0.004213,0.70157,0.909157,0.05751,0.949925,0.998901,0.002267,0.864451,0.991342,0.017774,0.769353,0.975707,0.016452,0.85196,0.972587,0.02966,0.668098,87.752937,75.872589,99.633293,91.717018,83.449989,99.984055,0.801419,1.580514,0.022324,4.947917,18.815104,4.361852,73.517426,91.386269,1.219578,1.925611,1.925611,6.455574,0.408564,2.161771,6.063885,0.628709,27.530499,0.032929,0.149876,3.982876,3.986747,0.038461,0.156207,3.990618,2.01298,0.227779,0.399157,1.326675,8.232061,15.147624,28.677225,0.986207,0.094821,1.566818,0.693605,0.300327,1.025733,9.28407,13.209156,56.804983,50.31086,81.924255,1.881371,0.495634,0.586627,10.6997,10.896326,0.462316,0.586627,11.092952,3.363787,15.255964,4.968294,0.884897,0.608705,0.499583,0.691508,0.691589,0.500278,0.691669
6,0.541337,0.544269,53.079033,77.973427,77.711524,1.12193,0.636601,0.209301,0.746021,42.138797,54.628418,70.487888,0.944313,0.311596,0.761663,0.916881,0.074291,0.557579,0.770244,0.203239,0.438248,0.715043,0.222159,0.815037,0.875956,0.158398,0.529572,0.880561,0.102428,0.62042,0.831737,0.152103,0.376406,53.224522,80.632965,1.761222,0.193869,0.617445,0.028996,0.136098,67.524648,79.642441,93.640814,51.83083,34.895889,45.917717,59.302315,50.384392,59.906704,60.578979,0.869842,1.487887,1.153855,0.665756,0.982806,0.650563,0.634901,59.863155,98.875,0.925399,1.380164,0.621037,0.293294,0.951781,35.139361,45.798512,65.967574,0.952227,0.329608,0.913311,0.832195,0.123623,0.118343,0.646564,0.060532,0.572244,0.854787,0.092918,0.409283,0.939713,0.016247,0.752629,0.984683,0.009379,0.553162,0.851589,0.06054,0.401484,0.783887,1.357674,0.050109,0.188551,3.846513,3.851919,0.050447,0.201574,3.857324,89.095078,98.074043,0.386901,0.816184,0.80634,0.875722,0.974775,0.038177,0.986084,0.975213,0.062957,0.813408,0.99415,0.006338,0.771325,0.933202,0.053791,0.960272,0.999468,0.002326,0.914955,0.994609,0.015137,0.767361,0.945151,0.016291,0.869875,0.973796,0.02786,0.44848,70.991737,42.286407,99.697067,79.88855,59.793053,99.984055,1.987586,3.957114,0.018057,22.200521,56.770832,3.426509,82.047234,95.236572,0.767253,1.526898,1.526898,4.69312,9.108842,23.494297,4.906845,0.971718,0.100641,0.051608,0.190478,3.859116,3.861938,0.049895,0.200019,3.864762,2.372985,0.234913,0.507031,1.584009,6.8884,13.225321,30.002304,0.987893,0.06881,0.954676,0.63163,0.175795,0.646925,28.666773,37.213219,76.762781,50.224388,82.161781,1.86101,0.56204,0.726333,1.919998,1.801394,0.645221,0.795726,1.68279,58.11311,82.267441,1.723406,0.94918,0.273503,0.496806,0.692402,0.691165,0.531389,0.689927
7,0.520081,0.522795,52.978235,76.889046,76.51343,1.048589,0.624491,0.203771,0.719495,43.298936,56.930649,72.41426,0.966846,0.221655,0.741901,0.913758,0.075899,0.539624,0.753206,0.205921,0.448471,0.720275,0.222195,0.805125,0.866017,0.161697,0.509018,0.870259,0.109162,0.608828,0.824703,0.154975,0.407713,70.302551,91.660034,1.06008,0.184794,0.570769,0.028498,0.131195,67.742234,80.418747,93.719254,48.487133,30.442471,45.474068,53.17514,43.599636,58.244164,59.998108,0.907428,1.770499,1.160291,0.68906,1.053632,0.668349,0.637308,62.892765,99.324997,0.827101,1.447088,0.545314,0.314955,1.003824,33.512416,44.926025,65.840537,0.960155,0.240599,0.91311,0.835811,0.126224,0.122144,0.660734,0.068143,0.562645,0.846124,0.096087,0.401839,0.951168,0.019877,0.617302,0.971352,0.013501,0.523408,0.853038,0.064767,0.123883,0.860968,0.659663,0.060597,0.206621,3.794994,3.804226,0.057939,0.208581,3.813458,91.266754,98.611549,0.313706,0.842181,0.660167,0.842749,0.968644,0.04274,0.988198,0.97859,0.058722,0.827556,0.979636,0.00483,0.771123,0.933417,0.05146,0.951575,0.997056,0.001507,0.907871,0.992591,0.014589,0.752821,0.972647,0.017143,0.863128,0.974654,0.027284,0.333215,57.220684,14.871855,99.569519,65.594597,31.205141,99.984055,2.772584,5.521682,0.023486,8.138021,26.953125,4.798939,73.19973,92.710999,1.123901,1.675611,1.675611,5.803219,0.967108,4.203484,5.872526,0.897042,0.35646,0.069035,0.220928,3.770031,3.779792,0.061791,0.219406,3.789553,7.243077,0.211077,1.752911,4.947794,4.662798,8.295491,24.528856,0.989498,0.065853,0.634229,0.578389,0.123549,0.471767,28.272915,37.401096,86.262763,53.59457,84.244202,1.748257,0.561121,0.829044,1.795741,1.734928,0.596278,0.840763,1.674115,57.399578,83.279221,1.670199,0.967784,0.141057,0.503333,0.691324,0.692098,0.465833,0.692873
8,0.569954,0.573168,26.637793,45.135708,44.428602,2.801601,0.663587,0.618606,1.861663,7.885203,12.281738,34.173813,0.865545,0.410747,0.58307,0.850711,0.094618,0.374402,0.649563,0.178796,0.381303,0.660961,0.168672,0.694,0.783807,0.199228,0.36783,0.815965,0.102192,0.480121,0.752201,0.148701,0.39597,34.195858,66.779457,2.521271,0.382805,0.735695,0.055323,0.259349,40.807861,51.317764,88.49276,50.055836,26.761499,45.287842,58.64756,50.223183,59.649158,59.780823,0.913369,1.914166,1.162567,0.67096,0.983904,0.661022,0.639356,55.608555,98.762497,1.031942,3.113923,0.946919,0.700406,2.107524,6.723784,9.565587,28.878717,0.890719,0.3637,0.865286,0.741399,0.161343,0.088189,0.582621,0.062392,0.323501,0.740253,0.120493,0.2832,0.896791,0.01741,0.269673,0.942264,0.016362,0.36597,0.780666,0.0756,0.181052,0.777135,0.678762,0.041063,0.154404,4.009116,4.013331,0.037976,0.16202,4.017546,74.066223,93.018661,0.946926,0.669236,1.106042,0.359611,0.854019,0.077736,0.953224,0.912322,0.109894,0.028216,0.681271,0.013127,0.395048,0.808127,0.081875,0.10692,0.85748,0.012032,0.359842,0.917428,0.040355,0.240438,0.909968,0.031327,0.349042,0.848659,0.052335,0.23577,50.135525,6.444029,93.827019,58.440392,17.454763,99.426018,3.167848,6.144279,0.191418,2.539062,11.653646,4.97019,74.864128,91.907768,1.041796,2.126656,2.126656,8.544108,0.2741,1.264481,6.224344,0.940737,0.267648,0.046628,0.170605,3.975625,3.978631,0.041277,0.165164,3.981636,1.643746,0.24666,0.27283,1.054117,5.939086,11.816404,33.382884,0.957677,0.162069,2.086416,0.557664,0.407046,1.38638,5.759263,9.394987,50.078716,40.786995,73.092712,2.325439,0.550322,0.692096,1.891234,2.009156,0.535616,0.725414,2.127078,1.294926,5.893235,4.973573,0.876574,0.404099,0.519028,0.690921,0.691666,0.489167,0.692411
9,0.563822,0.566977,28.463864,51.654057,52.494857,1.072201,0.619685,0.212186,0.731589,44.009128,57.46204,72.577981,0.967634,0.260399,0.766135,0.924673,0.072733,0.554731,0.751981,0.21454,0.432703,0.708278,0.210237,0.803859,0.862392,0.169955,0.501037,0.869564,0.109246,0.611693,0.823378,0.155342,0.597469,63.66441,88.893311,1.283344,0.183475,0.568213,0.028194,0.13115,67.56236,80.447975,93.860353,52.720375,36.358265,45.811024,60.372128,51.729858,61.233616,60.822571,0.861514,1.471626,1.146104,0.654637,0.935412,0.649316,0.6359,66.825653,99.612503,0.75808,1.488715,0.477856,0.327155,1.034926,35.625401,46.899635,66.249278,0.958542,0.278561,0.915418,0.840405,0.122768,0.129465,0.666696,0.065037,0.589973,0.859788,0.090119,0.422405,0.956924,0.01643,0.729525,0.97726,0.010672,0.557357,0.860215,0.061005,0.210345,0.842759,0.682935,0.056835,0.206564,3.810198,3.817325,0.058828,0.213233,3.824452,92.789696,98.991295,0.256328,0.82851,0.748053,0.851429,0.97175,0.039306,0.987977,0.977851,0.062106,0.890206,0.997269,0.004486,0.768076,0.938914,0.05398,0.97564,0.999716,0.001221,0.919825,0.993471,0.014998,0.74947,0.981174,0.019676,0.877517,0.980021,0.027967,0.422519,67.433548,35.233807,99.633293,76.920288,53.856525,99.984055,2.138341,4.254755,0.021927,13.736979,36.458332,4.411101,76.031013,93.67807,0.968266,1.567638,1.567638,5.406486,10.658875,26.239731,4.855248,0.905992,0.370918,0.052678,0.190996,3.85473,3.858982,0.048307,0.195952,3.863235,2.033109,0.275219,0.403655,1.34022,5.053129,12.547319,25.973695,0.989714,0.062805,0.611597,0.608258,0.11635,0.453631,26.841742,32.580418,86.303788,53.917469,84.70118,1.708433,0.61489,0.814338,1.588563,1.59018,0.63534,0.836857,1.591796,53.039112,77.910751,1.934597,0.982548,0.086357,0.490417,0.692107,0.692022,0.498472,0.691937
10,0.553526,0.556584,50.14661,73.968163,73.538479,1.858981,0.657545,0.379678,1.224504,21.476005,30.365595,54.444919,0.926141,0.39376,0.747918,0.918431,0.071667,0.547558,0.760648,0.169015,0.429228,0.71719,0.174877,0.784364,0.855442,0.165036,0.505465,0.86821,0.09416,0.602907,0.823984,0.134951,0.333266,33.101116,63.465366,2.73043,0.219624,0.661856,0.031785,0.156249,59.525335,71.298241,92.819376,51.649399,34.229546,45.51556,60.071442,49.915203,60.029301,60.148026,0.885685,1.630075,1.160648,0.660678,0.991474,0.653835,0.637036,58.274345,98.9375,0.960974,1.555133,0.80987,0.317219,1.057118,28.551111,40.713482,61.225108,0.913455,0.429755,0.878316,0.77171,0.146602,0.090369,0.59061,0.064799,0.443008,0.795592,0.106851,0.371258,0.934801,0.017967,0.475159,0.969423,0.01325,0.451622,0.812427,0.069894,0.161142,0.780598,1.070834,0.040939,0.153154,3.968911,3.973705,0.040601,0.173973,3.978498,72.189827,92.318504,1.097515,0.761707,0.913576,0.774684,0.953781,0.046254,0.981073,0.963086,0.074252,0.494265,0.969134,0.009236,0.653669,0.915118,0.06273,0.889988,0.980706,0.003816,0.788575,0.982799,0.021548,0.666245,0.961392,0.019854,0.749785,0.960859,0.033956,0.272389,48.368454,3.211475,93.525429,54.310925,9.721982,98.899872,3.522871,6.803833,0.241907,17.252604,44.661457,3.483211,15.073529,39.809784,4.416458,1.641602,1.641602,5.905246,11.121742,28.513062,4.676317,0.929279,0.242047,0.038461,0.156602,3.995448,4.003531,0.033999,0.157582,4.011615,1.68058,0.234688,0.300354,1.089514,10.621736,22.94566,30.830579,0.984515,0.071435,0.933329,0.571881,0.159717,0.644788,23.969041,29.131876,74.803309,24.735661,56.166264,3.31001,0.558134,0.756664,2.101982,2.125624,0.551241,0.741958,2.149267,41.939747,69.661736,2.391057,0.910556,0.343595,0.489167,0.691657,0.692423,0.486111,0.69319


acdc.dice_loss                                           0.559774
acdc.loss                                                0.562896
acdc.mIoU                                               41.505707
acdc.mean_accuracy                                      64.750026
acdc.overall_accuracy                                   64.587687
ade20k.ce_loss                                           1.733339
ade20k.dice_loss                                         0.645262
ade20k.focal_loss                                        0.361829
ade20k.loss                                              1.153950
ade20k.mIoU                                             28.975941
ade20k.mean_accuracy                                    38.724790
ade20k.overall_accuracy                                 56.661637
aircraft.accuracy_top_1                                  0.909008
aircraft.loss                                            0.450296
chexpert.testing/ensemble_3/0-aps                        0.690759
chexpert.t

acdc.dice_loss                                           0.076245
acdc.loss                                                0.077020
acdc.mIoU                                               14.442974
acdc.mean_accuracy                                      16.444012
acdc.overall_accuracy                                   16.376422
ade20k.ce_loss                                           0.867984
ade20k.dice_loss                                         0.023768
ade20k.focal_loss                                        0.214972
ade20k.loss                                              0.573179
ade20k.mIoU                                             15.680938
ade20k.mean_accuracy                                    19.773612
ade20k.overall_accuracy                                 19.243818
aircraft.accuracy_top_1                                  0.104611
aircraft.loss                                            0.382560
chexpert.testing/ensemble_3/0-aps                        0.161094
chexpert.t

In [None]:
import torch

map_column_headers_to_index = {column: idx for idx, column in enumerate(model_df.columns)}
map_index_to_column_headers = {idx: column for idx, column in enumerate(model_df.columns)}
convert_df_to_tensor = model_df.to_numpy()
train_set = torch.tensor(convert_df_to_tensor, dtype=torch.float32)

print(train_set.shape)
print(map_column_headers_to_index)

In [None]:
rows_with_nans = cleaned_df[cleaned_df.isnull().any(axis=1)]
# store the headers of the columns for a given row that are not NaNs, there should be one list for each column
nan_cell_boolean = rows_with_nans.isnull()
# store the headers of the columns for a given row that are NaNs
NaN_metric_per_row = nan_cell_boolean.apply(lambda x: list(x.index[x]), axis=1).to_list()

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class NeuralNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(NeuralNet, self).__init__()
        self.l1 = nn.Linear(input_size, hidden_size)
        self.l2 = nn.Linear(hidden_size, hidden_size)
        self.l3 = nn.Linear(hidden_size, hidden_size)
        self.l4 = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        out = self.l1(x)
        out = F.leaky_relu(out)
        # out = F.dropout(out, p=0.8)
        out = self.l2(out)
        out = F.leaky_relu(out)
        # out = F.dropout(out, p=0.8)
        out = self.l3(out)
        out = F.leaky_relu(out)
        out = self.l4(out)
        out = F.tanh(out) * 10
        
        return out

In [None]:
import random
from sklearn.model_selection import ShuffleSplit


def train_eval_neural_network(train_df, target_columns, test_df, num_epochs, num_hidden_units):
    
    all_columns = train_df.columns
    train_tensor = torch.tensor(train_df.to_numpy(), dtype=torch.float32)
    # normalize column-wise
    full_mean = train_tensor.mean(axis=0)
    full_sd = train_tensor.std(axis=0)
    
    train_column_names = [col for col in all_columns if col not in target_columns]
    x_train = train_tensor[:, [map_column_headers_to_index[col] for col in train_column_names]]
    input_mean = full_mean[[map_column_headers_to_index[col] for col in train_column_names]]
    input_sd = full_sd[[map_column_headers_to_index[col] for col in train_column_names]]
    x_train = (x_train - input_mean) / input_sd
    
    target_columns = [col for col in all_columns if col not in train_column_names]
    y_train = train_tensor[:, [map_column_headers_to_index[col] for col in target_columns]]
    target_mean = full_mean[[map_column_headers_to_index[col] for col in target_columns]]
    target_sd = full_sd[[map_column_headers_to_index[col] for col in target_columns]]
    y_train = (y_train - target_mean) / target_sd
    # remove NaN from Series test_df and convert it to tensor
    clean_test_df: pd.Series = test_df.dropna()[1:-1].astype(float).values
    # Convert to numpy
    x_test = torch.tensor(clean_test_df, dtype=torch.float32)
    x_test = x_test - input_mean / input_sd
    
    result_dict = {key: [] for key in target_columns}
    
    kfold = ShuffleSplit(n_splits=50, test_size=0.2, random_state=42)
    average_train_loss = []
    average_val_loss = []
    
    for train_ids, val_ids in kfold.split(x_train):
        x_train_fold = x_train[train_ids]
        y_train_fold = y_train[train_ids]
        x_val_fold = x_train[val_ids]
        y_val_fold = y_train[val_ids]
        model = NeuralNet(x_train_fold.shape[1], num_hidden_units, y_train_fold.shape[1])
        optimizer = optim.AdamW(model.parameters(), lr=0.01, weight_decay=0.01)
        # print("------------------------------------------------------------")
        for _ in range(num_epochs):
            model.train()
            optimizer.zero_grad()
            y_pred = model(x_train_fold)
            loss = F.mse_loss(y_pred, y_train_fold) + F.l1_loss(y_pred, y_train_fold)
            loss.backward()
            optimizer.step()
            model.eval()
            pred_val = model(x_val_fold)
            val_loss = F.mse_loss(pred_val, y_val_fold) + F.l1_loss(pred_val, y_val_fold)
            
            
        average_train_loss.append(loss.item())
        average_val_loss.append(val_loss.item())
        y_pred = model.forward(x_test)
        y_pred_rec = y_pred * target_sd + target_mean
        updated_dict = {key: y_pred_rec[idx].item() for idx, key in enumerate(target_columns)}
        result_dict = {key: result_dict[key] + [updated_dict[key]] for key in target_columns}
    
    print(f"Mean train loss: {torch.tensor(average_train_loss).mean().item()}")
    print(f"Mean val loss: {torch.tensor(average_val_loss).mean().item()}")
    return result_dict
        
    

In [None]:
first_row_with_nans = rows_with_nans.iloc[0]
for i in range(1, 11):
    test_output_dict = train_eval_neural_network(train_df=model_df, 
                                                 target_columns=NaN_metric_per_row[0], 
                                                 test_df=first_row_with_nans, 
                                                 num_epochs=200, 
                                                 num_hidden_units=2**i)

In [None]:
print(test_output_dict)