<a href="https://colab.research.google.com/github/Bhadra-lab/Neural-tangent/blob/main/drug_gene_precisionMedicine.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### *Precision medicine aims to tailor medical treatments to each patient's unique characteristics. Understanding drug-gene interactions is crucial for this goal. These interactions help identify the most effective drugs for individual patients based on their genetic makeup. Genetic variability greatly affects how people respond to drugs, including differences in drug metabolism, transport, and target proteins.*

**We explore the training of infinitely-wide, Bayesian, neural networks using a library called Neural Tangents.**

JAX is an open-source Python library developed by Google Research primarily for high-performance numerical computing and machine learning. It is similar to NumPy, but with added features to enable automatic differentiation and GPU/TPU acceleration.

In [None]:
!pip install jax jaxlib --upgrade



Neural Tangents is a Python library built on top of JAX that provides tools for analyzing the behavior of neural networks, particularly in the infinite-width limit

In [None]:
!pip install neural-tangents



In [None]:

# used synthetic dataset

key1, key2 = random.split(random.PRNGKey(1))
Drug_set1 = random.normal(key1, (20, 100)) # drug-gene interaction no-missing data
Drug_set2 = random.normal(key1, (50, 100)) # there is a missing interaction of these drug with a particular gene
specific_interaction_set1 = random.uniform(key1, shape=(20, 1)) # this drug-gene interaction missing in drug set 2 (training target)


We'll define our network using a neural network library closely resembling JAX's stax library. In this library, layers are represented by pairs of functions: init_fn and apply_fn. The init_fn function initializes parameters randomly based on a given input shape, while apply_fn computes the outputs of the function for specific inputs using the parameters.

In the neural_tangents.stax library, layers are represented by triplets of functions: init_fn, apply_fn, and kernel_fn. The first two functions, init_fn and apply_fn, function similarly to their counterparts in stax. However, the third function, kernel_fn, calculates infinite-width Gaussian process (GP) kernels corresponding to the layer.

In [None]:
from jax import random
from neural_tangents import stax

init_fn, apply_fn, kernel_fn = stax.serial(
    stax.Dense(512), stax.Relu(),
    stax.Dense(512), stax.Relu(),
    stax.Dense(1)
)

Next, we explore the exact prior over functions in the infinite-width limit using the kernel_fn. The kernel function, denoted as kernel = kernel_fn(k_1, k_2), calculates the kernel between two sets of inputs, k_1 and k_2. The kernel_fn can compute two distinct kernels: the NNGP kernel (Neural Network Gaussian Process), which characterizes the Bayesian infinite network, and the NT kernel (Neural Tangent Kernel), which illustrates how this network progresses under gradient descent. \\


*   k_1 and k_2 could both be the training data, if you want to compute the kernel matrix among all training examples.
*   k_1 could be the training data and k_2 could be the test data, if you want to compute the cross-kernel matrix between training and test examples.
*   k_1 and k_2 could be entirely different sets of data, unrelated to training or testing, for some other analysis or comparison.

In [None]:
## compute the cross-kernel matrix between training and test examples.
import jax.numpy as jnp
nngp = kernel_fn(Drug_set1, Drug_set2, 'nngp')
ntk = kernel_fn(Drug_set1, Drug_set2, 'ntk')

The nt.predict.gradient_descent_mse_ensemble function is part of the neural_tangents library in JAX, specifically within the predict module. This function is used to compute predictions from an ensemble of neural networks after performing gradient descent training using mean squared error (MSE) loss.

This predict_fn function includes two different modes: in "NNGP" mode we compute the Bayesian posterior (which is equivalent to gradient descent with all but the last-layer weights frozen), in "NTK" mode we compute the distribution of networks after gradient descent training.

In [None]:
import neural_tangents as nt

predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, Drug_set1,specific_interaction_set1)

#y_test_nngp = predict_fn(x_test=X_test, get='nngp')
missing_interaction_ntk = predict_fn(x_test=Drug_set2, get='ntk')



In [None]:
print(missing_interaction_ntk)

[[0.49183676]
 [0.4087034 ]
 [0.42373955]
 [0.39780936]
 [0.4478033 ]
 [0.39586216]
 [0.4310285 ]
 [0.41442433]
 [0.3682715 ]
 [0.4265475 ]
 [0.41246817]
 [0.37055683]
 [0.34291947]
 [0.42827386]
 [0.39322692]
 [0.42034617]
 [0.46186164]
 [0.3531958 ]
 [0.29335907]
 [0.48094597]
 [0.4656588 ]
 [0.3632041 ]
 [0.3443497 ]
 [0.36552286]
 [0.4129476 ]
 [0.37792048]
 [0.35632676]
 [0.40715608]
 [0.41331872]
 [0.35961902]
 [0.43581748]
 [0.4082433 ]
 [0.42031568]
 [0.3894707 ]
 [0.29209775]
 [0.45162088]
 [0.31991184]
 [0.41406566]
 [0.44170672]
 [0.38474664]
 [0.31312734]
 [0.38084272]
 [0.43842906]
 [0.27906573]
 [0.46072945]
 [0.33172214]
 [0.44671074]
 [0.3238895 ]
 [0.3953954 ]
 [0.3058522 ]]
