In [None]:
import gymnasium as gym
import numpy as np
import cv2
import stable_baselines3 as sb3
from stable_baselines3 import PPO
from gymnasium.wrappers import ResizeObservation
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack
from PIL import Image
from stable_baselines3.common.vec_env import VecEnvWrapper
import ale_py

In [None]:
model_path_right="../../models/pong/left_best_model.zip"
model_path_left="../../models/pong/right_best_model.zip"

# This code is for testing right side

In [None]:
# This code is for testing right side

class MyVecTransposeImage(VecEnvWrapper):
    def __init__(self, venv, skip=False):
        super().__init__(venv)
        self.skip = skip

        # Get original shape: e.g., (84, 84, 4)
        old_shape = self.observation_space.shape
        # Transpose shape to (C, H, W)
        new_shape = (old_shape[2], old_shape[0], old_shape[1])  # (4, 84, 84)

        # Use the original low/high if they are uniform; if not, use min/max appropriately
        low_val = self.observation_space.low.min()
        high_val = self.observation_space.high.max()

        self.observation_space = gym.spaces.Box(
            low=low_val,
            high=high_val,
            shape=new_shape,
            dtype=self.observation_space.dtype
        )

    def reset(self):
        obs = self.venv.reset()
        return self.transpose_observations(obs)

    def step_async(self, actions):
        self.venv.step_async(actions)

    def step_wait(self):
        obs, rewards, dones, infos = self.venv.step_wait()
        return self.transpose_observations(obs), rewards, dones, infos

    def transpose_observations(self, obs):
        if self.skip:
            return obs
        if isinstance(obs, dict):
            for key, val in obs.items():
                obs[key] = self._transpose(val)
            return obs
        else:
            return self._transpose(obs)

    def _transpose(self, obs):
        # obs shape is (n_envs, H, W, C) -> transpose to (n_envs, C, H, W)
        return obs.transpose(0, 3, 1, 2)
    

