<a href="https://colab.research.google.com/github/MatiasNazareth1993-coder/Virtual-cell/blob/main/Virtual_cell_simulation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# virtual_cell_model.py
# Requirements: torch, sklearn (for metrics), numpy
# pip install torch sklearn numpy

import math
import random
import numpy as np
from typing import Tuple
from sklearn.metrics import roc_auc_score, accuracy_score
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# -------------------------
# Synthetic dataset helper
# -------------------------
class SyntheticCellDataset(Dataset):
    """
    Creates synthetic multimodal samples for demonstration.
    Each sample:
      - genomic vector (e.g., summary features) -> dim_g
      - proteomic vector -> dim_p
      - metabolomic vector -> dim_m
      - env signals vector (ROS, cytokines, nutrients) -> dim_e
    Targets:
      - fate (0=healthy_division,1=senescence,2=apoptosis)
      - telomerase_action (0 or 1)
    """
    def __init__(self, n_samples=5000, dims=(64,32,32,8), seed=42):
        super().__init__()
        self.rng = np.random.RandomState(seed)
        self.n = n_samples
        self.dim_g, self.dim_p, self.dim_m, self.dim_e = dims
        self._make()

    def _make(self):
        self.genomic = self.rng.normal(size=(self.n, self.dim_g)).astype(np.float32)
        self.proteomic = self.rng.normal(size=(self.n, self.dim_p)).astype(np.float32)
        self.metabolomic = self.rng.normal(size=(self.n, self.dim_m)).astype(np.float32)
        self.env = self.rng.normal(size=(self.n, self.dim_e)).astype(np.float32)

        # Synthetic rule: shorter "telomere proxy" in genomic increases chance of senescence,
        # high ROS in env increases senescence, but controlled telomerase_action can reduce senescence.
        # We'll craft a continuous score and discretize.
        telomere_proxy = self.genomic[:, :1].squeeze()  # simple scalar proxy
        ros_signal = self.env[:, 0]  # assume env[:,0] is ROS
        noise = 0.1 * self.rng.normal(size=self.n)

        # telomerase_action target: 1 when telomere_proxy < threshold and ROS low enough (simulated)
        telomerase_prob = (telomere_proxy < -0.2) & (ros_signal < 0.5)
        telomerase_action = telomerase_prob.astype(np.int64)

        # fate: 0 healthy, 1 senescence, 2 apoptosis
        # If telomere very short and no telomerase -> senescence; if env huge ROS -> apoptosis
        fate = np.zeros(self.n, dtype=np.int64)
        fate[(telomere_proxy < -0.6) & (telomerase_action == 0)] = 1
        fate[(ros_signal > 2.0)] = 2
        # randomize a bit
        flip = self.rng.rand(self.n)
        fate[flip < 0.02] = self.rng.randint(0,3,size=(flip < 0.02).sum())

        self.targets_fate = fate
        self.targets_tel = telomerase_action

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        return {
            'genomic': self.genomic[idx],
            'proteomic': self.proteomic[idx],
            'metabolomic': self.metabolomic[idx],
            'env': self.env[idx],
            'fate': self.targets_fate[idx],
            'tel': self.targets_tel[idx]
        }

# -------------------------
# Model definition
# -------------------------
class ModalityEncoder(nn.Module):
    def __init__(self, in_dim, hid_dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hid_dim),
            nn.LayerNorm(hid_dim),
            nn.GELU(),
            nn.Linear(hid_dim, hid_dim),
            nn.LayerNorm(hid_dim),
            nn.GELU()
        )
    def forward(self, x):
        return self.net(x)

