In [1]:
import os
import datetime
import pygame
import random
import math
import time
import numpy as np
import json
import sys
from stable_baselines3 import PPO
from pygame._sdl2.video import Window, Renderer, Texture

# -----------------------------------------------------------
# User Study Configuration
# -----------------------------------------------------------

# Goals and scenarios configuration
REQUIRED_SUCCESSES_PER_GOAL = 7
ENVIRONMENT_SEEDS = [1, 2, 58, 487]  # Seeds to cycle through
CURRENT_SEED_INDEX = 0

# Data organization
BASE_DATA_FOLDER = "user_study_data"
SESSION_ID = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')

# Tracking variables
goal_completion_counts = {}  # Will track successes per goal for current environment

# -----------------------------------------------------------
# Force Sensor Integration (from your original code)
# -----------------------------------------------------------

import tkinter as tk
from tkinter import ttk
import serial as pyserial
import threading
from collections import deque

# Global flag to choose input mode (False: PS5/keyboard, True: force sensor)
USE_FORCE_SENSOR = False  # You can toggle this with the F key

# Global variable for force sensor horizontal input (fx, fy)
force_sensor_input = [0.0, 0.0]

# Global flag to indicate if force sensor is available
FORCE_SENSOR_AVAILABLE = False

# Force sensor scaling constants
FORCE_SENSOR_SCALE_X = 5.0
FORCE_SENSOR_SCALE_Y = 5.0

# Calibration constant
XY_FORCE_CAL = 1.33
Z_SMOOTH_ALPHA = 0.2

# Global variables for serial reading and smoothing
sample_prev_fs = None
sample_curr_fs = None
sample_lock_fs = threading.Lock()
last_fx_smooth = None
last_fy_smooth = None

# Try to open the serial port
try:
    ser = pyserial.Serial('COM5', 115200, timeout=0.01)
    FORCE_SENSOR_AVAILABLE = True
    print("Force sensor connected successfully.")
except Exception as e:
    print("Error: Could not open serial port:", e)
    ser = None
    FORCE_SENSOR_AVAILABLE = False
    print("Force sensor not available. Using keyboard/joystick controls only.")

def serial_reader_fs():
    global sample_prev_fs, sample_curr_fs, last_fx_smooth, last_fy_smooth, force_sensor_input
    while True:
        if ser is None:
            time.sleep(0.1)
            continue
        try:
            line = ser.readline().decode('utf-8').strip()
            if not line:
                continue
            tokens = line.split(',')
            if len(tokens) != 8:
                continue
            try:
                values = list(map(float, tokens))
            except Exception:
                continue
            sample = {
                'weights': values[0:4],
                'fx': values[4],
                'fy': values[5],
                'magnitude': values[6],
                'angle': values[7],
                'timestamp': time.time()
            }
            with sample_lock_fs:
                if sample_curr_fs is None:
                    sample_curr_fs = sample
                    sample_prev_fs = sample
                else:
                    sample_prev_fs = sample_curr_fs
                    sample_curr_fs = sample

            # Interpolate horizontal forces
            with sample_lock_fs:
                sp = sample_prev_fs
                sc = sample_curr_fs
            now = time.time()
            if sp is None or sc is None:
                continue
            t0 = sp['timestamp']
            t1 = sc['timestamp']
            fraction = 1.0 if t1 == t0 else (now - t0) / (t1 - t0)
            fraction = max(0.0, min(1.0, fraction))
            fx_interp = sp['fx'] * (1 - fraction) + sc['fx'] * fraction
            fy_interp = sp['fy'] * (1 - fraction) + sc['fy'] * fraction

            # Simple exponential smoothing for fx and fy
            if last_fx_smooth is None:
                smoothed_fx = fx_interp
            else:
                smoothed_fx = last_fx_smooth + Z_SMOOTH_ALPHA * (fx_interp - last_fx_smooth)
            if last_fy_smooth is None:
                smoothed_fy = fy_interp
            else:
                smoothed_fy = last_fy_smooth + Z_SMOOTH_ALPHA * (fy_interp - last_fy_smooth)
            last_fx_smooth = smoothed_fx
            last_fy_smooth = smoothed_fy

            # Scale with calibration factor
            final_fx = smoothed_fx * XY_FORCE_CAL
            final_fy = smoothed_fy * XY_FORCE_CAL

            # Apply additional scaling for game environment
            final_fx = final_fx * FORCE_SENSOR_SCALE_X
            final_fy = final_fy * FORCE_SENSOR_SCALE_Y

            # Update the global force sensor input vector
            force_sensor_input = [final_fx, final_fy]
        except Exception as e:
            print("Serial reader error:", e)
            time.sleep(0.01)

