In [32]:
# ============================================================================
# CELL 1: State Management & Utilities (FIXED - Consistent State Size)
# ============================================================================

from pathlib import Path
import json
import numpy as np
import time
from collections import deque

BASE_PATH = Path("C:/Users/natmaw/Documents/Boston Stuff/CS 5100 Foundations of AI/PokeAI")
ACTION_FILE = BASE_PATH / "action.json"
STATE_FILE = BASE_PATH / "game_state.json"
INPUT_FILE = BASE_PATH / "input_cache.txt"
MODEL_FILE = BASE_PATH / "model_checkpoint.json"

EXPECTED_STATE_DIM = 6
PALETTE_DIM = 768
TILE_DIM = 600

# FIXED: Consistent state size for all modes
# 8 (derived) + 600 (tiles) + 768 (palette) = 1376
LEARNING_STATE_DIM = 8 + TILE_DIM + PALETTE_DIM  # 1376

# Action code mapping
ACTION_MAP = {
    'U': 'UP', 'D': 'DOWN', 'L': 'LEFT', 'R': 'RIGHT',
    'A': 'A', 'B': 'B', 'S': 'Start', 'E': 'Select'
}

last_state_mod_time = 0
last_input_mod_time = 0

def normalize_game_state(raw_state):
    """Normalize context state for learning."""
    if len(raw_state) < 6:
        raw_state = list(raw_state) + [0] * (6 - len(raw_state))
    
    normalized = np.array(raw_state, dtype=float)
    normalized[0] = raw_state[0] / 255.0
    normalized[1] = raw_state[1] / 255.0
    normalized[2] = np.clip(raw_state[2], 0, 255)
    normalized[3] = 1.0 if raw_state[3] > 0 else 0.0
    normalized[4] = 1.0 if raw_state[4] > 0 else 0.0
    normalized[5] = int(raw_state[5]) % 4
    
    return normalized

def compute_derived_features(current, prev):
    """Extract temporal features (8D)"""
    if prev is None:
        return np.zeros(8)
    
    vel_x = current[0] - prev[0]
    vel_y = current[1] - prev[1]
    map_changed = 1.0 if abs(current[2] - prev[2]) > 0.5 else 0.0
    battle_started = 1.0 if current[3] > prev[3] else 0.0
    battle_ended = 1.0 if current[3] < prev[3] else 0.0
    menu_opened = 1.0 if current[4] > prev[4] else 0.0
    menu_closed = 1.0 if current[4] < prev[4] else 0.0
    direction_changed = 1.0 if current[5] != prev[5] else 0.0
    
    return np.array([vel_x, vel_y, map_changed, battle_started, battle_ended,
                     menu_opened, menu_closed, direction_changed])

def build_learning_state(derived, palette, tiles, in_battle):
    """
    Build learning state vector - ALWAYS SAME SIZE (1376).
    
    Structure: [derived(8)] + [tiles(600)] + [palette(768)]
    
    In battle: tiles are zeroed but still included for consistent shape.
    """
    # Ensure correct sizes
    if len(derived) != 8:
        derived = np.zeros(8)
    if len(tiles) != TILE_DIM:
        tiles = np.zeros(TILE_DIM)
    if len(palette) != PALETTE_DIM:
        palette = np.zeros(PALETTE_DIM)
    
    # In battle, zero out tiles (they're just UI) but keep the slots
    if in_battle > 0.5:
        tiles = np.zeros(TILE_DIM)
    
    # Always concatenate all three for consistent size
    state = np.concatenate([derived, tiles, palette])
    
    # Add tiny noise
    noise = np.random.randn(len(state)) * 0.0001
    return state + noise

def read_input_cache():
    """Read and parse the input cache file."""
    global last_input_mod_time
    
    if not INPUT_FILE.exists():
        return []
    
    try:
        current_mod_time = INPUT_FILE.stat().st_mtime
        if current_mod_time == last_input_mod_time:
            return []
        last_input_mod_time = current_mod_time
    except:
        return []
    
    inputs = []
    try:
        with open(INPUT_FILE, 'r') as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                
                parts = line.split(',')
                if len(parts) >= 7:
                    action_code = parts[0]
                    inputs.append({
                        'action': ACTION_MAP.get(action_code, action_code),
                        'x': int(parts[1]),
                        'y': int(parts[2]),
                        'map': int(parts[3]),
                        'in_battle': int(parts[4]),
                        'menu_flag': int(parts[5]),
                        'direction': int(parts[6])
                    })
    except Exception as e:
        print(f"[WARN] Error reading input cache: {e}")
        return []
    
    return inputs

def read_game_state_minimal():
    """Read current game state (minimal version)."""
    global last_state_mod_time
    
    if not STATE_FILE.exists():
        return np.zeros(EXPECTED_STATE_DIM), (0, 0), 0
    
    try:
        with open(STATE_FILE, 'r') as f:
            data = json.loads(f.read())
        
        raw = data.get('s', [0, 0, 0, 0, 0, 0])
        input_count = data.get('ic', 0)
        
        context_state = normalize_game_state(raw)
        raw_position = (int(raw[0]), int(raw[1]))
        
        return context_state, raw_position, input_count
        
    except Exception as e:
        return np.zeros(EXPECTED_STATE_DIM), (0, 0), 0

def read_game_state_full():
    """Read full game state including visuals."""
    if not STATE_FILE.exists():
        return (np.zeros(EXPECTED_STATE_DIM), np.zeros(PALETTE_DIM), 
                np.zeros(TILE_DIM), (0, 0))
    
    try:
        with open(STATE_FILE, 'r') as f:
            data = json.loads(f.read())
        
        raw = data.get('s', [0, 0, 0, 0, 0, 0])
        palette_raw = data.get('p', [])
        tiles_raw = data.get('t', [])
        
        context_state = normalize_game_state(raw)
        raw_position = (int(raw[0]), int(raw[1]))
        
        # Process palette
        if palette_raw:
            palette_state = np.array(palette_raw, dtype=float)
        else:
            palette_state = np.zeros(PALETTE_DIM)
        
        # Process tiles
        if tiles_raw:
            tile_state = np.array(tiles_raw, dtype=float)
        else:
            tile_state = np.zeros(TILE_DIM)
        
        # Ensure correct dimensions
        if len(palette_state) < PALETTE_DIM:
            palette_state = np.pad(palette_state, (0, PALETTE_DIM - len(palette_state)))
        elif len(palette_state) > PALETTE_DIM:
            palette_state = palette_state[:PALETTE_DIM]
            
        if len(tile_state) < TILE_DIM:
            tile_state = np.pad(tile_state, (0, TILE_DIM - len(tile_state)))
        elif len(tile_state) > TILE_DIM:
            tile_state = tile_state[:TILE_DIM]
        
        return context_state, palette_state, tile_state, raw_position
        
    except Exception as e:
        print(f"[WARN] Error reading full state: {e}")
        return (np.zeros(EXPECTED_STATE_DIM), np.zeros(PALETTE_DIM), 
                np.zeros(TILE_DIM), (0, 0))

def process_cached_input(inp):
    """Convert a cached input dict to normalized state + action."""
    raw_state = [
        inp.get('x', 0),
        inp.get('y', 0),
        inp.get('map', 0),
        inp.get('in_battle', 0),
        inp.get('menu_flag', 0),
        inp.get('direction', 0)
    ]
    context_state = normalize_game_state(raw_state)
    raw_position = (inp.get('x', 0), inp.get('y', 0))
    action = inp.get('action', None)
    
    return context_state, raw_position, action

# Legacy compatibility
def read_game_state(max_retries=3):
    """Legacy function for compatibility."""
    context, palette, tiles, raw_pos = read_game_state_full()
    return context, palette, tiles, False, raw_pos, None

def write_action(action_name):
    if action_name:
        action_name = action_name.upper()
    try:
        with open(ACTION_FILE, "w") as f:
            json.dump({"action": action_name}, f)
            f.flush()
    except Exception as e:
        print(f"[ERROR] Failed to write action: {e}")

In [33]:
# ============================================================================
# CELL 2: Perceptron Classes (FIXED - Shape Safety)
# ============================================================================

