In [1]:
from models import *
import sys

import numpy as np
import pandas as pd

import sage

print("GPUs Available: ", tf.config.list_physical_devices('GPU'), flush=True)
physical_devices = tf.config.list_physical_devices('GPU')
for gpu in physical_devices:
    tf.config.experimental.set_memory_growth(gpu, True)
tf.config.set_visible_devices(physical_devices[1:], 'GPU')

############################################
# Define helper functions
############################################

def save_sages(sage_vals, path):
    res = pd.DataFrame([data_cols, sage_vals.values, sage_vals.std]).T
    res.columns = ["metabolite_id", "sage_value", "sage_value_sd"]
    res.to_csv(path, index=False)
    
def save_pw_sages(sage_vals, pw_groupnames, path):
    res = pd.DataFrame([pw_groupnames, sage_vals.values, sage_vals.std]).T
    res.columns = ["pathway_name", "sage_value", "sage_value_sd"]
    res.to_csv(path, index=False)

def get_sage_pws(model, ref_data, test_data, dim, pw_groups, pw_groupnames, pw_type):

    def get_dim_vals(dat):
        return model.encode_mu(dat)[:,dim:(dim+1)]
    
    # calculate base values
    dim_output = get_dim_vals(test_data)
    
    # Setup and calculate
    # NOTE: any callable function that returns a prediction is allowed in PermutationSampler
    imputer_pw = sage.GroupedMarginalImputer(ref_data[:10], pw_groups)
    sampler_pw = sage.PermutationSampler(get_dim_vals, imputer_pw, 'mse')
    sage_values_pw = sampler_pw(test_data, Y=dim_output, batch_size = 512)
    
    # rename dimension 0 to dimension 18
    dim_idx = 18 if dim == 0 else dim
    
    if pw_type == "super": 
        save_pw_sages(sage_values_pw, pw_groupnames, f"results/sage_values/VAE/superpw_dim_{dim_idx}.csv")
    elif pw_type == "sub":
        save_pw_sages(sage_values_pw, pw_groupnames, f"results/sage_values/VAE/subpw_dim_{dim_idx}.csv")
        
    return sage_values_pw


def get_sage_mets(model, ref_data, test_data, dim):

    def get_dim_vals(dat):
        return model.encode_mu(dat)[:,dim:(dim+1)]
    
    dim_output = get_dim_vals(test_data)
    
    # Setup and calculate
    # NOTE: any callable function that returns a prediction is allowed in PermutationSampler
    imputer = sage.MarginalImputer(ref_data[:10])
    sampler = sage.PermutationSampler(get_dim_vals, imputer, 'mse')
    sage_values = sampler(test_data, Y=dim_output, batch_size = 10)
    
    # rename dimension 0 to dimension 18
    dim_idx = 18 if dim == 0 else dim
    
    save_sages(sage_values, f"results/sage_values/VAE/met_dim_{dim_idx}.csv")
    
    return sage_values


def get_groups(pw_name):
    # Feature groups
    feature_groups = met_annos.groupby(pw_name)['COMP_IDstr'].apply(list).to_dict()

    group_names = [group for group in feature_groups]
    for col in feature_names:
        if np.all([col not in group[1] for group in feature_groups.items()]):
            group_names.append(col)

    # Group indices
    groups = []
    for _, group in feature_groups.items():
        ind_list = []
        for feature in group:
            ind_list.append(data_cols.index(feature))
        groups.append(ind_list)
        
    return {'feature_groups':feature_groups, 'group_names':group_names, 'groups':groups}

2022-08-09 05:09:01.433879: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.10.1
Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)


GPUs Available:  []


  from .autonotebook import tqdm as notebook_tqdm
