# Import packages

In [1]:
from pathlib import Path
import os, sys
for p in [Path.cwd()] + list(Path.cwd().parents):
    if p.name == 'Multifirefly-Project':
        os.chdir(p)
        sys.path.insert(0, str(p / 'multiff_analysis/multiff_code/methods'))
        break

%load_ext autoreload
%autoreload 2


from data_wrangling import specific_utils, combine_info_utils
from pattern_discovery import pattern_by_trials, pattern_by_trials, cluster_analysis, organize_patterns_and_features, category_class
from decision_making_analysis.cluster_replacement import cluster_replacement_utils
from decision_making_analysis.decision_making import decision_making_class, decision_making_utils, intended_targets_classes
from decision_making_analysis.GUAT import GUAT_collect_info_class, GUAT_combine_info_class
from decision_making_analysis.compare_GUAT_and_TAFT import GUAT_vs_TAFT_class, GUAT_vs_TAFT_x_sessions_class, helper_GUAT_vs_TAFT_class
from visualization.matplotlib_tools import plot_trials, plot_behaviors_utils
from visualization.animation import animation_class
from null_behaviors import show_null_trajectory, find_best_arc, curvature_utils, curv_of_traj_utils
from machine_learning.ml_methods import regression_utils, classification_utils, prep_ml_data_utils, hyperparam_tuning_class
from visualization.plotly_polar_tools import plotly_utils_polar, plotly_for_ff_polar, plotly_for_trajectory_polar
from machine_learning.ml_methods import ml_methods_class
from visualization.dash_tools.dash_main_class_methods import dash_applied_to_GUAT_TAFT
from decision_making_analysis.advanced_modeling import model_choices

import os, sys
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from os.path import exists
import math
import copy
import matplotlib.pyplot as plt
import pandas as pd
import itertools
import matplotlib.pyplot as plt
import gc
from scipy import stats
from IPython.display import HTML
from matplotlib import rc
from sklearn.svm import SVC
from sklearn.ensemble import AdaBoostClassifier, BaggingClassifier
from sklearn.neural_network import MLPClassifier, MLPRegressor
from sklearn.linear_model import LinearRegression
from sklearn.neighbors import KNeighborsRegressor
from sklearn.naive_bayes import GaussianNB
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
import warnings
import os, sys, sys
from importlib import reload
from sklearn.exceptions import ConvergenceWarning


plt.rcParams["animation.html"] = "html5"
os.environ['KMP_DUPLICATE_LIB_OK']='True'
rc('animation', html='jshtml')
matplotlib.rcParams.update(matplotlib.rcParamsDefault)
matplotlib.rcParams['animation.embed_limit'] = 2**128
pd.set_option('display.float_format', lambda x: '%.5f' % x)
np.set_printoptions(suppress=True)
pd.options.display.max_rows = 50



Set up logging configuration.


# try now

In [None]:
from decision_making_analysis.advanced_modeling.choice_model import *

# try

In [None]:
multiff_code/methods/decision_making_analysis/advanced_modeling/model_choices.py

In [2]:
import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader



In [4]:
torch.manual_seed(0)

# 3A) Success head on random data
Xs = torch.randn(512, 10)
ys = torch.bernoulli(torch.full((512,), 0.5))
ds = model_choices.SuccessDataset(Xs, ys)
dl = DataLoader(ds, batch_size=64, shuffle=True)
succ = model_choices.SuccessPredictor(in_dim=10)
model_choices.train_success_head(succ, dl, epochs=2)

# 3B) Choice model with masking
items = []
for _ in range(256):
    K = torch.randint(low=2, high=6, size=(1,)).item()  # 2..5 options
    option_feats = torch.randn(K, 8)
    p_succ_opt = torch.sigmoid(torch.randn(K))
    extra_costs = torch.randn(K, 2)
    chosen = torch.randint(low=0, high=K, size=(1,)).item()
    items.append({
        "option_features": option_feats,
        "p_succ_opt": p_succ_opt,
        "extra_costs": extra_costs,
        "option_mask": torch.ones(K, dtype=torch.bool),
        "chosen_index": torch.tensor(chosen),
    })
dc = model_choices.ChoiceDataset(items)
dlc = DataLoader(dc, batch_size=32, shuffle=True, collate_fn=model_choices.choice_collate)
chooser = model_choices.ChoiceScorer(option_dim=8, extra_cost_dim=2, use_psucc=True)
model_choices.train_choice_model(chooser, dlc, epochs=2)

print("Smoke test complete.")

[Success] epoch 1: loss=0.6924
[Success] epoch 2: loss=0.6911
[Choice] epoch 1: loss=1.1762
[Choice] epoch 2: loss=1.1666
Smoke test complete.


In [5]:
model_choices

<module 'decision_making_analysis.advanced_modeling.model_choices' from '/Users/dusiyi/Documents/Multifirefly-Project/multiff_analysis/multiff_code/methods/decision_making_analysis/advanced_modeling/model_choices.py'>

# try 2

In [7]:
# ---- 3A: success head toy data ----
import torch
from torch.utils.data import DataLoader

X_stop = torch.tensor([
    [ 0.2, -0.3, 0.1],   # stop features (D_s = 3)
    [ 1.0,  0.4, 0.5],
    [-0.7,  0.2, 0.9],
    [ 0.3,  0.1, 0.2],
], dtype=torch.float32)
y_inside = torch.tensor([1., 1., 0., 1.], dtype=torch.float32)

ds = SuccessDataset(X_stop, y_inside)
dl = DataLoader(ds, batch_size=2, shuffle=True)

model = SuccessPredictor(in_dim=3)
train_success_head(model, dl, epochs=3, lr=1e-3)

with torch.no_grad():
    print("p_hat_succ:", model(X_stop).tolist())


[Success] epoch 1: loss=0.7340
[Success] epoch 2: loss=0.7292
[Success] epoch 3: loss=0.7245
p_hat_succ: [0.46155235171318054, 0.4903356432914734, 0.45318105816841125, 0.45039984583854675]


