V4.2 Update: in readme.md

In [None]:
# Check for TPU availability and set it up
import os

# Check if TPU is available
try:
    import torch_xla
    import torch_xla.core.xla_model as xm
    print("PyTorch XLA already installed")
    TPU_AVAILABLE = True
except ImportError:
    TPU_AVAILABLE = False
    print("PyTorch XLA not found, will attempt to install")

# Install necessary packages including PyTorch/XLA
!pip install pygame-ce pymunk stable-baselines3 stable-baselines3[extra] shimmy>=2.0 optuna
!pip install -q cloud-tpu-client

if not TPU_AVAILABLE:
    # Check what version of PyTorch we need
    import torch
    if torch.__version__.startswith('2'):
        # For PyTorch 2.x
        !pip install -q torch_xla[tpu]>=2.0
    else:
        # For PyTorch 1.x
        !pip install -q torch_xla

    # Restart runtime (required after installing PyTorch/XLA)
    print("TPU support installed. Please restart the runtime now.")
    import IPython
    IPython.display.display(IPython.display.HTML(
        "<script>google.colab.kernel.invokeFunction('notebook.Runtime.restartRuntime', [], {})</script>"
    ))
else:
    # Initialize TPU if available
    import torch_xla.core.xla_model as xm
    device = xm.xla_device()
    print(f"XLA device detected: {device}")

PyTorch XLA not found, will attempt to install
TPU support installed. Please restart the runtime now.


In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
!ls /content/

'=2.0'	 capture   game_history   sample_data


In [None]:
!rm -r /content/capture
!rm -r /content/game_history
!rm -r /content/logs

rm: cannot remove '/content/logs': No such file or directory


# Classes

## Recorder

In [None]:
import json
import os
import datetime

class Recorder:

    def __init__(self, task: str = "game_history_record"):
        """
        tasks:
        1. game_history_record
        2. temp_memory
        """
        # CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
        CURRENT_DIR = ""
        if task == "game_history_record":
            collection_name = self.get_newest_record_name()
            self.json_file_path = CURRENT_DIR + "./game_history/" + collection_name + ".json"

        # Ensure directory exists
        os.makedirs(os.path.dirname(self.json_file_path), exist_ok=True)

        if os.path.exists(self.json_file_path):
            print("Loading the json memory file")
            self.memory = self.load(self.json_file_path)
        else:
            print("The json memory file does not exist. Creating new file.")
            self.memory = {"game_records": []}  # Direct dictionary instead of json.loads
            with open(self.json_file_path, "w") as f:
                json.dump(self.memory, f)

    def get(self):
        print("Getting the json memory")
        return self.memory

    def add_no_limit(self, data: float, ):
        """
        Add a records.

        Args:
            role: The role of the sender (e.g., 'user', 'assistant')
            message: The message content
        """
        self.memory["game_records"].append({
            "game_total_duration": data,
            "timestamp": str(datetime.datetime.now())
        })

        self.save(self.json_file_path)

    def save(self, file_path):
        try:
            with open(file_path, 'w') as f:
                json.dump(self.memory, f)
        except Exception as e:
            print(f"Error saving memory to {file_path}: {e}")

    def load(self, file_path):
        try:
            with open(file_path, 'r') as f:
                return json.load(f)
        except Exception as e:
            print(f"Error loading memory from {file_path}: {e}")
            return {"game_records": []}

    def get_newest_record_name(self) -> str:
        """
        傳回最新的對話歷史資料和集的名稱 (game_YYYY_MM)
            - 例如: "game_2022-01"
        """

        this_month = datetime.datetime.now().strftime("%Y-%m")
        return "record_" + this_month

## Shapes & Objects

In [None]:
import pymunk
from typing import Tuple, Optional

class Shape:

    def __init__(
                self,
                position: Tuple[float, float] = (300, 100),
                velocity: Tuple[float, float] = (0, 0),
                body: Optional[pymunk.Body] = None,
                shape: Optional[pymunk.Shape] = None,
            ):
        """
        Initialize a physical shape with associated body.

        Args:
            position: Initial position (x, y) of the body
            velocity: Initial velocity (vx, vy) of the body
            body: The pymunk Body to attach to this shape
            shape: The pymunk Shape for collision detection
        """

        self.body = body
        self.default_position = position
        self.default_velocity = velocity
        self.body.position = position
        self.body.velocity = velocity
        self.default_angular_velocity = 0

        self.shape = shape

    def reset(self):
        """Reset the body to its default position, velocity and angular velocity."""
        self.body.position = self.default_position
        self.body.velocity = self.default_velocity
        self.body.angular_velocity = self.default_angular_velocity


In [None]:
import pymunk

# from shapes.shape import Shape
from typing import Tuple, Optional

class Circle(Shape):

    def __init__(
                self,
                position: Tuple[float, float] = (300, 100),
                velocity: Tuple[float, float] = (0, 0),
                body: Optional[pymunk.Body] = None,
                shape_radio: float = 20,
                shape_mass: float = 1,
                shape_friction: float = 0.1,
            ):
        """
        Initialize a circular physics object.

        Args:
            position: Initial position (x, y) of the circle
            velocity: Initial velocity (vx, vy) of the circle
            body: The pymunk Body to attach this circle to
            shape_radio: Radius of the circle in pixels
            shape_mass: Mass of the circle
            shape_friction: Friction coefficient for the circle
        """

        super().__init__(position, velocity, body)
        self.shape_radio = shape_radio
        self.shape = pymunk.Circle(self.body, shape_radio)
        self.shape.mass = shape_mass
        self.shape.friction = shape_friction
        self.shape.elasticity = 0.8  # Add some bounce to make the simulation more interesting


## Levels

In [None]:
import random
import pymunk
import pygame
import numpy as np
import time

# from shapes.circle import Circle

def get_level(level: int, space):
    """
    Get the level object based on the level number.
    """
    if level == 1:
        return Level1(space)
    elif level == 2:
        return Level2(space)
    elif level == 3:
        return Level3(space)
    else:
        raise ValueError("Invalid level number")

class Levels:
    def __init__(self, space, window_x: int = 1000, window_y: int = 600):
        self.space = space
        self.window_x = window_x
        self.window_y = window_y

    def create_player(self,
                      default_player_position: tuple = None,
                      ball_color = (255, 213, 79),  # Bright yellow ball
                      window_x: int = 1000,
                      window_y: int = 600,
                     ):
        """
        Create the ball with physics properties
        default_player_position: Initial position of the player
            default: (window_x / 2, window_y / 5)
        """
        if default_player_position is None:
            default_player_position = (window_x / 2, window_y / 5)
        dynamic_body = pymunk.Body()  # Ball body
        ball_radius = int(window_x / 67)
        player = Circle(
            position=default_player_position,
            velocity=(0, 0),
            body=dynamic_body,
            shape_radio=ball_radius,
            shape_friction=100,
        )
        # Store initial values for reset
        return {
            "type": "player",
            "shape": player,
            "default_position": default_player_position,
            "body": dynamic_body,
            "ball_radius": ball_radius,
            "ball_color": ball_color
        }

    def create_platform(self,
                        platform_shape: str = "circle",
                        platform_proportion: float = 0.4,
                        window_x: int = 1000,
                        window_y: int = 600,
                       ):
        """
        Create the platform with physics properties
        platform_shape: circle, rectangle
        platform_length: Length of a rectangle or Diameter of a circle
        """
        platform_length = int(window_x * platform_proportion)

        # Create game bodies
        kinematic_body = pymunk.Body(body_type=pymunk.Body.KINEMATIC)  # Platform body
        kinematic_body.position = (window_x / 2, (window_y / 3) * 2)
        default_kinematic_position = kinematic_body.position

        if platform_shape == "circle":
            platform_length = platform_length / 2 # radius
            platform = pymunk.Circle(kinematic_body, platform_length)
            platform.mass = 1  # 质量对 Kinematic 物体无意义，但需要避免除以零错误
            platform.friction = 0.7

        elif platform_shape == "rectangle":
            platform_length = platform_length
            vs = [(-platform_length/2, -10),
                (platform_length/2, -10),
                (platform_length/2, 10),
                (-platform_length/2, 10)]

            platform = pymunk.Poly(kinematic_body, vs)
        platform.friction = 0.7
        platform.rotation = 0

        return {
            "type": "platform",
            "platform_shape": platform_shape,
            "shape": platform,
            "default_position": default_kinematic_position,
            "body": kinematic_body,
            "platform_length": platform_length,
        }

# TODO not use for now
    def _draw_indie_style(self):
        """Draw game objects with indie game aesthetic"""
        # # Draw platform with gradient and glow
        # platform_points = []
        # for v in self.platform.get_vertices():
        #     x, y = v.rotated(self.kinematic_body.angle) + self.kinematic_body.position
        #     platform_points.append((int(x), int(y)))

        # pygame.draw.polygon(self.screen, self.PLATFORM_COLOR, platform_points)
        # pygame.draw.polygon(self.screen, (255, 255, 255), platform_points, 2)

        platform_pos = (int(self.kinematic_body.position[0]), int(self.kinematic_body.position[1]))
        pygame.draw.circle(self.screen, self.PLATFORM_COLOR, platform_pos, self.platform_length)
        pygame.draw.circle(self.screen, (255, 255, 255), platform_pos, self.platform_length, 2)

        # Draw rotation direction indicator
        self._draw_rotation_indicator(platform_pos, self.platform_length, self.kinematic_body.angular_velocity)

        # Draw ball with gradient and glow
        ball_pos = (int(self.dynamic_body.position[0]), int(self.dynamic_body.position[1]))
        pygame.draw.circle(self.screen, self.BALL_COLOR, ball_pos, self.ball_radius)
        pygame.draw.circle(self.screen, (255, 255, 255), ball_pos, self.ball_radius, 2)

