In [2]:
import os, pickle
import equinox as eqx
import xcquinox as xce



In [62]:
def make_net(xorc, level, depth, nhidden, ninput = None, use = None, spin_scaling = None, lob = None, ueg_limit = None,
                random_seed = None, savepath = None, configfile = 'network.config'):
    defaults_dct = {'GGA': {'X': {'ninput' : 1, 'depth': 3, 'nhidden': 16, 'use': [1], 'spin_scaling': True, 'lob': 1.804, 'ueg_limit':True},
                            'C': {'ninput': 1, 'depth': 3, 'nhidden': 16, 'use': [2], 'spin_scaling': False, 'lob': 2.0, 'ueg_limit':True}
                           },
                    'MGGA': {'X': {'ninput' : 2, 'depth': 3, 'nhidden': 16, 'use': [1, 2], 'spin_scaling': True, 'lob': 1.174, 'ueg_limit':True},
                            'C': {'ninput': 2, 'depth': 3, 'nhidden': 16, 'use': [2, 3], 'spin_scaling': False, 'lob': 2.0, 'ueg_limit':True}
                           },
                    'NL': {'X': {'ninput' : 15, 'depth': 3, 'nhidden': 16, 'use': None, 'spin_scaling': True, 'lob': 1.174, 'ueg_limit':True},
                            'C': {'ninput': 16, 'depth': 3, 'nhidden': 16, 'use': None, 'spin_scaling': False, 'lob': 2.0, 'ueg_limit':True}
                           }
                   }
    assert level.upper() in ['GGA', 'MGGA', 'NONLOCAL', 'NL']

    ninput = ninput if ninput is not None else defaults_dct[level.upper()][xorc.upper()]['ninput']
    depth = depth if depth is not None else defaults_dct[level.upper()][xorc.upper()]['depth']
    nhidden = nhidden if nhidden is not None else defaults_dct[level.upper()][xorc.upper()]['nhidden']
    use = use if use is not None else defaults_dct[level.upper()][xorc.upper()]['use']
    spin_scaling = spin_scaling if spin_scaling is not None else defaults_dct[level.upper()][xorc.upper()]['spin_scaling']
    ueg_limit = ueg_limit if ueg_limit is not None else defaults_dct[level.upper()][xorc.upper()]['ueg_limit']
    lob = lob if lob is not None else defaults_dct[level.upper()][xorc.upper()]['lob']
    random_seed = random_seed if random_seed is not None else 92017
    config = {'ninput':ninput,
              'depth':depth,
              'nhidden':nhidden,
              'use':use,
              'spin_scaling':spin_scaling,
              'ueg_limit': ueg_limit,
              'lob':lob,
             'random_seed': random_seed}
    if xorc.upper() == 'X':    
        net = xce.net.eX(n_input=ninput, use=use, depth=depth, n_hidden=nhidden, spin_scaling=spin_scaling, lob=lob, seed=random_seed)
    elif xorc.upper() == 'C':
        net = xce.net.eC(n_input=ninput, use=use, depth=depth, n_hidden=nhidden, spin_scaling=spin_scaling, lob=lob, seed=random_seed)
    
    if savepath:
        try:
            os.makedirs(savepath)
        except Exception as e:
            print(e)
            print(f'Exception raised in creating {savepath}.')
        with open(os.path.join(savepath, configfile), 'w') as f:
            for k, v in config.items():
                f.write(f'{k}\t{v}\n')
        with open(os.path.join(savepath, configfile+'.pkl'), 'wb') as f:
            pickle.dump(config, f)
        eqx.tree_serialise_leaves(os.path.join(savepath, 'xc.eqx'), net)

    return net, config

def get_net(xorc, level, net_path, configfile='network.config', netfile='xc.eqx'):
    with open(os.path.join(net_path, configfile+'.pkl'), 'rb') as f:
        params = pickle.load(f)
    #network parameters
    depth = params['depth']
    nodes = params['nhidden']
    use = params['use']
    inp = params['ninput']
    ss = params['spin_scaling']
    lob = params['lob']
    ueg = params['ueg_limit']
    seed = params['random_seed']
    xcf = netfile

    net, _ = make_net(xorc=xorc, level=level, depth=depth, nhidden=nodes, ninput=inp, use=use,
                       spin_scaling = ss, lob = lob, ueg_limit = ueg, random_seed = seed)
    return net


