# Visible Neural Network - Hyperparamter tuning applied to many yvars individually

In [None]:
# Data ----
from dataGMX.core import get_data # <- Soybean Data
from dataG2F.qol  import ensure_dir_path_exists

# Data Utilities ----
import numpy  as np
import pandas as pd

from EnvDL.dlfn import BigDataset, plDNN_general
from EnvDL.sets import mask_columns

# Model Building  ----
## General ====
import torch
from   torch import nn
import torch.nn.functional as F
from   torch.utils.data import Dataset
from   torch.utils.data import DataLoader

## VNN ====
import sparsevnn
from   sparsevnn.core import\
    VNNHelper, \
    structured_layer_info, \
    SparseLinearCustom
from   sparsevnn.kegg import \
    kegg_connections_build, \
    kegg_connections_clean, \
    kegg_connections_append_y_hat, \
    kegg_connections_sanitize_names

# Hyperparameter Tuning ----
import os # needed for checking history (saved by lightning) 

## Logging with Pytorch Lightning ====
import lightning.pytorch as pl
from   lightning.pytorch.loggers import CSVLogger # used to save the history of each trial (used by ax)

## Adaptive Experimentation Platform ====
from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.utils.notebook.plotting import init_notebook_plotting, render

# For logging experiment results in sql database
from ax.storage.sqa_store.db import init_engine_and_session_factory
from ax.storage.sqa_store.db import get_engine, create_all_tables
from ax.storage.sqa_store.save import save_experiment # saving
from ax.storage.sqa_store.structs import DBSettings # loading
# from ax.storage.sqa_store.load import load_experiment # loading alternate

In [None]:
torch.set_float32_matmul_precision('medium')

In [None]:
init_notebook_plotting()

## Setup

In [None]:
cache_path = '../nbs_artifacts/gmx_yhat_02/'

In [None]:
# Run settings: 
params_run = {
    'batch_size': 256, 
    'max_epoch' : 256 #512, # 256, 
}

# data settings
params_data = {
    'y_var': [
        'sdwt100', 'ProteinDry', 'OilDry', 'AshDry', 'FiberDry', 'LysineDry', 
        'CysteineDry', 'MethionineDry', 'ThreonineDry', 'TryptophanDry', 'IsoleucineDry', 
        'LeucineDry', 'HistidineDry', 'PhenylalanineDry', 'ValineDry', 'AlanineDry', 
        'ArginineDry', 'AsparticacidDry', 'GlutamicacidDry', 'GlycineDry', 'ProlineDry', 
        'SerineDry', 'TyrosineDry', 'SucroseDry', 'Linolenic', 'Linoleic', 'Oleic', 
        'Palmitic', 'Moisture', 'RaffinoseDry', 'StachyoseDry', 'Stearic', 'SeedAS', 
        'SeedPL', 'SeedW', 'SeedLWR', 'SeedCS', 'SeedL', 'SampleWeight', 'B11', 'Na23', 
        'Mg26', 'Al27', 'P31', 'S34', 'K39', 'Ca44', 'Mn55', 'Ni60', 'Cu63', 'Zn66', 
        'Fe54', 'Co59', 'Se78', 'Rb85', 'Sr88', 'Mo98', 'Cd111', 'As75'],
    'y_resid': 'None', # None, Env, Geno
    'y_resid_strat': 'None', # None, naive_mean, filter_mean, ...
    'holdout_parents': { # For this dataset a percent of genotypes are randomly held out. The seed value makes this reproducible. May switch to similarity based approach.
        'rng_seed': 9874325,
        'pr': 0.2} 
}

In [None]:
## Settings ====
run_hyps = 2#5 
run_hyps_force = False # should we run more trials even if the target number has been reached?
max_hyps = 50

params_list = [    
    ## Output Size ====
    {
    'name': 'default_out_nodes_inp',
    'type': 'range',
    'bounds': [1, 8],
    'value_type': 'int',
    'log_scale': False
    },
    {
    'name': 'default_out_nodes_edge',
    'type': 'range',
    'bounds': [1, 32],
    'value_type': 'int',
    'log_scale': False
    },
    {
    'name': 'default_out_nodes_out',
    'type': 'fixed',
    'value': 1, # len(params_data['y_var']) if type(params_data['y_var']) else 1,
    'value_type': 'int',
    'log_scale': False
    },
    ## Dropout ====
    {
    'name': 'default_drop_nodes_inp',
    'type': 'range',
    'bounds': [0.01, 0.99],
    'value_type': 'float',
    'log_scale': False
    },
    {
    'name': 'default_drop_nodes_edge',
    'type': 'range',
    'bounds': [0.01, 0.99],
    'value_type': 'float',
    'log_scale': False
    },
    {
    'name': 'default_drop_nodes_out',
    'type': 'range',
    'bounds': [0.01, 0.99],
    'value_type': 'float',
    'log_scale': False,
    'sort_values':True
    },
    ## Node Repeats ====
    {
    'name': 'default_reps_nodes_inp',
    'type': 'choice',
    'values': [1, 2, 3],
    'value_type': 'int',
    'is_ordered': True,
    'sort_values':True
    },
    {
    'name': 'default_reps_nodes_edge',
    'type': 'choice',
    'values': [1, 2, 3],
    'value_type': 'int',
    'is_ordered': True,
    'sort_values':True
    },
    {
    'name': 'default_reps_nodes_out',
    'type': 'choice',
    'values': [1, 2, 3],
    'value_type': 'int',
    'is_ordered': True,
    'sort_values':True
    },
    ## Node Output Size Scaling ====
    {
    'name': 'default_decay_rate',
    'type': 'choice',
    'values': [0+(0.1*i) for i in range(10)]+[1.+(1*i) for i in range(11)],
    'value_type': 'float',
    'is_ordered': True,
    'sort_values':True
    }
    ]

