In [2]:
import os
import equinox as eqx
import xcquinox as xce

In [4]:
def make_net(xorc, level, depth, nhidden, ninput = None, use = None, spin_scaling = None, lob = None, ueg_limit = None,
                random_seed = None, savepath = None, configfile = 'network.config'):
    defaults_dct = {'GGA': {'X': {'ninput' : 1, 'depth': 3, 'nhidden': 16, 'use': [1], 'spin_scaling': True, 'lob': 1.804, 'ueg_limit':True},
                            'C': {'ninput': 3, 'depth': 3, 'nhidden': 16, 'use': [2], 'spin_scaling': False, 'lob': 2.0, 'ueg_limit':True}
                           },
                    'MGGA': {'X': {'ninput' : 2, 'depth': 3, 'nhidden': 16, 'use': [1, 2], 'spin_scaling': True, 'lob': 1.174, 'ueg_limit':True},
                            'C': {'ninput': 4, 'depth': 3, 'nhidden': 16, 'use': [2, 3], 'spin_scaling': False, 'lob': 2.0, 'ueg_limit':True}
                           },
                    'NL': {'X': {'ninput' : 18, 'depth': 3, 'nhidden': 16, 'use': None, 'spin_scaling': True, 'lob': 0, 'ueg_limit':True},
                            'C': {'ninput': 16, 'depth': 3, 'nhidden': 16, 'use': None, 'spin_scaling': False, 'lob': 0, 'ueg_limit':True}
                           }
                   }
    assert level.upper() in ['GGA', 'MGGA', 'NONLOCAL', 'NL']
    ninput = ninput if ninput is not None else defaults_dct[level.upper()][xorc.upper()]['ninput']
    depth = depth if depth is not None else defaults_dct[level.upper()][xorc.upper()]['depth']
    nhidden = nhidden if nhidden is not None else defaults_dct[level.upper()][xorc.upper()]['nhidden']
    use = use if use is not None else defaults_dct[level.upper()][xorc.upper()]['use']
    spin_scaling = spin_scaling if spin_scaling is not None else defaults_dct[level.upper()][xorc.upper()]['spin_scaling']
    ueg_limit = ueg_limit if ueg_limit is not None else defaults_dct[level.upper()][xorc.upper()]['ueg_limit']
    lob = lob if lob is not None else defaults_dct[level.upper()][xorc.upper()]['lob']
    random_seed = random_seed if random_seed is not None else 92017
    config = {'ninput':ninput,
              'depth':depth,
              'nhidden':nhidden,
              'use':use,
              'spin_scaling':spin_scaling,
              'ueg_limit': ueg_limit,
              'lob':lob,
             'random_seed': random_seed}
    if xorc.upper() == 'X':    
        net = xce.net.eX(n_input=ninput, use=use, depth=depth, n_hidden=nhidden, spin_scaling=spin_scaling, lob=lob, seed=random_seed)
    elif xorc.upper() == 'C':
        net = xce.net.eC(n_input=ninput, use=use, depth=depth, n_hidden=nhidden, spin_scaling=spin_scaling, lob=lob, seed=random_seed)
    
    if savepath:
        try:
            os.makedirs(savepath)
        except Exception as e:
            print(e)
            print(f'Exception raised in creating {savepath}.')
        with open(os.path.join(savepath, configfile), 'w') as f:
            for k, v in config.items():
                f.write(f'{k}\t{v}\n')
        eqx.tree_serialise_leaves(os.path.join(savepath, 'xc.eqx', net))

    return net, config

def make_xcfunc(x_net_path, c_net_path):
    pass

In [None]:
rseeds = [92017, 17920]
dirs = ['gga', 'mgga', 'nl']
constrs = ['c', 'nc']
randir = '/home/awills/Documents/Research/xcquinox/scripts/script_data/ctests/ran'
DEPTH = 3
NHIDDEN = 16
for didx, direc in enumerate(dirs):
    for sidx, seed in enumerate(rseeds):
        for cidx, con in enumerate(constrs):
            xsubdir = f'{x_{DEPTH}_{NHIDDEN}_{con}{sidx}_{direc}'
            csubdir = f'{x_{DEPTH}_{NHIDDEN}_{con}{sidx}_{direc}'
            xfp = os.path.join(randir, xsubdir)
            cfp = os.path.join(randir, csubdir)
            if con == 'c':
                #leave inputs as none so the defaults are imposed, which use constraints   
                x, xc = make_net('x', level=direc, depth=DEPTH, nhidden=NHIDDEN, random_seed = seed, savepath = xfp)
                c, cc = make_net('c', level=direc, depth=DEPTH, nhidden=NHIDDEN, random_seed = seed, savepath = cfp)
            if con == 'nc':
                #manually set ueg_limit/spin_scaling/lobs/use/ninputs
                ueg = False
                lob = 0
                ss = False
                use = []
                