# Testing TRIP analysis with ChiRho

## Toy model network

In [3]:
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

from chirho_diffeqpy import DiffEqPy, ATempParams
from chirho.dynamical.handlers.solver import TorchDiffEq
from chirho.dynamical.ops import simulate
from chirho.dynamical.handlers.trajectory import LogTrajectory
from chirho.dynamical.handlers import StaticBatchObservation, StaticIntervention, DynamicIntervention
from chirho.counterfactual.handlers import MultiWorldCounterfactual
from chirho.observational.handlers import condition
from chirho.dynamical.ops import State
import numpy as np
import torch
from pyro import sample, set_rng_seed
from pyro.distributions import Uniform, Poisson
from typing import Tuple, Optional, Union, Dict
from functools import partial
import matplotlib.pyplot as plt
import pyro
from pyro.infer.autoguide import AutoDelta, AutoMultivariateNormal
from contextlib import nullcontext
pyro.settings.set(module_local_params=True)

# Define types for clarity
ArrayLike = Union[np.ndarray, torch.Tensor]

# Define the pure dynamics function for a simple gene regulatory network
def gene_network_dynamics(state: State[ArrayLike], atemp_params: Dict[str, ArrayLike]) -> State[ArrayLike]:
    # State contains gene expression levels
    gene1 = state["gene1"]  # Activator
    gene2 = state["gene2"]  # Repressed by gene1
    gene3 = state["gene3"]  # Activated by gene2
    
    # Parameters
    production_rates = atemp_params["production_rates"]  # Basal production rates
    degradation_rates = atemp_params["degradation_rates"]  # Degradation rates
    activation_strength = atemp_params["activation_strength"]  # How strongly gene2 activates gene3
    repression_strength = atemp_params["repression_strength"]  # How strongly gene1 represses gene2
    
    # Dynamics (change in expression levels)
    dgene1_dt = production_rates[0] - degradation_rates[0] 
    
    # Gene2 is repressed by gene1 (higher gene1 means lower production of gene2)
    dgene2_dt = production_rates[1] / (1 + repression_strength * gene1) - degradation_rates[1] * gene2
    
    # Gene3 is activated by gene2 (higher gene2 means higher production of gene3)
    dgene3_dt = production_rates[2] * (activation_strength * gene2) / (1 + activation_strength * gene2) - degradation_rates[2] * gene3
    
    return dict(gene1=dgene1_dt, gene2=dgene2_dt, gene3=dgene3_dt)

# Define a prior over parameters
def prior():
    # Sample parameters from priors
    production_rates = torch.stack([
        sample("prod_rate_1", Uniform(0.5, 1.5)),
        sample("prod_rate_2", Uniform(0.5, 1.5)),
        sample("prod_rate_3", Uniform(0.5, 1.5))
    ])
    
    degradation_rates = torch.stack([
        sample("deg_rate_1", Uniform(0.1, 0.3)),
        sample("deg_rate_2", Uniform(0.1, 0.3)),
        sample("deg_rate_3", Uniform(0.1, 0.3))
    ])
    
    activation_strength = sample("activation", Uniform(1.0, 3.0))
    repression_strength = sample("repression", Uniform(1.0, 3.0))
    
    # Initial state
    initial_state = dict(
        gene1=sample("initial_gene1", Uniform(0.1, 1.0)),
        gene2=sample("initial_gene2", Uniform(0.1, 1.0)),
        gene3=sample("initial_gene3", Uniform(0.1, 1.0))
    )
    
    # Parameters
    atemp_params = dict(
        production_rates=production_rates,
        degradation_rates=degradation_rates,
        activation_strength=activation_strength,
        repression_strength=repression_strength
    )
    
    return initial_state, atemp_params

# Setup for running the simulation
def simulate_gene_network(solver, times):
    from chirho.dynamical.handlers import simulate
    
    with solver:
        # Run the model with the solver
        initial_state, atemp_params = prior()
        trajectory = simulate(
            gene_network_dynamics,
            initial_state,
            atemp_params,
            times
        )
    
    return trajectory

# Plot function
def plot_trajectory(times, trajectory):
    plt.figure(figsize=(10, 6))
    plt.plot(times, trajectory["gene1"], label="Gene 1 (Activator)")
    plt.plot(times, trajectory["gene2"], label="Gene 2 (Repressed by Gene 1)")
    plt.plot(times, trajectory["gene3"], label="Gene 3 (Activated by Gene 2)")
    plt.xlabel("Time")
    plt.ylabel("Gene Expression Level")
    plt.title("Gene Regulatory Network Simulation")
    plt.legend()
    plt.grid(True)
    plt.show()

In [4]:
# Setup time points and solvers
from chirho.dynamical.solvers import TorchDiffEq

# Time points for simulation
times = torch.linspace(0, 20, 100)

# Create solver
torchdiffeq_solver = TorchDiffEq(rtol=1e-5, atol=1e-7, method="dopri5")

# Set seed for reproducibility
pyro.set_rng_seed(42)

# Run simulation
trajectory = simulate_gene_network(torchdiffeq_solver, times)

# Plot results
plot_trajectory(times, trajectory)

ModuleNotFoundError: No module named 'chirho.dynamical.solvers'