In [None]:
lightning_log_dir = cache_path+"lightning"
exp_name = [e for e in cache_path.split('/') if e != ''][-1]

In [None]:
# parameterization is needed for setup. These values will be overwritten by Ax if tuning is occuring. 
# in this file I define params later. I've included it here to gurantee that we can merge other params dicts into it.
params = {
'default_out_nodes_inp'  : 1,
'default_out_nodes_edge' : 1,
'default_out_nodes_out'  : len(params_data['y_var']) if type(params_data['y_var']) else 1,

'default_drop_nodes_inp' : 0.0,
'default_drop_nodes_edge': 0.0,
'default_drop_nodes_out' : 0.0,

'default_reps_nodes_inp' : 1,
'default_reps_nodes_edge': 1,
'default_reps_nodes_out' : 1,

'default_decay_rate'     : 1
}

default_out_nodes_inp  = params['default_out_nodes_inp' ]
default_out_nodes_edge = params['default_out_nodes_edge'] 
default_out_nodes_out  = params['default_out_nodes_out' ]

default_drop_nodes_inp = params['default_drop_nodes_inp' ] 
default_drop_nodes_edge= params['default_drop_nodes_edge'] 
default_drop_nodes_out = params['default_drop_nodes_out' ] 

default_reps_nodes_inp = params['default_reps_nodes_inp' ]
default_reps_nodes_edge= params['default_reps_nodes_edge']
default_reps_nodes_out = params['default_reps_nodes_out' ]

default_decay_rate = params['default_decay_rate' ]

In [None]:
batch_size = params_run['batch_size']
max_epoch  = params_run['max_epoch']

y_var = params_data['y_var']

In [None]:
save_prefix = [e for e in cache_path.split('/') if e != ''][-1]

if 'None' != params_data['y_resid_strat']:
    save_prefix = save_prefix+'_'+params_data['y_resid_strat']

ensure_dir_path_exists(dir_path = cache_path)

In [None]:
use_gpu_num = 0

device = "cuda" if torch.cuda.is_available() else "cpu"
if use_gpu_num in [0, 1]: 
    torch.cuda.set_device(use_gpu_num)
print(f"Using {device} device")

In [None]:
def _dist_scale_function(out, dist, decay_rate):
    scale = 1/(1+decay_rate*dist)
    out = round(scale * out)
    out = max(1, out)
    return out

In [None]:
def _expand_node_shortcut(vnn_helper, query = 'y_hat'):
    # define new entries
    if True in [True if e in vnn_helper.edge_dict.keys() else False for e in 
                [f'{query}_res_-2', f'{query}_res_-1']
                ]:
        print('Warning! New node name already exists! Overwriting existing node!')

    # Add residual connection in graph
    vnn_helper.edge_dict[f'{query}_res_-2'] = myvnn.edge_dict[query] 
    vnn_helper.edge_dict[f'{query}_res_-1'] = [f'{query}_res_-2']
    vnn_helper.edge_dict[query]             = [f'{query}_res_-2', f'{query}_res_-1']

    # Add new nodes, copying information from query node
    vnn_helper.node_props[f'{query}_res_-2'] = vnn_helper.node_props[query] 
    vnn_helper.node_props[f'{query}_res_-1'] = vnn_helper.node_props[query]

    return vnn_helper


## Load Data

In [None]:
# Data Prep ----
obs_geno_lookup          = get_data('obs_geno_lookup')
phno                     = get_data('phno')
ACGT_gene_slice_list     = get_data('KEGG_slices')
parsed_kegg_gene_entries = get_data('KEGG_entries')

In [None]:
# this dataset is very close to balanced wrt to the genotypes with 9-18 obs. 
# This code will turn a seed value and percent to be held out into a list of genotypse
rng = np.random.default_rng( params_data['holdout_parents']['rng_seed'] )
tmp = get_data(name = 'phno')

taxa = sorted(list(set(tmp.Taxa)))
rng.shuffle(taxa)
taxa = pd.DataFrame(taxa).reset_index().rename(columns={0:'Taxa', 'index':'DrawOrder'})

tmp = pd.merge(taxa, tmp).loc[:,['DrawOrder', 'Taxa']].assign(n=lambda x: 1).groupby(['DrawOrder', 'Taxa']).count().reset_index()
tmp['cdf'] = tmp['n'].cumsum(0)/tmp['n'].sum()
# filter
holdout_parents = list(tmp.loc[(tmp.cdf <= params_data['holdout_parents']['pr'] ), 'Taxa'])

