In [1]:
import wandb
api = wandb.Api()
import pandas as pd
import matplotlib.patheffects as PathEffects
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
import matplotlib.ticker as mticker
import seaborn as sns
from tqdm import tqdm
plt.rcParams["text.usetex"] = False

import matplotlib
plt.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['lines.linewidth'] = 2
matplotlib.rcParams['font.family'] = "sans-serif"
matplotlib.rcParams['legend.fontsize'] = 12
matplotlib.rcParams['axes.labelsize'] = 12
matplotlib.rcParams['xtick.labelsize'] = 12
matplotlib.rcParams['ytick.labelsize'] = 12
matplotlib.rcParams['legend.title_fontsize'] = 12
matplotlib.rcParams['axes.spines.top'] = False
matplotlib.rcParams['axes.spines.right'] = False
matplotlib.rcParams['figure.dpi'] = 250
matplotlib.rcParams['figure.figsize'] = (5,5)
cmap = plt.get_cmap("viridis")
import seaborn as sns
colors_b = sns.color_palette("colorblind")
colors = sns.color_palette("Set2")
import matplotlib.cm as cm
viridis = cm.get_cmap('flare', 8)  # sample 8 colors
colors = [viridis(i) for i in range(8)]  # RGBA values
plt.rcParams["axes.prop_cycle"] = plt.cycler(color=colors)

import numpy as np
%config InlineBackend.figure_formats = ['svg']

  viridis = cm.get_cmap('flare', 8)  # sample 8 colors


In [2]:
lrs = [0.0005, 0.001, 0.002, 0.004, 0.008, 0.016, 0.032]
wds = [0.1]
beta1s = [0.9875, 0.975, 0.95, 0.9]
beta2s = [0.99375, 0.9875, 0.975, 0.95, 0.9, 0.8]


def get_data(project, control_length):
    runs = api.runs(project)
    summary_list = []
    config_list = []
    name_list = []
    for run in runs:
        summary_list.append(run.summary._json_dict)
        config = {k:v for k,v in run.config.items() if not k.startswith('_')}
        config_list.append(config)
        name_list.append(run.name)

    summary_df = pd.DataFrame.from_records(summary_list)
    config_df = pd.DataFrame.from_records(config_list)
    name_df = pd.DataFrame({'name': name_list})
    all_df = pd.concat([name_df, config_df,summary_df], axis=1)
    res = []
    i = 0
    for run in runs:#tqdm(runs, desc="Processing"):
        i += 1
        if ((run.config['lr'] in lrs) & (run.config['weight_decay'] in wds)& (run.config['beta1'] in beta1s) & (run.config['beta2'] in beta2s)):
            print('wd '+str(run.config['weight_decay'])+', lr '+str(run.config['lr'])+', b1 '+str(run.config['beta1'])+', b2 '+str(run.config['beta2']))
            print(np.array(run.history(keys=['valid/ppl']))[-1,1])
            print(np.array(run.history(keys=['train/ppl']))[-1,1])
            if np.array(run.history(keys=['valid/ppl'])).shape[0]>0:
                ppl_test = np.array(run.history(keys=['valid/ppl']))[-1,1]
                ppl_train = np.mean(np.array(run.history(keys=['train/ppl']))[-10:,1])
                ppl_all = np.array(run.history(keys=['train/ppl']))[:,1]
                res.append({'optim':run.config['optim'],'num_steps':run.config['steps_budget'], 'bs':run.config['grad_accumulation_steps']*run.config['micro_batch_size'], 'lr':run.config['lr'], 'wd':run.config['weight_decay'], 'beta1':run.config['beta1'], 'beta2':run.config['beta2'], 'ppl_train':ppl_train, 'ppl_all':ppl_all, 'ppl_test':ppl_test})
            else:
                print('ERROR!!! ---- wd '+str(run.config['weight_decay'])+', lr '+str(run.config['lr'])+', b2 '+str(run.config['beta2']))
    return res

print('running..')
res = get_data("orvi-things/scaling_betas_adam_double_chinchilla", control_length = 31)
print('done..')

running..
wd 0.1, lr 0.032, b1 0.9875, b2 0.99375
1964.4521323408753
1992.3216819264856
wd 0.1, lr 0.032, b1 0.9875, b2 0.975
2128.520332177053
2107.57168234044
wd 0.1, lr 0.032, b1 0.9875, b2 0.95
57971.9708233657
1.1216771099066047e+290
wd 0.1, lr 0.032, b1 0.9875, b2 0.9875
2213.041606521971
2177.545528450089
wd 0.1, lr 0.016, b1 0.9875, b2 0.99375
19.873338292226503
20.34268321629274
wd 0.1, lr 0.032, b1 0.9875, b2 0.9
57971.9708233657
1.0478508741246578e+262
wd 0.1, lr 0.032, b1 0.9875, b2 0.8
57971.9708233657
3.511893743678942e+132
wd 0.1, lr 0.016, b1 0.9875, b2 0.975
19.877175979325504
20.35040114713966
wd 0.1, lr 0.016, b1 0.9875, b2 0.9875
19.813289286369862
20.218494550613173
wd 0.1, lr 0.008, b1 0.9875, b2 0.975
19.957945335942313
20.409169737617265
wd 0.1, lr 0.016, b1 0.9875, b2 0.9
57971.9708233657
7.874616342148216e+300
wd 0.1, lr 0.016, b1 0.9875, b2 0.8
57971.9708233657
2.4696556030798696e+186
wd 0.1, lr 0.016, b1 0.9875, b2 0.95
21.345648995267133
21.80709934314181
w

20.289939259532225
wd 0.1, lr 0.008, b1 0.9, b2 0.9
20.00246714087959
20.45588734877871
wd 0.1, lr 0.004, b1 0.95, b2 0.9
20.15085554389549
20.699625418865942
wd 0.1, lr 0.002, b1 0.975, b2 0.95
20.23406247502833
20.698593992930284
wd 0.1, lr 0.001, b1 0.975, b2 0.95
20.878595117873857
21.399184701845332
wd 0.1, lr 0.004, b1 0.95, b2 0.975
20.181179544314226
20.729959470107413
wd 0.1, lr 0.004, b1 0.9, b2 0.8
20.659741781551567
21.27250694846127
wd 0.1, lr 0.016, b1 0.9, b2 0.99375
2108.446000600806
2098.371236990458
wd 0.1, lr 0.016, b1 0.95, b2 0.9875
19.879986517591455
20.352298333834465
wd 0.1, lr 0.002, b1 0.95, b2 0.8
20.476813192225844
21.09893297566941
wd 0.1, lr 0.001, b1 0.95, b2 0.99375
21.23347235084892
21.846920069523968
wd 0.1, lr 0.004, b1 0.975, b2 0.99375
19.968917063192265
20.437075415717082
wd 0.1, lr 0.016, b1 0.95, b2 0.95
19.87327041876756
20.33126934621812
wd 0.1, lr 0.004, b1 0.9, b2 0.95
20.412065912448647
20.90646360184479
wd 0.1, lr 0.002, b1 0.975, b2 0.975


In [3]:
res1 = pd.DataFrame(res).sort_values(by="lr")
res1.to_pickle("Adam_160M_SP_SL2048_6Btok_BS256.pkl")
#res1=res1[(res1['bs']==256) & (res1['num_steps']==6200)& (res1['wd']==0.1)]