# TODO not use for now
    def _draw_rotation_indicator(self, position, radius, angular_velocity):
        """Draw an indicator showing the platform's rotation direction and speed"""
        # Only draw the indicator if there's some rotation
        if abs(angular_velocity) < 0.1:
            return

        # Calculate indicator properties based on angular velocity
        indicator_color = (50, 255, 150) if angular_velocity > 0 else (255, 150, 50)
        num_arrows = min(3, max(1, int(abs(angular_velocity))))
        indicator_radius = radius - 20  # Place indicator inside the platform

        # Draw arrow indicators along the platform's circumference
        start_angle = self.kinematic_body.angle

        for i in range(num_arrows):
            # Calculate arrow position
            arrow_angle = start_angle + i * (2 * np.pi / num_arrows)

            # Calculate arrow start and end points
            base_x = position[0] + int(np.cos(arrow_angle) * indicator_radius)
            base_y = position[1] + int(np.sin(arrow_angle) * indicator_radius)

            # Determine arrow direction based on angular velocity
            if angular_velocity > 0:  # Clockwise
                arrow_end_angle = arrow_angle + 0.3
            else:  # Counter-clockwise
                arrow_end_angle = arrow_angle - 0.3

            tip_x = position[0] + int(np.cos(arrow_end_angle) * (indicator_radius + 15))
            tip_y = position[1] + int(np.sin(arrow_end_angle) * (indicator_radius + 15))

            # Draw arrow line
            pygame.draw.line(self.screen, indicator_color, (base_x, base_y), (tip_x, tip_y), 3)

            # Draw arrowhead
            arrowhead_size = 7
            pygame.draw.circle(self.screen, indicator_color, (tip_x, tip_y), arrowhead_size)

pygame-ce 2.5.3 (SDL 2.30.12, Python 3.11.12)


### Level1

In [None]:
class Level1(Levels):
    """
    Level 1: Basic setup with a dynamic body and a static kinematic body.
    """
    def __init__(self, space):
        super().__init__(space)
        self.space = space


    def setup(self, window_x, window_y):
        player = super().create_player(window_x=window_x, window_y=window_y)
        platform = super().create_platform(window_x=window_x, window_y=window_y)
        self.space.add(player["body"], player["shape"].shape)
        self.space.add(platform["body"], platform["shape"])
        self.dynamic_body = player["body"]
        self.kinematic_body = platform["body"]
        self.default_player_position = player["default_position"]

        self.kinematic_body.angular_velocity = random.randrange(-1, 2, 2)

        return (player, ), (platform, )

    def action(self):
        """
        shape state changes in the game
        """
        # Noting to do in this level
        pass

    def reset(self):
        """
        Reset the level to its initial state.
        """
        self.dynamic_body.position = self.default_player_position
        self.dynamic_body.angular_velocity = 0
        self.dynamic_body.velocity = (0, 0)
        self.kinematic_body.angular_velocity = random.randrange(-1, 2, 2)

        self.space.reindex_shapes_for_body(self.dynamic_body)
        self.space.reindex_shapes_for_body(self.kinematic_body)

### Level2

In [None]:
class Level2(Levels):
    """
    Level 1: Basic setup with a dynamic body and a static kinematic body.
    """
    def __init__(self, space):
        super().__init__(space)
        self.space = space
        self.last_angular_velocity_change_time = time.time()
        self.angular_velocity_change_timeout = 5 # sec


    def setup(self, window_x, window_y):
        player = super().create_player(window_x=window_x, window_y=window_y)
        platform = super().create_platform(window_x=window_x, window_y=window_y)
        self.space.add(player["body"], player["shape"].shape)
        self.space.add(platform["body"], platform["shape"])
        self.dynamic_body = player["body"]
        self.kinematic_body = platform["body"]
        self.default_player_position = player["default_position"]

        self.kinematic_body.angular_velocity = random.randrange(-1, 2, 2)

        return (player, ), (platform, )

    def action(self):
        """
        shape state changes in the game
        """

        if time.time() - self.last_angular_velocity_change_time > self.angular_velocity_change_timeout:
            self.kinematic_body.angular_velocity = random.randrange(-1, 2, 2)
            self.last_angular_velocity_change_time = time.time()

    def reset(self):
        """
        Reset the level to its initial state.
        """
        self.dynamic_body.position = self.default_player_position
        self.dynamic_body.angular_velocity = 0
        self.dynamic_body.velocity = (0, 0)
        self.kinematic_body.angular_velocity = random.randrange(-1, 2, 2)
        self.last_angular_velocity_change_time = time.time()

        self.space.reindex_shapes_for_body(self.dynamic_body)
        self.space.reindex_shapes_for_body(self.kinematic_body)

### Level3

In [None]:
# Two players
# NOTE: 連續動作空間和對抗式訓練
class Level3(Levels):
    """
    Level 3: Two players with adversarial training
    """
    def __init__(self, space):
        super().__init__(space)
        self.space = space
        self.last_collision_time = 0
        self.collision_reward_cooldown = 0.5  # seconds
        self.collision_occurred = False

    def setup(self, window_x, window_y):
        x = window_x / 5
        player1 = super().create_player(window_x=window_x,
                                        window_y=window_y,
                                        default_player_position=(x*2, window_y / 5)
                                       )
        player2 = super().create_player(window_x=window_x,
                                        window_y=window_y,
                                        ball_color=(194, 238, 84),
                                        default_player_position=(x*3, window_y / 5)
                                       )
        platform = super().create_platform(platform_shape="rectangle",platform_proportion=0.8, window_x=window_x, window_y=window_y)

        # Set collision types for balls - 這是關鍵修復
        player1["shape"].shape.collision_type = 1
        player2["shape"].shape.collision_type = 2

        # Add collision handler for balls colliding with each other
        handler = self.space.add_collision_handler(1, 2)
        handler.begin = self.handle_collision

        self.space.add(player1["body"], player1["shape"].shape)
        self.space.add(player2["body"], player2["shape"].shape)
        self.space.add(platform["body"], platform["shape"])
        self.dynamic_body1 = player1["body"]
        self.dynamic_body2 = player2["body"]
        self.kinematic_body = platform["body"]
        self.default_player_position1 = player1["default_position"]
        self.default_player_position2 = player2["default_position"]

        self.collision_occurred = False

        return (player1, player2), (platform, )

    def handle_collision(self, arbiter, space, data):
        """Handle collisions between balls"""
        current_time = time.time()
        if current_time - self.last_collision_time > self.collision_reward_cooldown:
            self.last_collision_time = current_time
            # Mark that a collision occurred
            self.collision_occurred = True
            
            # 計算碰撞時的相對速度來決定獎勵
            body1, body2 = arbiter.shapes[0].body, arbiter.shapes[1].body
            
            # 計算碰撞前的動量
            self.collision_impulse_1 = abs(body1.velocity[0]) + abs(body1.velocity[1])
            self.collision_impulse_2 = abs(body2.velocity[0]) + abs(body2.velocity[1])
            
            print(f"Collision occurred! Body1 speed: {self.collision_impulse_1:.2f}, Body2 speed: {self.collision_impulse_2:.2f}")
        return True

    def action(self):
        """
        shape state changes in the game
        """
        # Reset collision flag each frame in step function, not here
        pass

    def reset(self):
        """
        Reset the level to its initial state.
        """
        self.dynamic_body1.position = self.default_player_position1
        self.dynamic_body1.angular_velocity = 0
        self.dynamic_body1.velocity = (0, 0)
        self.dynamic_body2.position = self.default_player_position2
        self.dynamic_body2.angular_velocity = 0
        self.dynamic_body2.velocity = (0, 0)

        self.collision_occurred = False
        self.last_collision_time = 0
        self.collision_impulse_1 = 0
        self.collision_impulse_2 = 0

        self.space.reindex_shapes_for_body(self.dynamic_body1)
        self.space.reindex_shapes_for_body(self.dynamic_body2)
        self.space.reindex_shapes_for_body(self.kinematic_body)

### --

In [None]:
def get_level(level: int, space):
    """
    Get the level object based on the level number.
    """
    if level == 1:
        return Level1(space)
    elif level == 2:
        return Level2(space)
    elif level == 3:
        return Level3(space)
    else:
        raise ValueError("Invalid level number")

## Game class

In [None]:
import pymunk
import pygame
import random
import time
import numpy as np
import os
import numpy as np
import base64
import math
import matplotlib.pyplot as plt
# import IPython.display as ipd

from typing import Dict, Tuple, Optional
# from IPython.display import display, Image, clear_output
from io import BytesIO
# from record import Recorder
# from levels.levels import get_level