In [None]:
# make holdout sets
# holdout_parents = params_data['holdout_parents']

# create a mask for parent genotype
mask = mask_columns(df= phno, col_name= 'Taxa', holdouts= holdout_parents)


train_mask = mask.sum(axis=1) == 0
test_mask  = mask.sum(axis=1) > 0

train_idx = train_mask.loc[train_mask].index
test_idx  = test_mask.loc[test_mask].index

In [None]:
# convert y to residual if needed

if params_data['y_resid'] == 'None':
    pass
# TODO update for GMX
# else:
#     if params_data['y_resid_strat'] == 'naive_mean':
#         # use only data in the training set (especially since testers will be more likely to be found across envs)
#         # get enviromental means, subtract from observed value
#         tmp = phno.loc[train_idx, ]
#         env_mean = tmp.groupby(['Env_Idx']
#                      ).agg(Env_Mean = (y_var, 'mean')
#                      ).reset_index()
#         tmp = phno.merge(env_mean)
#         tmp.loc[:, y_var] = tmp.loc[:, y_var] - tmp.loc[:, 'Env_Mean']
#         phno = tmp.drop(columns='Env_Mean')

#     if params_data['y_resid_strat'] == 'filter_mean':
#         # for adjusting to environment we could use _all_ observations but ideally we will use the same set of genotypes across all observations
#         def minimum_hybrids_for_env(tmp = phno.loc[:, ['Env', 'Year', 'Hybrid']],
#                                     year = 2014):
#             # Within each year what hybrids are most common?
#             tmp = tmp.loc[(tmp.Year == year), ].groupby(['Env', 'Hybrid']).count().reset_index().sort_values('Year')

#             all_envs = set(tmp.Env)
#             # if we filter on the number of sites a hybrid is planted at, what is the largest number of sites we can ask for before we lose a location?
#             # site counts for sets which contain all envs
#             i = max([i for i in list(set(tmp.Year)) if len(set(tmp.loc[(tmp.Year >= i), 'Env'])) == len(all_envs)])

#             before = len(set(tmp.loc[:, 'Hybrid']))
#             after  = len(set(tmp.loc[(tmp.Year >= i), 'Hybrid']))
#             print(f'Reducing {year} hybrids from {before} to {after} ({round(100*after/before)}%).')
#             tmp = tmp.loc[(tmp.Year >= i), ['Env', 'Hybrid']].reset_index(drop=True)
#             return tmp


#         tmp = phno.loc[:, ['Env', 'Year', 'Hybrid']]
#         filter_hybrids = [minimum_hybrids_for_env(tmp = phno.loc[:, ['Env', 'Year', 'Hybrid']], year = i) 
#                           for i in list(set(phno.Year)) ]
#         env_mean = pd.concat(filter_hybrids).merge(phno, how = 'left')

#         env_mean = env_mean.groupby(['Env_Idx']
#                           ).agg(Env_Mean = (y_var, 'mean')
#                           ).reset_index()

#         tmp = phno.merge(env_mean)
#         tmp.loc[:, y_var] = tmp.loc[:, y_var] - tmp.loc[:, 'Env_Mean']
#         phno = tmp.drop(columns='Env_Mean')
        

In [None]:
# center and y value data
assert 0 == phno.loc[:, y_var].isna().sum().sum() # second sum is for multiple y_vars

y = phno.loc[:, y_var].to_numpy() # added to make multiple ys work
# use train index to prevent information leakage
y_c = y[train_idx].mean(axis=0)
y_s = y[train_idx].std(axis=0)

y = (y - y_c)/y_s

## Fit Using VNNHelper

