# [What is Reinforcement Learning?](https://gymnasium.farama.org/introduction/basic_usage/)

<img src="https://gymnasium.farama.org/_images/AE_loop_dark.png" alt="Alt text" width="50%" height="auto">

# 1. Load a game
> run install.sh first

In [None]:
%cd .. # Should be root directory

> ## Environment

In [None]:
from environment.environment import WarehouseBrawl, RenderMode
import numpy as np
import matplotlib.pyplot as plt

# Environment
env = WarehouseBrawl(mode=RenderMode.RGB_ARRAY)

# Arbitary Action
env.step({
    0: np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), # agent 1
    1: np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) # agent 2
}) 

img = env.camera.get_frame(env, mode=RenderMode.RGB_ARRAY)
plt.imshow(np.rot90(img, -1))
plt.axis('off')
plt.show()

> ## Observation

In [None]:
observation = env.get_observation_space()

print("observation.shape", observation.shape)
print("Lower bounds of the intervals", observation.low, sep="\n")
print("Upper bounds of the intervals", observation.high, sep="\n")


In [None]:
env.obs_helper.print_all_sections()

for k, v in env.obs_helper.sections.items():
    print(k, v) # key: (index start, index end)

> ## Action

In [None]:
actions = env.get_action_space()
env.act_helper.sections.keys()

print(actions.shape)

# 1. W (Aim up)
# 2. A (Left)
# 3. S (Aim down/fastfall)
# 4. D (Right)
# 5. Space (Jump)
# 6. H (Pickup/Throw)
# 7. L (Dash/Dodge)
# 8. J (Light Attack)
# 9. K (Heavy Attack)
# 10. G (Taunt)

# 2. Build an Agent and Test

> ## Define Functions for Test
> Don't need to read

In [None]:
import pygame
import numpy as np
import time

def visualize_with_pygame(env, window, video_writer=None):
    img = np.rot90(env.camera.get_frame(env, mode=RenderMode.RGB_ARRAY), -1)
    frame_height, frame_width = img.shape[:2]

    if window is None:
        window = pygame.display.set_mode((frame_width, frame_height))
        pygame.display.set_caption('WarehouseBrawl Visualization')
    
    # Pygame UI rendering logic
    if img.shape[0] != frame_height or img.shape[1] != frame_width:
        img = np.resize(img, (frame_height, frame_width, 3))
    surface = pygame.surfarray.make_surface(np.transpose(img, (1, 0, 2)))
    window.blit(surface, (0, 0))
    # pygame.display.flip()
    
    # Flip the display horizontally (mirror image)
    surface = pygame.transform.flip(surface, True, False)
    window.blit(surface, (0, 0))
    pygame.display.flip()

    # If a video writer is provided, save the frame
    if video_writer is not None:
        # For most video writers, e.g., skvideo.io.FFmpegWriter
        # Make sure img is in shape (height, width, 3) and dtype uint8
        frame_to_save = np.flip(img, axis=1)  # Flip horizontally to match display
        if frame_to_save.dtype != np.uint8:
            frame_to_save = frame_to_save.astype(np.uint8)
        video_writer.writeFrame(frame_to_save)
    
    return window

def run_game_with_visualization(policy_func_agnet1, policy_func_agnet2, max_episode_length=300, save_video=False, save_video_path='game_video.mp4', delay_every_frame=0):
    if save_video:
        import skvideo.io
        video_writer = skvideo.io.FFmpegWriter(save_video_path, outputdict={'-pix_fmt': 'yuv420p'})
    else:
        video_writer = None # Disable video recording (Use the code above to save video)

    pygame.init()
    env.reset()
    observation, *_ = env.step({
        0: np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), 
        1: np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
    })

    terminated = False
    truncated = False
    n_episode_length = 0
    running = True
    window = None

    while running and not terminated and not truncated and n_episode_length < max_episode_length:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                running = False
                break

        action = {  # step environment with random actions
            0: policy_func_agnet1(observation),  # agent 1
            1: policy_func_agnet2(observation)   # agent 2
        }
        observation, reward, terminated, truncated, info = env.step(action)
        print("Reward:", reward, " "*64, end="\r")

        window = visualize_with_pygame(env, window, video_writer=video_writer)

        n_episode_length += 1
        if delay_every_frame > 0:
            time.sleep(delay_every_frame)
    
    if save_video:
        video_writer.close()

    pygame.quit()


> ## Policy (Agent)
> Edit this function to build your own agent

In [None]:
def policy(observation):
    if np.random.rand() < 0.05: # 5% chance to jump
        return np.array([0, 0, 0, 0, 1, 0, 0, 0, 0, 0])  # Jump
    else:
        return np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])  # No jump
    
    # return env.get_action_space().sample() # Random Action

> ## Test

In [None]:
run_game_with_visualization(
    policy_func_agnet1=policy,
    policy_func_agnet2=policy, 
    max_episode_length=100,
    save_video=True,
    save_video_path="game_video.mp4",
)

In [None]:
from IPython.display import Video

# For a local video file (ensure the file is in the same directory or provide the full path)
Video("./game_video.mp4", width=640, height=360, embed=True)