class BalancingBallGame:
    """
    A physics-based balancing ball game that can run standalone or be used as a Gym environment.
    """

    # Game constants

    # Visual settings for indie style
    BACKGROUND_COLOR = (41, 50, 65)  # Dark blue background
    BALL_COLOR = (255, 213, 79)  # Bright yellow ball
    PLATFORM_COLOR = (235, 64, 52)  # Red platform

    def __init__(self,
                 render_mode: str = "human",
                 sound_enabled: bool = True,
                 difficulty: str = "medium",
                 window_x: int = 1000,
                 window_y: int = 600,
                 max_step: int = 30000,
                 player_ball_speed: int = 5,
                 reward_staying_alive: float = 0.1,
                 reward_ball_centered: float = 0.2,
                 penalty_falling: float = -10.0,
                 level: int = 2,
                 fps: int = 120,
                 platform_shape: str = "circle",
                 platform_proportion: int = 0.4,
                 capture_per_second: int = None,
                 max_force: float = 50000.0,  # Maximum horizontal force
                 collision_reward: float = 5.0,  # Increased collision reward
                 speed_reward_multiplier: float = 0.01,  # Reward for maintaining speed
                 opponent_fall_bonus: float = 15.0,  # Bonus for causing opponent to fall
                 survival_bonus: float = 0.5,  # Bonus for staying alive when opponent falls
                 platform_distance_penalty: float = 0.02,  # Penalty for being far from platform center
                ):
        """
        Initialize the balancing ball game.

        Args:
            render_mode: "human" for visible window, "rgb_array" for gym env, "headless" for no rendering
            sound_enabled: Whether to enable sound effects
            difficulty: Game difficulty level ("easy", "medium", "hard")
            max_step: 1 step = 1/fps, if fps = 120, 1 step = 1/120
            reward_staying_alive: float = 0.1,
            reward_ball_centered: float = 0.2,
            penalty_falling: float = -10.0,
            fps: frame per second
            platform_proportion: platform_length = window_x * platform_proportion
            capture_per_second: save game screen as a image every second, None means no capture
            max_force: Maximum horizontal force that can be applied
            collision_reward: Reward bonus for causing opponent to fall through collision
            speed_reward_multiplier: Multiplier for speed-based rewards
        """
        # Game parameters
        self.max_step = max_step
        self.reward_staying_alive = reward_staying_alive
        self.reward_ball_centered = reward_ball_centered
        self.penalty_falling = penalty_falling
        self.fps = fps
        self.window_x = window_x
        self.window_y = window_y
        self.player_ball_speed = player_ball_speed
        self.max_force = max_force
        self.collision_reward = collision_reward
        self.speed_reward_multiplier = speed_reward_multiplier
        self.opponent_fall_bonus = opponent_fall_bonus
        self.survival_bonus = survival_bonus
        self.platform_distance_penalty = platform_distance_penalty

        self.recorder = Recorder("game_history_record")
        self.render_mode = render_mode
        self.sound_enabled = sound_enabled
        self.difficulty = difficulty

        platform_length = int(window_x * platform_proportion)
        self._get_x_axis_max_reward_rate(platform_length)

        # Initialize physics space
        self.space = pymunk.Space()
        self.space.gravity = (0, 9810)
        self.space.damping = 0.9

        self.level = get_level(level, self.space)
        players, platforms = self.level.setup(self.window_x, self.window_y)
        self.dynamic_body_players = []
        self.kinematic_body_platforms = []
        self.players_color = []
        self.player_alive = []  # Track which players are still alive

        for i, player in enumerate(players):
            self.dynamic_body_players.append(player["body"])
            self.players_color.append(player["ball_color"])
            self.player_alive.append(True)

        for platform in platforms:
            self.kinematic_body_platforms.append(platform["body"])
            if (platform["platform_shape"] == "rectangle"): # TODO 變數名不清晰
                self.platform_shape = platform["shape"]

        self.ball_radius = players[0]["ball_radius"]
        self.platform_length = platforms[0]["platform_length"]
        self.num_players = len(players)

        # Game state tracking
        self.steps = 0
        self.start_time = time.time()
        self.game_over = False
        self.score = [0] * self.num_players  # Score for each player
        self.winner = None
        self.last_speeds = [0] * self.num_players  # Track last speed for each player
        self.players_fell_this_step = [False] * self.num_players  # Track who fell this step

        # Initialize Pygame if needed
        if self.render_mode in ["human", "rgb_array", "rgb_array_and_human", "rgb_array_and_human_in_colab"]:
            self._setup_pygame()
        else:
            print("render_mode is not human or rgb_array, so no pygame setup.")

        # Set difficulty parameters
        self._apply_difficulty()
        self.capture_per_second = capture_per_second

        # Create folders for captures if needed
        # CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
        CURRENT_DIR = "."
        os.makedirs(os.path.dirname(CURRENT_DIR + "/capture/"), exist_ok=True)

        if self.num_players > 2:
            raise ValueError("Warning!!! collision reward calculation in step() can only work for two players now")


    def _setup_pygame(self):
        """Set up PyGame for rendering"""
        pygame.init()
        self.frame_count = 0

        if self.sound_enabled:
            self._load_sounds()

        if self.render_mode == "human":
            self.screen = pygame.display.set_mode((self.window_x, self.window_y))
            pygame.display.set_caption("Balancing Ball - Indie Game")
            self.font = pygame.font.Font(None, int(self.window_x / 34))

        elif self.render_mode == "rgb_array":
            self.screen = pygame.Surface((self.window_x, self.window_y))

        elif self.render_mode == "rgb_array_and_human": # todo
            print("rgb_array_and_human mode is not supported yet.")

        elif self.render_mode == "rgb_array_and_human_in_colab": # todo
            from pymunk.pygame_util import DrawOptions

            self.screen = pygame.Surface((self.window_x, self.window_y))  # Create hidden surface

            # Set up display in Colab
            self.draw_options = DrawOptions(self.screen)
            html_display = ipd.HTML('''
                <div id="pygame-output" style="width:100%;">
                    <img id="pygame-img" style="width:100%;">
                </div>
            ''')
            self.display_handle = display(html_display, display_id='pygame_display')

            self.last_update_time = time.time()
            self.update_interval = 1.0 / 15  # Update display at 15 FPS to avoid overwhelming Colab
            self.font = pygame.font.Font(None, int(self.window_x / 34))


        else:
            print("Invalid render mode. Using headless mode.")

        self.clock = pygame.time.Clock()

        # Create custom draw options for indie style

    def _load_sounds(self):
        """Load game sound effects"""
        try:
            pygame.mixer.init()
            self.sound_bounce = pygame.mixer.Sound("assets/bounce.wav") if os.path.exists("assets/bounce.wav") else None
            self.sound_fall = pygame.mixer.Sound("assets/fall.wav") if os.path.exists("assets/fall.wav") else None
        except Exception:
            print("Sound loading error")
            self.sound_enabled = False
            pass

    def _apply_difficulty(self):
        """Apply difficulty settings to the game"""
        if self.difficulty == "easy":
            self.max_platform_speed = 1.5
            self.ball_elasticity = 0.5
        elif self.difficulty == "medium":
            self.max_platform_speed = 2.5
            self.ball_elasticity = 0.7
        else:  # hard
            self.max_platform_speed = 3.5
            self.ball_elasticity = 0.9

        # self.circle.shape.elasticity = self.ball_elasticity

    def reset(self) -> np.ndarray:
        """Reset the game state and return the initial observation"""
        # Reset physics objects
        self.level.reset()

        # Reset game state
        self.steps = 0
        self.start_time = time.time()
        self.game_over = False
        self.score = [0] * self.num_players
        self.winner = None
        self.player_alive = [True] * self.num_players
        self.last_speeds = [0] * self.num_players
        self.players_fell_this_step = [False] * self.num_players

        # Return initial observation
        return self._get_observation()

    def step(self, actions: list = []) -> Tuple[np.ndarray, list, bool, Dict]:
        """
        Take a step in the game using the given actions.

        Args:
            actions: List of continuous actions [-1.0, 1.0] for each player controlling horizontal force

        Returns:
            observation: Game state observation
            rewards: List of rewards for each player
            terminated: Whether episode is done
            info: Additional information
        """
        # Reset collision and fall tracking for this step
        self.level.collision_occurred = False
        self.players_fell_this_step = [False] * self.num_players
        
        # Apply actions to players (horizontal forces)
        for i in range(len(self.dynamic_body_players)):
            if i < len(actions) and self.player_alive[i]:
                # Scale action to force range
                force_magnitude = actions[i] * self.max_force
                force_vector = pymunk.Vec2d(force_magnitude, 0)
                self.dynamic_body_players[i].apply_force_at_world_point(
                    force_vector,
                    self.dynamic_body_players[i].position
                )

        self.level.action()

        # Step the physics simulation
        self.space.step(1/self.fps)

        # Check game state
        self.steps += 1
        terminated = False
        rewards = [0] * self.num_players
        player_velocities = []

        # Check collision reward if available
        collision_occurred = getattr(self.level, 'collision_occurred', False)
        collision_impulse_1 = getattr(self.level, 'collision_impulse_1', 0)
        collision_impulse_2 = getattr(self.level, 'collision_impulse_2', 0)

        # Check if balls fall off screen and calculate rewards
        alive_count = 0
        platform_center_x = self.kinematic_body_platforms[0].position[0]
        
        for i, player in enumerate(self.dynamic_body_players):
            if not self.player_alive[i]:
                continue

            ball_x = player.position[0]
            ball_y = player.position[1]
            player_velocities.append(player.velocity)

            # Check if player falls
            if (ball_y > self.kinematic_body_platforms[0].position[1] + 50 or
                ball_x < 0 or ball_x > self.window_x):

                self.player_alive[i] = False
                self.players_fell_this_step[i] = True
                rewards[i] = self.penalty_falling

                if self.sound_enabled and self.sound_fall:
                    self.sound_fall.play()
            else:
                alive_count += 1
                
                # 基礎生存獎勵
                survival_reward = self.reward_staying_alive

                # 速度獎勵 - 鼓勵保持移動
                current_speed = abs(player.velocity[0]) + abs(player.velocity[1])
                speed_reward = min(current_speed * self.speed_reward_multiplier, 0.1)  # 限制最大速度獎勵
                
                # 平台中心獎勵 - 鼓勵靠近平台中心但不要太極端
                center_reward = self._calculate_center_reward(ball_x)
                
                # 平台距離懲罰 - 距離平台中心太遠會有小懲罰
                distance_from_platform = abs(ball_x - platform_center_x)
                distance_penalty = -min(distance_from_platform * self.platform_distance_penalty, 0.5)

                rewards[i] = survival_reward + speed_reward + center_reward + distance_penalty
                self.score[i] += rewards[i]
                self.last_speeds[i] = current_speed

        # 處理碰撞獎勵 - 更智能的獎勵系統
        if collision_occurred and len(player_velocities) >= 2:
            # 基於碰撞時的衝量差距給予獎勵
            impulse_diff = collision_impulse_1 - collision_impulse_2
            
            # 獎勵較高衝量的玩家，懲罰較低衝量的玩家
            if abs(impulse_diff) > 0.1:  # 避免微小差距的獎勵
                collision_reward_1 = impulse_diff * self.collision_reward * 0.1
                collision_reward_2 = -impulse_diff * self.collision_reward * 0.1
                
                # 限制碰撞獎勵的範圍
                collision_reward_1 = np.clip(collision_reward_1, -2.0, 2.0)
                collision_reward_2 = np.clip(collision_reward_2, -2.0, 2.0)
                
                if self.player_alive[0]:
                    rewards[0] += collision_reward_1
                if self.player_alive[1]:
                    rewards[1] += collision_reward_2
                    
                print(f"Collision rewards: P1: {collision_reward_1:.2f}, P2: {collision_reward_2:.2f}")

        # 處理對手掉落的獎勵
        for i in range(self.num_players):
            if self.player_alive[i]:  # 如果這個玩家還活著
                # 檢查是否有對手在這步掉落
                opponents_fell = any(self.players_fell_this_step[j] for j in range(self.num_players) if j != i)
                if opponents_fell:
                    rewards[i] += self.opponent_fall_bonus  # 獲得擊敗對手的獎勵
                    print(f"Player {i+1} gets opponent fall bonus: {self.opponent_fall_bonus}")

        # Check if game should end
        if alive_count <= 1 or self.steps >= self.max_step:
            print("Final Scores: ", self.score)
            terminated = True
            self.game_over = True

            # Determine winner (last player alive or highest score)
            if alive_count == 1:
                self.winner = next(i for i in range(self.num_players) if self.player_alive[i])
                # Give bonus to winner
                rewards[self.winner] += self.survival_bonus * self.steps / 100  # 生存時間越長獎勵越多
                print(f"Winner: Player {self.winner + 1}")
            elif alive_count == 0:
                self.winner = None  # Draw
                print("Draw - all players fell")
            else:
                # Game ended due to max steps, winner is highest score
                self.winner = np.argmax(self.score)
                print(f"Time limit reached. Winner by score: Player {self.winner + 1}")

            result = {
                "game_total_duration": f"{time.time() - self.start_time:.2f}",
                "scores": self.score,
                "winner": self.winner,
                "steps": self.steps
            }
            self.recorder.add_no_limit(result)

        return self._get_observation(), rewards, terminated

    def _get_observation(self) -> np.ndarray:
        """Convert game state to observation for RL agent"""
        # update particles and draw them
        screen_data = self.render() # 获取数据

        if self.capture_per_second is not None and self.frame_count % self.capture_per_second == 0:  # Every second at 60 FPS
            pygame.image.save(self.screen, f"capture/frame_{self.frame_count/60}.png")

        self.frame_count += 1
        return screen_data

    def _calculate_center_reward(self, ball_x):
        """Calculate reward based on how close ball is to center"""
        distance_from_center = abs(ball_x - self.window_x/2)
        if distance_from_center < self.reward_width:
            normalized_distance = distance_from_center / self.reward_width
            return self.reward_ball_centered * (1.0 - normalized_distance)
        return 0

    def render(self) -> Optional[np.ndarray]:
        """Render the current game state"""
        if self.render_mode == "headless":
            return None

        # Clear screen with background color
        self.screen.fill(self.BACKGROUND_COLOR)

        # Custom drawing (for indie style)
        self._draw_indie_style()


        # Update display if in human mode
        if self.render_mode == "human":
            # Draw game information
            self._draw_game_info()
            pygame.display.flip()
            self.clock.tick(self.fps)
            return None

        elif self.render_mode == "rgb_array":
            # Return RGB array for gym environment
            return pygame.surfarray.array3d(self.screen)

        elif self.render_mode == "rgb_array_and_human": # todo
            print("rgb_array_and_human mode is not supported yet.")

        elif self.render_mode == "rgb_array_and_human_in_colab":
            self.space.debug_draw(self.draw_options)
            current_time = time.time()
            if current_time - self.last_update_time >= self.update_interval:
                # Convert Pygame surface to an image that can be displayed in Colab
                buffer = BytesIO()
                pygame.image.save(self.screen, buffer, 'PNG')
                buffer.seek(0)
                img_data = base64.b64encode(buffer.read()).decode('utf-8')

                # Update the HTML image
                self.display_handle.update(ipd.HTML(f'''
                    <div id="pygame-output" style="width:100%;">
                        <img id="pygame-img" src="data:image/png;base64,{img_data}" style="width:100%;">
                    </div>
                '''))

                self.last_update_time = current_time
            return pygame.surfarray.array3d(self.screen)
        else:
            pass

    def get(self):
        print("Getting the json memory")
        return self.memory

    def add_no_limit(self, data: float, ):
        """
        Add a records.

        Args:
            role: The role of the sender (e.g., 'user', 'assistant')
            message: The message content
        """
        self.memory["game_records"].append({
            "game_total_duration": data,
            "timestamp": str(datetime.datetime.now())
        })

        self.save(self.json_file_path)

    def save(self, file_path):
        try:
            with open(file_path, 'w') as f:
                json.dump(self.memory, f)
        except Exception as e:
            print(f"Error saving memory to {file_path}: {e}")

    def load(self, file_path):
        try:
            with open(file_path, 'r') as f:
                return json.load(f)
        except Exception as e:
            print(f"Error loading memory from {file_path}: {e}")
            return {"game_records": []}

    def get_newest_record_name(self) -> str:
        """
        傳回最新的對話歷史資料和集的名稱 (game_YYYY_MM)
            - 例如: "game_2022-01"
        """

        this_month = datetime.datetime.now().strftime("%Y-%m")
        return "record_" + this_month

    def _draw_indie_style(self):
        """Draw game objects with indie game aesthetic"""
        # # Draw platform with gradient and glow
        for i in range(len(self.dynamic_body_players)):
            ball_pos = (int(self.dynamic_body_players[i].position[0]), int(self.dynamic_body_players[i].position[1]))
            pygame.draw.circle(self.screen, self.players_color[i], ball_pos, self.ball_radius)
            pygame.draw.circle(self.screen, (255, 255, 255), ball_pos, self.ball_radius, 2)

        for platform in self.kinematic_body_platforms:
            if (platform["platform_shape"] == "rectangle"): # TODO 變數名不清晰
                platform_points = []
                for v in self.platform_shape.get_vertices():
                    x, y = v.rotated(platform.angle) + platform.position
                    platform_points.append((int(x), int(y)))

                pygame.draw.polygon(self.screen, self.PLATFORM_COLOR, platform_points)
                pygame.draw.polygon(self.screen, (255, 255, 255), platform_points, 2)
            else:  # Circle platform
                platform_pos = (int(platform.position[0]), int(platform.position[1]))
                pygame.draw.circle(self.screen, self.PLATFORM_COLOR, platform_pos, self.platform_length)
                pygame.draw.circle(self.screen, (255, 255, 255), platform_pos, self.platform_length, 2)

            # Draw rotation direction indicator
            self._draw_rotation_indicator(platform_pos, self.platform_length, platform.angular_velocity, platform)


    def _draw_rotation_indicator(self, position, radius, angular_velocity, body):
        """Draw an indicator showing the platform's rotation direction and speed"""
        # Only draw the indicator if there's some rotation
        if abs(angular_velocity) < 0.1:
            return

        # Calculate indicator properties based on angular velocity
        indicator_color = (50, 255, 150) if angular_velocity > 0 else (255, 150, 50)
        num_arrows = min(3, max(1, int(abs(angular_velocity))))
        indicator_radius = radius - 20  # Place indicator inside the platform

        # Draw arrow indicators along the platform's circumference
        start_angle = body.angle

        for i in range(num_arrows):
            # Calculate arrow position
            arrow_angle = start_angle + i * (2 * np.pi / num_arrows)

            # Calculate arrow start and end points
            base_x = position[0] + int(np.cos(arrow_angle) * indicator_radius)
            base_y = position[1] + int(np.sin(arrow_angle) * indicator_radius)

            # Determine arrow direction based on angular velocity
            if angular_velocity > 0:  # Clockwise
                arrow_end_angle = arrow_angle + 0.3
            else:  # Counter-clockwise
                arrow_end_angle = arrow_angle - 0.3

            tip_x = position[0] + int(np.cos(arrow_end_angle) * (indicator_radius + 15))
            tip_y = position[1] + int(np.sin(arrow_end_angle) * (indicator_radius + 15))

            # Draw arrow line
            pygame.draw.line(self.screen, indicator_color, (base_x, base_y), (tip_x, tip_y), 3)

            # Draw arrowhead
            arrowhead_size = 7
            pygame.draw.circle(self.screen, indicator_color, (tip_x, tip_y), arrowhead_size)

    def _draw_game_info(self):
        """Draw game information on screen"""
        # Create texts
        time_text = f"Time: {time.time() - self.start_time:.1f}"
        score_texts = [f"P{i+1}: {self.score[i]:.1f}" for i in range(self.num_players)]

        # Render texts
        time_surface = self.font.render(time_text, True, (255, 255, 255))
        score_surfaces = [self.font.render(text, True, (255, 255, 255)) for text in score_texts]

        # Draw text backgrounds and texts
        pygame.draw.rect(self.screen, (0, 0, 0, 128),
                        (5, 5, time_surface.get_width() + 10, time_surface.get_height() + 5))
        self.screen.blit(time_surface, (10, 10))

        # Draw scores
        y_offset = 40
        for i, surface in enumerate(score_surfaces):
            color = self.players_color[i] if i < len(self.players_color) else (255, 255, 255)
            pygame.draw.rect(self.screen, (0, 0, 0, 128),
                            (5, y_offset, surface.get_width() + 10, surface.get_height() + 5))
            colored_surface = self.font.render(score_texts[i], True, color)
            self.screen.blit(colored_surface, (10, y_offset))
            y_offset += 30

        # Draw game over screen
        if self.game_over:
            if self.winner is not None:
                game_over_text = f"WINNER: Player {self.winner + 1} - Press R to restart"
            else:
                game_over_text = "DRAW - Press R to restart"
            game_over_surface = self.font.render(game_over_text, True, (255, 255, 255))

            # Draw semi-transparent background
            overlay = pygame.Surface((self.window_x, self.window_y), pygame.SRCALPHA)
            overlay.fill((0, 0, 0, 128))
            self.screen.blit(overlay, (0, 0))

            # Draw text
            self.screen.blit(game_over_surface,
                           (self.window_x/2 - game_over_surface.get_width()/2,
                            self.window_y/2 - game_over_surface.get_height()/2))

    def _get_x_axis_max_reward_rate(self, platform_length):
        """
        ((self.platform_length / 2) - 5) for calculate the distance to the
        center of game window coordinates. The closer you are, the higher the reward.

        When the ball is to be 10 points away from the center coordinates,
        it should be 1 - ((self.platform_length - 10) * self.x_axis_max_reward_rate)
        """
        self.reward_width = (platform_length / 2) - 5
        self.x_axis_max_reward_rate = 2 / self.reward_width
        print("self.x_axis_max_reward_rate: ", self.x_axis_max_reward_rate)

    def _reward_calculator(self, ball_x):
        # score & reward
        step_reward = 1/100

        rw = abs(ball_x - self.window_x/2)
        if rw < self.reward_width:
            x_axis_reward_rate = 1 + ((self.reward_width - abs(ball_x - self.window_x/2)) * self.x_axis_max_reward_rate)
            step_reward = self.steps * 0.01 * x_axis_reward_rate  # Simplified reward calculation

            if self.steps % 500 == 0:
                step_reward += self.steps/100
                print("check point: ", self.steps/500)

            return step_reward
        else:
            return 0

    def close(self):
        """Close the game and clean up resources"""
        if self.render_mode in ["human", "rgb_array"]:
            pygame.quit()

    def run_standalone(self):
        """Run the game in standalone mode with keyboard controls"""
        if self.render_mode not in ["human", "rgb_array_and_human_in_colab"]:
            raise ValueError("Standalone mode requires render_mode='human' or 'rgb_array_and_human_in_colab'")

        running = True
        while running:
            # Handle events
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    running = False
                elif event.type == pygame.KEYDOWN:
                    if event.key == pygame.K_r and self.game_over:
                        self.reset()

            # Process keyboard controls for continuous actions
            keys = pygame.key.get_pressed()
            actions = []

            # Player 1 controls (Arrow keys)
            if keys[pygame.K_LEFT]:
                actions.append(-1.0)  # Full left force
            elif keys[pygame.K_RIGHT]:
                actions.append(1.0)   # Full right force
            else:
                actions.append(0.0)   # No force

            # Player 2 controls (WASD)
            if len(self.dynamic_body_players) > 1:
                if keys[pygame.K_a]:
                    actions.append(-1.0)  # Full left force
                elif keys[pygame.K_d]:
                    actions.append(1.0)   # Full right force
                else:
                    actions.append(0.0)   # No force

            # Take game step
            if not self.game_over:
                self.step(actions)

            # Render
            self.render()

        self.close()

