In [None]:
import ray
import time
import numpy as np
import scipy
import torch
import sklearn
import supersuit as ss
import matplotlib.pyplot as plt
import os
import cv2

%matplotlib inline

from ray import tune
from ray.rllib.algorithms.ddpg import DDPG, DDPGConfig
from ray.rllib.algorithms.sac import SAC, SACConfig
from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
from pettingzoo.mpe import simple_tag_v3
import gymnasium as gym

checkpoint_path = r"C:\Users\rkamatar\Downloads\checkpoint_000013"
video_filename='sac_policy_11_27_T14_37_checkpoint13.mp4'

def env_creator():
    env = simple_tag_v3.parallel_env(num_good=2, num_adversaries=4, num_obstacles=0, max_cycles=100, continuous_actions=True, render_mode="rgb_array")
    env = ss.pad_observations_v0(env)
    env = ss.pad_action_space_v0(env)
    env = ss.frame_stack_v1(env, 3)
    env = ss.dtype_v0(env, np.float32)  # Ensure observations are float32
    return env

# env = env_creator()

def load_checkpoint(checkpoint_path, config):
    algo = SAC(config=config)
    # algo.restore(checkpoint_path)
    return algo

def save_rendered_frame(frame, frame_count):
    # Convert frame from RGB to BGR for OpenCV compatibility
    frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
    
    # Save the frame as an image file
    frame_file = f'{frame_dir}/frame_{frame_count}.png'
    cv2.imwrite(frame_file, frame)

def save_frame(obs, frame_count):
    for agent, observation in obs.items():
        # Ensure the observation is in the correct format (HxWxC)
        if observation.ndim == 3 and observation.shape[0] == 3:  # If CxHxW format
            frame = observation.transpose(1, 2, 0)
        else:
            frame = observation

        # Normalize and convert to uint8 if necessary
        if frame.dtype != np.uint8:
            frame = (frame * 255).astype(np.uint8)

        # Ensure the frame is in RGB format (OpenCV uses BGR)
        if frame.shape[-1] == 3:
            frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)

        frame_file = f'{frame_dir}/frame_{frame_count}_{agent}.png'
        cv2.imwrite(frame_file, frame)

    return frame_count + 1

def render_environment(env, algo, frame_dir, video_dir, num_episodes=5, video_filename=f"output_video.mp4", fps=10):
    env.reset()
    frame_count = 0
    
    # Initialize video writer
    video_writer = None
    
    for episode in range(num_episodes):
        obs, _ = env.reset()
        done = {agent: False for agent in obs.keys()}
        episode_reward = 0
        step = 0
        
        while not all(done.values()):
            actions = {agent: algo.compute_single_action(obs[agent], policy_id=agent) for agent in obs.keys()}
            obs, rewards, terminated, truncated, _ = env.step(actions)
            done = {agent: terminated[agent] or truncated[agent] for agent in obs.keys()}
            episode_reward += sum(rewards.values())
            
            # Capture and save the rendered frame
            rendered_frame = env.render()  # Get pixel array from render()
            
            if rendered_frame is not None:
                # Save individual frames as images
                save_rendered_frame(rendered_frame, frame_count)
                
                # Initialize VideoWriter if it's not already initialized
                if video_writer is None:
                    height, width, layers = rendered_frame.shape
                    fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # Codec for MP4
                    video_writer = cv2.VideoWriter(f"{video_dir}/{video_filename}", fourcc, fps, (width, height))
                
                # Write the frame to the video file
                video_writer.write(cv2.cvtColor(rendered_frame, cv2.COLOR_RGB2BGR))
                
                frame_count += 1
            
            time.sleep(0.1)
            step += 1
        
        print(f"Episode {episode + 1} finished after {step} steps. Total reward: {episode_reward}")
    
    env.close()
    
    # Release the VideoWriter to finalize the video file
    if video_writer is not None:
        video_writer.release()
    
    print(f"Frames saved in {frame_dir}")
    print(f"Video saved as {video_filename}")

class CustomParallelPettingZooEnv(ParallelPettingZooEnv):
    def __init__(self, env):
        super().__init__(env)
        self.env = env  # Store the original environment

    def render(self, mode='human'):
        return self.env.render()  # Pass the mode parameter

ray.init(ignore_reinit_error=True)

env_name = "simple_tag"
tune.register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator()))

# frame_dir = './outputs/saved_frames'
# video_dir = "./outputs/saved_video"
# if not os.path.exists(frame_dir):
#     os.makedirs(frame_dir)
#     os.makedirs(video_dir)
config = (
            SACConfig()
            .environment(env='simple_tag')
            .framework("torch")
            .rollouts(num_rollout_workers=7)
            # .resources(num_cpus_per_worker=1, num_gpus_per_worker=1/8)
            .training(
                lr = 1e-4,
                tau=.01,
                train_batch_size=1024,
                gamma=.95,
            )
            # .multi_agent(
            #     policies={agent: (None, env.observation_space(agent), env.action_space(agent), {})
            #             for agent in env.possible_agents},
            #     policy_mapping_fn=lambda agent_id, *args, **kwargs: agent_id,
            # )
        )
# Load the algorithm from the checkpoint
# loaded_algo = load_checkpoint(checkpoint_path, config)
algo = SAC(config=config)
algo.restore(checkpoint_path)

render_env = env_creator()
print(algo.get_policy())
policy = algo.get_policy()
# Render the environment with the loaded algorithm
frame_dir = './outputs/saved_frames'
video_dir = "./outputs/saved_video"
render_environment(render_env, algo, frame_dir, video_dir, num_episodes=2, video_filename=video_filename)