# G only KEGG based network architecture

> 

In [1]:
import numpy as np
import pandas as pd

from EnvDL.core import ensure_dir_path_exists 
from EnvDL.dlfn import g2fc_datawrapper, BigDataset, plDNN_general
from EnvDL.dlfn import ResNet2d, BasicBlock2d
from EnvDL.dlfn import LSUV_

import torch
import torch.nn.functional as F # F.mse_loss
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch import nn

import lightning.pytorch as pl
from lightning.pytorch.loggers import TensorBoardLogger

from EnvDL.dlfn import kegg_connections_build, kegg_connections_clean, kegg_connections_append_y_hat, kegg_connections_sanitize_names
from EnvDL.dlfn import VNNHelper, VisableNeuralNetwork, Linear_block_reps
from EnvDL.dlfn import plDNN_general, BigDataset
from EnvDL.dlfn import reverse_edge_dict, reverse_node_props
from EnvDL.dlfn import VNNVAEHelper, plVNNVAE
from EnvDL.dlfn import kegg_connections_build, kegg_connections_clean, kegg_connections_append_y_hat, kegg_connections_sanitize_names
from EnvDL.dlfn import VNNHelper, VisableNeuralNetwork, Linear_block_reps
from EnvDL.dlfn import ListDataset, plVNN
from EnvDL.dlfn import plDNN_general, BigDataset

In [2]:
cache_path = '../nbs_artifacts/02.40_g2fc_G_ACGT_VNN_baseline/'
save_prefix = [e for e in cache_path.split('/') if e != ''][-1]

# Run settings: 
max_epoch  = 20
batch_size = 48

# VNN settings:
default_out_nodes_inp   = 4
default_out_nodes_edge  = 32
default_out_nodes_out   = 1

default_drop_nodes_inp  = 0.0
default_drop_nodes_edge = 0.
default_drop_nodes_out  = 0.0

default_reps_nodes_inp  = 1
default_reps_nodes_edge = 1
default_reps_nodes_out  = 1

In [3]:
use_gpu_num = 0

device = "cuda" if torch.cuda.is_available() else "cpu"
if use_gpu_num in [0, 1]: 
    torch.cuda.set_device(use_gpu_num)
print(f"Using {device} device")

Using cuda device


In [4]:
ensure_dir_path_exists(dir_path = cache_path)

In [5]:
from EnvDL.dlfn import kegg_connections_build, kegg_connections_clean, kegg_connections_append_y_hat, kegg_connections_sanitize_names
from EnvDL.dlfn import VNNHelper, VisableNeuralNetwork, Linear_block_reps
from EnvDL.dlfn import ListDataset, plVNN

## Fit Using VNNHelper

In [6]:

# Same setup as above to create kegg_gene_brite
X = g2fc_datawrapper()
X.set_split()
X.load_all(name_list = ['obs_geno_lookup', 'YMat', 'KEGG_slices',], store=True) 
X.calc_cs('YMat', version = 'np', filter = 'val:train')
ACGT_gene_slice_list =     X.get('KEGG_slices', ops_string='')
parsed_kegg_gene_entries = X.get('KEGG_entries')


# Restrict to only those with pathway
kegg_gene_brite = [e for e in parsed_kegg_gene_entries if 'BRITE' in e.keys()]

# also require to have a non-empty path
kegg_gene_brite = [e for e in kegg_gene_brite if not e['BRITE']['BRITE_PATHS'] == []]

print('Retaining '+ str(round(len(kegg_gene_brite)/len(parsed_kegg_gene_entries), 4)*100)+'%, '+str(len(kegg_gene_brite)
    )+'/'+str(len(parsed_kegg_gene_entries)
    )+' Entries'
    )
# kegg_gene_brite[1]['BRITE']['BRITE_PATHS']

Loading and storing default `phno`.
Retaining 43.53%, 6067/13939 Entries


In [7]:
kegg_connections = kegg_connections_build(kegg_gene_brite = kegg_gene_brite, 
                                          n_genes = 6067) 
kegg_connections = kegg_connections_clean(         kegg_connections = kegg_connections)
kegg_connections = kegg_connections_append_y_hat(  kegg_connections = kegg_connections)
kegg_connections = kegg_connections_sanitize_names(kegg_connections = kegg_connections, 
                                                   replace_chars = {'.':'_'})

