In [None]:
import sys
import pickle

sys.path.append("../..")
from helpers import get_model_ds_loss, init_dir, \
                    run_analysis, unpack_stats

## params

In [None]:
inp_dim = 3
hid_dim = 64
out_dim = 10
nlayer = 5
bias = False 
use_bn = (False,False,False,False)
bn_code = f"{use_bn[0]:d}{use_bn[1]:d}{use_bn[2]:d}{use_bn[3]:d}"
mode = "relu"
loss_mode = 'CrossEntropy'
device = 0

nsamp = 500

save_model_dir = "models"
datafolder = f"./data/{bn_code}"

init_dir(datafolder)

### Analysis

In [None]:
epochs = 2000
lr = 0.01
valfreq = 1
lrs = [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.]

for lr in lrs:
    model, ds, loss_fn = get_model_ds_loss(inp_dim, hid_dim, out_dim,
                                           nlayer, bias, use_bn, mode,
                                           nsamp, device, loss_mode)

    val_stats, tr_stats = run_analysis(model, ds, loss_fn, lr, epochs, valfreq)
    #val_stats, tr_stats = run_analysis(model, ds, loss_fn, lr, epochs, valfreq, save_model_dir)
    
    H, delta, fo, ho, error, fostat = unpack_stats(val_stats)
    loss, acc = unpack_stats(tr_stats)
    
    filename = f"stat_lr{lr:.0e}_nl{nlayer}_hid_{hid_dim}_bn{bn_code}"

    stat = {"H": H,
            "delta": delta,
            "fo": fo,
            "ho": ho,
            "error": error,
            "fostat": fostat,
            "loss": loss,
            "acc": acc,
            "lr": lr}
    
    with open(f'{datafolder}/{filename}.pkl','wb') as f:
        pickle.dump(stat,f)


In [None]:
import os
import glob

In [None]:
in_data = "data/*"
out_dir = "loss_acc_lr"
init_dir(out_dir)

for dirname in sorted(glob.iglob(in_data)):
    child_dir = os.path.join(out_dir, os.path.basename(dirname))
    init_dir(child_dir)
    for fname in sorted(glob.iglob(os.path.join(dirname, "*"))):
        with open(fname,'rb') as f:
            stat = pickle.load(f)
        stat = {"loss": stat["loss"],
                "acc": stat["acc"],
                "lr": stat["lr"]}
        with open(f'{child_dir}/{os.path.basename(fname)}','wb') as f:
            pickle.dump(stat,f)