In [1]:
# ============================================================================
# CELL 1: State Management & Utilities (CONSISTENT VERSION)
# ============================================================================

from pathlib import Path
import json
import numpy as np
import time
from collections import deque

BASE_PATH = Path("C:/Users/natmaw/Documents/Boston Stuff/CS 5100 Foundations of AI/PokeAI")
ACTION_FILE = BASE_PATH / "action.json"
STATE_FILE = BASE_PATH / "game_state.json"
INPUT_FILE = BASE_PATH / "input_cache.txt"
MODEL_FILE = BASE_PATH / "model_checkpoint.json"
TAUGHT_TRANSITIONS_FILE = BASE_PATH / "taught_transitions.json"

EXPECTED_STATE_DIM = 6
PALETTE_DIM = 768
TILE_DIM = 600
LEARNING_STATE_DIM = 8 + TILE_DIM + PALETTE_DIM  # 1376

# Action code mapping (from Lua short codes)
ACTION_MAP = {
    'U': 'UP', 'D': 'DOWN', 'L': 'LEFT', 'R': 'RIGHT',
    'A': 'A', 'B': 'B', 'S': 'Start', 'E': 'Select'
}

def normalize_game_state(raw_state):
    """Normalize context state for learning."""
    if len(raw_state) < 6:
        raw_state = list(raw_state) + [0] * (6 - len(raw_state))
    
    normalized = np.array(raw_state, dtype=float)
    normalized[0] = raw_state[0] / 255.0
    normalized[1] = raw_state[1] / 255.0
    normalized[2] = np.clip(raw_state[2], 0, 255)
    normalized[3] = 1.0 if raw_state[3] > 0 else 0.0
    normalized[4] = 1.0 if raw_state[4] > 0 else 0.0
    normalized[5] = int(raw_state[5]) % 4
    
    return normalized

def compute_derived_features(current, prev):
    """Extract temporal features (8D)"""
    if prev is None:
        return np.zeros(8)
    
    vel_x = current[0] - prev[0]
    vel_y = current[1] - prev[1]
    map_changed = 1.0 if abs(current[2] - prev[2]) > 0.5 else 0.0
    battle_started = 1.0 if current[3] > prev[3] else 0.0
    battle_ended = 1.0 if current[3] < prev[3] else 0.0
    menu_opened = 1.0 if current[4] > prev[4] else 0.0
    menu_closed = 1.0 if current[4] < prev[4] else 0.0
    direction_changed = 1.0 if current[5] != prev[5] else 0.0
    
    return np.array([vel_x, vel_y, map_changed, battle_started, battle_ended,
                     menu_opened, menu_closed, direction_changed])

def build_learning_state(derived, palette, tiles, in_battle):
    """
    Build learning state vector - ALWAYS SAME SIZE (1376).
    Structure: [derived(8)] + [tiles(600)] + [palette(768)]
    """
    if len(derived) != 8:
        derived = np.zeros(8)
    if len(tiles) != TILE_DIM:
        tiles = np.zeros(TILE_DIM)
    if len(palette) != PALETTE_DIM:
        palette = np.zeros(PALETTE_DIM)
    
    if in_battle > 0.5:
        tiles = np.zeros(TILE_DIM)
    
    state = np.concatenate([derived, tiles, palette])
    noise = np.random.randn(len(state)) * 0.0001
    return state + noise

def read_game_state():
    """
    Read game state from file.
    Returns: context_state, palette_state, tile_state, dead, raw_position
    (5 values - consistent with Cell 6)
    """
    if not STATE_FILE.exists():
        return np.zeros(EXPECTED_STATE_DIM), np.zeros(PALETTE_DIM), np.zeros(TILE_DIM), False, (0, 0)

    try:
        with open(STATE_FILE, 'r') as f:
            data = json.loads(f.read())
        
        raw = data.get('s', data.get('state', [0, 0, 0, 0, 0, 0]))
        palette_raw = data.get('p', data.get('palette', []))
        tiles_raw = data.get('t', data.get('tiles', []))
        dead = bool(data.get('dead', False))
        
        raw_x = int(raw[0]) if len(raw) > 0 else 0
        raw_y = int(raw[1]) if len(raw) > 1 else 0
        raw_position = (raw_x, raw_y)

        context_state = normalize_game_state(np.array(raw, dtype=float))
        
        if palette_raw:
            palette_state = np.array(palette_raw, dtype=float)
        else:
            palette_state = np.zeros(PALETTE_DIM)
        
        if tiles_raw:
            tile_state = np.array(tiles_raw, dtype=float)
        else:
            tile_state = np.zeros(TILE_DIM)
        
        # Ensure correct dimensions
        if len(palette_state) < PALETTE_DIM:
            palette_state = np.pad(palette_state, (0, PALETTE_DIM - len(palette_state)))
        elif len(palette_state) > PALETTE_DIM:
            palette_state = palette_state[:PALETTE_DIM]
            
        if len(tile_state) < TILE_DIM:
            tile_state = np.pad(tile_state, (0, TILE_DIM - len(tile_state)))
        elif len(tile_state) > TILE_DIM:
            tile_state = tile_state[:TILE_DIM]

        return context_state, palette_state, tile_state, dead, raw_position
        
    except Exception as e:
        return np.zeros(EXPECTED_STATE_DIM), np.zeros(PALETTE_DIM), np.zeros(TILE_DIM), False, (0, 0)

def write_action(action_name):
    """Write action to file for Lua to read."""
    if action_name:
        action_name = action_name.upper()
    try:
        with open(ACTION_FILE, "w") as f:
            json.dump({"action": action_name}, f)
            f.flush()
    except Exception as e:
        print(f"[ERROR] Failed to write action: {e}")

In [2]:
# ============================================================================
# CELL 2: Perceptron Classes (FIXED - Shape Safety)
# ============================================================================

class Perceptron:
    def __init__(self, kind, action=None, group=None, entity_type=None):
        self.kind = kind
        self.action = action
        self.group = group
        self.entity_type = entity_type
        
        self.utility = 1.0
        self.weights = None
        
        self.eligibility_fast = 0.0
        self.eligibility_slow = 0.0
        
        self.familiarity = 0.0
        self.activation_history = deque(maxlen=10)
        
        self.learning_rate = 0.01
        self.prediction_errors = deque(maxlen=50)

    def ensure_weights(self, dim):
        """Initialize or resize weights to match dimension."""
        if self.weights is None:
            self.weights = np.random.randn(dim) * 0.001
        elif len(self.weights) != dim:
            # FIXED: Resize weights if dimension changed
            old_weights = self.weights
            self.weights = np.random.randn(dim) * 0.001
            # Copy over what we can from old weights
            min_len = min(len(old_weights), dim)
            self.weights[:min_len] = old_weights[:min_len]

    def predict(self, state):
        self.ensure_weights(len(state))
        raw_activation = np.dot(self.weights, state)
        
        if self.kind == "entity":
            novelty_factor = 1.0 / (1.0 + np.sqrt(self.familiarity * 0.5))
            decayed_activation = raw_activation * novelty_factor
            self.activation_history.append(abs(raw_activation))
            return decayed_activation
        else:
            return raw_activation

    def adapt_learning_rate(self):
        if len(self.prediction_errors) >= 50:
            avg_error = np.mean(self.prediction_errors)
            
            if avg_error < 0.1:
                self.learning_rate = max(0.001, self.learning_rate * 0.99)
            elif avg_error > 0.5:
                self.learning_rate = min(0.05, self.learning_rate * 1.01)

    def update(self, state, error, gamma_fast=0.5, gamma_slow=0.95, stagnation=0.0):
        """Update weights - FIXED to handle shape mismatches."""
        # FIXED: Always ensure weights match state size
        self.ensure_weights(len(state))
        
        # Double-check shapes match (safety)
        if len(self.weights) != len(state):
            print(f"[WARN] Shape mismatch: weights={len(self.weights)}, state={len(state)}. Resizing.")
            self.weights = np.random.randn(len(state)) * 0.001
        
        self.eligibility_fast = gamma_fast * self.eligibility_fast + 1.0
        self.eligibility_slow = gamma_slow * self.eligibility_slow + 1.0
        
        self.adapt_learning_rate()
        
        fast_update = 0.7 * self.learning_rate * error * state * self.eligibility_fast
        slow_update = 0.3 * self.learning_rate * error * state * self.eligibility_slow
        self.weights += fast_update + slow_update

        if self.kind == "action":
            if error > 0.01:
                if stagnation > 0.5:
                    self.utility *= 0.97
                elif error > 0.2:
                    self.utility = min(self.utility * 1.02, 2.0)
                else:
                    self.utility *= 0.995
            
            if self.group == "move":
                self.utility = np.clip(self.utility, 0.1, 2.0)
            else:
                self.utility = np.clip(self.utility, 0.01, 2.0)
        
        if self.kind == "entity" and len(self.activation_history) > 0:
            recent_avg = np.mean(self.activation_history)
            if recent_avg > 0.1:
                self.familiarity += 0.03
        
        if self.kind == "entity":
            prediction = self.predict(state)
            self.prediction_errors.append(abs(prediction - error))


