In [1]:
import wandb
import pandas as pd
import matplotlib.pyplot as plt

In [23]:

# Initialize wandb API
api = wandb.Api()

# Replace 'your_sweep_id' with your actual sweep ID


task = 'SC'
# sweep_ids = ['7e0do624']
sweep_ids = ['22j8955i']
# sweep_ids = ['wauv85y4', 'gfwkoa2e', 'i81h7lqv', '7befnpfs']
# Fetch the sweep object


sweep_params = ['lr', 'pdrop', 'scheduler_patience', 'scheduler_factor', 'dt_max']

In [24]:
# Initialize an empty DataFrame to hold the relevant data
data = []

for sweep_id in sweep_ids:
    sweep_id = 'maximes_crew/S3_'+task+'_runs/' + sweep_id
    sweep = api.sweep(sweep_id)

    # Fetch all runs in the sweep
    runs = sweep.runs
    # Extract the necessary data from each run
    for run in runs:
        config = run.config
        summary = run.summary
        history = run.history()
        n_layers = config.get('nb_layers')
        n_hiddens = config.get('nb_hiddens')
        lr = config.get('lr')
        dt_min = config.get('dt_min')
        dt_max = config.get('dt_max')
        pdrop = config.get('pdrop')
        sched_patience = config.get('scheduler_patience')
        sched_factor = config.get('scheduler_factor')
        if 'valid acc' in history.columns and 'test acc' in history.columns:
            best_valid_acc_idx = history['valid acc'].idxmax()  # Get the index of the best valid accuracy
            if (best_valid_acc_idx+1) in history['test acc'].index:
                test_acc = history['test acc'].loc[best_valid_acc_idx+1]  # Get test acc at that index
            else:
                test_acc = None
        else:
            test_acc = None

        if test_acc is not None:
            data.append({
                'n_layers': n_layers,
                'n_hiddens': n_hiddens,
                'lr': lr,
                'dt_min': dt_min,
                'dt_max': dt_max,
                'pdrop' : pdrop,
                'scheduler_patience' : sched_patience,
                'scheduler_factor' : sched_factor,
                'test_acc': test_acc
            })



In [25]:
data

[{'n_layers': 3,
  'n_hiddens': 512,
  'lr': 0.00024976901806777346,
  'dt_min': 0.09943137406416136,
  'dt_max': 20.531631352743,
  'pdrop': 0.4686489890974874,
  'scheduler_patience': 5,
  'scheduler_factor': 0.895254422210165,
  'test_acc': 0.9193032026864474},
 {'n_layers': 3,
  'n_hiddens': 512,
  'lr': 0.0018103546457709568,
  'dt_min': 0.02102401111397289,
  'dt_max': 29.804501422887657,
  'pdrop': 0.2779253695475523,
  'scheduler_patience': 5,
  'scheduler_factor': 0.6355862428603751,
  'test_acc': 0.9358460304731356},
 {'n_layers': 3,
  'n_hiddens': 512,
  'lr': 0.005395520158892739,
  'dt_min': 0.05125576159312153,
  'dt_max': 20.78084018227121,
  'pdrop': 0.14288388204676125,
  'scheduler_patience': 5,
  'scheduler_factor': 0.801976997017799,
  'test_acc': 0.9224545158380112},
 {'n_layers': 3,
  'n_hiddens': 512,
  'lr': 0.0009419479129965356,
  'dt_min': 0.056846789617321994,
  'dt_max': 3.867904202864722,
  'pdrop': 0.41331972796069794,
  'scheduler_patience': 5,
  'schedu

In [26]:
# Convert the data to a DataFrame for easier manipulation
df = pd.DataFrame(data)

# Calculate the number of runs for each line configuration
df['run_count'] = df.groupby(sweep_params).transform('size')

# Group by n_layers and lr, then compute the mean of best_valid_acc for each group
df_grouped = df.groupby(sweep_params+['run_count']).agg({'test_acc': ['mean','std']}).reset_index()


df_grouped_sorted = df_grouped.sort_values(by=('test_acc','mean'), ascending=False).reset_index(drop=True)

In [27]:
df_grouped_sorted[0:10]

Unnamed: 0_level_0,lr,pdrop,scheduler_patience,scheduler_factor,dt_max,run_count,test_acc,test_acc
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,mean,std
0,0.002208,0.488026,5,0.629532,23.405067,1,0.944113,
1,0.001743,0.351245,5,0.593245,1.637379,1,0.943668,
2,0.00219,0.473331,5,0.640709,7.653211,1,0.943276,
3,0.001608,0.37128,5,0.563697,5.072047,1,0.943214,
4,0.003329,0.489024,5,0.534703,1.158927,1,0.943013,
5,0.002968,0.251589,10,0.517185,10.86349,1,0.943004,
6,0.002459,0.410064,5,0.528143,2.562579,1,0.942941,
7,0.002264,0.350946,5,0.674465,3.940317,1,0.942832,
8,0.001376,0.360981,5,0.538128,25.650849,1,0.942659,
9,0.001544,0.377533,5,0.595938,17.821705,1,0.942487,