In [None]:
def vnn_factory_1(parsed_kegg_gene_entries, params):

    print(''.join('#' for i in range(80)))
    print(params)
    print(''.join('#' for i in range(80)))
    
    
    default_out_nodes_inp  = params['default_out_nodes_inp' ]
    default_out_nodes_edge = params['default_out_nodes_edge'] 
    default_out_nodes_out  = params['default_out_nodes_out' ]

    default_drop_nodes_inp = params['default_drop_nodes_inp' ] 
    default_drop_nodes_edge= params['default_drop_nodes_edge'] 
    default_drop_nodes_out = params['default_drop_nodes_out' ] 

    default_reps_nodes_inp = params['default_reps_nodes_inp' ]
    default_reps_nodes_edge= params['default_reps_nodes_edge']
    default_reps_nodes_out = params['default_reps_nodes_out' ]



    default_decay_rate = params['default_decay_rate' ]



    # Clean up KEGG Pathways -------------------------------------------------------
    # Same setup as above to create kegg_gene_brite
    # Restrict to only those with pathway
    kegg_gene_brite = [e for e in parsed_kegg_gene_entries if 'BRITE' in e.keys()]

    # also require to have a non-empty path
    kegg_gene_brite = [e for e in kegg_gene_brite if not e['BRITE']['BRITE_PATHS'] == []]

    print('Retaining '+ str(round(len(kegg_gene_brite)/len(parsed_kegg_gene_entries), 4)*100)+'%, '+str(len(kegg_gene_brite)
        )+'/'+str(len(parsed_kegg_gene_entries)
        )+' Entries'
        )
    # kegg_gene_brite[1]['BRITE']['BRITE_PATHS']


    kegg_connections = kegg_connections_build(kegg_gene_brite = kegg_gene_brite, 
                                            n_genes = len(kegg_gene_brite)) 
    kegg_connections = kegg_connections_clean(         kegg_connections = kegg_connections)
    #TODO think about removing 
    # "Not Included In
    # Pathway Or Brite"
    # or reinstate 'Others'

    kegg_connections = kegg_connections_append_y_hat(  kegg_connections = kegg_connections)
    kegg_connections = kegg_connections_sanitize_names(kegg_connections = kegg_connections, 
                                                    replace_chars = {'.':'_'})


    # Initialize helper for input nodes --------------------------------------------
    myvnn = VNNHelper(edge_dict = kegg_connections)

    # Get a mapping of brite names to tensor list index
    find_names = myvnn.nodes_inp # e.g. ['100383860', '100278565', ... ]
    lookup_dict = {}

    # the only difference lookup_dict and brite_node_to_list_idx_dict above is that this is made using the full set of genes in the list 
    # whereas that is made using kegg_gene_brite which is a subset
    for i in range(len(parsed_kegg_gene_entries)):
        if 'BRITE' not in parsed_kegg_gene_entries[i].keys():
            pass
        elif parsed_kegg_gene_entries[i]['BRITE']['BRITE_PATHS'] == []:
            pass
        else:
            name = parsed_kegg_gene_entries[i]['BRITE']['BRITE_PATHS'][0][-1]
            if name in find_names:
                lookup_dict[name] = i
    # lookup_dict    

    brite_node_to_list_idx_dict = {}
    for i in range(len(kegg_gene_brite)):
        brite_node_to_list_idx_dict[str(kegg_gene_brite[i]['BRITE']['BRITE_PATHS'][0][-1])] = i        

    # Get the input sizes for the graph
    size_in_zip = zip(myvnn.nodes_inp, [np.prod(ACGT_gene_slice_list[lookup_dict[e]].shape[1:]) for e  in myvnn.nodes_inp])

    # Set node defaults ------------------------------------------------------------
    # init input node sizes
    myvnn.set_node_props(key = 'inp', node_val_zip = size_in_zip)

    # init node output sizes
    myvnn.set_node_props(key = 'out', node_val_zip = zip(myvnn.nodes_inp, [default_out_nodes_inp  for e in myvnn.nodes_inp]))
    myvnn.set_node_props(key = 'out', node_val_zip = zip(myvnn.nodes_edge,[default_out_nodes_edge for e in myvnn.nodes_edge]))
    myvnn.set_node_props(key = 'out', node_val_zip = zip(myvnn.nodes_out, [default_out_nodes_out  for e in myvnn.nodes_out]))

    # # options should be controlled by node_props
    myvnn.set_node_props(key = 'flatten', node_val_zip = zip(myvnn.nodes_inp, [True for e in myvnn.nodes_inp]))

    myvnn.set_node_props(key = 'reps', node_val_zip = zip(myvnn.nodes_inp, [default_reps_nodes_inp  for e in myvnn.nodes_inp]))
    myvnn.set_node_props(key = 'reps', node_val_zip = zip(myvnn.nodes_edge,[default_reps_nodes_edge for e in myvnn.nodes_edge]))
    myvnn.set_node_props(key = 'reps', node_val_zip = zip(myvnn.nodes_out, [default_reps_nodes_out  for e in myvnn.nodes_out]))

    myvnn.set_node_props(key = 'drop', node_val_zip = zip(myvnn.nodes_inp, [default_drop_nodes_inp  for e in myvnn.nodes_inp]))
    myvnn.set_node_props(key = 'drop', node_val_zip = zip(myvnn.nodes_edge,[default_drop_nodes_edge for e in myvnn.nodes_edge]))
    myvnn.set_node_props(key = 'drop', node_val_zip = zip(myvnn.nodes_out, [default_drop_nodes_out  for e in myvnn.nodes_out]))


    # Scale node outputs by distance -----------------------------------------------
    dist = sparsevnn.core.vertex_from_end(
        edge_dict = myvnn.edge_dict,
        end =myvnn.dependancy_order[-1]
    )

    # overwrite node outputs with a size inversely proportional to distance from prediction node
    for query in list(dist.keys()):
        myvnn.node_props[query]['out'] = _dist_scale_function(
            out = myvnn.node_props[query]['out'],
            dist = dist[query],
            decay_rate = default_decay_rate)
        

    # Expand out node replicates ---------------------------------------------------
    # kegg_connections_expanded = sparsevnn.core.expand_edge_dict(vnn_helper = myvnn, edge_dict = myvnn.edge_dict)

    # ======================================================= #
    # one place to add residual connections would be here.    #
    # edit the links before instatinating the new VNNHelper   #
    # the important thing to do is to edit the graph before   #
    # calulating the inputs for each node.                    #
    # ======================================================= #

            # # expand then copy over the properties that have already been defined.
            # myvnn_exp = VNNHelper(edge_dict = kegg_connections_expanded)

            # import re
            # for new_key in list(myvnn_exp.node_props.keys()):
            #     if new_key in myvnn.node_props.keys():
            #         # copy directly
            #         myvnn_exp.node_props[new_key] = myvnn.node_props[new_key]
            #     else:
            #         # check for a key that matches the query key after removing the replicate information
            #         query = new_key 
            #         suffix = re.findall('_rep_\d+$', query)[0]

            #         query = query.removesuffix(suffix)
            #         if query in myvnn.node_props.keys():
            #             myvnn_exp.node_props[new_key] = myvnn.node_props[query]
            #         else:
            #             print(f'WARNING: no entry {query} found for {new_key}') 

            # # now main vnn is the expanded version
            # myvnn = myvnn_exp


            # # Cleanup ----------------------------------------------------------------------
            # # now the original VNNHelper isn't needed
            # myvnn = myvnn_exp


    # expand out graph.