# Start the force sensor serial reader thread if sensor is available
if FORCE_SENSOR_AVAILABLE:
    threading.Thread(target=serial_reader_fs, daemon=True).start()

# -----------------------------------------------------------
# Game initialization
# -----------------------------------------------------------

pygame.init()
pygame.joystick.init()

USE_AI_CONTROL = False

# Config / Constants
FULL_VIEW_SIZE = (1200, 800)
RED_ONLY_SIZE  = (1200, 800)

NOISE_MAGNITUDE = 2.5
MIN_NOISE = 0.0
MAX_NOISE = 2.0
NOISE_STEP = 0.1

OLD_WINDOW_SIZE   = (600, 600)
SCALING_FACTOR_X  = FULL_VIEW_SIZE[0] / OLD_WINDOW_SIZE[0]
SCALING_FACTOR_Y  = FULL_VIEW_SIZE[1] / OLD_WINDOW_SIZE[1]
SCALING_FACTOR    = (SCALING_FACTOR_X + SCALING_FACTOR_Y) / 2

WHITE  = (255, 255, 255)
BLACK  = (0, 0, 0)
RED    = (255, 0, 0)
GREEN  = (0, 200, 0)
BLUE   = (0, 0, 255)
YELLOW = (255, 255, 0)
GRAY   = (128, 128, 128)

FONT_COLOR = (0, 0, 0)
FONT_SIZE = int(16 * SCALING_FACTOR)
ARROW_LENGTH = int(60 * SCALING_FACTOR)

OBSTACLE_RADIUS      = int(10 * SCALING_FACTOR)
COLLISION_BUFFER     = int(5 * SCALING_FACTOR)
ENABLE_OBSTACLES     = True
MAX_SPEED            = 3 * SCALING_FACTOR

DOT_RADIUS            = int(15 * SCALING_FACTOR)
TARGET_RADIUS         = int(10 * SCALING_FACTOR)
GOAL_DETECTION_RADIUS = DOT_RADIUS + TARGET_RADIUS

GHOST_TRAIL_DURATION  = 3.0
recent_positions      = []
last_reset_time       = time.time()

RECENT_DIR_LOOKBACK   = 1.0
GOAL_SWITCH_THRESHOLD = 0.05