class ControlSwapPerceptron(Perceptron):
    def __init__(self):
        super().__init__(kind="control_swap")
        self.swap_history = deque(maxlen=100)
        self.confidence = 0.0
        
    def should_swap(self, state, movement_stagnation):
        if self.weights is None:
            return False, 0.0
        
        self.ensure_weights(len(state))
        swap_score = np.dot(self.weights, state)
        stagnation_factor = np.tanh(movement_stagnation / 5.0)
        combined_score = swap_score * 0.7 + stagnation_factor * 0.3
        
        return combined_score > 0.5, abs(combined_score)
    
    def record_swap_outcome(self, state, swapped, novelty_gained):
        self.swap_history.append((swapped, novelty_gained))
        
        if len(self.swap_history) >= 20:
            recent = list(self.swap_history)[-20:]
            successful = sum(1 for swap, nov in recent if swap and nov > 0.2)
            self.confidence = successful / 20.0

In [3]:
# ============================================================================
# CELL 3: Brain Class - CONSISTENT WITH COLLEAGUE'S VERSION
# ============================================================================
# Includes:
# - Markov imitation learning system
# - Taught transitions loading
# - Model checkpoint with markov_stats
# - Consistent data structures
# ============================================================================

import gc

# === MARKOV SYSTEM CONSTANTS ===
MARKOV_FAMILIARITY_THRESHOLD = 0.6
MARKOV_IMMEDIATE_WEIGHT = 0.50
MARKOV_SEQUENTIAL_WEIGHT = 0.30
MARKOV_PARTIAL_WEIGHT = 0.20
POSITION_TOLERANCE = 2

TAUGHT_TRANSITIONS_FILE = BASE_PATH / "taught_transitions.json"