def make_xcfunc(level, x_net_path, c_net_path, configfile = 'network.config', xdspath = None, cdspath = None,
               savepath = None):
    level_dict = {'GGA':2, 'MGGA':3, 'NONLOCAL':4, 'NL':4}
    try:
        with open(os.path.join(x_net_path, 'network.config.pkl'), 'rb') as f:
            xparams = pickle.load(f)
        with open(os.path.join(c_net_path, 'network.config.pkl'), 'rb') as f:
            cparams = pickle.load(f)
    except:
        print('BOTH exchange and correlation networks require a network.config.pkl file to generate the XC functional object.')
        raise
    #create the network to generate the descriptors for saving
    xnet, _ = get_net(xorc='X', net_path = x_net_path)
    cnet, _ = get_net(xorc='C', net_path = x_net_path)

    if xdspath:
        xnet = eqx.tree_deserialise_leaves(xdspath, xnet)
    if cdspath:
        cnet = eqx.tree_deserialise_leaves(cdspath, cnet)

    xc = xce.xc.eXC(grid_models = [xnet, cnet], heg_mult = True, level = level_dict[xlevel.upper()])

    if savepath:
        try:
            os.makedirs(savepath)
        except Exception as e:
            print(e)
            print(f'Exception raised in creating {savepath}.')
        with open(os.path.join(savepath, 'x'+configfile+'.pkl'), 'wb') as f:
            pickle.dump(xparams, f)
        with open(os.path.join(savepath, 'c'+configfile+'.pkl'), 'wb') as f:
            pickle.dump(cparams, f)
        eqx.tree_serialise_leaves(os.path.join(savepath, 'xc.eqx', xc))

    return xc

def get_xcfunc(level, xc_net_path, configfile = 'network.config', xcdsfile = 'xc.eqx'):
    level_dict = {'GGA':2, 'MGGA':3, 'NONLOCAL':4, 'NL':4}
    try:
        with open(os.path.join(xc_net_path, 'x'+configfile+'.pkl'), 'rb') as f:
            xparams = pickle.load(f)
        with open(os.path.join(xc_net_path, 'c'+configfile+'.pkl'), 'rb') as f:
            cparams = pickle.load(f)
    except:
        print('Error in opening separate exchange/correlation configuration files. Both must be present to re-create the network architecture.')
        raise

    #create the network to generate the descriptors for saving
    xnet = get_net(xorc='X', level=level, net_path = xc_net_path, configfile='x'+configfile)
    cnet = get_net(xorc='C', level=level, net_path = xc_net_path, configfile='c'+configfile)
    print(xnet, cnet)
    xc = xce.xc.eXC(grid_models = [xnet, cnet], heg_mult = True, level = level_dict[level.upper()])
    if xcdsfile:
        print('deserializing')
        print('pre', xc)
        xc = eqx.tree_deserialise_leaves(os.path.join(xc_net_path, xcdsfile), xc)
        print('post', xc)
    return xc

    

In [64]:
xc = get_xcfunc('GGA', xc_net_path='/home/awills/Documents/Research/xcquinox/xcquinox/tests/testmakenet/xc',
               )
for gm in xc.grid_models:
    

