In [1]:
import wandb
import pandas as pd
import matplotlib.pyplot as plt


In [3]:
# Initialize wandb API
api = wandb.Api()

# Replace 'your_sweep_id' with your actual sweep ID
sweep_id = 'maximes_crew/S3_SHD_runs/lcz0m1fp'

# Fetch the sweep object
sweep = api.sweep(sweep_id)

# Fetch all runs in the sweep
runs = sweep.runs

sweep_params = ['drop2', 'prenorm','premix','mix','residual1','residual2','nb_state','nb_hiddens']

# Initialize an empty DataFrame to hold the relevant data
data = []

# Extract the necessary data from each run
for run in runs:
    config = run.config
    summary = run.summary
    n_layers = config.get('nb_layers')
    nb_hiddens = config.get('nb_hiddens')
    lr = config.get('lr')
    dt_min = config.get('dt_min')
    dt_max = config.get('dt_max')
    pdrop = config.get('pdrop')
    sched_patience = config.get('scheduler_patience')
    sched_factor = config.get('scheduler_factor')
    best_valid_acc = summary.get('best valid acc')
    s4 = config.get('s4')
    normalization = config.get('normalization')
    use_readout_layer = config.get('use_readout_layer')
    nb_state = config.get('nb_state')
    pure_complex = config.get('pure_complex')
    activation = config.get('activation')

    # Additional parameters
    drop2 = config.get('drop2')
    prenorm = config.get('prenorm')
    premix = config.get('premix')
    mix = config.get('mix')
    residual1 = config.get('residual1')
    residual2 = config.get('residual2')

    if best_valid_acc is not None:
        data.append({
            'n_layers': n_layers,
            'nb_hiddens': nb_hiddens,
            's4': s4,
            'normalization': normalization,
            'use_readout_layer': use_readout_layer,
            'nb_state': nb_state,
            'pure_complex': pure_complex,
            'activation': activation,
            'lr': lr,
            'dt_min': dt_min,
            'dt_max': dt_max,
            'pdrop': pdrop,
            'scheduler_patience': sched_patience,
            'scheduler_factor': sched_factor,
            'best_valid_acc': best_valid_acc,
            # Additional parameters
            'drop2': drop2,
            'prenorm': prenorm,
            'premix': premix,
            'mix': mix,
            'residual1': residual1,
            'residual2': residual2
        })




In [4]:
# Convert the data to a DataFrame for easier manipulation
df = pd.DataFrame(data)

# Group by the sweep parameters and compute the mean of best_valid_acc for each group
df_grouped = df.groupby(sweep_params).agg({'best_valid_acc': 'mean'}).reset_index()

# Calculate the number of runs for each configuration
df['run_count'] = df.groupby(sweep_params)['best_valid_acc'].transform('size')

# Add the run count to the grouped DataFrame
df_grouped = df_grouped.merge(df[sweep_params + ['run_count']].drop_duplicates(), on=sweep_params, how='left')

# Sort by best_valid_acc in descending order
df_grouped_sorted = df_grouped.sort_values(by='best_valid_acc', ascending=False).reset_index(drop=True)

# Display the sorted DataFrame
df_grouped_sorted

Unnamed: 0,drop2,prenorm,premix,mix,residual1,residual2,nb_state,nb_hiddens,normalization,pure_complex,use_readout_layer,best_valid_acc,run_count
0,False,True,True,GLU,True,False,2,128,batchnorm,True,True,0.944515,5
1,True,True,True,GLU,True,False,2,128,batchnorm,True,True,0.94087,5
2,True,False,False,GLU,True,True,2,128,batchnorm,True,True,0.939931,5
3,True,True,False,GLU,False,False,2,128,batchnorm,True,True,0.939126,5
4,True,False,False,GLU,False,False,2,128,batchnorm,True,True,0.936253,5
5,False,False,False,GLU,False,False,2,128,batchnorm,True,True,0.935574,5
6,True,False,False,GLU,False,True,2,128,batchnorm,True,True,0.935283,5
7,True,True,False,GLU,True,False,2,128,batchnorm,True,True,0.934959,5
8,False,False,False,GLU,True,True,2,128,batchnorm,True,True,0.934549,5
9,True,True,True,GLU,False,False,2,128,batchnorm,True,True,0.933176,5
