In [None]:
# Hacky way to schedule. Here I'm setting these to sleep until the gpus should be free.
# At the end of the notebooks  os._exit(00) will kill the kernel freeing the gpu. 
#                          Hours to wait
# import time; time.sleep( 3 * (24*60*60))

# G only KEGG based network architecture

> The important change here is to use two linear/drop/relu repeats per module

In [None]:
import os, re, json
from tqdm import tqdm

import numpy as np
import pandas as pd
pd.set_option('display.max_columns', None)

import plotly.express as px
import plotly.io as pio
pio.templates.default = "plotly_white"

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch import nn
import torch.nn.functional as F # F.mse_loss

from EnvDL.core import * 
# includes 
    # remove_matching_files 
    # read_json
from EnvDL.dna import *
from EnvDL.dlfn import * 
# includes
    # read_split_info 
    # find_idxs_split_dict
    # train_loop_yx
    # train_error_yx
    # test_loop_yx
    # train_nn_yx
    # yhat_loop_yx
    

import lightning.pytorch as pl
from lightning.pytorch.loggers import TensorBoardLogger

from graphviz import Digraph
# import torchviz

In [None]:
use_gpu_num = 0

device = "cuda" if torch.cuda.is_available() else "cpu"
if use_gpu_num in [0, 1]: 
    torch.cuda.set_device(use_gpu_num)
print(f"Using {device} device")

In [None]:
cache_path = '../nbs_artifacts/02.30_g2fc_G_ACGT_VNN.1/'
save_prefix = [e for e in cache_path.split('/') if e != ''][-1]
ensure_dir_path_exists(dir_path = cache_path)

In [None]:
# Settings:
## For graph linear blocks
default_output_size = 50
default_dropout_pr = 0.1
default_block_reps = 2

## Training settings
run_epochs = 15
epochs_run = 0
dataloader_batch_size = 1000

## Functions

## Load data

In [None]:
load_from = '../nbs_artifacts/01.03_g2fc_prep_matrices/'
phno_geno = pd.read_csv(load_from+'phno_geno.csv')
phno = phno_geno

obs_geno_lookup = np.load(load_from+'obs_geno_lookup.npy') # Phno_Idx	Geno_Idx	Is_Phno_Idx
YMat = np.load(load_from+'YMat.npy')

In [None]:
# Testing using the cleaned up version in '../nbs_artifacts/01.05_g2fc_demo_model/'
load_from = '../nbs_artifacts/01.05_g2fc_demo_model/'
parsed_kegg_gene_entries = get_cached_result(load_from+'filtered_kegg_gene_entries.pkl')

ACGT_gene_slice_list = get_cached_result(load_from+'ACGT_gene_slice_list.pkl')

In [None]:
assert len(ACGT_gene_slice_list) == len(parsed_kegg_gene_entries)

In [None]:
## Create train/test validate indicies from json
load_from = '../nbs_artifacts/01.06_g2fc_cluster_genotypes/'

split_info = read_split_info(
    load_from = '../nbs_artifacts/01.06_g2fc_cluster_genotypes/',
    json_prefix = '2023:9:5:12:8:26')

temp = phno.copy()
temp[['Female', 'Male']] = temp['Hybrid'].str.split('/', expand = True)

test_dict = find_idxs_split_dict(
    obs_df = temp, 
    split_dict = split_info['test'][0]
)
# test_dict

# since this is applying predefined model structure no need for validation.
# This is included for my future reference when validation is needed.
temp = temp.loc[test_dict['train_idx'], ] # restrict before re-aplying

val_dict = find_idxs_split_dict(
    obs_df = temp, 
    split_dict = split_info['validate'][0]
)
# val_dict

# test_dict

train_idx = test_dict['train_idx']
test_idx  = test_dict['test_idx']

In [None]:
# # FIXME
# train_idx = train_idx[0:100] # 81169 total
# test_idx = test_idx[0:100]
# # 1/100th
# train_idx = train_idx[0:3000] # 81169 total
# test_idx = test_idx[0:3000]

## Generate Graph for DNN

### Functions for Graph Construction