eX(
  n_input=1,
  n_hidden=16,
  ueg_limit=False,
  spin_scaling=True,
  lob=1.804,
  use=[1],
  net=MLP(
    layers=(
      Linear(
        weight=f64[16,1],
        bias=f64[16],
        in_features=1,
        out_features=16,
        use_bias=True
      ),
      Linear(
        weight=f64[16,16],
        bias=f64[16],
        in_features=16,
        out_features=16,
        use_bias=True
      ),
      Linear(
        weight=f64[16,16],
        bias=f64[16],
        in_features=16,
        out_features=16,
        use_bias=True
      ),
      Linear(
        weight=f64[1,16],
        bias=f64[1],
        in_features=16,
        out_features=1,
        use_bias=True
      )
    ),
    activation=<function gelu>,
    final_activation=<function <lambda>>,
    use_bias=True,
    use_final_bias=True,
    in_size=1,
    out_size=1,
    width_size=16,
    depth=3
  ),
  tanh=<wrapped function tanh>,
  lobf=LOB(limit=1.804, sig=<wrapped function sigmoid>),
  sig=<wrapped function sigmoid>,

In [26]:
xnet, xp = make_net('X', 'GGA', 3, 16)


The history saving thread hit an unexpected error (OperationalError('attempt to write a readonly database')).History will not be written to the database.


In [33]:
for idx, layer in enumerate(xnet.net.layers):

    print(idx, layer.weight)

0 [[-0.20489091]
 [-0.79174841]
 [ 0.28844834]
 [-0.43362515]
 [ 0.50391024]
 [-0.98559318]
 [ 0.94832324]
 [ 0.98880652]
 [ 0.60825287]
 [-0.09721325]
 [-0.42161051]
 [ 0.10627861]
 [ 0.72794369]
 [ 0.58913371]
 [-0.83273611]
 [ 0.71938769]]
1 [[-0.17911172 -0.02960953 -0.00651743  0.1887775  -0.09808548 -0.10622367
   0.17663048 -0.01051425 -0.21838071 -0.24993214 -0.19228363 -0.24216076
   0.19877497  0.01767027 -0.07503117  0.20756354]
 [-0.0215123  -0.13539059  0.05323678 -0.0922613   0.11724758  0.04538853
   0.0920973  -0.00854392 -0.11365275  0.02772144 -0.23221039 -0.03953853
  -0.12928641  0.00604485  0.06631281  0.09605728]
 [-0.24430495 -0.06414282  0.04955017 -0.02253693  0.216225   -0.01013105
   0.16478333  0.02455806  0.1479879   0.12867301 -0.12124441 -0.1118201
  -0.21104311  0.17432564 -0.12371706 -0.09656645]
 [ 0.15160833 -0.18017836 -0.17466687  0.11739767  0.02900885  0.14762458
  -0.2430875  -0.22499971 -0.05278997 -0.21727072 -0.13998593  0.04189591
  -0.061690

In [25]:
rseeds = [92017, 17920]
dirs = ['gga', 'mgga', 'nl']
constrs = ['c', 'nc']
randir = '/home/awills/Documents/Research/xcquinox/scripts/script_data/ctests/ran'
DEPTH = 3
NHIDDEN = 16
for didx, direc in enumerate(dirs):
    for sidx, seed in enumerate(rseeds):
        for cidx, con in enumerate(constrs):
            xsubdir = f'x_{DEPTH}_{NHIDDEN}_{con}{sidx}_{direc}'
            csubdir = f'c_{DEPTH}_{NHIDDEN}_{con}{sidx}_{direc}'
            print(xsubdir, csubdir)
            xfp = os.path.join(randir, xsubdir)
            cfp = os.path.join(randir, csubdir)
            if con == 'c':
                print('Constrained network generation')
                #leave inputs as none so the defaults are imposed, which use constraints   
                x, xc = make_net('x', level=direc, depth=DEPTH, nhidden=NHIDDEN, random_seed = seed, savepath = xfp)
                c, cc = make_net('c', level=direc, depth=DEPTH, nhidden=NHIDDEN, random_seed = seed, savepath = cfp)
            if con == 'nc':
                print('Non-constrained network generation')
                #manually set ueg_limit/spin_scaling/lobs/use/ninputs
                ueg = False
                lob = 0
                ss = False
                use = []
                if direc == 'gga':
                    ninput = 3
                elif direc == 'mgga':
                    ninput = 4
                elif direc == 'nl':
                    ninput = 16
                print('Making exchange network in {}'.format(xsubdir))
                x, xc = make_net('x', level=direc, depth=DEPTH, nhidden=NHIDDEN,
                                 ninput = ninput, use = use, spin_scaling = ss, lob = lob, ueg_limit = ueg,
                                 random_seed = seed, savepath = xfp)
                print('Making correlation network in {}'.format(csubdir))
                c, cc = make_net('c', level=direc, depth=DEPTH, nhidden=NHIDDEN,
                                 ninput = ninput, use = use, spin_scaling = ss, lob = lob, ueg_limit = ueg,
                                 random_seed = seed, savepath = cfp)
                
                

x_3_16_c0_gga c_3_16_c0_gga
Constrained network generation
[Errno 17] File exists: '/home/awills/Documents/Research/xcquinox/scripts/script_data/ctests/ran/x_3_16_c0_gga'
Exception raised in creating /home/awills/Documents/Research/xcquinox/scripts/script_data/ctests/ran/x_3_16_c0_gga.
[Errno 17] File exists: '/home/awills/Documents/Research/xcquinox/scripts/script_data/ctests/ran/c_3_16_c0_gga'
Exception raised in creating /home/awills/Documents/Research/xcquinox/scripts/script_data/ctests/ran/c_3_16_c0_gga.
x_3_16_nc0_gga c_3_16_nc0_gga
Non-constrained network generation
Making exchange network in x_3_16_nc0_gga
[Errno 17] File exists: '/home/awills/Documents/Research/xcquinox/scripts/script_data/ctests/ran/x_3_16_nc0_gga'
Exception raised in creating /home/awills/Documents/Research/xcquinox/scripts/script_data/ctests/ran/x_3_16_nc0_gga.
Making correlation network in c_3_16_nc0_gga
[Errno 17] File exists: '/home/awills/Documents/Research/xcquinox/scripts/script_data/ctests/ran/c_3_16