class ScaledFloatFrame(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = gym.spaces.Box(
            low=0.0,
            high=1.0,
            shape=self.observation_space.shape,
            dtype=np.float32
        )

    def observation(self, obs):
        return np.array(obs).astype(np.float32) / 255.0


class AddChannelDimension(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        obs_shape = self.observation_space.shape
        self.observation_space = gym.spaces.Box(
            low=0,
            high=255,
            shape=(obs_shape[0], obs_shape[1], 1),
            dtype=np.uint8,
        )

    def observation(self, observation):
        # Add a channel dimension
        return np.expand_dims(observation, axis=-1)


class FireResetEnv(gym.Wrapper):
    def __init__(self, env=None):
        super().__init__(env)
        assert 'FIRE' in env.unwrapped.get_action_meanings(), "Environment does not support 'FIRE' action"
        assert len(env.unwrapped.get_action_meanings()) >= 3, "Action space too small for expected actions"

    def step(self, action):
        return self.env.step(action)

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        obs, _, terminated, truncated, _ = self.env.step(1)
        if terminated or truncated:
            obs, info = self.env.reset(**kwargs)
        return obs, info


# Custom wrapper to transpose frames to channel-first
class ChannelFirstWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        obs_shape = self.observation_space.shape
        self.observation_space = gym.spaces.Box(
            low=self.observation_space.low.transpose(2, 0, 1),
            high=self.observation_space.high.transpose(2, 0, 1),
            shape=(obs_shape[2], obs_shape[0], obs_shape[1]),
            dtype=self.observation_space.dtype,
        )

    def observation(self, obs):
        return obs.transpose(2, 0, 1)  # Convert to (C, H, W)


# Function to create the test environment with preprocessing
def create_test_env(env_name, n_stack=4):
    def _init():
        env = gym.make(env_name, obs_type="grayscale", render_mode="rgb_array")
        env = FireResetEnv(env)
        env = ResizeObservation(env, (84, 84))
        env = AddChannelDimension(env)
        env = ScaledFloatFrame(env)
        print("ScaledFloatFrame     : {}".format(env.observation_space.shape))
        return env

    env = DummyVecEnv([_init])  # Create vectorized environment
    print("DummyVecEnv          : {}".format(env.observation_space.shape))
    env = VecFrameStack(env, n_stack=n_stack)  # Stack 4 frames
    print("VecFrameStack         : {}".format(env.observation_space.shape))
    env = MyVecTransposeImage(env)
    print("MyVecTransposeImage  : {}".format(env.observation_space.shape))
    return env


# Function to save gameplay as a GIF
def save_video_as_gif(frames, filename="gameplay.gif"):
    frames = [Image.fromarray(frame) for frame in frames]
    frames[0].save(
        filename,
        save_all=True,
        append_images=frames[1:],
        duration=20,  # Set the frame duration (in ms)
        loop=0  # Infinite loop
    )
    print(f"Gameplay saved as {filename}")


# Test the trained model and record video
def test_model_and_save_video(model_path, env_name, timesteps=2500, gif_filename="gameplay.gif"):
    # Load the trained model
    model = PPO.load(model_path)

    # Create the test environment
    env = create_test_env(env_name)
    obs = env.reset()

    frames = []
    total_reward = 0

    for t in range(timesteps):
        # Get action from the model
        action, _ = model.predict(obs, deterministic=True)
        obs, reward, dones, infos = env.step(action)
        total_reward += reward

        # Render and store the frame
        rendered_frame = env.envs[0].render()
        frames.append(rendered_frame)

        if dones.any():  # If any environment finishes
            print(f"Episode finished after {t + 1} timesteps.")
            break

    # Save the video as a GIF
    save_video_as_gif(frames, filename=gif_filename)

    print(f"Total Reward: {total_reward.sum()}")
    env.close()


# Test the right-paddle model and save video as a GIF
test_model_and_save_video(
    model_path=model_path_right,
    env_name="PongNoFrameskip-v4",
    timesteps=4500,
    gif_filename="../../videos/pong/right_paddle_gameplay.gif"
)


ScaledFloatFrame     : (84, 84, 1)
DummyVecEnv          : (84, 84, 1)
VecFrameStack         : (84, 84, 4)
MyVecTransposeImage  : (4, 84, 84)
Gameplay saved as videos/right_paddle_gameplay.gif
Total Reward: 14.0


# This code is for testing left side

In [6]:
# This code is for testing left side

class InvertPongWrapper(gym.ObservationWrapper):
    """
    Wrapper to invert the observations and actions to train the left paddle in Pong.
    """
    def __init__(self, env):
        super().__init__(env)

        # Directly copy the low and high values from the original observation space
        obs_space = self.observation_space
        self.observation_space = gym.spaces.Box(
            low=obs_space.low,  # No need to flip as values are uniform
            high=obs_space.high,
            shape=obs_space.shape,
            dtype=obs_space.dtype,
        )

    def observation(self, obs):
        """
        Flip the screen horizontally so that the left paddle is treated as the primary agent.
        """
        return np.flip(obs, axis=1)  # Flip the width axis

    def step(self, action):
        """
        Flip the actions to control the left paddle.
        """
        # Invert the action logic
        if action == 2:  # RIGHT
            action = 3  # LEFT
        elif action == 3:  # LEFT
            action = 2  # RIGHT
        elif action == 4:  # RIGHTFIRE
            action = 5  # LEFTFIRE
        elif action == 5:  # LEFTFIRE
            action = 4  # RIGHTFIRE

        # Perform the step with the flipped action
        return super().step(action)

class MyVecTransposeImage(VecEnvWrapper):
    def __init__(self, venv, skip=False):
        super().__init__(venv)
        self.skip = skip

        # Get original shape: e.g., (84, 84, 4)
        old_shape = self.observation_space.shape
        # Transpose shape to (C, H, W)
        new_shape = (old_shape[2], old_shape[0], old_shape[1])  # (4, 84, 84)

        # Use the original low/high if they are uniform; if not, use min/max appropriately
        low_val = self.observation_space.low.min()
        high_val = self.observation_space.high.max()

        self.observation_space = gym.spaces.Box(
            low=low_val,
            high=high_val,
            shape=new_shape,
            dtype=self.observation_space.dtype
        )

    def reset(self):
        obs = self.venv.reset()
        return self.transpose_observations(obs)

    def step_async(self, actions):
        self.venv.step_async(actions)

    def step_wait(self):
        obs, rewards, dones, infos = self.venv.step_wait()
        return self.transpose_observations(obs), rewards, dones, infos

    def transpose_observations(self, obs):
        if self.skip:
            return obs
        if isinstance(obs, dict):
            for key, val in obs.items():
                obs[key] = self._transpose(val)
            return obs
        else:
            return self._transpose(obs)

    def _transpose(self, obs):
        # obs shape is (n_envs, H, W, C) -> transpose to (n_envs, C, H, W)
        return obs.transpose(0, 3, 1, 2)
    

class ScaledFloatFrame(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = gym.spaces.Box(
            low=0.0,
            high=1.0,
            shape=self.observation_space.shape,
            dtype=np.float32
        )

    def observation(self, obs):
        return np.array(obs).astype(np.float32) / 255.0


class AddChannelDimension(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        obs_shape = self.observation_space.shape
        self.observation_space = gym.spaces.Box(
            low=0,
            high=255,
            shape=(obs_shape[0], obs_shape[1], 1),
            dtype=np.uint8,
        )

    def observation(self, observation):
        # Add a channel dimension
        return np.expand_dims(observation, axis=-1)


class FireResetEnv(gym.Wrapper):
    def __init__(self, env=None):
        super().__init__(env)
        assert 'FIRE' in env.unwrapped.get_action_meanings(), "Environment does not support 'FIRE' action"
        assert len(env.unwrapped.get_action_meanings()) >= 3, "Action space too small for expected actions"

    def step(self, action):
        return self.env.step(action)

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        obs, _, terminated, truncated, _ = self.env.step(1)
        if terminated or truncated:
            obs, info = self.env.reset(**kwargs)
        return obs, info


# Custom wrapper to transpose frames to channel-first
class ChannelFirstWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        obs_shape = self.observation_space.shape
        self.observation_space = gym.spaces.Box(
            low=self.observation_space.low.transpose(2, 0, 1),
            high=self.observation_space.high.transpose(2, 0, 1),
            shape=(obs_shape[2], obs_shape[0], obs_shape[1]),
            dtype=self.observation_space.dtype,
        )

    def observation(self, obs):
        return obs.transpose(2, 0, 1)  # Convert to (C, H, W)


# Function to create the test environment with preprocessing
def create_test_env(env_name, n_stack=4):
    def _init():
        env = gym.make(env_name, obs_type="grayscale", render_mode="rgb_array")
        env = FireResetEnv(env)
        env = ResizeObservation(env, (84, 84))
        env = AddChannelDimension(env)
        env = ScaledFloatFrame(env)
        print("ScaledFloatFrame     : {}".format(env.observation_space.shape))
        env = InvertPongWrapper(env)
        print("InvertPongWrapper    : {}".format(env.observation_space.shape))
        return env

    env = DummyVecEnv([_init])  # Create vectorized environment
    print("DummyVecEnv          : {}".format(env.observation_space.shape))
    env = VecFrameStack(env, n_stack=n_stack)  # Stack 4 frames
    print("VecFrameStack         : {}".format(env.observation_space.shape))
    env = MyVecTransposeImage(env)
    print("MyVecTransposeImage  : {}".format(env.observation_space.shape))
    return env


# Function to save gameplay as a GIF from the perspective of the left agent
def save_video_as_left_agent(frames, filename="gameplay_as_left_agent.gif"):
    flipped_frames = [np.flip(frame, axis=1) for frame in frames]  # Flip frames horizontally
    pil_frames = [Image.fromarray(frame) for frame in flipped_frames]
    pil_frames[0].save(
        filename,
        save_all=True,
        append_images=pil_frames[1:],
        duration=20,  # Set the frame duration (in ms)
        loop=0  # Infinite loop
    )
    print(f"Gameplay saved as {filename}")


def test_model_and_save_video_as_left_agent(model_path, env_name, timesteps=2500, gif_filename="gameplay_as_left_agent.gif"):
    """
    Test a trained model for the left paddle on Pong and save gameplay as a GIF
    from the perspective of the left agent.
    """
    # Load the trained model
    model = PPO.load(model_path)

    # Create the test environment
    env = create_test_env(env_name)
    obs = env.reset()

    frames = []
    total_reward = 0

    for t in range(timesteps):
        # Get action from the model
        action, _ = model.predict(obs, deterministic=True)
        obs, reward, dones, infos = env.step(action)
        total_reward += reward.sum()

        # Render and store the frame
        rendered_frame = env.envs[0].render()
        frames.append(rendered_frame)

        if dones.any():  # If any environment finishes
            print(f"Episode finished after {t + 1} timesteps.")
            obs = env.reset()  # Reset the environment
            break

    # Save the video as a GIF with flipping for the left agent perspective
    save_video_as_left_agent(frames, filename=gif_filename)

    print(f"Total Reward: {total_reward}")
    env.close()


# Test the left-paddle model and save video as a GIF from the left perspective
test_model_and_save_video_as_left_agent(
    model_path=model_path_left, # Update with the actual path
    env_name="PongNoFrameskip-v4",
    timesteps=2500,
    gif_filename="../../videos/pong/left_paddle_gameplay_as_left_agent.gif"
)



ScaledFloatFrame     : (84, 84, 1)
InvertPongWrapper    : (84, 84, 1)
DummyVecEnv          : (84, 84, 1)
VecFrameStack         : (84, 84, 4)
MyVecTransposeImage  : (4, 84, 84)
Gameplay saved as videos/left_paddle_gameplay_as_left_agent.gif
Total Reward: 7.0
