In [1]:
import json
import torch

with open("patient_information.json") as f:
    patient = json.load(f)

with open("model.json") as f:
    model = json.load(f)


In [3]:
states = list(model["states"].keys())
state_index = {name: i for i, name in enumerate(states)}


In [4]:
control_channels = model["control_spec"]["channels"]
n_controls = len(control_channels)


In [5]:
print(model["states"].keys())


dict_keys(['T', 'E', 'H', 'PDL1', 'TGFb', 'ctDNA', 'Adr', 'Cyc', 'Tax', 'Tam', 'IO', 'TIL', 'N', 'BM'])


In [6]:
class AffineSigmoidPolicy(torch.nn.Module):
    def __init__(self, n_state_inputs, n_controls, time_feature=True, init_scale=0.1):
        super().__init__()
        self.time_feature = time_feature
        in_dim = n_state_inputs + (1 if time_feature else 0)
        # simple linear layer → controls
        self.linear = torch.nn.Linear(in_dim, n_controls)
        # small initialization
        with torch.no_grad():
            self.linear.weight.mul_(init_scale)
            self.linear.bias.mul_(init_scale)

    def forward(self, x, t_scaled):
        """
        x: (batch, n_state_inputs)
        t_scaled: (batch, 1) in [0,1]
        """
        if self.time_feature:
            inp = torch.cat([x, t_scaled], dim=-1)
        else:
            inp = x
        logits = self.linear(inp)
        # Sigmoid to get [0,1], you’ll later rescale to per-drug bounds
        u = torch.sigmoid(logits)
        return u


In [7]:
policy_cfg = patient["control_spec"]["policy"]
state_inputs = policy_cfg["state_inputs"]  # ['T','E','H','PDL1','ctDNA','N','BM']
n_state_inputs = len(state_inputs)

policy = AffineSigmoidPolicy(
    n_state_inputs=n_state_inputs,
    n_controls=n_controls,
    time_feature=policy_cfg.get("time_feature", True),
    init_scale=policy_cfg.get("init_scale", 0.1),
).to("cpu")  # or device


In [None]:
def simulate_patient(policy, patient, model, device="cpu", dtype=torch.float64):
    # time grid
    tgrid = patient["tgrid"]
    t0 = float(tgrid["start_day"])
    num_days = int(tgrid["num_days"])
    dt = float(tgrid["step_days"])
    steps = num_days  # one state per day

    # initial condition from patient or model
    u0_list = []
    for s in model["states"]:
        name = s["name"]
        if name in patient["initial_conditions"]:
            u0_list.append(float(patient["initial_conditions"][name]))
        else:
            u0_list.append(float(model["initial_conditions"].get(name, 0.0)))
    u = torch.tensor(u0_list, device=device, dtype=dtype)

    traj = []
    controls = []

    # precompute indices for policy inputs
    idxs = [state_index[name] for name in patient["control_spec"]["policy"]["state_inputs"]]

    t = t0
    for k in range(steps + 1):
        traj.append(u)

        # compute control from policy
        state_inp = u[idxs].unsqueeze(0)          # (1, n_state_inputs)
        t_scaled = torch.tensor([[t / num_days]], device=device, dtype=dtype)
        u_raw = policy(state_inp, t_scaled).squeeze(0)  # (n_controls,)

        # rescale to per-drug bounds if you have them
        # (here just assume 0–1 for now, you can insert your bounds logic)
        controls.append(u_raw)

        # step ODE one day forward using your existing RK4 / solver
        if k < steps:
            u = rk4_step_full_system(u, u_raw, t, dt, model)  # <-- YOUR function
            t += dt

    traj = torch.stack(traj, dim=0)         # (T+1, n_states)
    controls = torch.stack(controls, dim=0) # (T+1, n_controls)
    return traj, controls