100%|██████████| 6067/6067 [00:00<00:00, 58683.04it/s]

Removed node "Others"





In [8]:
# initialize helper for input nodes
myvnn = VNNHelper(edge_dict = kegg_connections)

myvnn.nodes_inp[0:10]

# Get a mapping of brite names to tensor list index
find_names = myvnn.nodes_inp # e.g. ['100383860', '100278565', ... ]
lookup_dict = {}

# the only difference lookup_dict and brite_node_to_list_idx_dict above is that this is made using the full set of genes in the list 
# whereas that is made using kegg_gene_brite which is a subset
for i in range(len(parsed_kegg_gene_entries)):
    if 'BRITE' not in parsed_kegg_gene_entries[i].keys():
        pass
    elif parsed_kegg_gene_entries[i]['BRITE']['BRITE_PATHS'] == []:
        pass
    else:
        name = parsed_kegg_gene_entries[i]['BRITE']['BRITE_PATHS'][0][-1]
        if name in find_names:
            lookup_dict[name] = i
lookup_dict    

{'100278565': 0,
 '100383860': 1,
 '100383837': 3,
 '100191673': 8,
 '100275685': 9,
 '103630585': 13,
 '100194370': 16,
 '100194192': 17,
 '100273289': 18,
 '100037826': 19,
 '100192899': 21,
 '100304252': 26,
 '100280063': 28,
 '103630746': 30,
 '100282726': 31,
 '100285519': 33,
 '100272424': 35,
 '100381824': 36,
 '100277985': 40,
 '100283343': 42,
 '100284607': 45,
 '100191772': 48,
 '100277385': 49,
 '100381530': 50,
 '103630860': 51,
 '100384769': 52,
 '103630917': 55,
 '100194239': 56,
 '103644387': 57,
 '103631010': 58,
 '103631056': 63,
 '100191905': 64,
 '100281199': 68,
 '100283731': 70,
 '103631177': 72,
 '100281329': 76,
 '100274981': 78,
 '100382167': 86,
 '100383301': 87,
 '100285254': 90,
 '100280817': 91,
 '542672': 92,
 '542290': 93,
 '103631462': 94,
 '100283425': 96,
 '103631568': 97,
 '100280768': 98,
 '103631596': 102,
 '542230': 105,
 '103631647': 106,
 '100383383': 107,
 '100274307': 111,
 '100273485': 112,
 '100191592': 113,
 '100273170': 114,
 '100384051': 11

In [9]:
# # if permuting gene identities
# torch.manual_seed(5461)

# keys = [e for e in lookup_dict.keys()]

# # vals = [lookup_dict[e] for e in lookup_dict.keys()]
# # dict(zip(keys, [int(i) for i in torch.randperm(len(keys))]))

# idx = torch.tensor([lookup_dict[e] for e in myvnn.nodes_inp])
# idx = idx[torch.randperm(idx.shape[0])]
# idx = [int(i) for i in idx]
# temp = dict(zip(myvnn.nodes_inp, idx))

# randomized_lookup_dict = {}
# for e in lookup_dict.keys():
#     if e not in temp.keys():
#         randomized_lookup_dict[e] = lookup_dict[e]
#     else:
#         randomized_lookup_dict[e] = temp[e]

# lookup_dict = randomized_lookup_dict

In [10]:
brite_node_to_list_idx_dict = {}
for i in range(len(kegg_gene_brite)):
    brite_node_to_list_idx_dict[str(kegg_gene_brite[i]['BRITE']['BRITE_PATHS'][0][-1])] = i        

# Get the input sizes for the graph
size_in_zip = zip(myvnn.nodes_inp, [np.prod(ACGT_gene_slice_list[lookup_dict[e]].shape[1:]) for e  in myvnn.nodes_inp])


In [11]:

# init input node sizes
myvnn.set_node_props(key = 'inp', node_val_zip = size_in_zip)

# init node output sizes
myvnn.set_node_props(key = 'out', node_val_zip = zip(myvnn.nodes_inp, [default_out_nodes_inp  for e in myvnn.nodes_inp]))
myvnn.set_node_props(key = 'out', node_val_zip = zip(myvnn.nodes_edge,[default_out_nodes_edge for e in myvnn.nodes_edge]))
myvnn.set_node_props(key = 'out', node_val_zip = zip(myvnn.nodes_out, [default_out_nodes_out  for e in myvnn.nodes_out]))


# # options should be controlled by node_props
myvnn.set_node_props(key = 'flatten', node_val_zip = zip(
    myvnn.nodes_inp, 
    [True for e in myvnn.nodes_inp]))

myvnn.set_node_props(key = 'reps', node_val_zip = zip(myvnn.nodes_inp, [default_reps_nodes_inp  for e in myvnn.nodes_inp]))
myvnn.set_node_props(key = 'reps', node_val_zip = zip(myvnn.nodes_edge,[default_reps_nodes_edge for e in myvnn.nodes_edge]))
myvnn.set_node_props(key = 'reps', node_val_zip = zip(myvnn.nodes_out, [default_reps_nodes_out  for e in myvnn.nodes_out]))

myvnn.set_node_props(key = 'drop', node_val_zip = zip(myvnn.nodes_inp, [default_drop_nodes_inp  for e in myvnn.nodes_inp]))
myvnn.set_node_props(key = 'drop', node_val_zip = zip(myvnn.nodes_edge,[default_drop_nodes_edge for e in myvnn.nodes_edge]))
myvnn.set_node_props(key = 'drop', node_val_zip = zip(myvnn.nodes_out, [default_drop_nodes_out  for e in myvnn.nodes_out]))

# init edge node input size (propagate forward input/edge outpus)
myvnn.calc_edge_inp()

# myvnn.mk_digraph(include = ['node_name', 'inp_size', 'out_size'])
# myvnn.mk_digraph(include = [''])

In [12]:
from EnvDL.dlfn import plDNN_general, BigDataset

In [13]:
vals = X.get('KEGG_slices', ops_string='asarray from_numpy float')

In [14]:
# restrict to the tensors that will be used
vals = [vals[lookup_dict[i]] for i in myvnn.nodes_inp]
# send to gpu
vals = [val.to('cuda') for val in vals]

In [15]:
# replace lookup so that it matches the lenght of the input tensors
new_lookup_dict = {}
for i in range(len(myvnn.nodes_inp)):
    new_lookup_dict[myvnn.nodes_inp[i]] = i
    # print((myvnn.nodes_inp[i], i))
    # break

In [16]:
## start insert

In [17]:
node_props = myvnn.node_props
# Linear_block = Linear_block_reps,
edge_dict = myvnn.edge_dict
dependancy_order = myvnn.dependancy_order
node_to_inp_num_dict = new_lookup_dict

In [18]:
# turn dependancy_order into a set of nodes that can be represented together on the edges of each matrix.
# walking through the dependency order, what is the maximum set of nodes that only depend on nodes which have already been calculated?

def find_nodes_dependency_clusters(
        dependancy_order,
        edge_dict):
    def get_current_window(available_nodes:list, # ordered list of nodes
                        connection_dict: dict, # dict of connections with parent->[children]
                        node_list_is_ordered:bool = True # If the node list isn't ordered then all will be searched
                        ):
        current_window = []
        for j in range(len(available_nodes)):
            prop_node = available_nodes[j]
            # does proposed node need any nodes in the current window?
            if True not in [True if e in current_window else False for e in connection_dict[prop_node]]:
                current_window += [prop_node]
            else:
                if node_list_is_ordered:
                    break
                else: # If the node list isn't ordered then all will be searched
                    pass
        return current_window

    # These are the nodes that can be in each matrix without being needed for the next output. 
    # This is _not_ including the nodes that need to be kept around as input for future nodes.
    temp_nodes = dependancy_order.copy()
    nodes_in_matrices = []
    for _ in range(len(temp_nodes)):
        if temp_nodes == []:
            break
        else:
            nodes_in_matrices += [
                get_current_window(available_nodes = temp_nodes,
                            connection_dict = edge_dict,
                            node_list_is_ordered = True)]

            temp_nodes = [e for e in temp_nodes if e not in nodes_in_matrices[-1]]
            # print(nodes_in_matrices)
    return(nodes_in_matrices)

nodes_as_matricies = find_nodes_dependency_clusters(
        dependancy_order = dependancy_order,
        edge_dict = edge_dict)

nodes_as_matricies = dict(zip([i for i in range(len(nodes_as_matricies))], nodes_as_matricies))
nodes_as_matricies

{0: ['103627197',
  '103646242',
  '103634739',
  '100273995',
  '113633883',
  '100191313',
  '103628156',
  '100382167',
  '103635427',
  '100216610',
  '100283696',
  '100285210',
  '100274222',
  '100283179',
  '100282762',
  '100281008',
  '103627891',
  '103642826',
  '103649923',
  '100282427',
  '100283558',
  '100381699',
  '103625778',
  '103653335',
  '103649912',
  '103626041',
  '100191330',
  '100285329',
  '100280690',
  '100192928',
  '100282883',
  '541967',
  '100194011',
  '100382343',
  '103643234',
  '103654581',
  '100285438',
  '100286035',
  '103625822',
  '100501431',
  '100193061',
  '103652996',
  '103637620',
  '100383192',
  '100192071',
  '100281536',
  '100283456',
  '103626932',
  '100502407',
  '103645955',
  '100279929',
  '100285571',
  '100857003',
  '103654422',
  '100273009',
  '100274751',
  '103647953',
  '100217183',
  '103636076',
  '100283135',
  '103654004',
  '103640907',
  '541937',
  '103630633',
  '103634847',
  '542532',
  '103637987',
 

In [19]:
# What steps need an input to be passed through unchanged? (identity)
# A -------C   Here C needs A as an identity in the previous matrix.
#   \     /    
#    --B--
# 
#    B A       C
# A |# 1|   A |#|
#           B |#|    

def find_nodes_dependency_identity(
        edge_dict,
        nodes_as_matricies
        ):
    connection_dict = edge_dict
    identity_nodes = {}
    for matrix_level in range(max(nodes_as_matricies.keys()), 0, -1):
        # print(matrix_level)
        if matrix_level == max(nodes_as_matricies.keys()):
            for_higher_level = []
        else:
            for_higher_level = identity_nodes[matrix_level]

        needed_inputs = list(set(sum([connection_dict[node] for node in nodes_as_matricies[matrix_level]], [])))        

        identity_nodes[matrix_level-1] = [e for e in list(set(needed_inputs+for_higher_level))
                                        if e not in nodes_as_matricies[matrix_level-1]]
        # print(identity_nodes)
    return(identity_nodes)

identity_nodes = find_nodes_dependency_identity(
        edge_dict = edge_dict,
        nodes_as_matricies = nodes_as_matricies
        )
# inputs to one layer that need to be preserved for the next layer
[(e, len(identity_nodes[e])) for e in identity_nodes.keys()]

[(6, 0), (5, 87), (4, 151), (3, 2027), (2, 6088), (1, 3580), (0, 0)]

In [20]:
node_matrix_summary = {}
for key in nodes_as_matricies.keys():
    node_matrix_summary[key] = {'calc':tuple(nodes_as_matricies[key])}
    if key in identity_nodes.keys():
        node_matrix_summary[key]['identity'] = tuple(identity_nodes[key])
    else:
        node_matrix_summary[key]['identity'] = []

# node_matrix_summary
        
print('Note: Layer 0 is the raw inputs.')
[(e, 
  'calc',
  len(node_matrix_summary[e]['calc']),
  'id',
  len(node_matrix_summary[e]['identity'])
  ) for e in node_matrix_summary.keys()]

Note: Layer 0 is the raw inputs.


[(0, 'calc', 3580, 'id', 0),
 (1, 'calc', 2531, 'id', 3580),
 (2, 'calc', 1247, 'id', 6088),
 (3, 'calc', 833, 'id', 2027),
 (4, 'calc', 486, 'id', 151),
 (5, 'calc', 122, 'id', 87),
 (6, 'calc', 68, 'id', 0),
 (7, 'calc', 1, 'id', 0)]

In [21]:
# # 0->1
# i = 1
# j = i+1

# inp_nodes = node_matrix_summary[i]['identity']
# out_nodes = node_matrix_summary[i]['calc']+node_matrix_summary[j]['identity']


In [22]:
# # row lookup: positions on edge
# row_nodes     = [e for e in inp_nodes]
# row_node_idxs = [i for i in range(len(inp_nodes))]
# row_size      = torch.Tensor([node_props[e]['inp'] for e in inp_nodes]).to(torch.int)
# row_stop      = torch.cumsum(row_size, 0)
# row_start     = torch.concat([torch.Tensor([0]).to(torch.int), row_stop[0:-1]])

# row_start

In [23]:
# # col lookup: positions on edge
# col_nodes     = [e for e in out_nodes]
# col_node_idxs = [i for i in range(len(out_nodes))]
# col_size      = torch.Tensor([node_props[e]['out'] for e in out_nodes]).to(torch.int)
# col_stop      = torch.cumsum(col_size, 0)
# col_start     = torch.concat([torch.Tensor([0]).to(torch.int), col_stop[0:-1]])
# col_start


In [24]:
def calc_node_slice_info(node_matrix_summary, 
                         layer = 1):
    i = layer
    j = i+1

    inp_nodes = node_matrix_summary[i]['identity']
    out_nodes = node_matrix_summary[i]['calc']+tuple(node_matrix_summary[j]['identity'])

    # row lookup: positions on edge
    row_nodes     = [e for e in inp_nodes]
    row_node_idxs = [i for i in range(len(inp_nodes))]
    row_size      = torch.Tensor([node_props[e]['inp'] for e in inp_nodes]).to(torch.int)
    row_stop      = torch.cumsum(row_size, 0)
    row_start     = torch.concat([torch.Tensor([0]).to(torch.int), row_stop[0:-1]])

    # col lookup: positions on edge
    col_nodes     = [e for e in out_nodes]
    col_node_idxs = [i for i in range(len(out_nodes))]
    col_size      = torch.Tensor([node_props[e]['out'] for e in out_nodes]).to(torch.int)
    col_stop      = torch.cumsum(col_size, 0)
    col_start     = torch.concat([torch.Tensor([0]).to(torch.int), col_stop[0:-1]])


    return {
        'row_nodes': row_nodes,
        'row_node_idxs': row_node_idxs,
        'row_size': row_size,
        'row_stop': row_stop,
        'row_start': row_start,
        'col_nodes': col_nodes,
        'col_node_idxs': col_node_idxs,
        'col_size': col_size,
        'col_stop': col_stop,
        'col_start': col_start
        }

node_slice_dict = {}
for i in range(1, max(node_matrix_summary.keys())):
    node_slice_dict[i] = calc_node_slice_info(
        node_matrix_summary = node_matrix_summary, 
        layer = i)

In [46]:
# go from slice dict to matrix

# node_slice_dict
mat_key = 2

row_nodes     = node_slice_dict[mat_key]['row_nodes']
row_node_idxs = node_slice_dict[mat_key]['row_node_idxs']
row_size      = node_slice_dict[mat_key]['row_size']
row_stop      = node_slice_dict[mat_key]['row_stop']
row_start     = node_slice_dict[mat_key]['row_start']

col_nodes     = node_slice_dict[mat_key]['col_nodes']
col_node_idxs = node_slice_dict[mat_key]['col_node_idxs']
col_size      = node_slice_dict[mat_key]['col_size']
col_stop      = node_slice_dict[mat_key]['col_stop']
col_start     = node_slice_dict[mat_key]['col_start']

In [47]:
# check that the expected nodes are present.
# all col_nodes passed?
False not in [
    # col_node passed?
    False not in [
        # dependency in row_nodes?
        True if e in row_nodes else False for e in edge_dict[col_node]]
        for col_node in col_nodes]

False

In [44]:
# init matrix (will convert to sparse soon)
mat = torch.zeros(row_stop[-1], col_stop[-1])
print(mat.shape)

torch.Size([88240, 36632])


In [45]:
# each interaction is a tiny matrix. 

# iterate over nodes in columns, 
#   iterate over nodes in rows
#       init slice as
#       - identity (eye)
#       - random

i = 0

col_nodes[i]
ding = True
# col_start[col_node_idxs[i]], col_stop[col_node_idxs[i]]


# edge_dict[node]
c_idx1 = col_start[col_node_idxs[i]]
c_idx2 = col_stop[col_node_idxs[i]]
c_size = col_size[col_node_idxs[i]]

current_row_nodes = edge_dict[col_nodes[i]]

for current_row_node in current_row_nodes:  
    j = [i for i in range(len(row_nodes)) if row_nodes[i] == current_row_node][0]
    r_idx1 = row_start[row_node_idxs[j]]
    r_idx2 = row_stop[row_node_idxs[j]]
    r_size = row_size[row_node_idxs[j]]

    if col_nodes[i] == row_nodes[j]:
        # identity
        if ding:
            print('ding')
            ding = False
        mat[r_idx1:r_idx2, c_idx1:c_idx2] = torch.eye(c_size)
    else:
        # random
        mat[r_idx1:r_idx2, c_idx1:c_idx2] = torch.randn(r_size, c_size)
        #TODO what about kaiming init?

# mat[:, c_idx1:c_idx2]


In [None]:
# import plotly.express as px
# px.imshow(mat)

In [None]:


node = node_matrix_summary[1]['identity'][0]

node_props[node]['inp'], node_props[node]['out']


In [None]:
node_props

In [None]:
## end insert

In [None]:
model = VisableNeuralNetwork(
    node_props = myvnn.node_props,
    Linear_block = Linear_block_reps,
    edge_dict = myvnn.edge_dict,
    dependancy_order = myvnn.dependancy_order,
    node_to_inp_num_dict = new_lookup_dict
)
model = model.to('cuda')
# # with torch.no_grad(): print(model(vals))

In [None]:
# # if randomizing y
# torch.manual_seed(2608434)

# y_trn = X.get('YMat', ops_string='cs filter:val:train asarray from_numpy float')
# y_trn = y_trn[torch.randperm(y_trn.shape[0])]


# y_val = X.get('YMat', ops_string='cs filter:val:train asarray from_numpy float')
# y_val = y_val[torch.randperm(y_val.shape[0])]


In [None]:

training_dataloader = DataLoader(BigDataset(
    lookups_are_filtered = True,
    lookup_obs =  X.get('val:train',       ops_string='                   asarray from_numpy'), 
    lookup_geno = X.get('obs_geno_lookup', ops_string='   filter:val:train asarray from_numpy'),
    y =           X.get('YMat',            ops_string='cs filter:val:train asarray from_numpy float cuda:0')[:, None],
    G =           vals,
    G_type = 'list',
    # send_batch_to_gpu = 'cuda:0'
    ),
    batch_size = batch_size,
    shuffle = True
)

validation_dataloader = DataLoader(BigDataset(
    lookups_are_filtered = True,
    lookup_obs =  X.get('val:test',        ops_string='                   asarray from_numpy'), 
    lookup_geno = X.get('obs_geno_lookup', ops_string='   filter:val:test asarray from_numpy'),
    y =           X.get('YMat',            ops_string='cs filter:val:test asarray from_numpy float cuda:0')[:, None],
    G =           vals,
    G_type = 'list',
    # send_batch_to_gpu = 'cuda:0'
    ),
    batch_size = batch_size,
    shuffle = False
)


In [None]:
LSUV_(model, data = next(iter(training_dataloader))[1])

In [None]:
VNN = plDNN_general(model)  

optimizer = VNN.configure_optimizers()

logger = TensorBoardLogger("tb_vnn_logs", name=save_prefix)
trainer = pl.Trainer(max_epochs=max_epoch, logger=logger)

trainer.fit(model=VNN, train_dataloaders=training_dataloader, val_dataloaders=validation_dataloader)


In [None]:
import time, json
save_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())

json_path = cache_path+''.join(['lookup_dict','__'+save_time,'.json'])
with open(json_path, 'w', encoding='utf-8') as f: 
    json.dump(new_lookup_dict, f, ensure_ascii=False, indent=4)    

pt_path = cache_path+''.join([save_prefix,'__'+save_time,'.pt'])

torch.save(VNN.mod, pt_path)