In [1]:
!pip install wandb

Collecting wandb
  Downloading wandb-0.17.1-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m29.8 MB/s[0m eta [36m0:00:00[0m
Collecting docker-pycreds>=0.4.0 (from wandb)
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting gitpython!=3.1.29,>=1.0.0 (from wandb)
  Downloading GitPython-3.1.43-py3-none-any.whl (207 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.3/207.3 kB[0m [31m15.0 MB/s[0m eta [36m0:00:00[0m
Collecting sentry-sdk>=1.0.0 (from wandb)
  Downloading sentry_sdk-2.5.0-py2.py3-none-any.whl (289 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m289.5/289.5 kB[0m [31m10.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting setproctitle (from wandb)
  Downloading setproctitle-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86

In [2]:
import pandas as pd
import wandb
from tqdm import tqdm
import requests
import os
import numpy as np
import matplotlib.pyplot as plt
import sklearn
from sklearn import metrics
import matplotlib.pyplot as plt
import scipy.stats as stats
import matplotlib.colors as mcolors

In [3]:
# Download the file using wget
!wget -q -O read_wandb.py https://raw.githubusercontent.com/eilamshapira/HumanChoicePrediction/master/RunningScripts/read_wandb.py


In [4]:
from read_wandb import wandb_results
api = wandb_results("Strategy_Transfer_TACL", wandb_username="noor25")

BASE_METRIC = "accuracy_per_mean_user_and_bot"

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [5]:
def result_metric(sweeps, group_name, drop_list=[0], drop_HPT=False, metric=BASE_METRIC, epoch="best"):
    df = api.get_sweeps_results(sweeps, metric=metric)
    np.set_printoptions(precision=3)
    config_cols = [c for c in df.columns if "config_" in c and c!="config_wandb_run_id" and c!="config_online_simulation_size"]
    HPT_cols = [col for col in config_cols if df[col].nunique() > 1]
    print(HPT_cols)
    if drop_HPT:
        df=df.drop([c for c in HPT_cols if not c in ["config_LLM_SIM_SIZE", "config_seed"]], axis=1)
        HPT_cols = ["config_LLM_SIM_SIZE", "config_seed"]

    # Remove non-numeric columns before computing mean and std
    numeric_cols = df.select_dtypes(include=np.number).columns
    df_numeric = df[numeric_cols]

    grouped = df_numeric.groupby([c for c in HPT_cols if c != "config_seed"])

    mean_df = grouped.mean()
    std_df = grouped.std()

    # Re-add non-numeric columns before computing best_col
    for col in config_cols:
        if col not in mean_df.columns:
            mean_df[col] = df[col]

    if epoch=="best":
        best_col = mean_df[[c for c in mean_df.columns if (metric in c and metric[-4:] == c.split("_epoch")[0][-4:])]].idxmax(axis=1)
    else:
        best_col = mean_df[[c for c in mean_df.columns if f"{metric}_epoch{epoch}" in c]].idxmax(axis=1)

    result = grouped.apply(lambda x: x[best_col.loc[x.name]].values)
    means = grouped.apply(lambda x: x[best_col.loc[x.name]].mean())
    stds = grouped.apply(lambda x: x[best_col.loc[x.name]].std())


    df_cols = {'mean': means, 'std': stds, 'values': result.values}
    if epoch == "best": df_cols['epoch'] = best_col.apply(lambda x: int(x.split("epoch")[1]) if "epoch" in x else "last")

    df_cols['CI'] = result.apply(lambda x: bootstrap_ci(x))

    summary_df = pd.DataFrame(df_cols, index=best_col.index)
    for d in drop_list:
        if d in summary_df.index:
            summary_df=summary_df.drop(d)
    if len(summary_df.index.names) == 1:
        return summary_df.rename_axis(group_name),best_col
    else:
        return summary_df,best_col

def bootstrap_ci(data, n_bootstrap=1000, ci=0.95):
    bootstrapped_means = []
    for _ in range(n_bootstrap):
        sample = np.random.choice(data, size=len(data), replace=True)
        bootstrapped_means.append(np.mean(sample))
    lower_bound = np.percentile(bootstrapped_means, (1 - ci) / 2 * 100)
    upper_bound = np.percentile(bootstrapped_means, (1 + ci) / 2 * 100)
    return lower_bound, upper_bound


# For HyperParameterTuning

For every configuration that you test in the sweep, you will receive in the table the average, standard deviation, all the values obtained for the different seed values, and also the confidence interval within which the result is located at a confidence level of 95%.

When epoch="best" is defined, you can check in which epoch the best result is obtained. If epoch=5 is defined, you will receive the result obtained for epoch number 5.

You can test multiple sweeps simultaneously by entering them into the list found in the first element of the function result_metric.

In [6]:
# Create the directory if it doesn't exist
directory = 'sweeps_csvs'
if not os.path.exists(directory):
    os.makedirs(directory)

sweep_results = result_metric(["uv5plu7y"], "ENV_LEARNING_RATE", drop_HPT=False, epoch="best")
pd.set_option('display.max_rows', None)   # Show all rows
pd.set_option('display.max_columns', None)  # Show all columns
pd.set_option('display.width', None)     # Don't truncate output width
pd.set_option('display.max_colwidth', None)
display(sweep_results[0])

Total number of sweeps: 1
Download sweep_id='uv5plu7y' data...


100%|██████████| 50/50 [00:01<00:00, 42.30it/s]


['config_seed', 'config_ENV_LEARNING_RATE']


Unnamed: 0_level_0,mean,std,values,epoch,CI
ENV_LEARNING_RATE,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
1e-05,0.824694,0.002791,"[0.8187546641184331, 0.8252406663560218, 0.826214489725285, 0.8266381106783883, 0.8249175069908017, 0.8252686450684819, 0.8236738322305607, 0.8290665225454534, 0.8219331845420441, 0.8252368331360409]",24,"(0.8230917615369797, 0.8261816958270607)"
4e-05,0.83223,0.00276,"[0.8310511093495802, 0.8344106165169124, 0.834303786097392, 0.8285816764333936, 0.8351172081837431, 0.8317349874301958, 0.8291874979070114, 0.8327948121840498, 0.8361939785672542, 0.8289199705602457]",23,"(0.8305371988328074, 0.8339500260581147)"
0.0001,0.835409,0.002328,"[0.8336611740063499, 0.8376899116149543, 0.8339766343307125, 0.8328760652927926, 0.8349468564661433, 0.8329203658073057, 0.8347944294903027, 0.8360939724239859, 0.840124253547841, 0.8370061911135623]",23,"(0.8341000506377076, 0.8368240185569856)"
0.0004,0.836536,0.002431,"[0.8361362067694211, 0.8386272994163164, 0.8363856260889975, 0.8347931071906355, 0.8376913079589968, 0.8348269974584533, 0.8405553343202583, 0.8317071035527487, 0.8378762066105551, 0.8367606553345749]",15,"(0.8350944321068903, 0.8379377549983265)"
0.001,0.83504,0.001669,"[0.8351450250340013, 0.8340406438412765, 0.8325816478645481, 0.8362437694204464, 0.8373806498395585, 0.8345594161672811, 0.8333715050324814, 0.8378170413000189, 0.8347505572346025, 0.8345050670351348]",16,"(0.8340947308870514, 0.8360002158517299)"


# Result for a specific epoch

In [7]:
sweep_results[1]

config_ENV_LEARNING_RATE
0.00001    ENV_Test_accuracy_per_mean_user_and_bot_epoch24
0.00004    ENV_Test_accuracy_per_mean_user_and_bot_epoch23
0.00010    ENV_Test_accuracy_per_mean_user_and_bot_epoch23
0.00040    ENV_Test_accuracy_per_mean_user_and_bot_epoch15
0.00100    ENV_Test_accuracy_per_mean_user_and_bot_epoch16
dtype: object

In [8]:
sweep_results = result_metric(["uv5plu7y"], "ENV_LEARNING_RATE", drop_HPT=False, epoch=15)
sweep_results[0]

Total number of sweeps: 1
Download sweep_id='uv5plu7y' data...
['config_seed', 'config_ENV_LEARNING_RATE']


Unnamed: 0_level_0,mean,std,values,CI
ENV_LEARNING_RATE,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
1e-05,0.821183,0.003821,"[0.8134131067180385, 0.8204955035855951, 0.8227819819577116, 0.8201752804536674, 0.8237023259050674, 0.826374419424073, 0.8187055093672332, 0.826057600321759, 0.8208589153628287, 0.8192702692641709]","(0.8190718305937473, 0.8232402305059686)"
4e-05,0.829366,0.001497,"[0.8289412052110847, 0.8294317922930174, 0.831165813522334, 0.8286308454461573, 0.8302466925076388, 0.8264862888859114, 0.8294468256107264, 0.8299786846463071, 0.8314786700410169, 0.8278567758200123]","(0.8284313045414282, 0.8302145734686696)"
0.0001,0.834219,0.002533,"[0.8294652922801309, 0.8363183009629958, 0.8314106343669891, 0.8364977922260179, 0.8359671987589636, 0.8329739010924435, 0.8343533523497321, 0.8370477063930639, 0.8356633044699742, 0.8324962159876711]","(0.8327602469462583, 0.8355839407484279)"
0.0004,0.836536,0.002431,"[0.8361362067694211, 0.8386272994163164, 0.8363856260889975, 0.8347931071906355, 0.8376913079589968, 0.8348269974584533, 0.8405553343202583, 0.8317071035527487, 0.8378762066105551, 0.8367606553345749]","(0.83520544006212, 0.8378570765596862)"
0.001,0.83327,0.002809,"[0.8351800279806842, 0.828436634097084, 0.8327247247486428, 0.835948013728704, 0.8294179182380255, 0.8355281231532306, 0.8308690253567366, 0.8334412769363276, 0.8352973399582154, 0.8358585021559682]","(0.8316864410835465, 0.8348547935389355)"
