In [1]:
import os
import datetime
import pygame
import random
import math
import time
import numpy as np
import json  # for saving log data
from stable_baselines3 import PPO
from pygame._sdl2.video import Window, Renderer, Texture

# -----------------------------------------------------------
# Force Sensor Integration (extracted from your force sensor code)
# -----------------------------------------------------------

import tkinter as tk
from tkinter import ttk
import serial as pyserial
import threading
from collections import deque

###############################################################################
# Multi-Seed Setup
###############################################################################
SCENARIO_SEEDS = [0, 2, 58]  # We will run the experiment for each of these seeds in turn.

###############################################################################
# Global flags & constants
###############################################################################
USE_FORCE_SENSOR = False  # toggle with F key
force_sensor_input = [0.0, 0.0]
FORCE_SENSOR_AVAILABLE = False

FORCE_SENSOR_SCALE_X = 5.0
FORCE_SENSOR_SCALE_Y = 5.0

XY_FORCE_CAL = 1.33
Z_SMOOTH_ALPHA = 0.2

sample_prev_fs = None
sample_curr_fs = None
sample_lock_fs = threading.Lock()
last_fx_smooth = None
last_fy_smooth = None

try:
    ser = pyserial.Serial('COM5', 115200, timeout=0.01)
    FORCE_SENSOR_AVAILABLE = True
    print("Force sensor connected successfully.")
except Exception as e:
    print("Error: Could not open serial port:", e)
    ser = None
    FORCE_SENSOR_AVAILABLE = False
    print("Force sensor not available. Using keyboard/joystick controls only.")

def serial_reader_fs():
    global sample_prev_fs, sample_curr_fs, last_fx_smooth, last_fy_smooth, force_sensor_input
    while True:
        if ser is None:
            time.sleep(0.1)
            continue
        try:
            line = ser.readline().decode('utf-8').strip()
            if not line:
                continue
            tokens = line.split(',')
            if len(tokens) != 2:
                continue
            try:
                values = list(map(float, tokens))
            except Exception:
                continue
            sample = {
                'fx': -values[0],
                'fy': values[1],
                'timestamp': time.time()
            }
            with sample_lock_fs:
                if sample_curr_fs is None:
                    sample_curr_fs = sample
                    sample_prev_fs = sample
                else:
                    sample_prev_fs = sample_curr_fs
                    sample_curr_fs = sample

            # Interpolate horizontal forces
            with sample_lock_fs:
                sp = sample_prev_fs
                sc = sample_curr_fs
            now = time.time()
            if sp is None or sc is None:
                continue
            t0 = sp['timestamp']
            t1 = sc['timestamp']
            fraction = 1.0 if t1 == t0 else (now - t0) / (t1 - t0)
            fraction = max(0.0, min(1.0, fraction))
            fx_interp = sp['fx'] * (1 - fraction) + sc['fx'] * fraction
            fy_interp = sp['fy'] * (1 - fraction) + sc['fy'] * fraction

            # Exponential smoothing
            if last_fx_smooth is None:
                smoothed_fx = fx_interp
            else:
                smoothed_fx = last_fx_smooth + Z_SMOOTH_ALPHA * (fx_interp - last_fx_smooth)
            if last_fy_smooth is None:
                smoothed_fy = fy_interp
            else:
                smoothed_fy = last_fy_smooth + Z_SMOOTH_ALPHA * (fy_interp - last_fy_smooth)
            last_fx_smooth = smoothed_fx
            last_fy_smooth = smoothed_fy

            # Scale with calibration
            final_fx = smoothed_fx * XY_FORCE_CAL
            final_fy = smoothed_fy * XY_FORCE_CAL
            final_fx *= FORCE_SENSOR_SCALE_X
            final_fy *= FORCE_SENSOR_SCALE_Y

            force_sensor_input = [final_fx, final_fy]
        except Exception as e:
            print("Serial reader error:", e)
            time.sleep(0.01)

if FORCE_SENSOR_AVAILABLE:
    threading.Thread(target=serial_reader_fs, daemon=True).start()

pygame.init()
pygame.joystick.init()

USE_AI_CONTROL = False

# Noise Models
def eeg_noise():
    """ Default: Gaussian EEG noise with σ=190. """
    return np.random.normal(0, 190)

def burst_noise():
    """ Rare large deviations (~20% prob). """
    return np.random.choice([-1,1]) * np.random.normal(0, 190) * (np.random.rand() < 0.2)

from scipy.signal import lfilter
def pink_noise():
    """ Pink noise (1/f). """
    a = [1, -0.95]
    return lfilter([1], a, np.random.normal(0, 290, size=1))[0]

NOISE_FUNCTION = pink_noise  # or burst_noise, etc.

###############################################################################
# Config / Constants
###############################################################################
FULL_VIEW_SIZE = (1200, 800)
RED_ONLY_SIZE  = (1200, 800)

NOISE_MAGNITUDE = 2.5
MIN_NOISE = 0.0
MAX_NOISE = 2.0
NOISE_STEP = 0.1

