In [3]:
import numpy as np
import  gym
import sys
sys.path.append('../')
from sarsa.sarsa import run_SARSA


In [4]:
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def discCP(state):
    x, x_dot, theta, theta_dot = state
    d = np.zeros(4, dtype=int)
    d[0] = np.floor((x + 4.799) / 0.96)
    d[1] = np.floor(sigmoid(x_dot) * 9.99)
    d[2] = np.floor((theta + 23.99) / 4.8)
    d[3] = np.floor(sigmoid(theta_dot) * 9.99)
    return d[0] + 10 * d[1] + 100 * d[2] + 1000 * d[3]

def prepCartPole(render_mode=None):
    env = gym.make('CartPole-v1', render_mode=render_mode)
    env._max_episode_steps = 500
    _a, _b = env.reset(seed=1)
    return env

if __name__ == '__main__':
    env = prepCartPole()
    rewards, policies = run_SARSA(env, total_episodes=50000, max_steps=500, state_discretize=discCP)



Episode: 20, Average Reward: (14.890526272962516, 0, 15.0), Epsilon: 0.9996
Episode: 40, Average Reward: (28.265858782532643, 0, 28.816), Epsilon: 0.9992
Episode: 60, Average Reward: (27.702367427016707, 0, 28.312), Epsilon: 0.9988
Episode: 80, Average Reward: (39.371291588663574, 0, 40.654), Epsilon: 0.9984
Episode: 100, Average Reward: (51.232794237968406, 0, 53.326), Epsilon: 0.9980
Episode: 120, Average Reward: (34.83998188715899, 0, 35.848), Epsilon: 0.9976
Episode: 140, Average Reward: (31.607216044107112, 0, 32.534), Epsilon: 0.9972
Episode: 160, Average Reward: (37.330471921480765, 0, 38.454), Epsilon: 0.9968
Episode: 180, Average Reward: (41.84680147011953, 0, 43.36), Epsilon: 0.9964
Episode: 200, Average Reward: (36.16306257023322, 0, 37.122), Epsilon: 0.9960
Episode: 220, Average Reward: (22.375343168705562, 0, 22.712), Epsilon: 0.9956
Episode: 240, Average Reward: (24.816526146480133, 0, 25.254), Epsilon: 0.9952
Episode: 260, Average Reward: (23.637394494951767, 0, 24.032),

In [5]:
print(policies)

{'episode_2500': array([0, 0, 0, ..., 0, 0, 0], dtype=int64), 'episode_5000': array([0, 0, 0, ..., 0, 0, 0], dtype=int64), 'episode_7500': array([0, 0, 0, ..., 0, 0, 0], dtype=int64), 'episode_10000': array([0, 0, 0, ..., 0, 0, 0], dtype=int64), 'episode_15000': array([0, 0, 0, ..., 0, 0, 0], dtype=int64), 'episode_20000': array([0, 0, 0, ..., 0, 0, 0], dtype=int64), 'episode_25000': array([0, 0, 0, ..., 0, 0, 0], dtype=int64), 'episode_30000': array([0, 0, 0, ..., 0, 0, 0], dtype=int64), 'episode_35000': array([0, 0, 0, ..., 0, 0, 0], dtype=int64), 'episode_40000': array([0, 0, 0, ..., 0, 0, 0], dtype=int64), 'episode_45000': array([0, 0, 0, ..., 0, 0, 0], dtype=int64), 'episode_50000': array([0, 0, 0, ..., 0, 0, 0], dtype=int64), 'optimal_policy': array([0, 0, 0, ..., 0, 0, 0], dtype=int64)}


In [8]:
import os
import numpy as np
import imageio
import gym
import cv2  


def prepCartPole(render_mode=None):
    env = gym.make('CartPole-v1', render_mode=render_mode)
    env._max_episode_steps = 500
    _a, _b = env.reset(seed=1)
    return env

def save_frames_as_gif(frames, path):
    """Save a list of frames as a GIF file."""
    imageio.mimsave(path, frames, fps=20)

def embed_text(frame, text, position=(10, 60), font_scale=1, color=(0, 0, 255), thickness=2):
    """Embed text on a frame using OpenCV."""
    frame = cv2.putText(frame, text, position, cv2.FONT_HERSHEY_SIMPLEX, 
                        font_scale, color, thickness, cv2.LINE_AA)
    return frame

def run_env_with_policy_and_save_frames(env, policy, folder_name, game_count=10, state_discretize=None):
    os.makedirs(folder_name, exist_ok=True)

    frames = []
    
    for game in range(1, game_count + 1):
        observation, info = env.reset()
        if state_discretize:
            observation = state_discretize(observation) 
        
        terminated = False
        truncated = False
        step = 0

        while not (terminated or truncated) and step <= 500:
            action = int(policy[observation]) 
            observation, reward, terminated, truncated, info = env.step(action)

            if state_discretize:
                observation = state_discretize(observation)

            frame = env.render()
            frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
            frame = embed_text(frame, f'Game: {game}, Step: {step}', position=(10, 60), color=(0, 0, 255))  # Red text
            frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
            step += 1

        if terminated or truncated:
            frame = embed_text(frame, "Game Over", position=(10, 100), font_scale=1.2, color=(0, 0, 255))
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.extend([frame] * 60)
    save_frames_as_gif(frames, os.path.join(folder_name, f'policy_{folder_name}_games.gif'))


def run_policies_and_save_frames(policies, state_discretize=None):
    for policy_key, policy in policies.items():
        try:
            print(f"Running policy: {policy_key}")
            env = prepCartPole(render_mode="rgb_array")  
            run_env_with_policy_and_save_frames(env, policy, folder_name=policy_key, state_discretize=state_discretize)
        except Exception as e:
            print(f"Error running policy: {policy_key}")
            print(e)

run_policies_and_save_frames(policies, state_discretize=discCP)



Running policy: episode_2500
Running policy: episode_5000
Running policy: episode_7500
Running policy: episode_10000
Running policy: episode_15000
Running policy: episode_20000
Running policy: episode_25000
Running policy: episode_30000
Running policy: episode_35000
Running policy: episode_40000
Running policy: episode_45000
Running policy: episode_50000
Running policy: optimal_policy