class Perceptron:
    def __init__(self, kind, action=None, group=None, entity_type=None):
        self.kind = kind
        self.action = action
        self.group = group
        self.entity_type = entity_type
        
        self.utility = 1.0
        self.weights = None
        
        self.eligibility_fast = 0.0
        self.eligibility_slow = 0.0
        
        self.familiarity = 0.0
        self.activation_history = deque(maxlen=10)
        
        self.learning_rate = 0.01
        self.prediction_errors = deque(maxlen=50)

    def ensure_weights(self, dim):
        """Initialize or resize weights to match dimension."""
        if self.weights is None:
            self.weights = np.random.randn(dim) * 0.001
        elif len(self.weights) != dim:
            # FIXED: Resize weights if dimension changed
            old_weights = self.weights
            self.weights = np.random.randn(dim) * 0.001
            # Copy over what we can from old weights
            min_len = min(len(old_weights), dim)
            self.weights[:min_len] = old_weights[:min_len]

    def predict(self, state):
        self.ensure_weights(len(state))
        raw_activation = np.dot(self.weights, state)
        
        if self.kind == "entity":
            novelty_factor = 1.0 / (1.0 + np.sqrt(self.familiarity * 0.5))
            decayed_activation = raw_activation * novelty_factor
            self.activation_history.append(abs(raw_activation))
            return decayed_activation
        else:
            return raw_activation

    def adapt_learning_rate(self):
        if len(self.prediction_errors) >= 50:
            avg_error = np.mean(self.prediction_errors)
            
            if avg_error < 0.1:
                self.learning_rate = max(0.001, self.learning_rate * 0.99)
            elif avg_error > 0.5:
                self.learning_rate = min(0.05, self.learning_rate * 1.01)

    def update(self, state, error, gamma_fast=0.5, gamma_slow=0.95, stagnation=0.0):
        """Update weights - FIXED to handle shape mismatches."""
        # FIXED: Always ensure weights match state size
        self.ensure_weights(len(state))
        
        # Double-check shapes match (safety)
        if len(self.weights) != len(state):
            print(f"[WARN] Shape mismatch: weights={len(self.weights)}, state={len(state)}. Resizing.")
            self.weights = np.random.randn(len(state)) * 0.001
        
        self.eligibility_fast = gamma_fast * self.eligibility_fast + 1.0
        self.eligibility_slow = gamma_slow * self.eligibility_slow + 1.0
        
        self.adapt_learning_rate()
        
        fast_update = 0.7 * self.learning_rate * error * state * self.eligibility_fast
        slow_update = 0.3 * self.learning_rate * error * state * self.eligibility_slow
        self.weights += fast_update + slow_update

        if self.kind == "action":
            if error > 0.01:
                if stagnation > 0.5:
                    self.utility *= 0.97
                elif error > 0.2:
                    self.utility = min(self.utility * 1.02, 2.0)
                else:
                    self.utility *= 0.995
            
            if self.group == "move":
                self.utility = np.clip(self.utility, 0.1, 2.0)
            else:
                self.utility = np.clip(self.utility, 0.01, 2.0)
        
        if self.kind == "entity" and len(self.activation_history) > 0:
            recent_avg = np.mean(self.activation_history)
            if recent_avg > 0.1:
                self.familiarity += 0.03
        
        if self.kind == "entity":
            prediction = self.predict(state)
            self.prediction_errors.append(abs(prediction - error))


class ControlSwapPerceptron(Perceptron):
    def __init__(self):
        super().__init__(kind="control_swap")
        self.swap_history = deque(maxlen=100)
        self.confidence = 0.0
        
    def should_swap(self, state, movement_stagnation):
        if self.weights is None:
            return False, 0.0
        
        self.ensure_weights(len(state))
        swap_score = np.dot(self.weights, state)
        stagnation_factor = np.tanh(movement_stagnation / 5.0)
        combined_score = swap_score * 0.7 + stagnation_factor * 0.3
        
        return combined_score > 0.5, abs(combined_score)
    
    def record_swap_outcome(self, state, swapped, novelty_gained):
        self.swap_history.append((swapped, novelty_gained))
        
        if len(self.swap_history) >= 20:
            recent = list(self.swap_history)[-20:]
            successful = sum(1 for swap, nov in recent if swap and nov > 0.2)
            self.confidence = successful / 20.0

In [34]:
# ============================================================================
# CELL 3: Brain Class - COMPLETE MEMORY SAFE VERSION
# ============================================================================

import gc