## GYM env

In [None]:
import gymnasium as gym
import numpy as np
from gymnasium import spaces
import cv2

# from balancing_ball_game import BalancingBallGame

class BalancingBallEnv(gym.Env):
    """
    Gymnasium environment for the Balancing Ball game with continuous action space
    """
    metadata = {'render_modes': ['human', 'rgb_array', 'rgb_array_and_human_in_colab']}

    def __init__(self,
                 render_mode="rgb_array",
                 difficulty="medium",
                 level=3,  # Default to level 3 for adversarial training
                 fps=30,
                 obs_type="game_screen",
                 image_size=(84, 84),
                 num_players=2,
                ):
        """
        render_mode: how to render the environment
            Example: "human" or "rgb_array"
        fps: Frames per second,
            Example: 30
        obs_type: type of observation
            Example: "game_screen" or "state_based"
        image_size: Size to resize images to (height, width)
            Example: (84, 84) - standard for many RL implementations
        num_players: Number of players (2 for adversarial training)
        """

        super(BalancingBallEnv, self).__init__()

        self.num_players = num_players

        # Action space: continuous - Box space for horizontal force [-1.0, 1.0] for each player
        if num_players == 1:
            self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(1,), dtype=np.float32)
        else:
            self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(num_players,), dtype=np.float32)

        # Initialize game
        self.window_x = 300
        self.window_y = 180
        self.platform_shape = "circle"
        self.platform_proportion = 0.333

        # Image preprocessing settings
        self.image_size = image_size

        self.stack_size = 3  # Number of frames to stack
        self.observation_stack = []  # Initialize the stack
        self.render_mode = render_mode

        self.game = BalancingBallGame(
            render_mode=render_mode,
            sound_enabled=(render_mode == "human"),
            difficulty=difficulty,
            window_x = self.window_x,
            window_y = self.window_y,
            level = level,
            fps = fps,
            platform_shape = self.platform_shape,
            platform_proportion = self.platform_proportion,
        )

        if obs_type == "game_screen":
            channels = 1

            # Image observation space with stacked frames
            self.observation_space = spaces.Box(
                low=0, high=255,
                shape=(self.image_size[0], self.image_size[1], channels * self.stack_size),
                dtype=np.uint8,
            )
            self.step = self.step_game_screen
            self.reset = self.reset_game_screen
        elif obs_type == "state_based":
            # State-based observation space for multi-player:
            # [ball1_x, ball1_y, ball1_vx, ball1_vy, ball2_x, ball2_y, ball2_vx, ball2_vy, platform_x, platform_y, platform_angular_velocity]
            obs_size = 4 * num_players + 3  # 4 values per player + 3 platform values
            self.observation_space = spaces.Box(
                low=np.full(obs_size, -1.0),
                high=np.full(obs_size, 1.0),
                dtype=np.float32
            )
            self.step = self.step_state_based
            self.reset = self.reset_state_based
        else:
            raise ValueError("obs_type must be 'game_screen' or 'state_based'")

    def _preprocess_observation(self, observation):
        """Process raw game observation for RL training

        Args:
            observation: RGB image from the game

        Returns:
            Processed observation ready for RL
        """
        observation = np.transpose(observation, (1, 0, 2))

        observation = cv2.cvtColor(observation, cv2.COLOR_RGB2GRAY)
        observation = np.expand_dims(observation, axis=-1)  # Add channel dimension back

        # Resize to target size
        if observation.shape[0] != self.image_size[0] or observation.shape[1] != self.image_size[1]:
            # For grayscale, temporarily remove the channel dimension for cv2.resize
            observation = cv2.resize(
                observation.squeeze(-1),
                (self.image_size[1], self.image_size[0]),
                interpolation=cv2.INTER_AREA
            )
            observation = np.expand_dims(observation, axis=-1)  # Add channel dimension back

        return observation

    def step_game_screen(self, action):
        """Take a step in the environment with continuous actions"""
        # Ensure action is the right shape
        if isinstance(action, (int, float)):
            action = [action]
        elif len(action) != self.num_players:
            # Pad or truncate action to match number of players
            if len(action) < self.num_players:
                action = list(action) + [0.0] * (self.num_players - len(action))
            else:
                action = action[:self.num_players]

        # Take step in the game
        obs, step_rewards, terminated = self.game.step(action)

        # Preprocess the observation
        obs = self._preprocess_observation(obs)

        # Stack the frames
        self.observation_stack.append(obs)
        if len(self.observation_stack) > self.stack_size:
            self.observation_stack.pop(0)  # Remove the oldest frame

        # If the stack isn't full yet, pad it with the current frame
        while len(self.observation_stack) < self.stack_size:
            self.observation_stack.insert(0, obs)  # Pad with current frame at the beginning

        stacked_obs = np.concatenate(self.observation_stack, axis=-1)

        # For multi-agent, return sum of rewards or individual rewards based on your preference
        # Here we return the sum for single-agent training on multi-player game
        total_reward = sum(step_rewards) if isinstance(step_rewards, list) else step_rewards

        # Gymnasium expects (observation, reward, terminated, truncated, info)
        info = {
            'individual_rewards': step_rewards if isinstance(step_rewards, list) else [step_rewards],
            'winner': getattr(self.game, 'winner', None),
            'scores': getattr(self.game, 'score', [0])
        }

        return stacked_obs, total_reward, terminated, False, info

    def reset_game_screen(self, seed=None, options=None):
        """Reset the environment"""
        super().reset(seed=seed)  # This properly seeds the environment in Gymnasium

        observation = self.game.reset()

        # Preprocess the observation
        observation = self._preprocess_observation(observation)

        # Reset the observation stack
        self.observation_stack = []

        # Fill the stack with the initial observation
        for _ in range(self.stack_size):
            self.observation_stack.append(observation)

        # Create stacked observation
        stacked_obs = np.concatenate(self.observation_stack, axis=-1)

        info = {}
        return stacked_obs, info

    def _get_state_based_observation(self):
        """Convert game state to state-based observation for RL agent"""
        obs = []

        # Add each player's state
        for i, player_body in enumerate(self.game.dynamic_body_players):
            # Normalize positions by window dimensions
            ball_x = player_body.position[0] / self.window_x * 2 - 1  # Convert to [-1, 1]
            ball_y = player_body.position[1] / self.window_y * 2 - 1  # Convert to [-1, 1]

            # Normalize velocities (assuming max velocity around 1000)
            max_velocity = 1000
            ball_vx = np.clip(player_body.velocity[0] / max_velocity, -1, 1)
            ball_vy = np.clip(player_body.velocity[1] / max_velocity, -1, 1)

            obs.extend([ball_x, ball_y, ball_vx, ball_vy])

        # Add platform state
        platform_body = self.game.kinematic_body_platforms[0]
        platform_x = platform_body.position[0] / self.window_x * 2 - 1  # Convert to [-1, 1]
        platform_y = platform_body.position[1] / self.window_y * 2 - 1  # Convert to [-1, 1]

        # Normalize angular velocity (assuming max around 10)
        max_angular_velocity = 10
        platform_angular_velocity = np.clip(platform_body.angular_velocity / max_angular_velocity, -1, 1)

        obs.extend([platform_x, platform_y, platform_angular_velocity])

        return np.array(obs, dtype=np.float32)

    def step_state_based(self, action):
        """Take a step in the environment with state-based observations"""
        # Ensure action is the right shape
        if isinstance(action, (int, float)):
            action = [action]
        elif len(action) != self.num_players:
            # Pad or truncate action to match number of players
            if len(action) < self.num_players:
                action = list(action) + [0.0] * (self.num_players - len(action))
            else:
                action = action[:self.num_players]

        # Take step in the game
        _, step_rewards, terminated = self.game.step(action)

        # Get state-based observation
        observation = self._get_state_based_observation()

        # For multi-agent, return sum of rewards
        total_reward = sum(step_rewards) if isinstance(step_rewards, list) else step_rewards

        info = {
            'individual_rewards': step_rewards if isinstance(step_rewards, list) else [step_rewards],
            'winner': getattr(self.game, 'winner', None),
            'scores': getattr(self.game, 'score', [0])
        }

        # Gymnasium expects (observation, reward, terminated, truncated, info)
        return observation, total_reward, terminated, False, info

    def reset_state_based(self, seed=None, options=None):
        """Reset the environment"""
        super().reset(seed=seed)  # This properly seeds the environment in Gymnasium

        self.game.reset()
        observation = self._get_state_based_observation()

        info = {}
        return observation, info

    def render(self):
        """Render the environment"""
        return self.game.render()

    def close(self):
        """Clean up resources"""
        self.game.close()