WINDOW_CENTER = (FULL_VIEW_SIZE[0] // 2, FULL_VIEW_SIZE[1] // 2)
START_POS = [WINDOW_CENTER[0], WINDOW_CENTER[1]]
dot_pos   = START_POS.copy()

gamma         = 0.2
reached_goal  = False
targets       = []
current_target_idx = 0
obstacles     = []

# Counters for goal completions and failures
goal_counters = {}  # Will track completions per goal
failure_counter = 0  # Combined counter for collisions and manual resets

# Decide if we skip "recent direction" in target detection
USE_RAW_ONLY_FOR_GOAL_DETECTION = True

# Joystick
joystick = None
if pygame.joystick.get_count() > 0:
    joystick = pygame.joystick.Joystick(0)
    joystick.init()
    print("Joystick initialized:", joystick.get_name())
else:
    print("No joystick detected.")

AXIS_L2 = 4
AXIS_R2 = 5

# -----------------------------------------------------------
# Windows and Renderers
# -----------------------------------------------------------
window1 = Window("2D Environment: Full View", size=FULL_VIEW_SIZE)
renderer1 = Renderer(window1, vsync=True)

window2 = Window("2D Environment: Red Arrow Only", size=RED_ONLY_SIZE)
renderer2 = Renderer(window2, vsync=True)

# -----------------------------------------------------------
# Surfaces
# -----------------------------------------------------------
def create_compatible_surface(size):
    return pygame.Surface(size, flags=pygame.SRCALPHA)

surface_full = create_compatible_surface(FULL_VIEW_SIZE)
surface_red_only = create_compatible_surface(RED_ONLY_SIZE)

# Load font
font = pygame.font.Font(None, FONT_SIZE)

def surface_to_texture(renderer, surf):
    if surf.get_bitsize() != 32:
        surf = surf.convert_alpha()
    return Texture.from_surface(renderer, surf)

# -----------------------------------------------------------
# Helper functions
# -----------------------------------------------------------
def distance(pos1, pos2):
    return math.hypot(pos1[0] - pos2[0], pos1[1] - pos2[1])

def line_circle_intersection(start, end, circle_center, radius):
    dx = end[0] - start[0]
    dy = end[1] - start[1]
    cx = circle_center[0] - start[0]
    cy = circle_center[1] - start[1]
    l2 = dx*dx + dy*dy
    if l2 == 0:
        return distance(start, circle_center) <= radius
    t = max(0, min(1, (cx*dx + cy*dy) / l2))
    proj_x = start[0] + t * dx
    proj_y = start[1] + t * dy
    return distance((proj_x, proj_y), circle_center) <= radius

def check_collision(pos, new_pos):
    """Check if the line from pos->new_pos intersects any obstacle."""
    if not ENABLE_OBSTACLES:
        return False
    for obstacle_pos in obstacles:
        if line_circle_intersection(pos, new_pos, obstacle_pos,
                                    OBSTACLE_RADIUS + COLLISION_BUFFER):
            return True
    return False

def inside_obstacle(pos):
    """Check if the dot center is inside an obstacle's area."""
    if not ENABLE_OBSTACLES:
        return False
    for obstacle_pos in obstacles:
        if distance(pos, obstacle_pos) <= (OBSTACLE_RADIUS + DOT_RADIUS):
            return True
    return False

def get_recent_direction():
    if len(recent_positions) < 2:
        return [0, 0]
    current_time = time.time()
    valid_points = []
    for (x, y, t) in reversed(recent_positions):
        if (current_time - t) <= RECENT_DIR_LOOKBACK:
            valid_points.append((x, y, t))
        else:
            break
    if len(valid_points) < 2:
        return [0, 0]
    valid_points.sort(key=lambda p: p[2])
    x1, y1, t1 = valid_points[0]
    x2, y2, t2 = valid_points[-1]
    dt = t2 - t1
    if dt < 0.001:
        return [0, 0]
    vx = (x2 - x1) / dt
    vy = (y2 - y1) / dt
    mag = math.hypot(vx, vy)
    return [vx/mag, vy/mag] if mag > 0 else [0, 0]

# -----------------------------------------------------------
# Potential-field approach for "perfect movement"
# -----------------------------------------------------------
def compute_perfect_direction(dot_pos, goal_pos, obstacles):
    """
    Returns a unit vector that tries to go from dot_pos to goal_pos but
    also repels from obstacles.
    """
    gx = goal_pos[0] - dot_pos[0]
    gy = goal_pos[1] - dot_pos[1]
    goal_dist = math.hypot(gx, gy)
    if goal_dist < 1e-6:
        return [0, 0]
    goal_dir = [gx / goal_dist, gy / goal_dist]

    repulse_x = 0.0
    repulse_y = 0.0
    repulsion_radius = 27 * SCALING_FACTOR
    repulsion_gain   = 30000.0

    for obs in obstacles:
        dx = dot_pos[0] - obs[0]
        dy = dot_pos[1] - obs[1]
        dist_o = math.hypot(dx, dy)
        if dist_o < 1e-6:
            continue
        if dist_o < repulsion_radius:
            push_dir_x = dx / dist_o
            push_dir_y = dy / dist_o
            strength = repulsion_gain / (dist_o**2)
            repulse_x += push_dir_x * strength
            repulse_y += push_dir_y * strength

    w_px = goal_dir[0] + repulse_x
    w_py = goal_dir[1] + repulse_y
    norm = math.hypot(w_px, w_py)
    if norm < 1e-6:
        return [0, 0]
    return [w_px / norm, w_py / norm]

# -----------------------------------------------------------
# PPO Model for gamma
# -----------------------------------------------------------
class GammaPredictor:
    def __init__(self, model_path="gamma_ppo_model.zip"):
        try:
            self.model = PPO.load(model_path)
        except:
            self.model = None
        self.max_dist = np.sqrt(FULL_VIEW_SIZE[0]**2 + FULL_VIEW_SIZE[1]**2)

    def prepare_observation(self, dot_pos, target_pos, human_input):
        dot_pos = np.array(dot_pos, dtype=np.float32)
        target_pos = np.array(target_pos, dtype=np.float32)
        to_target = target_pos - dot_pos
        dist = np.linalg.norm(to_target)
        perfect_dir = to_target / dist if dist > 0 else np.array([0, 0], dtype=np.float32)
        h_mag = np.linalg.norm(human_input)
        human_dir = human_input / h_mag if h_mag > 0 else np.array([0, 0], dtype=np.float32)
        normalized_dist = dist / self.max_dist
        # Since obstacle info is not provided here, default obs_dist_ratio to 1.0:
        obs_dist_ratio = 1.0
        obs = np.concatenate([
            dot_pos,
            human_dir,
            target_pos,
            perfect_dir,
            [normalized_dist],
            [obs_dist_ratio]
        ]).astype(np.float32)
        return obs

    def predict_gamma(self, dot_pos, target_pos, human_input):
        if self.model is None:
            return 0.2  # fallback
        obs = self.prepare_observation(dot_pos, target_pos, human_input)
        obs_batched = obs[np.newaxis, :]
        action, _ = self.model.predict(obs_batched, deterministic=True)
        return float(action[0])

gamma_predictor = GammaPredictor()

# -----------------------------------------------------------
# Predict human target function
# -----------------------------------------------------------
def predict_human_target(human_input):
    global current_target_idx
    dist_to_current = distance(dot_pos, targets[current_target_idx])
    close_threshold = GOAL_DETECTION_RADIUS * 2
    if dist_to_current < close_threshold:
        return current_target_idx

    if human_input[0] == 0 and human_input[1] == 0:
        return current_target_idx

    h_mag = math.hypot(human_input[0], human_input[1])
    h_dir = [human_input[0]/h_mag, human_input[1]/h_mag] if h_mag > 0 else [0, 0]

    if USE_RAW_ONLY_FOR_GOAL_DETECTION:
        recent_dir = [0, 0]  # ignore recent direction
    else:
        recent_dir = get_recent_direction()

    best_score = float('-inf')
    best_idx   = current_target_idx

    for i, targ in enumerate(targets):
        to_tx = targ[0] - dot_pos[0]
        to_ty = targ[1] - dot_pos[1]
        to_mag = math.hypot(to_tx, to_ty)
        if to_mag == 0:
            continue
        to_dir = [to_tx/to_mag, to_ty/to_mag]
        align_human  = (h_dir[0]*to_dir[0] + h_dir[1]*to_dir[1])
        align_recent = (recent_dir[0]*to_dir[0] + recent_dir[1]*to_dir[1])
        score = (align_human * 0.8) + (align_recent * 0.2)
        if score > best_score:
            best_score = score
            best_idx   = i

    return best_idx

# -----------------------------------------------------------
# Data Logging Functions
# -----------------------------------------------------------
data_log = []
trial_start_time = time.time()
current_trajectory = []
trial_outcome = None

def setup_data_folders():
    """Setup the directory structure for organizing data by mode and environment"""
    global BASE_DATA_FOLDER, SESSION_ID
    
    # Create base data folder with session ID
    session_folder = os.path.join(BASE_DATA_FOLDER, SESSION_ID)
    
    # Create folders for different control modes
    ai_folder = os.path.join(session_folder, "ai_control")
    manual_folder = os.path.join(session_folder, "manual_control")
    
    # Create environment seed folders under each control mode
    for seed in ENVIRONMENT_SEEDS:
        os.makedirs(os.path.join(ai_folder, f"env_seed_{seed}"), exist_ok=True)
        os.makedirs(os.path.join(manual_folder, f"env_seed_{seed}"), exist_ok=True)
    
    # Create a folder for summary statistics
    os.makedirs(os.path.join(session_folder, "summary"), exist_ok=True)
    
    print(f"Data folders created at {session_folder}")
    return session_folder

def save_trajectory_data():
    """Save trajectory data to the appropriate folder based on control mode and environment"""
    global data_log, BASE_DATA_FOLDER, SESSION_ID, ENVIRONMENT_SEEDS, CURRENT_SEED_INDEX
    
    if not data_log:
        return
    
    # Get current control mode and seed
    mode_folder = "ai_control" if USE_AI_CONTROL else "manual_control"
    current_seed = ENVIRONMENT_SEEDS[CURRENT_SEED_INDEX]
    
    # Define path
    session_folder = os.path.join(BASE_DATA_FOLDER, SESSION_ID)
    environment_folder = os.path.join(session_folder, mode_folder, f"env_seed_{current_seed}")
    
    # Generate filename with timestamp
    timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
    filename = os.path.join(environment_folder, f"trajectory_{timestamp}.json")
    
    # Save data
    with open(filename, "w") as f:
        json.dump(data_log, f, indent=2)
    
    print(f"Trajectory data saved to {filename}")
    
    # Also save a summary file
    summary_folder = os.path.join(session_folder, "summary")
    summary_file = os.path.join(summary_folder, "session_summary.json")
    
    # Create or update summary data
    summary_data = {}
    if os.path.exists(summary_file):
        with open(summary_file, "r") as f:
            summary_data = json.load(f)
    
    # Add or update environment stats
    env_key = f"env_seed_{current_seed}"
    mode_key = "ai_control" if USE_AI_CONTROL else "manual_control"
    
    if env_key not in summary_data:
        summary_data[env_key] = {}
    
    if mode_key not in summary_data[env_key]:
        summary_data[env_key][mode_key] = {
            "total_trials": 0,
            "successful_trials": 0,
            "collisions": 0,
            "manual_resets": 0,
            "goal_completions": {}
        }
    
    # Update statistics
    for trial in data_log:
        summary_data[env_key][mode_key]["total_trials"] += 1
        
        if trial["trial_outcome"] == "success":
            summary_data[env_key][mode_key]["successful_trials"] += 1
            
            # Track goal completions
            goal = str(trial.get("goal_reached", "unknown"))
            if goal not in summary_data[env_key][mode_key]["goal_completions"]:
                summary_data[env_key][mode_key]["goal_completions"][goal] = 0
            summary_data[env_key][mode_key]["goal_completions"][goal] += 1
            
        elif trial["trial_outcome"] == "collision":
            summary_data[env_key][mode_key]["collisions"] += 1
            
        elif trial["trial_outcome"] == "manual_reset":
            summary_data[env_key][mode_key]["manual_resets"] += 1
    
    # Save updated summary
    with open(summary_file, "w") as f:
        json.dump(summary_data, f, indent=2)
    
    # Clear data log for next set of trials
    data_log.clear()

# -----------------------------------------------------------
# Movement functions
# -----------------------------------------------------------
def move_dot(human_input):
    global dot_pos, gamma, reached_goal, current_target_idx, USE_AI_CONTROL, trial_outcome
    global goal_completion_counts
    
    h_dx, h_dy = human_input
    h_mag = math.hypot(h_dx, h_dy)
    h_dir = [h_dx / h_mag, h_dy / h_mag] if h_mag > 0 else [0, 0]

    target_pos = targets[current_target_idx]
    w_dir = compute_perfect_direction(dot_pos, target_pos, obstacles)

    input_mag = min(max(h_mag / MAX_SPEED, 0), 1)
    step_size = MAX_SPEED * input_mag

    # Calculate distances to target and obstacles for gamma calculation
    dist_to_target = distance(dot_pos, target_pos)
    
    # Find closest obstacle distance
    min_obs_distance = float('inf')
    if obstacles:
        for obs in obstacles:
            obs_dist = distance(dot_pos, obs)
            min_obs_distance = min(min_obs_distance, obs_dist)
    else:
        min_obs_distance = float('inf')  # No obstacles
    
    # If using AI control, calculate gamma based on distances with some noise
    if USE_AI_CONTROL and h_mag > 0:
        # Define thresholds
        goal_threshold = GOAL_DETECTION_RADIUS * 3  # Radius around goal
        obs_threshold = (OBSTACLE_RADIUS + DOT_RADIUS) * 3  # Radius around obstacles
        
        # Base gamma calculation (REVERSED from original behavior)
        # Higher when near goals/obstacles, lower when far away
        base_gamma = 0.2  # Default low value when far from everything
        
        # Increase gamma when near goal
        if dist_to_target < goal_threshold:
            goal_factor = 1.0 - (dist_to_target / goal_threshold)  # 1 when at goal, 0 when at threshold
            base_gamma = max(base_gamma, 0.2 + 0.5 * goal_factor)  # Max 0.7 near goal
            
        # Increase gamma when near obstacles (highest priority)
        if min_obs_distance < obs_threshold:
            obs_factor = 1.0 - (min_obs_distance / obs_threshold)  # 1 when at obstacle, 0 when at threshold
            base_gamma = max(base_gamma, 0.4 + 0.5 * obs_factor)  # Max 0.9 near obstacle
            
        # When near both goal and obstacle, prioritize obstacle avoidance
        if dist_to_target < goal_threshold and min_obs_distance < obs_threshold:
            base_gamma = max(base_gamma, 0.7)  # Minimum of 0.7 when near both
        
        # Add some noise to make it look like an AI model
        noise = random.uniform(-0.05, 0.05)  # Small noise
        final_gamma = base_gamma + noise
        gamma = max(0.0, min(1.0, final_gamma))  # Clamp to [0, 1]

    w_move_x = gamma * w_dir[0] * step_size
    w_move_y = gamma * w_dir[1] * step_size

    if h_mag > 0:
        noise_x = np.random.normal(0, NOISE_MAGNITUDE)
        noise_y = np.random.normal(0, NOISE_MAGNITUDE)
        noisy_dx = h_dir[0] + noise_x
        noisy_dy = h_dir[1] + noise_y
        nm = math.hypot(noisy_dx, noisy_dy)
        if nm > 0:
            noisy_dx /= nm
            noisy_dy /= nm
        h_move_x = (1 - gamma) * noisy_dx * step_size
        h_move_y = (1 - gamma) * noisy_dy * step_size
    else:
        h_move_x, h_move_y = 0, 0

    final_dx = w_move_x + h_move_x
    final_dy = w_move_y + h_move_y
    new_x = dot_pos[0] + final_dx
    new_y = dot_pos[1] + final_dy

    if not check_collision(dot_pos, [new_x, new_y]):
        dot_pos[0] = max(0, min(FULL_VIEW_SIZE[0], new_x))
        dot_pos[1] = max(0, min(FULL_VIEW_SIZE[1], new_y))

    if inside_obstacle(dot_pos):
        print("Collision with obstacle -> resetting!")
        trial_outcome = "collision"
        global failure_counter
        failure_counter += 1
        reset()
        return [0,0], [0,0], [0,0]

    final_mag = math.hypot(final_dx, final_dy)
    x_dir = [final_dx / final_mag, final_dy / final_mag] if final_mag > 0 else [0, 0]

    dist_to_goal = distance(dot_pos, target_pos)
    if dist_to_goal < GOAL_DETECTION_RADIUS:
        reached_goal = True
        trial_outcome = "success"
        
        # Update general goal counter (for display)
        if current_target_idx not in goal_counters:
            goal_counters[current_target_idx] = 1
        else:
            goal_counters[current_target_idx] += 1
        
        # Update user study goal completion tracker
        goal_key = str(current_target_idx)
        if goal_key not in goal_completion_counts:
            goal_completion_counts[goal_key] = 1
        else:
            goal_completion_counts[goal_key] += 1
            
        # Check if we've reached the required number of completions for this goal
        print(f"Goal {current_target_idx} completed {goal_completion_counts.get(goal_key, 0)}/{REQUIRED_SUCCESSES_PER_GOAL} times")
        
        pygame.time.set_timer(pygame.USEREVENT, 1000)  # Delay before reset

    return h_dir, w_dir, x_dir

# -----------------------------------------------------------
# Reset function with improved data management
# -----------------------------------------------------------
def reset():
    global dot_pos, reached_goal, current_target_idx, gamma
    global recent_positions, last_reset_time, trial_start_time, current_trajectory, trial_outcome
    global failure_counter, data_log
    
    if trial_start_time is not None and len(current_trajectory) > 0:
        trial_duration = time.time() - trial_start_time
        goal_reached = targets[current_target_idx] if reached_goal else None
        mode = "AI" if USE_AI_CONTROL else "Manual"
        trial_timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        
        if trial_outcome == "manual_reset":
            failure_counter += 1
            
        data_log.append({
            "timestamp": trial_timestamp,
            "mode": mode,
            "trial_duration": trial_duration,
            "trial_outcome": trial_outcome if trial_outcome is not None else "manual_reset",
            "goal_reached": goal_reached,
            "target_index": current_target_idx,
            "trajectory": current_trajectory.copy()
        })
        
        print(f"Trial recorded: {trial_outcome} for goal {current_target_idx}")
    
    # Check if we need to change environment
    check_environment_completion()
    
    dot_pos = START_POS.copy()
    reached_goal = False
    current_target_idx = 0
    gamma = 0.2
    recent_positions.clear()
    last_reset_time = time.time()
    trial_start_time = time.time()
    current_trajectory.clear()
    trial_outcome = None
    pygame.time.set_timer(pygame.USEREVENT, 0)

# -----------------------------------------------------------
# Environment management functions
# -----------------------------------------------------------
def check_environment_completion():
    """Check if all goals have been completed the required number of times"""
    global goal_completion_counts, REQUIRED_SUCCESSES_PER_GOAL, targets
    
    if not targets:  # No targets yet
        return False
    
    # Check if we have enough completions for each goal
    environment_complete = True
    for i in range(len(targets)):
        goal_key = str(i)
        completions = goal_completion_counts.get(goal_key, 0)
        if completions < REQUIRED_SUCCESSES_PER_GOAL:
            environment_complete = False
            break
    
    # If all goals have been completed enough times, change environment
    if environment_complete and len(goal_completion_counts) >= len(targets):
        change_environment()
        return True
    
    return False

def change_environment():
    """Change to the next environment by updating the seed"""
    global CURRENT_SEED_INDEX, ENVIRONMENT_SEEDS, goal_completion_counts, data_log
    
    # Save current trajectory data before changing environment
    save_trajectory_data()
    
    # Move to next seed
    CURRENT_SEED_INDEX = (CURRENT_SEED_INDEX + 1) % len(ENVIRONMENT_SEEDS)
    new_seed = ENVIRONMENT_SEEDS[CURRENT_SEED_INDEX]
    
    # Reset goal completion counts for new environment
    goal_completion_counts = {}
    
    # Initialize the new environment
    initialize_environment_fixed(new_seed)
    
    # Display message about environment change
    print(f"\n========================================")
    print(f"ENVIRONMENT CHANGED TO SEED {new_seed}")
    print(f"All goals completed {REQUIRED_SUCCESSES_PER_GOAL} times!")
    print(f"========================================\n")
    
    # Show a message on screen for a few seconds
    display_environment_change_message(new_seed)

def display_environment_change_message(seed):
    """Display a message about environment change"""
    # We'll implement this in the main loop with a timer
    global environment_change_message_time
    environment_change_message_time = time.time()
    
# Environment change message timer
environment_change_message_time = None
ENVIRONMENT_MESSAGE_DURATION = 5.0  # Show message for 5 seconds

# -----------------------------------------------------------
# Main function and execution
# -----------------------------------------------------------
def main():
    """Main function to run the user study"""
    global running, environment_change_message_time
    
    # Setup data folders for the study
    setup_data_folders()
    
    # Set the initial environment
    current_seed = ENVIRONMENT_SEEDS[CURRENT_SEED_INDEX]
    initialize_environment_fixed(current_seed)
    
    # Main loop
    running = True
    clock = pygame.time.Clock()
    
    print("\n=== USER STUDY MODE ACTIVE ===")
    print(f"Required successes per goal: {REQUIRED_SUCCESSES_PER_GOAL}")
    print(f"Environments: {ENVIRONMENT_SEEDS}")
    print("=================================\n")

    while running:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                # Save any remaining data before quitting
                save_trajectory_data()
                running = False
            if event.type == pygame.KEYDOWN:
                # Press F to toggle between force sensor and PS5/keyboard control
                if event.key == pygame.K_f:
                    if not FORCE_SENSOR_AVAILABLE and not USE_FORCE_SENSOR:
                        print("Force sensor not available!")
                    else:
                        USE_FORCE_SENSOR = not USE_FORCE_SENSOR
                        print("Force sensor mode:", USE_FORCE_SENSOR)
                if event.key == pygame.K_LEFTBRacket:
                    NOISE_MAGNITUDE = max(MIN_NOISE, NOISE_MAGNITUDE - NOISE_STEP)
                elif event.key == pygame.K_RIGHTBRacket:
                    NOISE_MAGNITUDE = min(MAX_NOISE, NOISE_MAGNITUDE + NOISE_STEP)
                if event.key == pygame.K_r:
                    trial_outcome = "manual_reset"
                    reset()
                # Add a key to save data without ending the study
                if event.key == pygame.K_s and (pygame.key.get_mods() & pygame.KMOD_CTRL):
                    save_trajectory_data()
                    print("Data saved manually!")
            if joystick and event.type == pygame.JOYBUTTONDOWN:
                if event.button == 2:  # X button on PS controller
                    trial_outcome = "manual_reset"
                    reset()
                if event.button == 3:  # Y button on PS controller
                    USE_AI_CONTROL = not USE_AI_CONTROL
                    print(f"{'AI' if USE_AI_CONTROL else 'Manual'} control enabled")
            if event.type == pygame.USEREVENT:
                if not reached_goal:
                    trial_outcome = "timeout"
                reset()

        if not reached_goal:
            # Use force sensor input if enabled; otherwise, use keyboard/joystick input.
            if USE_FORCE_SENSOR and FORCE_SENSOR_AVAILABLE:
                dx, dy = force_sensor_input
            else:
                # If force sensor is selected but not available, silently fall back to keyboard/joystick
                if USE_FORCE_SENSOR and not FORCE_SENSOR_AVAILABLE:
                    USE_FORCE_SENSOR = False
                    
                dx, dy = 0.0, 0.0
                keys = pygame.key.get_pressed()
                if keys[pygame.K_LEFT]:
                    dx -= 1
                if keys[pygame.K_RIGHT]:
                    dx += 1
                if keys[pygame.K_UP]:
                    dy -= 1
                if keys[pygame.K_DOWN]:
                    dy += 1

                if joystick:
                    axis_0 = joystick.get_axis(0)
                    axis_1 = joystick.get_axis(1)
                    deadzone = 0.1
                    if abs(axis_0) > deadzone or abs(axis_1) > deadzone:
                        dx = axis_0
                        dy = axis_1
                    else:
                        dx, dy = 0.0, 0.0
                    l2_val = joystick.get_axis(AXIS_L2)
                    r2_val = joystick.get_axis(AXIS_R2)
                    if l2_val > 0.1:
                        gamma = max(0.0, gamma - 0.01)
                    if r2_val > 0.1:
                        gamma = min(1.0, gamma + 0.01)

            if abs(dx) < 0.1 and abs(dy) < 0.1:
                dx, dy = 0.0, 0.0

            # For keyboard/joystick, apply MAX_SPEED scaling
            # (For force sensor, scaling is already applied in the serial_reader_fs function)
            if not USE_FORCE_SENSOR:
                dx *= MAX_SPEED
                dy *= MAX_SPEED
                
            human_input = [dx, dy]

            proposed_idx = predict_human_target(human_input)
            current_target_idx = proposed_idx

            h_dir, w_dir, x_dir = move_dot(human_input)
            if not reached_goal:
                recent_positions.append((dot_pos[0], dot_pos[1], time.time()))
                current_trajectory.append((dot_pos[0], dot_pos[1]))
        else:
            h_dir, w_dir, x_dir = [0,0], [0,0], [0,0]

        render_full_view(surface_full, h_dir, w_dir, x_dir)
        render_red_only(surface_red_only, x_dir)

        tex1 = surface_to_texture(renderer1, surface_full)
        tex2 = surface_to_texture(renderer2, surface_red_only)

        renderer1.clear()
        tex1.draw(dstrect=(0, 0, FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1]))
        renderer1.present()

        renderer2.clear()
        tex2.draw(dstrect=(0, 0, RED_ONLY_SIZE[0], RED_ONLY_SIZE[1]))
        renderer2.present()

        # Environment message timer
        if environment_change_message_time is not None:
            if time.time() - environment_change_message_time > ENVIRONMENT_MESSAGE_DURATION:
                environment_change_message_time = None

        clock.tick(60)

    # Be sure to close the serial port when exiting
    if ser is not None:
        ser.close()
        
    pygame.quit()