OLD_WINDOW_SIZE   = (600, 600)
SCALING_FACTOR_X  = FULL_VIEW_SIZE[0] / OLD_WINDOW_SIZE[0]
SCALING_FACTOR_Y  = FULL_VIEW_SIZE[1] / OLD_WINDOW_SIZE[1]
SCALING_FACTOR    = (SCALING_FACTOR_X + SCALING_FACTOR_Y) / 2

WHITE  = (255, 255, 255)
BLACK  = (0, 0, 0)
RED    = (255, 0, 0)
GREEN  = (0, 200, 0)
BLUE   = (0, 0, 255)
YELLOW = (255, 255, 0)
GRAY   = (128, 128, 128)

FONT_COLOR = (0, 0, 0)
FONT_SIZE = int(16 * SCALING_FACTOR)
ARROW_LENGTH = int(60 * SCALING_FACTOR)

# Obstacles
OBSTACLE_RADIUS  = int(10 * SCALING_FACTOR)
COLLISION_BUFFER = int(5 * SCALING_FACTOR)
ENABLE_OBSTACLES = True
MAX_SPEED        = 3 * SCALING_FACTOR

# Dot & targets
DOT_RADIUS            = int(15 * SCALING_FACTOR)
TARGET_RADIUS         = int(10 * SCALING_FACTOR)
GOAL_DETECTION_RADIUS = DOT_RADIUS + TARGET_RADIUS

GHOST_TRAIL_DURATION  = 3.0
recent_positions      = []
last_reset_time       = time.time()

RECENT_DIR_LOOKBACK   = 1.0
GOAL_SWITCH_THRESHOLD = 0.05