class Brain:
    def __init__(self):
        self.perceptrons = []
        
        # State tracking (reduced for memory)
        self.prev_learning_states = deque(maxlen=10)
        self.prev_context_states = deque(maxlen=5)
        self.last_positions = deque(maxlen=15)
        self.action_history = deque(maxlen=30)
        
        self.control_mode = "move"
        self.timestep = 0
        self.last_action = None
        self.last_direction = 0
        
        self.MOVE_UTILITY_FLOOR = 0.05
        self.INTERACT_UTILITY_FLOOR = 0.15
        
        # === EXPLORATION MEMORY ===
        self.EXPLORATION_MEMORY_FILE = BASE_PATH / "taught_exploration_memory.json"
        self.exploration_memory = {}
        self.current_map_id = None
        self.SAVE_INTERVAL = 100
        self.MAX_MAPS_IN_MEMORY = 20
        
        # Direction mapping
        self.DIRECTION_NAMES = {0: "DOWN", 1: "UP", 2: "LEFT", 3: "RIGHT"}
        self.DIRECTION_TO_INT = {"DOWN": 0, "UP": 1, "LEFT": 2, "RIGHT": 3}
        self.INT_TO_ACTION = {0: "DOWN", 1: "UP", 2: "LEFT", 3: "RIGHT"}
        
        self.DIRECTION_DELTAS_INT = {0: (0, 1), 1: (0, -1), 2: (-1, 0), 3: (1, 0)}
        self.ACTION_DELTAS = {"UP": (0, -1), "DOWN": (0, 1), "LEFT": (-1, 0), "RIGHT": (1, 0)}
        self.DELTA_TO_DIRECTION = {(0, 1): 0, (0, -1): 1, (-1, 0): 2, (1, 0): 3}
        
        self.load_exploration_memory()
        
        # === ACTION EXECUTION ===
        self.pending_action = None
        self.pending_action_frames = 0
        self.ACTION_CONFIRM_FRAMES = 3
        self.last_confirmed_action = None
        
        # === TILE INTERACTION ===
        self.INTERACTION_VERIFY_FRAMES = 8
        self.MIN_SUCCESS_RATE_THRESHOLD = 0.1
        self.pending_interaction_verify = None
        self.interaction_verify_countdown = 0
        
        # === MENU TRAP ===
        self.menu_trap_frames = 0
        self.menu_trap_b_boost = 1.0
        self.menu_trap_position = None
        self.B_BOOST_INCREMENT = 0.15
        self.B_BOOST_MAX = 3.0
        self.MENU_TRAP_THRESHOLD = 5
        self.original_b_utility = None
        
        # === MODE SWAPPING ===
        self.DEFAULT_MOVE_TO_INTERACT_THRESHOLD = 15
        self.DEFAULT_INTERACT_TO_MOVE_THRESHOLD = 25
        self.move_to_interact_threshold = self.DEFAULT_MOVE_TO_INTERACT_THRESHOLD
        self.interact_to_move_threshold = self.DEFAULT_INTERACT_TO_MOVE_THRESHOLD
        self.THRESHOLD_INCREMENT = 15
        self.MAX_THRESHOLD = 150
        self.frames_in_current_mode = 0
        self.swap_chain_count = 0
        self.position_at_mode_swap = None
        self.last_map_id = None
        self.last_battle_state = None
        
        # === STAGNATION ===
        self.STATE_STAGNATION_THRESHOLD = 20
        self.state_stagnation_count = 0
        self.last_context_state_hash = None
        self.stagnation_initiator_action = None
        self.STAGNATION_INITIATOR_PENALTY = 0.7
        self.unproductive_swap_count = 0
        self.UNPRODUCTIVE_SWAP_THRESHOLD = 3
        
        # === BOTH MODE ===
        self.BOTH_MODE_STAGNATION_THRESHOLD = 35
        self.BOTH_MODE_SWAP_THRESHOLD = 5
        self.last_direction_for_progress = None
        self.direction_change_counts_as_progress = True
        
        # === NOVELTY ===
        self.UNVISITED_TILE_BONUS = 1.5
        self.OBSTRUCTION_PENALTY = 0.25
        
        # === TRANSITIONS ===
        self.TRANSITION_ATTRACTION_WEIGHT = 0.6
        self.TEMP_DEBT_ACCUMULATION = 0.5
        self.TEMP_DEBT_DECAY = 0.02
        self.TEMP_DEBT_MAX = 15.0
        
        # === DEBT ===
        self.MAX_MAP_DEBT = 10.0
        self.MAX_LOCATION_DEBT = 5.0
        self.DEBT_DECAY_RATE = 0.005
        
        # === BANS ===
        self.transition_bans = {}
        self.BAN_VICINITY_RADIUS = 3
        self.BAN_COVERAGE_LIFT_THRESHOLD = 0.6
        self.BAN_TIMEOUT_STEPS = 300
        
        # Multi-scale memory
        self.visited_maps = {}
        self.map_novelty_debt = {}
        self.location_memory = {}
        self.location_novelty = {}
        self.action_execution_count = {}
        self.MAX_LOCATIONS = 500
        
        self.swap_perceptron = ControlSwapPerceptron()
        
        # Error history (reduced)
        self.error_history = deque(maxlen=100)
        self.numeric_error_history = deque(maxlen=100)
        self.visual_error_history = deque(maxlen=100)
        
        self._entity_norms_cache = {}
        self._cache_valid = False
        self.innate_entities_spawned = False
        
        # === REPETITION ===
        self.consecutive_action_count = 0
        self.current_repeated_action = None
        self.LEARNING_SLOWDOWN_START = 3
        self.LEARNING_SLOWDOWN_MAX = 10
        self.PENALTY_THRESHOLD = 12
        self.HARD_RESET_THRESHOLD = 18
        
        # === PATTERN ===
        self.PATTERN_CHECK_WINDOW = 30
        self.PATTERN_MIN_REPEATS = 3
        self.PATTERN_MAX_LENGTH = 10
        self.detected_pattern = None
        self.pattern_repeat_count = 0

        # === PROBE CACHE ===
        self._cached_probe_action = None
        self._cached_probe_dir = None
        self._probe_cache_position = None
        
        # === MARKOV SYSTEM (NEW) ===
        self.taught_batches = []
        self.taught_transitions = []  # Flat list of all frames
        self.taught_metadata = {}
        self.markov_action_count = 0
        self.curiosity_action_count = 0
        self.last_markov_score = 0.0
        self.last_markov_action = None
        self.recent_actions_buffer = deque(maxlen=8)

    # =========================================================================
    # CORE METHODS
    # =========================================================================
    
    def add(self, p):
        self.perceptrons.append(p)
        self._cache_valid = False

    def actions(self):
        return [p for p in self.perceptrons if p.kind == "action"]

    def entities(self):
        return [p for p in self.perceptrons if p.kind == "entity"]

    def get_location_key(self, x, y, map_id, bin_size=5):
        return (int(map_id), int(x // bin_size) * bin_size, int(y // bin_size) * bin_size)

    def is_near_map_edge(self, x, y):
        return x < 10 or x > 245 or y < 10 or y > 245

    def record_action_execution(self, action_name):
        if action_name:
            self.action_execution_count[action_name] = self.action_execution_count.get(action_name, 0) + 1
            self.recent_actions_buffer.append(action_name)

    def get_position_stagnation(self):
        if len(self.last_positions) < 2:
            return 0
        current_pos = self.last_positions[-1]
        return sum(1 for pos in reversed(list(self.last_positions)[:-1]) if pos == current_pos)

    def get_group_weight(self, group):
        return sum(a.utility for a in self.actions() if a.group == group)

    def log_state(self, learning_state, context_state):
        self.prev_learning_states.append(learning_state)
        self.prev_context_states.append(context_state)

    def update_position(self, x, y):
        self.last_positions.append((int(x), int(y)))

    # =========================================================================
    # ACTION EXECUTION CONFIRMATION
    # =========================================================================
    
    def set_pending_action(self, action_name):
        self.pending_action = action_name
        self.pending_action_frames = 0
    
    def confirm_action_executed(self, context_state, prev_context_state):
        if self.pending_action is None:
            return True
        self.pending_action_frames += 1
        action_executed = False
        if prev_context_state is not None:
            if self.pending_action in ["UP", "DOWN", "LEFT", "RIGHT"]:
                pos_changed = (context_state[0] != prev_context_state[0] or 
                              context_state[1] != prev_context_state[1])
                dir_changed = context_state[5] != prev_context_state[5]
                action_executed = pos_changed or dir_changed
            elif self.pending_action in ["A", "B", "Start", "Select"]:
                menu_changed = abs(context_state[4] - prev_context_state[4]) > 0.1
                battle_changed = context_state[3] != prev_context_state[3]
                map_changed = context_state[2] != prev_context_state[2]
                action_executed = menu_changed or battle_changed or map_changed
        if action_executed or self.pending_action_frames >= self.ACTION_CONFIRM_FRAMES:
            self.last_confirmed_action = self.pending_action
            self.pending_action = None
            self.pending_action_frames = 0
            return True
        return False
    
    def should_send_new_action(self):
        return self.pending_action is None or self.pending_action_frames >= self.ACTION_CONFIRM_FRAMES

    # =========================================================================
    # MENU TRAP
    # =========================================================================
    
    def update_menu_trap_tracking(self, context_state, action_taken, raw_position=None):
        current_pos = raw_position if raw_position else (round(context_state[0] * 255), round(context_state[1] * 255))
        if self.menu_trap_position is not None and current_pos != self.menu_trap_position:
            self.reset_menu_trap_boost()
            return
        if self.get_context_state_hash(context_state) == self.last_context_state_hash:
            if action_taken in ["A", "B", "Start", "Select"]:
                self.menu_trap_frames += 1
                self.menu_trap_position = current_pos
                if self.menu_trap_frames > self.MENU_TRAP_THRESHOLD:
                    if self.original_b_utility is None:
                        for a in self.actions():
                            if a.action == 'B':
                                self.original_b_utility = a.utility
                                break
                    self.menu_trap_b_boost = min(self.B_BOOST_MAX, self.menu_trap_b_boost + self.B_BOOST_INCREMENT)
        elif current_pos != self.menu_trap_position:
            self.reset_menu_trap_boost()

    def reset_menu_trap_boost(self):
        if self.menu_trap_b_boost > 1.0 and self.original_b_utility is not None:
            for a in self.actions():
                if a.action == 'B':
                    a.utility = self.original_b_utility
                    break
        self.menu_trap_frames = 0
        self.menu_trap_b_boost = 1.0
        self.menu_trap_position = None
        self.original_b_utility = None

    # =========================================================================
    # MEMORY MANAGEMENT
    # =========================================================================
    
    def cleanup_memory(self):
        if len(self.location_memory) > self.MAX_LOCATIONS:
            sorted_locs = sorted(self.location_memory.items(), key=lambda x: x[1], reverse=True)
            self.location_memory = dict(sorted_locs[:self.MAX_LOCATIONS // 2])
            self.location_novelty = {k: v for k, v in self.location_novelty.items() if k in self.location_memory}
        
        if len(self.exploration_memory) > self.MAX_MAPS_IN_MEMORY:
            self.save_exploration_memory()
            sorted_maps = sorted(self.exploration_memory.items(),
                                key=lambda x: x[1].get('last_visited_timestep', 0), reverse=True)
            self.exploration_memory = dict(sorted_maps[:self.MAX_MAPS_IN_MEMORY // 2])
        
        self._entity_norms_cache.clear()
        self._cache_valid = False
        gc.collect()
    
    def get_memory_stats(self):
        stats = {
            'exploration_maps': len(self.exploration_memory),
            'location_memory': len(self.location_memory),
            'error_history': len(self.error_history),
            'perceptrons': len(self.perceptrons),
        }
        total_tiles = sum(len(m.get('visited_tiles', set())) for m in self.exploration_memory.values())
        stats['total_tiles'] = total_tiles
        return stats

    # =========================================================================
    # EXPLORATION MEMORY
    # =========================================================================
    
    def load_exploration_memory(self):
        try:
            if self.EXPLORATION_MEMORY_FILE.exists():
                with open(self.EXPLORATION_MEMORY_FILE, 'r') as f:
                    data = json.load(f)
                    self.exploration_memory = {}
                    for map_key, map_data in data.items():
                        map_id = int(map_key.replace('map_', ''))
                        self.exploration_memory[map_id] = self._deserialize_map_memory(map_data)
                print(f"  Loaded exploration memory: {len(self.exploration_memory)} maps")
            else:
                self.exploration_memory = {}
        except Exception as e:
            print(f"  Error loading exploration memory: {e}")
            self.exploration_memory = {}

    def _deserialize_map_memory(self, map_data):
        tile_interactions = {}
        for tile_key, tile_data in map_data.get('tile_interactions', {}).items():
            tile_interactions[tile_key] = {
                'directions_tried': set(tile_data.get('directions_tried', [])),
                'direction_attempts': {int(k): v for k, v in tile_data.get('direction_attempts', {}).items()},
                'direction_successes': {int(k): v for k, v in tile_data.get('direction_successes', {}).items()},
                'exhausted': tile_data.get('exhausted', False)
            }
        return {
            'visited_tiles': set(tuple(t) for t in map_data.get('visited_tiles', [])),
            'obstructions': set(tuple(t) for t in map_data.get('obstructions', [])),
            'interactable_objects': map_data.get('interactable_objects', []),
            'last_visited_timestep': map_data.get('last_visited_timestep', 0),
            'transitions': map_data.get('transitions', []),
            'temp_debt': map_data.get('temp_debt', 0.0),
            'tile_interactions': tile_interactions
        }

    def save_exploration_memory(self):
        try:
            data = {}
            for map_id, map_data in self.exploration_memory.items():
                data[f'map_{map_id}'] = self._serialize_map_memory(map_data)
            with open(self.EXPLORATION_MEMORY_FILE, 'w') as f:
                json.dump(data, f)
        except Exception as e:
            print(f"  Error saving exploration memory: {e}")

    def _serialize_map_memory(self, map_data):
        serialized_ti = {}
        for tile_key, td in map_data.get('tile_interactions', {}).items():
            serialized_ti[tile_key] = {
                'directions_tried': list(td.get('directions_tried', set())),
                'direction_attempts': {str(k): v for k, v in td.get('direction_attempts', {}).items()},
                'direction_successes': {str(k): v for k, v in td.get('direction_successes', {}).items()},
                'exhausted': td.get('exhausted', False)
            }
        return {
            'visited_tiles': [list(t) for t in map_data['visited_tiles']],
            'obstructions': [list(t) for t in map_data['obstructions']],
            'interactable_objects': map_data['interactable_objects'],
            'last_visited_timestep': map_data['last_visited_timestep'],
            'transitions': map_data.get('transitions', []),
            'temp_debt': map_data.get('temp_debt', 0.0),
            'tile_interactions': serialized_ti
        }

    def merge_taught_exploration(self, filepath):
        """Merge taught exploration memory into current memory."""
        if not Path(filepath).exists():
            print(f"  No taught exploration at {filepath}")
            return
        try:
            with open(filepath, 'r') as f:
                taught_data = json.load(f)
            for map_key, map_data in taught_data.items():
                map_id = int(map_key.replace('map_', ''))
                taught_mem = self._deserialize_map_memory(map_data)
                if map_id in self.exploration_memory:
                    # Merge
                    self.exploration_memory[map_id]['visited_tiles'].update(taught_mem['visited_tiles'])
                    self.exploration_memory[map_id]['obstructions'].update(taught_mem['obstructions'])
                    for obj in taught_mem['interactable_objects']:
                        if obj not in self.exploration_memory[map_id]['interactable_objects']:
                            self.exploration_memory[map_id]['interactable_objects'].append(obj)
                    for trans in taught_mem['transitions']:
                        existing = [t for t in self.exploration_memory[map_id]['transitions'] 
                                   if t['position'] == trans['position'] and t['direction'] == trans['direction']]
                        if not existing:
                            self.exploration_memory[map_id]['transitions'].append(trans)
                else:
                    self.exploration_memory[map_id] = taught_mem
            print(f"  Merged taught exploration: {len(taught_data)} maps")
        except Exception as e:
            print(f"  Error merging taught exploration: {e}")

    def get_current_map_memory(self, map_id):
        if map_id not in self.exploration_memory:
            self.exploration_memory[map_id] = {
                'visited_tiles': set(), 'obstructions': set(), 'interactable_objects': [],
                'last_visited_timestep': self.timestep, 'transitions': [], 'temp_debt': 0.0,
                'tile_interactions': {}
            }
        return self.exploration_memory[map_id]

    def record_visited_tile(self, x, y, map_id):
        memory = self.get_current_map_memory(map_id)
        memory['visited_tiles'].add((int(x), int(y)))
        memory['last_visited_timestep'] = self.timestep

    def record_obstruction(self, x, y, map_id, direction):
        dx, dy = self.DIRECTION_DELTAS_INT.get(direction, (0, 0))
        memory = self.get_current_map_memory(map_id)
        memory['obstructions'].add((int(x + dx), int(y + dy)))

    # =========================================================================
    # TILE INTERACTION
    # =========================================================================
    
    def get_tile_interaction_key(self, x, y):
        return f"{int(x)}_{int(y)}"
    
    def get_tile_interaction_state(self, x, y, map_id):
        memory = self.get_current_map_memory(map_id)
        tile_key = self.get_tile_interaction_key(x, y)
        if tile_key not in memory['tile_interactions']:
            memory['tile_interactions'][tile_key] = {
                'directions_tried': set(),
                'direction_attempts': {0: 0, 1: 0, 2: 0, 3: 0},
                'direction_successes': {0: 0, 1: 0, 2: 0, 3: 0},
                'exhausted': False
            }
        return memory['tile_interactions'][tile_key]
    
    def should_interact_at_tile(self, x, y, map_id):
        tile_state = self.get_tile_interaction_state(x, y, map_id)
        if tile_state['exhausted']:
            return False
        if len(tile_state['directions_tried']) < 4:
            return True
        for d in range(4):
            attempts = tile_state['direction_attempts'].get(d, 0)
            successes = tile_state['direction_successes'].get(d, 0)
            if attempts > 0 and successes / attempts >= self.MIN_SUCCESS_RATE_THRESHOLD:
                return True
        return False
    
    def get_untried_directions(self, x, y, map_id):
        tile_state = self.get_tile_interaction_state(x, y, map_id)
        return [d for d in range(4) if d not in tile_state['directions_tried']]

    def get_best_probe_action(self, raw_x, raw_y, current_map, current_dir):
        cache_key = (raw_x, raw_y, current_map, current_dir)
        if self._probe_cache_position == cache_key:
            return self._cached_probe_action, self._cached_probe_dir
        
        if not self.should_interact_at_tile(raw_x, raw_y, current_map):
            result = (None, None)
        else:
            untried = self.get_untried_directions(raw_x, raw_y, current_map)
            if not untried:
                best_dir = self.get_best_interaction_direction(raw_x, raw_y, current_map)
                if best_dir is not None:
                    result = ('A', current_dir) if current_dir == best_dir else (self.INT_TO_ACTION[best_dir], best_dir)
                else:
                    result = (None, None)
            elif current_dir in untried:
                result = ('A', current_dir)
            else:
                target_dir = untried[0]
                result = (self.INT_TO_ACTION[target_dir], target_dir)
        
        self._probe_cache_position = cache_key
        self._cached_probe_action, self._cached_probe_dir = result
        return result

    def get_best_interaction_direction(self, x, y, map_id):
        tile_state = self.get_tile_interaction_state(x, y, map_id)
        best_dir, best_rate = None, 0.0
        for d in range(4):
            attempts = tile_state['direction_attempts'].get(d, 0)
            if attempts > 0:
                rate = tile_state['direction_successes'].get(d, 0) / attempts
                if rate > best_rate:
                    best_rate, best_dir = rate, d
        return best_dir

    def get_tile_interaction_stats(self, map_id):
        memory = self.get_current_map_memory(map_id)
        tile_interactions = memory.get('tile_interactions', {})
        return {
            'probed': len(tile_interactions),
            'exhausted': sum(1 for t in tile_interactions.values() if t.get('exhausted', False)),
            'with_success': sum(1 for t in tile_interactions.values() if any(t.get('direction_successes', {}).get(d, 0) > 0 for d in range(4)))
        }

    def get_exploration_coverage(self, map_id):
        memory = self.get_current_map_memory(map_id)
        visited = len(memory['visited_tiles'])
        obstructions = len(memory['obstructions'])
        if visited == 0 or visited + obstructions < 10:
            return 0.0
        return visited / (visited + obstructions)

    # =========================================================================
    # TRANSITIONS & DEBT
    # =========================================================================
    
    def record_transition(self, from_pos, from_map, to_map, direction, action_type):
        memory = self.get_current_map_memory(from_map)
        for t in memory['transitions']:
            if t['position'] == list(from_pos) and t['direction'] == direction:
                t['use_count'] += 1
                t['last_used'] = self.timestep
                return
        memory['transitions'].append({
            'position': list(from_pos), 'direction': direction, 'action': action_type,
            'destination_map': to_map, 'use_count': 1, 'last_used': self.timestep
        })
        print(f"  üö™ TRANSITION FOUND: Map {from_map} ({from_pos}) ‚Üí Map {to_map}")

    def is_transition_banned(self, map_id, position, direction):
        if map_id not in self.transition_bans:
            return False
        ban = self.transition_bans[map_id]
        banned_tile = tuple(ban['banned_tile']) if isinstance(ban['banned_tile'], list) else ban['banned_tile']
        position = tuple(position) if isinstance(position, list) else position
        if position == banned_tile and direction == ban['banned_direction']:
            return True
        return False

    def get_temp_debt(self, map_id):
        memory = self.get_current_map_memory(map_id)
        return memory.get('temp_debt', 0.0)

    def accumulate_temp_debt(self, map_id):
        memory = self.get_current_map_memory(map_id)
        memory['temp_debt'] = min(self.TEMP_DEBT_MAX, memory.get('temp_debt', 0.0) + self.TEMP_DEBT_ACCUMULATION)

    def decay_all_debts(self):
        for map_id in list(self.map_novelty_debt.keys()):
            if map_id != self.current_map_id:
                self.map_novelty_debt[map_id] *= (1.0 - self.DEBT_DECAY_RATE)
                if self.map_novelty_debt[map_id] < 0.1:
                    del self.map_novelty_debt[map_id]

    # =========================================================================
    # STAGNATION & MODE
    # =========================================================================

    def get_context_state_hash(self, context_state):
        return (round(context_state[0], 2), round(context_state[1], 2), int(context_state[2]),
                int(context_state[3]), round(context_state[4], 2), int(context_state[5]))

    def should_use_both_mode(self):
        return (self.state_stagnation_count > self.BOTH_MODE_STAGNATION_THRESHOLD or 
                self.unproductive_swap_count > self.BOTH_MODE_SWAP_THRESHOLD)

    # =========================================================================
    # EXPLORATION TRACKING
    # =========================================================================
    
    def update_exploration_tracking(self, context_state, prev_context_state, raw_position=None, prev_raw_position=None):
        current_map = int(context_state[2])
        raw_x = raw_position[0] if raw_position else int(context_state[0] * 255)
        raw_y = raw_position[1] if raw_position else int(context_state[1] * 255)
        
        if self.current_map_id is not None and current_map != self.current_map_id:
            if prev_context_state is not None and prev_raw_position is not None:
                self.record_transition(prev_raw_position, self.current_map_id, current_map,
                    int(prev_context_state[5]), 'interact' if self.last_action == 'A' else 'walk')
            self.on_map_change(current_map)
        
        self.current_map_id = current_map
        self.record_visited_tile(raw_x, raw_y, current_map)
        self.accumulate_temp_debt(current_map)
        self.last_direction = int(context_state[5])
        
        if self.timestep % 300 == 0:
            self.decay_all_debts()

    def on_map_change(self, new_map):
        self.save_exploration_memory()
        self.control_mode = "move"
        self.frames_in_current_mode = 0
        memory = self.get_current_map_memory(new_map)
        print(f"  üó∫Ô∏è MAP CHANGE ‚Üí {new_map}: {len(memory['visited_tiles'])} visited")

    # =========================================================================
    # UTILITY & LEARNING HELPERS
    # =========================================================================
    
    def enforce_utility_floors(self):
        for a in self.actions():
            floor = self.MOVE_UTILITY_FLOOR if a.group == "move" else self.INTERACT_UTILITY_FLOOR
            a.utility = max(a.utility, floor)

    def stagnation_level(self, window=10):
        if len(self.prev_learning_states) < window:
            return 0.0
        recent = list(self.prev_learning_states)[-window:]
        return 1.0 - np.tanh(np.mean([np.linalg.norm(recent[i] - recent[i-1]) for i in range(1, len(recent))]) * 2.0)

    def track_consecutive_action(self, action_name):
        if action_name == self.current_repeated_action:
            self.consecutive_action_count += 1
        else:
            self.current_repeated_action = action_name
            self.consecutive_action_count = 1

    def get_learning_multiplier(self, action_name):
        if action_name != self.current_repeated_action or self.consecutive_action_count < self.LEARNING_SLOWDOWN_START:
            return 1.0
        progress = min(1.0, (self.consecutive_action_count - self.LEARNING_SLOWDOWN_START) / 
                       (self.LEARNING_SLOWDOWN_MAX - self.LEARNING_SLOWDOWN_START))
        return max(0.05, 1.0 - 0.95 * progress)

    def apply_repetition_penalty(self):
        if self.current_repeated_action is None:
            return
        for a in self.actions():
            if a.action == self.current_repeated_action:
                floor = self.INTERACT_UTILITY_FLOOR if a.group == "interact" else self.MOVE_UTILITY_FLOOR
                if self.consecutive_action_count >= self.HARD_RESET_THRESHOLD:
                    a.utility = max(floor, a.utility * 0.5)
                    self.consecutive_action_count = 0
                elif self.consecutive_action_count >= self.PENALTY_THRESHOLD:
                    a.utility = max(a.utility * 0.7, floor)
                break

    def apply_pattern_penalty(self):
        pass

    def spawn_innate_entities(self, learning_state):
        if self.innate_entities_spawned:
            return
        for etype, indices in [("sense_menu", [5, 6]), ("sense_battle", [3, 4]), 
                                ("sense_movement", [0, 1]), ("sense_map_transition", [2])]:
            entity = Perceptron("entity", entity_type=etype)
            entity.ensure_weights(len(learning_state))
            entity.weights = np.zeros(len(learning_state))
            for idx in indices:
                if idx < len(entity.weights):
                    entity.weights[idx] = 0.5 if len(indices) > 1 else 1.0
            self.add(entity)
        self.innate_entities_spawned = True

    def compute_multi_modal_error(self, state, next_state):
        min_len = min(len(state), len(next_state))
        diffs = [abs(next_state[i] - state[i]) for i in range(min(8, min_len))]
        weights = [0.5, 0.5, 10.0, 5.0, 3.0, 2.0, 1.5, 0.3]
        weighted = sum(d * w for d, w in zip(diffs, weights[:len(diffs)]))
        if min_len > 8:
            weighted += np.linalg.norm(next_state[8:min_len] - state[8:min_len]) * 2.0
        numeric = sum(diffs)
        visual = np.linalg.norm(next_state[8:min_len] - state[8:min_len]) if min_len > 8 else 0.0
        return weighted, numeric, visual

    # =========================================================================
    # MARKOV IMITATION LEARNING SYSTEM
    # =========================================================================
    
    def load_taught_transitions(self, filepath):
        """Load taught transitions for Markov imitation."""
        if not Path(filepath).exists():
            print(f"  No taught transitions at {filepath}")
            return
        try:
            with open(filepath, 'r') as f:
                data = json.load(f)
            self.taught_batches = data.get('batches', [])
            self.taught_metadata = data.get('metadata', {})
            
            # Build flat frame index
            self.taught_transitions = []
            for batch in self.taught_batches:
                for frame in batch.get('frames', []):
                    self.taught_transitions.append({
                        'batch_type': batch.get('batch_type'),
                        'trigger_action': batch.get('trigger_action'),
                        'state': frame.get('state', {}),
                        'action': frame.get('action'),
                        'recent_actions': frame.get('recent_actions', [])
                    })
            
            print(f"  ‚úÖ Loaded taught transitions: {len(self.taught_batches)} batches, {len(self.taught_transitions)} frames")
        except Exception as e:
            print(f"  ‚ùå Error loading taught transitions: {e}")

    def compute_markov_similarity(self, current_state, current_recent, taught_frame):
        """Compute similarity score for Markov decision."""
        taught_state = taught_frame.get('state', {})
        taught_recent = taught_frame.get('recent_actions', [])
        
        # Immediate similarity (50%)
        immediate = 0.0
        if current_state.get('map_id') == taught_state.get('map_id'):
            immediate += 0.25
            dx = abs(current_state.get('x', 0) - taught_state.get('x', 0))
            dy = abs(current_state.get('y', 0) - taught_state.get('y', 0))
            if dx <= POSITION_TOLERANCE and dy <= POSITION_TOLERANCE:
                immediate += 0.35 * (1.0 - (dx + dy) / (2 * POSITION_TOLERANCE + 1))
            if current_state.get('direction') == taught_state.get('direction'):
                immediate += 0.20
            if current_state.get('in_battle') == taught_state.get('in_battle'):
                immediate += 0.10
            if current_state.get('in_menu') == taught_state.get('in_menu'):
                immediate += 0.10
        
        # Sequential similarity (30%)
        sequential = 0.0
        if current_recent and taught_recent:
            matches = 0
            for i in range(1, min(9, len(current_recent) + 1, len(taught_recent) + 1)):
                if len(current_recent) >= i and len(taught_recent) >= i:
                    if current_recent[-i] == taught_recent[-i]:
                        matches += 1
                    else:
                        break
            if matches >= 8:
                sequential = 1.0
            elif matches >= 5:
                sequential = 0.7
            elif matches >= 3:
                sequential = 0.4
            else:
                sequential = matches / 8.0 * 0.3
        
        # Partial similarity (20%)
        partial = 0.0
        if current_state.get('in_battle') == taught_state.get('in_battle'):
            partial += 0.4
        if current_state.get('in_menu') == taught_state.get('in_menu'):
            partial += 0.4
        partial += 0.2  # Base partial credit
        
        total = (immediate * MARKOV_IMMEDIATE_WEIGHT + 
                 sequential * MARKOV_SEQUENTIAL_WEIGHT + 
                 partial * MARKOV_PARTIAL_WEIGHT)
        
        return total

    def get_markov_action(self, current_state, raw_position):
        """Get action from Markov system if similarity is high enough."""
        if not self.taught_transitions:
            return None, 0.0
        
        # Handle both numpy array and dict formats
        if isinstance(current_state, dict):
            state_dict = {
                'map_id': current_state.get('map', current_state.get('map_id', 0)),
                'x': raw_position[0],
                'y': raw_position[1],
                'direction': current_state.get('direction', 0),
                'in_battle': current_state.get('in_battle', 0),
                'in_menu': current_state.get('in_menu', current_state.get('menu_flag', 0))
            }
        else:
            # Numpy array format: [x, y, map_id, in_battle, menu_flag, direction]
            state_dict = {
                'map_id': int(current_state[2]),
                'x': raw_position[0],
                'y': raw_position[1],
                'direction': int(current_state[5]),
                'in_battle': int(current_state[3]),
                'in_menu': int(current_state[4])
            }
        
        current_recent = list(self.recent_actions_buffer)
        
        best_action = None
        best_score = 0.0
        
        for frame in self.taught_transitions:
            score = self.compute_markov_similarity(state_dict, current_recent, frame)
            if score > best_score:
                best_score = score
                best_action = frame.get('action')
        
        self.last_markov_score = best_score
        self.last_markov_action = best_action if best_score >= MARKOV_FAMILIARITY_THRESHOLD else None
        
        if best_score >= MARKOV_FAMILIARITY_THRESHOLD:
            return best_action, best_score
        
        return None, best_score

    # =========================================================================
    # MAIN LEARN METHOD
    # =========================================================================
    
    def learn(self, learning_state, next_learning_state, context_state, next_context_state, dead=False,
            raw_position=None, next_raw_position=None):
        
        if len(learning_state) != len(next_learning_state):
            max_dim = max(len(learning_state), len(next_learning_state))
            learning_state = np.pad(learning_state, (0, max(0, max_dim - len(learning_state))))
            next_learning_state = np.pad(next_learning_state, (0, max(0, max_dim - len(next_learning_state))))
        
        if not self.innate_entities_spawned:
            self.spawn_innate_entities(learning_state)
        
        prev_context = self.prev_context_states[-1] if self.prev_context_states else None
        prev_raw = getattr(self, '_last_raw_position', None)
        self.update_exploration_tracking(context_state, prev_context, raw_position, prev_raw)
        self._last_raw_position = raw_position
        
        weighted_error, numeric_error, visual_error = self.compute_multi_modal_error(learning_state, next_learning_state)
        self.error_history.append(weighted_error)
        self.numeric_error_history.append(numeric_error)
        self.visual_error_history.append(visual_error)
        
        current_map = int(context_state[2])
        loc = self.get_location_key(*(raw_position if raw_position else (context_state[0]*255, context_state[1]*255)), current_map)
        
        self.visited_maps[current_map] = self.visited_maps.get(current_map, 0) + 1
        
        if len(self.location_memory) < self.MAX_LOCATIONS:
            self.location_memory[loc] = self.location_memory.get(loc, 0) + 1
        
        if self.visited_maps[current_map] > 10:
            self.map_novelty_debt[current_map] = min(self.MAX_MAP_DEBT, 
                self.map_novelty_debt.get(current_map, 0.0) + 0.05 * (self.visited_maps[current_map] - 10))
        
        if self.visited_maps[current_map] > 30:
            weighted_error *= 0.5
        if self.location_memory.get(loc, 0) > 25:
            weighted_error *= 0.7
        
        stagnation = self.stagnation_level()
        learning_mult = self.get_learning_multiplier(self.last_action) if self.last_action else 1.0
        
        for p in self.perceptrons:
            mult = learning_mult if (p.kind == "action" and p.action == self.last_action) else 1.0
            p.update(learning_state, weighted_error * mult, stagnation=stagnation)
        
        self.apply_repetition_penalty()
        self.apply_pattern_penalty()
        self.enforce_utility_floors()
        
        if self.timestep % 1000 == 0:
            self.cleanup_memory()
        
        if self.timestep % self.SAVE_INTERVAL == 0:
            self.save_exploration_memory()
        
        self.action_history.append(self.last_action)

    # =========================================================================
    # MODEL CHECKPOINT - CONSISTENT FORMAT
    # =========================================================================

    def save_model_checkpoint(self, filepath=None):
        """Save model checkpoint in consistent format."""
        if filepath is None:
            filepath = MODEL_FILE
        
        actions_data = []
        for a in self.actions():
            weights_nonzero = []
            if a.weights is not None:
                nonzero_indices = np.where(np.abs(a.weights) > 0.01)[0]
                weights_nonzero = [[int(idx), round(float(a.weights[idx]), 4)] for idx in nonzero_indices]
            actions_data.append({
                "action": a.action,
                "group": a.group,
                "utility": round(float(a.utility), 4),
                "weights_shape": int(len(a.weights)) if a.weights is not None else 1376,
                "weights_nonzero": weights_nonzero,
                "learning_rate": round(float(a.learning_rate), 4),
                "familiarity": round(float(a.familiarity), 4)
            })
        
        entities_data = []
        for e in self.entities():
            weights_nonzero = []
            if e.weights is not None:
                nonzero_indices = np.where(np.abs(e.weights) > 0.01)[0]
                weights_nonzero = [[int(idx), round(float(e.weights[idx]), 4)] for idx in nonzero_indices]
            entities_data.append({
                "entity_type": e.entity_type,
                "utility": round(float(e.utility), 4),
                "weights_shape": int(len(e.weights)) if e.weights is not None else 1376,
                "weights_nonzero": weights_nonzero,
                "familiarity": round(float(e.familiarity), 4)
            })
        
        # Location novelty with string keys
        location_novelty_str = {}
        for k, v in self.location_novelty.items():
            location_novelty_str[str(k)] = round(float(v), 4)
        
        model_data = {
            "timestep": int(self.timestep),
            "perceptrons": {
                "actions": actions_data,
                "entities": entities_data
            },
            "debt_tracking": {
                "map_novelty_debt": {str(k): round(float(v), 4) for k, v in self.map_novelty_debt.items()},
                "location_novelty": location_novelty_str,
                "visited_maps": {str(k): int(v) for k, v in self.visited_maps.items()}
            },
            "control_mode": self.control_mode,
            "markov_stats": {
                "markov_action_count": int(self.markov_action_count),
                "curiosity_action_count": int(self.curiosity_action_count)
            }
        }
        
        try:
            with open(filepath, 'w') as f:
                json.dump(model_data, f, indent=2)
            print(f"üíæ Model saved: step {self.timestep}, markov={self.markov_action_count}, curiosity={self.curiosity_action_count}")
        except Exception as e:
            print(f"‚ùå Save error: {e}")

    def load_taught_model(self, filepath):
        """Load model from checkpoint."""
        if not Path(filepath).exists():
            return 0
        
        try:
            with open(filepath, 'r') as f:
                model_data = json.load(f)
            
            self.timestep = model_data.get("timestep", 0)
            self.control_mode = model_data.get("control_mode", "move")
            
            # Load action perceptrons
            for a_data in model_data.get("perceptrons", {}).get("actions", []):
                for a in self.actions():
                    if a.action == a_data["action"]:
                        a.utility = a_data.get("utility", 1.0)
                        a.learning_rate = a_data.get("learning_rate", 0.01)
                        a.familiarity = a_data.get("familiarity", 0.0)
                        weights_shape = a_data.get("weights_shape", 1376)
                        a.weights = np.zeros(weights_shape)
                        for idx, val in a_data.get("weights_nonzero", []):
                            if idx < weights_shape:
                                a.weights[idx] = val
                        break
            
            # Load entity perceptrons
            for e_data in model_data.get("perceptrons", {}).get("entities", []):
                for e in self.entities():
                    if e.entity_type == e_data.get("entity_type"):
                        e.utility = e_data.get("utility", 1.0)
                        e.familiarity = e_data.get("familiarity", 0.0)
                        weights_shape = e_data.get("weights_shape", 1376)
                        e.weights = np.zeros(weights_shape)
                        for idx, val in e_data.get("weights_nonzero", []):
                            if idx < weights_shape:
                                e.weights[idx] = val
                        break
            
            # Load debt tracking
            debt_data = model_data.get("debt_tracking", {})
            self.map_novelty_debt = {int(k): float(v) for k, v in debt_data.get("map_novelty_debt", {}).items()}
            self.visited_maps = {int(k): int(v) for k, v in debt_data.get("visited_maps", {}).items()}
            
            # Load location novelty (handle tuple keys)
            for k, v in debt_data.get("location_novelty", {}).items():
                try:
                    key = eval(k) if k.startswith('(') else k
                    self.location_novelty[key] = float(v)
                except:
                    pass
            
            # Load markov stats
            markov_stats = model_data.get("markov_stats", {})
            self.markov_action_count = markov_stats.get("markov_action_count", 0)
            self.curiosity_action_count = markov_stats.get("curiosity_action_count", 0)
            
            print(f"‚úÖ Model loaded: step {self.timestep}")
            return self.timestep
            
        except Exception as e:
            print(f"‚ùå Load error: {e}")
            return 0

    # Alias for backwards compatibility
    def save_model(self, filepath=None):
        self.save_model_checkpoint(filepath)
    
    def load_model(self, filepath=None):
        if filepath is None:
            filepath = MODEL_FILE
        return self.load_taught_model(filepath)

In [None]:
# ============================================================================
# CELL 6: Main Loop - AI Control with Markov + Curiosity
# ============================================================================
# Matches colleague's version with:
# - Markov imitation learning
# - Curiosity-driven exploration
# - Proper logging and stats
# ============================================================================

brain = Brain()

# Add action perceptrons
for b in ["UP", "DOWN", "LEFT", "RIGHT"]:
    brain.add(Perceptron("action", action=b, group="move"))
for b in ["A", "B", "Start", "Select"]:
    brain.add(Perceptron("action", action=b, group="interact"))

# Load taught model if exists
TAUGHT_MODEL_PATH = BASE_PATH / "model_checkpoint.json"
TAUGHT_EXPLORATION_PATH = BASE_PATH / "taught_exploration_memory.json"

if TAUGHT_MODEL_PATH.exists():
    loaded_ts = brain.load_taught_model(TAUGHT_MODEL_PATH)
    print(f"üéì TRANSFER: Loaded model from timestep {loaded_ts}")
    print(f"   Utilities: {[f'{a.action}:{a.utility:.3f}' for a in brain.actions()]}")
    brain.merge_taught_exploration(TAUGHT_EXPLORATION_PATH)
else:
    print("No taught model found - starting fresh")

# Load taught transitions for Markov
brain.load_taught_transitions(TAUGHT_TRANSITIONS_FILE)

exploration_weight = 1.3
forced_explore_prob = 0.18
prev_context_state = None
prev_raw_position = None

print("="*70)
print("AI CONTROL - Hybrid Markov + Curiosity")
print("="*70)
print("MARKOV SYSTEM:")
print(f"  - Familiarity threshold: {MARKOV_FAMILIARITY_THRESHOLD}")
print(f"  - Immediate weight: {MARKOV_IMMEDIATE_WEIGHT}")
print(f"  - Sequential weight: {MARKOV_SEQUENTIAL_WEIGHT}")
print(f"  - Partial weight: {MARKOV_PARTIAL_WEIGHT}")
print(f"  - Taught batches: {len(brain.taught_batches)}")
print(f"  - Taught frames: {len(brain.taught_transitions)}")
if brain.taught_metadata:
    print(f"  - Action changes recorded: {brain.taught_metadata.get('action_changes', 0)}")
    print(f"  - Maps in teaching: {brain.taught_metadata.get('maps_visited', [])}")
print("="*70)
print("CURIOSITY SYSTEM:")
print(f"  - Forced random exploration: {forced_explore_prob:.0%}")
print(f"  - 'Both' mode threshold: stagnation > {brain.BOTH_MODE_STAGNATION_THRESHOLD}")
print(f"  - Unvisited tile bonus: {brain.UNVISITED_TILE_BONUS}x")
print(f"  - Obstruction penalty: {brain.OBSTRUCTION_PENALTY}x")
print("="*70)
print(f"PERSISTENT MEMORY: {brain.EXPLORATION_MEMORY_FILE}")
print("="*70)

import random

def select_action(brain, learning_state, context_state, raw_position):
    """
    Hybrid action selection: Markov first, then curiosity.
    """
    # Try Markov first
    markov_action, markov_score = brain.get_markov_action(context_state, raw_position)
    
    if markov_action and markov_score >= MARKOV_FAMILIARITY_THRESHOLD:
        brain.markov_action_count += 1
        # Find the perceptron for this action
        for a in brain.actions():
            if a.action == markov_action:
                return a
    
    # Fall back to curiosity/exploration
    brain.curiosity_action_count += 1
    
    # Forced random exploration
    if random.random() < forced_explore_prob:
        return random.choice(brain.actions())
    
    # Simple utility-based selection with exploration bonuses
    current_map = int(context_state[2])
    raw_x, raw_y = raw_position
    memory = brain.get_current_map_memory(current_map)
    visited_tiles = memory['visited_tiles']
    obstructions = memory['obstructions']
    
    best_action = None
    best_score = -float('inf')
    
    for a in brain.actions():
        score = a.utility
        
        # Movement bonuses
        if a.group == "move":
            dx, dy = brain.ACTION_DELTAS.get(a.action, (0, 0))
            target = (raw_x + dx, raw_y + dy)
            
            if target not in visited_tiles:
                score *= brain.UNVISITED_TILE_BONUS
            if target in obstructions:
                score *= brain.OBSTRUCTION_PENALTY
        
        # Add randomness
        score *= (0.9 + random.random() * 0.2)
        
        if score > best_score:
            best_score = score
            best_action = a
    
    return best_action if best_action else brain.actions()[0]

while True:
    # Read state
    context_state, palette_state, tile_state, dead, raw_position = read_game_state()
    
    raw_x, raw_y = raw_position
    in_battle = context_state[3]
    current_map = int(context_state[2])
    current_dir = int(context_state[5])
    
    brain.update_position(raw_x, raw_y)

    derived = compute_derived_features(context_state, prev_context_state)
    learning_state = build_learning_state(derived, palette_state, tile_state, in_battle)
    
    brain.log_state(learning_state, context_state)
    
    # Action execution confirmation
    brain.confirm_action_executed(context_state, prev_context_state)

    if brain.should_send_new_action():
        action = select_action(brain, learning_state, context_state, raw_position)

        if action is not None:
            write_action(action.action)
            brain.last_action = action.action
            brain.set_pending_action(action.action)
            brain.record_action_execution(action.action)
            brain.update_menu_trap_tracking(context_state, action.action, raw_position=raw_position)
        else:
            write_action("NONE")
    else:
        if brain.pending_action:
            write_action(brain.pending_action)

    # === LOGGING ===
    if brain.timestep % 100 == 0:
        memory = brain.get_current_map_memory(current_map)
        visited_count = len(memory['visited_tiles'])
        obs_count = len(memory['obstructions'])
        interactables = len(memory['interactable_objects'])
        coverage = brain.get_exploration_coverage(current_map)
        transitions = memory.get('transitions', [])
        tile_stats = brain.get_tile_interaction_stats(current_map)
        
        tile_needs_probing = brain.should_interact_at_tile(raw_x, raw_y, current_map)
        probe_action, probe_dir = brain.get_best_probe_action(raw_x, raw_y, current_map, current_dir)
        
        dir_name = brain.DIRECTION_NAMES.get(current_dir, '?')
        mode = brain.control_mode
        
        is_both_mode = brain.should_use_both_mode()
        mode_display = "BOTH ‚ö°" if is_both_mode else mode
        
        # Markov stats
        total_actions = brain.markov_action_count + brain.curiosity_action_count
        markov_ratio = brain.markov_action_count / max(1, total_actions)
        
        print(f"\n{'='*70}")
        print(f"Step {brain.timestep} | Map {current_map} | Pos ({raw_x}, {raw_y}) facing {dir_name}")
        print(f"  Mode: {mode_display} | Battle: {int(in_battle)} | Stagnation: {brain.state_stagnation_count}")
        
        # Markov status
        print(f"\n  üß† DECISION MODE:")
        print(f"     Markov: {brain.markov_action_count} ({markov_ratio:.1%}) | Curiosity: {brain.curiosity_action_count} ({1-markov_ratio:.1%})")
        print(f"     Last Markov score: {brain.last_markov_score:.3f} (threshold: {MARKOV_FAMILIARITY_THRESHOLD})")
        if brain.last_markov_action:
            print(f"     Last Markov suggestion: {brain.last_markov_action}")
        
        # Exploration status
        print(f"\n  üìä EXPLORATION:")
        print(f"     Visited: {visited_count} | Obstructions: {obs_count} | Coverage: {coverage:.0%}")
        print(f"     Interactables found: {interactables}")
        
        # Tile probing
        print(f"\n  üéØ TILE PROBING:")
        print(f"     Tiles probed: {tile_stats['probed']} | Exhausted: {tile_stats['exhausted']} | With success: {tile_stats['with_success']}")
        
        if tile_needs_probing:
            if probe_action == 'A':
                print(f"     Current tile: READY TO PROBE (facing untried direction)")
            elif probe_action:
                print(f"     Current tile: NEED TO TURN {probe_action} first")
            else:
                print(f"     Current tile: NEEDS PROBING (checking directions)")
        else:
            print(f"     Current tile: EXHAUSTED or fully probed")
        
        # Transitions
        if transitions:
            print(f"\n  üö™ TRANSITIONS: {len(transitions)} known")
            for t in transitions[:3]:
                pos = tuple(t['position']) if isinstance(t['position'], list) else t['position']
                banned = "üö´" if brain.is_transition_banned(current_map, pos, t['direction']) else ""
                print(f"     ({pos[0]},{pos[1]}) ‚Üí Map {t['destination_map']} (used {t['use_count']}x) {banned}")
        
        # Debt info
        map_debt = brain.map_novelty_debt.get(current_map, 0.0)
        temp_debt = brain.get_temp_debt(current_map)
        
        if map_debt > 0.1 or temp_debt > 0.1:
            print(f"\n  üí≥ DEBT: map={map_debt:.2f}/{brain.MAX_MAP_DEBT}, temp={temp_debt:.2f}")
        
        # Menu trap status
        if brain.menu_trap_b_boost > 1.0:
            print(f"\n  üîí MENU TRAP: B boost {brain.menu_trap_b_boost:.2f}x ({brain.menu_trap_frames} frames)")
        
        # Pending action
        if brain.pending_action:
            print(f"\n  ‚è≥ Pending: {brain.pending_action} ({brain.pending_action_frames}/{brain.ACTION_CONFIRM_FRAMES})")
        
        # Utilities
        action_utils = sorted([(a.action, a.utility) for a in brain.actions()], key=lambda x: x[1], reverse=True)
        print(f"\n  ‚ö° Utilities: {' '.join([f'{k}:{v:.2f}' for k,v in action_utils])}")

    # === MILESTONES ===
    if brain.timestep % 500 == 0 and brain.timestep > 0:
        total_visited = sum(len(m['visited_tiles']) for m in brain.exploration_memory.values())
        total_obs = sum(len(m['obstructions']) for m in brain.exploration_memory.values())
        total_interactables = sum(len(m['interactable_objects']) for m in brain.exploration_memory.values())
        total_transitions = sum(len(m.get('transitions', [])) for m in brain.exploration_memory.values())
        total_probed = sum(len(m.get('tile_interactions', {})) for m in brain.exploration_memory.values())
        total_exhausted = sum(
            sum(1 for t in m.get('tile_interactions', {}).values() if t.get('exhausted', False))
            for m in brain.exploration_memory.values()
        )
        
        total_actions = brain.markov_action_count + brain.curiosity_action_count
        markov_ratio = brain.markov_action_count / max(1, total_actions)
        
        print(f"\n{'#'*70}")
        print(f"# MILESTONE {brain.timestep}")
        print(f"# Maps explored: {len(brain.exploration_memory)}")
        print(f"# Tiles visited: {total_visited} | Obstructions: {total_obs}")
        print(f"# Interactables: {total_interactables} | Transitions: {total_transitions}")
        print(f"# Tiles probed: {total_probed} | Exhausted: {total_exhausted}")
        print(f"#")
        print(f"# HYBRID DECISION STATS:")
        print(f"#   Markov (imitation): {brain.markov_action_count} ({markov_ratio:.1%})")
        print(f"#   Curiosity (explore): {brain.curiosity_action_count} ({1-markov_ratio:.1%})")
        print(f"#   Taught transitions: {len(brain.taught_transitions)}")
        print(f"{'#'*70}")

        brain.save_model_checkpoint(BASE_PATH / "model_checkpoint.json")
        brain.save_exploration_memory()

    time.sleep(0.02)

    # Learn
    next_context, next_palette, next_tiles, dead, next_raw_position = read_game_state()
    next_in_battle = next_context[3]
    next_derived = compute_derived_features(next_context, context_state)
    next_learning_state = build_learning_state(next_derived, next_palette, next_tiles, next_in_battle)

    brain.learn(learning_state, next_learning_state, context_state, next_context, dead=dead, 
                raw_position=raw_position, next_raw_position=next_raw_position)

    prev_context_state = context_state.copy()
    prev_raw_position = raw_position
    brain.timestep += 1

  Loaded exploration memory: 13 maps
‚úÖ Model loaded: step 233500
üéì TRANSFER: Loaded model from timestep 233500
   Utilities: ['UP:0.100', 'DOWN:0.100', 'LEFT:0.100', 'RIGHT:0.100', 'A:0.150', 'B:0.150', 'Start:0.150', 'Select:0.150']
  Merged taught exploration: 13 maps
  ‚úÖ Loaded taught transitions: 5605 batches, 39874 frames
AI CONTROL - Hybrid Markov + Curiosity
MARKOV SYSTEM:
  - Familiarity threshold: 0.6
  - Immediate weight: 0.5
  - Sequential weight: 0.3
  - Partial weight: 0.2
  - Taught batches: 5605
  - Taught frames: 39874
  - Action changes recorded: 3012
  - Maps in teaching: [248, 32, 33, 35, 36, 223, 38, 40, 10, 11, 43, 13, 14, 15, 16, 17, 50, 81, 20, 21, 22, 23, 24, 25, 26, 19, 172, 91, 232, 12]
CURIOSITY SYSTEM:
  - Forced random exploration: 18%
  - 'Both' mode threshold: stagnation > 35
  - Unvisited tile bonus: 1.5x
  - Obstruction penalty: 0.25x
PERSISTENT MEMORY: C:\Users\natmaw\Documents\Boston Stuff\CS 5100 Foundations of AI\PokeAI\taught_exploration_mem