## Test

In [None]:
# from balancing_ball_game import BalancingBallGame

def run_standalone_game(render_mode="human", difficulty="medium", capture_per_second=3, window_x=1000, window_y=600, level=3):
    """Run the game in standalone mode with visual display"""

    platform_shape = "circle"
    platform_proportion = 0.333

    game = BalancingBallGame(
        render_mode = render_mode,
        difficulty = difficulty,
        window_x = window_x,
        window_y = window_y,
        platform_shape = platform_shape,
        platform_proportion = platform_proportion,
        level = level,
        fps = 30,
        capture_per_second = 3,
    )

    game.run_standalone()

def test_gym_env(episodes=3, difficulty="medium"):
    """Test the OpenAI Gym environment with continuous actions"""
    import time
    # from gym_env import BalancingBallEnv

    fps = 30
    env = BalancingBallEnv(
        render_mode="rgb_array_and_human_in_colab",
        difficulty=difficulty,
        fps=fps,
        level=3,  # Use level 3 for adversarial training
        num_players=2,
    )

    for episode in range(episodes):
        observation, info = env.reset()
        total_reward = 0
        step = 0
        done = False

        while not done:
            # Sample continuous actions for both players
            action = env.action_space.sample()  # Returns array of shape (2,) with values in [-1, 1]

            # Take step
            observation, reward, terminated, truncated, info = env.step(action)

            done = terminated or truncated
            total_reward += reward
            step += 1

            # Render
            env.render()

            # Print some info
            if step % 100 == 0:
                print(f"Step {step}: Action: {action}, Reward: {reward:.2f}, Individual Rewards: {info.get('individual_rewards', [])}")

        winner = info.get('winner', None)
        winner_text = f"Winner: Player {winner + 1}" if winner is not None else "Draw"
        print(f"Episode {episode+1}: Steps: {step}, Total Reward: {total_reward:.2f}, {winner_text}")

    env.close()