In [None]:
## Building a Neural Net from an arbitrary graph
# start by finding the top level -- all those keys which are theselves not values
# helper function to get all keys and all value from a dict. Useful for when keys don't have unique values.
def find_uniq_keys_values(input_dict):
    all_keys = list(input_dict.keys())
    all_values = []
    for e in all_keys:
        all_values.extend(input_dict[e])
    all_values = list(set(all_values))

    return({'all_keys': all_keys,
           'all_values': all_values})

In [None]:
### Find order that nodes in the graph should be called to have all dependencies run when they are called.
# find the dependancies for run order from many dependancies to none
# wrapper function to find the nodes that aren't any other nodes dependancies.
def find_top_nodes(all_key_value_dict):
    return([e for e in all_key_value_dict['all_keys'] if e not in all_key_value_dict['all_values']])
# wrapper function to find the input nodes. They don't occur in the keys and thus won't be added to the list otherwise.
# another way to do this would have been to 
def find_input_nodes(all_key_value_dict):
    return([e for e in all_key_value_dict['all_values'] if e not in all_key_value_dict['all_keys']])

### Process paths into graphs

In [None]:
# Restrict to only those with pathway
kegg_gene_brite = [e for e in parsed_kegg_gene_entries if 'BRITE' in e.keys()]

In [None]:
# also require to have a non-empty path
kegg_gene_brite = [e for e in kegg_gene_brite if not e['BRITE']['BRITE_PATHS'] == []]

In [None]:
print('Retaining '+ str(round(len(kegg_gene_brite)/len(parsed_kegg_gene_entries), 4)*100)+'%, '+str(len(kegg_gene_brite)
     )+'/'+str(len(parsed_kegg_gene_entries)
     )+' Entries'
     )

In [None]:
# kegg_gene_brite[1]['BRITE']['BRITE_PATHS']

### Tiny `n`  gene version of the network:

In [None]:
n_genes = 6067 

print('Using '+str(n_genes)+'/'+str(len(kegg_gene_brite))+' genes.')

# if n_genes is too big, don't visualize.
vis_dot_bool = True
if n_genes > 10:
    vis_dot_bool = False

In [None]:
"""
The goal here is to have a dict with each node and a list of it's children. 
For example, the graph
a--b--d
 |-c--e
Would be parsed into     
{'a':['b', 'c'],
 'b':['d'],
 'c':['e']}
"""
kegg_connections = {}

# for all genes in list
for i in tqdm(range(n_genes)): 
# for i in tqdm(range(len(kegg_gene_brite))):    
    temp = kegg_gene_brite[i]['BRITE']['BRITE_PATHS']
    # clean up to make sure that there are no ":" characters. These can mess up graphviz
    temp = [[temp[j][i].replace(':', '-') for i in range(len(temp[j])) ] for j in range(len(temp))]
    # all paths through graph associated with a gene
    for j in range(len(temp)):
        # steps of the path through the graph
        for k in range(len(temp[j])-1):
            
            # name standardization 
            temp_jk  = temp[j][k]
            temp_jk1 = temp[j][k+1]
            temp_jk  = temp_jk.lower().title().replace(' ', '')
            temp_jk1 = temp_jk1.lower().title().replace(' ', '')
            
            # if this is a new key, add it and add the k+1 entry as it's child
            if temp_jk  not in kegg_connections.keys():
                kegg_connections[temp_jk] = [temp_jk1]
            else: 
                # Check to see if there's a new child to add   
                if temp_jk1 not in kegg_connections[temp_jk]:
                    # make sure that no key contains itself. This was a problem for 'Others' which is now disallowed.
                    if (temp_jk != temp_jk1):
#                         if ((temp_jk  != temp_jk1) & (temp_jk1 != 'Others')):
                        # add it.
                        kegg_connections[temp_jk].extend([temp_jk1])

# kegg_connections                        

In [None]:
if 'Others' in kegg_connections.keys():
    del kegg_connections['Others']
    print('Removed node "Others"')

# remove 'Others' as a possible value
for key in kegg_connections.keys():
    kegg_connections[key] = [e for e in kegg_connections[key] if e != 'Others']