In [8]:
# ---- 3B: choice model toy batch (variable K) ----
items = [
    {   # episode 1: K=3 options (e.g., {retry, ff#7, ff#12})
        "option_features": torch.tensor([[0.1, 0.2], [-0.3, 0.5], [0.4, -0.1]], dtype=torch.float32),  # [3, D_o=2]
        "p_succ_opt":     torch.tensor([0.70, 0.25, 0.55], dtype=torch.float32),                       # [3]
        "extra_costs":    torch.tensor([[0.9, 0.1], [0.2, 0.7], [0.6, 0.3]], dtype=torch.float32),     # [3, D_c=2]
        "chosen_index":   torch.tensor(0),  # picked the first option
    },
    {   # episode 2: K=2 options
        "option_features": torch.tensor([[0.0, -0.2], [0.8, 0.1]], dtype=torch.float32),               # [2, 2]
        "p_succ_opt":     torch.tensor([0.40, 0.65], dtype=torch.float32),                              # [2]
        "extra_costs":    torch.tensor([[0.3, 0.4], [0.1, 0.2]], dtype=torch.float32),                 # [2, 2]
        "chosen_index":   torch.tensor(1),
    },
]

loader = DataLoader(ChoiceDataset(items), batch_size=2, shuffle=True, collate_fn=choice_collate)
chooser = ChoiceScorer(option_dim=2, extra_cost_dim=2, use_psucc=True)

for _ in range(3):
    for batch in loader:
        logits = chooser(batch["option_features"], batch["mask"], batch["p_succ_opt"], batch["extra_costs"])
        loss = masked_cross_entropy(logits, batch["targets"], batch["mask"])
        chooser.zero_grad(); loss.backward()
        for p in chooser.parameters(): 
            if p.grad is not None:
                p.data -= 1e-2 * p.grad   # tiny manual SGD step
    print("train loss:", float(loss))


train loss: 0.875237226486206
train loss: 0.8737525939941406
train loss: 0.8722797632217407


# try 3

## run

In [13]:
torch.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# 1) Train success head on synthetic stops
Xtr, ytr, Xva, yva = generate_synthetic_success_data(N=5000)
succ = SuccessPredictor(in_dim=6)
dl_tr = DataLoader(SuccessDataset(Xtr, ytr), batch_size=128, shuffle=True)
train_success_head(succ, dl_tr, epochs=5, lr=1e-3, device=device)
evaluate_success_head(succ.to(device), Xva.to(device), yva.to(device))

# 2) Build a synthetic choice dataset that uses p_succ from the success head
items = build_choice_items(succ, N=3000)
ntr = int(0.85 * len(items))
dc_tr = ChoiceDataset(items[:ntr])
dc_va = ChoiceDataset(items[ntr:])
dl_tr_c = DataLoader(dc_tr, batch_size=64, shuffle=True, collate_fn=choice_collate)
dl_va_c = DataLoader(dc_va, batch_size=128, shuffle=False, collate_fn=choice_collate)

chooser = ChoiceScorer(option_dim=2, extra_cost_dim=2, use_psucc=True)
train_choice_model(chooser, dl_tr_c, epochs=5, lr=1e-3, device=device)
acc = eval_choice_top1(chooser.to(device), dl_va_c, device=device)
print(f"[Choice/Val] top-1 accuracy={acc:.3f}")

Using device: cpu
[Success] epoch 1: loss=0.5714
[Success] epoch 2: loss=0.4938
[Success] epoch 3: loss=0.4550
[Success] epoch 4: loss=0.4413
[Success] epoch 5: loss=0.4312
[Success/Val] acc=0.812  corr(p, -miss)=0.38  corr(p, bright)=-0.01  corr(p, align)=0.88
[Choice] epoch 1: loss=1.0445
[Choice] epoch 2: loss=0.8030
[Choice] epoch 3: loss=0.7667
[Choice] epoch 4: loss=0.7603
[Choice] epoch 5: loss=0.7538
[Choice/Val] top-1 accuracy=0.704


## functions

In [None]:
"""
Tiny PyTorch skeleton for MultiFF retry/switch modeling.

Implements the two core training steps and leaves hooks for the optional
variants. The code is intentionally lightweight: small MLPs, masking for
variable option counts, and minimal training loops.

You can paste this into a file and adapt the Dataset stubs to your
preprocessing.
"""
from __future__ import annotations

import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


# -------------------------------
# Utilities
# -------------------------------