class Brain:
    def __init__(self):
        self.perceptrons = []
        
        # REDUCED DEQUE SIZES
        self.prev_learning_states = deque(maxlen=10)
        self.prev_context_states = deque(maxlen=5)
        self.last_positions = deque(maxlen=15)
        self.action_history = deque(maxlen=30)
        
        self.control_mode = "move"
        self.timestep = 0
        self.last_action = None
        self.last_direction = 0
        
        self.MOVE_UTILITY_FLOOR = 0.05
        self.INTERACT_UTILITY_FLOOR = 0.15
        
        # === EXPLORATION MEMORY ===
        self.EXPLORATION_MEMORY_FILE = BASE_PATH / "exploration_memory.json"
        self.exploration_memory = {}
        self.current_map_id = None
        self.SAVE_INTERVAL = 100
        self.MAX_MAPS_IN_MEMORY = 20
        
        # Direction mapping
        self.DIRECTION_NAMES = {0: "DOWN", 1: "UP", 2: "LEFT", 3: "RIGHT"}
        self.DIRECTION_TO_INT = {"DOWN": 0, "UP": 1, "LEFT": 2, "RIGHT": 3}
        self.INT_TO_ACTION = {0: "DOWN", 1: "UP", 2: "LEFT", 3: "RIGHT"}
        
        self.DIRECTION_DELTAS_INT = {0: (0, 1), 1: (0, -1), 2: (-1, 0), 3: (1, 0)}
        self.ACTION_DELTAS = {"UP": (0, -1), "DOWN": (0, 1), "LEFT": (-1, 0), "RIGHT": (1, 0)}
        self.DELTA_TO_DIRECTION = {(0, 1): 0, (0, -1): 1, (-1, 0): 2, (1, 0): 3}
        
        self.load_exploration_memory()
        
        # === ACTION EXECUTION ===
        self.pending_action = None
        self.pending_action_frames = 0
        self.ACTION_CONFIRM_FRAMES = 3
        self.last_confirmed_action = None
        
        # === TILE INTERACTION ===
        self.INTERACTION_VERIFY_FRAMES = 8
        self.MIN_SUCCESS_RATE_THRESHOLD = 0.1
        self.pending_interaction_verify = None
        self.interaction_verify_countdown = 0
        
        # === MENU TRAP ===
        self.menu_trap_frames = 0
        self.menu_trap_b_boost = 1.0
        self.menu_trap_position = None
        self.B_BOOST_INCREMENT = 0.15
        self.B_BOOST_MAX = 3.0
        self.MENU_TRAP_THRESHOLD = 5
        self.original_b_utility = None
        
        # === MODE SWAPPING ===
        self.DEFAULT_MOVE_TO_INTERACT_THRESHOLD = 15
        self.DEFAULT_INTERACT_TO_MOVE_THRESHOLD = 25
        self.move_to_interact_threshold = self.DEFAULT_MOVE_TO_INTERACT_THRESHOLD
        self.interact_to_move_threshold = self.DEFAULT_INTERACT_TO_MOVE_THRESHOLD
        self.THRESHOLD_INCREMENT = 15
        self.MAX_THRESHOLD = 150
        self.frames_in_current_mode = 0
        self.swap_chain_count = 0
        self.position_at_mode_swap = None
        self.last_map_id = None
        self.last_battle_state = None
        
        # === STAGNATION ===
        self.STATE_STAGNATION_THRESHOLD = 20
        self.state_stagnation_count = 0
        self.last_context_state_hash = None
        self.stagnation_initiator_action = None
        self.STAGNATION_INITIATOR_PENALTY = 0.7
        self.unproductive_swap_count = 0
        self.UNPRODUCTIVE_SWAP_THRESHOLD = 3
        
        # === BOTH MODE ===
        self.BOTH_MODE_STAGNATION_THRESHOLD = 35
        self.BOTH_MODE_SWAP_THRESHOLD = 5
        self.last_direction_for_progress = None
        self.direction_change_counts_as_progress = True
        
        # === NOVELTY ===
        self.UNVISITED_TILE_BONUS = 1.5
        self.OBSTRUCTION_PENALTY = 0.25
        
        # === TRANSITIONS ===
        self.TRANSITION_ATTRACTION_WEIGHT = 0.6
        self.TEMP_DEBT_ACCUMULATION = 0.5
        self.TEMP_DEBT_DECAY = 0.02
        self.TEMP_DEBT_MAX = 15.0
        
        # === DEBT ===
        self.MAX_MAP_DEBT = 10.0
        self.MAX_LOCATION_DEBT = 5.0
        self.DEBT_DECAY_RATE = 0.005
        
        # === BANS ===
        self.transition_bans = {}
        self.BAN_VICINITY_RADIUS = 3
        self.BAN_COVERAGE_LIFT_THRESHOLD = 0.6
        self.BAN_TIMEOUT_STEPS = 300
        
        # Multi-scale memory - LIMITED
        self.visited_maps = {}
        self.map_novelty_debt = {}
        self.location_memory = {}
        self.location_novelty = {}
        self.action_execution_count = {}
        self.MAX_LOCATIONS = 500
        
        self.swap_perceptron = ControlSwapPerceptron()
        
        # REDUCED ERROR HISTORY
        self.error_history = deque(maxlen=100)
        self.numeric_error_history = deque(maxlen=100)
        self.visual_error_history = deque(maxlen=100)
        
        self._entity_norms_cache = {}
        self._cache_valid = False
        self.innate_entities_spawned = False
        
        # === REPETITION ===
        self.consecutive_action_count = 0
        self.current_repeated_action = None
        self.LEARNING_SLOWDOWN_START = 3
        self.LEARNING_SLOWDOWN_MAX = 10
        self.PENALTY_THRESHOLD = 12
        self.HARD_RESET_THRESHOLD = 18
        
        # === PATTERN ===
        self.PATTERN_CHECK_WINDOW = 30
        self.PATTERN_MIN_REPEATS = 3
        self.PATTERN_MAX_LENGTH = 10
        self.detected_pattern = None
        self.pattern_repeat_count = 0

        # === PROBE CACHE ===
        self._cached_probe_action = None
        self._cached_probe_dir = None
        self._probe_cache_position = None
        
        # === TEACHING MODE ===
        self.teaching_mode = True
        self.demonstration_count = 0
        self.context_action_stats = {}
        self.MAX_CONTEXT_STATS = 50

    # =========================================================================
    # CORE METHODS
    # =========================================================================
    
    def add(self, p):
        self.perceptrons.append(p)
        self._cache_valid = False

    def actions(self):
        return [p for p in self.perceptrons if p.kind == "action"]

    def entities(self):
        return [p for p in self.perceptrons if p.kind == "entity"]

    def get_location_key(self, x, y, map_id, bin_size=5):
        return (int(map_id), int(x // bin_size) * bin_size, int(y // bin_size) * bin_size)

    def is_near_map_edge(self, x, y):
        return x < 10 or x > 245 or y < 10 or y > 245

    def record_action_execution(self, action_name):
        if action_name:
            self.action_execution_count[action_name] = self.action_execution_count.get(action_name, 0) + 1

    def get_position_stagnation(self):
        if len(self.last_positions) < 2:
            return 0
        current_pos = self.last_positions[-1]
        return sum(1 for pos in reversed(list(self.last_positions)[:-1]) if pos == current_pos)

    def get_group_weight(self, group):
        return sum(a.utility for a in self.actions() if a.group == group)

    def log_state(self, learning_state, context_state):
        self.prev_learning_states.append(learning_state)
        self.prev_context_states.append(context_state)

    def update_position(self, x, y):
        self.last_positions.append((int(x), int(y)))

    # =========================================================================
    # MEMORY MANAGEMENT
    # =========================================================================
    
    def cleanup_memory(self):
        if len(self.location_memory) > self.MAX_LOCATIONS:
            sorted_locs = sorted(self.location_memory.items(), key=lambda x: x[1], reverse=True)
            self.location_memory = dict(sorted_locs[:self.MAX_LOCATIONS // 2])
            self.location_novelty = {k: v for k, v in self.location_novelty.items() if k in self.location_memory}
        
        if len(self.context_action_stats) > self.MAX_CONTEXT_STATS:
            keys = list(self.context_action_stats.keys())
            for k in keys[:-self.MAX_CONTEXT_STATS // 2]:
                del self.context_action_stats[k]
        
        if len(self.exploration_memory) > self.MAX_MAPS_IN_MEMORY:
            self.save_exploration_memory()
            sorted_maps = sorted(self.exploration_memory.items(),
                                key=lambda x: x[1].get('last_visited_timestep', 0), reverse=True)
            self.exploration_memory = dict(sorted_maps[:self.MAX_MAPS_IN_MEMORY // 2])
        
        self._entity_norms_cache.clear()
        self._cache_valid = False
        gc.collect()
    
    def get_memory_stats(self):
        stats = {
            'exploration_maps': len(self.exploration_memory),
            'location_memory': len(self.location_memory),
            'context_stats': len(self.context_action_stats),
            'error_history': len(self.error_history),
            'perceptrons': len(self.perceptrons),
        }
        total_tiles = sum(len(m.get('visited_tiles', set())) for m in self.exploration_memory.values())
        stats['total_tiles'] = total_tiles
        return stats

    # =========================================================================
    # EXPLORATION MEMORY
    # =========================================================================
    
    def load_exploration_memory(self):
        try:
            if self.EXPLORATION_MEMORY_FILE.exists():
                with open(self.EXPLORATION_MEMORY_FILE, 'r') as f:
                    data = json.load(f)
                    self.exploration_memory = {}
                    items = list(data.items())[-self.MAX_MAPS_IN_MEMORY:]
                    for map_key, map_data in items:
                        map_id = int(map_key.replace('map_', ''))
                        self.exploration_memory[map_id] = self._deserialize_map_memory(map_data)
                print(f"  Loaded exploration memory: {len(self.exploration_memory)} maps")
            else:
                self.exploration_memory = {}
        except Exception as e:
            print(f"  Error loading exploration memory: {e}")
            self.exploration_memory = {}

    def _deserialize_map_memory(self, map_data):
        tile_interactions = {}
        ti_data = map_data.get('tile_interactions', {})
        ti_items = list(ti_data.items())[-100:]
        for tile_key, tile_data in ti_items:
            tile_interactions[tile_key] = {
                'directions_tried': set(tile_data.get('directions_tried', [])),
                'direction_attempts': {int(k): v for k, v in tile_data.get('direction_attempts', {}).items()},
                'direction_successes': {int(k): v for k, v in tile_data.get('direction_successes', {}).items()},
                'exhausted': tile_data.get('exhausted', False)
            }
        return {
            'visited_tiles': set(tuple(t) for t in map_data.get('visited_tiles', [])[-1000:]),
            'obstructions': set(tuple(t) for t in map_data.get('obstructions', [])[-500:]),
            'interactable_objects': map_data.get('interactable_objects', [])[-50:],
            'last_visited_timestep': map_data.get('last_visited_timestep', 0),
            'transitions': map_data.get('transitions', [])[-20:],
            'temp_debt': map_data.get('temp_debt', 0.0),
            'tile_interactions': tile_interactions
        }

    def save_exploration_memory(self):
        try:
            data = {f'map_{mid}': self._serialize_map_memory(md) for mid, md in self.exploration_memory.items()}
            with open(self.EXPLORATION_MEMORY_FILE, 'w') as f:
                json.dump(data, f)
        except Exception as e:
            print(f"  Error saving exploration memory: {e}")

    def _serialize_map_memory(self, map_data):
        serialized_ti = {}
        for tile_key, td in list(map_data.get('tile_interactions', {}).items())[-100:]:
            serialized_ti[tile_key] = {
                'directions_tried': list(td.get('directions_tried', set())),
                'direction_attempts': {str(k): v for k, v in td.get('direction_attempts', {}).items()},
                'direction_successes': {str(k): v for k, v in td.get('direction_successes', {}).items()},
                'exhausted': td.get('exhausted', False)
            }
        return {
            'visited_tiles': list(map_data['visited_tiles'])[-1000:],
            'obstructions': list(map_data['obstructions'])[-500:],
            'interactable_objects': map_data['interactable_objects'][-50:],
            'last_visited_timestep': map_data['last_visited_timestep'],
            'transitions': map_data.get('transitions', [])[-20:],
            'temp_debt': map_data.get('temp_debt', 0.0),
            'tile_interactions': serialized_ti
        }

    def get_current_map_memory(self, map_id):
        if map_id not in self.exploration_memory:
            self.exploration_memory[map_id] = {
                'visited_tiles': set(), 'obstructions': set(), 'interactable_objects': [],
                'last_visited_timestep': self.timestep, 'transitions': [], 'temp_debt': 0.0,
                'tile_interactions': {}
            }
        return self.exploration_memory[map_id]

    def record_visited_tile(self, x, y, map_id):
        memory = self.get_current_map_memory(map_id)
        if len(memory['visited_tiles']) < 1000:
            memory['visited_tiles'].add((int(x), int(y)))
        memory['last_visited_timestep'] = self.timestep

    def record_obstruction(self, x, y, map_id, direction):
        dx, dy = self.DIRECTION_DELTAS_INT.get(direction, (0, 0))
        memory = self.get_current_map_memory(map_id)
        if len(memory['obstructions']) < 500:
            memory['obstructions'].add((int(x + dx), int(y + dy)))

    # =========================================================================
    # TILE INTERACTION
    # =========================================================================
    
    def get_tile_interaction_key(self, x, y):
        return f"{int(x)}_{int(y)}"
    
    def get_tile_interaction_state(self, x, y, map_id):
        memory = self.get_current_map_memory(map_id)
        tile_key = self.get_tile_interaction_key(x, y)
        if tile_key not in memory['tile_interactions']:
            memory['tile_interactions'][tile_key] = {
                'directions_tried': set(),
                'direction_attempts': {0: 0, 1: 0, 2: 0, 3: 0},
                'direction_successes': {0: 0, 1: 0, 2: 0, 3: 0},
                'exhausted': False
            }
        return memory['tile_interactions'][tile_key]
    
    def should_interact_at_tile(self, x, y, map_id):
        tile_state = self.get_tile_interaction_state(x, y, map_id)
        if tile_state['exhausted']:
            return False
        if len(tile_state['directions_tried']) < 4:
            return True
        for d in range(4):
            attempts = tile_state['direction_attempts'].get(d, 0)
            successes = tile_state['direction_successes'].get(d, 0)
            if attempts > 0 and successes / attempts >= self.MIN_SUCCESS_RATE_THRESHOLD:
                return True
        return False
    
    def get_untried_directions(self, x, y, map_id):
        tile_state = self.get_tile_interaction_state(x, y, map_id)
        return [d for d in range(4) if d not in tile_state['directions_tried']]

    def get_exploration_coverage(self, map_id):
        memory = self.get_current_map_memory(map_id)
        visited = len(memory['visited_tiles'])
        obstructions = len(memory['obstructions'])
        if visited == 0 or visited + obstructions < 10:
            return 0.0
        return visited / (visited + obstructions)

    # =========================================================================
    # TRANSITIONS & DEBT
    # =========================================================================
    
    def record_transition(self, from_pos, from_map, to_map, direction, action_type):
        memory = self.get_current_map_memory(from_map)
        for t in memory['transitions']:
            if t['position'] == from_pos and t['direction'] == direction:
                t['use_count'] += 1
                t['last_used'] = self.timestep
                return
        memory['transitions'].append({
            'position': from_pos, 'direction': direction, 'action': action_type,
            'destination_map': to_map, 'use_count': 1, 'last_used': self.timestep
        })
        print(f"  üö™ TRANSITION FOUND: Map {from_map} ({from_pos}) ‚Üí Map {to_map}")

    def get_temp_debt(self, map_id):
        memory = self.get_current_map_memory(map_id)
        raw_debt = memory.get('temp_debt', 0.0)
        if map_id != self.current_map_id:
            steps_away = self.timestep - memory.get('last_visited_timestep', 0)
            return max(0.0, raw_debt - steps_away * self.TEMP_DEBT_DECAY)
        return raw_debt

    def accumulate_temp_debt(self, map_id):
        memory = self.get_current_map_memory(map_id)
        memory['temp_debt'] = min(self.TEMP_DEBT_MAX, memory.get('temp_debt', 0.0) + self.TEMP_DEBT_ACCUMULATION)

    def decay_all_debts(self):
        for map_id in list(self.map_novelty_debt.keys()):
            if map_id != self.current_map_id:
                self.map_novelty_debt[map_id] *= (1.0 - self.DEBT_DECAY_RATE)
                if self.map_novelty_debt[map_id] < 0.1:
                    del self.map_novelty_debt[map_id]

    # =========================================================================
    # EXPLORATION TRACKING
    # =========================================================================
    
    def update_exploration_tracking(self, context_state, prev_context_state, raw_position=None, prev_raw_position=None):
        current_map = int(context_state[2])
        raw_x = raw_position[0] if raw_position else int(context_state[0] * 255)
        raw_y = raw_position[1] if raw_position else int(context_state[1] * 255)
        
        if self.current_map_id is not None and current_map != self.current_map_id:
            if prev_context_state is not None and prev_raw_position is not None:
                self.record_transition(prev_raw_position, self.current_map_id, current_map,
                    int(prev_context_state[5]), 'interact' if self.last_action == 'A' else 'walk')
            self.on_map_change(current_map)
        
        self.current_map_id = current_map
        self.record_visited_tile(raw_x, raw_y, current_map)
        self.accumulate_temp_debt(current_map)
        self.last_direction = int(context_state[5])
        
        if self.timestep % 300 == 0:
            self.decay_all_debts()

    def on_map_change(self, new_map):
        self.save_exploration_memory()
        self.control_mode = "move"
        self.frames_in_current_mode = 0
        memory = self.get_current_map_memory(new_map)
        print(f"  üó∫Ô∏è MAP CHANGE ‚Üí {new_map}: {len(memory['visited_tiles'])} visited")

    # =========================================================================
    # UTILITY & LEARNING HELPERS
    # =========================================================================
    
    def enforce_utility_floors(self):
        for a in self.actions():
            floor = self.MOVE_UTILITY_FLOOR if a.group == "move" else self.INTERACT_UTILITY_FLOOR
            a.utility = max(a.utility, floor)

    def stagnation_level(self, window=10):
        if len(self.prev_learning_states) < window:
            return 0.0
        recent = list(self.prev_learning_states)[-window:]
        return 1.0 - np.tanh(np.mean([np.linalg.norm(recent[i] - recent[i-1]) for i in range(1, len(recent))]) * 2.0)

    def track_consecutive_action(self, action_name):
        if action_name == self.current_repeated_action:
            self.consecutive_action_count += 1
        else:
            self.current_repeated_action = action_name
            self.consecutive_action_count = 1

    def get_learning_multiplier(self, action_name):
        if action_name != self.current_repeated_action or self.consecutive_action_count < self.LEARNING_SLOWDOWN_START:
            return 1.0
        progress = min(1.0, (self.consecutive_action_count - self.LEARNING_SLOWDOWN_START) / 
                       (self.LEARNING_SLOWDOWN_MAX - self.LEARNING_SLOWDOWN_START))
        return max(0.05, 1.0 - 0.95 * progress)

    def apply_repetition_penalty(self):
        if self.current_repeated_action is None:
            return
        for a in self.actions():
            if a.action == self.current_repeated_action:
                floor = self.INTERACT_UTILITY_FLOOR if a.group == "interact" else self.MOVE_UTILITY_FLOOR
                if self.consecutive_action_count >= self.HARD_RESET_THRESHOLD:
                    a.utility = max(floor, a.utility * 0.5)
                    self.consecutive_action_count = 0
                elif self.consecutive_action_count >= self.PENALTY_THRESHOLD:
                    a.utility = max(a.utility * 0.7, floor)
                break

    def apply_pattern_penalty(self):
        pass  # Simplified for now

    def spawn_innate_entities(self, learning_state):
        if self.innate_entities_spawned:
            return
        for etype, indices in [("sense_menu", [5, 6]), ("sense_battle", [3, 4]), 
                                ("sense_movement", [0, 1]), ("sense_map_transition", [2])]:
            entity = Perceptron("entity", entity_type=etype)
            entity.ensure_weights(len(learning_state))
            entity.weights = np.zeros(len(learning_state))
            for idx in indices:
                if idx < len(entity.weights):
                    entity.weights[idx] = 0.5 if len(indices) > 1 else 1.0
            self.add(entity)
        self.innate_entities_spawned = True

    def compute_multi_modal_error(self, state, next_state):
        min_len = min(len(state), len(next_state))
        diffs = [abs(next_state[i] - state[i]) for i in range(min(8, min_len))]
        weights = [0.5, 0.5, 10.0, 5.0, 3.0, 2.0, 1.5, 0.3]
        weighted = sum(d * w for d, w in zip(diffs, weights[:len(diffs)]))
        if min_len > 8:
            weighted += np.linalg.norm(next_state[8:min_len] - state[8:min_len]) * 2.0
        numeric = sum(diffs)
        visual = np.linalg.norm(next_state[8:min_len] - state[8:min_len]) if min_len > 8 else 0.0
        return weighted, numeric, visual

    # =========================================================================
    # MAIN LEARN METHOD
    # =========================================================================
    
    def learn(self, learning_state, next_learning_state, context_state, next_context_state, dead=False,
            raw_position=None, next_raw_position=None):
        
        # Ensure same shape
        if len(learning_state) != len(next_learning_state):
            max_dim = max(len(learning_state), len(next_learning_state))
            learning_state = np.pad(learning_state, (0, max(0, max_dim - len(learning_state))))
            next_learning_state = np.pad(next_learning_state, (0, max(0, max_dim - len(next_learning_state))))
        
        if not self.innate_entities_spawned:
            self.spawn_innate_entities(learning_state)
        
        prev_context = self.prev_context_states[-1] if self.prev_context_states else None
        prev_raw = getattr(self, '_last_raw_position', None)
        self.update_exploration_tracking(context_state, prev_context, raw_position, prev_raw)
        self._last_raw_position = raw_position
        
        weighted_error, numeric_error, visual_error = self.compute_multi_modal_error(learning_state, next_learning_state)
        self.error_history.append(weighted_error)
        self.numeric_error_history.append(numeric_error)
        self.visual_error_history.append(visual_error)
        
        current_map = int(context_state[2])
        loc = self.get_location_key(*(raw_position if raw_position else (context_state[0]*255, context_state[1]*255)), current_map)
        
        self.visited_maps[current_map] = self.visited_maps.get(current_map, 0) + 1
        
        if len(self.location_memory) < self.MAX_LOCATIONS:
            self.location_memory[loc] = self.location_memory.get(loc, 0) + 1
        
        if self.visited_maps[current_map] > 10:
            self.map_novelty_debt[current_map] = min(self.MAX_MAP_DEBT, 
                self.map_novelty_debt.get(current_map, 0.0) + 0.05 * (self.visited_maps[current_map] - 10))
        
        if self.visited_maps[current_map] > 30:
            weighted_error *= 0.5
        if self.location_memory.get(loc, 0) > 25:
            weighted_error *= 0.7
        
        stagnation = self.stagnation_level()
        learning_mult = self.get_learning_multiplier(self.last_action) if self.last_action else 1.0
        
        for p in self.perceptrons:
            mult = learning_mult if (p.kind == "action" and p.action == self.last_action) else 1.0
            p.update(learning_state, weighted_error * mult, stagnation=stagnation)
        
        self.apply_repetition_penalty()
        self.apply_pattern_penalty()
        self.enforce_utility_floors()
        
        if self.timestep % 1000 == 0:
            self.cleanup_memory()
        
        if self.timestep % self.SAVE_INTERVAL == 0:
            self.save_exploration_memory()
        
        self.action_history.append(self.last_action)

    # =========================================================================
    # TEACHING MODE
    # =========================================================================
        
    def learn_from_human_action(self, learning_state, human_action, context_state):
        if human_action is None or human_action == "NONE":
            return
        
        self.demonstration_count += 1
        context = self._detect_context(context_state)
        
        context_key = f"{context}_{int(context_state[2])}"
        if len(self.context_action_stats) < self.MAX_CONTEXT_STATS:
            if context_key not in self.context_action_stats:
                self.context_action_stats[context_key] = {}
            if human_action not in self.context_action_stats[context_key]:
                self.context_action_stats[context_key][human_action] = 0
            self.context_action_stats[context_key][human_action] += 1
        
        for a in self.actions():
            if a.action == human_action:
                a.utility = min(a.utility * 1.05, 2.0)
                break
        
        for a in self.actions():
            a.ensure_weights(len(learning_state))
            if a.action == human_action:
                a.update(learning_state, 0.1, stagnation=0.0)
            else:
                a.update(learning_state, -0.02, stagnation=0.0)

    def _detect_context(self, context_state):
        if context_state[3] > 0.5:
            return "battle"
        elif context_state[4] > 0.5:
            return "menu"
        else:
            return "overworld"

    def print_teaching_stats(self):
        if not self.context_action_stats:
            return
        print(f"\n{'='*50}")
        print(f"üìö TEACHING STATS (Demos: {self.demonstration_count})")
        sorted_contexts = sorted(self.context_action_stats.items(),
                                 key=lambda x: sum(x[1].values()), reverse=True)[:5]
        for context_key, actions in sorted_contexts:
            total = sum(actions.values())
            print(f"\n  {context_key}:")
            for action, count in sorted(actions.items(), key=lambda x: x[1], reverse=True)[:3]:
                print(f"    {action}: {count} ({count/total*100:.0f}%)")

    # =========================================================================
    # MODEL SAVE/LOAD
    # =========================================================================

    def save_model(self, filepath=None):
        if filepath is None:
            filepath = MODEL_FILE
        
        actions_data = []
        for a in self.actions():
            if a.weights is not None:
                nonzero_indices = np.where(np.abs(a.weights) > 1e-6)[0]
                nonzero_weights = [(int(idx), float(a.weights[idx])) for idx in nonzero_indices]
                actions_data.append({
                    "action": a.action, "group": a.group, "utility": float(a.utility),
                    "weights_shape": int(len(a.weights)), "weights_nonzero": nonzero_weights,
                    "learning_rate": float(a.learning_rate), "familiarity": float(a.familiarity)
                })
        
        entities_data = []
        for e in self.entities():
            if e.weights is not None:
                nonzero_indices = np.where(np.abs(e.weights) > 1e-6)[0]
                nonzero_weights = [(int(idx), float(e.weights[idx])) for idx in nonzero_indices]
                entities_data.append({
                    "entity_type": e.entity_type, "utility": float(e.utility),
                    "weights_shape": int(len(e.weights)), "weights_nonzero": nonzero_weights,
                    "familiarity": float(e.familiarity)
                })
        
        model_data = {
            "timestep": int(self.timestep),
            "perceptrons": {"actions": actions_data, "entities": entities_data},
            "debt_tracking": {
                "map_novelty_debt": {int(k): float(v) for k, v in self.map_novelty_debt.items()},
                "visited_maps": {int(k): int(v) for k, v in self.visited_maps.items()}
            },
            "teaching_stats": {
                "demonstration_count": int(self.demonstration_count),
                "context_action_stats": self.context_action_stats
            },
            "control_mode": self.control_mode
        }
        
        try:
            with open(filepath, 'w') as f:
                json.dump(model_data, f)
            print(f"üíæ Model saved: {self.timestep} steps, {self.demonstration_count} demos")
        except Exception as e:
            print(f"‚ùå Save error: {e}")

    def load_model(self, filepath=None):
        if filepath is None:
            filepath = MODEL_FILE
        
        if not filepath.exists():
            print(f"‚ÑπÔ∏è No saved model at {filepath}")
            return False
        
        try:
            with open(filepath, 'r') as f:
                model_data = json.load(f)
            
            self.timestep = model_data.get("timestep", 0)
            self.control_mode = model_data.get("control_mode", "move")
            
            for a_data in model_data.get("perceptrons", {}).get("actions", []):
                for a in self.actions():
                    if a.action == a_data["action"]:
                        a.utility = a_data["utility"]
                        a.learning_rate = a_data.get("learning_rate", 0.01)
                        a.familiarity = a_data.get("familiarity", 0.0)
                        weights_shape = a_data["weights_shape"]
                        a.weights = np.zeros(weights_shape)
                        for idx, val in a_data["weights_nonzero"]:
                            if idx < weights_shape:
                                a.weights[idx] = val
                        break
            
            debt_data = model_data.get("debt_tracking", {})
            self.map_novelty_debt = {int(k): float(v) for k, v in debt_data.get("map_novelty_debt", {}).items()}
            self.visited_maps = {int(k): int(v) for k, v in debt_data.get("visited_maps", {}).items()}
            
            teaching_data = model_data.get("teaching_stats", {})
            self.demonstration_count = teaching_data.get("demonstration_count", 0)
            self.context_action_stats = teaching_data.get("context_action_stats", {})
            
            print(f"‚úÖ Model loaded: {self.timestep} steps, {self.demonstration_count} demos")
            return True
        except Exception as e:
            print(f"‚ùå Load error: {e}")
            return False

In [35]:
# # ============================================================================
# # CELL 4: Action Selection - Updated with All Fixes
# # ============================================================================
# # CHANGES:
# # 1. Added FORCED_EXPLORE_PROB (18%) for random exploration
# # 2. Added "both" mode handling - allows all actions when stuck
# # 3. Added turn-for-probing override - allows turns even in interact mode
# # ============================================================================

# import random  # Add to imports if not present

# GBA_ACTIONS = ["Up", "Down", "Left", "Right", "A", "B", "Start", "Select"]
# ACTION_DELTAS = {"UP": (0, -1), "DOWN": (0, 1), "LEFT": (-1, 0), "RIGHT": (1, 0)}
# DIRECTION_TO_ACTION = {0: "DOWN", 1: "UP", 2: "LEFT", 3: "RIGHT"}
# ACTION_TO_DIRECTION = {"DOWN": 0, "UP": 1, "LEFT": 2, "RIGHT": 3}

# def manhattan_distance(pos1, pos2):
#     return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1])


# def anticipatory_action(brain, learning_state, context_state, 
#                        exploration_weight=1.3, min_interact_prob=0.15,
#                        raw_position=None,
#                        forced_explore_prob=0.18):  # NEW: 18% forced random
#     """
#     Action selection with all fixes:
#     1. Forced random exploration (18%)
#     2. "Both" mode when extremely stuck
#     3. Turn-for-probing override
#     4. Tile-based interaction probing
#     5. Novelty-driven movement
#     """
#     actions_list = brain.actions()
#     if not actions_list:
#         return Perceptron("action", action="UP", group="move")

#     mode = brain.determine_control_mode(context_state, raw_position=raw_position)
#     current_map = int(context_state[2])
#     current_dir = int(context_state[5])
    
#     raw_x = raw_position[0] if raw_position else int(context_state[0] * 255)
#     raw_y = raw_position[1] if raw_position else int(context_state[1] * 255)
#     current_pos = (raw_x, raw_y)
    
#     # Get exploration memory
#     memory = brain.get_current_map_memory(current_map)
#     visited_tiles = memory['visited_tiles']
#     obstructions = memory['obstructions']
    
#     # Get tile interaction state
#     tile_needs_probing = brain.should_interact_at_tile(raw_x, raw_y, current_map)
    
#     # NEW: Get best probe action (handles turn-then-interact)
#     probe_action, probe_dir = brain.get_best_probe_action(raw_x, raw_y, current_map, current_dir)
    
#     # Get transition info
#     transition_attraction, best_transition = brain.get_transition_attraction(current_map)
#     coverage = brain.get_exploration_coverage(current_map)
    
#     # === BUILD ALLOWED ACTIONS LIST ===
#     if mode == "battle":
#         # In battle, use group weights to decide
#         move_weight = brain.get_group_weight("move")
#         interact_weight = brain.get_group_weight("interact")
#         total = move_weight + interact_weight + 1e-9
#         if random.random() < move_weight / total:
#             allowed = [a for a in actions_list if a.group == "move"]
#         else:
#             allowed = [a for a in actions_list if a.group == "interact"]
#         all_actions = actions_list  # Fallback
        
#     elif mode == "both":
#         # NEW: "Both" mode - allow everything
#         allowed = actions_list
#         all_actions = actions_list
        
#     elif mode == "interact":
#         allowed = [a for a in actions_list if a.group == "interact"]
#         all_actions = None
        
#         # NEW: Turn-for-probing override
#         # If we need to turn to probe, allow that movement action
#         if probe_action and probe_action in ['UP', 'DOWN', 'LEFT', 'RIGHT']:
#             turn_actions = [a for a in actions_list if a.action == probe_action]
#             if turn_actions:
#                 # Add the turn action to allowed list
#                 allowed = allowed + turn_actions
        
#     else:  # move
#         allowed = [a for a in actions_list if a.group == "move"]
#         all_actions = None

#     if not allowed:
#         allowed = actions_list

#     # === NEW: FORCED RANDOM EXPLORATION (18%) ===
#     if random.random() < forced_explore_prob:
#         chosen = random.choice(allowed)
#         brain.record_action_execution(chosen.action)
#         brain.track_consecutive_action(chosen.action)
        
#         # Still start interaction verification if it's an A press on a probeable tile
#         if chosen.action == 'A' and tile_needs_probing:
#             brain.start_interaction_verification(raw_x, raw_y, current_map, current_dir)
        
#         return chosen

#     # === SCORE ACTIONS ===
#     action_scores = []
    
#     for a in allowed:
#         predicted = brain.predict_future_error(learning_state, a, context_state, raw_position=raw_position)
        
#         # --- MOVE ACTIONS ---
#         if a.group == "move":
#             if mode in ["move", "both"]:
#                 predicted *= exploration_weight
            
#             dx, dy = ACTION_DELTAS.get(a.action, (0, 0))
#             target_tile = (raw_x + dx, raw_y + dy)
#             action_direction = ACTION_TO_DIRECTION.get(a.action, -1)
            
#             # BONUS: Unvisited tile
#             if target_tile not in visited_tiles:
#                 predicted *= brain.UNVISITED_TILE_BONUS
            
#             # PENALTY: Known obstruction
#             if target_tile in obstructions:
#                 predicted *= brain.OBSTRUCTION_PENALTY
            
#             # PENALTY: Transition ban
#             if brain.is_position_banned(current_map, raw_x, raw_y, action_direction):
#                 predicted *= 0.05
            
#             # BONUS: Toward transition when well-explored
#             if transition_attraction > 0.3 and best_transition and coverage > 0.5:
#                 trans_pos = tuple(best_transition['position']) if isinstance(best_transition['position'], list) else best_transition['position']
#                 if manhattan_distance(target_tile, trans_pos) < manhattan_distance(current_pos, trans_pos):
#                     predicted *= (1.0 + transition_attraction)
            
#             # NEW: If this is a turn needed for probing, boost it
#             if probe_action == a.action and probe_dir is not None:
#                 predicted *= 2.0  # Strong boost for needed turn
            
#             # Random factor for variety
#             predicted *= (0.9 + random.random() * 0.2)
        
#         # --- INTERACT ACTIONS ---
#         elif a.group == "interact":
#             predicted = max(predicted, min_interact_prob)
            
#             # Menu trap B-boost
#             if a.action == 'B':
#                 predicted *= brain.menu_trap_b_boost
            
#             # A-press logic
#             if a.action == 'A':
#                 if tile_needs_probing and probe_action == 'A':
#                     # We're facing an untried direction - strong boost!
#                     predicted *= 3.0
#                 elif tile_needs_probing:
#                     # Tile needs probing but we need to turn first
#                     predicted *= 0.5  # Mild penalty - turn should happen instead
#                 else:
#                     # Tile exhausted
#                     predicted *= 0.1
            
#             # Start/Select - always penalize, no boost
#             if a.action in ['Start', 'Select']:
#                 predicted *= 0.3
        
#         action_scores.append((a, predicted))

#     # === SELECT BEST ===
#     if action_scores:
#         best_action = max(action_scores, key=lambda x: x[1])[0]
#         best_score = max(s for _, s in action_scores)
        
#         if best_score > 0.01:
#             brain.record_action_execution(best_action.action)
#             brain.track_consecutive_action(best_action.action)
            
#             # Start interaction verification for A-press on probeable tile
#             if best_action.action == 'A' and tile_needs_probing:
#                 brain.start_interaction_verification(raw_x, raw_y, current_map, current_dir)
            
#             return best_action
    
#     # === FALLBACKS ===
    
#     # Battle fallback
#     if mode == "battle" and all_actions:
#         all_scores = [(a, brain.predict_future_error(learning_state, a, context_state, raw_position=raw_position)) 
#                       for a in all_actions]
#         if all_scores:
#             best_action = max(all_scores, key=lambda x: x[1])[0]
#             brain.record_action_execution(best_action.action)
#             brain.track_consecutive_action(best_action.action)
#             return best_action
    
#     # Move fallback: prefer unvisited
#     if mode in ["move", "both"]:
#         for a in allowed:
#             if a.group == "move":
#                 dx, dy = ACTION_DELTAS.get(a.action, (0, 0))
#                 target = (raw_x + dx, raw_y + dy)
#                 if target not in visited_tiles and target not in obstructions:
#                     brain.record_action_execution(a.action)
#                     brain.track_consecutive_action(a.action)
#                     return a
    
#     # Generic fallback
#     if allowed:
#         best = max(allowed, key=lambda a: a.utility)
#         brain.record_action_execution(best.action)
#         brain.track_consecutive_action(best.action)
#         return best
    
#     best = max(actions_list, key=lambda a: a.utility)
#     brain.record_action_execution(best.action)
#     brain.track_consecutive_action(best.action)
#     return best

In [36]:
# ============================================================================
# CELL 5: Taught Transitions Manager (Markov Imitation Learning)
# ============================================================================

TRANSITIONS_FILE = BASE_PATH / "taught_transitions.json"

class TaughtTransitionsManager:
    """
    Manages taught transition data for Markov imitation learning.
    
    Similarity Scoring:
    - Immediate (50%): Same map, position within 2 tiles, same direction, same battle/menu state
    - Sequential (30%): Last 8 actions match (with subset checks at 5, 3)
    - Partial (20%): in_battle, in_menu, movement_blocked, near_transition, tile_probed
    """
    
    def __init__(self):
        self.batches = []
        self.metadata = {}
        self.frame_index = []  # Flat list of all frames for fast lookup
        self.last_load_time = 0
        
        # Similarity weights
        self.IMMEDIATE_WEIGHT = 0.50
        self.SEQUENTIAL_WEIGHT = 0.30
        self.PARTIAL_WEIGHT = 0.20
        
        # Thresholds
        self.SIMILARITY_THRESHOLD = 0.6
        self.POSITION_TOLERANCE = 2  # Tiles
        
    def load_transitions(self, force=False):
        """Load taught transitions from file."""
        if not TRANSITIONS_FILE.exists():
            return False
        
        try:
            mod_time = TRANSITIONS_FILE.stat().st_mtime
            if not force and mod_time == self.last_load_time:
                return True  # Already loaded
            
            with open(TRANSITIONS_FILE, 'r') as f:
                data = json.load(f)
            
            self.batches = data.get('batches', [])
            self.metadata = data.get('metadata', {})
            self.last_load_time = mod_time
            
            # Build flat frame index for fast lookup
            self._build_frame_index()
            
            print(f"‚úÖ Loaded {len(self.batches)} batches, {len(self.frame_index)} frames")
            return True
            
        except Exception as e:
            print(f"‚ùå Error loading transitions: {e}")
            return False
    
    def _build_frame_index(self):
        """Build flat index of all frames for fast similarity search."""
        self.frame_index = []
        
        for batch in self.batches:
            batch_type = batch.get('batch_type', 'steady')
            trigger = batch.get('trigger_action', None)
            
            for frame in batch.get('frames', []):
                self.frame_index.append({
                    'batch_type': batch_type,
                    'trigger_action': trigger,
                    'state': frame.get('state', {}),
                    'action': frame.get('action'),
                    'recent_actions': frame.get('recent_actions', [])
                })
    
    def compute_immediate_similarity(self, current_state, taught_state):
        """
        Immediate context similarity (50% weight).
        Same map, position within 2 tiles, same direction, same battle/menu state.
        """
        score = 0.0
        
        # Map match (required)
        if current_state.get('map_id') != taught_state.get('map_id'):
            return 0.0
        score += 0.25
        
        # Position within tolerance
        dx = abs(current_state.get('x', 0) - taught_state.get('x', 0))
        dy = abs(current_state.get('y', 0) - taught_state.get('y', 0))
        if dx <= self.POSITION_TOLERANCE and dy <= self.POSITION_TOLERANCE:
            # Closer = higher score
            dist = dx + dy
            score += 0.35 * (1.0 - dist / (2 * self.POSITION_TOLERANCE + 1))
        
        # Direction match
        if current_state.get('direction') == taught_state.get('direction'):
            score += 0.20
        
        # Battle/menu state match
        if current_state.get('in_battle') == taught_state.get('in_battle'):
            score += 0.10
        if current_state.get('in_menu') == taught_state.get('in_menu'):
            score += 0.10
        
        return score
    
    def compute_sequential_similarity(self, current_recent, taught_recent):
        """
        Sequential similarity (30% weight).
        Last 8 actions match, with subset checks at 5, 3.
        """
        if not current_recent or not taught_recent:
            return 0.0
        
        score = 0.0
        
        # Full 8-action match
        match_8 = self._count_trailing_matches(current_recent, taught_recent, 8)
        if match_8 >= 8:
            score = 1.0
        elif match_8 >= 5:
            score = 0.7
        elif match_8 >= 3:
            score = 0.4
        else:
            # Partial credit for any matches
            score = match_8 / 8.0 * 0.3
        
        return score
    
    def _count_trailing_matches(self, list1, list2, max_check):
        """Count how many trailing elements match."""
        matches = 0
        len1 = len(list1)
        len2 = len(list2)
        
        for i in range(1, min(max_check + 1, len1 + 1, len2 + 1)):
            if list1[-i] == list2[-i]:
                matches += 1
            else:
                break
        
        return matches
    
    def compute_partial_similarity(self, current_state, taught_state, brain=None):
        """
        Partial context similarity (20% weight).
        in_battle, in_menu, movement_blocked, near_transition, tile_probed.
        """
        score = 0.0
        
        # Battle state
        if current_state.get('in_battle') == taught_state.get('in_battle'):
            score += 0.30
        
        # Menu state
        if current_state.get('in_menu') == taught_state.get('in_menu'):
            score += 0.30
        
        # Additional context if brain available
        if brain is not None:
            map_id = current_state.get('map_id', 0)
            x = current_state.get('x', 0)
            y = current_state.get('y', 0)
            
            # Near transition
            memory = brain.get_current_map_memory(map_id)
            transitions = memory.get('transitions', [])
            for t in transitions:
                pos = t.get('position', [0, 0])
                if abs(x - pos[0]) <= 3 and abs(y - pos[1]) <= 3:
                    score += 0.20
                    break
            
            # Tile probed
            if brain.should_interact_at_tile(x, y, map_id):
                score += 0.20
        else:
            # Without brain, give partial credit
            score += 0.20
        
        return min(score, 1.0)
    
    def compute_similarity(self, current_state, current_recent, taught_frame, brain=None):
        """Compute total similarity score."""
        taught_state = taught_frame.get('state', {})
        taught_recent = taught_frame.get('recent_actions', [])
        
        immediate = self.compute_immediate_similarity(current_state, taught_state)
        sequential = self.compute_sequential_similarity(current_recent, taught_recent)
        partial = self.compute_partial_similarity(current_state, taught_state, brain)
        
        total = (immediate * self.IMMEDIATE_WEIGHT + 
                 sequential * self.SEQUENTIAL_WEIGHT + 
                 partial * self.PARTIAL_WEIGHT)
        
        return total, {
            'immediate': immediate,
            'sequential': sequential,
            'partial': partial
        }
    
    def find_best_action(self, current_state, current_recent, brain=None):
        """
        Find the best action from taught data based on similarity.
        
        Returns: (action, similarity_score, details) or (None, 0, None) if no good match
        """
        if not self.frame_index:
            self.load_transitions()
        
        if not self.frame_index:
            return None, 0.0, None
        
        best_action = None
        best_score = 0.0
        best_details = None
        best_frame = None
        
        for frame in self.frame_index:
            score, details = self.compute_similarity(
                current_state, current_recent, frame, brain
            )
            
            if score > best_score:
                best_score = score
                best_action = frame.get('action')
                best_details = details
                best_frame = frame
        
        # Only return if above threshold
        if best_score >= self.SIMILARITY_THRESHOLD:
            return best_action, best_score, {
                'details': best_details,
                'matched_state': best_frame.get('state') if best_frame else None,
                'batch_type': best_frame.get('batch_type') if best_frame else None
            }
        
        return None, best_score, None
    
    def get_stats(self):
        """Get statistics about loaded transitions."""
        if not self.batches:
            return None
        
        action_counts = {}
        for frame in self.frame_index:
            action = frame.get('action', 'NONE')
            action_counts[action] = action_counts.get(action, 0) + 1
        
        return {
            'total_batches': len(self.batches),
            'total_frames': len(self.frame_index),
            'action_changes': self.metadata.get('action_changes', 0),
            'maps_visited': self.metadata.get('maps_visited', []),
            'action_distribution': action_counts
        }
    
    def print_stats(self):
        """Print transition statistics."""
        stats = self.get_stats()
        if not stats:
            print("No transitions loaded")
            return
        
        print(f"\n{'='*50}")
        print(f"üìä TAUGHT TRANSITIONS")
        print(f"{'='*50}")
        print(f"  Batches: {stats['total_batches']}")
        print(f"  Frames: {stats['total_frames']}")
        print(f"  Action changes: {stats['action_changes']}")
        print(f"  Maps: {stats['maps_visited']}")
        print(f"\n  Action distribution:")
        for action, count in sorted(stats['action_distribution'].items(), 
                                     key=lambda x: x[1], reverse=True):
            pct = count / stats['total_frames'] * 100
            print(f"    {action}: {count} ({pct:.1f}%)")


# Global instance
transitions_manager = TaughtTransitionsManager()

In [37]:
# ============================================================================
# CELL 6: Main Loop - WITH TRANSITION LEARNING
# ============================================================================

import gc

brain = Brain()

# Action perceptrons
for b in ["UP", "DOWN", "LEFT", "RIGHT"]:
    brain.add(Perceptron("action", action=b, group="move"))
for b in ["A", "B", "Start", "Select"]:
    brain.add(Perceptron("action", action=b, group="interact"))

# Load existing model
brain.load_model()

# Initialize transitions manager
transitions_manager = TaughtTransitionsManager()
transitions_manager.load_transitions()

# Stats
total_inputs_processed = 0
batches_processed = 0

# Visual cache
cached_palette = np.zeros(PALETTE_DIM)
cached_tiles = np.zeros(TILE_DIM)
last_visual_update = 0
VISUAL_UPDATE_INTERVAL = 50

# Recent actions tracking (mirrors Lua's tracking)
recent_actions_buffer = []
RECENT_ACTIONS_SIZE = 8

print("="*70)
print("üéì TEACHING MODE + MARKOV TRANSITIONS")
print("="*70)
print(f"Input file: {INPUT_FILE}")
print(f"Transitions file: {TRANSITIONS_FILE}")
print(f"Memory file: {brain.EXPLORATION_MEMORY_FILE}")
print(f"Model file: {MODEL_FILE}")
print("="*70)

# Print initial transition stats
transitions_manager.print_stats()

POLL_INTERVAL = 2.0
prev_context_state = None
prev_raw_position = None
last_input_mod_time = 0

def add_to_recent_actions(action):
    """Track recent actions for similarity matching."""
    global recent_actions_buffer
    if action:
        recent_actions_buffer.append(action)
        if len(recent_actions_buffer) > RECENT_ACTIONS_SIZE:
            recent_actions_buffer = recent_actions_buffer[-RECENT_ACTIONS_SIZE:]

while True:
    # Check if input file exists
    if not INPUT_FILE.exists():
        print(f"[Waiting for {INPUT_FILE.name}...]")
        time.sleep(POLL_INTERVAL)
        continue
    
    # Check file modification
    try:
        current_mod_time = INPUT_FILE.stat().st_mtime
        file_size = INPUT_FILE.stat().st_size
        
        if current_mod_time == last_input_mod_time or file_size == 0:
            time.sleep(POLL_INTERVAL)
            continue
        
        last_input_mod_time = current_mod_time
    except Exception as e:
        time.sleep(POLL_INTERVAL)
        continue
    
    # Read inputs
    inputs = []
    try:
        with open(INPUT_FILE, 'r') as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                parts = line.split(',')
                if len(parts) >= 7:
                    action_code = parts[0]
                    inputs.append({
                        'action': ACTION_MAP.get(action_code, action_code),
                        'x': int(parts[1]),
                        'y': int(parts[2]),
                        'map': int(parts[3]),
                        'in_battle': int(parts[4]),
                        'menu_flag': int(parts[5]),
                        'direction': int(parts[6])
                    })
    except Exception as e:
        print(f"[Read error: {e}]")
        time.sleep(POLL_INTERVAL)
        continue
    
    if not inputs:
        time.sleep(POLL_INTERVAL)
        continue
    
    # Process batch
    batches_processed += 1
    batch_size = len(inputs)
    total_inputs_processed += batch_size
    
    print(f"\n{'='*60}")
    print(f"üì¶ BATCH #{batches_processed}: {batch_size} inputs")
    
    # Reload transitions periodically
    if batches_processed % 5 == 0:
        transitions_manager.load_transitions()
    
    # Update visual cache
    if batches_processed - last_visual_update >= VISUAL_UPDATE_INTERVAL:
        context, palette, tiles, raw_pos = read_game_state_full()
        if np.any(palette != 0):
            cached_palette = palette
            cached_tiles = tiles
            last_visual_update = batches_processed
    
    # Process inputs
    action_counts = {}
    taught_matches = 0
    
    for inp in inputs:
        inp_context, inp_raw_pos, human_action = process_cached_input(inp)
        
        brain.update_position(inp_raw_pos[0], inp_raw_pos[1])
        
        derived = compute_derived_features(inp_context, prev_context_state)
        learning_state = build_learning_state(derived, cached_palette, cached_tiles, inp_context[3])
        
        brain.log_state(learning_state, inp_context)
        
        if human_action:
            # Track recent actions
            add_to_recent_actions(human_action)
            
            # Learn from human action
            brain.learn_from_human_action(learning_state, human_action, inp_context)
            brain.last_action = human_action
            action_counts[human_action] = action_counts.get(human_action, 0) + 1
            
            # Check similarity with taught transitions
            current_state = {
                'map_id': inp.get('map'),
                'x': inp.get('x'),
                'y': inp.get('y'),
                'direction': inp.get('direction'),
                'in_battle': inp.get('in_battle'),
                'in_menu': inp.get('menu_flag')
            }
            
            taught_action, sim_score, match_info = transitions_manager.find_best_action(
                current_state, recent_actions_buffer, brain
            )
            
            if taught_action and sim_score >= 0.6:
                taught_matches += 1
                # Could log or use this for analysis
        
        if prev_context_state is not None:
            prev_derived = compute_derived_features(prev_context_state, None)
            prev_learning = build_learning_state(prev_derived, cached_palette, cached_tiles, 
                                                  prev_context_state[3])
            brain.learn(prev_learning, learning_state, prev_context_state, inp_context,
                       dead=False, raw_position=prev_raw_position, next_raw_position=inp_raw_pos)
        
        prev_context_state = inp_context.copy()
        prev_raw_position = inp_raw_pos
        brain.timestep += 1
    
    # Summary
    current_map = int(inp_context[2])
    memory = brain.get_current_map_memory(current_map)
    
    print(f"  Actions: {action_counts}")
    print(f"  Map {current_map} | Pos ({inp_raw_pos[0]}, {inp_raw_pos[1]})")
    print(f"  Visited: {len(memory['visited_tiles'])} | Demos: {brain.demonstration_count}")
    print(f"  Taught matches: {taught_matches}/{batch_size} ({taught_matches/max(1,batch_size)*100:.0f}%)")
    
    # Utilities
    utils = sorted([(a.action, a.utility) for a in brain.actions()], 
                   key=lambda x: x[1], reverse=True)
    print(f"  Utils: {' '.join([f'{k}:{v:.2f}' for k,v in utils])}")
    
    # Save every batch
    print(f"\n  üíæ SAVING...")
    try:
        brain.save_exploration_memory()
        brain.save_model()
        print(f"     ‚úì Saved")
    except Exception as e:
        print(f"     ‚úó Error: {e}")
    
    # Memory cleanup every 10 batches
    if batches_processed % 10 == 0:
        stats = brain.get_memory_stats()
        print(f"\n  üìä MEMORY: Maps={stats['exploration_maps']}, Tiles={stats['total_tiles']}")
        
        # Print transition stats
        transitions_manager.print_stats()
        
        brain.cleanup_memory()
        gc.collect()
    
    inputs = None
    gc.collect()

  Loaded exploration memory: 0 maps
‚úÖ Model loaded: 0 steps, 0 demos
‚úÖ Loaded 0 batches, 0 frames
üéì TEACHING MODE + MARKOV TRANSITIONS
Input file: C:\Users\natmaw\Documents\Boston Stuff\CS 5100 Foundations of AI\PokeAI\input_cache.txt
Transitions file: C:\Users\natmaw\Documents\Boston Stuff\CS 5100 Foundations of AI\PokeAI\taught_transitions.json
Memory file: C:\Users\natmaw\Documents\Boston Stuff\CS 5100 Foundations of AI\PokeAI\exploration_memory.json
Model file: C:\Users\natmaw\Documents\Boston Stuff\CS 5100 Foundations of AI\PokeAI\model_checkpoint.json
No transitions loaded

üì¶ BATCH #1: 109 inputs
  Actions: {'Start': 12, 'DOWN': 42, 'A': 55}
  Map 17 | Pos (13, 13)
  Visited: 1 | Demos: 109
  Taught matches: 0/109 (0%)
  Utils: DOWN:1.99 A:1.99 Start:0.99 UP:0.59 LEFT:0.59 RIGHT:0.59 B:0.59 Select:0.59

  üíæ SAVING...
üíæ Model saved: 109 steps, 109 demos
     ‚úì Saved


KeyboardInterrupt: 