In [None]:
# Make sure that no list contains it's own key
for key in kegg_connections.keys():
    kegg_connections[key] = [e for e in kegg_connections[key] if e != key]

In [None]:
# there might be associations with no dependants and with no dependants except those that have no dependants.
# Build up a list with those keys that don't connect back to snps then I'll pass over the connection dict once to remove references to them.
rm_list = []
rm_list_i = len(rm_list)
rm_list_j = -1
for i in range(100):
    if rm_list_i == rm_list_j:
        break
    else:
        rm_list = [key for key in kegg_connections.keys() if [e for e in kegg_connections[key] if e not in rm_list]
     ==[]]
        rm_list_j = rm_list_i 
        rm_list_i = len(rm_list)
rm_list

In [None]:
for key in rm_list:
    del kegg_connections[key]
    
for key in kegg_connections.keys():
    kegg_connections[key] = [e for e in kegg_connections[key] if e not in rm_list]

In [None]:
# add yhat node to the graph
temp_values = []
for key in kegg_connections.keys():
    temp_values += kegg_connections[key]

kegg_connections['y_hat'] = [key for key in kegg_connections.keys() if key not in temp_values]

In [None]:
# This is too big to render in full
dot = ''
if vis_dot_bool:
    dot = Digraph()
    for key in tqdm(kegg_connections.keys()):
        dot.node(key)
        for value in kegg_connections[key]:
            # edge takes a head/tail whereas edges takes name pairs concatednated (A, B -> AB)in a list
            dot.edge(value, key)    

dot

Version with the node names masked for size 

In [None]:
dot = ''
if vis_dot_bool:
    name_to_num_dict = dict(zip(list(kegg_connections.keys()),
                                [str(i) for i in range(len(list(kegg_connections.keys())))]))

    temp = {}
    for key in kegg_connections.keys():
        temp[name_to_num_dict[key]] = [name_to_num_dict[e] if e in name_to_num_dict.keys() else e for e in kegg_connections[key]]

    dot = Digraph()
    for key in tqdm(temp.keys()):
        dot.node(key)
        for value in temp[key]:
            # edge takes a head/tail whereas edges takes name pairs concatednated (A, B -> AB)in a list
            dot.edge(value, key)    

    # dot.render(directory=cache_path, view=True) 
dot

### Setup to build the graph

In [None]:
# start by finding the top level -- all those keys which are theselves not values
res = find_uniq_keys_values(input_dict = kegg_connections)
all_keys = res['all_keys']
all_values = res['all_values']

# use the keys to find the input/outputs of the graph
output_nodes = [e for e in all_keys if e not in all_values]
input_nodes = [e for e in all_values if e not in all_keys]

# (output_nodes, input_nodes)

In [None]:
# find the dependancies for run order from many dependancies to none
temp = kegg_connections.copy()

no_dependants = find_input_nodes(all_key_value_dict = find_uniq_keys_values(input_dict = temp))
# first pass. Same as the output nodes identified above
dependancy_order = []
# Then iterate
for ith in range(100): #TODO <- this should be set as a input parameter
    top_nodes = find_top_nodes(all_key_value_dict = find_uniq_keys_values(input_dict = temp))
    if top_nodes == []:
        break
    else:
        dependancy_order += top_nodes    
        # remove nodes from the graph that are at the 'top' level and haven't already been removed
        for key in [e for e in dependancy_order if e in temp.keys()]:
             temp.pop(key)

# dependancy_order

In [None]:
# reverse to get the order that the nodes should be called
dependancy_order.reverse()
# dependancy_order

In [None]:
# Trying out new approach: add a node for the input data tha will only flatten the input.
dependancy_order = input_nodes+dependancy_order

for key in input_nodes:
    kegg_connections[key] = [] #[key] # needs to contain itself so the model's `get_input_node()` function works 
                               # or that function needs to change.

In [None]:
# build a dict to go from the node names in `no_dependants` to the list index in `ACGT_gene_slice_list`
brite_node_to_list_idx_dict = {}
for i in tqdm(range(len(kegg_gene_brite))):
    brite_node_to_list_idx_dict[str(kegg_gene_brite[i]['BRITE']['BRITE_PATHS'][0][-1])] = i