2022-08-09 05:09:03.807983: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
2022-08-09 05:09:03.809150: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcuda.so.1
2022-08-09 05:09:03.823093: E tensorflow/stream_executor/cuda/cuda_driver.cc:328] failed call to cuInit: CUDA_ERROR_UNKNOWN: unknown error
2022-08-09 05:09:03.823131: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:169] retrieving CUDA diagnostic information for host: 8557ee92613c
2022-08-09 05:09:03.823139: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:176] hostname: 8557ee92613c
2022-08-09 05:09:03.823277: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:200] libcuda reported version is: 510.39.1
2022-08-09 05:09:03.823296: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:204] kernel reported version is: 510.39.1
2022-08-09 05:09:03.823302

In [2]:
############################################
# Initialize variables and instantiate objects
############################################

twins_path  = 'data/BioBank.xlsx'

#aml_data = pd.read_excel(twins_path, sheet_name='Metabolite Annotations', engine="openpyxl")
#aml_anno = pd.read_excel(twins_path, sheet_name='Patient Information', engine="openpyxl")
print("Read Patients")

twins_train_df = pd.read_excel(twins_path, sheet_name='Training Set', engine="openpyxl")
twins_test_df  = pd.read_excel(twins_path, sheet_name='Testing Set', engine="openpyxl")
print("Read Data")

# these arrays are used for score calculations
twins_train = twins_train_df.values
twins_test  = twins_test_df.values

model_path = 'models/'

# Define paths for saved VAE model
path_vae     = model_path + 'VAE.h5'
path_encoder = model_path + 'VAE_encoder.h5'
path_decoder = model_path + 'VAE_decoder.h5'

# Data & model configuration
input_dim = twins_train_df.shape[1]
params = load_vae_parameters(optimal=False)

intermediate_dim = max(params['encoder_units'], params['decoder_units'])
latent_dim = 18

kl_beta = params['kl_beta']
learning_rate = params['learning_rate']

batch_size = 1024
n_epochs = 1000

# instantiate model
mtmodel = mtVAE(input_dim,
                intermediate_dim,
                latent_dim,
                kl_beta,
                learning_rate)

print("Load Model")

# load model
mtmodel.vae.load_weights(path_vae)
mtmodel.encoder.load_weights(path_encoder)
mtmodel.decoder.load_weights(path_decoder)

print("Load Weights")

# Create groups to calculate grouped SAGE values
#data_cols = aml_data.columns.to_list()
met_annos = pd.read_excel(twins_path, sheet_name='Metabolite Annotations', engine="openpyxl")
data_cols = met_annos['COMP_IDstr'].to_list()
#met_annos = met_annos.loc[met_annos['COMP_IDstr'].isin(data_cols)]
feature_names = met_annos['COMP_IDstr'].to_list()
biochemical_names = met_annos['BIOCHEMICAL'].to_list()
print("Start SAGE")

Read Patients
Read Data
Load Model
Load Weights
Start SAGE


2022-08-09 05:12:49.449587: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-08-09 05:12:49.457817: I tensorflow/compiler/jit/xla_gpu_device.cc:99] Not creating XLA devices, tf_xla_enable_xla_devices not set
2022-08-09 05:12:49.472988: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:196] None of the MLIR optimization passes are enabled (registered 0 passes)
2022-08-09 05:12:49.475657: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 2749780000 Hz


In [3]:
%%capture
############################################
# MAIN PART OF SCRIPT
############################################
from pathlib import Path

for latent_dim in range(18):
    Path(f"{latent_dim}.ignore").touch()
    for sage_type in ['superpathway', 'subpathway', 'metabolite']:
        if sage_type == "superpathway":
            superpw        = get_groups('SUPER_PATHWAY')
            sage_pw_values = get_sage_pws(mtmodel, twins_train, twins_test, latent_dim, superpw['groups'], superpw['group_names'], "super")

        elif sage_type == "subpathway":
            subpw          = get_groups('SUB_PATHWAY')
            sage_pw_values = get_sage_pws(mtmodel, twins_train, twins_test, latent_dim, subpw['groups'], subpw['group_names'], "sub")

        elif sage_type == "metabolite":
            sage_met_values = get_sage_mets(mtmodel, twins_train, twins_test, latent_dim)

KeyboardInterrupt: 

In [4]:
print("DONE!!")


DONE!!