## Train

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gymnasium as gym
import sys
import optuna

from stable_baselines3 import PPO
from stable_baselines3.common.policies import ActorCriticPolicy, ActorCriticCnnPolicy  # MLP policy instead of CNN
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback
from stable_baselines3.common.evaluation import evaluate_policy

class Train:
    def __init__(self,
                 learning_rate=0.0003,
                 n_steps=2048,
                 batch_size=64,
                 n_epochs=10,
                 gamma=0.99,
                 gae_lambda=0.95,
                 ent_coef=0.01,
                 vf_coef=0.5,
                 max_grad_norm=0.5,
                 policy_kwargs=None,
                 n_envs=4,
                 difficulty="medium",
                 level=3,  # Default to level 3 for adversarial training
                 load_model=None,
                 log_dir="./logs/",
                 model_dir="./models/",
                 obs_type="game_screen",
                 num_players=2,  # Number of players for adversarial training
                ):

        # Create directories
        os.makedirs(log_dir, exist_ok=True)
        os.makedirs(model_dir, exist_ok=True)
        self.log_dir = log_dir
        self.model_dir = model_dir
        self.n_envs = n_envs
        self.obs_type = obs_type
        self.level = level
        self.num_players = num_players

        # Setup environments
        env = make_vec_env(
            self.make_env(render_mode="rgb_array", difficulty=difficulty, obs_type=obs_type, num_players=num_players),
            n_envs=n_envs
        )
        self.env = env

        # Setup evaluation environment
        eval_env = make_vec_env(
            self.make_env(render_mode="rgb_array", difficulty=difficulty, obs_type=obs_type, num_players=num_players),
            n_envs=1
        )
        self.eval_env = eval_env

        # Create the PPO model
        if load_model:
            print(f"Loading model from {load_model}")
            self.model = PPO.load(
                load_model,
                env=self.env,
                tensorboard_log=log_dir,
            )
        else:
            # 優化的超參數，特別針對對抗訓練
            hyper_param = {
                'learning_rate': 0.0001,  # 降低學習率以提高穩定性
                'gamma': 0.995,  # 提高折扣因子以重視長期獎勵
                'clip_range': 0.15,  # 降低裁切範圍以提高穩定性
                'gae_lambda': 0.98,  # 提高GAE lambda
                'ent_coef': 0.02,  # 提高熵係數以增加探索
                'vf_coef': 0.5,
            }

            policy_kwargs = {
                "features_extractor_kwargs": {"features_dim": 512},
                "net_arch": [512, 512, 256],  # 增加網絡深度以處理複雜策略
                "activation_fn": torch.nn.ReLU,
            }

            policy = ActorCriticCnnPolicy if obs_type == "game_screen" else ActorCriticPolicy
            print("obs type: ", self.obs_type)
            print("policy: ", policy)
            print("num_players: ", self.num_players)

            # PPO for continuous action space with adversarial training
            self.model = PPO(
                policy=policy,
                env=self.env,
                learning_rate=hyper_param["learning_rate"],
                n_steps=n_steps,
                batch_size=batch_size,
                n_epochs=n_epochs,
                gamma=hyper_param["gamma"],
                clip_range=hyper_param["clip_range"],
                gae_lambda=hyper_param["gae_lambda"],
                ent_coef=hyper_param["ent_coef"],
                vf_coef=hyper_param["vf_coef"],
                max_grad_norm=max_grad_norm,
                tensorboard_log=log_dir,
                policy_kwargs=policy_kwargs,
                verbose=1,
            )

    def make_env(self, render_mode="rgb_array", difficulty="medium", obs_type="game_screen", num_players=2):
        """
        Create and return an environment function to be used with VecEnv
        """
        def _init():
            env = BalancingBallEnv(
                render_mode=render_mode,
                difficulty=difficulty,
                level=self.level,
                obs_type=obs_type,
                num_players=num_players
            )
            return env
        return _init

    def train_ppo(self,
                  total_timesteps=1000000,
                  save_freq=10000,
                  eval_freq=10000,
                  eval_episodes=5,
                 ):
        """
        Train a PPO agent to play the Balancing Ball game

        Args:
            total_timesteps: Total number of steps to train for
            n_envs: Number of parallel environments
            save_freq: How often to save checkpoints (in timesteps)
            log_dir: Directory for tensorboard logs
            model_dir: Directory to save models
            eval_freq: How often to evaluate the model (in timesteps)
            eval_episodes: Number of episodes to evaluate on
            difficulty: Game difficulty level
            load_model: Path to model to load for continued training
        """

        # Setup callbacks
        checkpoint_callback = CheckpointCallback(
            save_freq=save_freq // self.n_envs,  # Divide by n_envs as save_freq is in timesteps
            save_path=self.model_dir,
            name_prefix="ppo_balancing_ball_" + str(self.obs_type),
        )

        eval_callback = EvalCallback(
            self.eval_env,
            best_model_save_path=self.model_dir,
            log_path=self.log_dir,
            eval_freq=eval_freq // self.n_envs,
            n_eval_episodes=eval_episodes,
            deterministic=True,
            render=False
        )

        # Train the model
        print("Starting training...")
        self.model.learn(
            total_timesteps=total_timesteps,
            callback=[checkpoint_callback, eval_callback],
        )

        # Save the final model
        self.model.save(f"{self.model_dir}/ppo_balancing_ball_final_" + str(self.obs_type))

        print("Training completed!")
        return self.model

    def evaluate(self, model_path, n_episodes=10, difficulty="medium"):
        """
        Evaluate a trained model

        Args:
            model_path: Path to the saved model
            n_episodes: Number of episodes to evaluate on
            difficulty: Game difficulty level
        """
        # Load the model
        model = PPO.load(model_path)

        # Evaluate
        mean_reward, std_reward = evaluate_policy(
            model,
            self.env,
            n_eval_episodes=n_episodes,
            deterministic=True,
            render=True
        )

        print(f"Mean reward: {mean_reward:.2f} +/- {std_reward:.2f}")

        self.env.close()