#     update_edge_links = {}
#     nodes = myvnn.dependancy_order

#     for query in [e for e in reversed(nodes)]:
#         if myvnn.node_props[query]['reps'] == 1:
#             pass
#         else:
#             reps = myvnn.node_props[query]['reps']
#             # set to 1 so that we can copy all the props (except input) over
#             myvnn.node_props[query]['reps'] = 1
            
#             for i in range(reps-1,0,-1):
#                 if i == 0:
#                     # no replicates
#                     pass
#                 else:
#                     if i == 1:
#                         # print({f'{query}_{i}':[f'{query}']})
#                         update_edge_links[f'{query}_{i}'] = f'{query}'
#                     else:
#                         # print({f'{query}_{i}':[f'{query}_{i-1}']})
#                         update_edge_links[f'{query}_{i}'] = f'{query}_{i-1}'

#                     # copy over all properties except input (input will either be set for the data or calculated on the fly)
#                     myvnn.node_props[f'{query}_{i}'] = {k:myvnn.node_props[query][k] for k in myvnn.node_props[query] if k != 'inp'}


#     # Now there should be new nodes in the helper but the links need to be updated to point to the right names. 

#     if True:
#         # update existing links
#         # create a lookup dictionary to map the old names to new names
#         old_to_new = {update_edge_links[k]:k for k in update_edge_links}
#         # old_to_new

#         for k in myvnn.edge_dict:
#             myvnn.edge_dict[k] = [e if e not in old_to_new.keys() else old_to_new[e] for e in myvnn.edge_dict[k]]

#         # add in new nodes
#         for k in update_edge_links:
#             myvnn.edge_dict[k] = [update_edge_links[k]]

#         # overwrite dependancy order
#         # myvnn.dependancy_order = VNNHelper(edge_dict= myvnn.edge_dict).dependancy_order
# # myvnn_updated = VNNHelper(edge_dict= myvnn.edge_dict)
# # myvnn_updated.node_props = myvnn.node_props
#         # myvnn = myvnn_updated

    # expand out graph.
    update_edge_links = {}
    nodes = [node for node in myvnn.dependancy_order if myvnn.node_props[node]['reps'] > 1]

    node_expansion_dict = {
        node: [node if i==0 else f'{node}_{i}' for i in range(myvnn.node_props[node]['reps'])]
        for node in nodes}
    #   current       1st          2nd (new)      3rd (new)
    # {'100798274': ['100798274', '100798274_1', '100798274_2'], ...

    # the keys don't change here. The values will be updated and then new k:v will be inserted
    myvnn.edge_dict = {k:[e if e not in node_expansion_dict.keys() 
        else node_expansion_dict[e][-1]
        for e in myvnn.edge_dict[k] ] for k in myvnn.edge_dict}

    # now insert connectsion to new nodes: A -> A_rep_1 -> A_rep_2
    for node in node_expansion_dict:
        for pair in zip(node_expansion_dict[node][1:], node_expansion_dict[node]):
            myvnn.edge_dict[pair[0]] = [pair[1]]

    # now add those new nodes
    # create a new node for all the nodes
    for node in node_expansion_dict:
        for new_node in node_expansion_dict[node][1:]:
            myvnn.node_props[new_node] = {k:myvnn.node_props[node][k] for k in myvnn.node_props[node] if k != 'inp'}

    new_vnn = VNNHelper(edge_dict= myvnn.edge_dict)
    new_vnn.node_props = myvnn.node_props
    myvnn = new_vnn


    # init edge node input size (propagate forward input/edge outpus)
    myvnn.calc_edge_inp()

    # replace lookup so that it matches the lenght of the input tensors
    new_lookup_dict = {}
    for i in range(len(myvnn.nodes_inp)):
        new_lookup_dict[myvnn.nodes_inp[i]] = i
    
    return myvnn,  lookup_dict #new_lookup_dict

myvnn, new_lookup_dict = vnn_factory_1(parsed_kegg_gene_entries = parsed_kegg_gene_entries, params = params)

### Calculate nodes membership in each matrix and positions within each