In [None]:
input_tensor_dict = {}
for e in no_dependants:
    input_tensor_dict[e] = ACGT_gene_slice_list[brite_node_to_list_idx_dict[e]]
    
# input_tensor_dict

In [None]:
# Figure out expected input/output shapes
#==NOTE! This assumes only dense connections!==

# This could be replaced by a sort of "distance from output" measure
output_size_dict = dict(zip(dependancy_order, 
                        [default_output_size for i in range(len(dependancy_order))]))
output_size_dict['y_hat'] = 1 
# output_size_dict

In [None]:
# Setup dropout % dictionary
dropout_pr_dict = dict(zip(dependancy_order, 
                        [default_dropout_pr for i in range(len(dependancy_order))]))
dropout_pr_dict['y_hat'] = 0 # not required, output node is purely linear without dropout

In [None]:
# Setup replicates of layers dictionary
block_rep_dict = dict(zip(dependancy_order, 
                        [default_block_reps for i in range(len(dependancy_order))]))
block_rep_dict['y_hat'] = 1 # not required, output node is purely linear. Not a linear block

In [None]:
# CHANNEL AWARE VERSION -----------------------------------------------------------------------------------
input_size_dict = kegg_connections.copy()

# use the expected output sizes from `output_size_dict` to fill in the non-data sizes
tensor_ndim = len(input_tensor_dict[list(input_tensor_dict.keys())[0]].shape)
for e in tqdm(input_size_dict.keys()):
    # overwrite named connections with the output size of those connections
    # if the entry is in no_dependants it's data so it's size needs to be grabbed from the input_tensor_dict
    
    # is there no channel dim? (major/minor allele)
    if 2 == tensor_ndim:
        input_size_dict[e] = [
            list(input_tensor_dict[ee].shape)[-1] # <- NOTE! THIS ASSUMES ONLY DENSE CONNECTIONS (i.e. only the 1st dim is needed)
            if ee in no_dependants
            else output_size_dict[ee] for ee in input_size_dict[e]]
    elif 3 == tensor_ndim: # There is a channel dim
        input_size_dict[e] = [
            (list(input_tensor_dict[ee].shape)[1]*list(input_tensor_dict[ee].shape)[2]) # <- NOTE! THIS ASSUMES ONLY DENSE CONNECTIONS (i.e. only the 1st dim is needed)  
            if ee in no_dependants
            else output_size_dict[ee] for ee in input_size_dict[e]]

# Now walk over entries and overwrite with the sum of the inputs
for e in tqdm(input_size_dict.keys()):
    input_size_dict[e] = np.sum(input_size_dict[e])

In [None]:
dot = ''
if vis_dot_bool:
    dot = Digraph()
    for key in tqdm(kegg_connections.keys()):
        key_label = 'in: '+str(input_size_dict[key])+'\nout: '+str(output_size_dict[key])
        dot.node(key, key_label)
        for value in kegg_connections[key]:
            # edge takes a head/tail whereas edges takes name pairs concatednated (A, B -> AB)in a list
            dot.edge(value, key)    

dot

## Set up DataSet

Now we have

- A dictionary with the connections: `example_dict`
- The expected input sizes for each node: `example_dict_input_size`
- A dictionary with the input tensors: `input_tensor_dict`
- A list of the input tensors' names: `no_dependants` 
- A list with the order that each module should be called: `dependancy_order`

To have a fair test of whether the model is working, I want to ensure there is information to learn in the dataset. To this end I'm using just two genotypes.

In [None]:
x_list_temp = [torch.from_numpy(input_tensor_dict[key]).to(torch.float) for key in input_tensor_dict.keys()]

In [None]:
YMat_cs = calc_cs(YMat[train_idx])
y_cs = apply_cs(YMat, YMat_cs)

In [None]:
y_temp = torch.from_numpy(y_cs).to(torch.float)#[:, None]

In [None]:
x_list_temp[0].shape

## Set up NeuralNetwork, Data Loader

