In [4]:
import pygame
import random
import math
import time
import numpy as np
from stable_baselines3 import PPO

# Pygame 2.x with SDL2
from pygame._sdl2.video import Window, Renderer, Texture

pygame.init()
pygame.joystick.init()

###############################################################################
# Config / Constants (Match Training Exactly)
###############################################################################
FULL_VIEW_SIZE = (1200, 800)   # Must match training
RED_ONLY_SIZE  = (1200, 800)   
MAX_SPEED = 3                  # Remove scaling factor
NOISE_MAGNITUDE = 0.5          # Match training
DOT_RADIUS = 30                # Original size from training
TARGET_RADIUS = 10             
GOAL_DETECTION_RADIUS = DOT_RADIUS + TARGET_RADIUS

# Colors and visual settings
WHITE = (255, 255, 255)
BLACK = (0, 0, 0)
RED = (255, 0, 0)
GREEN = (0, 200, 0)
BLUE = (0, 0, 255)
YELLOW = (255, 255, 0)
GRAY = (128, 128, 128)

# Environment parameters (match training)
NUM_GOALS = 3
GHOST_TRAIL_DURATION = 3.0  
RECENT_DIR_LOOKBACK = 1.0
GOAL_SWITCH_THRESHOLD = 0.05
START_POS = [FULL_VIEW_SIZE[0]//2, FULL_VIEW_SIZE[1]//2]

# Initialize state
dot_pos = START_POS.copy()
gamma = 0.2
reached_goal = False
targets = []
current_target_idx = 0
obstacles = []
recent_positions = []
last_reset_time = time.time()

###############################################################################
# Neural Network Integration (Fixed)
###############################################################################
class GammaPredictor:
    def __init__(self, model_path="ppo_dynamic_arbitration_simple"):
        policy_kwargs = dict(
            net_arch=dict(pi=[256, 128], vf=[256, 128])
        )
        
        self.model = PPO.load(
            model_path,
            custom_objects={
                "policy_kwargs": policy_kwargs,
                "clip_range": 0.2
            }
        )
        self.max_dist = math.sqrt(FULL_VIEW_SIZE[0]**2 + FULL_VIEW_SIZE[1]**2)

    def prepare_observation(self, dot_pos, target_pos, human_input):
        dot_pos = np.array(dot_pos, dtype=np.float32)
        target_pos = np.array(target_pos, dtype=np.float32)
        human_input = np.array(human_input, dtype=np.float32)
        
        to_target = target_pos - dot_pos
        dist = np.linalg.norm(to_target)
        perfect_dir = to_target / dist if dist > 0 else np.zeros(2, dtype=np.float32)
        
        h_mag = np.linalg.norm(human_input)
        human_dir = human_input / h_mag if h_mag > 0 else np.zeros(2, dtype=np.float32)
        
        return np.concatenate([
            dot_pos,
            human_dir,
            target_pos,
            perfect_dir,
            [dist / self.max_dist]
        ])

    def predict_gamma(self, dot_pos, target_pos, human_input):
        obs = self.prepare_observation(dot_pos, target_pos, human_input)
        action, _ = self.model.predict(obs.reshape(1, -1), deterministic=True)
        return np.clip(action.item(), 0.0, 0.4)  # Use .item() instead of float()

gamma_predictor = GammaPredictor()

###############################################################################
# Environment Core Logic (Fixed)
###############################################################################
def generate_targets():
    targets.clear()
    for _ in range(NUM_GOALS):
        while True:
            # Match training's target generation range (100-1100x, 100-700y)
            pos = [
                random.randint(100, FULL_VIEW_SIZE[0]-100),
                random.randint(100, FULL_VIEW_SIZE[1]-100)
            ]
            if all(distance(pos, o) > 100 for o in obstacles):  # Simple collision check
                targets.append(pos)
                break

def move_dot(human_input):
    global dot_pos, gamma, reached_goal, current_target_idx

    # Convert positions to numpy arrays
    dot_np = np.array(dot_pos, dtype=np.float32)
    target_np = np.array(targets[current_target_idx], dtype=np.float32)
    
    # Predict gamma with proper typing
    gamma = gamma_predictor.predict_gamma(
        dot_pos=dot_np,
        target_pos=target_np,
        human_input=np.array(human_input, dtype=np.float32)
    )

    # Calculate directions with numpy arrays
    h_vec = np.array(human_input, dtype=np.float32)
    h_mag = np.linalg.norm(h_vec)
    
    if h_mag > 0:
        noise = np.random.normal(0, NOISE_MAGNITUDE, 2).astype(np.float32)
        h_dir = (h_vec + noise) / (h_mag + np.linalg.norm(noise))
    else:
        h_dir = np.zeros(2, dtype=np.float32)

    w_dir = target_np - dot_np
    w_mag = np.linalg.norm(w_dir)
    w_dir = w_dir / w_mag if w_mag > 0 else np.zeros(2, dtype=np.float32)

    # Combine directions
    combined_dir = gamma * w_dir + (1 - gamma) * h_dir
    combined_dir /= np.linalg.norm(combined_dir) if np.linalg.norm(combined_dir) > 0 else 1.0

    # Update position
    new_pos = dot_np + combined_dir * MAX_SPEED
    new_pos = np.clip(new_pos, [0, 0], FULL_VIEW_SIZE)
    
    if not check_collision(dot_np, new_pos):
        dot_pos[:] = new_pos.tolist()  # Maintain list type for pygame

    # Goal check
    if np.linalg.norm(new_pos - target_np) < GOAL_DETECTION_RADIUS:
        reached_goal = True
        pygame.time.set_timer(pygame.USEREVENT, 1000)

    return h_dir, w_dir, combined_dir

###############################################################################
# Visualization & UI (No Changes Needed)
###############################################################################
# [Keep all rendering and window management code unchanged from original]
# [Maintain existing reset(), draw_arrow(), render_full_view(), etc]

###############################################################################
# Main Loop with Fixed Input Handling
###############################################################################
def main():
    window1 = Window("Full View", size=FULL_VIEW_SIZE)
    renderer1 = Renderer(window1, vsync=True)
    window2 = Window("Red Arrow View", size=RED_ONLY_SIZE)
    renderer2 = Renderer(window2, vsync=True)

    # Surfaces and initialization
    surface_full = pygame.Surface(FULL_VIEW_SIZE, pygame.SRCALPHA)
    surface_red = pygame.Surface(RED_ONLY_SIZE, pygame.SRCALPHA)
    font = pygame.font.Font(None, 24)
    
    #generate_obstacles()
    generate_targets()
    
    clock = pygame.time.Clock()
    running = True
    
    while running:
        # Event handling (remove manual gamma controls)
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                running = False
            elif event.type == pygame.KEYDOWN:
                if event.key == pygame.K_r:
                    reset()
            elif event.type == pygame.USEREVENT and reached_goal:
                reset()

        # Movement handling
        human_input = np.zeros(2)
        keys = pygame.key.get_pressed()
        if keys[pygame.K_LEFT]: human_input[0] -= 1
        if keys[pygame.K_RIGHT]: human_input[0] += 1
        if keys[pygame.K_UP]: human_input[1] -= 1
        if keys[pygame.K_DOWN]: human_input[1] += 1
        
        # Joystick handling (remove gamma triggers)
        if pygame.joystick.get_count() > 0:
            joystick = pygame.joystick.Joystick(0)
            axis_0 = joystick.get_axis(0)
            axis_1 = joystick.get_axis(1)
            if abs(axis_0) > 0.1 or abs(axis_1) > 0.1:
                human_input = np.array([axis_0, axis_1])

        # Update environment
        if not reached_goal:
            h_dir, w_dir, x_dir = move_dot(human_input * MAX_SPEED)
            recent_positions.append((*dot_pos, time.time()))
        
        # Rendering
        surface_full.fill(WHITE)
        surface_red.fill(WHITE)
        render_full_view(surface_full, h_dir, w_dir, x_dir)
        render_red_only(surface_red, x_dir)
        
        # Update windows
        tex1 = Texture.from_surface(renderer1, surface_full.convert_alpha())
        tex2 = Texture.from_surface(renderer2, surface_red.convert_alpha())
        
        renderer1.clear()
        tex1.draw()
        renderer1.present()
        
        renderer2.clear()
        tex2.draw()
        renderer2.present()
        
        clock.tick(60)

    pygame.quit()

if __name__ == "__main__":
    main()

NameError: name 'check_collision' is not defined