In [13]:
# foraging_sim.py
# Explorer → Forager on a shared DataFrame; no 'unknown' anywhere.
# Explorer reveals mines via per-mine 'revealed' flags; auto-labels when fully revealed / empty.
# Forager softmax decisions + animations that reflect live, shared env.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
from matplotlib import colors as mcolors
from dataclasses import dataclass
from typing import Optional, Dict, Tuple, List
import matplotlib.patches as mpatches

# ------------------- Color map for environment richness -------------------
RICHNESS_COLORS = {
    "poor": "sandybrown",
    "neutral": "lightgreen",
    "rich": "gold",
}

# ------------------- Environment initializer (TRUE has no 'unknown') -------------------
def init_gridworld(
    size: int = 3,
    seed: Optional[int] = None,
    # TRUE overall richness distribution
    p_overall: List[float] = (0.30, 0.45, 0.25),  # poor, neutral, rich
    # P(#mines = 0..3)
    p_mines: List[float] = (0.25, 0.40, 0.25, 0.10),
    # Optional preset visible map: dict[(row,col)->label] to override labels (e.g., a known map)
    preset_visible_overall: Optional[Dict[Tuple[int,int], str]] = None,
) -> pd.DataFrame:
    """
    Builds an N×N grid with:
      - TRUE hidden fields: overall∈{poor,neutral,rich}, mines∈{poor,neutral,rich} or None
      - VISIBLE fields start labeled (from preset if provided, else from TRUE).
      - Mines start HIDDEN using explicit 'Mine i revealed' flags (False where a TRUE mine exists).
        Visible mine fields are None until revealed by the explorer.
    """
    rng = np.random.default_rng(seed if seed is not None else None)

    true_overall_labels = np.array(["poor", "neutral", "rich"])
    p_overall = np.asarray(p_overall, dtype=float)
    p_overall = p_overall / p_overall.sum()

    p_mines = np.asarray(p_mines, dtype=float); p_mines /= p_mines.sum()

    mine_cat_given_overall = {
        "poor":    {"poor": 0.70, "neutral": 0.25, "rich": 0.05},
        "neutral": {"poor": 0.25, "neutral": 0.50, "rich": 0.25},
        "rich":    {"poor": 0.10, "neutral": 0.30, "rich": 0.60},
    }
    reward_prob_map = {"poor": 0.20, "neutral": 0.50, "rich": 0.80}

    def sample_k(ov: str) -> int:
        base = p_mines.copy()
        if ov == "rich":
            base[:2] = 0.0   # at least 2 mines
        elif ov == "poor":
            base[3] = 0.0   # cannot have 3 mines
        s = base.sum()
        if s == 0:
            return 2 if ov == "rich" else (1 if ov == "poor" else 0)
        base /= s
        return int(rng.choice([0, 1, 2, 3], p=base))

    def sample_mine_cat(ov: str, k: int) -> List[Optional[str]]:
        if k == 0:
            return [None, None, None]
        probs_map = mine_cat_given_overall[ov]
        cats = np.array(["poor", "neutral", "rich"])
        probs = np.array([probs_map["poor"], probs_map["neutral"], probs_map["rich"]], dtype=float)
        probs = probs / probs.sum()
        sampled = rng.choice(cats, size=k, p=probs).tolist()
        return (sampled + [None] * (3 - k))[:3]

    digs_range = {"poor": (1, 2), "neutral": (2, 4), "rich": (3, 6)}
    def digs_allowed_for(cat: Optional[str]) -> Optional[int]:
        if cat is None: return None
        low, high = digs_range[cat]
        return int(rng.integers(low, high + 1))

    N = size
    center = (N // 2, N // 2)
    rows = []

    for r in range(N):
        for c in range(N):
            is_center = (r, c) == center
            if is_center:
                overall_true = "poor"   # base is empty poor
                t1 = t2 = t3 = None
                rp1 = rp2 = rp3 = None
                d1 = d2 = d3 = None
                env_empty_true = True
            else:
                overall_true = rng.choice(true_overall_labels, p=p_overall).item()
                k = sample_k(overall_true)
                t1, t2, t3 = sample_mine_cat(overall_true, k)
                rp1 = reward_prob_map[t1] if t1 else None
                rp2 = reward_prob_map[t2] if t2 else None
                rp3 = reward_prob_map[t3] if t3 else None
                d1 = digs_allowed_for(t1)
                d2 = digs_allowed_for(t2)
                d3 = digs_allowed_for(t3)
                env_empty_true = (k == 0)

            rows.append({
                "Location": f"{r}:{c}", "Row": r, "Col": c, "is_center": is_center,

                # TRUE (no 'unknown' anywhere)
                "TRUE Number 1 mines": t1,
                "TRUE Number 2 mines": t2,
                "TRUE Number 3 mines": t3,
                "TRUE Mine 1 reward_prob": rp1,
                "TRUE Mine 2 reward_prob": rp2,
                "TRUE Mine 3 reward_prob": rp3,
                "TRUE Mine 1 digs_allowed": d1,
                "TRUE Mine 2 digs_allowed": d2,
                "TRUE Mine 3 digs_allowed": d3,
                "TRUE env_empty": env_empty_true,
                "TRUE Overall Richness": overall_true,

                # VISIBLE initial mines are hidden (revealed flags False where a TRUE mine exists)
                "Number 1 mines": None,
                "Number 2 mines": None,
                "Number 3 mines": None,
                "Mine 1 reward_prob": None,
                "Mine 2 reward_prob": None,
                "Mine 3 reward_prob": None,
                "Mine 1 digs_allowed": None,
                "Mine 2 digs_allowed": None,
                "Mine 3 digs_allowed": None,

                "Mine 1 revealed": (t1 is None),
                "Mine 2 revealed": (t2 is None),
                "Mine 3 revealed": (t3 is None),

                # Runtime
                "Overall Richness of this environment": None,  # set below
                "env_empty": env_empty_true,
                "visited_by_explorer": False,
            })

    env = pd.DataFrame(rows)

    # --- Visible overall richness: use PRESET if provided; otherwise TRUE ---
    if preset_visible_overall:
        env["Overall Richness of this environment"] = env["TRUE Overall Richness"]
        for (rr, cc), lab in preset_visible_overall.items():
            env.loc[(env["Row"] == rr) & (env["Col"] == cc),
                    "Overall Richness of this environment"] = lab
    else:
        env["Overall Richness of this environment"] = env["TRUE Overall Richness"]

    return env

In [14]:
# ------------------- Runtime helpers (operate IN-PLACE on the shared env) -------------------
def ensure_runtime_columns(env: pd.DataFrame) -> pd.DataFrame:
    # Keep digs_remaining columns
    for i in (1, 2, 3):
        digs_col = f"Mine {i} digs_allowed"
        rem_col  = f"Mine {i} digs_remaining"
        if rem_col not in env.columns:
            env[rem_col] = env[digs_col]
        else:
            mask = env[rem_col].isna() & env[digs_col].notna()
            env.loc[mask, rem_col] = env.loc[mask, digs_col]

    # Ensure revealed flags exist (if missing, infer from visible mine fields)
    for i in (1, 2, 3):
        rev_col = f"Mine {i} revealed"
        if rev_col not in env.columns:
            # revealed if visible category is not None OR TRUE mine absent
            env[rev_col] = env[f"Number {i} mines"].notna() | env[f"TRUE Number {i} mines"].isna()

    if "visited_by_explorer" not in env.columns:
        env["visited_by_explorer"] = False
    if "env_empty" not in env.columns:
        env["env_empty"] = False
    return env  # same object

def _indexify(env: pd.DataFrame) -> pd.DataFrame:
    """Ensure we have a MultiIndex on (Row, Col) while KEEPING columns."""
    if list(env.index.names) != ["Row", "Col"]:
        env.set_index(["Row", "Col"], inplace=True, drop=False)
    return env

def _cell_total_remaining_digs_df(env_idx: pd.DataFrame, pos: Tuple[int, int]) -> int:
    s = 0
    for i in (1, 2, 3):
        rem = env_idx.loc[pos, f"Mine {i} digs_remaining"]
        if pd.notna(rem): s += int(float(rem))
    return s

In [None]:
# ------------------- Explorer -------------------
@dataclass
class ExplorerConfig:
    init_resource: int = 100
    move_cost: int = 4
    scan_cost: int = 2
    gamma: float = 1.0
    beta_local: float = 0.6
    beta_global: float = 0.4
    avoid_base: bool = True
    no_backtrack: bool = True  # never revisit an already visited block

class ExplorerAgent:
    """
    - Single scan per cell (reveals ONE hidden mine slot).
    - No backtracking (cannot enter a previously visited cell).
    - If neither move nor scan is possible, STOP (do not drain resources).
    - Writes directly into the SHARED env.
    - Auto-refresh labeling:
        * After first visit, if the tile truly has no mines -> label 'poor'.
        * When all TRUE mines in the tile are revealed -> set label to TRUE overall.
    - Transfer map to Forager: export_env_for_forager() returns the SAME env reference.
    """
    def __init__(self, env_df: pd.DataFrame, cfg: ExplorerConfig, seed: Optional[int] = None):
        self.cfg = cfg
        self.rng = np.random.default_rng(seed if seed is not None else None)

        self.env = ensure_runtime_columns(env_df)
        _indexify(self.env)

        self.resource = cfg.init_resource
        size = int(self.env.index.get_level_values(0).max()) + 1
        self.base = (size // 2, size // 2)
        self.pos = self.base
        self.left_base_once = False
        self.t = 0
        self.total_moves = 0
        self.log: List[Dict] = []

        # open base
        self.env.loc[self.pos, 'visited_by_explorer'] = True
        self._autolabel_if_ready(self.pos)  # may set 'poor' if base has no mines

        # enforce single scan per cell
        self.scanned_once: set[Tuple[int,int]] = set()

        # Animation buffers
        self.frames_pos: List[Tuple[int, int]] = []
        self.frames_resource: List[float] = []
        self.frames_action: List[str] = []
        self.frames_unrevealed_mask: List[np.ndarray] = []
        self.frames_reward_dummy: List[float] = []
        self.frames_decision: List[str] = []
        self._snapshot("start")

    # ---- utilities ----
    def _neighbors(self):
        r, c = self.pos
        neigh = [(r-1, c), (r+1, c), (r, c-1), (r, c+1)]
        valid = [(rr, cc) for rr, cc in neigh if (rr, cc) in self.env.index]
        if self.cfg.avoid_base and self.left_base_once:
            valid = [p for p in valid if p != self.base]
        if self.cfg.no_backtrack:
            valid = [p for p in valid if not bool(self.env.loc[p, 'visited_by_explorer'])]
        return valid

    def _hidden_mines_here(self) -> List[int]:
        mines = []
        for i in (1, 2, 3):
            has_true = self.env.loc[self.pos, f"TRUE Number {i} mines"] is not None
            revealed = bool(self.env.loc[self.pos, f"Mine {i} revealed"])
            if has_true and (not revealed):
                mines.append(i)
        return mines

    def _global_hidden_count(self) -> int:
        cnt = 0
        for pos in self.env.index:
            for i in (1, 2, 3):
                has_true = self.env.loc[pos, f"TRUE Number {i} mines"] is not None
                revealed = bool(self.env.loc[pos, f"Mine {i} revealed"])
                if has_true and (not revealed):
                    cnt += 1
        return cnt

    def _fully_revealed(self, pos: Tuple[int,int]) -> bool:
        # fully revealed if every TRUE mine slot at pos has revealed=True
        for i in (1,2,3):
            has_true = self.env.loc[pos, f"TRUE Number {i} mines"] is not None
            if has_true and (not bool(self.env.loc[pos, f"Mine {i} revealed"])):
                return False
        return True

    def _autolabel_if_ready(self, pos: Tuple[int, int]) -> None:
        """
        Refresh 'Overall Richness of this environment' if:
          - explorer has visited this pos, AND
          - (a) tile has no TRUE mines -> label 'poor'
          - (b) OR all TRUE mines at the tile are revealed -> set to TRUE overall.
        """
        if not bool(self.env.loc[pos, "visited_by_explorer"]):
            return

        has_true_mine = any(self.env.loc[pos, f"TRUE Number {i} mines"] is not None for i in (1,2,3))
        if not has_true_mine:
            self.env.loc[pos, "Overall Richness of this environment"] = "poor"
            return

        if self._fully_revealed(pos):
            self.env.loc[pos, "Overall Richness of this environment"] = self.env.loc[pos, "TRUE Overall Richness"]

    def _snapshot(self, action: str, decision: str):
        idx = self.env
        nrows = int(idx.index.get_level_values(0).max()) + 1
        ncols = int(idx.index.get_level_values(1).max()) + 1

        mask = np.zeros((nrows, ncols), dtype=bool)
        for pos in idx.index:
            r, c = pos
            # Unrevealed mask if any TRUE mine exists and is still hidden
            unrevealed = any(
                (idx.loc[pos, f"TRUE Number {i} mines"] is not None) and (not bool(idx.loc[pos, f"Mine {i} revealed"]))
                for i in (1,2,3)
            )
            mask[r, c] = unrevealed

        self.frames_unrevealed_mask.append(mask)
        self.frames_pos.append(tuple(self.pos))
        self.frames_resource.append(float(self.resource))
        self.frames_action.append(action)
        self.frames_decision.append(decision)
        self.frames_reward_dummy.append(0.0)

    def _log(self, **kw):
        kw.setdefault('step', self.t)
        kw.setdefault('row', self.pos[0])
        kw.setdefault('col', self.pos[1])
        kw.setdefault('resource', self.resource)
        kw.setdefault('total_moves', self.total_moves)
        kw.setdefault('hidden_global', self._global_hidden_count())
        self.log.append(kw)

    def _reveal_one_mine_here(self) -> bool:
        # enforce single-scan-per-cell
        if self.pos in self.scanned_once: return False
        candidates = self._hidden_mines_here()
        if not candidates or self.resource < self.cfg.scan_cost: return False

        mine_id = int(self.rng.choice(candidates))
        self.resource -= self.cfg.scan_cost

        # reveal from TRUE_* fields into visible
        tcat = self.env.loc[self.pos, f"TRUE Number {mine_id} mines"]
        trp  = self.env.loc[self.pos, f"TRUE Mine {mine_id} reward_prob"]
        tda  = self.env.loc[self.pos, f"TRUE Mine {mine_id} digs_allowed"]

        self.env.loc[self.pos, f"Number {mine_id} mines"] = tcat
        self.env.loc[self.pos, f"Mine {mine_id} reward_prob"] = trp
        self.env.loc[self.pos, f"Mine {mine_id} digs_allowed"] = tda
        self.env.loc[self.pos, f"Mine {mine_id} digs_remaining"] = tda
        self.env.loc[self.pos, f"Mine {mine_id} revealed"] = True

        self.scanned_once.add(self.pos)

        # After scanning, try to auto-label if now fully revealed
        self._autolabel_if_ready(self.pos)

        self._log(action='scan', mine=mine_id, revealed_cat=tcat, digs_allowed=tda, reward_prob=trp)
        self._snapshot("scan","decision")
        return True

    def _move(self) -> bool:
        neigh = self._neighbors()
        if not neigh or self.resource < self.cfg.move_cost: return False

        # prefer neighbors with more hidden mines (simple heuristic)
        weights = []
        eps = 1e-3
        for n in neigh:
            hidden = sum(
                (self.env.loc[n, f"TRUE Number {i} mines"] is not None) and (not bool(self.env.loc[n, f"Mine {i} revealed"]))
                for i in (1,2,3)
            )
            weights.append(eps + hidden)
        probs = np.array(weights, dtype=float); probs /= probs.sum()

        choice_idx = int(self.rng.choice(len(neigh), p=probs))
        choice = neigh[choice_idx]

        # pay & move
        self.resource -= self.cfg.move_cost
        if (self.pos == self.base) and (choice != self.base): self.left_base_once = True
        self.pos = choice
        self.total_moves += 1

        # OPEN tile for forager and attempt autolabel (e.g., visited empty -> 'poor')
        self.env.loc[self.pos, 'visited_by_explorer'] = True
        self._autolabel_if_ready(self.pos)

        self._log(action='move', move_probs=list(np.round(probs, 4)))
        self._snapshot("move","move")
        return True

    def _values(self) -> Tuple[float, float]:
        R_local = len(self._hidden_mines_here())
        R_global = self._global_hidden_count()
        tmove = max(self.total_moves, 1)
        stay = self.cfg.beta_local * (self.cfg.move_cost / max(self.cfg.scan_cost, 1e-9)) * self.cfg.gamma * (R_local / tmove)
        leave = self.cfg.beta_global * (max(self.cfg.scan_cost, 1e-9) / max(self.cfg.move_cost, 1e-9)) * self.cfg.gamma * (tmove / max(R_global, 1))
        return float(stay), float(leave)

    def step(self):
        # If neither action is affordable, stop this phase
        if (self.resource < self.cfg.scan_cost) and (self.resource < self.cfg.move_cost):
            self._log(action='halt', decision='insufficient_resources'); self._snapshot("halt","insufficient_resources")
            self.t += 1; return

        v_stay, v_leave = self._values()
        can_scan_here = (self.resource >= self.cfg.scan_cost) and (self.pos not in self.scanned_once) and (len(self._hidden_mines_here()) > 0)

        did_action = False
        # Prefer moving if leave "value" higher or scanning impossible
        if (not can_scan_here) or (v_leave > v_stay):
            did_action = self._move()

        # If couldn't move, try single scan here (once per cell)
        if (not did_action) and can_scan_here:
            did_action = self._reveal_one_mine_here()

        # If still no action possible -> STOP PHASE
        if not did_action:
            self._log(action='halt', decision='no_actions_left'); self._snapshot("halt","no_actions_left")

        self.t += 1

    # --- Transfer the map for forager decisions (same shared DataFrame) ---
    def export_env_for_forager(self) -> pd.DataFrame:
        return self.env

    def run(self, max_steps: int = 300) -> pd.DataFrame:
        for _ in range(max_steps):
            if (self.resource < self.cfg.scan_cost) and (self.resource < self.cfg.move_cost):
                self._log(action='halt', decision='insufficient_resources'); self._snapshot("halt","insufficient_resources")
                break
            self.step()
            if self.log and self.log[-1].get('action') == 'halt': break
        return pd.DataFrame(self.log)

In [None]:
# ------------------- Forager (softmax stay/leave + softmax neighbor move) -------------------
@dataclass
class MVTConfig:
    init_resource: int = 100
    move_cost: int = 10
    dig_cost: int = 5
    gamma: float = 1.0
    beta_trust: float = 0.7
    reward_amount: float = 1.0
    avoid_base: bool = True
    mine_choice_value: Dict[str, float] = None
    env_factor: Dict[str, float] = None
    label_value: Dict[str, float] = None
    stay_leave_temp: float = 0.7        # softmax temperature for stay vs leave
    move_temp: float = 0.7              # softmax temperature for neighbor move

    def __post_init__(self):
        if self.mine_choice_value is None:
            self.mine_choice_value = {'rich': 1.0, 'neutral': 0.6, 'poor': 0.2}
        if self.env_factor is None:
            self.env_factor = {'rich': 0.8, 'neutral': 0.5, 'poor': 0.2}
        if self.label_value is None:
            self.label_value = {'rich': 0.8, 'neutral': 0.5, 'poor': 0.2}

def _softmax(x: np.ndarray, temp: float = 1.0) -> np.ndarray:
    x = np.array(x, dtype=float) / max(temp, 1e-9)
    x = x - np.max(x)
    e = np.exp(x)
    return e / np.maximum(e.sum(), 1e-12)

class MVTAgent:
    """
    - Movement restricted to env['visited_by_explorer'] == True (opened tiles).
    - Mines are diggable iff digs_remaining > 0 (no 'unknown' checks).
    - Stay vs leave chosen via SOFTMAX over [value_stay, value_leave].
    - Neighbor move chosen via SOFTMAX over logits = beta*label + (1-beta)*perceived.
    - Operates on the SAME shared env (from Explorer.export_env_for_forager()).
    - Animation: recolors from current visible label each frame; dim by remaining digs.
    """
    def __init__(self, env_df: pd.DataFrame, cfg: MVTConfig, seed: Optional[int] = None):
        self.cfg = cfg
        self.rng = np.random.default_rng(seed if seed is not None else None)

        self.env = ensure_runtime_columns(env_df)  # same reference
        _indexify(self.env)

        self.resource = cfg.init_resource
        size = int(self.env.index.get_level_values(0).max()) + 1
        self.base = (size // 2, size // 2)
        self.pos = self.base
        self.left_base_once = False

        self.total_reward = 0.0
        self.total_digs = 0
        self.t = 0
        self.log: List[Dict] = []

        self.perceived_value: Dict[Tuple[int, int], float] = {}
        self.Nrows = int(self.env.index.get_level_values(0).max()) + 1
        self.Ncols = int(self.env.index.get_level_values(1).max()) + 1

        # Animation buffers
        self.frames_intensity: List[np.ndarray] = []
        self.frames_has_mines: List[np.ndarray] = []
        self.frames_pos: List[Tuple[int, int]] = []
        self.frames_reward: List[float] = []
        self.frames_resource: List[float] = []
        self.frames_action: List[str] = []
        self.frames_decision: List[str] = []


        self.update_current_richness()
        self._log_enter()
        self._snapshot_grid_state(action_label="start",decision_label="starting")

    # ---------- helpers ----------
    def _cell_overall(self, pos=None):
        if pos is None: pos = self.pos
        return self.env.loc[pos, 'Overall Richness of this environment']

    def _available_mines(self, pos=None) -> List[int]:
        if pos is None: pos = self.pos
        mines = []
        for i in (1, 2, 3):
            rem = self.env.loc[pos, f'Mine {i} digs_remaining']
            cat = self.env.loc[pos, f'Number {i} mines']
            if pd.notna(rem) and float(rem) > 0 and (cat is not None):
                mines.append(i)
        return mines

    def _mine_choice_weights(self, mines: List[int]) -> np.ndarray:
        ov = self._cell_overall()
        ov_fac = self.cfg.env_factor.get(ov, 0.5)
        vals = []
        for i in mines:
            cat = self.env.loc[self.pos, f'Number {i} mines']
            v = self.cfg.mine_choice_value.get(cat, ov_fac)
            vals.append(max(v, 1e-9))
        return np.array(vals, dtype=float)

    def _neighbors(self):
        r, c = self.pos
        neigh = [(r-1, c), (r+1, c), (r, c-1), (r, c+1)]
        valid = [(rr, cc) for rr, cc in neigh if (rr, cc) in self.env.index]
        if self.cfg.avoid_base and self.left_base_once:
            valid = [p for p in valid if p != self.base]
        # only tiles opened by explorer
        valid = [p for p in valid if bool(self.env.loc[p, 'visited_by_explorer'])]
        return valid

    # ---------- perceived/current richness ----------
    def _cell_perceived_value(self, pos: Tuple[int, int]) -> float:
        mines = self._available_mines(pos)
        if not mines: return 0.0
        vals = []
        for i in mines:
            cat = self.env.loc[pos, f'Number {i} mines']
            rem = self.env.loc[pos, f'Mine {i} digs_remaining']
            base_v = self.cfg.mine_choice_value.get(cat, 0.5)
            norm_rem = min(int(rem), 6) / 6.0
            vals.append(base_v * (0.5 + 0.5 * norm_rem))
        return float(np.mean(vals))

    def update_current_richness(self):
        self.perceived_value = {pos: self._cell_perceived_value(pos) for pos in self.env.index}

    # ---------- values ----------
    def _avg_reward(self) -> float:
        return self.total_reward / self.total_digs if self.total_digs > 0 else 0.0

    def value_stay(self) -> float:
        avg_rew = self._avg_reward()
        env_fac = self.cfg.env_factor.get(self._cell_overall(), 0.5)
        return avg_rew * env_fac * ((self.cfg.move_cost / self.cfg.dig_cost) * self.cfg.gamma)

    def value_leave(self) -> float:
        avg_rew = self._avg_reward()
        return avg_rew * self.cfg.dig_cost * ((self.cfg.dig_cost / self.cfg.move_cost) * self.cfg.gamma)

    # ---------- logging ----------
    def _log(self, **kw):
        kw.setdefault('step', self.t)
        kw.setdefault('row', self.pos[0]); kw.setdefault('col', self.pos[1])
        kw.setdefault('resource', self.resource)
        kw.setdefault('reward_total', self.total_reward)
        kw.setdefault('total_digs', self.total_digs)
        kw.setdefault('overall', self._cell_overall())
        kw.setdefault('perceived_here', self.perceived_value.get(self.pos, 0.0))
        self.log.append(kw)

    def _log_enter(self):
        self._log(action='enter', decision=None, v_stay=None, v_leave=None, note='entered_cell')

    def _update_cell_empty_flag(self, pos=None):
        if pos is None: pos = self.pos
        has_any = len(self._available_mines(pos)) > 0
        self.env.loc[pos, 'env_empty'] = (not has_any)

    # ---------- depletion snapshots for animation ----------
    def _snapshot_grid_state(self, action_label: str = "", decision_label: str = ""):
        MAX_DIGS_PER_CELL = 18
        inten = np.zeros((self.Nrows, self.Ncols), dtype=float)
        mask  = np.zeros((self.Nrows, self.Ncols), dtype=bool)
        for (r, c) in self.env.index:
            total = _cell_total_remaining_digs_df(self.env, (r, c))
            inten[r, c] = min(total / MAX_DIGS_PER_CELL, 1.0)
            mask[r, c] = (total > 0)

        self.frames_intensity.append(inten)
        self.frames_has_mines.append(mask)
        self.frames_pos.append(tuple(self.pos))
        self.frames_reward.append(float(self.total_reward))
        self.frames_resource.append(float(self.resource))
        self.frames_action.append(action_label if action_label else "")

    def _post_action_snapshot(self, action_label: str, decision_label: str):
        self._snapshot_grid_state(action_label=action_label, decision_label=decision_label)

    # ---------- actions ----------
    def _auto_leave_if_empty(self) -> bool:
        if len(self._available_mines()) > 0: return False
        if self.resource < self.cfg.move_cost or len(self._neighbors()) == 0:
            self._log(action='halt', decision='stuck_no_mines', v_stay=None, v_leave=None, note='no_mines_no_move')
            self._post_action_snapshot("halt","stuck_no_mines")
            return True
        self._move(auto_leave=True)
        return True

    def _dig(self):
        if self.resource < self.cfg.dig_cost:
            self._log(action='halt', decision='no_resource_to_dig',
                      v_stay=self.value_stay(), v_leave=self.value_leave())
            self._post_action_snapshot("halt","no_resources_to_dig")
            return False

        mines = self._available_mines()
        if not mines: return False

        w = self._mine_choice_weights(mines)
        mine_id = int(self.rng.choice(mines, p=w / w.sum()))
        cat = self.env.loc[self.pos, f'Number {mine_id} mines']
        p = self.env.loc[self.pos, f'Mine {mine_id} reward_prob']

        # pay dig cost
        self.resource -= self.cfg.dig_cost

        success = (p is not None) and (self.rng.random() < float(p))
        if success: self.total_reward += self.cfg.reward_amount

        rem_col = f"Mine {mine_id} digs_remaining"
        new_rem = int(self.env.loc[self.pos, rem_col]) - 1
        self.env.loc[self.pos, rem_col] = new_rem

        depleted = (new_rem <= 0)
        if depleted:
            self.env.loc[self.pos, f'Number {mine_id} mines'] = None
            self.env.loc[self.pos, f'Mine {mine_id} reward_prob'] = None
            self.env.loc[self.pos, rem_col] = None

        self.total_digs += 1
        self._update_cell_empty_flag()
        self.perceived_value[self.pos] = self._cell_perceived_value(self.pos)

        self._log(action='dig', decision='stay',
                  v_stay=self.value_stay(), v_leave=self.value_leave(),
                  mine=mine_id, mine_cat=cat, success=bool(success),
                  reward_gained=(self.cfg.reward_amount if success else 0.0),
                  depleted=bool(depleted),
                  remaining_after=(None if depleted else new_rem))
        self._post_action_snapshot("dig","stay")
        return True

    def _move(self, auto_leave: bool = False):
        self.update_current_richness()

        if self.resource < self.cfg.move_cost:
            self._log(action='halt', decision='no_resource_to_move',
                    v_stay=self.value_stay(), v_leave=self.value_leave())
            self._post_action_snapshot("halt","no_resource_to_move")
            return False

        neigh = self._neighbors()
        if not neigh:
            self._log(action='halt', decision='no_neighbors',
                    v_stay=self.value_stay(), v_leave=self.value_leave())
            self._post_action_snapshot("halt","no_neighbors")
            return False

        beta = self.cfg.beta_trust
        labels = [self.env.loc[n, 'Overall Richness of this environment'] for n in neigh]
        label_scores = np.array([self.cfg.label_value.get(l, 0.5) for l in labels], dtype=float)
        perceived_scores = np.array([self.perceived_value.get(n, 0.0) for n in neigh], dtype=float)

        # SOFTMAX neighbor selection (adds randomness)
        logits = beta * label_scores + (1.0 - beta) * perceived_scores
        probs = _softmax(logits, temp=self.cfg.move_temp)

        choice_idx = int(self.rng.choice(len(neigh), p=probs))
        choice = neigh[choice_idx]
        choice_label = labels[choice_idx]

        # pay and move
        self.resource -= self.cfg.move_cost
        if (self.pos == self.base) and (choice != self.base): self.left_base_once = True
        self.pos = choice

        self._log(action='move', decision=('auto_leave' if auto_leave else 'leave'),
                v_stay=self.value_stay(), v_leave=self.value_leave(),
                move_choice_label=choice_label, move_probs=list(np.round(probs, 4)),
                move_label_scores=list(np.round(label_scores, 3)),
                move_perceived_scores=list(np.round(perceived_scores, 3)))
        self._post_action_snapshot("move","auto_leave or leave")
        self._log_enter()
        return True

    # ---------- main loop (softmax stay vs leave) ----------
    def step(self):
        can_dig = (self.resource >= self.cfg.dig_cost) and (len(self._available_mines()) > 0)
        can_move = (self.resource >= self.cfg.move_cost) and (len(self._neighbors()) > 0)

        if not can_dig and not can_move:
            self._log(action='halt', decision='insufficient_actions', v_stay=None, v_leave=None)
            self._post_action_snapshot("halt","insufficient_actions")
            self.t += 1; return

        # auto-leave if no mines here but movement is possible
        if len(self._available_mines()) == 0 and can_move:
            self._move(auto_leave=True)
            self.t += 1; return

        v_stay, v_leave = self.value_stay(), self.value_leave()
        logits = np.array([
            v_stay if can_dig else -1e9,
            v_leave if can_move else -1e9
        ], dtype=float)
        probs = _softmax(logits, temp=self.cfg.stay_leave_temp)
        choice = int(self.rng.choice([0,1], p=probs))  # 0=stay(dig), 1=leave(move)

        if choice == 0 and can_dig:
            self._dig()
        elif choice == 1 and can_move:
            self._move(auto_leave=False)
        else:
            # fallback
            if can_dig: self._dig()
            else:       self._move(auto_leave=False)

        self.t += 1

    def run(self, max_steps: int = 300) -> pd.DataFrame:
        for _ in range(max_steps):
            if (self.resource < self.cfg.dig_cost) and (self.resource < self.cfg.move_cost):
                self._log(action='halt', decision='insufficient_resources', v_stay=None, v_leave=None)
                self._post_action_snapshot("halt","insufficient resources"); break
            self.step()
            if self.log and self.log[-1].get('action') == 'halt': break
        return pd.DataFrame(self.log)

In [None]:
# ------------------- Animation helpers -------------------
def _env_index_or_copy(env: pd.DataFrame) -> pd.DataFrame:
    if list(env.index.names) == ["Row", "Col"]:
        return env
    return env.set_index(["Row", "Col"], drop=False)

def animate_explorer(env: pd.DataFrame, agent: ExplorerAgent, outpath: str = "explore_animation.gif"):
    """
    Colors the grid by TRUE richness (no 'unknown' blocks).
    Dim tiles that still contain any HIDDEN mines (based on per-mine 'revealed' flags).
    """
    frames = len(agent.frames_unrevealed_mask)
    pos_seq = agent.frames_pos
    resource_seq = agent.frames_resource
    action_seq = agent.frames_action
    decision_seq = agent.frames_decision

    env_idx = _env_index_or_copy(env)
    Nrows = int(env_idx.index.get_level_values(0).max()) + 1
    Ncols = int(env_idx.index.get_level_values(1).max()) + 1

    fig, ax = plt.subplots(figsize=(5, 5))
    legend_patches = [mpatches.Patch(color=col, label=lab) for lab, col in RICHNESS_COLORS.items()]
    ax.legend(handles=legend_patches, loc="upper right", title="Richness (TRUE for explorer)")

    tiles = {}
    # Create tiles colored by TRUE richness
    for (r, c), cell in env_idx.iterrows():
        richness_true = cell["TRUE Overall Richness"]
        base_rgb = np.array(mcolors.to_rgb(RICHNESS_COLORS.get(richness_true, "white")))
        rect = plt.Rectangle((c - 0.5, r - 0.5), 1, 1, facecolor=base_rgb, edgecolor="black")
        ax.add_patch(rect)
        tiles[(r, c)] = rect

    # Axes/layout
    ax.set_xlim(-0.5, Ncols - 0.5)
    ax.set_ylim(Nrows - 0.5, -0.5)
    ax.set_xticks(range(Ncols))
    ax.set_yticks(range(Nrows))
    ax.grid(True, linestyle=":", linewidth=0.5)
    ax.set_aspect("equal")
    ax.set_title("Explorer (TRUE map color; dim = has hidden mines)")

    # Agent marker & HUD
    agent_dot, = ax.plot([], [], "bo", markersize=10)
    hud_text = ax.text(0.02, 0.98, "", transform=ax.transAxes, va='top', ha='left')

    def init_anim():
        agent_dot.set_data([], [])
        hud_text.set_text("")
        return [agent_dot, hud_text] + list(tiles.values())

    def update(frame):
        if frame >= frames:
            frame = frames - 1

        # Keep TRUE richness color each frame
        for (r, c), rect in tiles.items():
            richness_true = env_idx.loc[(r, c), "TRUE Overall Richness"]
            rect.set_facecolor(mcolors.to_rgb(RICHNESS_COLORS.get(richness_true, "white")))

        # Dim tiles that still have any hidden mine slots (from agent snapshot)
        mask = agent.frames_unrevealed_mask[frame]
        for (r, c), rect in tiles.items():
            if mask[r, c]:
                col = np.array(rect.get_facecolor()[:3])
                rect.set_facecolor(np.clip(col * 0.6, 0, 1))

        rr, cc = pos_seq[frame]
        agent_dot.set_data([cc], [rr])
        hud_text.set_text(f"t={frame}  action={action_seq[frame]}\nresource={resource_seq[frame]:.2f}\ndecision={decision_seq[frame]}")
        return [agent_dot, hud_text] + list(tiles.values())

    anim = FuncAnimation(fig, update, init_func=init_anim, frames=frames, interval=150, blit=True, repeat=False)
    anim.save(outpath, writer=PillowWriter(fps=6))
    plt.close(fig)

def animate_forager(env: pd.DataFrame, agent: MVTAgent, outpath: str = "forage_animation.gif"):
    """
    Uses the *live* shared env to recolor each frame from the current visible label
    ('Overall Richness of this environment'), so explorer’s auto-labeling / reveals are shown.
    Dimming reflects remaining digs intensity from the agent’s snapshots.
    """
    frames = len(agent.frames_intensity)
    pos_seq = agent.frames_pos
    reward_seq = agent.frames_reward
    resource_seq = agent.frames_resource
    action_seq = agent.frames_action
    decision_seq = agent.frames_decision

    env_idx = _env_index_or_copy(env)
    Nrows = int(env_idx.index.get_level_values(0).max()) + 1
    Ncols = int(env_idx.index.get_level_values(1).max()) + 1

    fig, ax = plt.subplots(figsize=(5, 5))
    legend_patches = [mpatches.Patch(color=col, label=lab) for lab, col in RICHNESS_COLORS.items()]
    ax.legend(handles=legend_patches, loc="upper right", title="Richness (visible)")

    tiles = {}
    # Create rectangles; we'll recolor every frame from agent.env
    for (r, c), _cell in env_idx.iterrows():
        rect = plt.Rectangle((c - 0.5, r - 0.5), 1, 1, facecolor="white", edgecolor="black")
        ax.add_patch(rect)
        tiles[(r, c)] = rect

    # Axes/layout
    ax.set_xlim(-0.5, Ncols - 0.5)
    ax.set_ylim(Nrows - 0.5, -0.5)
    ax.set_xticks(range(Ncols))
    ax.set_yticks(range(Nrows))
    ax.grid(True, linestyle=":", linewidth=0.5)
    ax.set_aspect("equal")
    ax.set_title("Forager (action & reward)")

    # Agent marker & HUD
    agent_dot, = ax.plot([], [], "ro", markersize=12)
    hud_text = ax.text(0.02, 0.98, "", transform=ax.transAxes, va='top', ha='left')

    def init_anim():
        agent_dot.set_data([], [])
        hud_text.set_text("")
        return [agent_dot, hud_text] + list(tiles.values())

    def update(frame):
        if frame >= frames:
            frame = frames - 1

        inten = agent.frames_intensity[frame]   # [0..1] remaining digs normalized
        hasm = agent.frames_has_mines[frame]    # bool mines-present mask

        # Recolor from CURRENT visible label in the shared env, then dim by intensity
        for (r, c), rect in tiles.items():
            richness_vis = agent.env.loc[(r, c), "Overall Richness of this environment"]
            base_rgb = np.array(mcolors.to_rgb(RICHNESS_COLORS.get(richness_vis, "white")))

            if not hasm[r, c]:
                scale = 0.20  # depleted: keep hue but very dim
            else:
                scale = 0.30 + 0.70 * float(inten[r, c])
            rect.set_facecolor(np.clip(base_rgb * scale, 0, 1))

        rr, cc = pos_seq[frame]
        rr, cc = int(rr), int(cc)
        agent_dot.set_data([cc], [rr])
        hud_text.set_text(
            f"t={frame}  action={action_seq[frame]}\n"
            f"reward_total={reward_seq[frame]:.2f}  resource={resource_seq[frame]:.2f}  decision={decision_seq[frame]}"
        )
        return [agent_dot, hud_text] + list(tiles.values())

    anim = FuncAnimation(fig, update, init_func=init_anim, frames=frames, interval=150, blit=True, repeat=False)
    anim.save(outpath, writer=PillowWriter(fps=6))
    plt.close(fig)

In [18]:
# ------------------- Run (two-phase) -------------------
if __name__ == "__main__":
    # === Build env ===
    # Visible map respects presets (if passed); otherwise uses TRUE labels. No 'unknown' anywhere.
    env = init_gridworld(size=3, seed=None)

    # === Phase 1: Explorer (writes IN-PLACE on env and transfers same map) ===
    ex_cfg = ExplorerConfig(init_resource=100, move_cost=4, scan_cost=2, gamma=1.0,
                            beta_local=0.6, beta_global=0.4, avoid_base=True, no_backtrack=True)
    explorer = ExplorerAgent(env, ex_cfg, seed=None)
    ex_traj = explorer.run(max_steps=300)
    animate_explorer(env, explorer, outpath="explore_animation.gif")

    # === Transfer visited map to Forager (SAME DataFrame reference) ===
    env_for_forager = explorer.export_env_for_forager()


    # === Phase 2: Forager (softmax stay/leave + softmax neighbor; move only on opened tiles) ===
    fg_cfg = MVTConfig(init_resource=100, move_cost=4, dig_cost=2,
                       gamma=1.0, beta_trust=0.7, reward_amount=1.0, avoid_base=True,
                       stay_leave_temp=0.7, move_temp=0.7)
    forager = MVTAgent(env_for_forager, fg_cfg, seed=None)
    traj = forager.run(max_steps=300)
    animate_forager(env_for_forager, forager, outpath="forage_animation.gif")

    # --- Summary prints ---
    print("=== Exploration Summary ===")
    print(f"Explorer steps: {len(ex_traj)} | Remaining resource: {explorer.resource:.2f} | Hidden mines left: {explorer._global_hidden_count()}")
    print("=== Foraging Summary ===")
    print(f"Total reward received: {forager.total_reward:.2f}")
    print(f"Resource remaining: {forager.resource:.2f}")

=== Exploration Summary ===
Explorer steps: 15 | Remaining resource: 56.00 | Hidden mines left: 4
=== Foraging Summary ===
Total reward received: 6.00
Resource remaining: 0.00