# -----------------------------------------------------------
# Entry point - run as standalone executable
# -----------------------------------------------------------
if __name__ == "__main__":
    try:
        main()
    except Exception as e:
        # Log any uncaught exceptions to a file
        error_log_path = os.path.join(BASE_DATA_FOLDER, f"error_log_{SESSION_ID}.txt")
        os.makedirs(os.path.dirname(error_log_path), exist_ok=True)
        
        with open(error_log_path, "w") as f:
            f.write(f"Error occurred at {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
            f.write(f"Exception: {str(e)}\n")
            
            import traceback
            traceback.print_exc(file=f)
        
        print(f"An error occurred: {e}")
        print(f"Error details have been saved to {error_log_path}")
        
        # Try to properly close resources even on error
        if ser is not None:
            ser.close()
        pygame.quit()
        sys.exit(1)

pygame 2.6.1 (SDL 2.28.4, Python 3.12.3)
Hello from the pygame community. https://www.pygame.org/contribute.html
Error: Could not open serial port: could not open port 'COM5': FileNotFoundError(2, 'The system cannot find the file specified.', None, 2)
Force sensor not available. Using keyboard/joystick controls only.
Joystick initialized: DualSense Wireless Controller
Data folders created at user_study_data\20250228_141742
An error occurred: name 'initialize_environment_fixed' is not defined
Error details have been saved to user_study_data\error_log_20250228_141742.txt


AttributeError: 'tuple' object has no attribute 'tb_frame'