# if args.mode == "train":
#     train_ppo(
#         total_timesteps=args.timesteps,
#         difficulty=args.difficulty,
#         n_envs=args.n_envs,
#         load_model=args.load_model,
#         eval_episodes=args.eval_episodes,
#     )
# else:
#     if args.load_model is None:
#         print("Error: Must provide --load_model for evaluation")
#     else:
#         evaluate(
#             model_path=args.load_model,
#             n_episodes=args.eval_episodes,
#             difficulty=args.difficulty
#         )

## Optuna

In [None]:
class Optuna_optimize:
    def __init__(self, obs_type="game_screen", num_players=2):
        self.obs_type = obs_type
        self.num_players = num_players
        self.env = make_vec_env(
            self.make_env(render_mode="rgb_array", difficulty="medium", obs_type=self.obs_type, num_players=num_players),
            n_envs=1
        )

    def make_env(self, render_mode="rgb_array", difficulty="medium", obs_type="game_screen", num_players=2):
        """
        Create and return an environment function to be used with VecEnv
        """
        def _init():
            env = BalancingBallEnv(
                render_mode=render_mode,
                difficulty=difficulty,
                level=3,  # Level 3 for adversarial training
                obs_type=obs_type,
                num_players=num_players
            )
            return env
        return _init

    def optuna_parameter_tuning(self, n_trials):
        print("You are using optuna for automatic parameter tuning, it will create a new model")

        pruner = optuna.pruners.HyperbandPruner(
            min_resource=100,        # 最小资源量
            max_resource='auto',   # 最大资源量 ('auto' 或 整数)
            reduction_factor=3     # 折减因子 (eta)
        )

        # 建立 study 物件，並指定剪枝器
        study = optuna.create_study(direction='maximize', pruner=pruner)

        # 執行優化
        try:
            study.optimize(self.objective, n_trials=n_trials)

            # 分析結果
            print("最佳試驗的超參數：", study.best_trial.params)
            print("最佳試驗的平均回報：", study.best_trial.value)

            import pandas as pd
            df = study.trials_dataframe()
            print(df.head())
        finally:
            self.env.close()
            del self.env


    def objective(self, trial):
        import gc

        # 1. 建議超參數 - Adjusted for continuous action space
        learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-3, log=True)
        gamma = trial.suggest_float('gamma', 0.95, 0.999)
        clip_range = trial.suggest_float('clip_range', 0.1, 0.3)
        gae_lambda = trial.suggest_float('gae_lambda', 0.8, 0.99)
        ent_coef = trial.suggest_float('ent_coef', 0.005, 0.02)  # Lower for continuous actions
        vf_coef = trial.suggest_float('vf_coef', 0.1, 1)
        # features_dim = trial.suggest_categorical('features_dim', [128, 256, 512])

        policy_kwargs = {
            # "features_extractor_kwargs": {"features_dim": features_dim},
            "net_arch": [256, 256],  # Architecture for continuous actions
        }

        n_steps=2048
        batch_size=64
        n_epochs=10
        max_grad_norm=0.5

        policy = ActorCriticCnnPolicy if self.obs_type == "game_screen" else ActorCriticPolicy
        print("obs type: ", self.obs_type)
        print("policy: ", policy)

        # 3. 建立模型 - PPO for continuous action space
        model = PPO(
                policy=policy,
                env=self.env,
                learning_rate=learning_rate,
                n_steps=n_steps,
                batch_size=batch_size,
                n_epochs=n_epochs,
                gamma=gamma,
                clip_range=clip_range,
                gae_lambda=gae_lambda,
                ent_coef=ent_coef,
                vf_coef=vf_coef,
                max_grad_norm=max_grad_norm,
                tensorboard_log=None,
                policy_kwargs=policy_kwargs,
                verbose=0,
            )

        try:
            # 4. 訓練模型
            model.learn(total_timesteps=50000)  # Increased timesteps for adversarial training
            # 5. 評估模型
            mean_reward = evaluate_policy(model, self.env, n_eval_episodes=10)[0]
        finally:
            # Always cleanup
            del model
            gc.collect()

            if TPU_AVAILABLE:
                import torch_xla.core.xla_model as xm
                xm.mark_step()

        return mean_reward

# Training

In [None]:
import gc

# Memory-optimized training setup
def get_tpu_memory_info():
    """Get memory information from TPU device if available"""
    pass

# Display memory information
get_tpu_memory_info()

n_envs = 1
batch_size = 64
n_steps = 2048

# Choose whether to do hyperparameter optimization or direct training
do_optimization = True

if do_optimization: # game_screen, state_based
    optuna_optimizer = Optuna_optimize(obs_type="state_based", num_players=2)
    n_trials = 10
    best_trial = optuna_optimizer.optuna_parameter_tuning(n_trials=n_trials)
    print(f"best_trial found: {best_trial}")