In [None]:
# Working version ====
# Doesn't pass output node through relu
class NeuralNetwork(nn.Module):
    def __init__(self, 
                 example_dict, # contains the node (excluding input tensors)
                 example_dict_input_size, # contains the input sizes (including the tensors)
                 example_dict_output_size,
                 example_dict_dropout_pr,
                 example_block_rep_dict,
                 input_tensor_names,
                 dependancy_order
                ):
        super(NeuralNetwork, self).__init__()
        def Linear_block(in_size, out_size, drop_pr, block_reps):
            block_list = []
            for i in range(block_reps):
                if i == 0:
                    block_list += [
                        nn.Linear(in_size, out_size),
                        nn.ReLU(),
                        nn.Dropout(drop_pr)]
                else:
                    block_list += [
                        nn.Linear(out_size, out_size),
                        nn.ReLU(),
                        nn.Dropout(drop_pr)]
        
            block = nn.ModuleList(block_list)
            return(block)           
        
        # fill in the list in dependancy order. 
        layer_list = []
        for key in dependancy_order:
            if key in input_tensor_names:
                layer_list += [
                    nn.Flatten()
                ]
            elif key != 'y_hat':
                layer_list += [
                    Linear_block(in_size=example_dict_input_size[key], 
                                 out_size=example_dict_output_size[key], 
                                 drop_pr=example_dict_dropout_pr[key],
                                 block_reps=example_block_rep_dict[key])
                              ]
            else:
                layer_list += [
                    nn.Linear(example_dict_input_size[key], 
                              example_dict_output_size[key])
                              ]
                

        self.nn_layer_list = nn.ModuleList(layer_list)

        # things for get_input_node in forward to work.
        self.example_dict = example_dict
        self.input_tensor_names = input_tensor_names
        self.dependancy_order = dependancy_order
        
        self.input_tensor_lookup = dict(zip(input_tensor_names, 
                                            [i for i in range(len(input_tensor_names))]))
        self.result_list = []
        self.result_list_lookup = {}
            

    def forward(self, x):
        # Note: x will be a list. input_tensor_lookup will contain the name: list index pairs.
        # I use a dict instead of a list comprehension here because there could be an arbitrarily
        # large number of inputs in the list. 
        def get_input_node(self, input_node, get_x):  
#             print(input_node, self.result_list_lookup)
            return(self.result_list[self.result_list_lookup[input_node]])
        
        # trying reinstantiating to get around inplace replacement issue.
        self.result_list = []
        self.result_list_lookup = {}
        for key in self.dependancy_order:
            input_nodes = self.example_dict[key]
            nn_layer_list_idx = [i for i in range(len(dependancy_order)) if dependancy_order[i]==key][0]
            
            self.result_list_lookup[key] = len(self.result_list_lookup)                
            if key in self.input_tensor_names: # If the input node is an input (flatten) layer
                self.result_list = self.result_list + [self.nn_layer_list[nn_layer_list_idx](
                    x[self.input_tensor_lookup[key]]
                ).clone()]

            elif key != 'y_hat':
                # refactored to handle module lists (even if module list contains only one entry)
                out = torch.concat(
                    [get_input_node(self, input_node = e, get_x = x) for e in input_nodes], 
                    -1)
            
                for module in self.nn_layer_list[nn_layer_list_idx]:
                    out = module(out)
        
                self.result_list = self.result_list + [out] 
            
            else:
                self.result_list = self.result_list + [self.nn_layer_list[nn_layer_list_idx](torch.concat(
                    [get_input_node(self, input_node = e, get_x = x) for e in input_nodes], 
                    -1)).clone()]            

        return self.result_list[self.result_list_lookup['y_hat']]

In [None]:
# model = NeuralNetwork(example_dict = kegg_connections, 
#                       example_dict_input_size = input_size_dict,
#                       example_dict_output_size = output_size_dict,
#                       example_dict_dropout_pr= dropout_pr_dict,
#                       example_block_rep_dict = block_rep_dict,
#                       input_tensor_names = list(input_tensor_dict.keys()),
#                       dependancy_order = dependancy_order) 
# model = model.to(device)
# model(next(iter(training_dataloader))[1])
# model