In [None]:
def vnn_factory_2(vnn_helper, node_to_inp_num_dict):
    myvnn = vnn_helper

    node_props = myvnn.node_props
    # Linear_block = Linear_block_reps,
    edge_dict = myvnn.edge_dict
    dependancy_order = myvnn.dependancy_order
    node_to_inp_num_dict = new_lookup_dict

    # Build dependancy dictionary --------------------------------------------------
    # check dep order
    tally = []
    for d in dependancy_order:
        if edge_dict[d] == []:
            tally.append(d)
        elif False not in [True if e in tally else False for e in edge_dict[d]]:
            tally.append(d)
        else:
            print('error!')
            break


    # build output nodes 
    d_out = {0:[]}
    for d in dependancy_order:
        if edge_dict[d] == []:
            d_out[min(d_out.keys())].append(d)
        else:
            # print((d, edge_dict[d]))

            d_out_i = 1+max(sum([[key for key in d_out.keys() if e in d_out[key]]
                    for e in edge_dict[d]], []))
            
            if d_out_i not in d_out.keys():
                d_out[d_out_i] = []
            d_out[d_out_i].append(d)


    # build input nodes NOPE. THE PASSHTROUGHS! 
    d_eye = {}
    tally = []
    for i in range(max(d_out.keys()), min(d_out.keys()), -1):
        # print(i)
        nodes_needed = sum([edge_dict[e] for e in d_out[i]], [])+tally
        # check against what is there and then dedupe
        nodes_needed = [e for e in nodes_needed if e not in d_out[i-1]]
        nodes_needed = list(set(nodes_needed))
        tally = nodes_needed
        d_eye[i] = nodes_needed

    # d_inp[0]= d_out[0]
    # [len(d_eye[i]) for i in d_eye.keys()]
    # [(key, len(d_out[key])) for key in d_out.keys()]


    dd = {}
    for i in d_eye.keys():
        dd[i] = {'out': d_out[i],
                'inp': d_out[i-1],
                'eye': d_eye[i]}
    # plus special 0 layer that handles the snps
        
    dd[0] = {'out': d_out[0],
            'inp': d_out[0],
            'eye': []}


    # check that the output nodes' inputs are satisfied by the same layer's inputs (inp and eye)
    for i in dd.keys():
        # out node in each
        for e in dd[i]['out']:
            # node depends in inp/eye
            node_pass_list = [True if ee in dd[i]['inp']+dd[i]['eye'] else False 
                            for ee in edge_dict[e]]
            if False not in node_pass_list:
                pass
            else:
                print('exit') 


    # print("Layer\t#In\t#Out")
    # for i in range(min(dd.keys()), max(dd.keys())+1, 1):
    #     node_in      = [node_props[e]['out'] for e in dd[i]['inp']+dd[i  ]['eye'] ]
    #     if i == max(dd.keys()):
    #         node_out = [node_props[e]['out'] for e in dd[i]['out'] ]
    #     else:
    #         node_out = [node_props[e]['out'] for e in dd[i]['out']+dd[i+1]['eye']]
    #     print(f'{i}:\t{sum(node_in)}\t{sum(node_out)}')

    M_list = [structured_layer_info(i = ii, node_groups = dd, node_props= node_props, edge_dict = edge_dict, as_sparse=True) for ii in range(0, max(dd.keys())+1)]
    return M_list

### Creating Structured Matrices for Layers
M_list = vnn_factory_2(vnn_helper = myvnn, node_to_inp_num_dict = new_lookup_dict)

### Setup Dataloader using `M_list`

In [None]:
lookup_dict = new_lookup_dict

vals = get_data('KEGG_slices')
vals = [torch.from_numpy(e).to(torch.float) for e in vals]
# restrict to the tensors that will be used
vals = torch.concat([vals[lookup_dict[i]].reshape(vals[0].shape[0], -1) 
                     for i in M_list[0].row_inp
                    #  for i in dd[0]['inp'] # matches
                     ], axis = 1)

vals = vals.to('cuda')

In [None]:
training_dataloader = DataLoader(BigDataset(
    lookups_are_filtered = False,
    lookup_obs  = torch.from_numpy(np.array(train_idx)), #X.get('val:train',       ops_string='   asarray from_numpy      '),
    lookup_geno = torch.from_numpy(obs_geno_lookup),
    y =           torch.from_numpy(y).to(torch.float32),
    G =           vals,
    G_type = 'raw',
    send_batch_to_gpu = 'cuda:0'
    ),
    batch_size = batch_size,
    shuffle = True 
)

validation_dataloader = DataLoader(BigDataset(
    lookups_are_filtered = False,
    lookup_obs  = torch.from_numpy(np.array(test_idx)), #X.get('val:train',       ops_string='   asarray from_numpy      '),
    lookup_geno = torch.from_numpy(obs_geno_lookup),
    y =           torch.from_numpy(y).to(torch.float32),
    G =           vals,
    G_type = 'raw',
    send_batch_to_gpu = 'cuda:0'
    ),
    batch_size = batch_size,
    shuffle = False 
)

## Structured Layer

In [None]:
class NeuralNetwork(nn.Module):
    def __init__(self, layer_list):
        super(NeuralNetwork, self).__init__()
        self.layer_list = nn.ModuleList(layer_list)
 
    def forward(self, x):
        for l in self.layer_list:
            x = l(x)
        return x

