In [None]:
# ============================================================================
# CELL 1: State Management & Utilities
# ============================================================================
# CHANGES FROM YOUR ORIGINAL:
# 1. Added parse_game_state_data() helper (handles both key formats)
# 2. Added _pad_or_trim() helper for dimension safety
# 3. read_game_state now has max_retries with backoff (matches colleague)
# 4. Added all MARKOV_* constants (were missing)
# 5. Added EXPLORATION_MEMORY_FILE, MODEL_CHECKPOINT_FILE, 
#    TAUGHT_EXPLORATION_FILE paths (matches colleague's path structure)
# 6. build_learning_state uses concatenation with battle gating (matches colleague)
# ============================================================================

from pathlib import Path
import json
import numpy as np
import time
from collections import deque

BASE_PATH = Path(r"C:\Users\HP\Documents\cogai")
ACTION_FILE = BASE_PATH / "action.json"
STATE_FILE = BASE_PATH / "game_state.json"
INPUT_FILE = BASE_PATH / "input_cache.txt"
MODEL_FILE = BASE_PATH / "taught_model_checkpoint.json"
TAUGHT_TRANSITIONS_FILE = BASE_PATH / "taught_transitions.json"
EXPLORATION_MEMORY_FILE = BASE_PATH / "taught_exploration_memory.json"
MODEL_CHECKPOINT_FILE = BASE_PATH / "taught_model_checkpoint.json"
TAUGHT_EXPLORATION_FILE = BASE_PATH / "taught_exploration_memory.json"

# === MARKOV SIMILARITY WEIGHTS (consistent with colleague) ===
MARKOV_IMMEDIATE_WEIGHT = 0.5
MARKOV_SEQUENTIAL_WEIGHT = 0.3
MARKOV_PARTIAL_WEIGHT = 0.2
MARKOV_FAMILIARITY_THRESHOLD = 0.6

MARKOV_SEQ_FULL_WEIGHT = 1.0
MARKOV_SEQ_MEDIUM_WEIGHT = 0.6
MARKOV_SEQ_SHORT_WEIGHT = 0.3

MARKOV_POS_EXACT_BONUS = 0.35
MARKOV_POS_NEAR_BONUS = 0.25
MARKOV_POS_FAR_BONUS = 0.1
MARKOV_POS_MAX_DIST = 5

EXPECTED_STATE_DIM = 6
PALETTE_DIM = 768
TILE_DIM = 600
LEARNING_STATE_DIM = 8 + TILE_DIM + PALETTE_DIM  # 1376

# Action code mapping (from Lua short codes)
ACTION_MAP = {
    'U': 'UP', 'D': 'DOWN', 'L': 'LEFT', 'R': 'RIGHT',
    'A': 'A', 'B': 'B', 'S': 'Start', 'E': 'Select'
}


def normalize_game_state(raw_state):
    if len(raw_state) < 6:
        raw_state = list(raw_state) + [0] * (6 - len(raw_state))
    normalized = np.array(raw_state, dtype=float)
    normalized[0] = raw_state[0] / 255.0
    normalized[1] = raw_state[1] / 255.0
    normalized[2] = np.clip(raw_state[2], 0, 255)
    normalized[3] = 1.0 if raw_state[3] > 0 else 0.0
    normalized[4] = 1.0 if raw_state[4] > 0 else 0.0
    normalized[5] = int(raw_state[5]) % 4
    return normalized


def compute_derived_features(current, prev):
    if prev is None:
        return np.zeros(8)
    vel_x = current[0] - prev[0]
    vel_y = current[1] - prev[1]
    map_changed = 1.0 if abs(current[2] - prev[2]) > 0.5 else 0.0
    battle_started = 1.0 if current[3] > prev[3] else 0.0
    battle_ended = 1.0 if current[3] < prev[3] else 0.0
    menu_opened = 1.0 if current[4] > prev[4] else 0.0
    menu_closed = 1.0 if current[4] < prev[4] else 0.0
    direction_changed = 1.0 if current[5] != prev[5] else 0.0
    return np.array([vel_x, vel_y, map_changed, battle_started, battle_ended,
                     menu_opened, menu_closed, direction_changed])


def build_learning_state(derived, palette, tiles, in_battle):
    """Build learning state ‚Äî matches colleague's version exactly."""
    if len(derived) != 8:
        derived = np.zeros(8)
    if len(tiles) != TILE_DIM:
        tiles = np.zeros(TILE_DIM)
    if len(palette) != PALETTE_DIM:
        palette = np.zeros(PALETTE_DIM)

    if in_battle > 0.5:
        state = np.concatenate([derived, palette])
    else:
        state = np.concatenate([derived, tiles, palette])
    noise = np.random.randn(len(state)) * 0.0001
    return state + noise


def _pad_or_trim(arr, target_dim):
    """NEW: Dimension safety helper (from colleague)."""
    if arr.shape[0] < target_dim:
        return np.pad(arr, (0, target_dim - arr.shape[0]))
    elif arr.shape[0] > target_dim:
        return arr[:target_dim]
    return arr


def parse_game_state_data(data):
    """NEW: Parse game state dict handling both long and short key formats."""
    raw = data.get("state") or data.get("s") or []
    palette_raw = data.get("palette") or data.get("p") or []
    tiles_raw = data.get("tiles") or data.get("t") or []
    dead = bool(data.get("dead", False))
    return raw, palette_raw, tiles_raw, dead


def read_game_state(max_retries=3):
    """CHANGED: Now uses parse_game_state_data + _pad_or_trim + retry loop."""
    if not STATE_FILE.exists():
        return np.zeros(EXPECTED_STATE_DIM), np.zeros(PALETTE_DIM), np.zeros(TILE_DIM), False, (0, 0)

    for attempt in range(max_retries):
        try:
            with open(STATE_FILE, "r") as f:
                data = json.loads(f.read())

            raw, palette_raw, tiles_raw, dead = parse_game_state_data(data)

            raw_x = int(raw[0]) if len(raw) > 0 else 0
            raw_y = int(raw[1]) if len(raw) > 1 else 0
            raw_position = (raw_x, raw_y)

            context_state = normalize_game_state(np.array(raw, dtype=float))
            palette_state = np.array(palette_raw, dtype=float) if palette_raw else np.zeros(PALETTE_DIM)
            tile_state = np.array(tiles_raw, dtype=float) if tiles_raw else np.zeros(TILE_DIM)

            context_state = _pad_or_trim(context_state, EXPECTED_STATE_DIM)
            palette_state = _pad_or_trim(palette_state, PALETTE_DIM)
            tile_state = _pad_or_trim(tile_state, TILE_DIM)

            return context_state, palette_state, tile_state, dead, raw_position

        except (json.JSONDecodeError, ValueError):
            if attempt < max_retries - 1:
                time.sleep(0.001)
                continue
            return np.zeros(EXPECTED_STATE_DIM), np.zeros(PALETTE_DIM), np.zeros(TILE_DIM), False, (0, 0)
        except Exception:
            return np.zeros(EXPECTED_STATE_DIM), np.zeros(PALETTE_DIM), np.zeros(TILE_DIM), False, (0, 0)


def write_action(action_name):
    if action_name:
        action_name = action_name.upper()
    try:
        with open(ACTION_FILE, "w") as f:
            json.dump({"action": action_name}, f)
            f.flush()
    except Exception as e:
        print(f"[ERROR] Failed to write action: {e}")

SyntaxError: (unicode error) 'unicodeescape' codec can't decode bytes in position 2-3: truncated \UXXXXXXXX escape (1559207712.py, line 20)

In [None]:
# ============================================================================
# CELL 2: Perceptron Classes
# ============================================================================
# CHANGES FROM YOUR ORIGINAL:
# 1. ensure_weights() now resizes with old-weight preservation (same as yours)
# 2. update() has dimension mismatch fallback matching colleague's approach
#    (uses min_dim slicing instead of full resize ‚Äî keeps both paths for safety)
# 3. No functional changes ‚Äî your version was already close
# ============================================================================

class Perceptron:
    def __init__(self, kind, action=None, group=None, entity_type=None):
        self.kind = kind
        self.action = action
        self.group = group
        self.entity_type = entity_type
        
        self.utility = 1.0
        self.weights = None
        
        self.eligibility_fast = 0.0
        self.eligibility_slow = 0.0
        
        self.familiarity = 0.0
        self.activation_history = deque(maxlen=10)
        
        self.learning_rate = 0.01
        self.prediction_errors = deque(maxlen=50)

    def ensure_weights(self, dim):
        """Initialize or resize weights to match dimension."""
        if self.weights is None:
            self.weights = np.random.randn(dim) * 0.001
        elif len(self.weights) != dim:
            old_weights = self.weights
            self.weights = np.random.randn(dim) * 0.001
            min_len = min(len(old_weights), dim)
            self.weights[:min_len] = old_weights[:min_len]

    def predict(self, state):
        self.ensure_weights(len(state))
        
        # Handle dimension mismatch (matches colleague)
        if len(self.weights) != len(state):
            min_dim = min(len(self.weights), len(state))
            raw_activation = np.dot(self.weights[:min_dim], state[:min_dim])
        else:
            raw_activation = np.dot(self.weights, state)
        
        if self.kind == "entity":
            novelty_factor = 1.0 / (1.0 + np.sqrt(self.familiarity * 0.5))
            decayed_activation = raw_activation * novelty_factor
            self.activation_history.append(abs(raw_activation))
            return decayed_activation
        else:
            return raw_activation

    def adapt_learning_rate(self):
        if len(self.prediction_errors) >= 50:
            avg_error = np.mean(self.prediction_errors)
            if avg_error < 0.1:
                self.learning_rate = max(0.001, self.learning_rate * 0.99)
            elif avg_error > 0.5:
                self.learning_rate = min(0.05, self.learning_rate * 1.01)

    def update(self, state, error, gamma_fast=0.5, gamma_slow=0.95, stagnation=0.0):
        self.ensure_weights(len(state))
        
        # Handle dimension mismatch (matches colleague's approach)
        if len(self.weights) != len(state):
            min_dim = min(len(self.weights), len(state))
            state = state[:min_dim]
            self.weights = self.weights[:min_dim]
        
        self.eligibility_fast = gamma_fast * self.eligibility_fast + 1.0
        self.eligibility_slow = gamma_slow * self.eligibility_slow + 1.0
        
        self.adapt_learning_rate()
        
        fast_update = 0.7 * self.learning_rate * error * state * self.eligibility_fast
        slow_update = 0.3 * self.learning_rate * error * state * self.eligibility_slow
        self.weights += fast_update + slow_update

        if self.kind == "action":
            if error > 0.01:
                if stagnation > 0.5:
                    self.utility *= 0.97
                elif error > 0.2:
                    self.utility = min(self.utility * 1.02, 2.0)
                else:
                    self.utility *= 0.995
            
            if self.group == "move":
                self.utility = np.clip(self.utility, 0.1, 2.0)
            else:
                self.utility = np.clip(self.utility, 0.01, 2.0)
        
        if self.kind == "entity" and len(self.activation_history) > 0:
            recent_avg = np.mean(self.activation_history)
            if recent_avg > 0.1:
                self.familiarity += 0.03
        
        if self.kind == "entity":
            prediction = self.predict(state)
            self.prediction_errors.append(abs(prediction - error))


class ControlSwapPerceptron(Perceptron):
    def __init__(self):
        super().__init__(kind="control_swap")
        self.swap_history = deque(maxlen=100)
        self.confidence = 0.0
        
    def should_swap(self, state, movement_stagnation):
        if self.weights is None:
            return False, 0.0
        self.ensure_weights(len(state))
        swap_score = np.dot(self.weights, state)
        stagnation_factor = np.tanh(movement_stagnation / 5.0)
        combined_score = swap_score * 0.7 + stagnation_factor * 0.3
        return combined_score > 0.5, abs(combined_score)
    
    def record_swap_outcome(self, state, swapped, novelty_gained):
        self.swap_history.append((swapped, novelty_gained))
        if len(self.swap_history) >= 20:
            recent = list(self.swap_history)[-20:]
            successful = sum(1 for swap, nov in recent if swap and nov > 0.2)
            self.confidence = successful / 20.0

