In [None]:
import numpy as np
import pybullet as p
import pybullet_data
import math
import time
import gymnasium as gym
from gymnasium import spaces

from src.physics_utils import PyBulletSim, CameraManager
from src.robot_manager import RobotController

class RobotGymEnv(gym.Env):
	metadata = {'render_modes': ['human', 'rgb_array']}

	def __init__(self, n_directions=8, render_mode='human', max_steps=200):
		super(RobotGymEnv, self).__init__()
		
		self.n_directions = n_directions
		self.render_mode = render_mode
		self.max_steps = max_steps
		
		# --- Action Space ---
		self.action_space = spaces.Discrete(n_directions + 2)
		
		# --- Observation Space ---
		# [EE_x, EE_y, EE_yaw, Cube_x, Cube_y, Target_x, Target_y]
		self.observation_space = spaces.Box(
			low=-np.inf, high=np.inf, shape=(7,), dtype=np.float32
		)

		# Initialize Simulation
		connection_mode = p.GUI if render_mode == 'human' else p.DIRECT
		self.sim = PyBulletSim(connection_mode)
		
		# --- Load Table & Define Boundaries ---
		p.setAdditionalSearchPath(pybullet_data.getDataPath())
		self.table_id = p.loadURDF("table/table.urdf", basePosition=[0.5, 0, 0], baseOrientation=[0, 0, 1, 1], useFixedBase=True)

		# Define boundaries (Table surface area)
		self.x_min, self.x_max = 0.2, 0.8
		self.y_min, self.y_max = -0.3, 0.3
		self.z_height = 0.65  # Height to hover above table

		# Draw workspace boundary lines
		self._draw_workspace_lines()
		self.target_debug_items = [] # Store IDs of target visualization

		# Load Cube
		self.cube_id = p.loadURDF("cube_small.urdf", basePosition=[0, 0, -10])

		# Initialize Robot
		self.robot = RobotController(
			"franka_panda/panda.urdf", 
			scale=1.5,
			initial_base_pos=[-0.3, 0, self.z_height - 0.4]
		)

		# Set specific speeds for RL steps
		self.robot.ee_velocity = 0.5
		self.robot.gripper_rot_velocity = 2.0

	def _draw_workspace_lines(self):
		"""Draws debug lines to visualize the workspace boundaries."""
		z = self.z_height
		corners = [
			[self.x_min, self.y_min, z],
			[self.x_max, self.y_min, z],
			[self.x_max, self.y_max, z],
			[self.x_min, self.y_max, z]
		]
		
		line_color = [1, 0, 0] # Red
		line_width = 2
		
		# Draw lines connecting corners
		for i in range(4):
			start_pos = corners[i]
			end_pos = corners[(i + 1) % 4]
			p.addUserDebugLine(start_pos, end_pos, lineColorRGB=line_color, lineWidth=line_width)

	def _get_direction_vector(self, action_idx):
		"""Calculates the velocity vector for a movement action index using sin/cos."""
		# Calculate angle for this direction index: i * (2pi / n)
		angle = action_idx * (2 * np.pi / self.n_directions)
		
		# Calculate unit vector components
		dx = math.cos(angle)
		dy = math.sin(angle)
		
		# Scale by robot speed
		vx = dx * self.robot.ee_velocity
		vy = dy * self.robot.ee_velocity
		
		return [vx, vy, 0] # Z velocity is 0 for planar movement

	def _draw_target_circle(self, x, y):
		"""Draws a debug circle at the target position."""
		# Remove old target visualization
		for item in self.target_debug_items:
			p.removeUserDebugItem(item)
		self.target_debug_items = []
		
		radius = 0.03
		z = self.z_height + 0.005
		color = [0, 1, 0] # Green
		num_segments = 12
		
		for i in range(num_segments):
			angle = 2 * math.pi * i / num_segments
			next_angle = 2 * math.pi * (i + 1) / num_segments
			p1 = [x + radius * math.cos(angle), y + radius * math.sin(angle), z]
			p2 = [x + radius * math.cos(next_angle), y + radius * math.sin(next_angle), z]
			item_id = p.addUserDebugLine(p1, p2, lineColorRGB=color, lineWidth=3)
			self.target_debug_items.append(item_id)

	def _get_observation(self):
		"""Returns the state: [EE_x_rel, EE_y_rel, EE_yaw, Cube_x_rel, Cube_y_rel, Target_x_rel, Target_y_rel]"""
		# Get raw positions
		ee_pos, ee_orn = self.robot.get_ee_pose()
		cube_pos, _ = p.getBasePositionAndOrientation(self.cube_id)

		# Calculate center of workspace
		center_x = (self.x_min + self.x_max) / 2
		center_y = (self.y_min + self.y_max) / 2

		# Relative EE Position (XY)
		ee_x_rel = ee_pos[0] - center_x
		ee_y_rel = ee_pos[1] - center_y

		# EE Yaw (Orientation)
		ee_euler = p.getEulerFromQuaternion(ee_orn)
		ee_yaw = ee_euler[2]

		# Relative Cube Position (XY)
		cube_x_rel = cube_pos[0] - center_x
		cube_y_rel = cube_pos[1] - center_y
		
		# Relative Target Position (XY)
		target_x_rel = self.target_pos[0] - center_x
		target_y_rel = self.target_pos[1] - center_y

		return np.array([ee_x_rel, ee_y_rel, ee_yaw, cube_x_rel, cube_y_rel, target_x_rel, target_y_rel], dtype=np.float32)

	def step(self, action):
		self.current_step += 1
		lin_vel = [0, 0, 0]
		rot_vel = 0
		
		# --- Map Discrete Action to Continuous Velocity ---
		if 0 <= action < self.n_directions:
			lin_vel = self._get_direction_vector(action)
		elif action == self.n_directions:
			rot_vel = self.robot.gripper_rot_velocity
		elif action == self.n_directions + 1:
			rot_vel = -self.robot.gripper_rot_velocity

		# --- Apply Control ---
		self.robot.move_ee_velocity(lin_vel, maintain_height=self.z_height)
		self.robot.rotate_gripper_velocity(rot_vel)
		
		# --- Step Simulation ---
		# Execute action for multiple physics steps (Action Repetition)
		# 20 steps * 1/240s ~= 0.083s per RL step. 
		# This makes the robot move ~4cm per step and reduces the step counter frequency.
		for _ in range(20):
			p.stepSimulation()
			if self.render_mode == 'human':
				time.sleep(self.sim.time_step)

		# --- Get Observation ---
		observation = self._get_observation()
		
		# --- Check Boundaries & Calculate Reward ---
		reward = 0.0
		terminated = False
		truncated = False
		
		# Check Step Limit
		if self.current_step >= self.max_steps:
			truncated = True

		ee_pos, _ = self.robot.get_ee_pose()
		cube_pos, _ = p.getBasePositionAndOrientation(self.cube_id)
		
		# 1. Boundary Check (Punishment)
		x, y, z = ee_pos
		if x < self.x_min or x > self.x_max or y < self.y_min or y > self.y_max:
			reward = -10.0
		else:
			# 2. Distance Rewards
			dist_cube_target = math.sqrt((cube_pos[0] - self.target_pos[0])**2 + (cube_pos[1] - self.target_pos[1])**2)
			dist_ee_cube = math.sqrt((ee_pos[0] - cube_pos[0])**2 + (ee_pos[1] - cube_pos[1])**2)
			
			# Action Cost (Encourage efficiency)
			action_cost = -0.05

			# Combined Reward
			# We prioritize moving the cube (2.0) over just reaching the cube (0.5)
			reward = action_cost - (2.0 * dist_cube_target + 0.5 * dist_ee_cube)
			
			# Success Bonus
			if dist_cube_target < 0.05:
				reward += 20.0
				terminated = True
				print("Target Reached!")
		
		info = {}
		
		return observation, reward, terminated, truncated, info

	def reset(self, seed=None, options=None):
		super().reset(seed=seed)
		self.current_step = 0
		
		# 1. Generate Random Spawn Position for Robot
		# Use self.np_random for Gymnasium compatibility
		robot_x = self.np_random.uniform(self.x_min, self.x_max)
		robot_y = self.np_random.uniform(self.y_min, self.y_max)
		target_pos = [robot_x, robot_y, self.z_height]
		
		# Fixed orientation (gripper pointing down)
		target_orn = [math.pi, 0, 0]
		
		# 2. Calculate Joint Angles and Reset Robot
		joint_poses = self.robot.inverse_kinematics(target_pos, target_orn)
		for i in range(7):
			p.resetJointState(self.robot.robot_id, i, joint_poses[i])
			self.robot.joints[i].set_position(joint_poses[i])
			
		# 3. Generate Random Spawn Position for Cube (avoiding robot)
		min_dist = 0.15 
		while True:
			cube_x = self.np_random.uniform(self.x_min, self.x_max)
			cube_y = self.np_random.uniform(self.y_min, self.y_max)
			dist = math.sqrt((cube_x - robot_x)**2 + (cube_y - robot_y)**2)
			if dist > min_dist:
				break
		p.resetBasePositionAndOrientation(self.cube_id, [cube_x, cube_y, self.z_height], [0, 0, 0, 1])
		
		# 4. Generate Random Target Position (avoiding cube start)
		while True:
			target_x = self.np_random.uniform(self.x_min, self.x_max)
			target_y = self.np_random.uniform(self.y_min, self.y_max)
			dist = math.sqrt((target_x - cube_x)**2 + (target_y - cube_y)**2)
			if dist > 0.1: # Ensure target isn't exactly where cube spawns
				break
		
		self.target_pos = [target_x, target_y]
		self._draw_target_circle(target_x, target_y)
		
		# Get initial observation
		observation = self._get_observation()
		
		return observation, {}

	def run_interactive(self):
		"""Allows manual control of the robot using keyboard arrow keys."""
		print("Interactive Mode: Use Arrow Keys to Move, Q/E to Rotate.")
		print("Press 'R' to Reset.")
		
		self.reset()
		
		# Calculate indices for cardinal directions based on n_directions
		# 0 is usually East (Right), increasing counter-clockwise
		idx_right = 0
		idx_up = int(self.n_directions / 4)
		idx_left = int(self.n_directions / 2)
		idx_down = int(3 * self.n_directions / 4)

		try:
			while True:
				keys = p.getKeyboardEvents()
				action = -1 
				
				# Map keys to actions dynamically
				if p.B3G_RIGHT_ARROW in keys and keys[p.B3G_RIGHT_ARROW] & p.KEY_IS_DOWN:
					action = idx_right
				elif p.B3G_UP_ARROW in keys and keys[p.B3G_UP_ARROW] & p.KEY_IS_DOWN:
					action = idx_up
				elif p.B3G_LEFT_ARROW in keys and keys[p.B3G_LEFT_ARROW] & p.KEY_IS_DOWN:
					action = idx_left
				elif p.B3G_DOWN_ARROW in keys and keys[p.B3G_DOWN_ARROW] & p.KEY_IS_DOWN:
					action = idx_down
				
				# Rotation
				if ord('q') in keys and keys[ord('q')] & p.KEY_IS_DOWN:
					action = self.n_directions # Rotate Positive
				elif ord('e') in keys and keys[ord('e')] & p.KEY_IS_DOWN:
					action = self.n_directions + 1 # Rotate Negative

				# Reset
				if ord('r') in keys and keys[ord('r')] & p.KEY_IS_DOWN:
					print("Resetting...")
					self.reset()
					time.sleep(0.5)
					continue

				if action != -1:
					# Step environment (counts as an action step)
					obs, reward, terminated, truncated, info = self.step(action)
					if terminated or truncated:
						print(f"Episode ended. Reward: {reward:.2f}")
						self.reset()
						time.sleep(0.5)
				else:
					# Idle mode: Maintain position, don't count steps
					# We must explicitly stop the robot, otherwise previous velocity persists
					self.robot.move_ee_velocity([0, 0, 0], maintain_height=self.z_height)
					self.robot.rotate_gripper_velocity(0)
					p.stepSimulation()
					time.sleep(self.sim.time_step)
					
		except KeyboardInterrupt:
			print("Exiting Interactive Mode...")

	def close(self):
		self.sim.close()

# --- Main Execution using Gym Env ---
env = RobotGymEnv(n_directions=8, render_mode='human', max_steps=50)
cam = CameraManager(target_pos=[0, 0, 0], distance=2, yaw=20, pitch=-45)
# env.run_interactive()
observation, info = env.reset()

print("Starting Gym Environment Loop...")
try:
	env.reset()
	while True:
		# Sample a random action
		action = env.action_space.sample()
		
		# Step the environment
		observation, reward, terminated, truncated, info = env.step(action)
		
		if terminated or truncated:
			print("Resetting environment...")
			observation, info = env.reset()
			time.sleep(0.5) # Pause briefly on reset

except KeyboardInterrupt:
	print("Stopping...")
finally:
	env.close()