WINDOW_CENTER = (FULL_VIEW_SIZE[0] // 2, FULL_VIEW_SIZE[1] // 2)
START_POS = [WINDOW_CENTER[0], WINDOW_CENTER[1]]
dot_pos   = START_POS.copy()

gamma         = 0.2
reached_goal  = False
targets       = []
current_target_idx = 0
obstacles     = []

# Stats
goal_counters = {}
failure_counter = 0

USE_RAW_ONLY_FOR_GOAL_DETECTION = True

joystick = None
if pygame.joystick.get_count() > 0:
    joystick = pygame.joystick.Joystick(0)
    joystick.init()
    print("Joystick initialized:", joystick.get_name())
else:
    print("No joystick detected.")

AXIS_L2 = 4
AXIS_R2 = 5

# Create windows
window1 = Window("2D Environment: Full View", size=FULL_VIEW_SIZE)
renderer1 = Renderer(window1, vsync=True)
window2 = Window("2D Environment: Red Arrow Only", size=RED_ONLY_SIZE)
renderer2 = Renderer(window2, vsync=True)

def create_compatible_surface(size):
    return pygame.Surface(size, flags=pygame.SRCALPHA)

surface_full = create_compatible_surface(FULL_VIEW_SIZE)
surface_red_only = create_compatible_surface(RED_ONLY_SIZE)

font = pygame.font.Font(None, FONT_SIZE)

def surface_to_texture(renderer, surf):
    if surf.get_bitsize() != 32:
        surf = surf.convert_alpha()
    return Texture.from_surface(renderer, surf)

# Helper functions
def distance(pos1, pos2):
    return math.hypot(pos1[0] - pos2[0], pos1[1] - pos2[1])

def line_circle_intersection(start, end, circle_center, radius):
    dx = end[0] - start[0]
    dy = end[1] - start[1]
    cx = circle_center[0] - start[0]
    cy = circle_center[1] - start[1]
    l2 = dx*dx + dy*dy
    if l2 == 0:
        return distance(start, circle_center) <= radius
    t = max(0, min(1, (cx*dx + cy*dy) / l2))
    proj_x = start[0] + t * dx
    proj_y = start[1] + t * dy
    return distance((proj_x, proj_y), circle_center) <= radius

def check_collision(pos, new_pos):
    if not ENABLE_OBSTACLES:
        return False
    for obstacle_pos in obstacles:
        if line_circle_intersection(pos, new_pos, obstacle_pos, OBSTACLE_RADIUS + COLLISION_BUFFER):
            return True
    return False

def inside_obstacle(pos):
    if not ENABLE_OBSTACLES:
        return False
    for obstacle_pos in obstacles:
        if distance(pos, obstacle_pos) <= (OBSTACLE_RADIUS + DOT_RADIUS):
            return True
    return False

def get_recent_direction():
    if len(recent_positions) < 2:
        return [0, 0]
    current_time = time.time()
    valid_points = []
    for (x, y, t) in reversed(recent_positions):
        if (current_time - t) <= RECENT_DIR_LOOKBACK:
            valid_points.append((x, y, t))
        else:
            break
    if len(valid_points) < 2:
        return [0, 0]
    valid_points.sort(key=lambda p: p[2])
    x1, y1, t1 = valid_points[0]
    x2, y2, t2 = valid_points[-1]
    dt = t2 - t1
    if dt < 0.001:
        return [0, 0]
    vx = (x2 - x1) / dt
    vy = (y2 - y1) / dt
    mag = math.hypot(vx, vy)
    return [vx/mag, vy/mag] if mag > 0 else [0, 0]

def compute_perfect_direction(dot_pos, goal_pos, obstacles):
    gx = goal_pos[0] - dot_pos[0]
    gy = goal_pos[1] - dot_pos[1]
    goal_dist = math.hypot(gx, gy)
    if goal_dist < 1e-6:
        return [0, 0]
    goal_dir = [gx / goal_dist, gy / goal_dist]

    repulse_x = 0.0
    repulse_y = 0.0
    repulsion_radius = 27 * SCALING_FACTOR
    repulsion_gain   = 30000.0

    for obs in obstacles:
        dx = dot_pos[0] - obs[0]
        dy = dot_pos[1] - obs[1]
        dist_o = math.hypot(dx, dy)
        if dist_o < 1e-6:
            continue
        if dist_o < repulsion_radius:
            push_dir_x = dx / dist_o
            push_dir_y = dy / dist_o
            strength = repulsion_gain / (dist_o**2)
            repulse_x += push_dir_x * strength
            repulse_y += push_dir_y * strength

    w_px = goal_dir[0] + repulse_x
    w_py = goal_dir[1] + repulse_y
    norm = math.hypot(w_px, w_py)
    if norm < 1e-6:
        return [0, 0]
    return [w_px / norm, w_py / norm]

class GammaPredictor:
    def __init__(self, model_path="gamma_ppo_model.zip"):
        try:
            self.model = PPO.load(model_path)
        except:
            self.model = None
        self.max_dist = np.sqrt(FULL_VIEW_SIZE[0]**2 + FULL_VIEW_SIZE[1]**2)

    def prepare_observation(self, dot_pos, target_pos, human_input):
        dot_pos = np.array(dot_pos, dtype=np.float32)
        target_pos = np.array(target_pos, dtype=np.float32)
        to_target = target_pos - dot_pos
        dist = np.linalg.norm(to_target)
        perfect_dir = to_target / dist if dist > 0 else np.array([0, 0], dtype=np.float32)
        h_mag = np.linalg.norm(human_input)
        human_dir = human_input / h_mag if h_mag > 0 else np.array([0, 0], dtype=np.float32)
        normalized_dist = dist / self.max_dist
        obs_dist_ratio = 1.0
        obs = np.concatenate([
            dot_pos,
            human_dir,
            target_pos,
            perfect_dir,
            [normalized_dist],
            [obs_dist_ratio]
        ]).astype(np.float32)
        return obs

    def predict_gamma(self, dot_pos, target_pos, human_input):
        if self.model is None:
            return 0.2
        obs = self.prepare_observation(dot_pos, target_pos, human_input)
        obs_batched = obs[np.newaxis, :]
        action, _ = self.model.predict(obs_batched, deterministic=True)
        return float(action[0])

gamma_predictor = GammaPredictor()

def predict_human_target(human_input):
    global current_target_idx
    dist_to_current = distance(dot_pos, targets[current_target_idx])
    close_threshold = GOAL_DETECTION_RADIUS * 2
    if dist_to_current < close_threshold:
        return current_target_idx

    if human_input[0] == 0 and human_input[1] == 0:
        return current_target_idx

    h_mag = math.hypot(human_input[0], human_input[1])
    h_dir = [h_input/h_mag for h_input in human_input] if h_mag > 0 else [0, 0]

    if USE_RAW_ONLY_FOR_GOAL_DETECTION:
        recent_dir = [0, 0]
    else:
        recent_dir = get_recent_direction()

    best_score = float('-inf')
    best_idx   = current_target_idx

    for i, targ in enumerate(targets):
        to_tx = targ[0] - dot_pos[0]
        to_ty = targ[1] - dot_pos[1]
        to_mag = math.hypot(to_tx, to_ty)
        if to_mag == 0:
            continue
        to_dir = [to_tx/to_mag, to_ty/to_mag]
        align_human  = (h_dir[0]*to_dir[0] + h_dir[1]*to_dir[1])
        align_recent = (recent_dir[0]*to_dir[0] + recent_dir[1]*to_dir[1])
        score = (align_human * 0.8) + (align_recent * 0.2)
        if score > best_score:
            best_score = score
            best_idx   = i

    return best_idx

data_log = []
trial_start_time = time.time()
current_trajectory = []
trial_outcome = None

save_folder = "user_study_data"
os.makedirs(save_folder, exist_ok=True)

def get_save_filename(seed):
    session_timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
    return os.path.join(save_folder, f"data_log_seed{seed}_{session_timestamp}.json")

save_filename = None

def save_data_log(seed):
    if not save_filename:
        return
    with open(save_filename, "w") as f:
        json.dump(data_log, f, indent=4)
    print(f"Data log updated and saved to {save_filename} for seed={seed}")

def move_dot(human_input):
    global dot_pos, gamma, reached_goal, current_target_idx, USE_AI_CONTROL, trial_outcome

    # -----------------------------------------------------------------
    # 1) Compute raw user direction (H) and perfect direction (W)
    # -----------------------------------------------------------------
    h_dx, h_dy = human_input
    h_mag = math.hypot(h_dx, h_dy)
    h_dir = [h_dx / h_mag, h_dy / h_mag] if h_mag > 0 else [0, 0]

    target_pos = targets[current_target_idx]
    w_dir = compute_perfect_direction(dot_pos, target_pos, obstacles)

    # Possibly compute step_size from user magnitude
    input_mag = min(max(h_mag / MAX_SPEED, 0), 1)
    step_size = MAX_SPEED * input_mag

    # If AI control is used, do dynamic gamma logic
    if USE_AI_CONTROL and h_mag > 0:
        dist_to_target = distance(dot_pos, target_pos)
        min_obs_distance = min(distance(dot_pos, obs) for obs in obstacles) if obstacles else float('inf')
        goal_threshold = GOAL_DETECTION_RADIUS * 3
        obs_threshold = (OBSTACLE_RADIUS + DOT_RADIUS) * 3
        base_gamma = 0.2

        if dist_to_target < goal_threshold:
            goal_factor = 1.0 - (dist_to_target / goal_threshold)
            base_gamma = max(base_gamma, 0.2 + 0.5 * goal_factor)
        if min_obs_distance < obs_threshold:
            obs_factor = 1.0 - (min_obs_distance / obs_threshold)
            base_gamma = max(base_gamma, 0.4 + 0.5 * obs_factor)
        if dist_to_target < goal_threshold and min_obs_distance < obs_threshold:
            base_gamma = max(base_gamma, 0.7)

        noise = random.uniform(-0.05, 0.05)
        final_gamma = base_gamma + noise
        gamma = max(0.0, min(1.0, final_gamma))

    # Degrade static gamma modes to push preference for dynamic
    if 0.0 <= gamma < 0.05:
        degrade_mode = "manual"
    elif 0.95 < gamma <= 1.0:
        degrade_mode = "full_ai"
    elif abs(gamma - 0.5) < 0.06:  # e.g., 0.44-0.56
        degrade_mode = "half"
    else:
        degrade_mode = "none"

    w_move_x = gamma * w_dir[0] * step_size
    w_move_y = gamma * w_dir[1] * step_size

    # If gamma=0.5 degrade the AI path with random offset
    if degrade_mode == "half":
        offset = 0.2
        w_dir[0] += np.random.uniform(-offset, offset)
        w_dir[1] += np.random.uniform(-offset, offset)
        mag_w = math.hypot(w_dir[0], w_dir[1])
        if mag_w > 1e-4:
            w_dir[0] /= mag_w
            w_dir[1] /= mag_w
        w_move_x = gamma * w_dir[0] * step_size
        w_move_y = gamma * w_dir[1] * step_size

    # Next, handle user noise
    noise_x = np.random.normal(0, NOISE_MAGNITUDE)
    noise_y = np.random.normal(0, NOISE_MAGNITUDE)
    if degrade_mode == "manual":
        # amplify noise
        noise_x *= 1.5
        noise_y *= 1.5

    if degrade_mode == "full_ai":
        # user input ignored
        h_dir = [0, 0]

    noisy_dx = h_dir[0] + noise_x
    noisy_dy = h_dir[1] + noise_y
    nm = math.hypot(noisy_dx, noisy_dy)
    if nm > 0:
        noisy_dx /= nm
        noisy_dy /= nm

    h_move_x = (1 - gamma) * noisy_dx * step_size
    h_move_y = (1 - gamma) * noisy_dy * step_size

    final_dx = w_move_x + h_move_x
    final_dy = w_move_y + h_move_y

    new_x = dot_pos[0] + final_dx
    new_y = dot_pos[1] + final_dy

    if not check_collision(dot_pos, [new_x, new_y]):
        dot_pos[0] = max(0, min(FULL_VIEW_SIZE[0], new_x))
        dot_pos[1] = max(0, min(FULL_VIEW_SIZE[1], new_y))

    if inside_obstacle(dot_pos):
        print("Collision with obstacle -> resetting!")
        trial_outcome = "collision"
        global failure_counter
        failure_counter += 1
        reset()
        return [0,0], [0,0], [0,0]

    final_mag = math.hypot(final_dx, final_dy)
    x_dir = [final_dx / final_mag, final_dy / final_mag] if final_mag > 0 else [0, 0]

    dist_to_goal = distance(dot_pos, target_pos)
    if dist_to_goal < GOAL_DETECTION_RADIUS:
        global reached_goal
        reached_goal = True
        trial_outcome = "success"
        if current_target_idx not in goal_counters:
            goal_counters[current_target_idx] = 1
        else:
            goal_counters[current_target_idx] += 1
        pygame.time.set_timer(pygame.USEREVENT, 1000)

    return h_dir, w_dir, x_dir

def reset():
    global dot_pos, reached_goal, current_target_idx, gamma
    global recent_positions, last_reset_time, trial_start_time, current_trajectory, trial_outcome
    global failure_counter
    
    if trial_start_time is not None and len(current_trajectory) > 0:
        trial_duration = time.time() - trial_start_time
        goal_reached = targets[current_target_idx] if reached_goal else None
        mode = "AI" if USE_AI_CONTROL else "Manual"
        trial_timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")

        if trial_outcome == "manual_reset":
            failure_counter += 1

        data_log.append({
            "timestamp": trial_timestamp,
            "mode": mode,
            "trial_duration": trial_duration,
            "trial_outcome": trial_outcome if trial_outcome else "manual_reset",
            "goal_reached": goal_reached,
            "trajectory": current_trajectory.copy()
        })
        print(f"Trial recorded: {data_log[-1]}")
        save_data_log(current_seed)

    dot_pos = START_POS.copy()
    reached_goal = False
    current_target_idx = 0
    gamma = 0.95
    recent_positions.clear()
    last_reset_time = time.time()
    trial_start_time = time.time()
    current_trajectory.clear()
    trial_outcome = None
    pygame.time.set_timer(pygame.USEREVENT, 0)

def initialize_environment_fixed(seed):
    global obstacles, targets, goal_counters
    random.seed(seed)
    np.random.seed(seed)
    obstacles.clear()
    targets.clear()
    margin = 50 * SCALING_FACTOR
    min_goal_gap = 200 * SCALING_FACTOR
    N_GOALS = 8
    N_OBSTACLES = 5

    new_goals = []
    attempts = 0
    while len(new_goals) < N_GOALS and attempts < 1000:
        x = random.uniform(margin, FULL_VIEW_SIZE[0] - margin)
        y = random.uniform(margin, FULL_VIEW_SIZE[1] - margin)
        candidate = (x, y)
        if all(distance(candidate, g) >= min_goal_gap for g in new_goals):
            new_goals.append(candidate)
        attempts += 1
    targets.extend(new_goals)
    
    goal_counters = {i: 0 for i in range(len(targets))}

    new_obstacles = []
    if len(new_goals) > 1:
        obstacle_goals = random.sample(new_goals, k=min(min(N_GOALS-1, N_OBSTACLES), len(new_goals)-1))
    else:
        obstacle_goals = new_goals
    for goal in obstacle_goals:
        t = random.uniform(0.6, 0.8)
        base_point = (START_POS[0] + t * (goal[0] - START_POS[0]),
                      START_POS[1] + t * (goal[1] - START_POS[1]))
        vec = (goal[0] - START_POS[0], goal[1] - START_POS[1])
        vec_norm = math.hypot(vec[0], vec[1])
        if vec_norm < 1e-6:
            perp = (0, 0)
        else:
            perp = (-vec[1] / vec_norm, vec[0] / vec_norm)
        offset_mag = random.uniform(20 * SCALING_FACTOR, 40 * SCALING_FACTOR)
        offset = (perp[0] * offset_mag * random.choice([-1,1]),
                  perp[1] * offset_mag * random.choice([-1,1]))
        candidate = (base_point[0] + offset[0], base_point[1] + offset[1])
        candidate = (max(margin, min(candidate[0], FULL_VIEW_SIZE[0] - margin)),
                     max(margin, min(candidate[1], FULL_VIEW_SIZE[1] - margin)))
        valid = True
        if distance(candidate, START_POS) < (DOT_RADIUS + OBSTACLE_RADIUS + 10):
            valid = False
        if distance(candidate, goal) < (TARGET_RADIUS + OBSTACLE_RADIUS + 20):
            valid = False
        for obs in new_obstacles:
            if distance(candidate, obs) < (2 * OBSTACLE_RADIUS + 10):
                valid = False
        if valid:
            new_obstacles.append(candidate)
    obstacles.extend(new_obstacles)
    print(f"Environment initialized with fixed seed {seed}.")

def draw_arrow(surface, color, start_pos, direction, length=ARROW_LENGTH):
    dx, dy = direction
    if dx == 0 and dy == 0:
        return
    mag = math.hypot(dx, dy)
    dx /= mag
    dy /= mag
    end_x = start_pos[0] + dx * length
    end_y = start_pos[1] + dy * length
    pygame.draw.line(surface, color, start_pos, (end_x, end_y), int(2 * SCALING_FACTOR))
    arrow_size = 7 * SCALING_FACTOR
    angle = math.atan2(dy, dx)
    arrow1_x = end_x - arrow_size * math.cos(angle + math.pi/6)
    arrow1_y = end_y - arrow_size * math.sin(angle + math.pi/6)
    arrow2_x = end_x - arrow_size * math.cos(angle - math.pi/6)
    arrow2_y = end_y - arrow_size * math.sin(angle - math.pi/6)
    pygame.draw.line(surface, color, (end_x, end_y), (arrow1_x, arrow1_y), int(2 * SCALING_FACTOR))
    pygame.draw.line(surface, color, (end_x, end_y), (arrow2_x, arrow2_y), int(2 * SCALING_FACTOR))

def render_full_view(surface, h_dir, w_dir, x_dir):
    surface.fill(WHITE)
    if ENABLE_OBSTACLES:
        for obstacle_pos in obstacles:
            pygame.draw.circle(surface, GRAY, (int(obstacle_pos[0]), int(obstacle_pos[1])), OBSTACLE_RADIUS)
    for i, target in enumerate(targets):
        pygame.draw.circle(surface, YELLOW, (int(target[0]), int(target[1])), TARGET_RADIUS)
        num_text = font.render(str(i + 1), True, BLACK)
        surface.blit(num_text, (target[0] - 5, target[1] - 12))
    curr_t = targets[current_target_idx]
    pygame.draw.circle(surface, BLACK, (int(curr_t[0]), int(curr_t[1])), TARGET_RADIUS + 2, int(2 * SCALING_FACTOR))

    now = time.time()
    while len(recent_positions) > 0 and (now - recent_positions[0][2]) > GHOST_TRAIL_DURATION:
        recent_positions.pop(0)
    if len(recent_positions) > 1:
        for idx in range(len(recent_positions) - 1):
            x1, y1, t1 = recent_positions[idx]
            x2, y2, t2 = recent_positions[idx+1]
            pygame.draw.line(surface, (200, 200, 200), (x1, y1), (x2, y2), 2)
    pygame.draw.circle(surface, BLACK, (int(dot_pos[0]), int(dot_pos[1])), DOT_RADIUS, int(2 * SCALING_FACTOR))

    # draw arrows
    if h_dir != [0,0]:
        draw_arrow(surface, BLUE, (int(dot_pos[0]), int(dot_pos[1])), h_dir, ARROW_LENGTH)
    if w_dir != [0,0]:
        draw_arrow(surface, GREEN, (int(dot_pos[0]), int(dot_pos[1])), w_dir, ARROW_LENGTH)
    if x_dir != [0,0]:
        draw_arrow(surface, RED, (int(dot_pos[0]), int(dot_pos[1])), x_dir, ARROW_LENGTH)

    g_txt = font.render(f"Gamma: {gamma:.2f}", True, FONT_COLOR)
    surface.blit(g_txt, (10, 10))
    form_txt = font.render(f"Movement = {gamma:.2f}W + {1-gamma:.2f}H", True, FONT_COLOR)
    surface.blit(form_txt, (10, 40))
    noise_txt = font.render(f"Noise σ: {NOISE_MAGNITUDE:.2f}", True, FONT_COLOR)
    surface.blit(noise_txt, (10, 100))
    instr_txt = font.render("L2/R2: gamma, [/]: noise, R: reset", True, FONT_COLOR)
    surface.blit(instr_txt, (10, 70))
    mode_txt = font.render(f"Control: {'AI' if USE_AI_CONTROL else 'Manual'} (Y to toggle)", True, FONT_COLOR)
    surface.blit(mode_txt, (10, 160))

    force_mode_txt = font.render(f"Input: {'Force Sensor' if USE_FORCE_SENSOR else 'Joystick/Keyboard'} (F to toggle)", True, FONT_COLOR)
    surface.blit(force_mode_txt, (10, 190))
    if not FORCE_SENSOR_AVAILABLE:
        no_sensor_txt = font.render("Force sensor not available!", True, RED)
        surface.blit(no_sensor_txt, (10, 220))

    elapsed_time = time.time() - last_reset_time
    timer_text = font.render(f"Time: {elapsed_time:.1f}s", True, FONT_COLOR)
    surface.blit(timer_text, (10, 130))
    if reached_goal:
        r_txt = font.render(f"Goal Reached in {elapsed_time:.1f}s! Auto-resetting...", True, FONT_COLOR)
        surface.blit(r_txt, (150, 110))
    seed_txt = font.render(f"Scenario Seed: {current_seed}", True, FONT_COLOR)
    surface.blit(seed_txt, (10, 250))

    legend_y = FULL_VIEW_SIZE[1] - int(100 * SCALING_FACTOR)
    legend_spacing = int(30 * SCALING_FACTOR)
    legend_items = [
        ("Green Arrow: Perfect Path (W)", GREEN),
        ("Blue Arrow: Human Movement (H)", BLUE),
        ("Red Arrow: Dot's Movement", RED),
        ("Gray line: Movement History", (200, 200, 200))
    ]
    for i, (lbl, color) in enumerate(legend_items):
        label = font.render(lbl, True, color)
        surface.blit(label, (10, legend_y + i*legend_spacing))

    counter_x = FULL_VIEW_SIZE[0] - 200
    counter_y = 10
    counter_spacing = int(18 * SCALING_FACTOR)
    counter_header = font.render("Results:", True, BLACK)
    surface.blit(counter_header, (counter_x, counter_y))
    counter_y += counter_spacing
    for i in range(len(targets)):
        count = goal_counters.get(i, 0)
        goal_txt = font.render(f"Goal {i+1}: {count}", True, GREEN)
        surface.blit(goal_txt, (counter_x, counter_y))
        counter_y += counter_spacing
    failure_txt = font.render(f"Failures: {failure_counter}", True, RED)
    surface.blit(failure_txt, (counter_x, counter_y))

def render_red_only(surface, x_dir):
    surface.fill(WHITE)
    if ENABLE_OBSTACLES:
        for obstacle_pos in obstacles:
            pygame.draw.circle(surface, GRAY, (int(obstacle_pos[0]), int(obstacle_pos[1])), OBSTACLE_RADIUS)
    for i, target in enumerate(targets):
        color = GREEN if i == current_target_idx else RED
        pygame.draw.circle(surface, color, (int(target[0]), int(target[1])), TARGET_RADIUS)
        if i == current_target_idx:
            pygame.draw.circle(surface, BLACK, (int(target[0]), int(target[1])), TARGET_RADIUS + 2, int(2 * SCALING_FACTOR))
        num_text = font.render(str(i + 1), True, BLACK)
        surface.blit(num_text, (target[0] - 5, target[1] - 12))
    pygame.draw.circle(surface, BLACK, (int(dot_pos[0]), int(dot_pos[1])), DOT_RADIUS, int(2 * SCALING_FACTOR))
    if x_dir != [0,0]:
        draw_arrow(surface, RED, (int(dot_pos[0]), int(dot_pos[1])), x_dir, ARROW_LENGTH)

    elapsed_time = time.time() - last_reset_time
    timer_text = font.render(f"Time: {elapsed_time:.1f}s", True, BLACK)
    surface.blit(timer_text, (10, 10))

    if USE_FORCE_SENSOR:
        mode_text = font.render("Force Sensor Mode", True, BLACK)
        surface.blit(mode_text, (10, 40))

    if reached_goal:
        completion_text = font.render("Goal Reached!", True, BLACK)
        text_rect = completion_text.get_rect(center=(RED_ONLY_SIZE[0]/2, 40))
        surface.blit(completion_text, text_rect)

    counter_x = RED_ONLY_SIZE[0] - 200
    counter_y = 10
    counter_spacing = int(18 * SCALING_FACTOR)
    counter_header = font.render("Counter:", True, BLACK)
    surface.blit(counter_header, (counter_x, counter_y))
    counter_y += counter_spacing
    for i in range(len(targets)):
        count = goal_counters.get(i, 0)
        goal_txt = font.render(f"Goal {i+1}: {count}", True, GREEN)
        surface.blit(goal_txt, (counter_x, counter_y))
        counter_y += counter_spacing
    failure_txt = font.render(f"Failures: {failure_counter}", True, RED)
    surface.blit(failure_txt, (counter_x, counter_y))

# Main experiment loop
current_seed = None

for s in SCENARIO_SEEDS:
    current_seed = s
    data_log = []
    save_filename = get_save_filename(s)

    initialize_environment_fixed(s)

    global trial_start_time, failure_counter
    trial_start_time = time.time()
    failure_counter = 0
    reset()

    running = True
    clock = pygame.time.Clock()

    while running:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                running = False
                break
            if event.type == pygame.KEYDOWN:
                if event.key == pygame.K_f:
                    if not FORCE_SENSOR_AVAILABLE and not USE_FORCE_SENSOR:
                        print("Force sensor not available!")
                    else:
                        USE_FORCE_SENSOR = not USE_FORCE_SENSOR
                        print("Force sensor mode:", USE_FORCE_SENSOR)
                if event.key == pygame.K_LEFTBRACKET:
                    global NOISE_MAGNITUDE
                    NOISE_MAGNITUDE = max(MIN_NOISE, NOISE_MAGNITUDE - NOISE_STEP)
                elif event.key == pygame.K_RIGHTBRACKET:
                    NOISE_MAGNITUDE = min(MAX_NOISE, NOISE_MAGNITUDE + NOISE_STEP)
                if event.key == pygame.K_r:
                    global trial_outcome
                    trial_outcome = "manual_reset"
                    reset()

                if event.key == pygame.K_SPACE and USE_FORCE_SENSOR:
                    USE_AI_CONTROL = not USE_AI_CONTROL
                    print(f"{'AI' if USE_AI_CONTROL else 'Manual'} control enabled (Space Key)")

            if joystick and event.type == pygame.JOYBUTTONDOWN:
                if event.button == 2:
                    trial_outcome = "manual_reset"
                    reset()
                if event.button == 3:
                    USE_AI_CONTROL = not USE_AI_CONTROL
                    print(f"{'AI' if USE_AI_CONTROL else 'Manual'} control enabled")

            if event.type == pygame.USEREVENT:
                if not reached_goal:
                    trial_outcome = "timeout"
                reset()

        if not running:
            break

        # Check if we have 2 successes for each goal. If yes, move to next seed.
        # i.e., if all goals have 2 or more in goal_counters => break from the loop.
        all_goals_2x = all(count >= 2 for count in goal_counters.values())
        if all_goals_2x:
            print("All goals reached at least 2 times. Moving on to next environment...")
            running = False
            continue  # jump out of the while loop

        if not reached_goal:
            # Force sensor or fallback
            if USE_FORCE_SENSOR and FORCE_SENSOR_AVAILABLE:
                dx, dy = force_sensor_input
            else:
                if USE_FORCE_SENSOR and not FORCE_SENSOR_AVAILABLE:
                    USE_FORCE_SENSOR = False

                dx, dy = 0.0, 0.0
                keys = pygame.key.get_pressed()
                if keys[pygame.K_LEFT]:
                    dx -= 1
                if keys[pygame.K_RIGHT]:
                    dx += 1
                if keys[pygame.K_UP]:
                    dy -= 1
                if keys[pygame.K_DOWN]:
                    dy += 1

                if joystick:
                    axis_0 = joystick.get_axis(0)
                    axis_1 = joystick.get_axis(1)
                    deadzone = 0.1
                    if abs(axis_0) > deadzone or abs(axis_1) > deadzone:
                        dx = axis_0
                        dy = axis_1
                    else:
                        dx, dy = 0.0, 0.0
                    l2_val = joystick.get_axis(AXIS_L2)
                    r2_val = joystick.get_axis(AXIS_R2)
                    if l2_val > 0.1:
                        gamma = max(0.0, gamma - 0.01)
                    if r2_val > 0.1:
                        gamma = min(1.0, gamma + 0.01)

            if abs(dx) < 0.1 and abs(dy) < 0.1:
                dx, dy = 0.0, 0.0

            if not USE_FORCE_SENSOR:
                dx *= MAX_SPEED
                dy *= MAX_SPEED

            human_input = [dx, dy]
            proposed_idx = predict_human_target(human_input)
            current_target_idx = proposed_idx

            h_dir, w_dir, x_dir = move_dot(human_input)
            if not reached_goal:
                recent_positions.append((dot_pos[0], dot_pos[1], time.time()))
                current_trajectory.append((dot_pos[0], dot_pos[1], gamma))
        else:
            h_dir, w_dir, x_dir = [0,0], [0,0], [0,0]

        render_full_view(surface_full, h_dir, w_dir, x_dir)
        render_red_only(surface_red_only, x_dir)

        tex1 = surface_to_texture(renderer1, surface_full)
        tex2 = surface_to_texture(renderer2, surface_red_only)

        renderer1.clear()
        tex1.draw(dstrect=(0, 0, FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1]))
        renderer1.present()

        renderer2.clear()
        tex2.draw(dstrect=(0, 0, RED_ONLY_SIZE[0], RED_ONLY_SIZE[1]))
        renderer2.present()

        clock.tick(60)

    # After while loop ends for this seed (either user closed or all goals done):
    save_data_log(s)
    print(f"Finished environment seed: {s}")