class VirtualCellModel(nn.Module):
    def __init__(self, dims=(64,32,32,8), hidden=128, dropout=0.1):
        super().__init__()
        dim_g, dim_p, dim_m, dim_e = dims
        enc_h = hidden // 2
        self.enc_g = ModalityEncoder(dim_g, enc_h)
        self.enc_p = ModalityEncoder(dim_p, enc_h)
        self.enc_m = ModalityEncoder(dim_m, enc_h)
        self.enc_e = ModalityEncoder(dim_e, enc_h)

        fused_dim = enc_h * 4
        self.fusion = nn.Sequential(
            nn.Linear(fused_dim, hidden),
            nn.LayerNorm(hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, hidden),
            nn.LayerNorm(hidden),
            nn.GELU()
        )

        # outputs
        self.fate_head = nn.Sequential(
            nn.Linear(hidden, hidden//2),
            nn.GELU(),
            nn.Linear(hidden//2, 3)  # 3-way classification
        )
        self.telomerase_head = nn.Sequential(
            nn.Linear(hidden, hidden//2),
            nn.GELU(),
            nn.Linear(hidden//2, 1),
            nn.Sigmoid()
        )

    def forward(self, genomic, proteomic, metabolomic, env):
        g = self.enc_g(genomic)
        p = self.enc_p(proteomic)
        m = self.enc_m(metabolomic)
        e = self.enc_e(env)
        fused = torch.cat([g,p,m,e], dim=-1)
        h = self.fusion(fused)
        fate_logits = self.fate_head(h)
        tel = self.telomerase_head(h).squeeze(-1)
        return fate_logits, tel

# -------------------------
# Training utilities
# -------------------------
def train_epoch(model, dataloader, optimizer, device, scaler=None, clip_grad=1.0):
    model.train()
    total_loss = 0.0
    for batch in dataloader:
        genomic = batch['genomic'].to(device)
        proteomic = batch['proteomic'].to(device)
        metabolomic = batch['metabolomic'].to(device)
        env = batch['env'].to(device)
        fate = batch['fate'].to(device)
        tel = batch['tel'].to(device).float()

        optimizer.zero_grad()
        fate_logits, tel_pred = model(genomic, proteomic, metabolomic, env)

        loss_fate = F.cross_entropy(fate_logits, fate)
        loss_tel = F.binary_cross_entropy(tel_pred, tel)
        # Combined loss: weight telomerase head to encourage correct transient control
        loss = loss_fate + 0.5 * loss_tel

        loss.backward()
        # gradient clipping - helpful when using higher LR
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad)
        optimizer.step()
        total_loss += loss.item() * genomic.size(0)

    return total_loss / len(dataloader.dataset)

@torch.no_grad()
def evaluate(model, dataloader, device):
    model.eval()
    all_fates = []
    all_fate_preds = []
    all_tel = []
    all_tel_preds = []
    for batch in dataloader:
        genomic = batch['genomic'].to(device)
        proteomic = batch['proteomic'].to(device)
        metabolomic = batch['metabolomic'].to(device)
        env = batch['env'].to(device)
        fate = batch['fate'].cpu().numpy()
        tel = batch['tel'].cpu().numpy()

        fate_logits, tel_pred = model(genomic, proteomic, metabolomic, env)
        fate_probs = F.softmax(fate_logits, dim=-1).cpu().numpy()
        fate_pred = np.argmax(fate_probs, axis=1)
        tel_pred_np = tel_pred.cpu().numpy()

        all_fates.append(fate)
        all_fate_preds.append(fate_pred)
        all_tel.append(tel)
        all_tel_preds.append(tel_pred_np)

    all_fates = np.concatenate(all_fates)
    all_fate_preds = np.concatenate(all_fate_preds)
    all_tel = np.concatenate(all_tel)
    all_tel_preds = np.concatenate(all_tel_preds)

    acc = accuracy_score(all_fates, all_fate_preds)
    try:
        auc = roc_auc_score(all_tel, all_tel_preds)
    except Exception:
        auc = float('nan')
    return {'accuracy': acc, 'telomerase_auc': auc}

# -------------------------
# Main training script
# -------------------------
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # Data
    ds = SyntheticCellDataset(n_samples=8000)
    train_size = int(0.8 * len(ds))
    val_size = len(ds) - train_size
    train_ds, val_ds = torch.utils.data.random_split(ds, [train_size, val_size])

    train_loader = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=256, shuffle=False, num_workers=1, pin_memory=True)

    # Model
    model = VirtualCellModel(dims=(ds.dim_g, ds.dim_p, ds.dim_m, ds.dim_e), hidden=256, dropout=0.12)
    model.to(device)

    # ---- HIGHER LEARNING RATE SETTING ----
    # You asked for "higher learning rate". Start with lr=5e-3 (higher than typical 1e-3) and
    # consider experimenting with 1e-2, but use warmup and reduce-on-plateau scheduler for stability.
    lr = 5e-3
    weight_decay = 1e-4
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    # Warmup scheduler (linear warmup then cosine decay)
    total_steps = 2000
    warmup_steps = 200
    def lr_lambda(step):
        if step < warmup_steps:
            return float(step) / float(max(1, warmup_steps))
        # cosine decay after warmup
        progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
        return 0.5 * (1.0 + math.cos(math.pi * progress))

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

    best_val = -1.0
    epochs = 25
    step = 0
    for epoch in range(epochs):
        train_loss = train_epoch(model, train_loader, optimizer, device, clip_grad=1.0)
        scheduler.step()
        step += 1

        metrics = evaluate(model, val_loader, device)
        print(f"Epoch {epoch+1:02d} | train_loss={train_loss:.4f} | val_acc={metrics['accuracy']:.4f} | tel_auc={metrics['telomerase_auc']:.4f} | lr={optimizer.param_groups[0]['lr']:.5f}")

        # simple checkpointing
        if metrics['telomerase_auc'] > best_val:
            best_val = metrics['telomerase_auc']
            torch.save(model.state_dict(), 'best_virtual_cell_model.pt')

    print("Training finished. Best telomerase AUC:", best_val)

if __name__ == "__main__":
    main()




Epoch 01 | train_loss=1.4372 | val_acc=0.4150 | tel_auc=0.4693 | lr=0.00003




Epoch 02 | train_loss=1.1201 | val_acc=0.8900 | tel_auc=0.5489 | lr=0.00005




Epoch 03 | train_loss=0.7321 | val_acc=0.8900 | tel_auc=0.7092 | lr=0.00007




Epoch 04 | train_loss=0.6517 | val_acc=0.8900 | tel_auc=0.7534 | lr=0.00010




Epoch 05 | train_loss=0.5254 | val_acc=0.9000 | tel_auc=0.8882 | lr=0.00013




Epoch 06 | train_loss=0.3668 | val_acc=0.9175 | tel_auc=0.9464 | lr=0.00015




Epoch 07 | train_loss=0.2681 | val_acc=0.9137 | tel_auc=0.9624 | lr=0.00018




Epoch 08 | train_loss=0.2052 | val_acc=0.9250 | tel_auc=0.9678 | lr=0.00020




Epoch 09 | train_loss=0.1518 | val_acc=0.9269 | tel_auc=0.9691 | lr=0.00022




Epoch 10 | train_loss=0.1018 | val_acc=0.9275 | tel_auc=0.9675 | lr=0.00025




Epoch 11 | train_loss=0.0648 | val_acc=0.9306 | tel_auc=0.9675 | lr=0.00028




Epoch 12 | train_loss=0.0413 | val_acc=0.9263 | tel_auc=0.9695 | lr=0.00030




Epoch 13 | train_loss=0.0249 | val_acc=0.9263 | tel_auc=0.9696 | lr=0.00033




Epoch 14 | train_loss=0.0174 | val_acc=0.9231 | tel_auc=0.9687 | lr=0.00035




Epoch 15 | train_loss=0.0146 | val_acc=0.9244 | tel_auc=0.9715 | lr=0.00038




Epoch 16 | train_loss=0.0130 | val_acc=0.9287 | tel_auc=0.9679 | lr=0.00040




Epoch 17 | train_loss=0.0177 | val_acc=0.9200 | tel_auc=0.9709 | lr=0.00043




Epoch 18 | train_loss=0.0270 | val_acc=0.9269 | tel_auc=0.9731 | lr=0.00045




Epoch 19 | train_loss=0.0207 | val_acc=0.9225 | tel_auc=0.9722 | lr=0.00047




Epoch 20 | train_loss=0.0185 | val_acc=0.9169 | tel_auc=0.9740 | lr=0.00050




Epoch 21 | train_loss=0.0209 | val_acc=0.9269 | tel_auc=0.9719 | lr=0.00052




Epoch 22 | train_loss=0.0242 | val_acc=0.9294 | tel_auc=0.9688 | lr=0.00055




Epoch 23 | train_loss=0.0193 | val_acc=0.9287 | tel_auc=0.9724 | lr=0.00057




Epoch 24 | train_loss=0.0198 | val_acc=0.9300 | tel_auc=0.9782 | lr=0.00060




Epoch 25 | train_loss=0.0123 | val_acc=0.9319 | tel_auc=0.9740 | lr=0.00063
Training finished. Best telomerase AUC: 0.97820967489316


In [None]:
# RL-based telomerase controller (PPO) with a simple cell simulator environment.
# This code is intended to be run in a local Python environment with PyTorch installed.
# It creates a lightweight gym-like environment that simulates cell states and the
# outcome probabilities (healthy division / senescence / apoptosis) influenced by
# telomerase actions. A PPO agent learns when to transiently activate telomerase
# to maximize long-term tissue health while minimizing cancer risk.
#
# To run here: the python_user_visible tool will execute and show training logs.
# Requirements: torch, numpy
# pip install torch numpy
#
# Note: This is an educational simulation — not biological guidance.

import math, random, time
from dataclasses import dataclass
from typing import Tuple, List
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# -------------------------
# Simple cell environment
# -------------------------
class SimpleCellEnv:
    """
    Minimal episodic environment representing a single cell's lifecycle across T steps.
    State: [telomere_proxy, ros, age, cumulative_tel_actions]
    Action: 0 (no telomerase), 1 (activate telomerase transiently)
    Dynamics: telomere_proxy tends to shorten (decrease) each step; action restores a
    small amount. High ROS pushes toward apoptosis; very short telomeres -> senescence.
    Repeated telomerase activation increases cumulative_cancer_risk.
    Reward: +1 for healthy division outcome, -1 for senescence, -2 for apoptosis.
    Also penalize cancer_risk at episode end.
    """
    def __init__(self, max_steps=10, seed=0):
        self.max_steps = max_steps
        self.rng = np.random.RandomState(seed)
        self.reset()

    def reset(self):
        # telomere_proxy: higher is healthier. Start near zero with variance.
        self.tel = float(self.rng.normal(loc=0.0, scale=0.3))
        # ROS baseline
        self.ros = float(self.rng.normal(loc=0.2, scale=0.2))
        self.age = 0.0
        self.cumulative_tel_actions = 0.0
        self.step_count = 0
        self.done = False
        return self._get_obs()

    def _get_obs(self):
        # normalize to roughly [-3,3] ranges
        return np.array([self.tel, self.ros, self.age / (self.max_steps + 1), self.cumulative_tel_actions / (self.max_steps + 1)], dtype=np.float32)

    def step(self, action: int) -> Tuple[np.ndarray, float, bool, dict]:
        """
        Perform action (0 or 1). Returns obs, reward (immediate 0 until terminal), done, info.
        We'll only give outcome reward at terminal step (a simple episodic decision process).
        """
        assert action in (0,1)
        # Action effect: transiently increases telomere length
        if action == 1:
            # small restoration
            self.tel += 0.5 * (1.0 - 0.1 * self.cumulative_tel_actions)  # diminishing returns
            self.cumulative_tel_actions += 1.0
        # Natural telomere shortening and ROS changes
        # telomere shortens a bit each step, and ROS may fluctuate
        self.tel -= 0.2 + 0.05 * self.rng.randn()
        self.ros += 0.05 * self.rng.randn()

        self.age += 1.0
        self.step_count += 1

        done = False
        reward = 0.0
        info = {}

        if self.step_count >= self.max_steps:
            # determine outcome probabilistically
            # base logits depend on tel and ros
            # score: healthier with larger tel, lower ros
            tel_score = self.tel  # higher better
            ros_score = -self.ros  # higher ros worse -> negative
            # softmax over three outcomes
            logits = np.array([1.2*tel_score + 0.2*ros_score,    # healthy
                               -0.5*tel_score + 0.6*(-ros_score), # senescence influenced by short tel (low tel_score)
                               -0.8*ros_score + 0.1*(1.0 - tel_score)])  # apoptosis influenced by high ros
            # add stochasticity
            logits += 0.3 * self.rng.randn(3)
            probs = np.exp(logits - np.max(logits))
            probs = probs / probs.sum()

            # sample outcome
            outcome = self.rng.choice(3, p=probs)
            if outcome == 0:
                reward = 1.0  # healthy division
            elif outcome == 1:
                reward = -1.0  # senescence
            else:
                reward = -2.0  # apoptosis

            # cancer risk penalty: if cumulative telomerase activations exceed threshold,
            # increase long-term cancer risk; penalize proportionally.
            cancer_risk = max(0.0, (self.cumulative_tel_actions - 1.5) * 0.2)  # small penalty per excess use
            reward -= cancer_risk

            done = True
            info['outcome'] = outcome
            info['probs'] = probs
            info['cancer_risk'] = cancer_risk

        obs = self._get_obs()
        self.done = done
        return obs, reward, done, info

# -------------------------
# PPO agent (policy + value)
# -------------------------
class MLPActorCritic(nn.Module):
    def __init__(self, obs_dim, hidden=128):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Linear(obs_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
        )
        self.policy = nn.Sequential(
            nn.Linear(hidden, 64),
            nn.ReLU(),
            nn.Linear(64, 1)  # logit for binary action
        )
        self.value = nn.Sequential(
            nn.Linear(hidden, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        h = self.shared(x)
        logit = self.policy(h).squeeze(-1)
        value = self.value(h).squeeze(-1)
        return logit, value

# Helper: compute advantages (GAE)
def compute_gae(rewards, values, dones, gamma=0.99, lam=0.95):
    advantages = np.zeros_like(rewards)
    lastgae = 0.0
    for t in reversed(range(len(rewards))):
        nonterminal = 1.0 - dones[t]
        delta = rewards[t] + gamma * (values[t+1] if t+1 < len(values) else 0.0) * nonterminal - values[t]
        advantages[t] = lastgae = delta + gamma * lam * nonterminal * lastgae
    returns = advantages + values[:len(advantages)]
    return advantages, returns

# -------------------------
# Training loop (vectorized episodes)
# -------------------------
def train_ppo(env_ctor, steps=2000, batch_size=32, epochs=20, clip_eps=0.2,
              policy_lr=3e-4, value_lr=1e-3, gamma=0.99):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    env = env_ctor()
    obs_dim = env._get_obs().shape[0]
    ac = MLPActorCritic(obs_dim).to(device)
    optimizer = optim.Adam(ac.parameters(), lr=policy_lr)
    value_optimizer = None  # we'll use single optimizer for simplicity

    # Storage
    all_rewards = []
    print(f"Training PPO for {steps} episodes. Device: {device}")

    for ep in range(steps):
        obs_buf = []
        act_buf = []
        logp_buf = []
        rew_buf = []
        val_buf = []
        done_buf = []

        obs = env.reset()
        ep_reward = 0.0
        traj_obs = []
        traj_rewards = []
        traj_values = []
        traj_logps = []
        traj_actions = []
        traj_dones = []

        # run one episode
        while True:
            obs_t = torch.tensor(obs, dtype=torch.float32).unsqueeze(0).to(device)
            with torch.no_grad():
                logit, value = ac(obs_t)
                prob = torch.sigmoid(logit).cpu().numpy()[0]
            action = 1 if random.random() < prob else 0
            # compute logp for action
            logp = math.log(prob + 1e-8) if action == 1 else math.log(1.0 - prob + 1e-8)

            next_obs, reward, done, info = env.step(action)
            traj_obs.append(obs.copy())
            traj_rewards.append(reward)
            traj_values.append(value.cpu().numpy()[0])
            traj_logps.append(logp)
            traj_actions.append(action)
            traj_dones.append(float(done))

            obs = next_obs
            ep_reward += reward
            if done:
                break

        # value for terminal next state assumed 0
        values_np = np.array(traj_values + [0.0], dtype=np.float32)
        advantages, returns = compute_gae(traj_rewards, values_np, traj_dones, gamma=gamma)

        # convert to tensors
        obs_tensor = torch.tensor(np.array(traj_obs), dtype=torch.float32).to(device)
        actions_tensor = torch.tensor(np.array(traj_actions), dtype=torch.float32).to(device)
        old_logps = torch.tensor(np.array(traj_logps), dtype=torch.float32).to(device)
        advantages_tensor = torch.tensor(advantages, dtype=torch.float32).to(device)
        returns_tensor = torch.tensor(returns, dtype=torch.float32).to(device)

        # Normalize advantages
        advantages_tensor = (advantages_tensor - advantages_tensor.mean()) / (advantages_tensor.std() + 1e-8)

        # PPO update (several epochs on the single-episode batch)
        for _ in range(8):  # minibatch multiple epochs
            logits, values_pred = ac(obs_tensor)
            probs = torch.sigmoid(logits)
            # new log prob
            new_logps = actions_tensor * torch.log(probs + 1e-8) + (1 - actions_tensor) * torch.log(1 - probs + 1e-8)
            ratio = torch.exp(new_logps - old_logps)
            surr1 = ratio * advantages_tensor
            surr2 = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps) * advantages_tensor
            policy_loss = -torch.min(surr1, surr2).mean()
            value_loss = F.mse_loss(values_pred, returns_tensor)
            entropy = -(probs * torch.log(probs + 1e-8) + (1-probs) * torch.log(1-probs + 1e-8)).mean()
            loss = policy_loss + 0.5 * value_loss - 0.01 * entropy

            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(ac.parameters(), 0.5)
            optimizer.step()

        all_rewards.append(ep_reward)

        # Logging
        if (ep + 1) % 50 == 0:
            last_mean = np.mean(all_rewards[-50:])
            print(f"Episode {ep+1}/{steps} | avg reward (last 50): {last_mean:.3f} | last ep reward: {ep_reward:.3f}")
    return ac, all_rewards

# -------------------------
# Run training (short) and show results
# -------------------------
if __name__ == "__main__":
    start = time.time()
    # Smaller training for demo purposes (you can increase steps to 2000+ for real training)
    policy, rewards = train_ppo(lambda: SimpleCellEnv(max_steps=8, seed=int(time.time()%10000)), steps=400, policy_lr=2e-4)
    end = time.time()
    print(f"Finished training in {end-start:.2f}s. Mean reward last 50: {np.mean(rewards[-50:]):.3f}")

    # Show some sample rollouts of the learned policy
    env = SimpleCellEnv(max_steps=8, seed=1234)
    obs = env.reset()
    traj = []
    for _ in range(20):
        obs_t = torch.tensor(obs, dtype=torch.float32).unsqueeze(0)
        logit, _ = policy(obs_t)
        prob = torch.sigmoid(logit).item()
        action = 1 if random.random() < prob else 0
        next_obs, reward, done, info = env.step(action)
        traj.append((obs.copy(), action, reward, done, info))
        obs = next_obs
        if done:
            break

    print("Sample rollout (obs, action, reward, done, info):")
    for x in traj:
        print(x)



Training PPO for 400 episodes. Device: cpu
Episode 50/400 | avg reward (last 50): -1.132 | last ep reward: -1.900
Episode 100/400 | avg reward (last 50): -1.064 | last ep reward: -0.300
Episode 150/400 | avg reward (last 50): -1.116 | last ep reward: -0.300
Episode 200/400 | avg reward (last 50): -1.400 | last ep reward: -2.300
Episode 250/400 | avg reward (last 50): -1.212 | last ep reward: -2.300
Episode 300/400 | avg reward (last 50): -1.644 | last ep reward: -0.300
Episode 350/400 | avg reward (last 50): -1.208 | last ep reward: -0.300
Episode 400/400 | avg reward (last 50): -0.980 | last ep reward: -0.300
Finished training in 11.63s. Mean reward last 50: -0.980
Sample rollout (obs, action, reward, done, info):
(array([0.4298121 , 0.13746962, 0.        , 0.        ], dtype=float32), 1, 0.0, False, {})
(array([0.76584154, 0.18182777, 0.11111111, 0.11111111], dtype=float32), 0, 0.0, False, {})
(array([0.5228621 , 0.15000159, 0.22222222, 0.11111111], dtype=float32), 1, 0.0, False, {})

In [None]:
# Retry execution with graceful fallback if PyTorch missing.
# This cell checks for torch; if unavailable, it prints installation instructions and exits.
try:
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torch.nn.functional as F
except Exception as e:
    print("PyTorch import failed:", e)
    print("If you want to run the RL demo locally, install PyTorch first. Example:")
    print("  pip install torch --index-url https://download.pytorch.org/whl/cpu")
    raise SystemExit("Missing PyTorch — aborting execution here.")

import math, random, time
from dataclasses import dataclass
from typing import Tuple, List
import numpy as np

# -------------------------
# Simple cell environment (same as before)
# -------------------------
class SimpleCellEnv:
    def __init__(self, max_steps=10, seed=0):
        self.max_steps = max_steps
        self.rng = np.random.RandomState(seed)
        self.reset()

    def reset(self):
        self.tel = float(self.rng.normal(loc=0.0, scale=0.3))
        self.ros = float(self.rng.normal(loc=0.2, scale=0.2))
        self.age = 0.0
        self.cumulative_tel_actions = 0.0
        self.step_count = 0
        self.done = False
        return self._get_obs()

    def _get_obs(self):
        return np.array([self.tel, self.ros, self.age / (self.max_steps + 1), self.cumulative_tel_actions / (self.max_steps + 1)], dtype=np.float32)

    def step(self, action: int):
        assert action in (0,1)
        if action == 1:
            self.tel += 0.5 * (1.0 - 0.1 * self.cumulative_tel_actions)
            self.cumulative_tel_actions += 1.0
        self.tel -= 0.2 + 0.05 * self.rng.randn()
        self.ros += 0.05 * self.rng.randn()
        self.age += 1.0
        self.step_count += 1
        done = False
        reward = 0.0
        info = {}
        if self.step_count >= self.max_steps:
            tel_score = self.tel
            ros_score = -self.ros
            logits = np.array([1.2*tel_score + 0.2*ros_score,
                               -0.5*tel_score + 0.6*(-ros_score),
                               -0.8*ros_score + 0.1*(1.0 - tel_score)])
            logits += 0.3 * self.rng.randn(3)
            probs = np.exp(logits - np.max(logits))
            probs = probs / probs.sum()
            outcome = self.rng.choice(3, p=probs)
            if outcome == 0:
                reward = 1.0
            elif outcome == 1:
                reward = -1.0
            else:
                reward = -2.0
            cancer_risk = max(0.0, (self.cumulative_tel_actions - 1.5) * 0.2)
            reward -= cancer_risk
            done = True
            info['outcome'] = int(outcome)
            info['probs'] = probs.tolist()
            info['cancer_risk'] = float(cancer_risk)
        obs = self._get_obs()
        self.done = done
        return obs, reward, done, info

# -------------------------
# PPO agent (policy + value)
# -------------------------
class MLPActorCritic(nn.Module):
    def __init__(self, obs_dim, hidden=128):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Linear(obs_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
        )
        self.policy = nn.Sequential(nn.Linear(hidden, 64), nn.ReLU(), nn.Linear(64, 1))
        self.value = nn.Sequential(nn.Linear(hidden, 64), nn.ReLU(), nn.Linear(64, 1))

    def forward(self, x):
        h = self.shared(x)
        logit = self.policy(h).squeeze(-1)
        value = self.value(h).squeeze(-1)
        return logit, value

def compute_gae(rewards, values, dones, gamma=0.99, lam=0.95):
    advantages = np.zeros_like(rewards, dtype=np.float32)
    lastgae = 0.0
    for t in reversed(range(len(rewards))):
        nonterminal = 1.0 - dones[t]
        next_value = values[t+1] if t+1 < len(values) else 0.0
        delta = rewards[t] + gamma * next_value * nonterminal - values[t]
        lastgae = delta + gamma * lam * nonterminal * lastgae
        advantages[t] = lastgae
    returns = advantages + values[:len(advantages)]
    return advantages, returns

def train_ppo(env_ctor, steps=200, clip_eps=0.2, policy_lr=2e-4, gamma=0.99):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    env = env_ctor()
    obs_dim = env._get_obs().shape[0]
    ac = MLPActorCritic(obs_dim).to(device)
    optimizer = optim.Adam(ac.parameters(), lr=policy_lr)
    all_rewards = []
    start_time = time.time()
    for ep in range(steps):
        obs = env.reset()
        traj_obs, traj_actions, traj_rewards, traj_values, traj_logps, traj_dones = [], [], [], [], [], []
        ep_reward = 0.0
        while True:
            obs_t = torch.tensor(obs, dtype=torch.float32).unsqueeze(0).to(device)
            with torch.no_grad():
                logit, value = ac(obs_t)
                prob = torch.sigmoid(logit).cpu().numpy()[0]
            action = 1 if random.random() < prob else 0
            logp = math.log(prob + 1e-8) if action == 1 else math.log(1.0 - prob + 1e-8)
            next_obs, reward, done, info = env.step(action)
            traj_obs.append(obs.copy())
            traj_actions.append(action)
            traj_rewards.append(reward)
            traj_values.append(value.cpu().numpy()[0])
            traj_logps.append(logp)
            traj_dones.append(float(done))
            obs = next_obs
            ep_reward += reward
            if done:
                break
        values_np = np.array(traj_values + [0.0], dtype=np.float32)
        advantages, returns = compute_gae(traj_rewards, values_np, traj_dones, gamma=gamma)
        obs_tensor = torch.tensor(np.array(traj_obs), dtype=torch.float32).to(device)
        actions_tensor = torch.tensor(np.array(traj_actions), dtype=torch.float32).to(device)
        old_logps = torch.tensor(np.array(traj_logps), dtype=torch.float32).to(device)
        advantages_tensor = torch.tensor(advantages, dtype=torch.float32).to(device)
        returns_tensor = torch.tensor(returns, dtype=torch.float32).to(device)
        advantages_tensor = (advantages_tensor - advantages_tensor.mean()) / (advantages_tensor.std() + 1e-8)
        # PPO update
        for _ in range(6):
            logits, values_pred = ac(obs_tensor)
            probs = torch.sigmoid(logits)
            new_logps = actions_tensor * torch.log(probs + 1e-8) + (1 - actions_tensor) * torch.log(1 - probs + 1e-8)
            ratio = torch.exp(new_logps - old_logps)
            surr1 = ratio * advantages_tensor
            surr2 = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps) * advantages_tensor
            policy_loss = -torch.min(surr1, surr2).mean()
            value_loss = F.mse_loss(values_pred, returns_tensor)
            entropy = -(probs * torch.log(probs + 1e-8) + (1-probs) * torch.log(1-probs + 1e-8)).mean()
            loss = policy_loss + 0.5 * value_loss - 0.01 * entropy
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(ac.parameters(), 0.5)
            optimizer.step()
        all_rewards.append(ep_reward)
        if (ep + 1) % 40 == 0:
            avg = np.mean(all_rewards[-40:])
            elapsed = time.time() - start_time
            print(f"Episode {ep+1}/{steps} | avg reward (last 40): {avg:.3f} | elapsed: {elapsed:.1f}s")
    return ac, all_rewards

# Run a short training session for demonstration
if __name__ == "__main__":
    ac_model, rewards = train_ppo(lambda: SimpleCellEnv(max_steps=8, seed=42), steps=200, policy_lr=2e-4)
    print("Training complete. Mean reward (last 50):", np.mean(rewards[-50:]))

    # Show a sample rollout
    env = SimpleCellEnv(max_steps=8, seed=999)
    obs = env.reset()
    rollout = []
    for _ in range(20):
        obs_t = torch.tensor(obs, dtype=torch.float32).unsqueeze(0)
        with torch.no_grad():
            logit, _ = ac_model(obs_t)
            prob = torch.sigmoid(logit).item()
        action = 1 if random.random() < prob else 0
        next_obs, reward, done, info = env.step(action)
        rollout.append((obs.copy(), action, reward, done, info))
        obs = next_obs
        if done:
            break
    print("Sample rollout:")
    for step in rollout:
        print(step)


Episode 40/200 | avg reward (last 40): -1.353 | elapsed: 0.8s
Episode 80/200 | avg reward (last 40): -1.220 | elapsed: 1.8s
Episode 120/200 | avg reward (last 40): -1.080 | elapsed: 2.9s
Episode 160/200 | avg reward (last 40): -1.035 | elapsed: 3.9s
Episode 200/200 | avg reward (last 40): -1.090 | elapsed: 4.8s
Training complete. Mean reward (last 50): -1.0519999999999998
Sample rollout:
(array([0.0944445 , 0.02831017, 0.        , 0.        ], dtype=float32), 1, 0.0, False, {})
(array([ 0.40775123, -0.00413487,  0.11111111,  0.11111111], dtype=float32), 1, 0.0, False, {})
(array([ 0.57943785, -0.10870337,  0.22222222,  0.22222222], dtype=float32), 1, 0.0, False, {})
(array([ 0.7066214 , -0.06143871,  0.33333334,  0.33333334], dtype=float32), 1, 0.0, False, {})
(array([ 0.8766315 , -0.04567734,  0.44444445,  0.44444445], dtype=float32), 1, 0.0, False, {})
(array([ 1.0321345 , -0.07491842,  0.5555556 ,  0.5555556 ], dtype=float32), 1, 0.0, False, {})
(array([ 1.091555  , -0.03426724,  0.

In [None]:
# ppo_telomerase.py
# Requirements: torch, numpy
# pip install torch numpy
import math, random, time
from typing import Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# -------------------------
# Simple cell environment
# -------------------------
class SimpleCellEnv:
    """
    Minimal episodic environment representing a single cell's lifecycle across T steps.
    State: [telomere_proxy, ros, age_norm, cumulative_tel_actions_norm]
    Action: 0 (no telomerase), 1 (activate telomerase transiently)
    Reward: terminal only: +1 healthy, -1 senescence, -2 apoptosis, minus cancer risk penalty.
    """
    def __init__(self, max_steps=10, seed=0):
        self.max_steps = max_steps
        self.rng = np.random.RandomState(seed)
        self.reset()

    def reset(self):
        self.tel = float(self.rng.normal(loc=0.0, scale=0.3))
        self.ros = float(self.rng.normal(loc=0.2, scale=0.2))
        self.age = 0.0
        self.cumulative_tel_actions = 0.0
        self.step_count = 0
        self.done = False
        return self._get_obs()

    def _get_obs(self):
        return np.array([
            self.tel,
            self.ros,
            self.age / (self.max_steps + 1),
            self.cumulative_tel_actions / (self.max_steps + 1)
        ], dtype=np.float32)

    def step(self, action: int):
        assert action in (0,1)
        if action == 1:
            # transient restoration, diminishing returns
            self.tel += 0.5 * (1.0 - 0.1 * self.cumulative_tel_actions)
            self.cumulative_tel_actions += 1.0

        # natural decline + noise
        self.tel -= 0.2 + 0.05 * self.rng.randn()
        self.ros += 0.05 * self.rng.randn()
        self.age += 1.0
        self.step_count += 1

        reward = 0.0
        done = False
        info = {}

        if self.step_count >= self.max_steps:
            # outcome logits depend on tel and ros (toy model)
            tel_score = self.tel
            ros_score = -self.ros
            logits = np.array([
                1.2*tel_score + 0.2*ros_score,                      # healthy
                -0.5*tel_score + 0.6*(-ros_score),                  # senescence (short tel)
                -0.8*ros_score + 0.1*(1.0 - tel_score)              # apoptosis (high ros)
            ])
            logits += 0.3 * self.rng.randn(3)  # stochasticity
            probs = np.exp(logits - np.max(logits))
            probs = probs / probs.sum()

            outcome = self.rng.choice(3, p=probs)
            if outcome == 0:
                reward = 1.0
            elif outcome == 1:
                reward = -1.0
            else:
                reward = -2.0

            # cancer risk penalty for repeated telomerase activations
            cancer_risk = max(0.0, (self.cumulative_tel_actions - 1.5) * 0.2)
            reward -= cancer_risk

            done = True
            info['outcome'] = int(outcome)
            info['probs'] = probs.tolist()
            info['cancer_risk'] = float(cancer_risk)

        obs = self._get_obs()
        self.done = done
        return obs, reward, done, info

# -------------------------
# PPO actor-critic model
# -------------------------
class MLPActorCritic(nn.Module):
    def __init__(self, obs_dim, hidden=128):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Linear(obs_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
        )
        self.policy = nn.Sequential(
            nn.Linear(hidden, 64),
            nn.ReLU(),
            nn.Linear(64, 1)  # binary action logit
        )
        self.value = nn.Sequential(
            nn.Linear(hidden, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        h = self.shared(x)
        logit = self.policy(h).squeeze(-1)
        value = self.value(h).squeeze(-1)
        return logit, value

# Generalized Advantage Estimation
def compute_gae(rewards, values, dones, gamma=0.99, lam=0.95):
    advantages = np.zeros_like(rewards, dtype=np.float32)
    lastgae = 0.0
    for t in reversed(range(len(rewards))):
        nonterminal = 1.0 - dones[t]
        next_value = values[t+1] if t+1 < len(values) else 0.0
        delta = rewards[t] + gamma * next_value * nonterminal - values[t]
        lastgae = delta + gamma * lam * nonterminal * lastgae
        advantages[t] = lastgae
    returns = advantages + values[:len(advantages)]
    return advantages, returns

# -------------------------
# PPO training loop
# -------------------------
def train_ppo(env_ctor, steps=1000, clip_eps=0.2, policy_lr=2e-4, gamma=0.99):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    env = env_ctor()
    obs_dim = env._get_obs().shape[0]
    ac = MLPActorCritic(obs_dim).to(device)
    optimizer = optim.Adam(ac.parameters(), lr=policy_lr)
    all_rewards = []
    start_time = time.time()

    for ep in range(steps):
        obs = env.reset()
        traj_obs, traj_actions, traj_rewards, traj_values, traj_logps, traj_dones = [], [], [], [], [], []
        ep_reward = 0.0

        # collect one episode (you can vectorize for efficiency)
        while True:
            obs_t = torch.tensor(obs, dtype=torch.float32).unsqueeze(0).to(device)
            with torch.no_grad():
                logit, value = ac(obs_t)
                prob = torch.sigmoid(logit).cpu().numpy()[0]
            action = 1 if random.random() < prob else 0
            logp = math.log(prob + 1e-8) if action == 1 else math.log(1.0 - prob + 1e-8)

            next_obs, reward, done, info = env.step(action)
            traj_obs.append(obs.copy())
            traj_actions.append(action)
            traj_rewards.append(reward)
            traj_values.append(value.cpu().numpy()[0])
            traj_logps.append(logp)
            traj_dones.append(float(done))

            obs = next_obs
            ep_reward += reward
            if done:
                break

        # compute GAE
        values_np = np.array(traj_values + [0.0], dtype=np.float32)
        advantages, returns = compute_gae(traj_rewards, values_np, traj_dones, gamma=gamma)

        # tensors
        obs_tensor = torch.tensor(np.array(traj_obs), dtype=torch.float32).to(device)
        actions_tensor = torch.tensor(np.array(traj_actions), dtype=torch.float32).to(device)
        old_logps = torch.tensor(np.array(traj_logps), dtype=torch.float32).to(device)
        advs = torch.tensor(advantages, dtype=torch.float32).to(device)
        rets = torch.tensor(returns, dtype=torch.float32).to(device)
        advs = (advs - advs.mean()) / (advs.std() + 1e-8)

        # PPO update: multiple epochs over this episode's data
        for _ in range(6):
            logits, values_pred = ac(obs_tensor)
            probs = torch.sigmoid(logits)
            new_logps = actions_tensor * torch.log(probs + 1e-8) + (1 - actions_tensor) * torch.log(1 - probs + 1e-8)
            ratio = torch.exp(new_logps - old_logps)
            surr1 = ratio * advs
            surr2 = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps) * advs
            policy_loss = -torch.min(surr1, surr2).mean()
            value_loss = F.mse_loss(values_pred, rets)
            entropy = -(probs * torch.log(probs + 1e-8) + (1-probs) * torch.log(1-probs + 1e-8)).mean()
            loss = policy_loss + 0.5 * value_loss - 0.01 * entropy

            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(ac.parameters(), 0.5)
            optimizer.step()

        all_rewards.append(ep_reward)
        # logging
        if (ep + 1) % 50 == 0:
            avg = np.mean(all_rewards[-50:])
            elapsed = time.time() - start_time
            print(f"Episode {ep+1}/{steps} | avg reward (last 50): {avg:.3f} | elapsed: {elapsed:.1f}s")
    return ac, all_rewards

# -------------------------
# Run training (adjust steps for longer training)
# -------------------------
if __name__ == "__main__":
    model, rewards = train_ppo(lambda: SimpleCellEnv(max_steps=8, seed=42), steps=400, policy_lr=2e-4)
    print("Training complete. Mean reward (last 50):", np.mean(rewards[-50:]))

    # Show a sample rollout
    env = SimpleCellEnv(max_steps=8, seed=1234)
    obs = env.reset()
    rollout = []
    for _ in range(20):
        obs_t = torch.tensor(obs, dtype=torch.float32).unsqueeze(0)
        with torch.no_grad():
            logit, _ = model(obs_t)
            prob = torch.sigmoid(logit).item()
        action = 1 if random.random() < prob else 0
        next_obs, reward, done, info = env.step(action)
        rollout.append((obs.copy(), action, reward, done, info))
        obs = next_obs
        if done:
            break
    print("Sample rollout:")
    for step in rollout:
        print(step)


Episode 50/400 | avg reward (last 50): -1.178 | elapsed: 1.0s
Episode 100/400 | avg reward (last 50): -0.996 | elapsed: 1.9s
Episode 150/400 | avg reward (last 50): -1.152 | elapsed: 3.0s
Episode 200/400 | avg reward (last 50): -1.056 | elapsed: 4.0s
Episode 250/400 | avg reward (last 50): -1.260 | elapsed: 5.1s
Episode 300/400 | avg reward (last 50): -1.452 | elapsed: 6.1s
Episode 350/400 | avg reward (last 50): -0.712 | elapsed: 7.2s
Episode 400/400 | avg reward (last 50): -1.484 | elapsed: 8.3s
Training complete. Mean reward (last 50): -1.4839999999999998
Sample rollout:
(array([0.4298121 , 0.13746962, 0.        , 0.        ], dtype=float32), 1, 0.0, False, {})
(array([0.76584154, 0.18182777, 0.11111111, 0.11111111], dtype=float32), 1, 0.0, False, {})
(array([0.9728621 , 0.15000159, 0.22222222, 0.22222222], dtype=float32), 1, 0.0, False, {})
(array([1.1720773 , 0.03786734, 0.33333334, 0.33333334], dtype=float32), 1, 0.0, False, {})
(array([1.2645755 , 0.08746465, 0.44444445, 0.44444

In [None]:
import time, random, math
from typing import List, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# -------------------------
# Multi-cell tissue environment
# -------------------------
class Cell:
    def __init__(self, rng):
        self.rng = rng
        self.reset()

    def reset(self):
        self.tel = float(self.rng.normal(loc=0.0, scale=0.3))
        self.ros = float(self.rng.normal(loc=0.2, scale=0.2))
        self.age = 0.0
        self.cumulative_tel = 0.0

    def step(self, action):
        # action: 0 or 1 for this cell in this timestep
        if action == 1:
            self.tel += 0.5 * (1.0 - 0.1 * self.cumulative_tel)
            self.cumulative_tel += 1.0
        self.tel -= 0.15 + 0.05 * self.rng.randn()
        self.ros += 0.03 * self.rng.randn()
        self.age += 1.0

    def get_obs(self, max_steps):
        return np.array([self.tel, self.ros, self.age / (max_steps + 1), self.cumulative_tel], dtype=np.float32)


class TissueEnv:
    """
    Tissue of M cells. Episode runs T steps. At the end, each cell produces an outcome
    (healthy, senescent, apoptotic) probabilistically based on tel and ros.
    Reward: sum over cells of (+1 healthy, -1 senescent, -2 apoptotic) minus cancer penalties
    that scale with excess telomerase activations per cell.
    """
    def __init__(self, M=10, T=6, budget=3, seed=0):
        self.M = M
        self.T = T
        self.budget = budget
        self.rng = np.random.RandomState(seed)
        self.cells = [Cell(self.rng) for _ in range(M)]
        self.step_count = 0

    def reset(self):
        self.cells = [Cell(self.rng) for _ in range(self.M)]
        self.step_count = 0
        return self._get_obs()

    def _get_obs(self):
        # Return M x obs_dim matrix and a global summary
        per_cell = np.stack([c.get_obs(self.T) for c in self.cells], axis=0)  # (M,4)
        # global summary: mean tel, mean ros, step fraction
        mean_tel = per_cell[:,0].mean()
        mean_ros = per_cell[:,1].mean()
        step_frac = self.step_count / (self.T + 1)
        global_feat = np.array([mean_tel, mean_ros, step_frac], dtype=np.float32)
        return per_cell, global_feat

    def step(self, actions: List[int]):
        # actions: list/array of length M with 0/1, but we assume caller enforces budget.
        assert len(actions) == self.M
        for c, a in zip(self.cells, actions):
            c.step(int(a))
        self.step_count += 1
        done = self.step_count >= self.T
        reward = 0.0
        info = {}
        if done:
            # compute cell outcomes
            rewards = []
            cancer_penalty = 0.0
            for c in self.cells:
                tel_score = c.tel
                ros_score = -c.ros
                logits = np.array([1.1*tel_score + 0.2*ros_score,
                                   -0.6*tel_score + 0.5*(-ros_score),
                                   -0.9*ros_score + 0.05*(1.0 - tel_score)])
                logits += 0.25 * self.rng.randn(3)
                probs = np.exp(logits - np.max(logits))
                probs = probs / probs.sum()
                outcome = self.rng.choice(3, p=probs)
                if outcome == 0:
                    r = 1.0
                elif outcome == 1:
                    r = -1.0
                else:
                    r = -2.0
                rewards.append(r)
                # cancer risk if cumulative tel activations exceed threshold (per-cell)
                cancer_penalty += max(0.0, (c.cumulative_tel - 1.2) * 0.3)
            reward = float(sum(rewards) - cancer_penalty)
            info['per_cell_rewards'] = rewards
            info['cancer_penalty'] = cancer_penalty
        obs = self._get_obs()
        return obs, reward, done, info

# -------------------------
# Policy / value network for tissue
# -------------------------
class TissuePolicy(nn.Module):
    """
    Produces per-cell logits (for Bernoulli) and a scalar value for the whole tissue.
    Architecture: per-cell encoder (shared), then optional global fusion.
    """
    def __init__(self, per_cell_dim=4, global_dim=3, hidden=128):
        super().__init__()
        self.cell_enc = nn.Sequential(
            nn.Linear(per_cell_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden//2),
            nn.ReLU()
        )
        # produce logits per cell from encoded cell features + global summary
        self.global_enc = nn.Sequential(
            nn.Linear(global_dim, hidden//2),
            nn.ReLU()
        )
        self.policy_head = nn.Sequential(
            nn.Linear(hidden//2 + hidden//2, hidden//2),
            nn.ReLU(),
            nn.Linear(hidden//2, 1)  # logit per cell
        )
        # value head takes pooled cell encodings + global to predict scalar
        self.value_head = nn.Sequential(
            nn.Linear((hidden//2) + (hidden//2), hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1)
        )

    def forward(self, per_cell_obs, global_feat):
        # per_cell_obs: (M, per_cell_dim), global_feat: (global_dim,)
        M = per_cell_obs.shape[0]
        cell_h = self.cell_enc(per_cell_obs)  # (M, h2)
        global_h = self.global_enc(global_feat).unsqueeze(0).expand(M, -1)  # (M, h2)
        combined = torch.cat([cell_h, global_h], dim=-1)  # (M, h2*2)
        logits = self.policy_head(combined).squeeze(-1)  # (M,)
        # value: average pooling of cell encodings concatenated with mean global features
        pooled = cell_h.mean(dim=0)
        value_input = torch.cat([pooled, self.global_enc(global_feat)], dim=-1)
        value = self.value_head(value_input).squeeze(-1)
        return logits, value

# -------------------------
# PPO utilities (GAE, etc.)
# -------------------------
def compute_gae_seq(rewards, values, dones, gamma=0.99, lam=0.95):
    # rewards: list len=T of scalar rewards; values: list len=T of scalar values (and last bootstrap if used)
    adv = np.zeros(len(rewards), dtype=np.float32)
    lastgaelam = 0.0
    for t in reversed(range(len(rewards))):
        nonterminal = 1.0 - dones[t]
        nextval = values[t+1] if t+1 < len(values) else 0.0
        delta = rewards[t] + gamma * nextval * nonterminal - values[t]
        lastgaelam = delta + gamma * lam * nonterminal * lastgaelam
        adv[t] = lastgaelam
    returns = adv + values[:len(rewards)]
    return adv, returns

# -------------------------
# Training loop for PPO on TissueEnv
# -------------------------
def train_tissue_ppo(env_ctor, episodes=500, budget=3, clip_eps=0.2, policy_lr=3e-4, gamma=0.99):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    env = env_ctor()
    M = env.M
    obs_dim = env._get_obs()[0].shape[1]
    global_dim = env._get_obs()[1].shape[0]
    policy = TissuePolicy(per_cell_dim=obs_dim, global_dim=global_dim).to(device)
    optimizer = optim.Adam(policy.parameters(), lr=policy_lr)
    all_rewards = []
    start_time = time.time()

    for ep in range(episodes):
        per_cell_obs, global_feat = env.reset()
        traj_obs_cells = []
        traj_globals = []
        traj_actions = []
        traj_logps = []
        traj_values = []
        traj_rewards = []
        traj_dones = []
        ep_reward = 0.0

        while True:
            # prepare tensors
            per_cell_tensor = torch.tensor(per_cell_obs, dtype=torch.float32).to(device)  # (M,4)
            global_tensor = torch.tensor(global_feat, dtype=torch.float32).to(device)     # (g,)
            logits, value = policy(per_cell_tensor, global_tensor)  # logits (M,), value scalar
            probs = torch.sigmoid(logits).cpu().detach().numpy()
            # sample bernoulli per cell
            sampled = (np.random.rand(M) < probs).astype(np.int32)
            # enforce budget: if exceed, pick top-k by probs
            if sampled.sum() > budget:
                # choose indices of top probabilities
                topk_idx = np.argsort(-probs)[:budget]
                actions = np.zeros(M, dtype=np.int32)
                actions[topk_idx] = 1
            else:
                actions = sampled

            # compute logps for chosen actions
            eps = 1e-8
            logps = actions * np.log(probs + eps) + (1 - actions) * np.log(1 - probs + eps)
            logp_sum = float(logps.sum())  # sum logp across cells so policy gradient treats joint action
            # step env
            (next_per_cell_obs, next_global), reward, done, info = env.step(actions.tolist())
            traj_obs_cells.append(per_cell_obs.copy())
            traj_globals.append(global_feat.copy())
            traj_actions.append(actions.copy())
            traj_logps.append(logp_sum)
            traj_values.append(float(value.cpu().detach().numpy()))
            traj_rewards.append(reward)
            traj_dones.append(float(done))
            ep_reward += reward

            per_cell_obs, global_feat = next_per_cell_obs, next_global

            if done:
                break

        # compute GAE for this episode (single scalar reward per timestep)
        values_np = np.array(traj_values + [0.0], dtype=np.float32)
        advs, returns = compute_gae_seq(traj_rewards, values_np, traj_dones, gamma=gamma)

        # convert to tensors for update
        obs_cells_tensor = torch.tensor(np.array(traj_obs_cells), dtype=torch.float32).to(device)  # (T, M, per_cell_dim)
        globals_tensor = torch.tensor(np.array(traj_globals), dtype=torch.float32).to(device)      # (T, global_dim)
        actions_tensor = torch.tensor(np.array(traj_actions), dtype=torch.float32).to(device)     # (T, M)
        old_logps_tensor = torch.tensor(np.array(traj_logps), dtype=torch.float32).to(device)    # (T,)
        advs_tensor = torch.tensor(advs, dtype=torch.float32).to(device)
        returns_tensor = torch.tensor(returns, dtype=torch.float32).to(device)
        advs_tensor = (advs_tensor - advs_tensor.mean()) / (advs_tensor.std() + 1e-8)

        # PPO update: perform several epochs over this episode's trajectory
        for _ in range(6):
            # recompute logits and values for each timestep
            new_logps = []
            values_pred = []
            for t in range(len(traj_obs_cells)):
                per_cell_t = obs_cells_tensor[t]  # (M, per_cell_dim)
                global_t = globals_tensor[t]
                logits_t, val_t = policy(per_cell_t, global_t)
                probs_t = torch.sigmoid(logits_t)
                # compute joint log prob (sum over cells) of the actions taken at time t
                actions_t = actions_tensor[t]
                logp_cells = actions_t * torch.log(probs_t + 1e-8) + (1 - actions_t) * torch.log(1 - probs_t + 1e-8)
                joint_logp = logp_cells.sum()
                new_logps.append(joint_logp)
                values_pred.append(val_t)
            new_logps = torch.stack(new_logps)  # (T,)
            values_pred = torch.stack(values_pred).squeeze(-1)  # (T,)

            ratio = torch.exp(new_logps - old_logps_tensor)
            surr1 = ratio * advs_tensor
            surr2 = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps) * advs_tensor
            policy_loss = -torch.min(surr1, surr2).mean()
            value_loss = F.mse_loss(values_pred, returns_tensor)
            # entropy over per-cell Bernoullis (sum entropy across cells, averaged over timesteps)
            entropies = []
            for t in range(len(traj_obs_cells)):
                per_cell_t = obs_cells_tensor[t]
                logits_t, _ = policy(per_cell_t, globals_tensor[t])
                probs_t = torch.sigmoid(logits_t)
                ent = -(probs_t * torch.log(probs_t + 1e-8) + (1-probs_t) * torch.log(1-probs_t + 1e-8)).sum()
                entropies.append(ent)
            entropy_term = torch.stack(entropies).mean()
            loss = policy_loss + 0.5 * value_loss - 0.01 * entropy_term

            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(policy.parameters(), 1.0)
            optimizer.step()

        all_rewards.append(ep_reward)
        if (ep + 1) % 50 == 0:
            avg = np.mean(all_rewards[-50:])
            elapsed = time.time() - start_time
            print(f"Episode {ep+1}/{episodes} | avg tissue reward (last50): {avg:.3f} | elapsed: {elapsed:.1f}s")

    return policy, all_rewards

# -------------------------
# Demo run (short)
# -------------------------
if __name__ == "__main__":
    env_ctor = lambda: TissueEnv(M=10, T=6, budget=3, seed=42)
    policy_model, rewards = train_tissue_ppo(env_ctor, episodes=300, budget=3, policy_lr=3e-4)
    print("Training finished. Mean reward last 50:", np.mean(rewards[-50:]))

    # Show one sample rollout from trained policy
    env = env_ctor()
    per_cell_obs, global_feat = env.reset()
    rollout = []
    for t in range(env.T):
        per_cell_tensor = torch.tensor(per_cell_obs, dtype=torch.float32)
        global_tensor = torch.tensor(global_feat, dtype=torch.float32)
        logits, val = policy_model(per_cell_tensor, global_tensor)
        probs = torch.sigmoid(logits).detach().numpy()
        sampled = (np.random.rand(env.M) < probs).astype(int)
        if sampled.sum() > env.budget:
            topk = np.argsort(-probs)[:env.budget]
            actions = np.zeros(env.M, dtype=int); actions[topk] = 1
        else:
            actions = sampled
        (per_cell_obs, global_feat), reward, done, info = env.step(actions.tolist())
        rollout.append((actions.copy(), reward, info))
        if done: break

    print("Sample rollout (actions per timestep, reward, info):")
    for step in rollout:
        print(step)

Episode 50/300 | avg tissue reward (last50): -11.025 | elapsed: 5.8s
Episode 100/300 | avg tissue reward (last50): -10.964 | elapsed: 10.7s
Episode 150/300 | avg tissue reward (last50): -11.815 | elapsed: 16.5s
Episode 200/300 | avg tissue reward (last50): -12.072 | elapsed: 21.5s
Episode 250/300 | avg tissue reward (last50): -12.682 | elapsed: 27.8s
Episode 300/300 | avg tissue reward (last50): -11.736 | elapsed: 32.8s
Training finished. Mean reward last 50: -11.736400000000001
Sample rollout (actions per timestep, reward, info):
(array([0, 0, 0, 1, 0, 1, 0, 0, 0, 1]), 0.0, {})
(array([0, 0, 0, 1, 0, 1, 0, 0, 0, 1]), 0.0, {})
(array([0, 0, 0, 1, 0, 1, 0, 0, 0, 1]), 0.0, {})
(array([0, 0, 0, 1, 0, 1, 0, 0, 0, 1]), 0.0, {})
(array([0, 0, 0, 1, 0, 1, 0, 0, 0, 1]), 0.0, {})
(array([0, 0, 0, 1, 0, 1, 0, 0, 0, 1]), -13.32, {'per_cell_rewards': [-2.0, -1.0, -2.0, -1.0, -1.0, 1.0, 1.0, -2.0, -1.0, -1.0], 'cancer_penalty': 4.32})


In [None]:
pip install numpy
pip install torch   # choose the correct CUDA/CPU wheel for your machine


In [None]:
# Multi-cell tissue environment with per-cell identity (stem vs differentiated)
# and PPO training where the policy can learn to preferentially target stem cells.
# Educational demo. Save as `tissue_with_id.py` and run locally.
# Requirements: torch, numpy
# pip install torch numpy

import time, random, math
from typing import List, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# -------------------------
# Cell and Tissue environment with identity
# -------------------------
class Cell:
    def __init__(self, rng, is_stem: bool):
        self.rng = rng
        self.is_stem = bool(is_stem)
        self.reset()

    def reset(self):
        # Stem cells start with slightly longer telomeres and lower ROS baseline
        if self.is_stem:
            self.tel = float(self.rng.normal(loc=0.2, scale=0.25))
            self.ros = float(self.rng.normal(loc=0.15, scale=0.15))
        else:
            self.tel = float(self.rng.normal(loc=0.0, scale=0.3))
            self.ros = float(self.rng.normal(loc=0.25, scale=0.2))
        self.age = 0.0
        self.cumulative_tel = 0.0

    def step(self, action):
        # action: 0 or 1 for this cell in this timestep
        if action == 1:
            # stem cells may respond more strongly to telomerase (slightly larger gain)
            gain = 0.6 if self.is_stem else 0.45
            self.tel += gain * (1.0 - 0.08 * self.cumulative_tel)
            self.cumulative_tel += 1.0
        # telomere shortens slower for stem cells
        base_shorten = 0.12 if self.is_stem else 0.18
        self.tel -= base_shorten + 0.04 * self.rng.randn()
        # ROS dynamics: differentiated cells may accumulate ROS faster
        ros_noise = 0.02 if self.is_stem else 0.04
        self.ros += ros_noise * self.rng.randn()
        self.age += 1.0

    def get_obs(self, max_steps):
        # include identity as a feature (1.0 for stem, 0.0 for differentiated)
        return np.array([self.tel, self.ros, self.age / (max_steps + 1), self.cumulative_tel, float(self.is_stem)], dtype=np.float32)


class TissueEnvWithID:
    """
    Tissue environment with per-cell identity. M cells, episode length T, budget activations per step.
    Identity: some fraction of cells are stem cells; policy receives per-cell identity as input.
    """
    def __init__(self, M=12, T=6, budget=3, stem_fraction=0.3, seed=0):
        self.M = M
        self.T = T
        self.budget = budget
        self.rng = np.random.RandomState(seed)
        self.stem_fraction = stem_fraction
        self.cells = []
        self.step_count = 0
        self._make_cells()

    def _make_cells(self):
        self.cells = []
        num_stem = max(1, int(round(self.M * self.stem_fraction)))
        stem_indices = set(self.rng.choice(self.M, size=num_stem, replace=False))
        for i in range(self.M):
            is_stem = i in stem_indices
            self.cells.append(Cell(self.rng, is_stem))

    def reset(self):
        self._make_cells()
        for c in self.cells:
            c.reset()
        self.step_count = 0
        return self._get_obs()

    def _get_obs(self):
        per_cell = np.stack([c.get_obs(self.T) for c in self.cells], axis=0)  # (M,5)
        mean_tel = per_cell[:,0].mean()
        mean_ros = per_cell[:,1].mean()
        stem_frac = np.mean([1.0 if c.is_stem else 0.0 for c in self.cells])
        step_frac = self.step_count / (self.T + 1)
        global_feat = np.array([mean_tel, mean_ros, stem_frac, step_frac], dtype=np.float32)
        return per_cell, global_feat

    def step(self, actions: List[int]):
        assert len(actions) == self.M
        for c, a in zip(self.cells, actions):
            c.step(int(a))
        self.step_count += 1
        done = self.step_count >= self.T
        reward = 0.0
        info = {}
        if done:
            rewards = []
            cancer_penalty = 0.0
            # make cancer risk heavier for stem cells (biologically plausible)
            for c in self.cells:
                tel_score = c.tel
                ros_score = -c.ros
                logits = np.array([1.2*tel_score + 0.15*ros_score,
                                   -0.6*tel_score + 0.5*(-ros_score),
                                   -0.9*ros_score + 0.05*(1.0 - tel_score)])
                logits += 0.25 * self.rng.randn(3)
                probs = np.exp(logits - np.max(logits))
                probs = probs / probs.sum()
                outcome = self.rng.choice(3, p=probs)
                if outcome == 0:
                    r = 1.0
                elif outcome == 1:
                    r = -1.0
                else:
                    r = -2.0
                rewards.append(r)
                # cancer risk weight higher for stem cells
                risk_weight = 0.5 if c.is_stem else 0.2
                cancer_penalty += risk_weight * max(0.0, (c.cumulative_tel - 1.2) * 0.3)
            reward = float(sum(rewards) - cancer_penalty)
            info['per_cell_rewards'] = rewards
            info['cancer_penalty'] = cancer_penalty
        obs = self._get_obs()
        return obs, reward, done, info

# -------------------------
# Policy / value network updated to accept per-cell identity
# -------------------------
class TissuePolicyWithID(nn.Module):
    """
    Produces per-cell logits and a scalar tissue value.
    Per-cell encoder now takes identity as an extra input (last feature).
    """
    def __init__(self, per_cell_dim=5, global_dim=4, hidden=128):
        super().__init__()
        self.cell_enc = nn.Sequential(
            nn.Linear(per_cell_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden//2),
            nn.ReLU()
        )
        self.global_enc = nn.Sequential(
            nn.Linear(global_dim, hidden//2),
            nn.ReLU()
        )
        self.policy_head = nn.Sequential(
            nn.Linear((hidden//2) + (hidden//2), hidden//2),
            nn.ReLU(),
            nn.Linear(hidden//2, 1)  # logit per cell
        )
        self.value_head = nn.Sequential(
            nn.Linear((hidden//2) + (hidden//2), hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1)
        )

    def forward(self, per_cell_obs, global_feat):
        # per_cell_obs: (M, per_cell_dim), global_feat: (global_dim,)
        M = per_cell_obs.shape[0]
        cell_h = self.cell_enc(per_cell_obs)  # (M, h2)
        global_h = self.global_enc(global_feat).unsqueeze(0).expand(M, -1)  # (M, h2)
        combined = torch.cat([cell_h, global_h], dim=-1)  # (M, h2*2)
        logits = self.policy_head(combined).squeeze(-1)  # (M,)
        pooled = cell_h.mean(dim=0)
        value_input = torch.cat([pooled, self.global_enc(global_feat)], dim=-1)
        value = self.value_head(value_input).squeeze(-1)
        return logits, value

# -------------------------
# PPO utilities and training (similar to earlier tissue PPO)
# -------------------------
def compute_gae_seq(rewards, values, dones, gamma=0.99, lam=0.95):
    adv = np.zeros(len(rewards), dtype=np.float32)
    lastgaelam = 0.0
    for t in reversed(range(len(rewards))):
        nonterminal = 1.0 - dones[t]
        nextval = values[t+1] if t+1 < len(values) else 0.0
        delta = rewards[t] + gamma * nextval * nonterminal - values[t]
        lastgaelam = delta + gamma * lam * nonterminal * lastgaelam
        adv[t] = lastgaelam
    returns = adv + values[:len(rewards)]
    return adv, returns

def train_tissue_ppo_with_id(env_ctor, episodes=400, budget=3, clip_eps=0.2, policy_lr=3e-4, gamma=0.99):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    env = env_ctor()
    M = env.M
    per_cell_dim = env._get_obs()[0].shape[1]  # 5 now (tel,ros,age,cum_tel,is_stem)
    global_dim = env._get_obs()[1].shape[0]    # 4 (mean_tel,mean_ros,stem_frac,step_frac)
    policy = TissuePolicyWithID(per_cell_dim=per_cell_dim, global_dim=global_dim).to(device)
    optimizer = optim.Adam(policy.parameters(), lr=policy_lr)
    all_rewards = []
    start_time = time.time()

    for ep in range(episodes):
        per_cell_obs, global_feat = env.reset()
        traj_obs_cells, traj_globals, traj_actions, traj_logps, traj_values, traj_rewards, traj_dones = [], [], [], [], [], [], []
        ep_reward = 0.0

        while True:
            per_cell_tensor = torch.tensor(per_cell_obs, dtype=torch.float32).to(device)  # (M,5)
            global_tensor = torch.tensor(global_feat, dtype=torch.float32).to(device)     # (g,)
            logits, value = policy(per_cell_tensor, global_tensor)
            probs = torch.sigmoid(logits).cpu().detach().numpy()
            sampled = (np.random.rand(M) < probs).astype(np.int32)
            # enforce budget top-k if exceed
            if sampled.sum() > budget:
                topk_idx = np.argsort(-probs)[:budget]
                actions = np.zeros(M, dtype=np.int32)
                actions[topk_idx] = 1
            else:
                actions = sampled
            # joint logp
            eps = 1e-8
            logps = actions * np.log(probs + eps) + (1 - actions) * np.log(1 - probs + eps)
            logp_sum = float(logps.sum())
            (next_per_cell_obs, next_global), reward, done, info = env.step(actions.tolist())
            traj_obs_cells.append(per_cell_obs.copy())
            traj_globals.append(global_feat.copy())
            traj_actions.append(actions.copy())
            traj_logps.append(logp_sum)
            traj_values.append(float(value.cpu().detach().numpy()))
            traj_rewards.append(reward)
            traj_dones.append(float(done))
            ep_reward += reward
            per_cell_obs, global_feat = next_per_cell_obs, next_global
            if done:
                break

        # GAE
        values_np = np.array(traj_values + [0.0], dtype=np.float32)
        advs, returns = compute_gae_seq(traj_rewards, values_np, traj_dones, gamma=gamma)

        # tensors
        obs_cells_tensor = torch.tensor(np.array(traj_obs_cells), dtype=torch.float32).to(device)
        globals_tensor = torch.tensor(np.array(traj_globals), dtype=torch.float32).to(device)
        actions_tensor = torch.tensor(np.array(traj_actions), dtype=torch.float32).to(device)
        old_logps_tensor = torch.tensor(np.array(traj_logps), dtype=torch.float32).to(device)
        advs_tensor = torch.tensor(advs, dtype=torch.float32).to(device)
        returns_tensor = torch.tensor(returns, dtype=torch.float32).to(device)
        advs_tensor = (advs_tensor - advs_tensor.mean()) / (advs_tensor.std() + 1e-8)

        # PPO update
        for _ in range(6):
            new_logps = []
            values_pred = []
            entropies = []
            for t in range(len(traj_obs_cells)):
                per_cell_t = obs_cells_tensor[t]
                global_t = globals_tensor[t]
                logits_t, val_t = policy(per_cell_t, global_t)
                probs_t = torch.sigmoid(logits_t)
                actions_t = actions_tensor[t]
                logp_cells = actions_t * torch.log(probs_t + 1e-8) + (1 - actions_t) * torch.log(1 - probs_t + 1e-8)
                joint_logp = logp_cells.sum()
                new_logps.append(joint_logp)
                values_pred.append(val_t)
                ent = -(probs_t * torch.log(probs_t + 1e-8) + (1-probs_t) * torch.log(1-probs_t + 1e-8)).sum()
                entropies.append(ent)
            new_logps = torch.stack(new_logps)
            values_pred = torch.stack(values_pred).squeeze(-1)
            ratio = torch.exp(new_logps - old_logps_tensor)
            surr1 = ratio * advs_tensor
            surr2 = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps) * advs_tensor
            policy_loss = -torch.min(surr1, surr2).mean()
            value_loss = F.mse_loss(values_pred, returns_tensor)
            entropy_term = torch.stack(entropies).mean()
            loss = policy_loss + 0.5 * value_loss - 0.01 * entropy_term

            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(policy.parameters(), 1.0)
            optimizer.step()

        all_rewards.append(ep_reward)
        if (ep + 1) % 50 == 0:
            avg = np.mean(all_rewards[-50:])
            elapsed = time.time() - start_time
            print(f"Episode {ep+1}/{episodes} | avg tissue reward (last50): {avg:.3f} | elapsed: {elapsed:.1f}s")

    return policy, all_rewards

# -------------------------
# Demo run (short)
# -------------------------
if __name__ == "__main__":
    env_ctor = lambda: TissueEnvWithID(M=12, T=6, budget=3, stem_fraction=0.25, seed=42)
    policy_model, rewards = train_tissue_ppo_with_id(env_ctor, episodes=300, budget=3, policy_lr=3e-4)
    print("Training finished. Mean reward last 50:", np.mean(rewards[-50:]))

    # Show one sample rollout from trained policy
    env = env_ctor()
    per_cell_obs, global_feat = env.reset()
    rollout = []
    for t in range(env.T):
        per_cell_tensor = torch.tensor(per_cell_obs, dtype=torch.float32)
        global_tensor = torch.tensor(global_feat, dtype=torch.float32)
        logits, val = policy_model(per_cell_tensor, global_tensor)
        probs = torch.sigmoid(logits).detach().numpy()
        sampled = (np.random.rand(env.M) < probs).astype(int)
        if sampled.sum() > env.budget:
            topk = np.argsort(-probs)[:env.budget]
            actions = np.zeros(env.M, dtype=int); actions[topk] = 1
        else:
            actions = sampled
        (per_cell_obs, global_feat), reward, done, info = env.step(actions.tolist())
        rollout.append((actions.copy(), reward, info))
        if done: break

    print("Sample rollout (actions per timestep, reward, info):")
    for step in rollout:
        print(step)

Episode 50/300 | avg tissue reward (last50): -10.479 | elapsed: 4.7s
Episode 100/300 | avg tissue reward (last50): -10.239 | elapsed: 9.5s
Episode 150/300 | avg tissue reward (last50): -11.616 | elapsed: 14.0s
Episode 200/300 | avg tissue reward (last50): -12.250 | elapsed: 18.6s
Episode 250/300 | avg tissue reward (last50): -12.084 | elapsed: 22.3s
Episode 300/300 | avg tissue reward (last50): -10.533 | elapsed: 26.1s
Training finished. Mean reward last 50: -10.5326
Sample rollout (actions per timestep, reward, info):
(array([1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0]), 0.0, {})
(array([0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0]), 0.0, {})
(array([0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1]), 0.0, {})
(array([0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0]), 0.0, {})
(array([0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0]), 0.0, {})
(array([0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1]), -10.336, {'per_cell_rewards': [-2.0, -1.0, 1.0, -1.0, -2.0, -2.0, 1.0, -1.0, -1.0, 1.0, -2.0, -1.0], 'cancer_penalty': 0.33599999999999997})


In [None]:
python tissue_with_id.py


In [None]:
pip install numpy
pip install torch


In [None]:
def gumbel_topk_sample(logits, k, tau=0.5, hard=False):
    """
    Differentiable top-k via Gumbel-Softmax.
    Returns soft mask (shape [M]) with at most k active entries (softly).
    If hard=True, returns a hard top-k mask (0/1).
    """
    gumbel = -torch.log(-torch.log(torch.rand_like(logits) + 1e-9) + 1e-9)
    y = (logits + gumbel) / tau
    y_soft = F.softmax(y, dim=0)
    if hard:
        topk = torch.topk(y_soft, k)[1]
        hard_mask = torch.zeros_like(y_soft)
        hard_mask[topk] = 1.0
        # straight-through: replace forward with hard, keep soft gradients
        y_soft = (hard_mask - y_soft).detach() + y_soft
    return y_soft


In [None]:
# tissue_gumbel_topk.py
# Requirements: torch, numpy
# pip install torch numpy

import time, math, random
from typing import List
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# -------------------------
# Gumbel-Top-k helper (differentiable)
# -------------------------
def sample_gumbel(shape, device='cpu', eps=1e-9):
    u = torch.rand(shape, device=device)
    return -torch.log(-torch.log(u + eps) + eps)

def gumbel_topk_soft(logits: torch.Tensor, k: int, tau: float = 0.5, hard: bool = True):
    """
    Differentiable Gumbel-Top-k.
    - logits: (M,) tensor of per-item logits
    - k: how many items to select
    - tau: softmax temperature (higher -> softer)
    - hard: if True, returns a hard top-k mask for forward pass but gradients flow through soft weights
    Returns:
      soft_mask: (M,) continuous values that sum to ~1 (soft selection over top-k)
      hard_mask: (M,) 0/1 mask (detached when used in forward environment) - only if hard requested
    Implementation notes:
      - We add Gumbel noise to logits and softmax at temperature tau.
      - Then approximate top-k by turning off low-weight entries (keep top-k by soft value).
      - Straight-through estimator: forward uses hard_mask, backward uses soft_mask.
    """
    device = logits.device
    M = logits.shape[0]

    # Gumbel noise
    g = sample_gumbel((M,), device=device)
    y = (logits + g) / (tau + 1e-12)
    y_soft = F.softmax(y, dim=0)  # soft selection across all M

    # keep only top-k entries of the soft vector (by magnitude) - produce hard mask
    _, topk_idx = torch.topk(y_soft, k)
    hard_mask = torch.zeros_like(y_soft)
    hard_mask[topk_idx] = 1.0

    if hard:
        # Straight-through: use hard_mask in forward, but keep gradient of y_soft
        out = (hard_mask - y_soft).detach() + y_soft
        return out, hard_mask
    else:
        return y_soft, hard_mask

# -------------------------
# Environment with identity (stem vs differentiated)
# -------------------------
class Cell:
    def __init__(self, rng, is_stem: bool):
        self.rng = rng
        self.is_stem = bool(is_stem)
        self.reset()

    def reset(self):
        if self.is_stem:
            self.tel = float(self.rng.normal(loc=0.2, scale=0.25))
            self.ros = float(self.rng.normal(loc=0.15, scale=0.15))
        else:
            self.tel = float(self.rng.normal(loc=0.0, scale=0.3))
            self.ros = float(self.rng.normal(loc=0.25, scale=0.2))
        self.age = 0.0
        self.cumulative_tel = 0.0

    def step(self, action):
        # action: 0 or 1
        if action == 1:
            gain = 0.6 if self.is_stem else 0.45
            self.tel += gain * (1.0 - 0.08 * self.cumulative_tel)
            self.cumulative_tel += 1.0
        base_shorten = 0.12 if self.is_stem else 0.18
        self.tel -= base_shorten + 0.04 * self.rng.randn()
        ros_noise = 0.02 if self.is_stem else 0.04
        self.ros += ros_noise * self.rng.randn()
        self.age += 1.0

    def get_obs(self, max_steps):
        # tel, ros, age_norm, cumulative_tel, is_stem
        return np.array([self.tel, self.ros, self.age / (max_steps + 1), self.cumulative_tel, float(self.is_stem)], dtype=np.float32)

class TissueEnvWithID:
    def __init__(self, M=12, T=6, budget=3, stem_fraction=0.25, seed=0):
        self.M = M
        self.T = T
        self.budget = budget
        self.rng = np.random.RandomState(seed)
        self.stem_fraction = stem_fraction
        self._make_cells()
        self.step_count = 0

    def _make_cells(self):
        num_stem = max(1, int(round(self.M * self.stem_fraction)))
        stem_idx = set(self.rng.choice(self.M, size=num_stem, replace=False))
        self.cells = [Cell(self.rng, i in stem_idx) for i in range(self.M)]

    def reset(self):
        self._make_cells()
        for c in self.cells:
            c.reset()
        self.step_count = 0
        return self._get_obs()

    def _get_obs(self):
        per_cell = np.stack([c.get_obs(self.T) for c in self.cells], axis=0)  # (M,5)
        mean_tel = per_cell[:,0].mean()
        mean_ros = per_cell[:,1].mean()
        stem_frac = np.mean([1.0 if c.is_stem else 0.0 for c in self.cells])
        step_frac = self.step_count / (self.T + 1)
        global_feat = np.array([mean_tel, mean_ros, stem_frac, step_frac], dtype=np.float32)
        return per_cell, global_feat

    def step(self, actions: List[int]):
        assert len(actions) == self.M
        for c, a in zip(self.cells, actions):
            c.step(int(a))
        self.step_count += 1
        done = self.step_count >= self.T
        reward = 0.0
        info = {}
        if done:
            rewards = []
            cancer_penalty = 0.0
            for c in self.cells:
                tel_score = c.tel
                ros_score = -c.ros
                logits = np.array([1.2*tel_score + 0.15*ros_score,
                                   -0.6*tel_score + 0.5*(-ros_score),
                                   -0.9*ros_score + 0.05*(1.0 - tel_score)])
                logits += 0.25 * self.rng.randn(3)
                probs = np.exp(logits - np.max(logits))
                probs = probs / probs.sum()
                outcome = self.rng.choice(3, p=probs)
                if outcome == 0:
                    r = 1.0
                elif outcome == 1:
                    r = -1.0
                else:
                    r = -2.0
                rewards.append(r)
                risk_weight = 0.5 if c.is_stem else 0.2
                cancer_penalty += risk_weight * max(0.0, (c.cumulative_tel - 1.2) * 0.3)
            reward = float(sum(rewards) - cancer_penalty)
            info['per_cell_rewards'] = rewards
            info['cancer_penalty'] = cancer_penalty
        obs = self._get_obs()
        return obs, reward, done, info

# -------------------------
# Policy with per-cell identity
# -------------------------
class TissuePolicyWithID(nn.Module):
    def __init__(self, per_cell_dim=5, global_dim=4, hidden=128):
        super().__init__()
        self.cell_enc = nn.Sequential(
            nn.Linear(per_cell_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden//2),
            nn.ReLU()
        )
        self.global_enc = nn.Sequential(
            nn.Linear(global_dim, hidden//2),
            nn.ReLU()
        )
        self.policy_head = nn.Sequential(
            nn.Linear((hidden//2) + (hidden//2), hidden//2),
            nn.ReLU(),
            nn.Linear(hidden//2, 1)  # logit per cell
        )
        self.value_head = nn.Sequential(
            nn.Linear((hidden//2) + (hidden//2), hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1)
        )

    def forward(self, per_cell_obs, global_feat):
        M = per_cell_obs.shape[0]
        cell_h = self.cell_enc(per_cell_obs)  # (M, h2)
        global_h = self.global_enc(global_feat).unsqueeze(0).expand(M, -1)
        combined = torch.cat([cell_h, global_h], dim=-1)
        logits = self.policy_head(combined).squeeze(-1)
        pooled = cell_h.mean(dim=0)
        value_input = torch.cat([pooled, self.global_enc(global_feat)], dim=-1)
        value = self.value_head(value_input).squeeze(-1)
        return logits, value

# -------------------------
# PPO utilities
# -------------------------
def compute_gae_seq(rewards, values, dones, gamma=0.99, lam=0.95):
    adv = np.zeros(len(rewards), dtype=np.float32)
    lastgaelam = 0.0
    for t in reversed(range(len(rewards))):
        nonterminal = 1.0 - dones[t]
        nextval = values[t+1] if t+1 < len(values) else 0.0
        delta = rewards[t] + gamma * nextval * nonterminal - values[t]
        lastgaelam = delta + gamma * lam * nonterminal * lastgaelam
        adv[t] = lastgaelam
    returns = adv + values[:len(rewards)]
    return adv, returns

# -------------------------
# Training loop with Gumbel-top-k
# -------------------------
def train_tissue_ppo_gumbel(env_ctor, episodes=400, budget=3, clip_eps=0.2,
                            policy_lr=3e-4, gamma=0.99, tau=0.6, device=None):
    device = device or (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
    env = env_ctor()
    M = env.M
    per_cell_dim = env._get_obs()[0].shape[1]
    global_dim = env._get_obs()[1].shape[0]
    policy = TissuePolicyWithID(per_cell_dim=per_cell_dim, global_dim=global_dim).to(device)
    optimizer = optim.Adam(policy.parameters(), lr=policy_lr)
    all_rewards = []
    start_time = time.time()

    for ep in range(episodes):
        per_cell_obs, global_feat = env.reset()
        traj_obs_cells, traj_globals, traj_actions, traj_logps, traj_values, traj_rewards, traj_dones = [], [], [], [], [], [], []
        ep_reward = 0.0

        while True:
            per_cell_tensor = torch.tensor(per_cell_obs, dtype=torch.float32).to(device)
            global_tensor = torch.tensor(global_feat, dtype=torch.float32).to(device)
            logits, value = policy(per_cell_tensor, global_tensor)  # logits (M,), value scalar

            # Gumbel-Top-k soft sample (straight-through)
            soft_mask, hard_mask = gumbel_topk_soft(logits, k=budget, tau=tau, hard=True)
            # soft_mask: used for logprob & gradients; hard_mask: used for environment forward
            # compute joint "log-prob" surrogate from soft mask
            eps = 1e-9
            joint_logp = torch.log(soft_mask + eps).sum()  # treat soft_mask as pseudo-prob
            joint_logp_val = float(joint_logp.detach().cpu().numpy())

            # actions for environment: hard mask (0/1) cast to ints
            actions = hard_mask.detach().cpu().numpy().astype(int).tolist()

            (next_per_cell_obs, next_global), reward, done, info = env.step(actions)

            traj_obs_cells.append(per_cell_obs.copy())
            traj_globals.append(global_feat.copy())
            traj_actions.append(actions)             # store discrete actions for info (not used in logp)
            traj_logps.append(joint_logp_val)       # old logprobs: scalar per timestep
            traj_values.append(float(value.cpu().detach().numpy()))
            traj_rewards.append(reward)
            traj_dones.append(float(done))
            ep_reward += reward

            per_cell_obs, global_feat = next_per_cell_obs, next_global
            if done:
                break

        # compute advantages/returns
        values_np = np.array(traj_values + [0.0], dtype=np.float32)
        advs, returns = compute_gae_seq(traj_rewards, values_np, traj_dones, gamma=gamma)

        # pack tensors
        obs_cells_tensor = torch.tensor(np.array(traj_obs_cells), dtype=torch.float32).to(device)
        globals_tensor = torch.tensor(np.array(traj_globals), dtype=torch.float32).to(device)
        old_logps_tensor = torch.tensor(np.array(traj_logps), dtype=torch.float32).to(device)
        advs_tensor = torch.tensor(advs, dtype=torch.float32).to(device)
        returns_tensor = torch.tensor(returns, dtype=torch.float32).to(device)
        advs_tensor = (advs_tensor - advs_tensor.mean()) / (advs_tensor.std() + 1e-8)

        # PPO update
        for _ in range(6):
            new_logps = []
            values_pred = []
            entropies = []
            for t in range(len(traj_obs_cells)):
                per_cell_t = obs_cells_tensor[t]
                global_t = globals_tensor[t]
                logits_t, val_t = policy(per_cell_t, global_t)
                # recompute soft_mask (no straight-through here) to compute new_logp
                soft_mask_t, hard_mask_t = gumbel_topk_soft(logits_t, k=budget, tau=tau, hard=False)
                # use soft_mask_t for joint logp
                logp_t = torch.log(soft_mask_t + 1e-9).sum()
                new_logps.append(logp_t)
                values_pred.append(val_t)
                # entropy on soft_mask to encourage exploration (sum across cells)
                ent = -(soft_mask_t * torch.log(soft_mask_t + 1e-9)).sum()
                entropies.append(ent)
            new_logps = torch.stack(new_logps)
            values_pred = torch.stack(values_pred).squeeze(-1)
            ratio = torch.exp(new_logps - old_logps_tensor)
            surr1 = ratio * advs_tensor
            surr2 = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps) * advs_tensor
            policy_loss = -torch.min(surr1, surr2).mean()
            value_loss = F.mse_loss(values_pred, returns_tensor)
            entropy_term = torch.stack(entropies).mean()
            loss = policy_loss + 0.5 * value_loss - 0.01 * entropy_term

            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(policy.parameters(), 1.0)
            optimizer.step()

        all_rewards.append(ep_reward)
        # Optional: anneal temperature over time to make selection crisper
        # tau = max(0.1, tau * 0.995)

        if (ep + 1) % 50 == 0:
            avg = np.mean(all_rewards[-50:])
            elapsed = time.time() - start_time
            print(f"Episode {ep+1}/{episodes} | avg tissue reward (last50): {avg:.3f} | elapsed: {elapsed:.1f}s")

    return policy, all_rewards

# -------------------------
# Demo run
# -------------------------
if __name__ == "__main__":
    env_ctor = lambda: TissueEnvWithID(M=12, T=6, budget=3, stem_fraction=0.25, seed=42)
    policy_model, rewards = train_tissue_ppo_gumbel(env_ctor, episodes=300, budget=3, policy_lr=3e-4, tau=0.8)
    print("Training finished. Mean reward last 50:", np.mean(rewards[-50:]))

    # sample rollout from learned policy (deterministic top-k via soft mask -> topk)
    env = env_ctor()
    per_cell_obs, global_feat = env.reset()
    rollout = []
    for t in range(env.T):
        per_cell_tensor = torch.tensor(per_cell_obs, dtype=torch.float32)
        global_tensor = torch.tensor(global_feat, dtype=torch.float32)
        logits, val = policy_model(per_cell_tensor, global_tensor)
        # deterministic top-k: take topk indices of logits (no noise)
        _, topk_idx = torch.topk(logits, env.budget)
        actions = np.zeros(env.M, dtype=int)
        actions[topk_idx.cpu().numpy()] = 1
        (per_cell_obs, global_feat), reward, done, info = env.step(actions.tolist())
        rollout.append((actions.copy(), reward, info))
        if done: break

    print("Sample rollout (actions per timestep, reward, info):")
    for step in rollout:
        print(step)


Episode 50/300 | avg tissue reward (last50): -10.973 | elapsed: 5.9s
Episode 100/300 | avg tissue reward (last50): -11.190 | elapsed: 12.2s
Episode 150/300 | avg tissue reward (last50): -10.649 | elapsed: 16.0s
Episode 200/300 | avg tissue reward (last50): -11.590 | elapsed: 20.1s
Episode 250/300 | avg tissue reward (last50): -10.846 | elapsed: 23.7s
Episode 300/300 | avg tissue reward (last50): -10.004 | elapsed: 27.2s
Training finished. Mean reward last 50: -10.004159999999999
Sample rollout (actions per timestep, reward, info):
(array([0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0]), 0.0, {})
(array([0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0]), 0.0, {})
(array([0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0]), 0.0, {})
(array([0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0]), 0.0, {})
(array([0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0]), 0.0, {})
(array([0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0]), -6.864, {'per_cell_rewards': [-2.0, -1.0, 1.0, -1.0, -2.0, -1.0, 1.0, 1.0, 1.0, 1.0, -2.0, -2.0], 'cancer_penalty': 0.8639999999999999})


In [None]:
pip install torch numpy


In [None]:
python tissue_with_predictor_gumbel_topk.py


In [None]:
# tissue_with_predictor_gumbel_topk.py
# Full pipeline: train/load VirtualCellModel predictor -> TissueEnv that uses predictor ->
# PPO agent with Gumbel-top-k differentiable selection -> training loop.
#
# Requirements: torch, numpy
# pip install torch numpy

import time, math, random
from typing import List, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os

# -------------------------
# Synthetic dataset to train predictor
# -------------------------
class SyntheticCellDataset(torch.utils.data.Dataset):
    """
    Produces per-cell samples with features:
      [tel, ros, age_norm, cumulative_tel, is_stem]
    Labels: 0=healthy,1=senescent,2=apoptotic
    """
    def __init__(self, n=10000, seed=42):
        self.rng = np.random.RandomState(seed)
        self.n = n
        self._make()

    def _make(self):
        tel = self.rng.normal(loc=0.0, scale=0.4, size=(self.n, 1)).astype(np.float32)
        ros = self.rng.normal(loc=0.2, scale=0.25, size=(self.n,1)).astype(np.float32)
        age = self.rng.randint(0, 8, size=(self.n,1)).astype(np.float32)
        cum_tel = self.rng.randint(0,4,size=(self.n,1)).astype(np.float32)
        is_stem = (self.rng.rand(self.n,1) < 0.25).astype(np.float32)

        X = np.concatenate([tel, ros, age/10.0, cum_tel, is_stem], axis=1)  # age normalized
        # Define rule to create labels (toy biologically-inspired rule)
        labels = np.zeros(self.n, dtype=np.int64)
        # apoptosis if ROS very high
        labels[ros.squeeze() > 0.9] = 2
        # senescence if tel very short and no telomerase
        labels[(tel.squeeze() < -0.6) & (cum_tel.squeeze() < 1.0)] = 1
        # otherwise healthy
        # add noise to labels
        flip = self.rng.rand(self.n)
        noisy = (flip < 0.03)
        labels[noisy] = self.rng.randint(0,3, noisy.sum())
        self.X = X.astype(np.float32)
        self.y = labels

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        return self.X[idx], int(self.y[idx])

# -------------------------
# VirtualCellModel predictor (supervised)
# -------------------------
class VirtualCellModel(nn.Module):
    """
    Small MLP that maps per-cell features to logits for 3 outcomes.
    """
    def __init__(self, in_dim=5, hidden=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.LayerNorm(hidden),
            nn.GELU(),
            nn.Linear(hidden, hidden//2),
            nn.LayerNorm(hidden//2),
            nn.GELU(),
            nn.Linear(hidden//2, 3)
        )

    def forward(self, x):
        # x: (batch, in_dim)
        logits = self.net(x)
        return logits  # logits for 3 classes

def train_predictor(save_path='predictor.pt', epochs=8, batch_size=256, lr=1e-3, device=None):
    device = device or (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
    ds = SyntheticCellDataset(n=12000)
    train_size = int(len(ds)*0.8)
    val_size = len(ds) - train_size
    train_ds, val_ds = torch.utils.data.random_split(ds, [train_size, val_size])
    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = torch.utils.data.DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=1)

    model = VirtualCellModel(in_dim=5, hidden=128).to(device)
    opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    best_acc = 0.0
    for ep in range(epochs):
        model.train()
        total_loss = 0.0
        for X,y in train_loader:
            X = X.to(device)
            y = y.to(device)
            logits = model(X)
            loss = F.cross_entropy(logits, y)
            opt.zero_grad(); loss.backward(); opt.step()
            total_loss += loss.item() * X.size(0)
        # val
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for X,y in val_loader:
                X = X.to(device); y = y.to(device)
                logits = model(X)
                pred = logits.argmax(dim=1)
                correct += (pred==y).sum().item()
                total += X.size(0)
        acc = correct/total
        print(f"[Predictor] Epoch {ep+1}/{epochs} loss={total_loss/len(train_loader.dataset):.4f} val_acc={acc:.4f}")
        if acc > best_acc:
            best_acc = acc
            torch.save(model.state_dict(), save_path)
    print("[Predictor] Best val acc:", best_acc)
    return model, best_acc

# -------------------------
# Gumbel-Top-k helper (differentiable)
# -------------------------
def sample_gumbel(shape, device='cpu', eps=1e-9):
    u = torch.rand(shape, device=device)
    return -torch.log(-torch.log(u + eps) + eps)

def gumbel_topk_soft(logits: torch.Tensor, k: int, tau: float = 0.5, hard: bool = True):
    device = logits.device
    M = logits.shape[0]
    g = sample_gumbel((M,), device=device)
    y = (logits + g) / (tau + 1e-12)
    y_soft = F.softmax(y, dim=0)
    _, topk_idx = torch.topk(y_soft, k)
    hard_mask = torch.zeros_like(y_soft)
    hard_mask[topk_idx] = 1.0
    if hard:
        out = (hard_mask - y_soft).detach() + y_soft
        return out, hard_mask
    else:
        return y_soft, hard_mask

# -------------------------
# Environment uses predictor for outcome probabilities
# -------------------------
class Cell:
    def __init__(self, rng, is_stem: bool):
        self.rng = rng
        self.is_stem = bool(is_stem)
        self.reset()

    def reset(self):
        if self.is_stem:
            self.tel = float(self.rng.normal(loc=0.2, scale=0.25))
            self.ros = float(self.rng.normal(loc=0.15, scale=0.15))
        else:
            self.tel = float(self.rng.normal(loc=0.0, scale=0.3))
            self.ros = float(self.rng.normal(loc=0.25, scale=0.2))
        self.age = 0.0
        self.cumulative_tel = 0.0

    def step(self, action):
        if action == 1:
            gain = 0.6 if self.is_stem else 0.45
            self.tel += gain * (1.0 - 0.08 * self.cumulative_tel)
            self.cumulative_tel += 1.0
        base_shorten = 0.12 if self.is_stem else 0.18
        self.tel -= base_shorten + 0.04 * self.rng.randn()
        ros_noise = 0.02 if self.is_stem else 0.04
        self.ros += ros_noise * self.rng.randn()
        self.age += 1.0

    def get_obs(self, max_steps):
        return np.array([self.tel, self.ros, self.age / (max_steps + 1), self.cumulative_tel, float(self.is_stem)], dtype=np.float32)

class TissueEnvWithPredictor:
    """
    Tissue environment that uses a trained predictor to compute per-cell outcome probabilities.
    Predictor should output logits for 3 classes; we softmax and sample per-cell outcome.
    """
    def __init__(self, predictor: VirtualCellModel, device=None, M=12, T=6, budget=3, stem_fraction=0.25, seed=0):
        self.predictor = predictor
        self.device = device or (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
        self.predictor.to(self.device)
        self.predictor.eval()
        self.M = M
        self.T = T
        self.budget = budget
        self.rng = np.random.RandomState(seed)
        self.stem_fraction = stem_fraction
        self._make_cells()
        self.step_count = 0

    def _make_cells(self):
        num_stem = max(1, int(round(self.M * self.stem_fraction)))
        stem_idx = set(self.rng.choice(self.M, size=num_stem, replace=False))
        self.cells = [Cell(self.rng, i in stem_idx) for i in range(self.M)]

    def reset(self):
        self._make_cells()
        for c in self.cells:
            c.reset()
        self.step_count = 0
        return self._get_obs()

    def _get_obs(self):
        per_cell = np.stack([c.get_obs(self.T) for c in self.cells], axis=0)  # (M,5)
        mean_tel = per_cell[:,0].mean()
        mean_ros = per_cell[:,1].mean()
        stem_frac = np.mean([1.0 if c.is_stem else 0.0 for c in self.cells])
        step_frac = self.step_count / (self.T + 1)
        global_feat = np.array([mean_tel, mean_ros, stem_frac, step_frac], dtype=np.float32)
        return per_cell, global_feat

    def step(self, actions: List[int]):
        assert len(actions) == self.M
        for c, a in zip(self.cells, actions):
            c.step(int(a))
        self.step_count += 1
        done = self.step_count >= self.T
        reward = 0.0
        info = {}
        if done:
            rewards = []
            cancer_penalty = 0.0
            # build batch of per-cell features and get predictor probabilities
            X = np.stack([c.get_obs(self.T) for c in self.cells], axis=0).astype(np.float32)  # (M,5)
            with torch.no_grad():
                X_t = torch.tensor(X, dtype=torch.float32).to(self.device)
                logits = self.predictor(X_t)  # (M,3)
                probs = F.softmax(logits, dim=-1).cpu().numpy()  # (M,3)
            # sample outcome per cell using predictor probabilities
            for i, c in enumerate(self.cells):
                p = probs[i]
                # ensure numeric stability
                p = p / p.sum()
                outcome = self.rng.choice(3, p=p)
                if outcome == 0:
                    r = 1.0
                elif outcome == 1:
                    r = -1.0
                else:
                    r = -2.0
                rewards.append(r)
                risk_weight = 0.5 if c.is_stem else 0.2
                cancer_penalty += risk_weight * max(0.0, (c.cumulative_tel - 1.2) * 0.3)
            reward = float(sum(rewards) - cancer_penalty)
            info['per_cell_rewards'] = rewards
            info['cancer_penalty'] = cancer_penalty
        obs = self._get_obs()
        return obs, reward, done, info

# -------------------------
# Policy (same as before) - per-cell identity supported
# -------------------------
class TissuePolicyWithID(nn.Module):
    def __init__(self, per_cell_dim=5, global_dim=4, hidden=128):
        super().__init__()
        self.cell_enc = nn.Sequential(
            nn.Linear(per_cell_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden//2),
            nn.ReLU()
        )
        self.global_enc = nn.Sequential(
            nn.Linear(global_dim, hidden//2),
            nn.ReLU()
        )
        self.policy_head = nn.Sequential(
            nn.Linear((hidden//2) + (hidden//2), hidden//2),
            nn.ReLU(),
            nn.Linear(hidden//2, 1)  # logit per cell
        )
        self.value_head = nn.Sequential(
            nn.Linear((hidden//2) + (hidden//2), hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1)
        )

    def forward(self, per_cell_obs, global_feat):
        M = per_cell_obs.shape[0]
        cell_h = self.cell_enc(per_cell_obs)  # (M, h2)
        global_h = self.global_enc(global_feat).unsqueeze(0).expand(M, -1)
        combined = torch.cat([cell_h, global_h], dim=-1)
        logits = self.policy_head(combined).squeeze(-1)
        pooled = cell_h.mean(dim=0)
        value_input = torch.cat([pooled, self.global_enc(global_feat)], dim=-1)
        value = self.value_head(value_input).squeeze(-1)
        return logits, value

# -------------------------
# PPO train loop integrating predictor and Gumbel-top-k
# -------------------------
def compute_gae_seq(rewards, values, dones, gamma=0.99, lam=0.95):
    adv = np.zeros(len(rewards), dtype=np.float32)
    lastgaelam = 0.0
    for t in reversed(range(len(rewards))):
        nonterminal = 1.0 - dones[t]
        nextval = values[t+1] if t+1 < len(values) else 0.0
        delta = rewards[t] + gamma * nextval * nonterminal - values[t]
        lastgaelam = delta + gamma * lam * nonterminal * lastgaelam
        adv[t] = lastgaelam
    returns = adv + values[:len(rewards)]
    return adv, returns

def train_tissue_ppo_with_predictor(env_ctor, episodes=400, budget=3, clip_eps=0.2,
                                    policy_lr=3e-4, gamma=0.99, tau=0.6, device=None):
    device = device or (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
    env = env_ctor()
    M = env.M
    per_cell_dim = env._get_obs()[0].shape[1]
    global_dim = env._get_obs()[1].shape[0]
    policy = TissuePolicyWithID(per_cell_dim=per_cell_dim, global_dim=global_dim).to(device)
    optimizer = optim.Adam(policy.parameters(), lr=policy_lr)
    all_rewards = []
    start_time = time.time()

    for ep in range(episodes):
        per_cell_obs, global_feat = env.reset()
        traj_obs_cells, traj_globals, traj_actions, traj_logps, traj_values, traj_rewards, traj_dones = [], [], [], [], [], [], []
        ep_reward = 0.0

        while True:
            per_cell_tensor = torch.tensor(per_cell_obs, dtype=torch.float32).to(device)
            global_tensor = torch.tensor(global_feat, dtype=torch.float32).to(device)
            logits, value = policy(per_cell_tensor, global_tensor)  # logits (M,), value scalar

            # Gumbel-Top-k (straight-through)
            soft_mask, hard_mask = gumbel_topk_soft(logits, k=budget, tau=tau, hard=True)
            joint_logp = torch.log(soft_mask + 1e-9).sum()
            joint_logp_val = float(joint_logp.detach().cpu().numpy())

            actions = hard_mask.detach().cpu().numpy().astype(int).tolist()
            (next_per_cell_obs, next_global), reward, done, info = env.step(actions)

            traj_obs_cells.append(per_cell_obs.copy())
            traj_globals.append(global_feat.copy())
            traj_actions.append(actions)
            traj_logps.append(joint_logp_val)
            traj_values.append(float(value.cpu().detach().numpy()))
            traj_rewards.append(reward)
            traj_dones.append(float(done))
            ep_reward += reward

            per_cell_obs, global_feat = next_per_cell_obs, next_global
            if done:
                break

        values_np = np.array(traj_values + [0.0], dtype=np.float32)
        advs, returns = compute_gae_seq(traj_rewards, values_np, traj_dones, gamma=gamma)

        obs_cells_tensor = torch.tensor(np.array(traj_obs_cells), dtype=torch.float32).to(device)
        globals_tensor = torch.tensor(np.array(traj_globals), dtype=torch.float32).to(device)
        old_logps_tensor = torch.tensor(np.array(traj_logps), dtype=torch.float32).to(device)
        advs_tensor = torch.tensor(advs, dtype=torch.float32).to(device)
        returns_tensor = torch.tensor(returns, dtype=torch.float32).to(device)
        advs_tensor = (advs_tensor - advs_tensor.mean()) / (advs_tensor.std() + 1e-8)

        for _ in range(6):
            new_logps = []
            values_pred = []
            entropies = []
            for t in range(len(traj_obs_cells)):
                per_cell_t = obs_cells_tensor[t]
                global_t = globals_tensor[t]
                logits_t, val_t = policy(per_cell_t, global_t)
                soft_mask_t, _ = gumbel_topk_soft(logits_t, k=budget, tau=tau, hard=False)
                logp_t = torch.log(soft_mask_t + 1e-9).sum()
                new_logps.append(logp_t)
                values_pred.append(val_t)
                ent = -(soft_mask_t * torch.log(soft_mask_t + 1e-9)).sum()
                entropies.append(ent)
            new_logps = torch.stack(new_logps)
            values_pred = torch.stack(values_pred).squeeze(-1)
            ratio = torch.exp(new_logps - old_logps_tensor)
            surr1 = ratio * advs_tensor
            surr2 = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps) * advs_tensor
            policy_loss = -torch.min(surr1, surr2).mean()
            value_loss = F.mse_loss(values_pred, returns_tensor)
            entropy_term = torch.stack(entropies).mean()
            loss = policy_loss + 0.5 * value_loss - 0.01 * entropy_term

            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(policy.parameters(), 1.0)
            optimizer.step()

        all_rewards.append(ep_reward)
        # optional: anneal temperature
        # tau = max(0.1, tau * 0.995)

        if (ep + 1) % 50 == 0:
            avg = np.mean(all_rewards[-50:])
            elapsed = time.time() - start_time
            print(f"[PPO] Episode {ep+1}/{episodes} | avg reward (last50): {avg:.3f} | elapsed: {elapsed:.1f}s")

    return policy, all_rewards

# -------------------------
# Main: train predictor (or load), then train PPO using predictor-driven env
# -------------------------
if __name__ == "__main__":
    DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    PREDICTOR_PATH = "predictor.pt"

    # 1) Train predictor if no saved weights
    if not os.path.exists(PREDICTOR_PATH):
        print("[Main] Training predictor from synthetic data...")
        predictor, acc = train_predictor(save_path=PREDICTOR_PATH, epochs=10, batch_size=256, lr=1e-3, device=DEVICE)
    else:
        print(f"[Main] Loading predictor weights from {PREDICTOR_PATH}")
        predictor = VirtualCellModel(in_dim=5, hidden=128).to(DEVICE)
        predictor.load_state_dict(torch.load(PREDICTOR_PATH, map_location=DEVICE))
        predictor.eval()

    # 2) Create env that uses predictor
    env_ctor = lambda: TissueEnvWithPredictor(predictor=predictor, device=DEVICE, M=12, T=6, budget=3, stem_fraction=0.25, seed=42)

    # 3) Train PPO agent with Gumbel-top-k selection
    print("[Main] Training PPO agent using predictor-driven environment...")
    policy_model, rewards = train_tissue_ppo_with_predictor(env_ctor, episodes=300, budget=3, policy_lr=3e-4, tau=0.8, device=DEVICE)
    print("[Main] PPO training finished. Mean reward (last 50):", np.mean(rewards[-50:]))

    # Save policy if desired
    torch.save(policy_model.state_dict(), "policy_model.pt")
    print("[Main] Saved policy_model.pt")


[Main] Training predictor from synthetic data...
[Predictor] Epoch 1/10 loss=0.3015 val_acc=0.9646
[Predictor] Epoch 2/10 loss=0.1493 val_acc=0.9675
[Predictor] Epoch 3/10 loss=0.1359 val_acc=0.9712
[Predictor] Epoch 4/10 loss=0.1250 val_acc=0.9746
[Predictor] Epoch 5/10 loss=0.1202 val_acc=0.9779
[Predictor] Epoch 6/10 loss=0.1179 val_acc=0.9729
[Predictor] Epoch 7/10 loss=0.1170 val_acc=0.9746
[Predictor] Epoch 8/10 loss=0.1165 val_acc=0.9742
[Predictor] Epoch 9/10 loss=0.1169 val_acc=0.9750
[Predictor] Epoch 10/10 loss=0.1153 val_acc=0.9750
[Predictor] Best val acc: 0.9779166666666667
[Main] Training PPO agent using predictor-driven environment...
[PPO] Episode 50/300 | avg reward (last50): -3.357 | elapsed: 3.5s
[PPO] Episode 100/300 | avg reward (last50): -3.584 | elapsed: 7.7s
[PPO] Episode 150/300 | avg reward (last50): -3.758 | elapsed: 11.3s
[PPO] Episode 200/300 | avg reward (last50): -3.421 | elapsed: 14.8s
[PPO] Episode 250/300 | avg reward (last50): -3.730 | elapsed: 18.9s