def masked_cross_entropy(logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
    """
    Cross-entropy for variable option sets.
    logits: [B, Kmax]
    targets: [B] (index of chosen option in 0..K-1)
    mask: [B, Kmax] boolean (True for valid options)
    """
    # Put -inf on invalid options so softmax ignores them
    logits_masked = logits.masked_fill(~mask, -1e9)
    return F.cross_entropy(logits_masked, targets)


# -------------------------------
# 3B) Choice scorer over {retry} ∪ {other targets}
# -------------------------------



# -------------------------------
# Optional variants (4)
# -------------------------------

class RetrySwitchHead(nn.Module):
    """Binary logit: retry vs switch after a near-miss."""
    def __init__(self, in_dim: int, hidden: List[int] = [32]):
        super().__init__()
        self.net = mlp(in_dim, hidden, 1)

    def forward(self, features: torch.Tensor) -> torch.Tensor:
        # returns P(retry)
        return torch.sigmoid(self.net(features)).squeeze(-1)


class TwoStagePolicy(nn.Module):
    """
    Stage 1: binary retry vs switch (on near-miss features).
    Stage 2: if switch, choose among other targets using ChoiceScorer.
    """
    def __init__(self, retry_in_dim: int, option_dim: int, extra_cost_dim: int = 0, use_psucc: bool = True):
        super().__init__()
        self.retry_head = RetrySwitchHead(retry_in_dim)
        self.choice = ChoiceScorer(option_dim, extra_cost_dim, use_psucc)

    def forward(self, retry_features: torch.Tensor,
                option_features: torch.Tensor, mask: torch.Tensor,
                p_succ_opt: Optional[torch.Tensor] = None,
                extra_costs: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        p_retry = self.retry_head(retry_features)
        logits_switch = self.choice(option_features, mask, p_succ_opt, extra_costs)
        return p_retry, logits_switch


class HazardHead(nn.Module):
    """
    Discrete-time hazard after a near-miss: h_t = P(switch at t | not switched yet).

    Input per step features X_t -> h_t via sigmoid(MLP).
    Loss implemented via discrete-time survival likelihood.
    """
    def __init__(self, in_dim: int, hidden: List[int] = [32]):
        super().__init__()
        self.net = mlp(in_dim, hidden, 1)

    def forward(self, step_features: torch.Tensor, step_mask: torch.Tensor) -> torch.Tensor:
        """
        step_features: [B, T, D]
        step_mask: [B, T] (True for valid time steps)
        Returns hazards h_t in [0,1]: [B, T]
        """
        h = torch.sigmoid(self.net(step_features)).squeeze(-1)
        return h * step_mask.float()

    @staticmethod
    def survival_nll(h: torch.Tensor, event_index: torch.Tensor, step_mask: torch.Tensor) -> torch.Tensor:
        """
        Negative log-likelihood for discrete-time hazards.
        h: [B, T] hazards
        event_index: [B] index of switch time; if censored (no switch in window), set to -1
        step_mask: [B, T]
        """
        # log S_{t} = sum_{k < t} log(1 - h_k);  log f(t) = log S_t + log h_t
        eps = 1e-6
        log1m_h = torch.log(torch.clamp(1 - h, min=eps)) * step_mask
        cumsums = torch.cumsum(log1m_h, dim=1)  # [B, T]
        B, T = h.shape
        nll = []
        for b in range(B):
            t_star = event_index[b].item()
            if t_star >= 0:  # observed switch
                surv = cumsums[b, t_star - 1] if t_star > 0 else torch.tensor(0.0, device=h.device)
                log_h = torch.log(torch.clamp(h[b, t_star], min=eps))
                nll.append(-(surv + log_h))
            else:  # censored at last valid step
                last = int(step_mask[b].nonzero(as_tuple=False)[-1])
                surv = cumsums[b, last]
                nll.append(-surv)
        return torch.stack(nll).mean()


# -------------------------------
# Belief / POMDP-ish helper (very simple)
# -------------------------------

@dataclass
class BeliefState:
    alpha: torch.Tensor  # evidence for success
    beta: torch.Tensor   # evidence for failure

    def p_succ(self) -> torch.Tensor:
        return self.alpha / (self.alpha + self.beta + 1e-9)


def update_belief(
    belief: BeliefState,
    flash_strength: torch.Tensor,
    miss_distance: Optional[torch.Tensor] = None,
    decay: float = 0.95,
) -> BeliefState:
    """
    Tiny heuristic updater: decay old evidence, add flash as positive evidence,
    add miss_distance (scaled) as negative evidence.
    """
    alpha = decay * belief.alpha + flash_strength
    if miss_distance is not None:
        beta = decay * belief.beta + miss_distance
    else:
        beta = decay * belief.beta
    return BeliefState(alpha=alpha, beta=beta)


# -------------------------------
# Datasets (stubs you will replace)
# -------------------------------

class SuccessDataset(Dataset):
    """Each item: (stop_features [D_s], label {0,1})"""
    def __init__(self, X: torch.Tensor, y: torch.Tensor):
        assert X.ndim == 2 and y.ndim == 1
        self.X, self.y = X.float(), y.long()

    def __len__(self) -> int: return len(self.y)

    def __getitem__(self, i: int) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.X[i], self.y[i].float()


class ChoiceDataset(Dataset):
    """
    Each item is a dict with keys:
      - option_features: [K, D_o]
      - option_mask: [K] (bool)
      - chosen_index: int in [0, K-1]
      - p_succ_opt: Optional [K]
      - extra_costs: Optional [K, D_c]
    """
    def __init__(self, items: List[Dict[str, torch.Tensor]]):
        self.items = items

    def __len__(self): return len(self.items)

    def __getitem__(self, i: int) -> Dict[str, torch.Tensor]:
        return self.items[i]


def choice_collate(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
    Kmax = max(item["option_features"].shape[0] for item in batch)
    B = len(batch)
    D_o = batch[0]["option_features"].shape[1]
    option_features = torch.zeros(B, Kmax, D_o)
    mask = torch.zeros(B, Kmax, dtype=torch.bool)
    targets = torch.zeros(B, dtype=torch.long)
    p_succ_opt = None
    extra_costs = None

    has_ps = all("p_succ_opt" in item for item in batch)
    has_ec = all("extra_costs" in item for item in batch)

    if has_ps:
        p_succ_opt = torch.zeros(B, Kmax)
    if has_ec:
        D_c = batch[0]["extra_costs"].shape[1]
        extra_costs = torch.zeros(B, Kmax, D_c)

    for b, item in enumerate(batch):
        K = item["option_features"].shape[0]
        option_features[b, :K] = item["option_features"]
        mask[b, :K] = True
        targets[b] = int(item["chosen_index"])  # ensure within 0..K-1
        if has_ps:
            p_succ_opt[b, :K] = item["p_succ_opt"]
        if has_ec:
            extra_costs[b, :K] = item["extra_costs"]

    out = {"option_features": option_features, "mask": mask, "targets": targets}
    if has_ps:
        out["p_succ_opt"] = p_succ_opt
    if has_ec:
        out["extra_costs"] = extra_costs
    return out


# -------------------------------
# Training loops (minimal)
# -------------------------------

def train_success_head(model: SuccessPredictor, loader: DataLoader, epochs: int = 10, lr: float = 1e-3,
                       device: str = "cpu") -> None:
    model.to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr)
    bce = nn.BCELoss()
    model.train()
    for ep in range(epochs):
        total = 0.0
        for X, y in loader:
            X, y = X.to(device), y.to(device)
            p = model(X)
            loss = bce(p, y)
            opt.zero_grad(); loss.backward(); opt.step()
            total += loss.item() * X.size(0)
        print(f"[Success] epoch {ep+1}: loss={total/len(loader.dataset):.4f}")


def train_choice_model(model: ChoiceScorer, loader: DataLoader, epochs: int = 10, lr: float = 1e-3,
                       device: str = "cpu") -> None:
    model.to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr)
    model.train()
    for ep in range(epochs):
        total = 0.0
        for batch in loader:
            opt.zero_grad()
            option_features = batch["option_features"].to(device)
            mask = batch["mask"].to(device)
            targets = batch["targets"].to(device)
            p_succ_opt = batch.get("p_succ_opt")
            extra_costs = batch.get("extra_costs")
            if p_succ_opt is not None: p_succ_opt = p_succ_opt.to(device)
            if extra_costs is not None: extra_costs = extra_costs.to(device)

            logits = model(option_features, mask, p_succ_opt, extra_costs)
            loss = masked_cross_entropy(logits, targets, mask)
            loss.backward(); opt.step()
            total += loss.item() * option_features.size(0)
        print(f"[Choice] epoch {ep+1}: loss={total/len(loader.dataset):.4f}")


# -------------------------------
# Worked example with synthetic data (end-to-end)
# -------------------------------

def _pearsonr(x: torch.Tensor, y: torch.Tensor) -> float:
    xm = (x - x.mean()) / (x.std() + 1e-8)
    ym = (y - y.mean()) / (y.std() + 1e-8)
    return float((xm * ym).mean().item())


def generate_synthetic_success_data(N: int = 4000) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Return train/val splits for the success head.
    Feature order (Ds=6): [miss_dist, flash_recency, flash_brightness, alignment, speed, curvature]
    """
    torch.manual_seed(42)
    miss = torch.rand(N) * 1.5                              # larger = worse
    recency = torch.rand(N)                                  # more recent flash = better (lower recency)
    bright = torch.rand(N)                                   # brighter = better
    align = torch.rand(N) * 2 - 1                            # cosine alignment [-1,1]
    speed = 0.5 + 0.5 * torch.rand(N)                        # 0.5..1.0
    curvature = 0.3 * torch.rand(N)                          # small turns

    X = torch.stack([miss, recency, bright, align, speed, curvature], dim=1)

    # Ground-truth logit for synthetic labels
    lin = (
        -1.0
        + (-2.5) * miss
        + (-0.8) * recency
        + (2.0) * bright
        + (1.6) * align
        + (0.3) * speed
        + (-0.2) * curvature
    )
    p = torch.sigmoid(lin + 0.25 * torch.randn(N))
    y = torch.bernoulli(p).float()

    idx = torch.randperm(N)
    ntr = int(0.8 * N)
    tr, va = idx[:ntr], idx[ntr:]
    return X[tr], y[tr], X[va], y[va]


def evaluate_success_head(model: SuccessPredictor, Xva: torch.Tensor, yva: torch.Tensor) -> None:
    with torch.no_grad():
        p = model(Xva)
        acc = ((p > 0.5) == (yva > 0.5)).float().mean().item()
        rho_miss = _pearsonr(p, -Xva[:, 0])           # higher when miss smaller
        rho_bri = _pearsonr(p, Xva[:, 2])             # higher when brighter
        rho_align = _pearsonr(p, Xva[:, 3])           # higher when aligned
    print(f"[Success/Val] acc={acc:.3f}  corr(p, -miss)={rho_miss:.2f}  corr(p, bright)={rho_bri:.2f}  corr(p, align)={rho_align:.2f}")


def build_choice_items(success_model: SuccessPredictor, N: int = 2500) -> List[Dict[str, torch.Tensor]]:
    """Construct a synthetic choice dataset using the same latent factors.
    We include p_succ from the trained success head.
    option_features: [alignment, cluster_density]  (Do=2)
    extra_costs:    [time_to_go, turn_cost]       (Dc=2)
    """
    torch.manual_seed(123)
    items: List[Dict[str, torch.Tensor]] = []
    for _ in range(N):
        K = int(torch.randint(low=2, high=6, size=(1,)).item())  # 2..5 options

        # Latents per option
        cluster = torch.rand(K)                           # more = denser cluster nearby
        distance = 0.2 + 1.8 * torch.rand(K)             # time-to-go proxy
        turn = torch.rand(K)                              # normalized turn cost 0..1
        align = torch.clamp(1.0 - 2.0 * turn + 0.2 * torch.randn(K), -1.0, 1.0)  # correlated with turn
        recency = torch.rand(K)
        bright = torch.rand(K)
        speed = 0.5 + 0.5 * torch.rand(K)
        curvature = 0.3 * torch.rand(K)
        miss = torch.clamp(0.2 * distance + 0.7 * (1 - bright) + 0.1 * torch.randn(K), 0.0, 1.5)

        # True success logits (same as in success generator, without extra noise)
        lin_true = (
            -1.0 + (-2.5) * miss + (-0.8) * recency + 2.0 * bright + 1.6 * align + 0.3 * speed + (-0.2) * curvature
        )
        p_true = torch.sigmoid(lin_true)

        # p_succ from the trained head
        stop_feats = torch.stack([miss, recency, bright, align, speed, curvature], dim=1)
        with torch.no_grad():
            p_hat = success_model(stop_feats)

        # Underlying utility that generates the observed choice
        # U = w1*logit(p_succ_true) - w2*time - w3*turn + w4*cluster + w5*align + noise
        logit_true = torch.logit(p_true.clamp(1e-4, 1 - 1e-4))
        U = 1.6 * logit_true - 1.0 * distance - 0.8 * turn + 0.6 * cluster + 0.3 * align + 0.30 * torch.randn(K)
        chosen = int(torch.argmax(U).item())

        option_features = torch.stack([align, cluster], dim=1)  # [K,2]
        extra_costs = torch.stack([distance, turn], dim=1)      # [K,2]

        items.append({
            "option_features": option_features,
            "extra_costs": extra_costs,
            "p_succ_opt": p_hat,
            "option_mask": torch.ones(K, dtype=torch.bool),
            "chosen_index": torch.tensor(chosen),
        })
    return items


def eval_choice_top1(model: ChoiceScorer, loader: DataLoader, device: str = "cpu") -> float:
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in loader:
            option_features = batch["option_features"].to(device)
            mask = batch["mask"].to(device)
            targets = batch["targets"].to(device)
            p_succ_opt = batch.get("p_succ_opt")
            extra_costs = batch.get("extra_costs")
            if p_succ_opt is not None: p_succ_opt = p_succ_opt.to(device)
            if extra_costs is not None: extra_costs = extra_costs.to(device)

            logits = model(option_features, mask, p_succ_opt, extra_costs)
            pred = torch.argmax(logits, dim=1)
            correct += (pred == targets).sum().item()
            total += targets.numel()
    return correct / max(1, total)



# --- Example with evaluation & validation (drop-in runnable) ---
# You can import this file as a module and call `run_toy_example_with_eval()` to see
# a full train/val split, metrics (accuracy, log-loss, ROC-AUC for success head;
# top-1/top-2 accuracy & NLL for the choice model), and a concrete episode that
# matches the toy you asked about.

from typing import Any


def _roc_auc_basic(y_true: np.ndarray, y_score: np.ndarray) -> float:
    """Compute ROC-AUC from scratch (no sklearn)."""
    y_true = y_true.astype(np.float64)
    y_score = y_score.astype(np.float64)
    n_pos = int(y_true.sum())
    n_neg = int(len(y_true) - n_pos)
    if n_pos == 0 or n_neg == 0:
        return float("nan")
    # Rank scores (average tie handling via argsort twice)
    order = np.argsort(y_score)
    ranks = np.empty_like(order)
    ranks[order] = np.arange(1, len(y_score) + 1)  # 1-based ranks
    sum_ranks_pos = ranks[y_true == 1].sum()
    auc = (sum_ranks_pos - n_pos * (n_pos + 1) / 2) / (n_pos * n_neg)
    return float(auc)


def _brier(y_true: np.ndarray, p: np.ndarray) -> float:
    return float(np.mean((p - y_true) ** 2))


def _make_toy_success_data(n: int = 1200, d: int = 6, seed: int = 123) -> Tuple[np.ndarray, np.ndarray]:
    rng = np.random.default_rng(seed)
    X = rng.normal(size=(n, d))
    w = rng.normal(size=(d,))
    b = 0.15
    logits = X @ w + b
    p = 1 / (1 + np.exp(-logits))
    y = rng.binomial(1, p)
    return X.astype(np.float32), y.astype(np.float32)


def _toy_episode_you_requested() -> Dict[str, torch.Tensor]:
    return {
        # episode 1: K=3 options (e.g., {retry, ff#7, ff#12})
        "option_features": torch.tensor([[0.1, 0.2], [-0.3, 0.5], [0.4, -0.1]], dtype=torch.float32),  # [3, D_o=2]
        "p_succ_opt":     torch.tensor([0.70, 0.25, 0.55], dtype=torch.float32),                       # [3]
        "extra_costs":    torch.tensor([[0.9, 0.1], [0.2, 0.7], [0.6, 0.3]], dtype=torch.float32),     # [3, D_c=2]
        "chosen_index":   torch.tensor(0),  # picked the first option (retry)
    }


def _make_toy_choice_items(n_episodes: int = 600, Do: int = 2, Dc: int = 2, seed: int = 7) -> List[Dict[str, torch.Tensor]]:
    """
    Generates variable-K episodes. Includes the user-specified first episode, followed by
    synthetic episodes where the 'observed' choice is argmax of a latent utility.
    """
    rng = np.random.default_rng(seed)
    items: List[Dict[str, torch.Tensor]] = []
    # include the requested episode first
    items.append(_toy_episode_you_requested())
    # latent weights for generator
    w_feat = rng.normal(size=(Do,))
    w_cost = -0.6 * np.ones(Dc)
    alpha_ps = 1.0
    for _ in range(n_episodes - 1):
        K = rng.integers(2, 6)
        feat = rng.normal(size=(K, Do))
        ps_logits = 0.9 * feat[:, 0] + 0.7 * (feat[:, 1] if Do > 1 else 0) + rng.normal(scale=0.4, size=K)
        p_succ = 1 / (1 + np.exp(-ps_logits))
        costs = rng.normal(size=(K, Dc))
        util = feat @ w_feat + alpha_ps * p_succ + costs @ w_cost + rng.normal(scale=0.25, size=K)
        chosen = int(np.argmax(util))
        items.append({
            "option_features": torch.tensor(feat, dtype=torch.float32),
            "p_succ_opt": torch.tensor(p_succ, dtype=torch.float32),
            "extra_costs": torch.tensor(costs, dtype=torch.float32),
            "chosen_index": torch.tensor(chosen),
        })
    return items


def _split_idx(n: int, val_frac: float = 0.2, seed: int = 0) -> Tuple[np.ndarray, np.ndarray]:
    rng = np.random.default_rng(seed)
    perm = rng.permutation(n)
    cut = int((1 - val_frac) * n)
    return perm[:cut], perm[cut:]


def _eval_choice(model: ChoiceScorer, loader: DataLoader, device: str = "cpu") -> Dict[str, Any]:
    model.eval()
    n, correct1, correct2 = 0, 0, 0
    total_nll = 0.0
    with torch.no_grad():
        for batch in loader:
            option_features = batch["option_features"].to(device)
            mask = batch["mask"].to(device)
            targets = batch["targets"].to(device)
            p_succ_opt = batch.get("p_succ_opt")
            extra_costs = batch.get("extra_costs")
            if p_succ_opt is not None: p_succ_opt = p_succ_opt.to(device)
            if extra_costs is not None: extra_costs = extra_costs.to(device)
            logits = model(option_features, mask, p_succ_opt, extra_costs)
            # CE per-sample
            nll = masked_cross_entropy(logits, targets, mask)
            total_nll += float(nll) * targets.size(0)
            # top-1
            top1 = torch.argmax(logits, dim=1)
            correct1 += int((top1 == targets).sum())
            # top-2
            top2_vals, top2_idx = torch.topk(logits, k=2, dim=1)
            correct2 += int(((top2_idx[:, 0] == targets) | (top2_idx[:, 1] == targets)).sum())
            n += targets.size(0)
    return {
        "top1_acc": correct1 / max(n, 1),
        "top2_acc": correct2 / max(n, 1),
        "avg_nll": total_nll / max(n, 1),
    }


def run_toy_example_with_eval(device: str = "cpu") -> None:
    """
    End-to-end demo with train/val split and metrics for both heads.
    Prints:
      - Success head: accuracy, log-loss, ROC-AUC, Brier score on val set
      - Choice model: top-1 / top-2 accuracy and average NLL on val set
    """
    # --- 3A) Success head ---
    X, y = _make_toy_success_data(n=2000, d=6, seed=11)
    tr_idx, va_idx = _split_idx(len(y), val_frac=0.2, seed=11)
    Xtr, ytr = X[tr_idx], y[tr_idx]
    Xva, yva = X[va_idx], y[va_idx]

    succ_model = SuccessPredictor(in_dim=X.shape[1])
    tr_loader = DataLoader(SuccessDataset(Xtr, ytr), batch_size=64, shuffle=True)
    train_success_head(succ_model, tr_loader, epochs=8, lr=2e-3, device=device)

    with torch.no_grad():
        p_va = succ_model(torch.tensor(Xva, dtype=torch.float32)).cpu().numpy()
    acc = float(((p_va >= 0.5).astype(np.float32) == yva).mean())
    eps = 1e-6
    logloss = float(-(yva * np.log(np.clip(p_va, eps, 1 - eps)) + (1 - yva) * np.log(np.clip(1 - p_va, eps, 1 - eps))).mean())
    auc = _roc_auc_basic(yva, p_va)
    brier = _brier(yva, p_va)

    print("=== Success head (val) ===")
    print(f"Accuracy: {acc:.3f}  LogLoss: {logloss:.3f}  ROC-AUC: {auc:.3f}  Brier: {brier:.3f}")

    # --- 3B) Choice model ---
    items = _make_toy_choice_items(n_episodes=1200, Do=2, Dc=2, seed=22)
    tr_idx, va_idx = _split_idx(len(items), val_frac=0.2, seed=22)
    train_items = [items[i] for i in tr_idx]
    val_items   = [items[i] for i in va_idx]

    train_loader = DataLoader(ChoiceDataset(train_items), batch_size=64, shuffle=True, collate_fn=choice_collate)
    val_loader   = DataLoader(ChoiceDataset(val_items),   batch_size=64, shuffle=False, collate_fn=choice_collate)

    chooser = ChoiceScorer(option_dim=2, extra_cost_dim=2, use_psucc=True)
    train_choice_model(chooser, train_loader, epochs=8, lr=2e-3, device=device)

    stats = _eval_choice(chooser, val_loader, device=device)
    print("=== Choice model (val) ===")
    print(f"Top-1 Acc: {stats['top1_acc']:.3f}  Top-2 Acc: {stats['top2_acc']:.3f}  Avg NLL: {stats['avg_nll']:.3f}")

    # Also print the first episode (your requested one) with the model's softmax
    # scores so you can see how it ranks {retry, ff#7, ff#12}.
    first = items[0]
    with torch.no_grad():
        logits = chooser(
            first["option_features"][None, ...],
            torch.ones(1, first["option_features"].shape[0], dtype=torch.bool),
            first["p_succ_opt"][None, ...],
            first["extra_costs"][None, ...],
        )
        probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
    print("Episode 1 — softmax over options [retry, ff#7, ff#12]:", np.round(probs, 3).tolist(), " chosen:", int(first["chosen_index"]))

# End of example section


# functions

In [6]:
"""
Tiny PyTorch skeleton for MultiFF retry/switch modeling.

Implements the two core training steps and leaves hooks for the optional
variants. The code is intentionally lightweight: small MLPs, masking for
variable option counts, and minimal training loops.

You can paste this into a file and adapt the Dataset stubs to your
preprocessing.
"""
from __future__ import annotations

import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


# -------------------------------
# Utilities
# -------------------------------

def mlp(in_dim: int, hidden: List[int], out_dim: int, dropout: float = 0.0) -> nn.Sequential:
    layers: List[nn.Module] = []
    last = in_dim
    for h in hidden:
        layers += [nn.Linear(last, h), nn.ReLU()]  # keep it simple
        if dropout > 0:
            layers += [nn.Dropout(dropout)]
        last = h
    layers += [nn.Linear(last, out_dim)]
    return nn.Sequential(*layers)


def masked_cross_entropy(logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
    """
    Cross-entropy for variable option sets.
    logits: [B, Kmax]
    targets: [B] (index of chosen option in 0..K-1)
    mask: [B, Kmax] boolean (True for valid options)
    """
    # Put -inf on invalid options so softmax ignores them
    logits_masked = logits.masked_fill(~mask, -1e9)
    return F.cross_entropy(logits_masked, targets)


# -------------------------------
# 3A) Capture-success predictor
# -------------------------------

class SuccessPredictor(nn.Module):
    """
    Predicts p_succ = P(next stop is inside boundary for that target).

    Inputs: stop_features: [B, D_s]
    Output: p_succ: [B]
    """
    def __init__(self, in_dim: int, hidden: List[int] = [32]):
        super().__init__()
        self.net = mlp(in_dim, hidden, 1, dropout=0.0)

    def forward(self, stop_features: torch.Tensor) -> torch.Tensor:
        return torch.sigmoid(self.net(stop_features)).squeeze(-1)


# -------------------------------
# 3B) Choice scorer over {retry} ∪ {other targets}
# -------------------------------

class ChoiceScorer(nn.Module):
    """
    Scores each candidate option with a utility U_i(t) and produces logits.

    Forward inputs:
      - option_features: [B, Kmax, D_o]
      - mask: [B, Kmax] boolean (True for valid candidates)
      - p_succ_opt: Optional[Tensor] [B, Kmax] (per-option success proba)
      - extra_costs: Optional[Tensor] [B, Kmax, D_c] (e.g., time-to-go, turn cost)

    Returns:
      - logits: [B, Kmax]
    """
    def __init__(self, option_dim: int, extra_cost_dim: int = 0, use_psucc: bool = True,
                 hidden: List[int] = [64, 32]):
        super().__init__()
        self.use_psucc = use_psucc
        in_dim = option_dim + extra_cost_dim + (1 if use_psucc else 0)
        self.net = mlp(in_dim, hidden, 1, dropout=0.0)

    def forward(
        self,
        option_features: torch.Tensor,
        mask: torch.Tensor,
        p_succ_opt: Optional[torch.Tensor] = None,
        extra_costs: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        B, Kmax, D_o = option_features.shape
        feats = [option_features]
        if self.use_psucc:
            if p_succ_opt is None:
                raise ValueError("p_succ_opt must be provided when use_psucc=True")
            feats.append(p_succ_opt.unsqueeze(-1))  # [B, Kmax, 1]
        if extra_costs is not None:
            feats.append(extra_costs)
        X = torch.cat(feats, dim=-1)  # [B, Kmax, Din]
        logits = self.net(X).squeeze(-1)  # [B, Kmax]
        # Mask invalid options with -inf to avoid accidental selection
        logits = logits.masked_fill(~mask, -1e9)
        return logits


# -------------------------------
# Optional variants (4)
# -------------------------------

class RetrySwitchHead(nn.Module):
    """Binary logit: retry vs switch after a near-miss."""
    def __init__(self, in_dim: int, hidden: List[int] = [32]):
        super().__init__()
        self.net = mlp(in_dim, hidden, 1)

    def forward(self, features: torch.Tensor) -> torch.Tensor:
        # returns P(retry)
        return torch.sigmoid(self.net(features)).squeeze(-1)


class TwoStagePolicy(nn.Module):
    """
    Stage 1: binary retry vs switch (on near-miss features).
    Stage 2: if switch, choose among other targets using ChoiceScorer.
    """
    def __init__(self, retry_in_dim: int, option_dim: int, extra_cost_dim: int = 0, use_psucc: bool = True):
        super().__init__()
        self.retry_head = RetrySwitchHead(retry_in_dim)
        self.choice = ChoiceScorer(option_dim, extra_cost_dim, use_psucc)

    def forward(self, retry_features: torch.Tensor,
                option_features: torch.Tensor, mask: torch.Tensor,
                p_succ_opt: Optional[torch.Tensor] = None,
                extra_costs: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        p_retry = self.retry_head(retry_features)
        logits_switch = self.choice(option_features, mask, p_succ_opt, extra_costs)
        return p_retry, logits_switch


class HazardHead(nn.Module):
    """
    Discrete-time hazard after a near-miss: h_t = P(switch at t | not switched yet).

    Input per step features X_t -> h_t via sigmoid(MLP).
    Loss implemented via discrete-time survival likelihood.
    """
    def __init__(self, in_dim: int, hidden: List[int] = [32]):
        super().__init__()
        self.net = mlp(in_dim, hidden, 1)

    def forward(self, step_features: torch.Tensor, step_mask: torch.Tensor) -> torch.Tensor:
        """
        step_features: [B, T, D]
        step_mask: [B, T] (True for valid time steps)
        Returns hazards h_t in [0,1]: [B, T]
        """
        h = torch.sigmoid(self.net(step_features)).squeeze(-1)
        return h * step_mask.float()

    @staticmethod
    def survival_nll(h: torch.Tensor, event_index: torch.Tensor, step_mask: torch.Tensor) -> torch.Tensor:
        """
        Negative log-likelihood for discrete-time hazards.
        h: [B, T] hazards
        event_index: [B] index of switch time; if censored (no switch in window), set to -1
        step_mask: [B, T]
        """
        # log S_{t} = sum_{k < t} log(1 - h_k);  log f(t) = log S_t + log h_t
        eps = 1e-6
        log1m_h = torch.log(torch.clamp(1 - h, min=eps)) * step_mask
        cumsums = torch.cumsum(log1m_h, dim=1)  # [B, T]
        B, T = h.shape
        nll = []
        for b in range(B):
            t_star = event_index[b].item()
            if t_star >= 0:  # observed switch
                surv = cumsums[b, t_star - 1] if t_star > 0 else torch.tensor(0.0, device=h.device)
                log_h = torch.log(torch.clamp(h[b, t_star], min=eps))
                nll.append(-(surv + log_h))
            else:  # censored at last valid step
                last = int(step_mask[b].nonzero(as_tuple=False)[-1])
                surv = cumsums[b, last]
                nll.append(-surv)
        return torch.stack(nll).mean()


# -------------------------------
# Belief / POMDP-ish helper (very simple)
# -------------------------------

@dataclass
class BeliefState:
    alpha: torch.Tensor  # evidence for success
    beta: torch.Tensor   # evidence for failure

    def p_succ(self) -> torch.Tensor:
        return self.alpha / (self.alpha + self.beta + 1e-9)


def update_belief(
    belief: BeliefState,
    flash_strength: torch.Tensor,
    miss_distance: Optional[torch.Tensor] = None,
    decay: float = 0.95,
) -> BeliefState:
    """
    Tiny heuristic updater: decay old evidence, add flash as positive evidence,
    add miss_distance (scaled) as negative evidence.
    """
    alpha = decay * belief.alpha + flash_strength
    if miss_distance is not None:
        beta = decay * belief.beta + miss_distance
    else:
        beta = decay * belief.beta
    return BeliefState(alpha=alpha, beta=beta)


# -------------------------------
# Datasets (stubs you will replace)
# -------------------------------

class SuccessDataset(Dataset):
    """Each item: (stop_features [D_s], label {0,1})"""
    def __init__(self, X: torch.Tensor, y: torch.Tensor):
        assert X.ndim == 2 and y.ndim == 1
        self.X, self.y = X.float(), y.long()

    def __len__(self) -> int: return len(self.y)

    def __getitem__(self, i: int) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.X[i], self.y[i].float()


class ChoiceDataset(Dataset):
    """
    Each item is a dict with keys:
      - option_features: [K, D_o]
      - option_mask: [K] (bool)
      - chosen_index: int in [0, K-1]
      - p_succ_opt: Optional [K]
      - extra_costs: Optional [K, D_c]
    """
    def __init__(self, items: List[Dict[str, torch.Tensor]]):
        self.items = items

    def __len__(self): return len(self.items)

    def __getitem__(self, i: int) -> Dict[str, torch.Tensor]:
        return self.items[i]


def choice_collate(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
    Kmax = max(item["option_features"].shape[0] for item in batch)
    B = len(batch)
    D_o = batch[0]["option_features"].shape[1]
    option_features = torch.zeros(B, Kmax, D_o)
    mask = torch.zeros(B, Kmax, dtype=torch.bool)
    targets = torch.zeros(B, dtype=torch.long)
    p_succ_opt = None
    extra_costs = None

    has_ps = all("p_succ_opt" in item for item in batch)
    has_ec = all("extra_costs" in item for item in batch)

    if has_ps:
        p_succ_opt = torch.zeros(B, Kmax)
    if has_ec:
        D_c = batch[0]["extra_costs"].shape[1]
        extra_costs = torch.zeros(B, Kmax, D_c)

    for b, item in enumerate(batch):
        K = item["option_features"].shape[0]
        option_features[b, :K] = item["option_features"]
        mask[b, :K] = True
        targets[b] = int(item["chosen_index"])  # ensure within 0..K-1
        if has_ps:
            p_succ_opt[b, :K] = item["p_succ_opt"]
        if has_ec:
            extra_costs[b, :K] = item["extra_costs"]

    out = {"option_features": option_features, "mask": mask, "targets": targets}
    if has_ps:
        out["p_succ_opt"] = p_succ_opt
    if has_ec:
        out["extra_costs"] = extra_costs
    return out


# -------------------------------
# Training loops (minimal)
# -------------------------------

def train_success_head(model: SuccessPredictor, loader: DataLoader, epochs: int = 10, lr: float = 1e-3,
                       device: str = "cpu") -> None:
    model.to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr)
    bce = nn.BCELoss()
    model.train()
    for ep in range(epochs):
        total = 0.0
        for X, y in loader:
            X, y = X.to(device), y.to(device)
            p = model(X)
            loss = bce(p, y)
            opt.zero_grad(); loss.backward(); opt.step()
            total += loss.item() * X.size(0)
        print(f"[Success] epoch {ep+1}: loss={total/len(loader.dataset):.4f}")


def train_choice_model(model: ChoiceScorer, loader: DataLoader, epochs: int = 10, lr: float = 1e-3,
                       device: str = "cpu") -> None:
    model.to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr)
    model.train()
    for ep in range(epochs):
        total = 0.0
        for batch in loader:
            opt.zero_grad()
            option_features = batch["option_features"].to(device)
            mask = batch["mask"].to(device)
            targets = batch["targets"].to(device)
            p_succ_opt = batch.get("p_succ_opt")
            extra_costs = batch.get("extra_costs")
            if p_succ_opt is not None: p_succ_opt = p_succ_opt.to(device)
            if extra_costs is not None: extra_costs = extra_costs.to(device)

            logits = model(option_features, mask, p_succ_opt, extra_costs)
            loss = masked_cross_entropy(logits, targets, mask)
            loss.backward(); opt.step()
            total += loss.item() * option_features.size(0)
        print(f"[Choice] epoch {ep+1}: loss={total/len(loader.dataset):.4f}")


# -------------------------------
# Tiny smoke test (random data)
# -------------------------------
if __name__ == "__main__":
    torch.manual_seed(0)

    # 3A) Success head on random data
    Xs = torch.randn(512, 10)
    ys = torch.bernoulli(torch.full((512,), 0.5))
    ds = SuccessDataset(Xs, ys)
    dl = DataLoader(ds, batch_size=64, shuffle=True)
    succ = SuccessPredictor(in_dim=10)
    train_success_head(succ, dl, epochs=2)

    # 3B) Choice model with masking
    items = []
    for _ in range(256):
        K = torch.randint(low=2, high=6, size=(1,)).item()  # 2..5 options
        option_feats = torch.randn(K, 8)
        p_succ_opt = torch.sigmoid(torch.randn(K))
        extra_costs = torch.randn(K, 2)
        chosen = torch.randint(low=0, high=K, size=(1,)).item()
        items.append({
            "option_features": option_feats,
            "p_succ_opt": p_succ_opt,
            "extra_costs": extra_costs,
            "option_mask": torch.ones(K, dtype=torch.bool),
            "chosen_index": torch.tensor(chosen),
        })
    dc = ChoiceDataset(items)
    dlc = DataLoader(dc, batch_size=32, shuffle=True, collate_fn=choice_collate)
    chooser = ChoiceScorer(option_dim=8, extra_cost_dim=2, use_psucc=True)
    train_choice_model(chooser, dlc, epochs=2)

    print("Smoke test complete.")


[Success] epoch 1: loss=0.6924
[Success] epoch 2: loss=0.6911
[Choice] epoch 1: loss=1.1762
[Choice] epoch 2: loss=1.1666
Smoke test complete.