# After finishing all seeds:
if ser is not None:
    ser.close()

pygame.quit()
print("All seeds completed. Exiting.")


pygame 2.6.0 (SDL 2.28.4, Python 3.10.9)
Hello from the pygame community. https://www.pygame.org/contribute.html
Error: Could not open serial port: module 'serial' has no attribute 'Serial'
Force sensor not available. Using keyboard/joystick controls only.
Joystick initialized: DualSense Wireless Controller


Exception: Can't get attribute '_make_function' on <module 'cloudpickle.cloudpickle' from 'c:\\Users\\mhfar\\anaconda3\\lib\\site-packages\\cloudpickle\\cloudpickle.py'>
Exception: Can't get attribute '_make_function' on <module 'cloudpickle.cloudpickle' from 'c:\\Users\\mhfar\\anaconda3\\lib\\site-packages\\cloudpickle\\cloudpickle.py'>
Exception: Can't get attribute '_make_function' on <module 'cloudpickle.cloudpickle' from 'c:\\Users\\mhfar\\anaconda3\\lib\\site-packages\\cloudpickle\\cloudpickle.py'>


Environment initialized with fixed seed 0.
AI control enabled
Trial recorded: {'timestamp': '2025-03-17 01:12:57', 'mode': 'AI', 'trial_duration': 4.58158802986145, 'trial_outcome': 'success', 'goal_reached': (517.9239668585399, 247.31394185221012), 'trajectory': [(600.0, 400.0, 0.95), (600.0, 400.0, 0.95), (600.0, 400.0, 0.95), (600.0, 400.0, 0.95), (600.0, 400.0, 0.95), (600.0, 400.0, 0.95), (600.0, 400.0, 0.95), (600.0, 400.0, 0.95), (600.0, 400.0, 0.95), (600.0, 400.0, 0.95), (600.0, 400.0, 0.95), (600.0, 400.0, 0.95), (600.0, 400.0, 0.95), (600.0, 400.0, 0.95), (600.0, 400.0, 0.95), (600.0, 400.0, 0.95), (600.0, 400.0, 0.95), (600.0, 400.0, 0.95), (600.0, 400.0, 0.95), (600.0, 400.0, 0.95), (600.0, 400.0, 0.95), (600.0, 400.0, 0.95), (600.0, 400.0, 0.95), (600.0, 400.0, 0.95), (600.0, 400.0, 0.95), (600.0, 400.0, 0.95), (600.0, 400.0, 0.95), (600.0, 400.0, 0.95), (600.0, 400.0, 0.95), (600.0, 400.0, 0.95), (600.0, 400.0, 0.95), (600.0, 400.0, 0.95), (600.0, 400.0, 0.95), (600.0, 4

IndexError: list index out of range