In [None]:
# ============================================================================
# CELL 3: Brain Class ‚Äî Complete, Aligned with Colleague's AI Agent
# ============================================================================
# KEY CHANGES from your original (Document 4):
#
# SAVE PATHS (sync doc compliance):
#   - Exploration saves to taught_exploration_memory.json (was correct)
#   - Model saves to taught_model_checkpoint.json (was model_checkpoint.json)
#
# ADDED SYSTEMS (were missing):
#   - Taught reference model + blend system (3 tiers)
#   - Colleague's Markov similarity (position-distance, map filtering)
#   - Transition ban system (create, vicinity, lift)
#   - Pattern detection (detect_pattern ‚Äî was pass)
#   - Forced random + blend on stagnation
#   - State stagnation initiator penalty
#   - Productive change detection + on_productive_change
#   - Full determine_control_mode with tile probing
#   - predict_future_error for curiosity scoring
#   - get_transition_attraction
#   - Movement boost gated on repetition
#   - Overworld gate on interaction verification
#   - save_model_checkpoint with blend_stats + markov_stats
#
# KEPT (your teaching-specific features):
#   - cleanup_memory() with MAX_LOCATIONS / MAX_MAPS_IN_MEMORY
#   - get_memory_stats()
#   - recent_actions_buffer
# ============================================================================

import gc
import random