else:
    # Create trainer for adversarial training
    training = Train(
        n_steps=n_steps,
        batch_size=batch_size,
        difficulty="medium",
        n_envs=n_envs,
        level=3,  # Level 3 for adversarial training
        load_model=None,  # Start fresh for adversarial training
        obs_type='game_screen',
        num_players=2,  # Two players for adversarial training
    )

    # Run training with continuous action space
    total_timesteps = 1000000  # More timesteps for adversarial training

    model = training.train_ppo(
        total_timesteps=total_timesteps,
        eval_episodes=5,
        save_freq=10000,
        eval_freq=10000
    )

    print("Adversarial training completed!")

  gym.logger.warn(
  gym.logger.warn(
[I 2025-05-25 11:43:30,804] A new study created in memory with name: no-name-cf79fa37-826e-4068-898d-e08b684502d8


The json memory file does not exist. Creating new file.
self.x_axis_max_reward_rate:  0.0449438202247191
You are using optuna for automatic parameter tuning, it will create a new model
obs type:  state_based
policy:  <class 'stable_baselines3.common.policies.ActorCriticPolicy'>




[1;30;43m流式输出内容被截断，只能显示最后 5000 行内容。[0m
Score:  [71.5103851894361, 64.54027205228513]
Score:  [90.39725099429307, 110.88679703357585]
Score:  [49.75308252358208, 89.3075658558906]
Score:  [21.50306249000874, 104.03911789945634]
Score:  [117.17661213774251, 34.71912151611687]
Score:  [98.08142717786336, 65.53466543716486]
Score:  [86.87598148525804, 80.84931521824133]
Score:  [86.58965865116053, 58.16426457366322]
Score:  [83.72798354286031, 80.56170713594156]
Score:  [66.04723297629107, 90.34738276454482]
Score:  [87.61823187909509, 112.61663355063448]
Score:  [75.73530085625964, 84.1326205382812]
Score:  [54.13658758184229, 107.15087091355221]
Score:  [101.05884122686092, 53.00314883398961]
Score:  [91.0606880528226, 128.3290146743305]
Score:  [86.6204587499034, 71.00534817050153]
Score:  [113.72910318299839, 76.7748816753822]
Score:  [119.79097672107585, 92.61676759919948]
Score:  [115.71538824574571, 105.20579856877916]
Score:  [88.40545328357585, 45.70147940031814]
Score:  [103.33

[I 2025-05-25 11:59:59,532] Trial 0 finished with value: 219.58195299999997 and parameters: {'learning_rate': 8.494030397369189e-05, 'gamma': 0.9952161147986843, 'clip_range': 0.10552714909933264, 'gae_lambda': 0.8597218726804645, 'ent_coef': 0.013737970368979356, 'vf_coef': 0.72382445893691}. Best is trial 0 with value: 219.58195299999997.


[1;30;43m流式输出内容被截断，只能显示最后 5000 行内容。[0m
Score:  [119.79097672107585, 57.588077064742095]
Score:  [98.1606480266228, 102.09151531066397]
Score:  [89.57817079837162, 145.52592295702974]
Score:  [90.35034760451828, 111.31540829147573]
Score:  [77.399692763803, 119.79097672107585]
Score:  [60.221034884894856, 65.47357342907792]
Score:  [119.79097672107585, 118.29595198149252]
Score:  [90.22733148369736, 80.71155435457041]
Score:  [98.84091314939435, 117.11482709964713]
Score:  [80.69619677068151, 101.09148763004424]
Score:  [94.14030355484508, 74.95600650586844]
Score:  [119.79097672107585, 85.95654441459206]
Score:  [104.59268069242479, 85.13230847542381]
Score:  [118.5986459919092, 75.24787391220583]
Score:  [82.04923679761093, 117.26540250232586]
Score:  [108.74398208948446, 100.72193298115546]
Score:  [74.12522290381324, 83.86110955871271]
Score:  [59.270428116352036, 114.33586831462969]
Score:  [119.79097672107585, 115.846483054684]
Score:  [92.75375856742352, 109.48716531815205]
Sco

[I 2025-05-25 12:46:36,850] Trial 1 finished with value: 219.58195299999997 and parameters: {'learning_rate': 6.61449871952151e-05, 'gamma': 0.9670777688990231, 'clip_range': 0.12019069014002766, 'gae_lambda': 0.8036593222765015, 'ent_coef': 0.00726002921357752, 'vf_coef': 0.10684985908166082}. Best is trial 0 with value: 219.58195299999997.


[1;30;43m流式输出内容被截断，只能显示最后 5000 行内容。[0m
Score:  [103.1791304330306, 97.0069229953962]
Score:  [90.14598062732586, 108.3812463949287]
Score:  [88.82094006830951, 110.37930139037238]
Score:  [121.0055809086085, 67.23627216261508]
Score:  [119.79097672107585, 128.60568378310825]
Score:  [79.72444044645046, 106.90242545260313]
Score:  [126.7933055343228, 129.90707478767035]
Score:  [103.55825752785586, 105.92779203911421]
Score:  [119.79097672107585, 104.52109521065918]
Score:  [59.57406809266782, 105.90408330988072]
Score:  [119.79097672107584, 104.14562082879388]
Score:  [102.36193331376776, 57.772851120156034]
Score:  [113.30197881241389, 83.14681949162954]
Score:  [113.86628816601956, 88.57847293858967]
Score:  [73.34539607739488, 106.13347639555502]
Score:  [117.30858841596017, 109.49773713774252]
Score:  [92.57780977535515, 119.79097672107585]
Score:  [76.5285877470083, 95.7068541593407]
Score:  [75.33204331975607, 108.85635172107584]
Score:  [106.7129864934571, 103.72806741346051]


[I 2025-05-25 14:04:37,708] Trial 2 finished with value: 219.58195299999997 and parameters: {'learning_rate': 5.5711220721563506e-05, 'gamma': 0.9591281800578875, 'clip_range': 0.2285640361778947, 'gae_lambda': 0.9430857346999462, 'ent_coef': 0.0173932398969839, 'vf_coef': 0.4905456714177241}. Best is trial 0 with value: 219.58195299999997.


[1;30;43m流式输出内容被截断，只能显示最后 5000 行内容。[0m
Score:  [118.2696169890943, 64.83688514511782]
Score:  [30.84688288154154, 54.513847917726366]
Score:  [76.53673235230299, 71.9508857286138]
Score:  [84.72904113994447, 54.141294285376084]
Score:  [70.71415042141842, 52.58708259987889]
Score:  [77.38617215319945, 65.1333248077811]
Score:  [74.91779663033579, 104.56234478195896]
Score:  [53.13393547810166, 64.91917438053426]
Score:  [49.03440970720074, 56.15790563472786]
Score:  [90.96600046063362, 68.65856550526026]
Score:  [81.55310839417568, 103.82883058144921]
Score:  [101.37031235222011, 76.52934686294633]
Score:  [98.85917328254001, 91.64690268372235]
Score:  [90.74844817931643, 83.97775405241141]
Score:  [43.35352410380575, 94.76550386900318]
Score:  [71.2850899989836, 58.688296572866605]
Score:  [48.361857520743754, 70.3435414244575]
Score:  [90.27093240380489, 91.2776055542588]
Score:  [70.48862386330268, 99.54131432754366]
Score:  [120.88201849633877, 99.9549483901105]
Score:  [81.76953

In [None]:
# Copy the best model to a stable location
!cp /content/models/best_model.zip /content/drive/MyDrive/RL_Models/best_model_$(date +%Y%m%d_%H%M%S).zip

# Optional: Monitor TPU usage
if TPU_AVAILABLE:
    !sudo lsof -w /dev/accel0

In [None]:
# Load a saved model and continue training or evaluate
model_path = "/content/models/best_model.zip"

if os.path.exists(model_path):
    print(f"Loading model from {model_path} for evaluation")

    # Create trainer with the saved model
    eval_trainer = Train(
        n_steps=1024,
        batch_size=batch_size,
        difficulty="medium",
        n_envs=1  # Use 1 env for evaluation
    )

    # Evaluate the model
    eval_trainer.evaluate(
        model_path=model_path,
        n_episodes=5,
        difficulty="medium"
    )
else:
    print(f"Model not found at {model_path}")

# --

In [None]:
# Test the adversarial training environment
run_standalone_game(render_mode="rgb_array_and_human_in_colab", difficulty="medium", window_x=1000, window_y=600, level=3)
# test_gym_env(difficulty="medium")

In [None]:
# Example of creating the environment with continuous action space for adversarial training
env = BalancingBallEnv(
    render_mode="rgb_array",
    difficulty="medium",
    fps=30,
    obs_type="game_screen",
    image_size=(84, 84),
    level=3,  # Level 3 for adversarial training
    num_players=2,  # Two players
)

# Reset environment to get initial observation
obs, info = env.reset()

# Print observation and action space info
print(f"Observation shape: {obs.shape}")  # Should be (84, 84, 3) for grayscale with 3 stacked frames
print(f"Action space: {env.action_space}")  # Should be Box(low=-1, high=1, shape=(2,))
print(f"Action space shape: {env.action_space.shape}")  # Should be (2,) for two players

# Test a random continuous action
action = env.action_space.sample()
print(f"Sample action: {action}")  # Should be array of 2 values between -1 and 1

# Take a step
obs, reward, terminated, truncated, info = env.step(action)
print(f"Step result - Reward: {reward}, Individual rewards: {info.get('individual_rewards', [])}")

# Display a sample observation (first frame only)
import matplotlib.pyplot as plt
plt.figure(figsize=(4, 4))
plt.imshow(obs[:,:,0], cmap='gray')
plt.title("Adversarial Training - Grayscale Observation")
plt.axis('off')
plt.show()

env.close()