In [None]:
import os
import datetime
import pygame
import random
import math
import time
import numpy as np
import json  # for saving log data
from stable_baselines3 import PPO
from pygame._sdl2.video import Window, Renderer, Texture

# -----------------------------------------------------------
# Force Sensor Integration (extracted from your force sensor code)
# -----------------------------------------------------------


import tkinter as tk
from tkinter import ttk
import serial as pyserial
import threading
from collections import deque

###############################################################################
# Multi-Seed Setup
###############################################################################
SCENARIO_SEEDS = [0, 2, 58]  # We will run the experiment for each of these seeds in turn.

###############################################################################
# Global flags & constants
###############################################################################
USE_FORCE_SENSOR = False  # toggle with F key
force_sensor_input = [0.0, 0.0]
FORCE_SENSOR_AVAILABLE = False

FORCE_SENSOR_SCALE_X = 5.0
FORCE_SENSOR_SCALE_Y = 5.0

XY_FORCE_CAL = 1.33
Z_SMOOTH_ALPHA = 0.2

skip_remaining_seeds = False
current_gamma_mode = None
gamma_mode_index = 0

sample_prev_fs = None
sample_curr_fs = None
sample_lock_fs = threading.Lock()
last_fx_smooth = None
last_fy_smooth = None

# Constants for high gamma threshold
HIGH_GAMMA_THRESHOLD = 0.65

# Track which goal should be completed next
expected_goal_idx = 0

try:
    ser = pyserial.Serial('COM5', 115200, timeout=0.01)
    FORCE_SENSOR_AVAILABLE = True
    print("Force sensor connected successfully.")
except Exception as e:
    print("Error: Could not open serial port:", e)
    ser = None
    FORCE_SENSOR_AVAILABLE = False
    print("Force sensor not available. Using keyboard/joystick controls only.")

def serial_reader_fs():
    global sample_prev_fs, sample_curr_fs, last_fx_smooth, last_fy_smooth, force_sensor_input
    while True:
        if ser is None:
            time.sleep(0.1)
            continue
        try:
            line = ser.readline().decode('utf-8').strip()
            if not line:
                continue
            tokens = line.split(',')
            if len(tokens) != 2:
                continue
            try:
                values = list(map(float, tokens))
            except Exception:
                continue
            sample = {
                'fx': -values[0],
                'fy': values[1],
                'timestamp': time.time()
            }
            with sample_lock_fs:
                if sample_curr_fs is None:
                    sample_curr_fs = sample
                    sample_prev_fs = sample
                else:
                    sample_prev_fs = sample_curr_fs
                    sample_curr_fs = sample

            # Interpolate horizontal forces
            with sample_lock_fs:
                sp = sample_prev_fs
                sc = sample_curr_fs
            now = time.time()
            if sp is None or sc is None:
                continue
            t0 = sp['timestamp']
            t1 = sc['timestamp']
            fraction = 1.0 if t1 == t0 else (now - t0) / (t1 - t0)
            fraction = max(0.0, min(1.0, fraction))
            fx_interp = sp['fx'] * (1 - fraction) + sc['fx'] * fraction
            fy_interp = sp['fy'] * (1 - fraction) + sc['fy'] * fraction

            # Exponential smoothing
            if last_fx_smooth is None:
                smoothed_fx = fx_interp
            else:
                smoothed_fx = last_fx_smooth + Z_SMOOTH_ALPHA * (fx_interp - last_fx_smooth)
            if last_fy_smooth is None:
                smoothed_fy = fy_interp
            else:
                smoothed_fy = last_fy_smooth + Z_SMOOTH_ALPHA * (fy_interp - last_fy_smooth)
            last_fx_smooth = smoothed_fx
            last_fy_smooth = smoothed_fy

            # Scale with calibration
            final_fx = smoothed_fx * XY_FORCE_CAL
            final_fy = smoothed_fy * XY_FORCE_CAL
            final_fx *= FORCE_SENSOR_SCALE_X
            final_fy *= FORCE_SENSOR_SCALE_Y

            force_sensor_input = [final_fx, final_fy]
        except Exception as e:
            print("Serial reader error:", e)
            time.sleep(0.01)

if FORCE_SENSOR_AVAILABLE:
    threading.Thread(target=serial_reader_fs, daemon=True).start()

pygame.init()
pygame.joystick.init()

USE_AI_CONTROL = False

# Noise Models
def eeg_noise():
    """ Default: Gaussian EEG noise with σ=190. """
    return np.random.normal(0, 190)

def burst_noise():
    """ Rare large deviations (~20% prob). """
    return np.random.choice([-1,1]) * np.random.normal(0, 190) * (np.random.rand() < 0.2)

from scipy.signal import lfilter
def pink_noise():
    """ Pink noise (1/f). """
    a = [1, -0.95]
    return lfilter([1], a, np.random.normal(0, 290, size=1))[0]

NOISE_FUNCTION = pink_noise  # or burst_noise, etc.

###############################################################################
# Config / Constants
###############################################################################
# Game area size will remain the same 
GAME_AREA_SIZE = (1200, 800)

# Add extra space around the game area for UI
# 200px on each side
FULL_VIEW_SIZE = (1600, 800)
RED_ONLY_SIZE  = (1600, 800)

# Horizontal center of game area
GAME_AREA_X = (FULL_VIEW_SIZE[0] - GAME_AREA_SIZE[0]) // 2  # Center the game area
GAME_AREA_Y = 0

NOISE_MAGNITUDE = 2.5
MIN_NOISE = 0.0
MAX_NOISE = 2.0
NOISE_STEP = 0.1

OLD_WINDOW_SIZE   = (600, 600)
SCALING_FACTOR_X  = GAME_AREA_SIZE[0] / OLD_WINDOW_SIZE[0]
SCALING_FACTOR_Y  = GAME_AREA_SIZE[1] / OLD_WINDOW_SIZE[1]
SCALING_FACTOR    = (SCALING_FACTOR_X + SCALING_FACTOR_Y) / 2

# Better color scheme
WHITE  = (255, 255, 255)
BLACK  = (0, 0, 0)
RED    = (255, 60, 60)
GREEN  = (60, 180, 60)
BLUE   = (60, 120, 255)
YELLOW = (240, 230, 60)
ORANGE = (255, 140, 0)
GRAY   = (128, 128, 128)
LIGHT_GRAY = (200, 200, 200)
DARK_GRAY = (80, 80, 80)
TEXT_COLOR = (30, 30, 40)
HIGHLIGHT_COLOR = (70, 70, 230)
EXPECTED_GOAL_COLOR = (0, 200, 100)  # Brighter green for the expected goal
WRONG_GOAL_COLOR = (255, 0, 0)  # Bright red for wrong goals
HISTORY_COLOR = (100, 100, 115)  # Darker gray for movement history

# Warm, creamy background color (more noticeable reading mode)
BACKGROUND_COLOR = (250, 240, 210)  # Stronger warm cream tone

FONT_COLOR = TEXT_COLOR
FONT_SIZE = int(18 * SCALING_FACTOR)
TITLE_FONT_SIZE = int(22 * SCALING_FACTOR)
ARROW_LENGTH = int(60 * SCALING_FACTOR)

OBSTACLE_RADIUS  = int(10 * SCALING_FACTOR)
COLLISION_BUFFER = int(5 * SCALING_FACTOR)
ENABLE_OBSTACLES = True
MAX_SPEED        = 3 * SCALING_FACTOR

DOT_RADIUS            = int(15 * SCALING_FACTOR)
TARGET_RADIUS         = int(10 * SCALING_FACTOR)
GOAL_DETECTION_RADIUS = DOT_RADIUS + TARGET_RADIUS

GHOST_TRAIL_DURATION  = 3.0
recent_positions      = []
last_reset_time       = time.time()

RECENT_DIR_LOOKBACK   = 1.0
GOAL_SWITCH_THRESHOLD = 0.05