class Brain:
    def __init__(self):
        self.perceptrons = []
        self.prev_learning_states = deque(maxlen=50)
        self.prev_context_states = deque(maxlen=10)
        self.last_positions = deque(maxlen=30)
        self.action_history = deque(maxlen=100)
        
        self.control_mode = "move"
        self.timestep = 0
        self.last_action = None
        self.last_direction = 0
        
        self.MOVE_UTILITY_FLOOR = 0.05
        self.INTERACT_UTILITY_FLOOR = 0.15
        
        # === EXPLORATION MEMORY ===
        # FIX: This is the TAUGHT exploration file (teaching writes, AI reads)
        self.EXPLORATION_MEMORY_FILE = BASE_PATH / "taught_exploration_memory.json"
        self.exploration_memory = {}
        self.current_map_id = None
        self.SAVE_INTERVAL = 100
        self.MAX_MAPS_IN_MEMORY = 20
        
        self.DIRECTION_NAMES = {0: "DOWN", 1: "UP", 2: "LEFT", 3: "RIGHT"}
        self.DIRECTION_TO_INT = {"DOWN": 0, "UP": 1, "LEFT": 2, "RIGHT": 3}
        self.INT_TO_ACTION = {0: "DOWN", 1: "UP", 2: "LEFT", 3: "RIGHT"}
        self.DIRECTION_DELTAS_INT = {0: (0, 1), 1: (0, -1), 2: (-1, 0), 3: (1, 0)}
        self.ACTION_DELTAS = {"UP": (0, -1), "DOWN": (0, 1), "LEFT": (-1, 0), "RIGHT": (1, 0)}
        self.DELTA_TO_DIRECTION = {(0, 1): 0, (0, -1): 1, (-1, 0): 2, (1, 0): 3}
        
        self.load_exploration_memory()
        
        # === MARKOV TRANSITION SYSTEM ===
        self.taught_transitions = []
        self.taught_batches = []
        self.taught_metadata = {}
        self.markov_enabled = True
        self.markov_action_count = 0
        self.curiosity_action_count = 0
        self.last_markov_score = 0.0
        self.last_markov_action = None
        self.recent_actions_buffer = deque(maxlen=8)
        
        # === TAUGHT MODEL REFERENCE (NEW) ===
        self.taught_reference = {'utilities': {}, 'weights': {}, 'loaded': False}
        
        # === BLEND SYSTEM (NEW) ===
        self.blend_tier = 0
        self.last_blend_timestep = 0
        self.BLEND_COOLDOWN = 50
        self.blend_count = 0
        self.BLEND_RATIOS = {1: (0.80, 0.20), 2: (0.60, 0.40), 3: (0.40, 0.60)}
        self.BLEND_TIER_TRIGGERS = {
            1: {'pattern_repeats': 3, 'pos_stagnation': 8, 'consecutive': 12},
            2: {'pattern_repeats': 6, 'pos_stagnation': 15, 'consecutive': 15},
            3: {'pattern_repeats': 10, 'state_stagnation_mult': 2.0}
        }
        
        # === ACTION EXECUTION ===
        self.pending_action = None
        self.pending_action_frames = 0
        self.ACTION_CONFIRM_FRAMES = 3
        self.last_confirmed_action = None
        
        # === TILE INTERACTION ===
        self.INTERACTION_VERIFY_FRAMES = 8
        self.MIN_SUCCESS_RATE_THRESHOLD = 0.1
        self.pending_interaction_verify = None
        self.interaction_verify_countdown = 0
        
        # === MENU TRAP ===
        self.menu_trap_frames = 0
        self.menu_trap_b_boost = 1.0
        self.menu_trap_position = None
        self.B_BOOST_INCREMENT = 0.15
        self.B_BOOST_MAX = 3.0
        self.MENU_TRAP_THRESHOLD = 5
        self.original_b_utility = None
        
        # === MODE SWAPPING ===
        self.DEFAULT_MOVE_TO_INTERACT_THRESHOLD = 15
        self.DEFAULT_INTERACT_TO_MOVE_THRESHOLD = 25
        self.move_to_interact_threshold = self.DEFAULT_MOVE_TO_INTERACT_THRESHOLD
        self.interact_to_move_threshold = self.DEFAULT_INTERACT_TO_MOVE_THRESHOLD
        self.THRESHOLD_INCREMENT = 15
        self.MAX_THRESHOLD = 150
        self.frames_in_current_mode = 0
        self.swap_chain_count = 0
        self.position_at_mode_swap = None
        self.last_map_id = None
        self.last_battle_state = None
        
        # === UNPRODUCTIVE SWAP (NEW) ===
        self.UNPRODUCTIVE_SWAP_THRESHOLD = 3
        self.unproductive_swap_count = 0
        
        # === STATE STAGNATION ===
        self.STATE_STAGNATION_THRESHOLD = 20
        self.state_stagnation_count = 0
        self.last_context_state_hash = None
        self.stagnation_initiator_action = None
        
        # === BOTH MODE ===
        self.BOTH_MODE_STAGNATION_THRESHOLD = 35
        self.BOTH_MODE_SWAP_THRESHOLD = 5
        self.last_direction_for_progress = None
        self.direction_change_counts_as_progress = True
        
        # === NOVELTY WEIGHTS ===
        self.UNVISITED_TILE_BONUS = 1.5
        self.OBSTRUCTION_PENALTY = 0.25
        
        # === TRANSITIONS & DEBT ===
        self.TRANSITION_ATTRACTION_WEIGHT = 0.6
        self.TEMP_DEBT_ACCUMULATION = 0.5
        self.TEMP_DEBT_DECAY = 0.02
        self.TEMP_DEBT_MAX = 15.0
        self.MAX_MAP_DEBT = 10.0
        self.MAX_LOCATION_DEBT = 5.0
        self.DEBT_DECAY_RATE = 0.005
        
        # === TRANSITION BAN (NEW) ===
        self.transition_bans = {}
        self.BAN_VICINITY_RADIUS = 3
        self.BAN_COVERAGE_LIFT_THRESHOLD = 0.6
        self.BAN_TIMEOUT_STEPS = 300
        
        # Multi-scale memory
        self.visited_maps = {}
        self.map_novelty_debt = {}
        self.location_memory = {}
        self.location_novelty = {}
        self.action_execution_count = {}
        self.MAX_LOCATIONS = 500
        
        self.swap_perceptron = ControlSwapPerceptron()
        self.error_history = deque(maxlen=100)
        self.numeric_error_history = deque(maxlen=100)
        self.visual_error_history = deque(maxlen=100)
        self._entity_norms_cache = {}
        self._cache_valid = False
        self.innate_entities_spawned = False
        
        # === REPETITION ===
        self.consecutive_action_count = 0
        self.current_repeated_action = None
        self.LEARNING_SLOWDOWN_START = 3
        self.LEARNING_SLOWDOWN_MAX = 10
        self.PENALTY_THRESHOLD = 12
        self.HARD_RESET_THRESHOLD = 18
        
        # === PATTERN (NEW ‚Äî was pass) ===
        self.PATTERN_CHECK_WINDOW = 50
        self.PATTERN_MIN_REPEATS = 3
        self.PATTERN_MAX_LENGTH = 10
        self.detected_pattern = None
        self.pattern_repeat_count = 0
        
        # === PROBE CACHE ===
        self._cached_probe_action = None
        self._cached_probe_dir = None
        self._probe_cache_position = None

    # =========================================================================
    # CORE
    # =========================================================================
    def add(self, p): self.perceptrons.append(p); self._cache_valid = False
    def actions(self): return [p for p in self.perceptrons if p.kind == "action"]
    def entities(self): return [p for p in self.perceptrons if p.kind == "entity"]
    def get_location_key(self, x, y, mid, bs=5): return (int(mid), int(x//bs)*bs, int(y//bs)*bs)
    def is_near_map_edge(self, x, y): return x < 10 or x > 245 or y < 10 or y > 245
    def record_action_execution(self, a):
        if a: self.action_execution_count[a] = self.action_execution_count.get(a, 0) + 1; self.recent_actions_buffer.append(a)
    def get_position_stagnation(self):
        if len(self.last_positions) < 2: return 0
        cp = self.last_positions[-1]
        return sum(1 for p in reversed(list(self.last_positions)[:-1]) if p == cp)
    def get_group_weight(self, g): return sum(a.utility for a in self.actions() if a.group == g)
    def log_state(self, ls, cs): self.prev_learning_states.append(ls); self.prev_context_states.append(cs)
    def update_position(self, x, y): self.last_positions.append((int(x), int(y)))
    
    # =========================================================================
    # MEMORY MANAGEMENT (kept from your version)
    # =========================================================================
    def cleanup_memory(self):
        if len(self.location_memory) > self.MAX_LOCATIONS:
            sl = sorted(self.location_memory.items(), key=lambda x: x[1], reverse=True)
            self.location_memory = dict(sl[:self.MAX_LOCATIONS // 2])
            self.location_novelty = {k: v for k, v in self.location_novelty.items() if k in self.location_memory}
        if len(self.exploration_memory) > self.MAX_MAPS_IN_MEMORY:
            self.save_exploration_memory()
            sm = sorted(self.exploration_memory.items(), key=lambda x: x[1].get('last_visited_timestep', 0), reverse=True)
            self.exploration_memory = dict(sm[:self.MAX_MAPS_IN_MEMORY // 2])
        self._entity_norms_cache.clear(); self._cache_valid = False; gc.collect()
    def get_memory_stats(self):
        return {'exploration_maps': len(self.exploration_memory), 'location_memory': len(self.location_memory),
                'error_history': len(self.error_history), 'perceptrons': len(self.perceptrons),
                'total_tiles': sum(len(m.get('visited_tiles', set())) for m in self.exploration_memory.values())}

    # =========================================================================
    # TAUGHT REFERENCE + BLEND (NEW)
    # =========================================================================
    def load_taught_reference(self, fp):
        try:
            if not Path(fp).exists(): print(f"  No taught reference at {fp}"); return
            with open(fp, 'r') as f: model = json.load(f)
            if "perceptrons" not in model: return
            for sa in model["perceptrons"].get("actions", []):
                an = sa.get("action")
                if an:
                    self.taught_reference['utilities'][an] = sa.get("utility", 1.0)
                    if sa.get("weights_nonzero"):
                        dim = sa.get("weights_shape", 1376); w = np.zeros(dim)
                        for idx, val in sa["weights_nonzero"]:
                            if idx < dim: w[idx] = val
                        self.taught_reference['weights'][an] = w
            self.taught_reference['loaded'] = True
            print(f"  üìñ Taught reference loaded: {list(self.taught_reference['utilities'].keys())}")
        except Exception as e: print(f"  ‚ö†Ô∏è Error loading taught reference: {e}")

    def blend_from_taught(self, tier):
        if not self.taught_reference['loaded'] or tier not in self.BLEND_RATIOS: return
        if self.timestep - self.last_blend_timestep < self.BLEND_COOLDOWN: return
        ai_w, tw = self.BLEND_RATIOS[tier]; bw = (tier == 3)
        for a in self.actions():
            if a.action not in self.taught_reference['utilities']: continue
            tu = self.taught_reference['utilities'][a.action]; a.utility = ai_w * a.utility + tw * tu
            if tu > 1.0: a.utility = max(a.utility, tu * 0.5)
            fl = self.INTERACT_UTILITY_FLOOR if a.group == "interact" else self.MOVE_UTILITY_FLOOR
            a.utility = max(min(a.utility, 2.0), fl)
            if bw and a.action in self.taught_reference['weights'] and a.weights is not None:
                taw = self.taught_reference['weights'][a.action]; md = min(len(a.weights), len(taw))
                a.weights[:md] = ai_w * a.weights[:md] + tw * taw[:md]
        self.last_blend_timestep = self.timestep; self.blend_tier = tier; self.blend_count += 1

    def get_blend_tier(self):
        t3 = self.BLEND_TIER_TRIGGERS[3]
        if self.detected_pattern and self.pattern_repeat_count >= t3['pattern_repeats']: return 3
        if self.state_stagnation_count >= self.STATE_STAGNATION_THRESHOLD * t3['state_stagnation_mult']: return 3
        t2 = self.BLEND_TIER_TRIGGERS[2]
        if self.detected_pattern and self.pattern_repeat_count >= t2['pattern_repeats']: return 2
        if self.get_position_stagnation() >= t2['pos_stagnation']: return 2
        if self.consecutive_action_count >= t2['consecutive']: return 2
        t1 = self.BLEND_TIER_TRIGGERS[1]
        if self.detected_pattern and self.pattern_repeat_count >= t1['pattern_repeats']: return 1
        if self.get_position_stagnation() >= t1['pos_stagnation']: return 1
        if self.consecutive_action_count >= t1['consecutive']: return 1
        return 0

    def try_blend_if_needed(self):
        if not self.taught_reference['loaded']: return False
        tier = self.get_blend_tier()
        if tier == 0: return False
        if tier <= self.blend_tier and (self.timestep - self.last_blend_timestep) < self.BLEND_COOLDOWN: return False
        self.blend_from_taught(tier); return True

    # =========================================================================
    # MARKOV (REPLACED ‚Äî matches colleague exactly)
    # =========================================================================
    def load_taught_transitions(self, fp=None):
        fp = fp or TAUGHT_TRANSITIONS_FILE
        try:
            if Path(fp).exists():
                with open(fp, 'r') as f: data = json.load(f)
                self.taught_transitions = []; self.taught_batches = data.get('batches', [])
                for batch in self.taught_batches:
                    bt = batch.get('batch_type', 'steady'); ta = batch.get('trigger_action')
                    for frame in batch.get('frames', []):
                        self.taught_transitions.append({'state': frame.get('state', {}), 'action': frame.get('action'),
                            'recent_actions': frame.get('recent_actions', []), 'frame_offset': frame.get('frame_offset', 0),
                            'batch_type': bt, 'trigger_action': ta})
                self.taught_metadata = data.get('metadata', {})
                print(f"  üìö Taught transitions: {len(self.taught_batches)} batches, {len(self.taught_transitions)} frames")
            else: self.taught_transitions = []; self.taught_batches = []; self.taught_metadata = {}
        except Exception as e:
            print(f"  Error loading transitions: {e}")
            self.taught_transitions = []; self.taught_batches = []; self.taught_metadata = {}

    def extract_partial_context(self, cs, rp=None):
        rx = rp[0] if rp else int(cs[0]*255); ry = rp[1] if rp else int(cs[1]*255); cm = int(cs[2])
        mb = self.get_position_stagnation() > 3
        nt = False; mem = self.get_current_map_memory(cm)
        for t in mem.get('transitions', []):
            tp = tuple(t['position']) if isinstance(t['position'], list) else t['position']
            if abs(rx - tp[0]) + abs(ry - tp[1]) <= 2: nt = True; break
        return {'in_battle': cs[3] > 0.5, 'in_menu': cs[4] > 0.5, 'movement_blocked': mb,
                'near_transition': nt, 'tile_probed': not self.should_interact_at_tile(rx, ry, cm)}

    def compute_markov_similarity(self, cs, rp=None, taught_frames=None):
        frames = taught_frames if taught_frames is not None else self.taught_transitions
        smc = taught_frames is not None
        if not frames: return 0.0, None, -1
        rx = rp[0] if rp else int(cs[0]*255); ry = rp[1] if rp else int(cs[1]*255)
        cm = int(cs[2]); cd = int(cs[5]); ib = cs[3] > 0.5; im = cs[4] > 0.5
        ca = list(self.action_history); cp = self.extract_partial_context(cs, rp)
        bs, ba, bi = 0.0, None, -1
        for idx, tr in enumerate(frames):
            ts = tr.get('state', {}); ta = tr.get('action'); trc = tr.get('recent_actions', []); bt = tr.get('batch_type', 'steady')
            if not ta or ta == "NONE": continue
            isc = 0.0
            if not smc:
                if ts.get('map_id') != cm: continue
            isc += 0.25
            tx, ty = ts.get('x', 0), ts.get('y', 0); pd = abs(rx-tx) + abs(ry-ty)
            if pd == 0: isc += MARKOV_POS_EXACT_BONUS
            elif pd <= 2: isc += MARKOV_POS_NEAR_BONUS
            elif pd <= MARKOV_POS_MAX_DIST: isc += MARKOV_POS_FAR_BONUS
            else: continue
            if ts.get('direction') == cd: isc += 0.2
            tib = ts.get('in_battle', 0) == 1; tim = ts.get('in_menu', 0) == 1
            if tib == ib: isc += 0.1
            if tim == im: isc += 0.1
            ssc = 0.0
            if trc and ca:
                if len(ca) >= 8 and len(trc) >= 8 and list(ca)[-8:] == trc[-8:]: ssc = MARKOV_SEQ_FULL_WEIGHT
                if ssc < MARKOV_SEQ_MEDIUM_WEIGHT and len(ca) >= 5 and len(trc) >= 5 and list(ca)[-5:] == trc[-5:]: ssc = MARKOV_SEQ_MEDIUM_WEIGHT
                if ssc < MARKOV_SEQ_SHORT_WEIGHT and len(ca) >= 3 and len(trc) >= 3 and list(ca)[-3:] == trc[-3:]: ssc = MARKOV_SEQ_SHORT_WEIGHT
            pm = sum(1 for a, b in [(tib, cp['in_battle']), (tim, cp['in_menu'])] if a == b)
            total = MARKOV_IMMEDIATE_WEIGHT * isc + MARKOV_SEQUENTIAL_WEIGHT * ssc + MARKOV_PARTIAL_WEIGHT * (pm / 2)
            if bt == "action_change": total *= 1.2
            if tr.get('frame_offset', 0) == 0: total *= 1.1
            if total > bs: bs, ba, bi = total, ta, idx
        return bs, ba, bi

    def get_markov_action(self, cs, rp=None, taught_frames=None):
        if not self.markov_enabled: return False, None, 0.0
        frames = taught_frames if taught_frames is not None else self.taught_transitions
        if not frames: return False, None, 0.0
        sc, ac, ix = self.compute_markov_similarity(cs, rp, taught_frames=frames)
        self.last_markov_score = sc
        if sc >= MARKOV_FAMILIARITY_THRESHOLD: self.last_markov_action = ac; return True, ac, sc
        return False, None, sc

    # =========================================================================
    # ACTION EXECUTION CONFIRMATION
    # =========================================================================
    def set_pending_action(self, a): self.pending_action = a; self.pending_action_frames = 0
    def confirm_action_executed(self, cs, pcs):
        if self.pending_action is None: return True
        self.pending_action_frames += 1; ae = False
        if pcs is not None:
            if self.pending_action in ["UP","DOWN","LEFT","RIGHT"]:
                ae = cs[0] != pcs[0] or cs[1] != pcs[1] or cs[5] != pcs[5]
            elif self.pending_action in ["A","B","Start","Select"]:
                ae = abs(cs[4]-pcs[4]) > 0.1 or cs[3] != pcs[3] or cs[2] != pcs[2]
        if ae or self.pending_action_frames >= self.ACTION_CONFIRM_FRAMES:
            self.last_confirmed_action = self.pending_action; self.pending_action = None; self.pending_action_frames = 0; return True
        return False
    def should_send_new_action(self):
        return self.pending_action is None or self.pending_action_frames >= self.ACTION_CONFIRM_FRAMES

    # =========================================================================
    # EXPLORATION MEMORY
    # =========================================================================
    def load_exploration_memory(self):
        try:
            if self.EXPLORATION_MEMORY_FILE.exists():
                with open(self.EXPLORATION_MEMORY_FILE, 'r') as f: data = json.load(f)
                self.exploration_memory = {}
                for mk, md in data.items():
                    self.exploration_memory[int(mk.replace('map_', ''))] = self._deserialize_map_memory(md)
                print(f"  Loaded exploration: {len(self.exploration_memory)} maps")
            else: self.exploration_memory = {}
        except Exception as e: print(f"  Error loading exploration: {e}"); self.exploration_memory = {}

    def _deserialize_map_memory(self, d):
        ti = {}
        for tk, td in d.get('tile_interactions', {}).items():
            ti[tk] = {'directions_tried': set(td.get('directions_tried', [])),
                      'direction_attempts': {int(k): v for k, v in td.get('direction_attempts', {}).items()},
                      'direction_successes': {int(k): v for k, v in td.get('direction_successes', {}).items()},
                      'exhausted': td.get('exhausted', False)}
        return {'visited_tiles': set(tuple(t) for t in d.get('visited_tiles', [])),
                'obstructions': set(tuple(t) for t in d.get('obstructions', [])),
                'interactable_objects': d.get('interactable_objects', []),
                'last_visited_timestep': d.get('last_visited_timestep', 0),
                'transitions': d.get('transitions', []), 'temp_debt': d.get('temp_debt', 0.0),
                'tile_interactions': ti}

    def save_exploration_memory(self):
        try:
            data = {f'map_{mid}': self._serialize_map_memory(md) for mid, md in self.exploration_memory.items()}
            with open(self.EXPLORATION_MEMORY_FILE, 'w') as f: json.dump(data, f)
        except Exception as e: print(f"  Error saving exploration: {e}")

    def _serialize_map_memory(self, d):
        sti = {}
        for tk, td in d.get('tile_interactions', {}).items():
            sti[tk] = {'directions_tried': list(td.get('directions_tried', set())),
                       'direction_attempts': {str(k): v for k, v in td.get('direction_attempts', {}).items()},
                       'direction_successes': {str(k): v for k, v in td.get('direction_successes', {}).items()},
                       'exhausted': td.get('exhausted', False)}
        return {'visited_tiles': [list(t) for t in d['visited_tiles']], 'obstructions': [list(t) for t in d['obstructions']],
                'interactable_objects': d['interactable_objects'], 'last_visited_timestep': d['last_visited_timestep'],
                'transitions': d.get('transitions', []), 'temp_debt': d.get('temp_debt', 0.0), 'tile_interactions': sti}

    def get_current_map_memory(self, mid):
        if mid not in self.exploration_memory:
            self.exploration_memory[mid] = {'visited_tiles': set(), 'obstructions': set(), 'interactable_objects': [],
                'last_visited_timestep': self.timestep, 'transitions': [], 'temp_debt': 0.0, 'tile_interactions': {}}
        return self.exploration_memory[mid]

    def record_visited_tile(self, x, y, mid):
        m = self.get_current_map_memory(mid); m['visited_tiles'].add((int(x), int(y))); m['last_visited_timestep'] = self.timestep
    def record_obstruction(self, x, y, mid, d):
        dx, dy = self.DIRECTION_DELTAS_INT.get(d, (0,0)); self.get_current_map_memory(mid)['obstructions'].add((int(x+dx), int(y+dy)))

    def merge_taught_exploration(self, fp):
        if not Path(fp).exists(): print(f"  No taught exploration at {fp}"); return
        try:
            with open(fp, 'r') as f: td = json.load(f)
            ta, ia = 0, 0
            for mk, md in td.items():
                mid = int(mk.replace('map_', '')); tm = self._deserialize_map_memory(md); am = self.get_current_map_memory(mid)
                am['visited_tiles'].update(tm['visited_tiles']); am['obstructions'].update(tm['obstructions'])
                for tt in tm.get('transitions', []):
                    tp = tuple(tt['position']) if isinstance(tt['position'], list) else tt['position']
                    if not any((tuple(e['position']) if isinstance(e['position'], list) else e['position']) == tp and e['direction'] == tt['direction'] for e in am['transitions']):
                        am['transitions'].append(tt); ta += 1
                for ti in tm.get('interactable_objects', []):
                    if ti not in am['interactable_objects']: am['interactable_objects'].append(ti); ia += 1
            print(f"  Merged: {ta} transitions, {ia} interactables")
        except Exception as e: print(f"  Error merging: {e}")

    # =========================================================================
    # TILE INTERACTION ‚Äî with overworld gate (FIX)
    # =========================================================================
    def get_tile_interaction_key(self, x, y): return f"{int(x)}_{int(y)}"
    def get_tile_interaction_state(self, x, y, mid):
        m = self.get_current_map_memory(mid); tk = self.get_tile_interaction_key(x, y)
        if tk not in m['tile_interactions']:
            m['tile_interactions'][tk] = {'directions_tried': set(), 'direction_attempts': {0:0,1:0,2:0,3:0},
                'direction_successes': {0:0,1:0,2:0,3:0}, 'exhausted': False}
        return m['tile_interactions'][tk]
    def should_interact_at_tile(self, x, y, mid):
        ts = self.get_tile_interaction_state(x, y, mid)
        if ts['exhausted']: return False
        if len(ts['directions_tried']) < 4: return True
        return any(ts['direction_attempts'].get(d,0) > 0 and ts['direction_successes'].get(d,0)/ts['direction_attempts'][d] >= self.MIN_SUCCESS_RATE_THRESHOLD for d in range(4))
    def get_untried_directions(self, x, y, mid):
        return [d for d in range(4) if d not in self.get_tile_interaction_state(x, y, mid)['directions_tried']]
    def get_best_interaction_direction(self, x, y, mid):
        ts = self.get_tile_interaction_state(x, y, mid)
        u = self.get_untried_directions(x, y, mid)
        if u: return u[0]
        bd, br = None, 0.0
        for d in range(4):
            a = ts['direction_attempts'].get(d, 0)
            if a > 0:
                r = ts['direction_successes'].get(d, 0) / a
                if r > br: br, bd = r, d
        return bd
    def get_best_probe_action(self, rx, ry, cm, cd):
        ck = (rx, ry, cm, cd)
        if self._probe_cache_position == ck: return self._cached_probe_action, self._cached_probe_dir
        if not self.should_interact_at_tile(rx, ry, cm): r = (None, None)
        else:
            u = self.get_untried_directions(rx, ry, cm)
            if not u:
                bd = self.get_best_interaction_direction(rx, ry, cm)
                r = ('A', cd) if bd is not None and cd == bd else (self.INT_TO_ACTION[bd], bd) if bd is not None else (None, None)
            elif cd in u: r = ('A', cd)
            else: r = (self.INT_TO_ACTION[u[0]], u[0])
        self._probe_cache_position = ck; self._cached_probe_action, self._cached_probe_dir = r; return r
    def record_tile_interaction_attempt(self, x, y, mid, d, success):
        ts = self.get_tile_interaction_state(x, y, mid); ts['directions_tried'].add(d)
        ts['direction_attempts'][d] = ts['direction_attempts'].get(d, 0) + 1
        if success:
            ts['direction_successes'][d] = ts['direction_successes'].get(d, 0) + 1
            m = self.get_current_map_memory(mid); dn = self.DIRECTION_NAMES.get(d, str(d))
            io = [int(x), int(y), dn]
            if io not in m['interactable_objects']: m['interactable_objects'].append(io); print(f"  üéØ INTERACTABLE: ({x},{y}) {dn}")
        self._check_tile_exhaustion(x, y, mid)
    def _check_tile_exhaustion(self, x, y, mid):
        ts = self.get_tile_interaction_state(x, y, mid)
        if len(ts['directions_tried']) >= 4 and not any(ts['direction_successes'].get(d,0) > 0 for d in range(4)):
            ts['exhausted'] = True
    def start_interaction_verification(self, x, y, mid, d):
        self.pending_interaction_verify = {'x': x, 'y': y, 'map_id': mid, 'direction': d}
        self.interaction_verify_countdown = self.INTERACTION_VERIFY_FRAMES
    def check_interaction_verification(self, cs, pcs):
        if self.pending_interaction_verify is None: return
        self.interaction_verify_countdown -= 1; success = False
        if pcs is not None:
            # FIX: Overworld gate ‚Äî only count success if A was pressed in overworld
            in_overworld = pcs[3] <= 0.5 and pcs[4] <= 0.5
            if in_overworld:
                success = abs(cs[4]-pcs[4]) > 0.1 or (cs[3] > 0.5 and pcs[3] <= 0.5) or int(cs[2]) != int(pcs[2])
        if success or self.interaction_verify_countdown <= 0:
            i = self.pending_interaction_verify
            self.record_tile_interaction_attempt(i['x'], i['y'], i['map_id'], i['direction'], success)
            self.pending_interaction_verify = None
    def get_tile_interaction_stats(self, mid):
        m = self.get_current_map_memory(mid); ti = m.get('tile_interactions', {})
        return {'probed': len(ti), 'exhausted': sum(1 for t in ti.values() if t.get('exhausted', False)),
                'with_success': sum(1 for t in ti.values() if any(t.get('direction_successes',{}).get(d,0) > 0 for d in range(4)))}
    def get_exploration_coverage(self, mid):
        m = self.get_current_map_memory(mid); v = len(m['visited_tiles']); o = len(m['obstructions'])
        return v / (v + o) if v > 0 and v + o >= 10 else 0.0

    # =========================================================================
    # TRANSITIONS, BANS, DEBT
    # =========================================================================
    def record_transition(self, fp, fm, tm, d, at):
        m = self.get_current_map_memory(fm)
        for t in m['transitions']:
            if t['position'] == list(fp) and t['direction'] == d: t['use_count'] += 1; t['last_used'] = self.timestep; return
        m['transitions'].append({'position': list(fp), 'direction': d, 'action': at, 'destination_map': tm, 'use_count': 1, 'last_used': self.timestep})
        print(f"  üö™ TRANSITION: Map {fm} ({fp}) ‚Üí Map {tm}")
    def get_transition_attraction(self, cm):
        m = self.get_current_map_memory(cm); ts = m.get('transitions', [])
        if not ts: return 0.0, None
        cd = self.map_novelty_debt.get(cm, 0.0); ctd = self.get_temp_debt(cm); cc = self.get_exploration_coverage(cm)
        ba, bt = 0.0, None
        for t in ts:
            if self.is_transition_banned(cm, t['position'], t['direction']): continue
            dm = t['destination_map']; dd = self.map_novelty_debt.get(dm, 0.0); dtd = self.get_temp_debt(dm); dc = self.get_exploration_coverage(dm)
            a = (cd + ctd*2.0 - dd - dtd*2.0)*0.5 + (cc - dc)*0.5
            if t['use_count'] < 3: a *= 1.5
            if a > ba: ba, bt = a, t
        return ba * self.TRANSITION_ATTRACTION_WEIGHT, bt
    def create_transition_ban(self, mid, tp, db):
        self.transition_bans[mid] = {'banned_tile': tp, 'banned_direction': db, 'vicinity_radius': self.BAN_VICINITY_RADIUS,
            'vicinity_active': False, 'created_at': self.timestep}
    def is_transition_banned(self, mid, pos, d):
        if mid not in self.transition_bans: return False
        b = self.transition_bans[mid]; bt = tuple(b['banned_tile']) if isinstance(b['banned_tile'], list) else b['banned_tile']
        pos = tuple(pos) if isinstance(pos, list) else pos
        if pos == bt and d == b['banned_direction']: return True
        if b['vicinity_active'] and abs(pos[0]-bt[0])+abs(pos[1]-bt[1]) <= b['vicinity_radius'] and d == b['banned_direction']: return True
        return False
    def is_position_banned(self, mid, x, y, d): return self.is_transition_banned(mid, (x,y), d)
    def update_transition_ban(self, mid, cp):
        if mid not in self.transition_bans: return
        b = self.transition_bans[mid]; bt = tuple(b['banned_tile']) if isinstance(b['banned_tile'], list) else b['banned_tile']
        if not b['vicinity_active'] and abs(cp[0]-bt[0])+abs(cp[1]-bt[1]) >= 3: b['vicinity_active'] = True
    def check_ban_lift_conditions(self, mid):
        if mid not in self.transition_bans: return
        b = self.transition_bans[mid]; m = self.get_current_map_memory(mid)
        nb = [t for t in m.get('transitions',[]) if not self.is_transition_banned(mid, t['position'], t['direction'])]
        if nb or self.get_exploration_coverage(mid) >= self.BAN_COVERAGE_LIFT_THRESHOLD or self.timestep - b['created_at'] >= self.BAN_TIMEOUT_STEPS:
            del self.transition_bans[mid]
    def get_temp_debt(self, mid):
        m = self.get_current_map_memory(mid); rd = m.get('temp_debt', 0.0)
        if mid != self.current_map_id:
            return max(0.0, rd - (self.timestep - m.get('last_visited_timestep', 0)) * self.TEMP_DEBT_DECAY)
        return rd
    def accumulate_temp_debt(self, mid):
        m = self.get_current_map_memory(mid); m['temp_debt'] = min(self.TEMP_DEBT_MAX, m.get('temp_debt', 0.0) + self.TEMP_DEBT_ACCUMULATION)
    def decay_all_debts(self):
        for mid in list(self.map_novelty_debt.keys()):
            if mid != self.current_map_id:
                self.map_novelty_debt[mid] *= (1.0 - self.DEBT_DECAY_RATE)
                if self.map_novelty_debt[mid] < 0.1: del self.map_novelty_debt[mid]
    def detect_obstruction(self, pc, cs, rp, prp):
        if pc is None or prp is None or self.last_action not in ['UP','DOWN','LEFT','RIGHT']: return False
        if rp == prp: self.record_obstruction(rp[0], rp[1], int(cs[2]), int(cs[5])); return True
        return False

    # =========================================================================
    # MENU TRAP
    # =========================================================================
    def update_menu_trap_tracking(self, cs, at, rp=None):
        cp = rp if rp else (round(cs[0]*255), round(cs[1]*255))
        if self.menu_trap_position is not None and cp != self.menu_trap_position: self.reset_menu_trap_boost(); return
        if self.get_context_state_hash(cs) == self.last_context_state_hash and at in ["A","B","Start","Select"]:
            self.menu_trap_frames += 1; self.menu_trap_position = cp
            if self.menu_trap_frames > self.MENU_TRAP_THRESHOLD:
                if self.original_b_utility is None:
                    for a in self.actions():
                        if a.action == 'B': self.original_b_utility = a.utility; break
                self.menu_trap_b_boost = min(self.B_BOOST_MAX, self.menu_trap_b_boost + self.B_BOOST_INCREMENT)
        elif cp != self.menu_trap_position: self.reset_menu_trap_boost()
    def reset_menu_trap_boost(self):
        if self.menu_trap_b_boost > 1.0 and self.original_b_utility is not None:
            for a in self.actions():
                if a.action == 'B': a.utility = self.original_b_utility; break
        self.menu_trap_frames = 0; self.menu_trap_b_boost = 1.0; self.menu_trap_position = None; self.original_b_utility = None

    # =========================================================================
    # STAGNATION, PATTERN, MODE (all NEW or fixed)
    # =========================================================================
    def get_context_state_hash(self, cs):
        return (round(cs[0],2), round(cs[1],2), int(cs[2]), int(cs[3]), round(cs[4],2), int(cs[5]))
    def check_state_stagnation(self, cs):
        ch = self.get_context_state_hash(cs)
        if ch == self.last_context_state_hash:
            self.state_stagnation_count += 1
            if self.state_stagnation_count == 1 and self.last_action: self.stagnation_initiator_action = self.last_action
        else: self.state_stagnation_count = 0; self.stagnation_initiator_action = None
        self.last_context_state_hash = ch
        return self.state_stagnation_count >= self.STATE_STAGNATION_THRESHOLD
    def apply_stagnation_initiator_penalty(self):
        if self.stagnation_initiator_action is None: return
        for a in self.actions():
            if a.action == self.stagnation_initiator_action:
                fl = self.INTERACT_UTILITY_FLOOR if a.group == "interact" else self.MOVE_UTILITY_FLOOR
                a.utility = max(fl, a.utility * 0.5); break
        self.stagnation_initiator_action = None
    def should_force_random(self):
        f = self.get_position_stagnation() >= 8 or self.consecutive_action_count >= 15 or \
            (self.detected_pattern and self.pattern_repeat_count >= 4) or \
            self.state_stagnation_count >= self.STATE_STAGNATION_THRESHOLD * 2
        if f: self.try_blend_if_needed()
        return f
    def get_forced_random_action_name(self):
        c = ["UP","DOWN","LEFT","RIGHT","A","B"]
        if self.current_repeated_action in c: c.remove(self.current_repeated_action)
        if self.detected_pattern:
            for a in self.detected_pattern:
                if a in c: c.remove(a)
        return random.choice(c or ["UP","DOWN","LEFT","RIGHT"])
    def check_productive_change(self, cs):
        cm = int(cs[2]); cb = cs[3] > 0.5; cp = (cs[0], cs[1]); p, r = False, ""
        if self.last_map_id is not None and cm != self.last_map_id: p, r = True, "map change"
        if self.last_battle_state is not None and cb != self.last_battle_state: p, r = True, "battle change"
        if self.position_at_mode_swap is not None:
            d = np.sqrt((cp[0]-self.position_at_mode_swap[0])**2 + (cp[1]-self.position_at_mode_swap[1])**2)
            if d > 0.03: p, r = True, f"moved {d*255:.1f}"
        cd = int(cs[5])
        if self.direction_change_counts_as_progress and self.last_direction_for_progress is not None and cd != self.last_direction_for_progress:
            self.state_stagnation_count = max(0, self.state_stagnation_count - 5)
        self.last_direction_for_progress = cd; self.last_map_id = cm; self.last_battle_state = cb
        return p, r
    def on_productive_change(self, r):
        self.move_to_interact_threshold = self.DEFAULT_MOVE_TO_INTERACT_THRESHOLD
        self.interact_to_move_threshold = self.DEFAULT_INTERACT_TO_MOVE_THRESHOLD
        self.swap_chain_count = 0; self.state_stagnation_count = 0; self.stagnation_initiator_action = None; self.unproductive_swap_count = 0
        if self.blend_tier > 0: self.blend_tier = 0
    def on_mode_swap(self, fm, tm):
        self.swap_chain_count += 1; self.frames_in_current_mode = 0; self.unproductive_swap_count += 1
        if self.unproductive_swap_count >= self.UNPRODUCTIVE_SWAP_THRESHOLD:
            self._reset_highest_to_third(tm); self.unproductive_swap_count = 0
        if tm == "interact": self.interact_to_move_threshold = min(self.MAX_THRESHOLD, self.interact_to_move_threshold + self.THRESHOLD_INCREMENT)
        else: self.move_to_interact_threshold = min(self.MAX_THRESHOLD, self.move_to_interact_threshold + self.THRESHOLD_INCREMENT)
    def _reset_highest_to_third(self, mode):
        if mode in ["battle","both"]: return
        g = "move" if mode == "move" else "interact"; ga = sorted([a for a in self.actions() if a.group == g], key=lambda a: a.utility, reverse=True)
        if len(ga) >= 3:
            fl = self.INTERACT_UTILITY_FLOOR if g == "interact" else self.MOVE_UTILITY_FLOOR
            ga[0].utility = max(ga[2].utility * 0.9, fl)
    def should_use_both_mode(self):
        return self.state_stagnation_count > self.BOTH_MODE_STAGNATION_THRESHOLD or self.unproductive_swap_count > self.BOTH_MODE_SWAP_THRESHOLD
    def determine_control_mode(self, cs, raw_position=None):
        if cs[3] > 0.5: return "battle"
        self.frames_in_current_mode += 1; ps = self.get_position_stagnation()
        p, r = self.check_productive_change(cs)
        if p: self.on_productive_change(r)
        if self.should_use_both_mode(): return "both"
        if self.check_state_stagnation(cs):
            self.apply_stagnation_initiator_penalty()
            nm = "interact" if self.control_mode == "move" else "move"
            self.control_mode = nm; self.position_at_mode_swap = (cs[0], cs[1])
            self.on_mode_swap(self.control_mode, nm); self.state_stagnation_count = 0; return self.control_mode
        rx = raw_position[0] if raw_position else int(cs[0]*255); ry = raw_position[1] if raw_position else int(cs[1]*255); cm = int(cs[2])
        tp = self.should_interact_at_tile(rx, ry, cm); ud = self.get_untried_directions(rx, ry, cm)
        if tp and ud and self.control_mode == "move" and self.frames_in_current_mode >= 3:
            self.control_mode = "interact"; self.position_at_mode_swap = (cs[0], cs[1]); self.frames_in_current_mode = 0; return self.control_mode
        if self.control_mode == "move" and ps >= self.move_to_interact_threshold:
            self.control_mode = "interact"; self.position_at_mode_swap = (cs[0], cs[1]); self.on_mode_swap("move", "interact")
        elif self.control_mode == "interact":
            if (not tp or not ud) and self.frames_in_current_mode >= 5:
                self.control_mode = "move"; self.position_at_mode_swap = (cs[0], cs[1]); self.frames_in_current_mode = 0
            elif self.frames_in_current_mode >= self.interact_to_move_threshold:
                self.control_mode = "move"; self.position_at_mode_swap = (cs[0], cs[1]); self.on_mode_swap("interact", "move")
        return self.control_mode

    # =========================================================================
    # REPETITION & PATTERN
    # =========================================================================
    def track_consecutive_action(self, an):
        if an == self.current_repeated_action: self.consecutive_action_count += 1
        else: self.current_repeated_action = an; self.consecutive_action_count = 1
    def get_learning_multiplier(self, an):
        if an != self.current_repeated_action or self.consecutive_action_count < self.LEARNING_SLOWDOWN_START: return 1.0
        return max(0.05, 1.0 - 0.95 * min(1.0, (self.consecutive_action_count - self.LEARNING_SLOWDOWN_START) / (self.LEARNING_SLOWDOWN_MAX - self.LEARNING_SLOWDOWN_START)))
    def detect_pattern(self):
        if len(self.action_history) < 6: return None, 0
        recent = list(self.action_history)[-self.PATTERN_CHECK_WINDOW:]
        for pl in range(1, self.PATTERN_MAX_LENGTH + 1):
            if len(recent) < pl * self.PATTERN_MIN_REPEATS: continue
            cand = tuple(recent[-pl:]); rc, ix = 0, len(recent) - pl
            while ix >= 0 and tuple(recent[ix:ix+pl]) == cand: rc += 1; ix -= pl
            if rc >= self.PATTERN_MIN_REPEATS: return cand, rc
        return None, 0
    def apply_pattern_penalty(self):
        pat, rc = self.detect_pattern()
        if pat is None: self.detected_pattern = None; self.pattern_repeat_count = 0; return
        self.detected_pattern, self.pattern_repeat_count = pat, rc
        pf = max(0.3, 1.0 - rc * 0.15)
        for an in set(pat):
            for a in self.actions():
                if a.action == an:
                    fl = self.INTERACT_UTILITY_FLOOR if a.group == "interact" else self.MOVE_UTILITY_FLOOR
                    a.utility = max(fl, a.utility * pf); break
    def apply_repetition_penalty(self):
        if self.current_repeated_action is None: return
        for a in self.actions():
            if a.action == self.current_repeated_action:
                fl = self.INTERACT_UTILITY_FLOOR if a.group == "interact" else self.MOVE_UTILITY_FLOOR
                if self.consecutive_action_count >= self.HARD_RESET_THRESHOLD:
                    a.utility = fl; self.consecutive_action_count = 0
                elif self.consecutive_action_count >= self.PENALTY_THRESHOLD:
                    a.utility = max(a.utility * 0.5, fl)
                break

    # =========================================================================
    # EXPLORATION TRACKING
    # =========================================================================
    def update_exploration_tracking(self, cs, pcs, rp=None, prp=None):
        cm = int(cs[2]); rx = rp[0] if rp else int(cs[0]*255); ry = rp[1] if rp else int(cs[1]*255); cp = (rx, ry)
        if self.current_map_id is not None and cm != self.current_map_id:
            pm = self.current_map_id
            if pcs is not None and prp is not None:
                self.record_transition(prp, pm, cm, int(pcs[5]), 'interact' if self.last_action == 'A' else 'walk')
            if prp is not None:
                ed = int(cs[5]) if pcs is not None else 0
                self.create_transition_ban(cm, cp, (ed + 2) % 4)
            self.on_map_change(cm)
        self.current_map_id = cm; self.record_visited_tile(rx, ry, cm); self.accumulate_temp_debt(cm)
        self.update_transition_ban(cm, cp); self.check_ban_lift_conditions(cm)
        if pcs is not None and prp is not None: self.detect_obstruction(pcs, cs, rp, prp)
        self.check_interaction_verification(cs, pcs); self.last_direction = int(cs[5])
        if self.timestep % 300 == 0: self.decay_all_debts()
    def on_map_change(self, nm):
        self.save_exploration_memory(); self.control_mode = "move"; self.frames_in_current_mode = 0
        m = self.get_current_map_memory(nm)
        print(f"  üó∫Ô∏è MAP CHANGE ‚Üí {nm}: {len(m['visited_tiles'])} visited, {len(m['obstructions'])} obs")

    # =========================================================================
    # ENTITY & LEARNING
    # =========================================================================
    def spawn_innate_entities(self, ls):
        if self.innate_entities_spawned: return
        for et, ix in [("sense_menu",[5,6]),("sense_battle",[3,4]),("sense_movement",[0,1]),("sense_map_transition",[2])]:
            e = Perceptron("entity", entity_type=et); e.ensure_weights(len(ls)); e.weights = np.zeros(len(ls))
            for i in ix:
                if i < len(e.weights): e.weights[i] = 0.5 if len(ix) > 1 else 1.0
            self.add(e)
        self.innate_entities_spawned = True
    def enforce_utility_floors(self):
        for a in self.actions():
            a.utility = max(a.utility, self.MOVE_UTILITY_FLOOR if a.group == "move" else self.INTERACT_UTILITY_FLOOR)
    def stagnation_level(self, w=10):
        if len(self.prev_learning_states) < w: return 0.0
        r = list(self.prev_learning_states)[-w:]
        return 1.0 - np.tanh(np.mean([np.linalg.norm(r[i][:min(len(r[i]),len(r[i-1]))] - r[i-1][:min(len(r[i]),len(r[i-1]))]) for i in range(1, len(r))]) * 2.0)
    def predict_future_error(self, st, ac, cs, rp=None):
        en = np.mean([e.predict(st) * e.utility for e in self.entities()]) if self.entities() else 0.5
        comb = en * 0.7 + ac.utility * 0.3; cm = int(cs[2])
        loc = self.get_location_key(*(rp if rp else (cs[0]*255, cs[1]*255)), cm)
        td = min(self.map_novelty_debt.get(cm, 0.0), self.MAX_MAP_DEBT) + self.get_temp_debt(cm) + min(self.location_novelty.get(loc, 0.0), self.MAX_LOCATION_DEBT) * 0.5
        comb *= 1.0 / (1.0 + td * 5.0)
        if ac.action == self.current_repeated_action and self.consecutive_action_count > self.LEARNING_SLOWDOWN_START:
            comb *= 1.0 / (1.0 + (self.consecutive_action_count - self.LEARNING_SLOWDOWN_START) * 0.15)
        if self.detected_pattern and ac.action in self.detected_pattern:
            comb *= 1.0 / (1.0 + self.pattern_repeat_count * 0.2)
        return comb + np.random.randn() * 0.05
    def compute_multi_modal_error(self, s, ns):
        ml = min(len(s), len(ns)); d = [abs(ns[i]-s[i]) for i in range(min(8, ml))]
        w = [0.5,0.5,10.0,5.0,3.0,2.0,1.5,0.3]
        we = sum(di*wi for di, wi in zip(d, w[:len(d)])) + (np.linalg.norm(ns[8:ml]-s[8:ml])*2.0 if ml > 8 else 0.0)
        return we, sum(d), (np.linalg.norm(ns[8:ml]-s[8:ml]) if ml > 8 else 0.0)

    def learn(self, ls, nls, cs, ncs, dead=False, raw_position=None, next_raw_position=None):
        if ls.shape != nls.shape:
            md = max(len(ls), len(nls)); ls = np.pad(ls, (0, max(0, md-len(ls)))); nls = np.pad(nls, (0, max(0, md-len(nls))))
        if not self.innate_entities_spawned: self.spawn_innate_entities(ls)
        pc = self.prev_context_states[-1] if self.prev_context_states else None
        pr = getattr(self, '_last_raw_position', None)
        self.update_exploration_tracking(cs, pc, raw_position, pr); self._last_raw_position = raw_position
        we, ne, ve = self.compute_multi_modal_error(ls, nls)
        self.error_history.append(we); self.numeric_error_history.append(ne); self.visual_error_history.append(ve)
        cm = int(cs[2]); loc = self.get_location_key(*(raw_position if raw_position else (cs[0]*255, cs[1]*255)), cm)
        self.visited_maps[cm] = self.visited_maps.get(cm, 0) + 1
        if len(self.location_memory) < self.MAX_LOCATIONS: self.location_memory[loc] = self.location_memory.get(loc, 0) + 1
        if self.visited_maps[cm] > 10: self.map_novelty_debt[cm] = min(self.MAX_MAP_DEBT, self.map_novelty_debt.get(cm, 0.0) + 0.05*(self.visited_maps[cm]-10))
        if self.location_memory.get(loc, 0) > 15: self.location_novelty[loc] = min(self.MAX_LOCATION_DEBT, self.location_novelty.get(loc, 0.0) + 0.1*(self.location_memory.get(loc,0)-15))
        if self.visited_maps[cm] > 30: we *= 0.5
        if self.location_memory.get(loc, 0) > 25: we *= 0.7
        stag = self.stagnation_level(); lm = self.get_learning_multiplier(self.last_action) if self.last_action else 1.0
        if self.detected_pattern and self.last_action and self.last_action in self.detected_pattern: lm *= 0.5
        for p in self.perceptrons:
            m = lm if (p.kind == "action" and p.action == self.last_action) else 1.0
            if p.kind == "action" and self.detected_pattern and p.action in self.detected_pattern: m *= 0.5
            p.update(ls, we * m, stagnation=stag)
        for a in self.actions():
            if a.action in ['Start','Select'] and a.weights is not None: a.weights *= 0.999
        self.apply_repetition_penalty(); self.apply_pattern_penalty(); self.enforce_utility_floors()
        if pc is not None and np.linalg.norm(cs[:2] - pc[:2]) > 0.001 and self.last_action and self.consecutive_action_count < self.PENALTY_THRESHOLD:
            for a in self.actions():
                if a.action == self.last_action:
                    a.utility = min(a.utility * (1.15 if raw_position and self.is_near_map_edge(*raw_position) else 1.08), 2.0); break
        if self.timestep % 1000 == 0: self.cleanup_memory()
        if self.timestep % self.SAVE_INTERVAL == 0: self.save_exploration_memory()
        self.action_history.append(self.last_action)

    # =========================================================================
    # SAVE/LOAD ‚Äî FIX: model saves to taught_model_checkpoint.json
    # =========================================================================
    def save_model_checkpoint(self, filepath=None):
        # FIX: Default to taught_model_checkpoint.json
        if filepath is None: filepath = BASE_PATH / "taught_model_checkpoint.json"
        model = {"timestep": self.timestep, "perceptrons": {"actions": [], "entities": []},
            "debt_tracking": {"map_novelty_debt": {str(k): v for k, v in self.map_novelty_debt.items()},
                "location_novelty": {str(k): v for k, v in self.location_novelty.items()},
                "visited_maps": {str(k): v for k, v in self.visited_maps.items()}},
            "control_mode": self.control_mode,
            "markov_stats": {"markov_action_count": self.markov_action_count, "curiosity_action_count": self.curiosity_action_count},
            "blend_stats": {"blend_count": self.blend_count, "last_blend_tier": self.blend_tier}}
        for a in self.actions():
            model["perceptrons"]["actions"].append({"action": a.action, "group": a.group, "utility": float(a.utility),
                "weights_shape": len(a.weights) if a.weights is not None else 0,
                "weights_nonzero": [[i, float(v)] for i, v in enumerate(a.weights) if abs(v) > 1e-10] if a.weights is not None else [],
                "learning_rate": float(a.learning_rate), "familiarity": float(a.familiarity)})
        for e in self.entities():
            model["perceptrons"]["entities"].append({"entity_type": e.entity_type, "utility": float(e.utility),
                "weights_shape": len(e.weights) if e.weights is not None else 0,
                "weights_nonzero": [[i, float(v)] for i, v in enumerate(e.weights) if abs(v) > 1e-10] if e.weights is not None else [],
                "familiarity": float(e.familiarity)})
        try:
            with open(filepath, 'w') as f: json.dump(model, f, indent=2)
            print(f"üíæ Model saved: step {self.timestep} ‚Üí {filepath}")
        except Exception as e: print(f"‚ùå Save error: {e}")

    def load_taught_model(self, fp):
        if not Path(fp).exists(): return 0
        try:
            with open(fp, 'r') as f: model = json.load(f)
            if "perceptrons" not in model: return 0
            for sa in model["perceptrons"]["actions"]:
                for a in self.actions():
                    if a.action == sa["action"]:
                        a.utility = sa["utility"]; a.learning_rate = sa.get("learning_rate", 0.01); a.familiarity = sa.get("familiarity", 0.0)
                        if sa.get("weights_nonzero"):
                            dim = sa.get("weights_shape", 1376); a.weights = np.zeros(dim)
                            for idx, val in sa["weights_nonzero"]:
                                if idx < dim: a.weights[idx] = val
                        break
            for se in model["perceptrons"].get("entities", []):
                for e in self.entities():
                    if e.entity_type == se.get("entity_type"):
                        e.utility = se.get("utility", 1.0); e.familiarity = se.get("familiarity", 0.0)
                        if se.get("weights_nonzero"):
                            dim = se.get("weights_shape", 1376); e.weights = np.zeros(dim)
                            for idx, val in se["weights_nonzero"]:
                                if idx < dim: e.weights[idx] = val
                        break
            if "debt_tracking" in model:
                d = model["debt_tracking"]; self.map_novelty_debt = {int(k): v for k, v in d.get("map_novelty_debt", {}).items()}
                self.visited_maps = {int(k): v for k, v in d.get("visited_maps", {}).items()}
                for k, v in d.get("location_novelty", {}).items():
                    try: self.location_novelty[eval(k)] = v
                    except: pass
            ms = model.get("markov_stats", {}); self.markov_action_count = ms.get("markov_action_count", 0); self.curiosity_action_count = ms.get("curiosity_action_count", 0)
            bs = model.get("blend_stats", {}); self.blend_count = bs.get("blend_count", 0); self.blend_tier = bs.get("last_blend_tier", 0)
            self.timestep = model.get("timestep", 0); return self.timestep
        except Exception as e: print(f"  ‚ö†Ô∏è Load error: {e}"); return 0
    def save_model(self, fp=None): self.save_model_checkpoint(fp)
    def load_model(self, fp=None): return self.load_taught_model(fp or (BASE_PATH / "taught_model_checkpoint.json"))

In [None]:
# ============================================================================
# CELL 6: Teaching Mode Main Loop + Auto Post-Processing on Shutdown
# ============================================================================
# Human plays, Brain learns. On Ctrl+C or every 2000 steps:
#   1. Saves taught_model_checkpoint.json
#   2. Saves taught_exploration_memory.json
#   3. Runs per_map_analysis ‚Üí updates taught_transitions.json
#   4. Runs nav target extraction ‚Üí writes taught_nav_targets.json
#   5. Runs battle extraction ‚Üí writes taught_battle_transitions.json
#   All 5 taught files ready to copy to colleague
# ============================================================================

from collections import defaultdict

# =========================================================================
# POST-PROCESSOR 1: per_map_analysis
# =========================================================================
def run_per_map_analysis():
    if not TAUGHT_TRANSITIONS_FILE.exists():
        print("  ‚ö†Ô∏è No taught_transitions.json ‚Äî skipping per_map_analysis")
        return
    with open(TAUGHT_TRANSITIONS_FILE, 'r') as f:
        data = json.load(f)
    batches = data.get('batches', [])
    if not batches:
        print("  ‚ö†Ô∏è No batches ‚Äî skipping per_map_analysis")
        return
    all_frames = []
    for batch in batches:
        for frame in batch.get('frames', []):
            all_frames.append(frame)
    print(f"  Analyzing {len(all_frames)} frames...")
    frames_by_map = defaultdict(list)
    for i, frame in enumerate(all_frames):
        mid = frame.get('state', {}).get('map_id')
        if mid is not None:
            frames_by_map[mid].append((i, frame))
    per_map = {}
    for map_id, indexed_frames in frames_by_map.items():
        mk = str(map_id)
        # Action probabilities
        ta = defaultdict(lambda: defaultdict(int))
        td = defaultdict(lambda: defaultdict(int))
        tt = defaultdict(int)
        for i, fr in indexed_frames:
            s = fr.get('state', {}); x, y = s.get('x', 0), s.get('y', 0)
            tk = f"{x}_{y}"; act = fr.get('action', 'NONE'); dr = s.get('direction', 0)
            ta[tk][act] += 1; td[tk][str(dr)] += 1; tt[tk] += 1
        ap = {}
        for tk in ta:
            total = tt[tk]
            if total == 0: continue
            probs = {a: round(c/total, 3) for a, c in ta[tk].items()}
            probs['total_frames'] = total; probs['direction_facing'] = dict(td[tk]); ap[tk] = probs
        # Movement graph
        mg = defaultdict(set)
        for idx in range(len(indexed_frames) - 1):
            _, f1 = indexed_frames[idx]; _, f2 = indexed_frames[idx+1]
            s1, s2 = f1.get('state', {}), f2.get('state', {})
            if s1.get('map_id') != s2.get('map_id'): continue
            x1, y1 = s1.get('x',0), s1.get('y',0); x2, y2 = s2.get('x',0), s2.get('y',0)
            if (x1,y1) != (x2,y2):
                t1, t2 = f"{x1}_{y1}", f"{x2}_{y2}"; mg[t1].add(t2); mg[t2].add(t1)
        mg_s = {k: sorted(list(v)) for k, v in mg.items()}
        # Decision points
        dp = []
        for idx in range(1, len(indexed_frames)):
            _, fp = indexed_frames[idx-1]; gi, fc = indexed_frames[idx]
            ap2, ac = fp.get('action','NONE'), fc.get('action','NONE')
            if ap2 != ac and ac != 'NONE' and ap2 != 'NONE':
                s = fc.get('state', {})
                dp.append({'position': [s.get('x',0), s.get('y',0)], 'from_action': ap2, 'to_action': ac,
                    'frame': gi, 'facing': s.get('direction',0),
                    'context': {'in_battle': s.get('in_battle',0), 'in_menu': s.get('in_menu',0)}})
        # Dwell times
        dd = defaultdict(lambda: {'visits': 0, 'frames': [], 'current_run': 0}); lt = None
        for i, fr in indexed_frames:
            s = fr.get('state', {}); tk = f"{s.get('x',0)}_{s.get('y',0)}"
            if tk == lt: dd[tk]['current_run'] += 1
            else:
                if lt is not None and dd[lt]['current_run'] > 0: dd[lt]['frames'].append(dd[lt]['current_run'])
                dd[tk]['visits'] += 1; dd[tk]['current_run'] = 1; lt = tk
        if lt is not None and dd[lt]['current_run'] > 0: dd[lt]['frames'].append(dd[lt]['current_run'])
        dtimes = {}
        for tk, d in dd.items():
            runs = d['frames']
            if not runs: continue
            total = sum(runs)
            dtimes[tk] = {'total_frames': total, 'visits': d['visits'], 'avg_dwell': round(total/len(runs),1), 'max_dwell': max(runs)}
        # Path segments
        ps = []; cs = None
        for idx in range(len(indexed_frames)):
            gi, fr = indexed_frames[idx]; s = fr.get('state', {}); act = fr.get('action', 'NONE')
            x, y = s.get('x',0), s.get('y',0)
            if act not in ('UP','DOWN','LEFT','RIGHT'):
                if cs and len(cs['tiles']) >= 3:
                    cs['end'] = cs['tiles'][-1]; cs['length'] = len(cs['tiles']); cs['frame_end'] = gi; ps.append(cs)
                cs = None; continue
            if cs is None or act != cs['primary_action']:
                if cs and len(cs['tiles']) >= 3:
                    cs['end'] = cs['tiles'][-1]; cs['length'] = len(cs['tiles']); cs['frame_end'] = gi-1; ps.append(cs)
                cs = {'start': [x,y], 'end': [x,y], 'tiles': [[x,y]], 'primary_action': act, 'actions': [act], 'length': 1, 'frame_start': gi, 'frame_end': gi}
            else:
                pos = [x, y]
                if pos != cs['tiles'][-1]: cs['tiles'].append(pos)
                cs['actions'].append(act)
        if cs and len(cs['tiles']) >= 3:
            cs['end'] = cs['tiles'][-1]; cs['length'] = len(cs['tiles']); cs['frame_end'] = indexed_frames[-1][0]; ps.append(cs)
        per_map[mk] = {'action_probabilities': ap, 'movement_graph': mg_s, 'decision_points': dp, 'dwell_times': dtimes, 'path_segments': ps}
        print(f"    Map {map_id}: {len(ap)} tiles, {len(dp)} decisions, {len(ps)} paths")
    data['per_map_analysis'] = per_map
    with open(TAUGHT_TRANSITIONS_FILE, 'w') as f:
        json.dump(data, f)
    print(f"  ‚úÖ per_map_analysis ‚Üí {TAUGHT_TRANSITIONS_FILE.name}")


# =========================================================================
# POST-PROCESSOR 2: taught_nav_targets.json
# =========================================================================
NAV_TARGETS_PATH = BASE_PATH / "taught_nav_targets.json"
ANALYSIS_WINDOW_AFTER = 40
RECENT_WINDOW = 100
MIN_FORWARD_PROGRESS = 0.5
DEDUP_RADIUS_NAV = 2
BACKTRACK_WINDOW = 50
BACKTRACK_PROXIMITY = 3

def run_nav_target_extraction():
    if not TAUGHT_TRANSITIONS_FILE.exists():
        print("  ‚ö†Ô∏è No taught_transitions.json ‚Äî writing empty nav targets")
        _write_empty_nav_targets(); return
    with open(TAUGHT_TRANSITIONS_FILE, 'r') as f:
        data = json.load(f)
    all_frames = []
    for batch in data.get('batches', []):
        for frame in batch.get('frames', []): all_frames.append(frame)
    if not all_frames:
        print("  ‚ö†Ô∏è No frames ‚Äî writing empty nav targets"); _write_empty_nav_targets(); return
    print(f"  Scanning {len(all_frames)} frames for novelty...")
    novelty_points = []
    for i, frame in enumerate(all_frames):
        s = frame.get('state', {}); x, y = s.get('x',0), s.get('y',0)
        mid = s.get('map_id',0); d = s.get('direction',0)
        ib, im = s.get('in_battle',0), s.get('in_menu',0)
        act = frame.get('action', 'NONE')
        ps = all_frames[i-1].get('state', {}) if i > 0 else {}
        pmid, pib = ps.get('map_id', mid), ps.get('in_battle', 0)
        if i > 0 and mid != pmid:
            px, py = ps.get('x', x), ps.get('y', y)
            novelty_points.append({'position': [px,py], 'map_id': pmid, 'direction': ps.get('direction',d),
                'frame_index': i, 'novelty_type': 'map_transition', 'destination_map': mid}); continue
        if ib == 1 and pib == 0:
            novelty_points.append({'position': [x,y], 'map_id': mid, 'direction': d,
                'frame_index': i, 'novelty_type': 'battle', 'destination_map': None}); continue
        if act == 'A' and ib == 0 and im == 0:
            triggered = False
            for j in range(i+1, min(i+9, len(all_frames))):
                fs = all_frames[j].get('state', {})
                if fs.get('in_menu', 0) != im or fs.get('map_id', mid) != mid: triggered = True; break
            if triggered:
                novelty_points.append({'position': [x,y], 'map_id': mid, 'direction': d,
                    'frame_index': i, 'novelty_type': 'interaction', 'destination_map': None}); continue
        if ib == 0 and im == 0 and i > RECENT_WINDOW:
            was_recent = any(all_frames[j].get('state',{}).get('map_id')==mid and all_frames[j].get('state',{}).get('x')==x and all_frames[j].get('state',{}).get('y')==y for j in range(max(0,i-RECENT_WINDOW), max(0,i-5)))
            if not was_recent:
                too_close = novelty_points and novelty_points[-1]['map_id']==mid and abs(novelty_points[-1]['position'][0]-x)+abs(novelty_points[-1]['position'][1]-y)<=DEDUP_RADIUS_NAV
                if not too_close:
                    novelty_points.append({'position': [x,y], 'map_id': mid, 'direction': d,
                        'frame_index': i, 'novelty_type': 'new_area', 'destination_map': None})
    tc = defaultdict(int)
    for np_item in novelty_points: tc[np_item['novelty_type']] += 1
    print(f"    Novelty points: {len(novelty_points)} ({', '.join(f'{t}:{c}' for t,c in tc.items())})")
    scored = []
    for np_item in novelty_points:
        fi = np_item['frame_index']; mid = np_item['map_id']; px, py = np_item['position']
        before = set()
        for j in range(max(0,fi-RECENT_WINDOW), fi):
            js = all_frames[j].get('state', {})
            if js.get('in_battle',0)==0 and js.get('in_menu',0)==0:
                before.add((js.get('map_id',0), js.get('x',0), js.get('y',0)))
        after_new, after_total = 0, 0
        for j in range(fi+1, min(fi+1+ANALYSIS_WINDOW_AFTER, len(all_frames))):
            js = all_frames[j].get('state', {})
            if js.get('in_battle',0)==1 or js.get('in_menu',0)==1: continue
            after_total += 1
            if (js.get('map_id',0), js.get('x',0), js.get('y',0)) not in before: after_new += 1
        fwd = after_new / after_total if after_total > 0 else 0.0
        bt = False
        if np_item['novelty_type'] == 'map_transition':
            for j in range(fi+1, min(fi+1+BACKTRACK_WINDOW, len(all_frames))):
                if all_frames[j].get('state',{}).get('map_id') == mid: bt = True; break
        else:
            for j in range(fi+5, min(fi+1+BACKTRACK_WINDOW, len(all_frames))):
                js = all_frames[j].get('state', {})
                if js.get('map_id')==mid and abs(js.get('x',0)-px)+abs(js.get('y',0)-py)<=BACKTRACK_PROXIMITY: bt = True; break
        if bt or fwd < MIN_FORWARD_PROGRESS: continue
        np_item['forward_progress_score'] = round(fwd, 3); scored.append(np_item)
    print(f"    After filtering: {len(scored)} targets")
    deduped = []
    by_map = defaultdict(list)
    for t in scored: by_map[t['map_id']].append(t)
    for mid, targets in by_map.items():
        targets.sort(key=lambda t: t['forward_progress_score'], reverse=True)
        kept = []
        for t in targets:
            tx, ty = t['position']
            if not any(abs(tx-k['position'][0])+abs(ty-k['position'][1])<=DEDUP_RADIUS_NAV for k in kept): kept.append(t)
        deduped.extend(kept)
    deduped.sort(key=lambda t: t['frame_index'])
    tbm = defaultdict(list); go = []
    for order, t in enumerate(deduped, 1):
        mk = str(t['map_id'])
        tbm[mk].append({'position': t['position'], 'direction': t['direction'], 'order': order,
            'progress_type': t['novelty_type'], 'forward_progress_score': t['forward_progress_score'],
            'destination_map': t.get('destination_map'), 'frame_index': t['frame_index']})
        go.append({'map_id': t['map_id'], 'position': t['position'], 'order': order})
    output = {'targets_by_map': dict(tbm), 'global_order': go,
        'metadata': {'total_targets': len(deduped), 'maps_with_targets': sorted(set(t['map_id'] for t in deduped)),
            'analysis_window_after': ANALYSIS_WINDOW_AFTER, 'min_forward_progress': MIN_FORWARD_PROGRESS,
            'dedup_radius': DEDUP_RADIUS_NAV, 'generated_from_frames': len(all_frames)}}
    with open(NAV_TARGETS_PATH, 'w') as f:
        json.dump(output, f, indent=2)
    print(f"  ‚úÖ taught_nav_targets.json ‚Üí {len(deduped)} targets across {len(tbm)} maps")

def _write_empty_nav_targets():
    with open(NAV_TARGETS_PATH, 'w') as f:
        json.dump({'targets_by_map': {}, 'global_order': [], 'metadata': {'total_targets': 0, 'maps_with_targets': [],
            'analysis_window_after': ANALYSIS_WINDOW_AFTER, 'min_forward_progress': MIN_FORWARD_PROGRESS,
            'dedup_radius': DEDUP_RADIUS_NAV, 'generated_from_frames': 0}}, f, indent=2)


# =========================================================================
# POST-PROCESSOR 3: taught_battle_transitions.json
# =========================================================================
BATTLE_TRANSITIONS_PATH = BASE_PATH / "taught_battle_transitions.json"

# Known Pokemon Center map IDs (loss detection ‚Äî player respawns here)
POKEMON_CENTER_MAPS = {1, 2, 3, 4}  # Adjust these to actual PC map IDs in your ROM

def run_battle_extraction():
    """
    Extract all battle frames from taught_transitions.json into
    taught_battle_transitions.json.
    
    Produces:
    - battle_sequences: grouped by individual battles with outcomes
    - flat_frames: flattened for AI's Markov scanning
    - metadata: battle counts, common sequences, outcomes
    """
    if not TAUGHT_TRANSITIONS_FILE.exists():
        print("  ‚ö†Ô∏è No taught_transitions.json ‚Äî writing empty battle transitions")
        _write_empty_battle_transitions(); return
    
    with open(TAUGHT_TRANSITIONS_FILE, 'r') as f:
        data = json.load(f)
    
    # Flatten all frames in order
    all_frames = []
    for batch in data.get('batches', []):
        bt = batch.get('batch_type', 'steady')
        for frame in batch.get('frames', []):
            frame['_batch_type'] = bt
            all_frames.append(frame)
    
    if not all_frames:
        print("  ‚ö†Ô∏è No frames ‚Äî writing empty battle transitions")
        _write_empty_battle_transitions(); return
    
    print(f"  Scanning {len(all_frames)} frames for battles...")
    
    # === STEP 1: Find battle sequences ===
    battle_sequences = []
    current_battle = None
    battle_id = 0
    
    for i, frame in enumerate(all_frames):
        s = frame.get('state', {})
        ib = s.get('in_battle', 0)
        
        if ib == 1 and current_battle is None:
            # Battle just started
            battle_id += 1
            current_battle = {
                'battle_id': battle_id,
                'start_frame': i,
                'end_frame': i,
                'map_id': s.get('map_id', 0),
                'frames': [],
                'recent_actions_buffer': []
            }
        
        if ib == 1 and current_battle is not None:
            # In battle ‚Äî record frame
            action = frame.get('action', 'NONE')
            recent = frame.get('recent_actions', [])
            
            # Ensure recent_actions is exactly 8 elements
            if len(recent) < 8:
                recent = (['NONE'] * (8 - len(recent))) + recent
            elif len(recent) > 8:
                recent = recent[-8:]
            
            battle_frame = {
                'state': {
                    'map_id': s.get('map_id', 0),
                    'x': s.get('x', 0),
                    'y': s.get('y', 0),
                    'direction': s.get('direction', 0),
                    'in_battle': 1,
                    'in_menu': s.get('in_menu', 0)
                },
                'action': action if action else 'NONE',
                'recent_actions': recent,
                'frame_offset': frame.get('frame_offset', 0),
                'batch_type': frame.get('_batch_type', 'steady')
            }
            
            current_battle['frames'].append(battle_frame)
            current_battle['end_frame'] = i
            if action and action != 'NONE':
                current_battle['recent_actions_buffer'].append(action)
        
        elif ib == 0 and current_battle is not None:
            # Battle just ended ‚Äî finalize
            duration = len(current_battle['frames'])
            
            if duration >= 2:  # Only keep battles with at least 2 frames
                # Detect outcome
                outcome = _detect_battle_outcome(
                    current_battle, all_frames, i
                )
                
                battle_seq = {
                    'battle_id': current_battle['battle_id'],
                    'start_frame': current_battle['start_frame'],
                    'end_frame': current_battle['end_frame'],
                    'map_id': current_battle['map_id'],
                    'duration_frames': duration,
                    'outcome': outcome,
                    'frames': current_battle['frames']
                }
                battle_sequences.append(battle_seq)
            
            current_battle = None
    
    # Handle battle still active at end of data
    if current_battle is not None and len(current_battle['frames']) >= 2:
        battle_sequences.append({
            'battle_id': current_battle['battle_id'],
            'start_frame': current_battle['start_frame'],
            'end_frame': current_battle['end_frame'],
            'map_id': current_battle['map_id'],
            'duration_frames': len(current_battle['frames']),
            'outcome': 'unknown',
            'frames': current_battle['frames']
        })
    
    # === STEP 2: Build flat_frames ===
    flat_frames = []
    for seq in battle_sequences:
        for frame in seq['frames']:
            flat_frame = dict(frame)
            flat_frame['battle_id'] = seq['battle_id']
            flat_frames.append(flat_frame)
    
    # === STEP 3: Compute metadata ===
    outcomes = defaultdict(int)
    maps_with_battles = set()
    for seq in battle_sequences:
        outcomes[seq['outcome']] += 1
        maps_with_battles.add(seq['map_id'])
    
    avg_length = (sum(s['duration_frames'] for s in battle_sequences) / 
                  len(battle_sequences)) if battle_sequences else 0
    
    # Most common action sequences (windows of 2-4 actions)
    seq_counts = defaultdict(int)
    for seq in battle_sequences:
        actions = [f['action'] for f in seq['frames'] if f['action'] != 'NONE']
        for win_size in [2, 3, 4]:
            for j in range(len(actions) - win_size + 1):
                window = tuple(actions[j:j+win_size])
                seq_counts[window] += 1
    
    common_seqs = sorted(seq_counts.items(), key=lambda x: x[1], reverse=True)[:10]
    most_common = []
    for seq_tuple, count in common_seqs:
        context = _guess_sequence_context(list(seq_tuple))
        most_common.append({
            'sequence': list(seq_tuple),
            'count': count,
            'context': context
        })
    
    metadata = {
        'total_battle_frames': len(flat_frames),
        'battles_recorded': len(battle_sequences),
        'avg_battle_length': round(avg_length, 1),
        'outcomes': dict(outcomes),
        'maps_with_battles': sorted(maps_with_battles),
        'most_common_sequences': most_common
    }
    
    # === STEP 4: Write output ===
    output = {
        'battle_sequences': battle_sequences,
        'flat_frames': flat_frames,
        'metadata': metadata
    }
    
    with open(BATTLE_TRANSITIONS_PATH, 'w') as f:
        json.dump(output, f)
    
    print(f"  ‚úÖ taught_battle_transitions.json:")
    print(f"     Battles: {len(battle_sequences)} | Frames: {len(flat_frames)}")
    print(f"     Outcomes: {dict(outcomes)}")
    print(f"     Avg length: {avg_length:.1f} frames")
    if most_common:
        print(f"     Top sequences:")
        for s in most_common[:3]:
            print(f"       {s['sequence']} x{s['count']} ({s['context']})")


def _detect_battle_outcome(battle, all_frames, end_index):
    """
    Detect battle outcome based on what happened.
    
    win: default ‚Äî battle ended normally
    run: human selected Run (action sequence contains Run pattern)
    loss: human was teleported to Pokemon Center after battle
    """
    actions = battle.get('recent_actions_buffer', [])
    
    # Run detection: typically DOWN, DOWN, A from battle main menu
    # Or the last few actions before battle end contain the Run navigation pattern
    if len(actions) >= 3:
        last_actions = actions[-6:] if len(actions) >= 6 else actions
        # Common run patterns: navigate to Run option then select
        # In FireRed: Fight is top-left, Bag top-right, Pokemon bottom-left, Run bottom-right
        # To reach Run from Fight: DOWN then RIGHT then A, or just DOWN DOWN A depending on cursor
        run_patterns = [
            ['DOWN', 'RIGHT', 'A'],
            ['DOWN', 'A'],  # If cursor was on Pokemon
            ['RIGHT', 'DOWN', 'A'],
        ]
        last_str = last_actions
        for pattern in run_patterns:
            plen = len(pattern)
            for k in range(len(last_str) - plen + 1):
                if last_str[k:k+plen] == pattern:
                    # Check if this was near the end of battle
                    if k >= len(last_str) - plen - 2:
                        return 'run'
    
    # Loss detection: after battle ended, did player appear in Pokemon Center?
    if end_index + 5 < len(all_frames):
        for j in range(end_index, min(end_index + 10, len(all_frames))):
            post_map = all_frames[j].get('state', {}).get('map_id', -1)
            if post_map in POKEMON_CENTER_MAPS:
                # Player was sent to Pokemon Center ‚Äî likely a loss
                if post_map != battle.get('map_id', -1):
                    return 'loss'
    
    return 'win'


def _guess_sequence_context(seq):
    """Guess what a battle action sequence represents."""
    if seq == ['A', 'A', 'A', 'A']:
        return 'mashing_A_through_text_or_selecting_fight'
    if seq == ['A', 'A']:
        return 'selecting_fight_and_move'
    if 'DOWN' in seq and 'A' in seq:
        return 'navigating_menu_then_selecting'
    if seq == ['B', 'B']:
        return 'cancelling_or_backing_out'
    if all(a == 'A' for a in seq):
        return 'mashing_A'
    return 'battle_input_sequence'


def _write_empty_battle_transitions():
    with open(BATTLE_TRANSITIONS_PATH, 'w') as f:
        json.dump({
            'battle_sequences': [],
            'flat_frames': [],
            'metadata': {
                'total_battle_frames': 0,
                'battles_recorded': 0,
                'avg_battle_length': 0,
                'outcomes': {},
                'maps_with_battles': [],
                'most_common_sequences': []
            }
        }, f, indent=2)


# =========================================================================
# MAIN TEACHING LOOP
# =========================================================================
brain = Brain()

for b in ["UP", "DOWN", "LEFT", "RIGHT"]:
    brain.add(Perceptron("action", action=b, group="move"))
for b in ["A", "B", "Start", "Select"]:
    brain.add(Perceptron("action", action=b, group="interact"))

TAUGHT_MODEL_SAVE = BASE_PATH / "taught_model_checkpoint.json"
TAUGHT_EXPLORATION_SAVE = BASE_PATH / "taught_exploration_memory.json"
brain.EXPLORATION_MEMORY_FILE = TAUGHT_EXPLORATION_SAVE

# Resume if existing
if TAUGHT_MODEL_SAVE.exists():
    loaded_ts = brain.load_taught_model(TAUGHT_MODEL_SAVE)
    print(f"üéì RESUME: Loaded taught model from timestep {loaded_ts}")
    print(f"   Utilities: {[f'{a.action}:{a.utility:.3f}' for a in brain.actions()]}")
else:
    print("üéì FRESH START: No existing taught model")

if TAUGHT_EXPLORATION_SAVE.exists():
    brain.load_exploration_memory()
    print(f"   Exploration: {len(brain.exploration_memory)} maps loaded")

prev_context_state = None
prev_raw_position = None

print("="*70)
print("TEACHING MODE ‚Äî Human plays, Brain learns")
print("="*70)
print(f"  Model ‚Üí {TAUGHT_MODEL_SAVE.name}")
print(f"  Exploration ‚Üí {TAUGHT_EXPLORATION_SAVE.name}")
print(f"  Transitions ‚Üí {TAUGHT_TRANSITIONS_FILE.name} (Lua)")
print(f"  Nav targets ‚Üí {NAV_TARGETS_PATH.name} (auto)")
print(f"  Battle data ‚Üí {BATTLE_TRANSITIONS_PATH.name} (auto)")
print("="*70)
print("Play the game. Ctrl+C to stop, save, and post-process.")
print("="*70)

def run_all_post_processing():
    """Run all 3 post-processors."""
    print("\n  üìä Running per_map_analysis...")
    try: run_per_map_analysis()
    except Exception as e: print(f"    ‚ö†Ô∏è per_map_analysis failed: {e}")
    
    print("  üìä Running nav target extraction...")
    try: run_nav_target_extraction()
    except Exception as e: print(f"    ‚ö†Ô∏è Nav targets failed: {e}")
    
    print("  üìä Running battle extraction...")
    try: run_battle_extraction()
    except Exception as e: print(f"    ‚ö†Ô∏è Battle extraction failed: {e}")

try:
    while True:
        context_state, palette_state, tile_state, dead, raw_position = read_game_state()
        
        if np.sum(np.abs(context_state)) < 0.001:
            time.sleep(0.02); continue
        
        raw_x, raw_y = raw_position
        in_battle = context_state[3]
        current_map = int(context_state[2])
        current_dir = int(context_state[5])
        
        brain.update_position(raw_x, raw_y)
        derived = compute_derived_features(context_state, prev_context_state)
        learning_state = build_learning_state(derived, palette_state, tile_state, in_battle)
        brain.log_state(learning_state, context_state)

        time.sleep(0.02)

        next_ctx, next_pal, next_til, next_dead, next_raw_pos = read_game_state()
        next_derived = compute_derived_features(next_ctx, context_state)
        next_learning_state = build_learning_state(next_derived, next_pal, next_til, next_ctx[3])

        brain.learn(learning_state, next_learning_state, context_state, next_ctx,
                    dead=dead, raw_position=raw_position, next_raw_position=next_raw_pos)

        prev_context_state = context_state.copy()
        prev_raw_position = raw_position
        brain.timestep += 1

        # Logging
        if brain.timestep % 100 == 0:
            mem = brain.get_current_map_memory(current_map)
            vc = len(mem['visited_tiles']); oc = len(mem['obstructions'])
            ic = len(mem['interactable_objects']); cov = brain.get_exploration_coverage(current_map)
            ts = brain.get_tile_interaction_stats(current_map)
            dn = brain.DIRECTION_NAMES.get(current_dir, '?')
            tr = mem.get('transitions', [])
            
            print(f"\n{'='*60}")
            print(f"Step {brain.timestep} | Map {current_map} | ({raw_x},{raw_y}) {dn} | Battle:{int(in_battle)}")
            print(f"  Visited:{vc} Obs:{oc} Coverage:{cov:.0%} Interactables:{ic}")
            print(f"  Probed:{ts['probed']} Exhausted:{ts['exhausted']} Success:{ts['with_success']}")
            if tr: print(f"  Transitions: {len(tr)} known")
            au = sorted([(a.action, a.utility) for a in brain.actions()], key=lambda x: x[1], reverse=True)
            print(f"  Utilities: {' '.join(f'{k}:{v:.2f}' for k,v in au)}")

        # Periodic save (model + exploration every 500)
        if brain.timestep % 500 == 0 and brain.timestep > 0:
            brain.save_model_checkpoint(TAUGHT_MODEL_SAVE)
            brain.save_exploration_memory()
            print(f"  üíæ Auto-saved at step {brain.timestep}")
        
        # Periodic post-processing (every 2000 steps)
        if brain.timestep % 2000 == 0 and brain.timestep > 0:
            print(f"\n  üìä Periodic post-processing at step {brain.timestep}...")
            run_all_post_processing()
            print(f"  üìä Post-processing complete")

except KeyboardInterrupt:
    print("\n\nüõë Stopping teaching...")
    
    print("\nüìÅ Step 1/2: Saving model + exploration...")
    brain.save_model_checkpoint(TAUGHT_MODEL_SAVE)
    brain.save_exploration_memory()
    print(f"   ‚úÖ {TAUGHT_MODEL_SAVE.name}")
    print(f"   ‚úÖ {TAUGHT_EXPLORATION_SAVE.name}")
    
    print("\nüìÅ Step 2/2: Running all post-processors...")
    run_all_post_processing()
    
    print(f"\n{'='*60}")
    print(f"‚úÖ TEACHING COMPLETE")
    print(f"   Timestep: {brain.timestep}")
    print(f"   Maps: {len(brain.exploration_memory)}")
    print(f"   Tiles: {sum(len(m['visited_tiles']) for m in brain.exploration_memory.values())}")
    print(f"\nüì¶ Files ready to copy to colleague:")
    print(f"   1. {TAUGHT_MODEL_SAVE.name}")
    print(f"   2. {TAUGHT_TRANSITIONS_FILE.name}")
    print(f"   3. {TAUGHT_EXPLORATION_SAVE.name}")
    print(f"   4. {NAV_TARGETS_PATH.name}")
    print(f"   5. {BATTLE_TRANSITIONS_PATH.name}")
    print(f"{'='*60}")

  Loaded exploration: 9 maps
üéì RESUME: Loaded taught model from timestep 25366
   Utilities: ['UP:0.100', 'DOWN:0.100', 'LEFT:0.100', 'RIGHT:0.100', 'A:0.150', 'B:0.150', 'Start:0.150', 'Select:0.150']
  Loaded exploration: 9 maps
   Exploration: 9 maps loaded
TEACHING MODE ‚Äî Human plays, Brain learns
  Model ‚Üí taught_model_checkpoint.json
  Exploration ‚Üí taught_exploration_memory.json
  Transitions ‚Üí taught_transitions.json (Lua)
  Nav targets ‚Üí taught_nav_targets.json (auto)
  Battle data ‚Üí taught_battle_transitions.json (auto)
Play the game. Ctrl+C to stop, save, and post-process.

Step 25400 | Map 14 | (14,11) UP | Battle:0
  Visited:4 Obs:0 Coverage:0% Interactables:0
  Probed:0 Exhausted:0 Success:0
  Transitions: 1 known
  Utilities: A:0.15 B:0.15 Start:0.15 Select:0.15 UP:0.10 DOWN:0.10 LEFT:0.10 RIGHT:0.10

Step 25500 | Map 14 | (14,11) UP | Battle:0
  Visited:4 Obs:0 Coverage:0% Interactables:0
  Probed:0 Exhausted:0 Success:0
  Transitions: 1 known
  Utilities