In [None]:
class ListDataset(Dataset): # for any G containing matix with many (phno) to one (geno)
    def __init__(self, 
                 y, 
                 x_list,
                 obs_idxs, # this is a list of the indexes used. It allows us to pass in smaller 
                           # tensors and then get the right genotype
                 obs_geno_lookup,
                 transform = None, target_transform = None,
                 **kwargs 
                ):
        self.device = device
        self.y = y 
        self.x_list = x_list
        self.obs_idxs = obs_idxs
        self.obs_geno_lookup = obs_geno_lookup
        self.transform = transform
        self.target_transform = target_transform    
        
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        y_idx =self.y[idx]
        
        idx_geno = obs_geno_lookup[self.obs_idxs[idx], 1]
        x_idx =[x[idx_geno, ] for x in self.x_list] 
        
        if self.target_transform:
            y_idx = self.transform(y_idx)
            x_idx = [self.transform(x) for x in x_idx]
            
        return y_idx, x_idx

In [None]:
training_dataloader = DataLoader(ListDataset(
        y = y_temp[train_idx][:, None].to('cuda'),
        x_list = [e.to('cuda') for e in x_list_temp],
        obs_idxs = train_idx, 
        obs_geno_lookup = obs_geno_lookup
    ),
    batch_size = dataloader_batch_size,
    shuffle = True
)

validation_dataloader = DataLoader(ListDataset(
        y = y_temp[test_idx][:, None].to('cuda'),
        x_list = [e.to('cuda') for e in x_list_temp],
        obs_idxs = test_idx, 
        obs_geno_lookup = obs_geno_lookup
    ),
    batch_size = dataloader_batch_size,
    shuffle = False 
)

# next(iter(training_dataloader))[1]

In [None]:
# next(iter(training_dataloader))[1].shape

In [None]:
model = NeuralNetwork(example_dict = kegg_connections, 
                      example_dict_input_size = input_size_dict,
                      example_dict_output_size = output_size_dict,
                      example_dict_dropout_pr= dropout_pr_dict,
                      example_block_rep_dict = block_rep_dict,
                      input_tensor_names = list(input_tensor_dict.keys()),
                      dependancy_order = dependancy_order)


model.to('cuda')
# LSUV_(model, data = next(iter(training_dataloader))[1])
# model(next(iter(training_dataloader))[1])

In [None]:
class plVNN(pl.LightningModule):
    def __init__(self, mod):
        super().__init__()
        self.mod = mod
    def training_step(self, batch, batch_idx):
        y_i, xs_i = batch
#         pred, out = self.mod(xs_i)
        pred = self.mod(xs_i)
        loss = F.mse_loss(pred, y_i)
        self.log("train_loss", loss)
        
#         with torch.no_grad():
#             weight_list=[(name, param) for name, param in model.named_parameters() if name.split('.')[-1] == 'weight']
#             for l in weight_list:
#                 self.log(("train_mean"+l[0]), l[1].mean())
#                 self.log(("train_std"+l[0]), l[1].std())        
        return(loss)
        
    def validation_step(self, batch, batch_idx):
        y_i, xs_i = batch
#         pred, out = self.mod(xs_i)
        pred = self.mod(xs_i)
        loss = F.mse_loss(pred, y_i)
        self.log('val_loss', loss)        
     
    def configure_optimizers(self, **kwargs):
        optimizer = torch.optim.Adam(self.parameters(), **kwargs)
        return optimizer    

In [None]:

# 2 epochs took 58 minutes so a 3 day weekend would be 
# 24*3*2=144 epochs
max_epoch = 150
max_epoch = 20

In [None]:
VNN = plVNN(model)                          # 1. Update
# optimizer = VNN.configure_optimizers(lr = lr, betas=(beta1, beta2)) # 2. Update
optimizer = VNN.configure_optimizers() # 2. Update

logger = TensorBoardLogger("tb_vnn_logs", name=save_prefix+'_no_LSUV')
trainer = pl.Trainer(max_epochs=max_epoch, logger=logger)

trainer.fit(model=VNN, train_dataloaders=training_dataloader, val_dataloaders=validation_dataloader)       # 4. Update

In [None]:
torch.save(VNN.mod, cache_path+'vnn'+'.pt')

In [None]:
os._exit(00)