In [None]:
pip install ema_pytorch

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt
import gymnasium as gym
import flappy_bird_gymnasium
from ema_pytorch import EMA

In [None]:
def make_env(env_id):
    def wrapped_env():
        env = gym.make(env_id, render_mode='rgb_array')
        # could add some environment wrapper
        return env
    return wrapped_env

env_id = 'FlappyBird-v0'
env = make_env(env_id)()
obs, _ = env.reset()
image = env.render()

plt.imshow(image)
plt.axis('off')
plt.show()

In [None]:
class BirdPolicy(nn.Module):
    def __init__(self):
        super(BirdPolicy, self).__init__()
        self.fc1 = nn.Linear(180, 128)  
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(128, 128)  
        self.relu2 = nn.ReLU()   
        self.fc3 = nn.Linear(128, 2)  
        self.softmax = nn.Softmax(dim=-1) 

    def forward(self, x):
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        x = self.fc2(x)
        return self.softmax(x)

In [None]:
def collect_traj(env, policy, device):
    obs, _ = env.reset()
    log_probs = []
    rewards = []
    done = False
    
    while not done:
        obs = torch.FloatTensor(obs).unsqueeze(0).to(device) # Add batch dimension
        probs = policy(obs)
        action = torch.distributions.Categorical(probs).sample()
        log_prob = torch.log(probs.squeeze(0)[action])
        obs, reward, done, _, info = env.step(action.item())
        log_probs.append(log_prob)
        rewards.append(reward)
    return log_probs, rewards

def REINFORCE(env, episodes, policy, ema, device, optimizer):
    for episode in range(episodes):
        Return = []
        Gt = 0
        # collect trajectory
        log_probs, rewards = collect_traj(env, policy, device)

        for reward in rewards[::-1]:
            # bellman equation
            # computed from back to front
            Gt = reward + 0.99 * Gt
            Return.insert(0, Gt)
        
        Return = torch.tensor(Return).to(device)
        policy_loss = []
        for log_prob, R in zip(log_probs, Return):
            # compute policy gradient
            policy_loss.append(-log_prob * R)
        policy_loss = torch.cat(policy_loss).mean()
        
        optimizer.zero_grad()
        policy_loss.backward()
        optimizer.step()
        ema.update()
            
        if episode % 1000 == 0:
            log_probs, rewards = collect_traj(env, ema, device)
            print(f'Episode {episode+1}, Total Reward: {sum(rewards)}')
            torch.save(ema.state_dict(), f"./weight_epoch.pt")

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
policy = BirdPolicy().to(device) 
ema = EMA(policy, beta = 0.99, update_every = 1)
optimizer = optim.Adam(policy.parameters(), lr=1e-4)

# train for 20000 iterations
REINFORCE(env, 20000, policy, device, optimizer)

In [None]:
import imageio
from PIL import Image

In [None]:
#PATH = f"{ckpt_dir}/weight_epoch.pt"
PATH = "./weight_epoch.pt"
policy = BirdPolicy().to(device) 
ema = EMA(policy, beta = 0.99, update_every = 1)
ema.load_state_dict(torch.load(PATH))

In [None]:
def play_and_record_modified(env_names, model, episodes=1, filename='gameplay.gif'):
    # Create multiple environments
    envs = [make_env(name)() for name in env_names]
    frames = []

    for episode in range(episodes):
        obses = [env.reset()[0] for env in envs]
        done = [False] * len(envs)
        while not all(done):
            merged_frame = None 
            
            for i, env in enumerate(envs):
                if not done[i]:
                    obs = torch.FloatTensor(obses[i]).unsqueeze(0).to('cuda') 
                    probs = model(obs)  # Get action from your model
                    action = torch.distributions.Categorical(probs).sample()
                    obses[i], _, done[i], _, _ = env.step(action.cpu().numpy())
                frame = env.render()
                if merged_frame is None:
                    merged_frame = frame
                else:
                    merged_frame = np.concatenate((merged_frame, frame), axis=1)
                    
            # Convert array to PIL Image and then append to frames list
            frames.append(Image.fromarray(merged_frame))
            for env in envs:
                if all(done):
                    env.close()

    # Save frames as GIF
    imageio.mimsave(filename, frames, fps=30)
# Example usage
env_names = ['FlappyBird-v0', 'FlappyBird-v0', 'FlappyBird-v0']  # The environments you want to run
play_and_record_modified(env_names, policy, episodes=1, filename='FlappyBird_Triple.gif')