In [None]:
#pip install 'gymnasium[box2d]'

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions.categorical import Categorical

import numpy as np
import matplotlib.pyplot as plt
import gymnasium as gym

from tqdm import *
from ema_pytorch import EMA
from gymnasium.wrappers import *

In [None]:
def make_env(env_id):
    def wrapped_env():
        env = gym.make(env_id, render_mode='rgb_array')
        return env
    return wrapped_env

env_id = 'CartPole-v1'
env = make_env(env_id)()
obs = env.reset()


img = env.render()
plt.imshow(img)
plt.axis('off')
plt.show()

In [None]:
obs_shape = env.observation_space.shape
action_shape = env.action_space.shape
n_action = env.action_space.n
init_obs, _ = env.reset()

print(init_obs.shape)
print('obs_shape:', obs_shape)
print('action_shape:', action_shape)
print('n_action:', n_action)

In [None]:
class Actor(nn.Module):
    def __init__(self, obs_dim, n_actions):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(obs_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, n_actions)
        )
        
    def forward(self, x):
        logits = self.model(x)
        return Categorical(logits = logits)

class Critic(nn.Module):
    def __init__(self, obs_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(obs_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
    
    def forward(self, x):
        return self.model(x)

In [None]:
gamma = 0.99
episodes = 200
env = make_env(env_id)()
#device = "cuda" if torch.cuda.is_available() else "cpu"
device = 'cpu'

actor = Actor(4, 2).to(device)
critic = Critic(obs_shape[0]).to(device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=1e-3)
critic_optim = torch.optim.Adam(critic.parameters(), lr=1e-3)


In [None]:
for episode in tqdm(range(1, episodes + 1), desc="A2C training progress: ~~"):          
    done = False
    obs, _ = env.reset()
    log_probs = []
    values = []
    rewards = []
    frame = 0
    while not done:
        obs = torch.tensor(obs).to(device)
        dist = actor(obs)
        value = critic(obs)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        next_obs, reward, done, _, info = env.step(action.detach().cpu().numpy())
      
        log_probs.append(log_prob)
        values.append(value)
        rewards.append(reward)
        obs = next_obs
        
        frame += 1
        if frame >= 300:
            # Give some constraints to the length of the episode
            # otherwise, it will become too large.
            break   
    R = 0
    returns = []
    for r in rewards[::-1]:
        R = r + gamma * R
        returns.insert(0, R)
        
    returns = torch.tensor(returns).to(device)
    values = torch.stack(values).to(device)
    advantage = returns - values
    
    critic_loss = advantage.pow(2).mean()
    critic_optim.zero_grad()
    critic_loss.backward()
    critic_optim.step() # update critic
    
    actor_loss = (-torch.stack(log_probs)*advantage.detach()).mean()
    actor_optim.zero_grad()
    actor_loss.backward()
    actor_optim.step() # update actor

    if episode % 20 == 0:
        torch.save(actor.state_dict(), f"./weight_epoch.pt")
    

In [None]:
import imageio
from PIL import Image

PATH = "./weight_epoch.pt"
actor = Actor(4, 2).to(device)
actor.load_state_dict(torch.load(PATH))

In [None]:
def play_and_record_modified(env_names, model, episodes=1, filename='gameplay.gif'):
    # Create multiple environments
    envs = [make_env(name)() for name in env_names]
    frames = []
    for episode in range(episodes):
        obses = [env.reset()[0] for env in envs]
        done = [False] * len(envs)
        num_frame = 0
        while not all(done):
            merged_frame = None 
            
            for i, env in enumerate(envs):
                if not done[i]:
                    obs = torch.FloatTensor(obses[i]).unsqueeze(0).to(device) 
                    dist = model(obs)  # Get action from your model
                    action = dist.sample()[0]
                    obses[i], _, done[i], _, _ = env.step(action.cpu().numpy())
                frame = env.render()
                if merged_frame is None:
                    merged_frame = frame
                else:
                    merged_frame = np.concatenate((merged_frame, frame), axis=1)
                    
            # Convert array to PIL Image and then append to frames list
            frames.append(Image.fromarray(merged_frame))
            
            for env in envs:
                if all(done):
                    env.close()
    # Save frames as GIF
    imageio.mimsave(filename, frames, fps=30)

# Example usage
actor.eval()
env_names = ['CartPole-v1', 'CartPole-v1', 'CartPole-v1'] 
play_and_record_modified(env_names, actor, episodes=1, filename='./CartPole-v1.gif')