In [3]:
%reload_ext autoreload
%autoreload 2

import causal_pyro
import pyro
import torch

from pyro.distributions import Normal, Uniform

In [None]:
def ODE_example():
    # Hidden markov model with SIR ODE as latent
    
    beta  = pyro.sample("beta", Uniform(0, 1))
    gamma = pyro.sample("gamma", Uniform(0, 1)) 
    
    def SIR(dX, X):
        
        # TODO: Think about how to deal with name collisions below.
        SI_flux = pyro.deterministic("SI_flux", beta * X.S * X.I)
        IR_flux = pyro.deterministic("IR_flux", gamma * X.I)

        dX.S = -SI_flux
        dX.I = SI_flux - IR_flux
        dX.R = IR_flux

        # TODO: Think about how to deal with names below.
        pyro.sample("S_obs", Normal(X.S, 1.))
        return X.S

    init_state = {"S": 1.0, "I": 2.0, "R": 3.3, "t": 0.0}
    tspan = torch.tensor([1., 2., 3.])

    # Some representation of time, variable, and value.
    # This draft representation states that we observe `S_obs` at t=1.5 at value 3.7 and at t=2.5 at value 3.9.
    observed_data = {"S_obs": (torch.tensor([1.5, 2.5]), torch.tensor([3.7, 3.9]))}

    # This draft representation says that we intervene to set `SI_flux` at t=1.7 to value 0.0 
    # for the rest of the simulation.
    intervention_data = {"SI_flux": (torch.tensor([1.7]) torch.tensor([0.0]))}


    with ODE_solver():
        with condition(data=observed_data):
            with intervene(data=intervention_data, type="continuous"):
                S = simulate(SIR, init_state, tspan)

In [1]:
import pyro
import torch
from pyro.distributions import constraints

from causal_pyro.dynamical.ops import State, simulate
from causal_pyro.dynamical.handlers import ODEDynamics


class SimpleSIRDynamics(ODEDynamics):
    @pyro.nn.PyroParam(constraint=constraints.positive)
    def beta(self):
        return torch.tensor(0.5)

    @pyro.nn.PyroParam(constraint=constraints.positive)
    def gamma(self):
        return torch.tensor(0.7)

    def diff(self, dX: State[torch.Tensor], X: State[torch.Tensor]):
        dX.S = -self.beta * X.S * X.I
        dX.I = self.beta * X.S * X.I - self.gamma * X.I
        dX.R = self.gamma * X.I


SIR_simple_model = SimpleSIRDynamics()

init_state = State(S=torch.tensor(1.0), I=torch.tensor(2.0), R=torch.tensor(3.3))
tspan = torch.tensor([1.0, 2.0, 3.0])

result = simulate(SIR_simple_model, init_state, tspan)

TypeError: rand() received an invalid combination of arguments - got (), but expected one of:
 * (tuple of ints size, *, torch.Generator generator, tuple of names names, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (tuple of ints size, *, torch.Generator generator, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (tuple of ints size, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (tuple of ints size, *, tuple of names names, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