In [None]:
def vnn_factory_3(M_list):
    layer_list = []
    for i in range(len(M_list)):
        
        apply_relu = None
        if i+1 != len(M_list): # apply relu to all but the last layer
            apply_relu = F.relu
        

        l = SparseLinearCustom(
            M_list[i].weight.shape[1], # have to transpose this?
            M_list[i].weight.shape[0],
            connectivity   = torch.LongTensor(M_list[i].weight.coalesce().indices()),
            custom_weights = M_list[i].weight.coalesce().values(), 
            custom_bias    = M_list[i].bias.clone().detach(), 
            weight_grad_bool = M_list[i].weight_grad_bool, 
            bias_grad_bool   = M_list[i].bias_grad_bool, #.to_sparse()#.indices()
            dropout_p        = M_list[i].dropout_p,
            nonlinear_transform= apply_relu
            )

        layer_list += [l]
        
    return layer_list

## Tiny Test Study

| type  | value key | value type  |
|------:|--|--|
| range | bounds | list |
| choice| values | list |
| fixed | value  | atomic |


<!-- {
#     "name": "size",
#     "type": "range",
#     "bounds": [1, 256],
#     "value_type": "int",  
#     "log_scale": False,  # For range, defaults to False.
# },
# {
#     "name": "drop",
#     "type": "choice",
#     "values": [],    
#     "value_type":"int", #"float", "bool", "str"
#     "is_ordered": False, # For choice    
    
#     "value_type": "float",  # Optional, defaults to inference from type of "bounds".
#     "bounds": [0.0, 1.0],
# },
# {
#     "name": "drop",
#     "type": "fixed",
#     "value": [],     
#     "value_type": "float",  # Optional, defaults to inference from type of "bounds".
#     "bounds": [0.0, 1.0],
# }

# # constraints can be passed to ax_clien.create_experiment for `parameter_constraints` and `outcome_constraints`. For example see https://ax.dev/tutorials/gpei_hartmann_service.html -->


In [None]:
# this is a very funny trick. I'm going to call lighning from within Ax. 
# That way I can save out traces while also relying on Ax to choose new hyps. 


def evaluate(parameterization):
    # draw from global
    # max_epoch = 20
    # lightning_log_dir = "test_tb"
    myvnn, new_lookup_dict = vnn_factory_1(parsed_kegg_gene_entries = parsed_kegg_gene_entries, params = parameterization)
    M_list = vnn_factory_2(vnn_helper = myvnn, node_to_inp_num_dict = new_lookup_dict)
    layer_list =  vnn_factory_3(M_list = M_list)
    model = NeuralNetwork(layer_list = layer_list)
    
    VNN = plDNN_general(model)  
    optimizer = VNN.configure_optimizers()
    logger = CSVLogger(lightning_log_dir, name=exp_name)
    logger.log_hyperparams(params={
        'params': parameterization
    })

    trainer = pl.Trainer(max_epochs=max_epoch, logger=logger)
    trainer.fit(model=VNN, train_dataloaders=training_dataloader, val_dataloaders=validation_dataloader)


    # if we were optimizing number of training epochs this would be an effective loss to use.
    # trainer.callback_metrics['train_loss']
    # float(trainer.callback_metrics['train_loss'])

    # To potentially _overtrain_ models and still let the selction be based on their best possible performance,
    # I'll use the lowest average error in an epoch
    log_path = lightning_log_dir+'/'+exp_name
    fls = os.listdir(log_path)
    nums = [int(e.split('_')[-1]) for e in fls] 

    M = pd.read_csv(log_path+f"/version_{max(nums)}/metrics.csv")
    M = M.loc[:, ['epoch', 'train_loss']].dropna()

    M = M.groupby('epoch').agg(
        train_loss = ('train_loss', 'mean'),
        train_loss_sd = ('train_loss', 'std'),
        ).reset_index()

    train_metric = M.train_loss.min()
    print(train_metric)
    return {"train_loss": (train_metric, 0.0)}


In [None]:
# ## Generated variables ====
# # using sql database
# sql_url = "sqlite:///"+f"./{lightning_log_dir}/{exp_name}.sqlite"

# # If the database exists, load it and begin from there
# loaded_db = False
# if os.path.exists(sql_url.split('///')[-1]): # must cleave off the sql dialect prefix
#     # alternate way to load an experiment (after `init_engine_and_session_factory` has been run)
#     # experiment = load_experiment(exp_name) # if this doesn't work, check if the database is named something else and try that.
#     db_settings = DBSettings(url=sql_url)
#     # Instead of URL, can provide a `creator function`; can specify custom encoders/decoders if necessary.
#     ax_client = AxClient(db_settings=db_settings)
#     ax_client.load_experiment_from_database(exp_name)
#     loaded_db = True

# else:
#     ax_client = AxClient()
#     ax_client.create_experiment(
#         name=exp_name,
#         parameters=params_list,
#         objectives={"train_loss": ObjectiveProperties(minimize=True)}
#     )

# run_trials_bool = True
# if run_hyps_force == False:
#     if loaded_db: 
#         # check if we've reached the max number of hyperparamters combinations to test
#         if max_hyps <= (ax_client.generation_strategy.trials_as_df.index.max()+1):
#             run_trials_bool = False

# if run_trials_bool:
#     # run the trials
#     for i in range(run_hyps):
#         parameterization, trial_index = ax_client.get_next_trial()
#         # Local evaluation here can be replaced with deployment to external system.
#         ax_client.complete_trial(trial_index=trial_index, raw_data=evaluate(parameterization))

#     if loaded_db == False:
#         init_engine_and_session_factory(url=sql_url)
#         engine = get_engine()
#         create_all_tables(engine)

#     # save the trials
#     experiment = ax_client.experiment
#     save_experiment(experiment)

In [None]:
def _prep_dls(
        y_var,
        train_idx,
        obs_geno_lookup,
        vals,
        batch_size,
        test_idx):

    phno = get_data('phno')
    # center and y value data
    assert 0 == phno.loc[:, y_var].isna().sum().sum() # second sum is for multiple y_vars

    y = phno.loc[:, y_var].to_numpy() # added to make multiple ys work
    # use train index to prevent information leakage
    y_c = y[train_idx].mean(axis=0)
    y_s = y[train_idx].std(axis=0)

    y = (y - y_c)/y_s



    training_dataloader = DataLoader(BigDataset(
        lookups_are_filtered = False,
        lookup_obs  = torch.from_numpy(np.array(train_idx)), #X.get('val:train',       ops_string='   asarray from_numpy      '),
        lookup_geno = torch.from_numpy(obs_geno_lookup),
        y =           torch.from_numpy(y).to(torch.float32),
        G =           vals,
        G_type = 'raw',
        send_batch_to_gpu = 'cuda:0'
        ),
        batch_size = batch_size,
        shuffle = True 
    )

    validation_dataloader = DataLoader(BigDataset(
        lookups_are_filtered = False,
        lookup_obs  = torch.from_numpy(np.array(test_idx)), #X.get('val:train',       ops_string='   asarray from_numpy      '),
        lookup_geno = torch.from_numpy(obs_geno_lookup),
        y =           torch.from_numpy(y).to(torch.float32),
        G =           vals,
        G_type = 'raw',
        send_batch_to_gpu = 'cuda:0'
        ),
        batch_size = batch_size,
        shuffle = False 
    )

    return training_dataloader, validation_dataloader

In [None]:
for ith_y_var in params_data['y_var'][0:2]:

    y_var = ith_y_var
    # update exp_name
    exp_name = [e for e in cache_path.split('/') if e != ''][-1]
    exp_name += '__'+y_var

    print(''.join(['#' for i in range(80)]))
    print(f'experiment: {exp_name}')
    print(''.join(['#' for i in range(80)]))
    print('\n')

    training_dataloader, validation_dataloader = _prep_dls(
            y_var = y_var,
            train_idx = train_idx,
            obs_geno_lookup = obs_geno_lookup,
            vals = vals,
            batch_size = batch_size,
            test_idx = test_idx)
    

    ## Generated variables ====
    # using sql database
    sql_url = "sqlite:///"+f"./{lightning_log_dir}/{exp_name}.sqlite"

    # If the database exists, load it and begin from there
    loaded_db = False
    if os.path.exists(sql_url.split('///')[-1]): # must cleave off the sql dialect prefix
        # alternate way to load an experiment (after `init_engine_and_session_factory` has been run)
        # experiment = load_experiment(exp_name) # if this doesn't work, check if the database is named something else and try that.
        db_settings = DBSettings(url=sql_url)
        # Instead of URL, can provide a `creator function`; can specify custom encoders/decoders if necessary.
        ax_client = AxClient(db_settings=db_settings)
        ax_client.load_experiment_from_database(exp_name)
        loaded_db = True

    else:
        ax_client = AxClient()
        ax_client.create_experiment(
            name=exp_name,
            parameters=params_list,
            objectives={"train_loss": ObjectiveProperties(minimize=True)}
        )

    run_trials_bool = True
    if run_hyps_force == False:
        if loaded_db: 
            # check if we've reached the max number of hyperparamters combinations to test
            if max_hyps <= (ax_client.generation_strategy.trials_as_df.index.max()+1):
                run_trials_bool = False

    if run_trials_bool:
        # run the trials
        for i in range(run_hyps):
            parameterization, trial_index = ax_client.get_next_trial()
            # Local evaluation here can be replaced with deployment to external system.
            ax_client.complete_trial(trial_index=trial_index, raw_data=evaluate(parameterization))

        if loaded_db == False:
            init_engine_and_session_factory(url=sql_url)
            engine = get_engine()
            create_all_tables(engine)

        # save the trials
        experiment = ax_client.experiment
        save_experiment(experiment)

In [None]:
ax_client.generation_strategy.trials_as_df#.tail()

In [None]:
# render(ax_client.get_contour_plot())

In [None]:
render(ax_client.get_optimization_trace(objective_optimum=0.0)) 

In [None]:
# If I need to check what tables are in the sqlite
# import sqlite3
# con = sqlite3.connect("./foo.db")
# cur = con.cursor()
# # cur.execute(".tables;") # should work, doesn't
# cur.execute("SELECT name FROM sqlite_master WHERE type='table';").fetchall()
# con.close()