In [17]:
%reload_ext autoreload
%autoreload 2

import pyro
import torch
from pyro.distributions import constraints, Normal

from causal_pyro.dynamical.ops import State, simulate
from causal_pyro.dynamical.handlers import (
    ODEDynamics,
    PointInterruption,
    PointIntervention,
    PointObservation,
    simulate,
)

In [18]:
class SimpleSIRDynamics(ODEDynamics):
    @pyro.nn.PyroParam(constraint=constraints.positive)
    def beta(self):
        return torch.tensor(0.5)

    @pyro.nn.PyroParam(constraint=constraints.positive)
    def gamma(self):
        return torch.tensor(0.7)

    def diff(self, dX: State[torch.Tensor], X: State[torch.Tensor]):
        dX.S = -self.beta * X.S * X.I
        dX.I = self.beta * X.S * X.I - self.gamma * X.I
        dX.R = self.gamma * X.I


SIR_simple_model = SimpleSIRDynamics()

init_state = State(S=torch.tensor(1.0), I=torch.tensor(2.0), R=torch.tensor(3.3))
tspan = torch.tensor([1.0, 2.0, 3.0, 4.0])

new_state = State(S=torch.tensor(10.0))
S_obs = torch.tensor(10.0)
loglikelihood = lambda state: Normal(state.S, 1).log_prob(S_obs)

with PointObservation(time=2.9, loglikelihood=loglikelihood):
    # with PointIntervention(time=2.99, intervention=new_state):
    result1 = simulate(SIR_simple_model, init_state, tspan)

result2 = simulate(SIR_simple_model, init_state, tspan)

print(result1)
print(result2)

with pyro.poutine.trace() as tr:
    with PointObservation(time=2.9, loglikelihood=loglikelihood):
        simulate(SIR_simple_model, init_state, tspan)

print(tr.trace.nodes)

State({'S': tensor([1.0000, 0.4254, 0.2488, 0.1836], grad_fn=<CatBackward0>), 'I': tensor([2.0000, 1.3780, 0.8036, 0.4435], grad_fn=<CatBackward0>), 'R': tensor([3.3000, 4.4966, 5.2476, 5.6728], grad_fn=<CatBackward0>)})
State({'I': tensor([2.0000, 1.3780, 0.8036, 0.4435], grad_fn=<ViewBackward0>), 'R': tensor([3.3000, 4.4966, 5.2476, 5.6728], grad_fn=<ViewBackward0>), 'S': tensor([1.0000, 0.4254, 0.2488, 0.1836], grad_fn=<ViewBackward0>)})
OrderedDict([('obs_2.9000000953674316', {'type': 'sample', 'name': 'obs_2.9000000953674316', 'fn': Unit(log_factor: -48.359619140625), 'is_observed': True, 'args': (), 'kwargs': {}, 'value': tensor([]), 'infer': {'is_auxiliary': True}, 'scale': 1.0, 'mask': None, 'cond_indep_stack': (), 'done': True, 'stop': False, 'continuation': None})])


In [19]:
with PointInterruption(time=2.99):
    with PointInterruption(time=3.01):
        result3 = simulate(SIR_simple_model, init_state, tspan)

In [None]:
with PointInterruption(time=3.01):
    with PointInterruption(time=2.99):
        result3 = simulate(SIR_simple_model, init_state, tspan)

In [28]:
getattr(result1, "S")

tensor([1.0000, 0.4254, 0.2488, 0.1836], grad_fn=<CatBackward0>)

In [29]:
getattr(result2, "S")

tensor([1.0000, 0.4254, 0.2488, 0.1836], grad_fn=<ViewBackward0>)

In [26]:
getattr(result3, "S")

tensor([1.0000, 0.4254, 0.2488, 0.1836], grad_fn=<CatBackward0>)