The Idea
==========

The idea is that the Gene Regulatory Network (GRN) is highly complex and can be approximated by a simple neural network.

JAX has a special module called FLAX for the easy creation of neural networks. This is how the basic workflow looks like:

In [6]:
import sys

sys.path.insert(0, '../src')

import jax
import jax.numpy as jnp
from jax import random, jit
import matplotlib.pyplot as plt
import seaborn as sns

from neural_network import (
    init_params, get_regulatory_function,
    visualise_network_function, RegulatoryNetwork
)

from dynamics import (
    apply_threshold, get_neighbor_average,
)

# Plotting style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 4)
plt.rcParams['figure.dpi'] = 100

print("✓ Imports successful")
print(f"JAX version: {jax.__version__}")
print(f"JAX devices: {jax.devices()}")

✓ Imports successful
JAX version: 0.4.35
JAX devices: [CpuDevice(id=0)]


In [13]:
# We create a neural network from flax like this:

model = RegulatoryNetwork()

# We have to decide how the input looks like. Lets define a dummy pattern and apply the neighbring averging
pattern = jnp.array([1, 0, 1, 0, 1, 0, 1, 0])
neighbor_average = get_neighbor_average(pattern)

# Now we create a random seed and key:
seed = 42
key = random.PRNGKey(seed)

key, subkey = random.split(key)
# Initialize with shape (1,) since network processes one scalar at a time
params = init_params(model, subkey, (1,))

# Now we can access the (randomly initialised) network function:
regulatory_function = get_regulatory_function(model, params)
print(f"Regulatory function output: {regulatory_function(neighbor_average)}")

Regulatory function output: [ 0.         -0.09629551  0.         -0.09629551  0.         -0.09629551
  0.         -0.09629551]


In [None]:
#