# Center point of the game area
GAME_CENTER = (GAME_AREA_X + GAME_AREA_SIZE[0] // 2, GAME_AREA_Y + GAME_AREA_SIZE[1] // 2)
START_POS = [GAME_CENTER[0], GAME_CENTER[1]]
dot_pos   = START_POS.copy()

gamma         = 0.2
reached_goal  = False
targets       = []
current_target_idx = 0
obstacles     = []

goal_counters = {}
failure_counter = 0
wrong_goal_message_time = 0  # Time when wrong goal message was shown

USE_RAW_ONLY_FOR_GOAL_DETECTION = True

# Variables for autonomous behavior
has_started_moving = False
movement_start_time = 0
target_lock_time = 3.0  # seconds before locking target
last_redirect_time = 0
redirect_interval = 5.0  # seconds between potential redirects
redirect_chance = 0.3   # chance to redirect to another goal
user_intended_target = None   # Store what the user initially wanted
initial_target_selected = False

joystick = None
if pygame.joystick.get_count() > 0:
    joystick = pygame.joystick.Joystick(0)
    joystick.init()
    print("Joystick initialized:", joystick.get_name())
else:
    print("No joystick detected.")

AXIS_L2 = 4
AXIS_R2 = 5

window1 = Window("2D Environment: Full View", size=FULL_VIEW_SIZE)
renderer1 = Renderer(window1, vsync=True)
window2 = Window("2D Environment: Red Arrow Only", size=RED_ONLY_SIZE)
renderer2 = Renderer(window2, vsync=True)

def create_compatible_surface(size):
    return pygame.Surface(size, flags=pygame.SRCALPHA)

surface_full = create_compatible_surface(FULL_VIEW_SIZE)
surface_red_only = create_compatible_surface(RED_ONLY_SIZE)

# Initialize fonts
pygame.font.init()
font = pygame.font.Font(None, FONT_SIZE)
title_font = pygame.font.Font(None, TITLE_FONT_SIZE)

def surface_to_texture(renderer, surf):
    if surf.get_bitsize() != 32:
        surf = surf.convert_alpha()
    return Texture.from_surface(renderer, surf)

def distance(pos1, pos2):
    return math.hypot(pos1[0] - pos2[0], pos1[1] - pos2[1])

# Correct line_circle_intersection
def line_circle_intersection(start, end, circle_center, radius):
    dx = end[0] - start[0]
    dy = end[1] - start[1]
    cx = circle_center[0] - start[0]
    cy = circle_center[1] - start[1]
    l2 = dx*dx + dy*dy
    if l2 == 0:
        return distance(start, circle_center) <= radius
    t = max(0, min(1, (cx*dx + cy*dy) / l2))
    proj_x = start[0] + t * dx
    proj_y = start[1] + t * dy
    return distance((proj_x, proj_y), circle_center) <= radius

def check_collision(pos, new_pos):
    if not ENABLE_OBSTACLES:
        return False
    for obstacle_pos in obstacles:
        if line_circle_intersection(pos, new_pos, obstacle_pos, OBSTACLE_RADIUS + COLLISION_BUFFER):
            return True
    return False

def inside_obstacle(pos):
    if not ENABLE_OBSTACLES:
        return False
    for obstacle_pos in obstacles:
        if distance(pos, obstacle_pos) <= (OBSTACLE_RADIUS + DOT_RADIUS):
            return True
    return False

def get_recent_direction():
    if len(recent_positions) < 2:
        return [0, 0]
    current_time = time.time()
    valid_points = []
    for (x, y, t) in reversed(recent_positions):
        if (current_time - t) <= RECENT_DIR_LOOKBACK:
            valid_points.append((x, y, t))
        else:
            break
    if len(valid_points) < 2:
        return [0, 0]
    valid_points.sort(key=lambda p: p[2])
    x1, y1, t1 = valid_points[0]
    x2, y2, t2 = valid_points[-1]
    dt = t2 - t1
    if dt < 0.001:
        return [0, 0]
    vx = (x2 - x1) / dt
    vy = (y2 - y1) / dt
    mag = math.hypot(vx, vy)
    return [vx/mag, vy/mag] if mag > 0 else [0, 0]

def compute_perfect_direction(dot_pos, goal_pos, obstacles):
    gx = goal_pos[0] - dot_pos[0]
    gy = goal_pos[1] - dot_pos[1]
    goal_dist = math.hypot(gx, gy)
    if goal_dist < 1e-6:
        return [0, 0]
    goal_dir = [gx / goal_dist, gy / goal_dist]

    repulse_x = 0.0
    repulse_y = 0.0
    repulsion_radius = 27 * SCALING_FACTOR
    repulsion_gain   = 30000.0

    for obs in obstacles:
        dx = dot_pos[0] - obs[0]
        dy = dot_pos[1] - obs[1]
        dist_o = math.hypot(dx, dy)
        if dist_o < 1e-6:
            continue
        if dist_o < repulsion_radius:
            push_dir_x = dx / dist_o
            push_dir_y = dy / dist_o
            strength = repulsion_gain / (dist_o**2)
            repulse_x += push_dir_x * strength
            repulse_y += push_dir_y * strength

    w_px = goal_dir[0] + repulse_x
    w_py = goal_dir[1] + repulse_y
    norm = math.hypot(w_px, w_py)
    if norm < 1e-6:
        return [0, 0]
    return [w_px / norm, w_py / norm]

def draw_gamma_gauge(surface, gamma_value, x, y, width=150, height=80):
    """
    Draw a horizontal bar gauge for gamma value that fills from left to right.
    
    Args:
        surface: Surface to draw on
        gamma_value: Value between 0 and 1
        x, y: Position of the gauge (top-left corner)
        width, height: Dimensions of the gauge
    """
    # Draw gauge background
    pygame.draw.rect(surface, LIGHT_GRAY, (x, y, width, height), 0)
    pygame.draw.rect(surface, DARK_GRAY, (x, y, width, height), 2)
    
    # Draw title
    gauge_title = font.render("Assistance", True, TEXT_COLOR)
    title_rect = gauge_title.get_rect(center=(x + width//2, y + 15))
    surface.blit(gauge_title, title_rect)
    
    # Bar metrics
    bar_height = 20
    bar_y = y + 35
    bar_width = width - 20  # Padding on sides
    bar_x = x + 10
    
    # Draw the background bar
    pygame.draw.rect(surface, GRAY, (bar_x, bar_y, bar_width, bar_height))
    
    # Draw the value bar (filled from left to right)
    if gamma_value > 0:
        fill_width = int(bar_width * gamma_value)
        
        # Color gradient from green to red
        if gamma_value < 0.5:
            # Green to yellow gradient
            color = (int(255 * (gamma_value * 2)), 180, 60)
        else:
            # Yellow to red gradient
            color = (255, int(180 - (gamma_value - 0.5) * 2 * 120), 60)
            
        pygame.draw.rect(surface, color, (bar_x, bar_y, fill_width, bar_height))
    
    # Draw border around bar
    pygame.draw.rect(surface, DARK_GRAY, (bar_x, bar_y, bar_width, bar_height), 2)
    
    # Draw marker at current value
    if gamma_value > 0:
        marker_x = bar_x + int(bar_width * gamma_value)
        pygame.draw.line(surface, BLACK, (marker_x, bar_y - 3), (marker_x, bar_y + bar_height + 3), 2)
    
    # Draw value text
    value_text = font.render(f"{gamma_value:.2f}", True, TEXT_COLOR)
    value_rect = value_text.get_rect(center=(x + width//2, bar_y + bar_height + 15))
    surface.blit(value_text, value_rect)

def draw_controller_guide(surface, x, y, width=200, height=160):
    """Draw controller button guide with text labels"""
    # Draw background
    pygame.draw.rect(surface, LIGHT_GRAY, (x, y, width, height), 0)
    pygame.draw.rect(surface, DARK_GRAY, (x, y, width, height), 2)
    
    # Draw title
    guide_title = font.render("Controller Guide", True, TEXT_COLOR)
    title_rect = guide_title.get_rect(center=(x + width//2, y + 15))
    surface.blit(guide_title, title_rect)
    
    # Starting position for text
    text_x = x + 20
    text_y = y + 40
    spacing = 35
    
    # Left stick
    control_text = font.render("Left Stick - Move Dot", True, TEXT_COLOR)
    surface.blit(control_text, (text_x, text_y))
    text_y += spacing
    
    # Square button
    reset_text = font.render("Square - Reset Position", True, TEXT_COLOR)
    surface.blit(reset_text, (text_x, text_y))
    text_y += spacing
    
    # L2/R2 buttons only for mode 4 (manual)
    if current_gamma_mode == "manual":
        assist_text = font.render("L2/R2 - Change Assistance", True, TEXT_COLOR)
        surface.blit(assist_text, (text_x, text_y))

class GammaPredictor:
    def __init__(self, model_path="gamma_ppo_model.zip"):
        try:
            self.model = PPO.load(model_path)
        except:
            self.model = None
        self.max_dist = np.sqrt(GAME_AREA_SIZE[0]**2 + GAME_AREA_SIZE[1]**2)

    def prepare_observation(self, dot_pos, target_pos, human_input):
        dot_pos = np.array(dot_pos, dtype=np.float32)
        target_pos = np.array(target_pos, dtype=np.float32)
        to_target = target_pos - dot_pos
        dist = np.linalg.norm(to_target)
        perfect_dir = to_target / dist if dist > 0 else np.array([0, 0], dtype=np.float32)
        h_mag = np.linalg.norm(human_input)
        human_dir = human_input / h_mag if h_mag > 0 else np.array([0, 0], dtype=np.float32)
        normalized_dist = dist / self.max_dist
        obs_dist_ratio = 1.0
        obs = np.concatenate([
            dot_pos,
            human_dir,
            target_pos,
            perfect_dir,
            [normalized_dist],
            [obs_dist_ratio]
        ]).astype(np.float32)
        return obs

    def predict_gamma(self, dot_pos, target_pos, human_input):
        if self.model is None:
            return 0.2
        obs = self.prepare_observation(dot_pos, target_pos, human_input)
        obs_batched = obs[np.newaxis, :]
        action, _ = self.model.predict(obs_batched, deterministic=True)
        return float(action[0])

gamma_predictor = GammaPredictor()

def predict_human_target(human_input):
    global current_target_idx
    
    dist_to_current = distance(dot_pos, targets[current_target_idx]) if len(targets) > 0 else float('inf')
    close_threshold = GOAL_DETECTION_RADIUS * 2
    if dist_to_current < close_threshold:
        return current_target_idx

    if human_input[0] == 0 and human_input[1] == 0:
        return current_target_idx

    h_mag = math.hypot(human_input[0], human_input[1])
    if h_mag <= 1e-6:
        return current_target_idx
    h_dir = [h_input/h_mag for h_input in human_input]

    recent_dir = [0, 0]  # default
    if not USE_RAW_ONLY_FOR_GOAL_DETECTION:
        recent_dir = get_recent_direction()

    best_score = float('-inf')
    best_idx   = current_target_idx

    for i, targ in enumerate(targets):
        to_tx = targ[0] - dot_pos[0]
        to_ty = targ[1] - dot_pos[1]
        to_mag = math.hypot(to_tx, to_ty)
        if to_mag == 0:
            continue
        to_dir = [to_tx/to_mag, to_ty/to_mag]
        align_human  = (h_dir[0]*to_dir[0] + h_dir[1]*to_dir[1])
        align_recent = (recent_dir[0]*to_dir[0] + recent_dir[1]*to_dir[1])
        score = (align_human * 0.8) + (align_recent * 0.2)
        if score > best_score:
            best_score = score
            best_idx   = i

    return best_idx

# Find a goal in the general direction but not the intended one
def find_alternative_goal(intended_idx, human_dir):
    if len(targets) <= 1 or intended_idx is None:
        return 0
    
    # Get the direction to the intended target
    intended_target = targets[intended_idx]
    intended_dir = [
        intended_target[0] - dot_pos[0],
        intended_target[1] - dot_pos[1]
    ]
    mag = math.hypot(intended_dir[0], intended_dir[1])
    if mag > 0:
        intended_dir = [intended_dir[0]/mag, intended_dir[1]/mag]
    
    # Find targets that are in a similar direction (within ~60 degrees)
    # but are not the intended target
    candidates = []
    for i, target in enumerate(targets):
        if i == intended_idx:
            continue
            
        target_dir = [
            target[0] - dot_pos[0],
            target[1] - dot_pos[1]
        ]
        tmag = math.hypot(target_dir[0], target_dir[1])
        if tmag > 0:
            target_dir = [target_dir[0]/tmag, target_dir[1]/tmag]
            
        # Calculate dot product to get cosine of angle
        dot_product = (intended_dir[0] * target_dir[0] + 
                      intended_dir[1] * target_dir[1])
        
        # If angle is less than ~60 degrees (cos > 0.5)
        if dot_product > 0.5:
            candidates.append((i, dot_product))
    
    if candidates:
        # Sort by how close they are to the intended direction
        candidates.sort(key=lambda x: x[1], reverse=True)
        return candidates[0][0]
    
    # If no suitable candidates, return a random target that's not the intended one
    options = [i for i in range(len(targets)) if i != intended_idx]
    if options:
        return random.choice(options)
    return 0

data_log = []
trial_start_time = time.time()
current_trajectory = []
trial_outcome = None

save_folder = "user_study_data"
os.makedirs(save_folder, exist_ok=True)

def get_save_filename(seed):
    session_timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
    return os.path.join(save_folder, f"data_log_seed{seed}_{session_timestamp}.json")

save_filename = None

def save_data_log(seed):
    if not save_filename:
        return
    with open(save_filename, "w") as f:
        json.dump(data_log, f, indent=4)
    print(f"Data log updated and saved to {save_filename} for seed={seed}")

def get_mode_number(gamma_mode):
    if gamma_mode_index is not None:
        return gamma_mode_index + 1
    if isinstance(gamma_mode, float):
        if gamma_mode < 0.1:
            return 1
        elif gamma_mode < 0.6:
            return 2
        else:
            return 3
    elif gamma_mode == "manual":
        return 4
    elif gamma_mode == "ai":
        return 5
    return 0

def move_dot(human_input):
    global dot_pos, gamma, reached_goal, current_target_idx, USE_AI_CONTROL, trial_outcome
    global has_started_moving, movement_start_time, last_redirect_time
    global user_intended_target, initial_target_selected, expected_goal_idx
    global wrong_goal_message_time, failure_counter, current_gamma_mode

    # Variables to track time spent going for the wrong goal
    if 'wrong_goal_time' not in globals():
        global wrong_goal_time, wrong_goal_correcting, wrong_goal_correction_end
        wrong_goal_time = 0
        wrong_goal_correcting = False
        wrong_goal_correction_end = 0

    if len(targets) == 0:
        return [0,0], [0,0], [0,0]

    h_dx, h_dy = human_input
    h_mag = math.hypot(h_dx, h_dy)
    h_dir = [h_dx / h_mag, h_dy / h_mag] if h_mag > 0 else [0, 0]

    # Only move if there's human input
    if h_mag <= 1e-6:
        return [0,0], [0,0], [0,0]

    # Track when movement begins and capture initial target selection
    if not has_started_moving and h_mag > 0:
        has_started_moving = True
        movement_start_time = time.time()
        
        # If user intended target hasn't been set yet, set it now based on their initial input
        if not initial_target_selected:
            proposed_idx = predict_human_target(human_input)
            current_target_idx = proposed_idx if proposed_idx < len(targets) else 0
            user_intended_target = current_target_idx
            initial_target_selected = True
            print(f"Initial user target set to: {user_intended_target}")

    # Make sure we have valid indices
    if current_target_idx >= len(targets):
        current_target_idx = 0
    if expected_goal_idx >= len(targets):
        expected_goal_idx = 0
    
    # Check if we're heading toward wrong goal and update timer
    now = time.time()
    if current_target_idx != expected_goal_idx and not wrong_goal_correcting:
        wrong_goal_time += 1/60  # Assuming 60 FPS
    else:
        wrong_goal_time = 0
    
    # If correction period is over, reset the flag
    if wrong_goal_correcting and now > wrong_goal_correction_end:
        wrong_goal_correcting = False
        print("IDA: Correction period ended, returning control to user")
    
    # Use the current target for movement unless we're in correction mode
    if wrong_goal_correcting and current_gamma_mode == 0.5:  # IDA mode
        target_pos = targets[expected_goal_idx]  # Use correct goal during correction
    else:
        target_pos = targets[current_target_idx]  # Use user's selected goal

    w_dir = compute_perfect_direction(dot_pos, target_pos, obstacles)
    input_mag = min(max(h_mag / MAX_SPEED, 0), 1)
    step_size = MAX_SPEED * input_mag

    # -------------- MODE 2: IDA (Binary Intervention) --------------
    if current_gamma_mode == 0.5:  # IDA mode
        # Default to no intervention
        ida_gamma = 0.0
        
        # 1. Check for imminent collision with obstacles
        collision_imminent = False
        collision_threshold = (OBSTACLE_RADIUS + DOT_RADIUS) * 1.5  # Buffer zone
        
        for obs in obstacles:
            # Calculate position after human input
            h_move_x = h_dir[0] * step_size
            h_move_y = h_dir[1] * step_size
            potential_pos = [dot_pos[0] + h_move_x, dot_pos[1] + h_move_y]
            
            # Check if this would lead to collision
            if distance(potential_pos, obs) < collision_threshold:
                collision_imminent = True
                break
        
        # 2. Check if human action is universally worse (pointing away from all goals)
        universally_worse = True
        
        for i, goal in enumerate(targets):
            # Direction to this goal
            to_goal = [goal[0] - dot_pos[0], goal[1] - dot_pos[1]]
            goal_dist = math.hypot(to_goal[0], to_goal[1])
            
            if goal_dist > 0:
                goal_dir = [to_goal[0]/goal_dist, to_goal[1]/goal_dist]
                # Calculate dot product (alignment with human direction)
                alignment = h_dir[0]*goal_dir[0] + h_dir[1]*goal_dir[1]
                
                # If pointing somewhat toward any goal, it's not universally worse
                if alignment > -0.3:  # Allow up to ~110 degrees off
                    universally_worse = False
                    break
        
        # 3. Check if user has been going for the wrong goal for too long
        # The threshold is 3 seconds of consistently moving toward wrong goal
        if wrong_goal_time > 3.0 and not wrong_goal_correcting:
            wrong_goal_correcting = True  # Start correction period
            wrong_goal_correction_end = now + 0.225  # Correct for 0.225 seconds (increased by 50%)
            print(f"IDA: User going to wrong goal for {wrong_goal_time:.1f}s, providing brief nudge")
        
        # Set binary intervention based on conditions
        if collision_imminent or universally_worse or wrong_goal_correcting:
            # For the brief correction nudge, use a lower gamma
            if wrong_goal_correcting:
                ida_gamma = 0.8  # Partial intervention for correction nudge
                # If correcting wrong goal, use the expected goal
                target_pos = targets[expected_goal_idx]
                w_dir = compute_perfect_direction(dot_pos, target_pos, obstacles)
                print("IDA mode: Brief correction nudge active")
            else:
                ida_gamma = 1.0  # Full intervention for collisions or universally worse actions
                print("IDA mode: Full intervention activated")
        else:
            ida_gamma = 0.0  # No intervention
        
        # Apply the binary gamma
        gamma = ida_gamma
        
    # -------------- MODE 3: Reddy's Approach (Continuous Blending) --------------
    elif current_gamma_mode == 1.0:  # Reddy's approach
        # Base gamma - lower than before to make it less forceful
        base_gamma = 0.25
        
        # 1. Goal proximity factor - gradual increase as agent approaches goal
        dist_to_target = distance(dot_pos, target_pos)
        goal_threshold = GOAL_DETECTION_RADIUS * 4
        
        if dist_to_target < goal_threshold:
            goal_factor = 1.0 - (dist_to_target / goal_threshold)
            # Reduced assistance compared to before
            base_gamma = max(base_gamma, 0.25 + 0.25 * goal_factor)
        
        # 2. Obstacle proximity factor
        min_obs_distance = min(distance(dot_pos, obs) for obs in obstacles) if obstacles else float('inf')
        obs_threshold = (OBSTACLE_RADIUS + DOT_RADIUS) * 2.5
        
        if min_obs_distance < obs_threshold:
            obs_factor = 1.0 - (min_obs_distance / obs_threshold)
            # Proportionally increase assistance near obstacles but less forceful
            base_gamma = max(base_gamma, 0.3 + 0.3 * obs_factor)
        
        # 3. Goal switching - be more responsive to human input changes
        # More aggressive detection of goal changes
        if h_mag > 0.3:  # Only when there's significant input
            proposed_idx = predict_human_target(human_input)
            
            # If the human seems to be aiming at a different goal than current,
            # be more willing to switch
            if proposed_idx != current_target_idx and proposed_idx < len(targets):
                # Calculate angle between current movement and direction to proposed goal
                proposed_goal = targets[proposed_idx]
                to_proposed = [proposed_goal[0] - dot_pos[0], proposed_goal[1] - dot_pos[1]]
                p_dist = math.hypot(to_proposed[0], to_proposed[1])
                
                if p_dist > 0:
                    p_dir = [to_proposed[0]/p_dist, to_proposed[1]/p_dist]
                    alignment = h_dir[0]*p_dir[0] + h_dir[1]*p_dir[1]
                    
                    # If strongly aligned with a new goal, switch to it
                    if alignment > 0.8:  # Higher threshold for clearer intent
                        current_target_idx = proposed_idx
                        print(f"Mode 3: Detected goal change to {current_target_idx+1}")
                        # When switching goals, reduce gamma temporarily to allow control
                        base_gamma = max(0.1, base_gamma - 0.2)
        
        # 4. Very rarely select wrong goal - about 1% chance per second
        if random.random() < 0.0002 and len(targets) > 1:  # 0.0002 per frame ≈ 1.2% per second at 60fps
            # Only if we're not too close to current goal
            if dist_to_target > goal_threshold:
                # Pick a random goal that's not the current one
                options = [i for i in range(len(targets)) if i != current_target_idx]
                if options:
                    new_target = random.choice(options)
                    current_target_idx = new_target
                    print(f"Mode 3: Randomly switched to goal {current_target_idx+1}")
        
        # 5. Action selection within tolerance of human input
        # Calculate alignment between human direction and optimal direction
        alignment = h_dir[0]*w_dir[0] + h_dir[1]*w_dir[1]
        
        # If human is already aligned with optimal direction, reduce assistance
        if alignment > 0.7:  # Fairly well aligned (cosine similarity > 0.7, ~45 degrees)
            base_gamma = max(0.1, base_gamma - 0.15)
        
        # Add small noise to make gamma feel more natural
        noise = random.uniform(-0.05, 0.05)
        reddy_gamma = base_gamma + noise
        
        # Ensure gamma stays in valid range and cap it lower than before
        gamma = max(0.1, min(0.60, reddy_gamma))
        
    # -------------- MODE 5: Context-adaptive PPO --------------
    elif USE_AI_CONTROL and h_mag > 0:
        # Modified AI control logic with context-aware gamma
        dist_to_target = distance(dot_pos, target_pos)
        min_obs_distance = min(distance(dot_pos, obs) for obs in obstacles) if obstacles else float('inf')
        
        # Increased threshold to detect goals from further away
        goal_threshold = GOAL_DETECTION_RADIUS * 6  # Increased from *3 to *6
        obs_threshold = (OBSTACLE_RADIUS + DOT_RADIUS) * 3
        base_gamma = 0.35
        
        # Count goals in human's movement direction
        goals_in_direction = 0
        direction_threshold = 0.7  # Cosine similarity threshold (about 45 degrees)
        
        if h_mag > 0:
            h_dir_norm = [h_input/h_mag for h_input in human_input]
            
            for targ in targets:
                to_goal = [targ[0] - dot_pos[0], targ[1] - dot_pos[1]]
                to_goal_mag = math.hypot(to_goal[0], to_goal[1])
                if to_goal_mag > 0:
                    to_goal_dir = [to_goal[0]/to_goal_mag, to_goal[1]/to_goal_mag]
                    alignment = h_dir_norm[0]*to_goal_dir[0] + h_dir_norm[1]*to_goal_dir[1]
                    if alignment > direction_threshold:
                        goals_in_direction += 1
        
        # Adjust gamma based on goal clarity
        if goals_in_direction == 1:
            # One clear goal - increase gamma earlier and higher
            extended_threshold = goal_threshold * 1.5
            if dist_to_target < extended_threshold:
                goal_factor = 1.0 - (dist_to_target / extended_threshold)
                base_gamma = max(base_gamma, 0.45 + 0.5 * goal_factor)
        else:
            # Multiple goals in direction - more conservative
            if dist_to_target < goal_threshold:
                goal_factor = 1.0 - (dist_to_target / goal_threshold)
                base_gamma = max(base_gamma, 0.35 + 0.4 * goal_factor)
        
        # Obstacle handling remains the same
        if min_obs_distance < obs_threshold:
            obs_factor = 1.0 - (min_obs_distance / obs_threshold)
            base_gamma = max(base_gamma, 0.45 + 0.5 * obs_factor)
        
        # Combined situation handling
        if dist_to_target < goal_threshold and min_obs_distance < obs_threshold:
            base_gamma = max(base_gamma, 0.7 + (0.1 if goals_in_direction == 1 else 0))
        
        # NEW: When very close to a goal or obstacle, still allow goal switching 
        # if user clearly indicates a different goal
        very_close_to_goal = dist_to_target < GOAL_DETECTION_RADIUS * 1.5
        very_close_to_obstacle = min_obs_distance < (OBSTACLE_RADIUS + DOT_RADIUS) * 1.2
        near_end_of_trajectory = very_close_to_goal  # End of trajectory is when we're close to a goal
        
        if (very_close_to_goal or very_close_to_obstacle or near_end_of_trajectory) and gamma > 0.7:
            # Make it even easier to switch goals when close to completion
            # Check if human is pointing toward a different goal - use lower threshold
            if h_mag > 0.3:  # Lower input threshold for goal switching
                proposed_idx = predict_human_target(human_input)
                
                if proposed_idx != current_target_idx and proposed_idx < len(targets):
                    proposed_goal = targets[proposed_idx]
                    to_proposed = [proposed_goal[0] - dot_pos[0], proposed_goal[1] - dot_pos[1]]
                    p_dist = math.hypot(to_proposed[0], to_proposed[1])
                    
                    if p_dist > 0:
                        p_dir = [to_proposed[0]/p_dist, to_proposed[1]/p_dist]
                        alignment = h_dir[0]*p_dir[0] + h_dir[1]*p_dir[1]
                        
                        # Lower threshold when near end of trajectory
                        intent_threshold = 0.75 if near_end_of_trajectory else 0.85
                        
                        # If pointing toward different goal with sufficient intent, allow the switch
                        if alignment > intent_threshold:
                            current_target_idx = proposed_idx
                            print(f"Mode 5: Goal switch to {current_target_idx+1} at end of trajectory")
                            # More significantly reduce gamma to ensure the switch works
                            base_gamma = 0.3
        
        noise = random.uniform(-0.05, 0.05)
        final_gamma = base_gamma + noise
        gamma = max(0.0, min(1.0, final_gamma))

    elif gamma > HIGH_GAMMA_THRESHOLD:
        # High gamma mode - system ignores user input but requires it to move
        print("High assistance level - system is using input for movement only")
    elif gamma < 0.05:
        print("Manual mode, gamma ~0 => dot is attracted to obstacles")
        ox, oy = 0.0, 0.0
        for obs in obstacles:
            dxo = obs[0] - dot_pos[0]
            dyo = obs[1] - dot_pos[1]
            dist_o = math.hypot(dxo, dyo)
            if dist_o > 1e-6:
                ox += dxo / dist_o
                oy += dyo / dist_o
        mo = math.hypot(ox, oy)
        if mo > 1e-6:
            ox /= mo
            oy /= mo
        w_dir = [ox, oy]
    elif abs(gamma - 0.5) < 0.06:
        print("Manual mode, gamma ~0.5 => partial obstacle attraction")
        ox, oy = 0.0, 0.0
        for obs in obstacles:
            dxo = obs[0] - dot_pos[0]
            dyo = obs[1] - dot_pos[1]
            dist_o = math.hypot(dxo, dyo)
            if dist_o > 1e-6:
                ox += dxo / dist_o
                oy += dyo / dist_o
        mo = math.hypot(ox, oy)
        if mo > 1e-6:
            ox /= mo
            oy /= mo
        blend_ratio = 0.3
        w_dir[0] = (1 - blend_ratio) * w_dir[0] + blend_ratio * ox
        w_dir[1] = (1 - blend_ratio) * w_dir[1] + blend_ratio * oy

    # When gamma is high, still allow user to control target selection but blend movement
    now = time.time()
    
    # First, always allow the user to select their target regardless of gamma
    if h_mag > 0.2:  # As long as there's meaningful input
        proposed_idx = predict_human_target(human_input)
        
        # If human is clearly pointing toward a different goal
        if proposed_idx != current_target_idx and proposed_idx < len(targets):
            proposed_goal = targets[proposed_idx]
            to_proposed = [proposed_goal[0] - dot_pos[0], proposed_goal[1] - dot_pos[1]]
            p_dist = math.hypot(to_proposed[0], to_proposed[1])
            
            if p_dist > 0:
                p_dir = [to_proposed[0]/p_dist, to_proposed[1]/p_dist]
                alignment = h_dir[0]*p_dir[0] + h_dir[1]*p_dir[1]
                
                # Use more lenient threshold for high gamma
                intent_threshold = 0.6 if gamma > HIGH_GAMMA_THRESHOLD else 0.7
                
                # If sufficiently aligned with different goal, allow the switch
                if alignment > intent_threshold:
                    current_target_idx = proposed_idx
                    print(f"Goal switched to {current_target_idx+1}")
                    # Update target_pos after switching
                    target_pos = targets[current_target_idx]
                    w_dir = compute_perfect_direction(dot_pos, target_pos, obstacles)
    
    # Then apply movement blending based on gamma
    if gamma > 0.9:  # Only for extremely high gamma, reduce human control to minimal
        # Still allow a little input influence even at highest gamma
        w_move_x = 0.95 * w_dir[0] * step_size  
        w_move_y = 0.95 * w_dir[1] * step_size
        h_move_x = 0.05 * h_dir[0] * step_size  # 5% human influence
        h_move_y = 0.05 * h_dir[1] * step_size  # 5% human influence
    else:
        # Normal blending operation for all other gamma values
        w_move_x = gamma * w_dir[0] * step_size
        w_move_y = gamma * w_dir[1] * step_size

        noise_x = np.random.normal(0, NOISE_MAGNITUDE)
        noise_y = np.random.normal(0, NOISE_MAGNITUDE)
        noisy_dx = h_dir[0] + noise_x
        noisy_dy = h_dir[1] + noise_y
        nm = math.hypot(noisy_dx, noisy_dy)
        if nm > 0:
            noisy_dx /= nm
            noisy_dy /= nm

        h_move_x = (1 - gamma) * noisy_dx * step_size
        h_move_y = (1 - gamma) * noisy_dy * step_size

    final_dx = w_move_x + h_move_x
    final_dy = w_move_y + h_move_y

    new_x = dot_pos[0] + final_dx
    new_y = dot_pos[1] + final_dy

    if not check_collision(dot_pos, [new_x, new_y]):
        # Account for game area offset
        dot_pos[0] = max(GAME_AREA_X, min(GAME_AREA_X + GAME_AREA_SIZE[0], new_x))
        dot_pos[1] = max(GAME_AREA_Y, min(GAME_AREA_Y + GAME_AREA_SIZE[1], new_y))

    if inside_obstacle(dot_pos):
        print("Collision with obstacle -> resetting!")
        trial_outcome = "collision"
        failure_counter += 1
        reset()
        return [0,0], [0,0], [0,0]

    final_mag = math.hypot(final_dx, final_dy)
    x_dir = [final_dx/final_mag, final_dy/final_mag] if final_mag > 0 else [0, 0]

    # Check if a goal has been reached
    dist_to_goal = distance(dot_pos, target_pos)
    if dist_to_goal < GOAL_DETECTION_RADIUS:
        # Check if this is the expected goal
        if current_target_idx == expected_goal_idx:
            # Successfully reached the expected goal
            reached_goal = True
            trial_outcome = "success"
            
            # Increment goal counter
            if current_target_idx not in goal_counters:
                goal_counters[current_target_idx] = 1
            else:
                goal_counters[current_target_idx] += 1
                
            # Check if we've reached this goal 4 times
            # If so, move to the next goal
            if goal_counters[current_target_idx] >= 4:
                # Move to the next expected goal
                expected_goal_idx = (expected_goal_idx + 1) % len(targets)
                print(f"Goal {current_target_idx+1} completed 4 times! Moving to next goal: {expected_goal_idx+1}")
        else:
            # Wrong goal reached
            print(f"Wrong goal reached! Expected: {expected_goal_idx+1}, Reached: {current_target_idx+1}")
            trial_outcome = "wrong_goal"
            failure_counter += 1
            wrong_goal_message_time = time.time()
            reset()
            return [0,0], [0,0], [0,0]
            
        pygame.time.set_timer(pygame.USEREVENT, 1000)

    return h_dir, w_dir, x_dir

def reset():
    global dot_pos, reached_goal, current_target_idx, gamma
    global recent_positions, last_reset_time, trial_start_time, current_trajectory, trial_outcome
    global failure_counter, has_started_moving, movement_start_time, last_redirect_time
    global user_intended_target, initial_target_selected
    
    if trial_start_time is not None and len(current_trajectory) > 0:
        trial_duration = time.time() - trial_start_time
        if reached_goal and 0 <= current_target_idx < len(targets):
            goal_reached = targets[current_target_idx]
        else:
            goal_reached = None
        mode = "AI" if USE_AI_CONTROL else "Manual"
        trial_timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")

        if trial_outcome == "manual_reset":
            failure_counter += 1

        data_log.append({
            "timestamp": trial_timestamp,
            "mode": mode,
            "trial_duration": trial_duration,
            "trial_outcome": trial_outcome if trial_outcome else "manual_reset",
            "goal_reached": goal_reached,
            "trajectory": current_trajectory.copy()
        })
        print(f"Trial recorded: {data_log[-1]}")
        save_data_log(current_seed)

    dot_pos = START_POS.copy()
    reached_goal = False
    current_target_idx = 0
    
    # Set default gamma to 0.5 for manual mode (mode 4), otherwise 0.95
    if current_gamma_mode == "manual":
        gamma = 0.5
    else:
        gamma = 0.95
        
    recent_positions.clear()
    last_reset_time = time.time()
    trial_start_time = time.time()
    current_trajectory.clear()
    trial_outcome = None
    has_started_moving = False
    movement_start_time = 0
    last_redirect_time = 0
    user_intended_target = None
    initial_target_selected = False
    pygame.time.set_timer(pygame.USEREVENT, 0)

def initialize_environment_fixed(seed):
    global obstacles, targets, goal_counters, expected_goal_idx
    random.seed(seed)
    np.random.seed(seed)
    obstacles.clear()
    targets.clear()
    # Calculate margin relative to game area, not window
    margin = 50 * SCALING_FACTOR
    min_goal_gap = 200 * SCALING_FACTOR
    N_GOALS = 8
    N_OBSTACLES = 5

    new_goals = []
    attempts = 0
    while len(new_goals) < N_GOALS and attempts < 1000:
        # Place goals within the game area
        x = random.uniform(GAME_AREA_X + margin, GAME_AREA_X + GAME_AREA_SIZE[0] - margin)
        y = random.uniform(GAME_AREA_Y + margin, GAME_AREA_Y + GAME_AREA_SIZE[1] - margin)
        candidate = (x, y)
        if all(distance(candidate, g) >= min_goal_gap for g in new_goals):
            new_goals.append(candidate)
        attempts += 1
    targets.extend(new_goals)
    
    # Reset goal counters and expected goal
    goal_counters = {i: 0 for i in range(len(targets))}
    expected_goal_idx = 0

    new_obstacles = []
    if len(new_goals) > 1:
        obstacle_goals = random.sample(new_goals, k=min(min(N_GOALS-1, N_OBSTACLES), len(new_goals)-1))
    else:
        obstacle_goals = new_goals
    for goal in obstacle_goals:
        t = random.uniform(0.6, 0.8)
        base_point = (START_POS[0] + t * (goal[0] - START_POS[0]),
                      START_POS[1] + t * (goal[1] - START_POS[1]))
        vec = (goal[0] - START_POS[0], goal[1] - START_POS[1])
        vec_norm = math.hypot(vec[0], vec[1])
        if vec_norm < 1e-6:
            perp = (0, 0)
        else:
            perp = (-vec[1] / vec_norm, vec[0] / vec_norm)
        offset_mag = random.uniform(20 * SCALING_FACTOR, 40 * SCALING_FACTOR)
        offset = (perp[0] * offset_mag * random.choice([-1,1]),
                  perp[1] * offset_mag * random.choice([-1,1]))
        candidate = (base_point[0] + offset[0], base_point[1] + offset[1])
        # Make sure obstacles are within game area
        candidate = (max(GAME_AREA_X + margin, min(candidate[0], GAME_AREA_X + GAME_AREA_SIZE[0] - margin)),
                    max(GAME_AREA_Y + margin, min(candidate[1], GAME_AREA_Y + GAME_AREA_SIZE[1] - margin)))
        valid = True
        if distance(candidate, START_POS) < (DOT_RADIUS + OBSTACLE_RADIUS + 10):
            valid = False
        if distance(candidate, goal) < (TARGET_RADIUS + OBSTACLE_RADIUS + 20):
            valid = False
        for obs in new_obstacles:
            if distance(candidate, obs) < (2 * OBSTACLE_RADIUS + 10):
                valid = False
        if valid:
            new_obstacles.append(candidate)
    obstacles.extend(new_obstacles)
    print(f"Environment initialized with fixed seed {seed}. #goals={len(targets)}, #obstacles={len(obstacles)}")


def draw_arrow(surface, color, start_pos, direction, length=ARROW_LENGTH):
    dx, dy = direction
    if dx == 0 and dy == 0:
        return
    mag = math.hypot(dx, dy)
    dx /= mag
    dy /= mag
    end_x = start_pos[0] + dx * length
    end_y = start_pos[1] + dy * length
    pygame.draw.line(surface, color, start_pos, (end_x, end_y), int(2 * SCALING_FACTOR))
    arrow_size = 7 * SCALING_FACTOR
    angle = math.atan2(dy, dx)
    arrow1_x = end_x - arrow_size * math.cos(angle + math.pi/6)
    arrow1_y = end_y - arrow_size * math.sin(angle + math.pi/6)
    arrow2_x = end_x - arrow_size * math.cos(angle - math.pi/6)
    arrow2_y = end_y - arrow_size * math.sin(angle - math.pi/6)
    pygame.draw.line(surface, color, (end_x, end_y), (arrow1_x, arrow1_y), int(2 * SCALING_FACTOR))
    pygame.draw.line(surface, color, (end_x, end_y), (arrow2_x, arrow2_y), int(2 * SCALING_FACTOR))

def render_full_view(surface, h_dir, w_dir, x_dir):
    """We read and modify current_target_idx, so we declare it global."""
    global current_target_idx, wrong_goal_message_time
    
    # Calculate time since movement started (for showing selection window)
    now = time.time()
    time_since_movement_start = now - movement_start_time if has_started_moving else 0

    surface.fill(BACKGROUND_COLOR)
    
    # Draw a thin border around the game area
    game_area_rect = pygame.Rect(GAME_AREA_X, GAME_AREA_Y, GAME_AREA_SIZE[0], GAME_AREA_SIZE[1])
    pygame.draw.rect(surface, LIGHT_GRAY, game_area_rect, 1)
    
    # Draw the environment within the game area
    if ENABLE_OBSTACLES:
        for obstacle_pos in obstacles:
            pygame.draw.circle(surface, GRAY, (int(obstacle_pos[0]), int(obstacle_pos[1])), OBSTACLE_RADIUS)
    
    # Draw all goals
    for i, target in enumerate(targets):
        # Use different colors for expected vs. other goals
        if i == expected_goal_idx:
            goal_color = EXPECTED_GOAL_COLOR
            outline_color = BLACK
            outline_width = 2
        else:
            goal_color = YELLOW
            outline_color = None
            outline_width = 0
            
        # Draw goal
        pygame.draw.circle(surface, goal_color, (int(target[0]), int(target[1])), TARGET_RADIUS)
        
        # Draw outline if needed
        if outline_color:
            pygame.draw.circle(surface, outline_color, (int(target[0]), int(target[1])),
                            TARGET_RADIUS + 2, int(outline_width*SCALING_FACTOR))
            
        # Draw goal number
        num_text = font.render(str(i + 1), True, BLACK)
        surface.blit(num_text, (target[0] - 5, target[1] - 12))

    # Highlight current target (what the user is moving toward)
    if len(targets) > 0:
        if current_target_idx >= len(targets):
            current_target_idx = 0
        curr_t = targets[current_target_idx]
        
        # Use dashed line for current target if it's not the expected one
        if current_target_idx != expected_goal_idx:
            # Draw dashed line to indicate user's target
            segments = 16
            radius = TARGET_RADIUS + 5
            for i in range(segments):
                if i % 2 == 0:  # Draw every other segment
                    start_angle = i * 2 * math.pi / segments
                    end_angle = (i + 1) * 2 * math.pi / segments
                    # Draw arc
                    start_pos = (curr_t[0] + radius * math.cos(start_angle),
                                 curr_t[1] + radius * math.sin(start_angle))
                    end_pos = (curr_t[0] + radius * math.cos(end_angle),
                               curr_t[1] + radius * math.sin(end_angle))
                    pygame.draw.line(surface, BLUE, start_pos, end_pos, 2)
        else:
            # Normal highlight for the expected target
            pygame.draw.circle(surface, BLACK, (int(curr_t[0]), int(curr_t[1])),
                            TARGET_RADIUS + 2, int(2*SCALING_FACTOR))

    # Draw ghost trail with darker color
    now = time.time()
    while len(recent_positions) > 0 and (now - recent_positions[0][2]) > GHOST_TRAIL_DURATION:
        recent_positions.pop(0)
    if len(recent_positions) > 1:
        for idx in range(len(recent_positions) - 1):
            x1, y1, t1 = recent_positions[idx]
            x2, y2, t2 = recent_positions[idx+1]
            pygame.draw.line(surface, HISTORY_COLOR, (x1, y1), (x2, y2), 2)

    # Draw the dot (user controlled)
    pygame.draw.circle(surface, BLACK, (int(dot_pos[0]), int(dot_pos[1])),
                       DOT_RADIUS, int(2*SCALING_FACTOR))

    # Draw directional arrows
    if h_dir != [0,0]:
        draw_arrow(surface, BLUE, (int(dot_pos[0]), int(dot_pos[1])), h_dir, ARROW_LENGTH)
    if w_dir != [0,0]:
        draw_arrow(surface, GREEN, (int(dot_pos[0]), int(dot_pos[1])), w_dir, ARROW_LENGTH)
    if x_dir != [0,0]:
        draw_arrow(surface, RED, (int(dot_pos[0]), int(dot_pos[1])), x_dir, ARROW_LENGTH)

    # Left side info panel (outside game area)
    left_x = 10  # Left margin
    
    # Condition number at top left (changed from Mode to Condition)
    mode_number = get_mode_number(current_gamma_mode)
    mode_title = title_font.render(f"Condition {mode_number}", True, HIGHLIGHT_COLOR)
    surface.blit(mode_title, (left_x, 20))
    
    # Gamma and movement info
    y_pos = 60
    g_txt = font.render(f"Gamma: {gamma:.2f}", True, TEXT_COLOR)
    surface.blit(g_txt, (left_x, y_pos))
    y_pos += 30
    
    # Display movement equation (without mentioning "autonomous mode")
    form_txt = font.render(f"Movement = {gamma:.2f}W + {1-gamma:.2f}H", True, TEXT_COLOR)
    surface.blit(form_txt, (left_x, y_pos))
    y_pos += 40
    
    # Goal sequence info
    expected_txt = font.render(f"Current Goal: {expected_goal_idx+1}", True, EXPECTED_GOAL_COLOR)
    surface.blit(expected_txt, (left_x, y_pos))
    y_pos += 30
    
    if expected_goal_idx < len(targets) and expected_goal_idx in goal_counters:
        progress_txt = font.render(f"Progress: {goal_counters[expected_goal_idx]}/4", True, TEXT_COLOR)
        surface.blit(progress_txt, (left_x, y_pos))
        y_pos += 30
    
    # Controls info
    controls = [
        "L2/R2: gamma",
        "[/]: noise",
        "R: reset",
        f"Noise σ: {NOISE_MAGNITUDE:.2f}",
        f"Control: {'AI' if USE_AI_CONTROL else 'Manual'}",
        f"Input: {'Force Sensor' if USE_FORCE_SENSOR else 'Joystick/Keyboard'}"
    ]
    
    for control in controls:
        ctrl_txt = font.render(control, True, TEXT_COLOR)
        surface.blit(ctrl_txt, (left_x, y_pos))
        y_pos += 30
    
    if not FORCE_SENSOR_AVAILABLE and USE_FORCE_SENSOR:
        unavail_txt = font.render("Force sensor not available!", True, RED)
        surface.blit(unavail_txt, (left_x, y_pos))
        y_pos += 30
    
    # Timer and seed
    y_pos += 20
    elapsed_time = time.time() - last_reset_time
    timer_text = font.render(f"Time: {elapsed_time:.1f}s", True, TEXT_COLOR)
    surface.blit(timer_text, (left_x, y_pos))
    y_pos += 30
    
    seed_txt = font.render(f"Scenario Seed: {current_seed}", True, TEXT_COLOR)
    surface.blit(seed_txt, (left_x, y_pos))
    y_pos += 30
    
    # Legend
    legend_y = FULL_VIEW_SIZE[1] - 140
    legend_title = font.render("Legend:", True, TEXT_COLOR)
    surface.blit(legend_title, (left_x, legend_y))
    
    legend_items = [
        ("Green Arrow: Perfect Path (W)", GREEN),
        ("Blue Arrow: Human Movement (H)", BLUE),
        ("Red Arrow: Dot's Movement", RED),
        ("Gray line: Movement History", LIGHT_GRAY)
    ]
    
    for i, (label, color) in enumerate(legend_items):
        text = font.render(label, True, color)
        surface.blit(text, (left_x, legend_y + 30 + i*25))
    
    # Right side info panel - Results
    right_x = GAME_AREA_X + GAME_AREA_SIZE[0] + 10
    
    results_title = font.render("Results:", True, TEXT_COLOR)
    surface.blit(results_title, (right_x, 20))
    
    result_y = 60
    for i in range(len(targets)):
        count = goal_counters.get(i, 0)
        if i == expected_goal_idx:
            result_txt = font.render(f"Goal {i+1}: {count}/4", True, EXPECTED_GOAL_COLOR)
        else:
            result_txt = font.render(f"Goal {i+1}: {count}/4", True, GREEN)
        surface.blit(result_txt, (right_x, result_y))
        result_y += 30
    
    failures_txt = font.render(f"Failures: {failure_counter}", True, RED)
    surface.blit(failures_txt, (right_x, result_y))
    
    # Add gamma gauge at the lower right side of the screen, below the goal counters
    gauge_x = right_x  # Right margin, same as goal counters
    gauge_y = result_y + 40  # Below the failure counter
    draw_gamma_gauge(surface, gamma, gauge_x, gauge_y, 150, 80)
    
    # Completion message when goal is reached
    if reached_goal:
        r_txt = title_font.render(f"Goal Reached in {elapsed_time:.1f}s!", True, GREEN)
        r_rect = r_txt.get_rect(center=(GAME_CENTER[0], 80))
        surface.blit(r_txt, r_rect)
    
    # Show "Wrong Goal!" message if recently hit wrong goal
    now = time.time()
    if now - wrong_goal_message_time < 2.0:  # Show message for 2 seconds
        wrong_txt = title_font.render(f"Wrong Goal! Go to Goal {expected_goal_idx+1}", True, WRONG_GOAL_COLOR)
        wrong_rect = wrong_txt.get_rect(center=(GAME_CENTER[0], 80))
        surface.blit(wrong_txt, wrong_rect)

def render_red_only(surface, x_dir):
    global current_target_idx, wrong_goal_message_time

    surface.fill(BACKGROUND_COLOR)
    
    # Draw a thin border around the game area
    game_area_rect = pygame.Rect(GAME_AREA_X, GAME_AREA_Y, GAME_AREA_SIZE[0], GAME_AREA_SIZE[1])
    pygame.draw.rect(surface, LIGHT_GRAY, game_area_rect, 1)
    
    # Draw the environment within the game area
    if ENABLE_OBSTACLES:
        for obstacle_pos in obstacles:
            pygame.draw.circle(surface, GRAY, (int(obstacle_pos[0]), int(obstacle_pos[1])), OBSTACLE_RADIUS)
    
    # Draw ghost trail with darker color
    now = time.time()
    while len(recent_positions) > 0 and (now - recent_positions[0][2]) > GHOST_TRAIL_DURATION:
        recent_positions.pop(0)
    if len(recent_positions) > 1:
        for idx in range(len(recent_positions) - 1):
            x1, y1, t1 = recent_positions[idx]
            x2, y2, t2 = recent_positions[idx+1]
            pygame.draw.line(surface, HISTORY_COLOR, (x1, y1), (x2, y2), 2)
    
    # Draw all goals with different colors
    for i, target in enumerate(targets):
        # Expected goal is bright green, current selected goal is highlighted,
        # other goals are red
        if i == expected_goal_idx:
            # Expected goal - bright green
            color = EXPECTED_GOAL_COLOR
        else:
            # Other goals - red
            color = RED
            
        pygame.draw.circle(surface, color, (int(target[0]), int(target[1])), TARGET_RADIUS)
        
        # Highlight current target (what user is moving toward)
        if i == current_target_idx:
            if i == expected_goal_idx:
                highlight_color = BLACK  # Black outline for expected goal
            else:
                highlight_color = BLUE   # Blue outline for incorrect target
                
            pygame.draw.circle(surface, highlight_color, (int(target[0]), int(target[1])),
                               TARGET_RADIUS + 2, int(2 * SCALING_FACTOR))
                
        # Draw goal number
        num_text = font.render(str(i + 1), True, BLACK)
        surface.blit(num_text, (target[0] - 5, target[1] - 12))

    # Draw the dot
    pygame.draw.circle(surface, BLACK, (int(dot_pos[0]), int(dot_pos[1])),
                       DOT_RADIUS, int(2 * SCALING_FACTOR))

    if x_dir != [0,0]:
        draw_arrow(surface, RED, (int(dot_pos[0]), int(dot_pos[1])), x_dir, ARROW_LENGTH)

    # Left side UI
    left_x = 10
    
    # Condition number at top left (changed from Mode to Condition)
    mode_number = get_mode_number(current_gamma_mode)
    mode_title = title_font.render(f"Condition {mode_number}", True, HIGHLIGHT_COLOR)
    surface.blit(mode_title, (left_x, 20))
    
    # Timer
    elapsed_time = time.time() - last_reset_time
    timer_text = font.render(f"Time: {elapsed_time:.1f}s", True, TEXT_COLOR)
    surface.blit(timer_text, (left_x, 60))
    
    # Goal Info
    goal_txt = font.render(f"Goal {expected_goal_idx+1}: {goal_counters.get(expected_goal_idx, 0)}/4", True, EXPECTED_GOAL_COLOR)
    surface.blit(goal_txt, (left_x, 100))
    
    # Removed the "High Assistance Level" text
    if USE_FORCE_SENSOR:
        force_text = font.render("Force Sensor Mode", True, TEXT_COLOR)
        surface.blit(force_text, (left_x, 140))
    
    # Right side UI - Results
    right_x = GAME_AREA_X + GAME_AREA_SIZE[0] + 10
    
    results_title = font.render("Results:", True, TEXT_COLOR)
    surface.blit(results_title, (right_x, 20))
    
    result_y = 60
    for i in range(len(targets)):
        count = goal_counters.get(i, 0)
        if i == expected_goal_idx:
            result_txt = font.render(f"Goal {i+1}: {count}/4", True, EXPECTED_GOAL_COLOR)
        else:
            result_txt = font.render(f"Goal {i+1}: {count}/4", True, GREEN)
        surface.blit(result_txt, (right_x, result_y))
        result_y += 30
    
    failures_txt = font.render(f"Failures: {failure_counter}", True, RED)
    surface.blit(failures_txt, (right_x, result_y))
    
    # Add gamma gauge at the lower left side of the screen
    gauge_x = 10  # Left margin
    gauge_y = FULL_VIEW_SIZE[1] - 100  # Lower position
    draw_gamma_gauge(surface, gamma, gauge_x, gauge_y, 150, 80)
    
    # Completion message when goal is reached
    if reached_goal:
        completion_text = title_font.render("Goal Reached!", True, GREEN)
        text_rect = completion_text.get_rect(center=(GAME_CENTER[0], 80))
        surface.blit(completion_text, text_rect)
        
    # Show "Wrong Goal!" message if recently hit wrong goal
    now = time.time()
    if now - wrong_goal_message_time < 2.0:  # Show message for 2 seconds
        wrong_txt = title_font.render(f"Wrong Goal! Go to Goal {expected_goal_idx+1}", True, WRONG_GOAL_COLOR)
        wrong_rect = wrong_txt.get_rect(center=(GAME_CENTER[0], 80))
        surface.blit(wrong_txt, wrong_rect)

def skip_to_next_environment():
    global data_log, running, save_filename, current_seed, trial_start_time, failure_counter
    global gamma_mode_index, gamma_modes, current_gamma_mode

    # Save data for the current seed
    save_data_log(current_seed)

    # Figure out if we are at the last seed
    current_index = SCENARIO_SEEDS.index(current_seed)
    if current_index < len(SCENARIO_SEEDS) - 1:
        print("Skipping to the next environment (same gamma mode).")
        # Removed the lines that changed current_seed & re-initialized environment
        # Just stop the current loop so the outer for-loop advances to the next seed:
        running = False
    else:
        # Already at the last seed -> skip to the next gamma mode
        print("Already at the last environment seed. Skipping to next gamma mode.")
        running = False
        # Set a flag so we skip any remaining seeds in this gamma mode
        if 'skip_remaining_seeds' not in globals():
            global skip_remaining_seeds
        skip_remaining_seeds = True

###############################################################################
# MAIN EXPERIMENT LOOP
###############################################################################

gamma_modes = [0.0, 0.5, 1.0, "manual", "ai"]
current_seed = None
save_filename = None
skip_remaining_seeds = False

for gamma_mode_index, gamma_mode in enumerate(gamma_modes):
    current_gamma_mode = gamma_mode
    print(f"\n===== STARTING GAMMA MODE {gamma_mode_index+1} = {gamma_mode} =====\n")
    skip_remaining_seeds = False  # Reset at the start of each gamma mode

    for s_index, s in enumerate(SCENARIO_SEEDS):

        if skip_remaining_seeds and s_index > 0:
            print(f"Skipping seed {s} to move to next gamma mode")
            continue

        current_seed = s
        data_log = []
        save_filename = get_save_filename(s)

        initialize_environment_fixed(s)

        global trial_start_time, failure_counter, expected_goal_idx
        trial_start_time = time.time()
        failure_counter = 0
        expected_goal_idx = 0  # Always start with goal 1
        reset()

        running = True
        clock = pygame.time.Clock()

        while running:
            # Check if the trial has been running for too long (15 seconds timeout)
            current_time = time.time()
            trial_duration = current_time - trial_start_time
            
            if trial_duration > 15.0 and not reached_goal:
                print(f"Time limit exceeded ({trial_duration:.1f}s) -> resetting!")
                trial_outcome = "timeout"
                failure_counter += 1
                reset()
            
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    running = False
                    break
                if event.type == pygame.KEYDOWN:
                    if event.key == pygame.K_f:
                        if not FORCE_SENSOR_AVAILABLE and not USE_FORCE_SENSOR:
                            print("Force sensor not available!")
                        else:
                            USE_FORCE_SENSOR = not USE_FORCE_SENSOR
                            print("Force sensor mode:", USE_FORCE_SENSOR)
                    if event.key == pygame.K_LEFTBRACKET:
                        NOISE_MAGNITUDE = max(MIN_NOISE, NOISE_MAGNITUDE - NOISE_STEP)
                    elif event.key == pygame.K_RIGHTBRACKET:
                        NOISE_MAGNITUDE = min(MAX_NOISE, NOISE_MAGNITUDE + NOISE_STEP)
                    if event.key == pygame.K_r:
                        trial_outcome = "manual_reset"
                        reset()
                    
                    if event.key == pygame.K_n:
                        skip_to_next_environment()
                        print("N key pressed: Skipping to next environment (same gamma mode)")

                    if event.key == pygame.K_SPACE and USE_FORCE_SENSOR:
                        USE_AI_CONTROL = not USE_AI_CONTROL
                        print(f"{'AI' if USE_AI_CONTROL else 'Manual'} control enabled (Space Key)")

                if joystick and event.type == pygame.JOYBUTTONDOWN:
                    if event.button == 2:
                        trial_outcome = "manual_reset"
                        reset()
                    if event.button == 3:
                        USE_AI_CONTROL = not USE_AI_CONTROL
                        print(f"{'AI' if USE_AI_CONTROL else 'Manual'} control enabled")
                    if event.button == 4:  # Add L1 for skipping to next environment
                        skip_to_next_environment()
                        print("L1 pressed: Skipping to next environment (same gamma mode)")

                if event.type == pygame.USEREVENT:
                    if not reached_goal:
                        trial_outcome = "timeout"
                    reset()

            if not running:
                break

            # Check if all goals are completed 4 times
            all_goals_completed = True
            for i in range(len(targets)):
                if goal_counters.get(i, 0) < 4:
                    all_goals_completed = False
                    break
                    
            if all_goals_completed and len(targets) > 0:
                print("All goals completed 4 times each. Moving to next environment...")
                running = False
                continue
            elif len(targets) == 0:
                # If no targets at all, skip this environment
                print("No targets in this environment. Skipping...")
                running = False
                continue

            if not reached_goal:
                # Decide if we fix gamma or use AI or manual
                if isinstance(gamma_mode, float):
                    gamma = gamma_mode
                    USE_AI_CONTROL = False
                elif gamma_mode == "ai":
                    USE_AI_CONTROL = True
                else:
                    # "manual"
                    USE_AI_CONTROL = False

                if USE_FORCE_SENSOR and FORCE_SENSOR_AVAILABLE:
                    dx, dy = force_sensor_input
                else:
                    if USE_FORCE_SENSOR and not FORCE_SENSOR_AVAILABLE:
                        USE_FORCE_SENSOR = False

                    dx, dy = 0.0, 0.0
                    keys = pygame.key.get_pressed()
                    if keys[pygame.K_LEFT]:
                        dx -= 1
                    if keys[pygame.K_RIGHT]:
                        dx += 1
                    if keys[pygame.K_UP]:
                        dy -= 1
                    if keys[pygame.K_DOWN]:
                        dy += 1

                    if joystick:
                        axis_0 = joystick.get_axis(0)
                        axis_1 = joystick.get_axis(1)
                        deadzone = 0.1
                        if abs(axis_0) > deadzone or abs(axis_1) > deadzone:
                            dx = axis_0
                            dy = axis_1
                        else:
                            dx, dy = 0.0, 0.0
                        l2_val = joystick.get_axis(AXIS_L2)
                        r2_val = joystick.get_axis(AXIS_R2)
                        if gamma_mode == "manual":
                            if l2_val > 0.1:
                                gamma = max(0.0, gamma - 0.003)
                            if r2_val > 0.1:
                                gamma = min(1.0, gamma + 0.003)

                if abs(dx) < 0.1 and abs(dy) < 0.1:
                    dx, dy = 0.0, 0.0

                if not USE_FORCE_SENSOR:
                    dx *= MAX_SPEED
                    dy *= MAX_SPEED

                human_input = [dx, dy]
                
                # For gamma ≤ HIGH_GAMMA_THRESHOLD, use human input to select target
                if gamma <= HIGH_GAMMA_THRESHOLD:
                    proposed_idx = predict_human_target(human_input)
                    current_target_idx = proposed_idx if proposed_idx < len(targets) else 0
                # In HIGH_GAMMA_THRESHOLD (autonomous) mode, let user select initial target
                # and keep that selection throughout the trajectory

                h_dir, w_dir, x_dir = move_dot(human_input)
                if not reached_goal:
                    recent_positions.append((dot_pos[0], dot_pos[1], time.time()))
                    current_trajectory.append((dot_pos[0], dot_pos[1], gamma))
            else:
                h_dir, w_dir, x_dir = [0,0], [0,0], [0,0]

            render_full_view(surface_full, h_dir, w_dir, x_dir)
            render_red_only(surface_red_only, x_dir)

            tex1 = surface_to_texture(renderer1, surface_full)
            tex2 = surface_to_texture(renderer2, surface_red_only)

            renderer1.clear()
            tex1.draw(dstrect=(0, 0, FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1]))
            renderer1.present()

            renderer2.clear()
            tex2.draw(dstrect=(0, 0, RED_ONLY_SIZE[0], RED_ONLY_SIZE[1]))
            renderer2.present()

            clock.tick(60)

        save_data_log(s)
        print(f"Finished environment seed: {s} (Gamma mode={gamma_mode})")

if ser is not None:
    ser.close()

pygame.quit()
print("All seeds completed. Exiting.")

pygame 2.6.0 (SDL 2.28.4, Python 3.10.9)
Hello from the pygame community. https://www.pygame.org/contribute.html
Error: Could not open serial port: module 'serial' has no attribute 'Serial'
Force sensor not available. Using keyboard/joystick controls only.
Joystick initialized: DualSense Wireless Controller

===== STARTING GAMMA MODE 1 = 0.0 =====

Environment initialized with fixed seed 0. #goals=5, #obstacles=3
Initial user target set to: 3
Manual mode, gamma ~0 => dot is attracted to obstacles
Manual mode, gamma ~0 => dot is attracted to obstacles
Manual mode, gamma ~0 => dot is attracted to obstacles
Manual mode, gamma ~0 => dot is attracted to obstacles
Manual mode, gamma ~0 => dot is attracted to obstacles
Manual mode, gamma ~0 => dot is attracted to obstacles
Manual mode, gamma ~0 => dot is attracted to obstacles
Manual mode, gamma ~0 => dot is attracted to obstacles
Manual mode, gamma ~0 => dot is attracted to obstacles
Manual mode, gamma ~0 => dot is attracted to obstacles
Man