In [None]:
!git clone https://github.com/JoyPang123/snake_env.git
!mv snake_env/snake ./snake
!pip install -e snake
exit() # Leave it here for automatically restart the runtime

In [None]:
!pip install wandb
!wandb login

In [None]:
from collections import Counter

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.multiprocessing as mp
from torch.distributions import Categorical

import torchvision.transforms as transforms
from torchvision.transforms import InterpolationMode

from PIL import Image

import wandb
import numpy as np
import matplotlib.pyplot as plt

import cv2

import gym

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
class RolloutBuffer:
    def __init__(self):
        self.actions = []
        self.states = []
        self.logprobs = []
        self.rewards = []
        self.done = []
    
    def clear(self):
        del self.actions[:]
        del self.states[:]
        del self.logprobs[:]
        del self.rewards[:]
        del self.done[:]

In [None]:
class ActorCritic(nn.Module):
    """Adapted from
    https://github.com/raillab/a2c/blob/master/a2c/model.py
    """
    def __init__(self, num_actions):
        super().__init__()

        # Create the layers for the model
        self.actor = nn.Sequential(
            nn.Conv2d(
                in_channels=3, out_channels=32,
                kernel_size=5, padding=2, stride=2
            ),  # (32, 32, 32)
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=32, out_channels=64,
                kernel_size=3, padding=1, stride=2
            ),  # (64, 16, 16)
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=64, out_channels=64,
                kernel_size=3, padding=1, stride=2
            ),  # (64, 8, 8)
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=64, out_channels=128,
                kernel_size=3, padding=1, stride=2
            ),  # (128, 4, 4)
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Flatten(start_dim=1),  # (2048)
            nn.Linear(128 * 4 * 4, 512),
            nn.ReLU(),
            nn.Linear(512, num_actions),
            nn.Softmax(dim=-1)
        )

        # Create the layers for the model
        self.critic = nn.Sequential(
            nn.Conv2d(
                in_channels=3, out_channels=32,
                kernel_size=5, padding=2, stride=2
            ),  # (32, 32, 32)
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=32, out_channels=64,
                kernel_size=3, padding=1, stride=2
            ),  # (64, 16, 16)
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=64, out_channels=64,
                kernel_size=3, padding=1, stride=2
            ),  # (64, 8, 8)
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=64, out_channels=128,
                kernel_size=3, padding=1, stride=2
            ),  # (128, 4, 4)
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Flatten(start_dim=1),  # (2048)
            nn.Linear(128 * 4 * 4, 512),
            nn.ReLU(),
            nn.Linear(512, 1),
            nn.Tanh()
        )

    def act(self, state):
        action_probs = self.actor(state)
        dist = Categorical(action_probs)

        action = dist.sample()
        action_logprob = dist.log_prob(action)
        
        return action.detach(), action_logprob.detach()

    def evaluate(self, state, action):
        action_probs = self.actor(state)
        dist = Categorical(action_probs)

        action_logprobs = dist.log_prob(action)
        dist_entropy = dist.entropy()
        state_values = self.critic(state)
        
        return action_logprobs, state_values, dist_entropy

In [None]:
class PPO:
    def __init__(self, action_dim, lr_actor, lr_critic, gamma, k_epochs, eps_clip):
        self.gamma = gamma
        self.eps_clip = eps_clip
        self.k_epochs = k_epochs

        self.buffer = RolloutBuffer()

        self.policy = ActorCritic(action_dim).to(device)
        self.optimizer = torch.optim.Adam([
            {'params': self.policy.actor.parameters(), 'lr': lr_actor},
            {'params': self.policy.critic.parameters(), 'lr': lr_critic}
        ])

        self.policy_old = ActorCritic(action_dim).to(device)
        self.policy_old.load_state_dict(self.policy.state_dict())

        self.criterion = nn.MSELoss()

    def select_action(self, state):
        with torch.no_grad():
            state = torch.FloatTensor(state).to(device)
            action, action_logprob = self.policy_old.act(state)

        self.buffer.states.append(state)
        self.buffer.actions.append(action)
        self.buffer.logprobs.append(action_logprob)

        return action.item()

    def update(self):
        rewards = []
        discounted_reward = 0
        for reward, done in zip(reversed(self.buffer.rewards), reversed(self.buffer.done)):
            if done:
                discounted_reward = 0
            discounted_reward = reward + (self.gamma * discounted_reward)
            rewards.insert(0, discounted_reward)

        # Normalizing the rewards
        rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-7)

        # Converting list to tensor detach for not updating the old policy network
        old_states = torch.squeeze(torch.stack(self.buffer.states, dim=0)).detach().to(device)
        old_actions = torch.squeeze(torch.stack(self.buffer.actions, dim=0)).detach().to(device)
        old_logprobs = torch.squeeze(torch.stack(self.buffer.logprobs, dim=0)).detach().to(device)

        # Optimize policy for K epochs
        for _ in range(self.k_epochs):
            # Evaluating old actions and values
            logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions)

            # Match state_values tensor dimensions with rewards tensor
            state_values = torch.squeeze(state_values)

            # Finding the ratio (pi_theta / pi_theta__old)
            ratios = torch.exp(logprobs - old_logprobs.detach())

            # Finding Surrogate Loss
            advantages = rewards - state_values.detach()
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages

            # Final loss of clipped objective PPO
            loss = -torch.min(surr1, surr2) + 0.5 * self.criterion(state_values, rewards) - 0.01 * dist_entropy

            # Take gradient step
            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()

        # Copy new weights into old policy
        self.policy_old.load_state_dict(self.policy.state_dict())

        # Clear buffer
        self.buffer.clear()

In [None]:
max_reward = -1e9

max_ep_len = 2000
max_training_timesteps = int(1e6)

print_freq = max_ep_len * 4

update_timestep = max_ep_len * 4
k_epochs = 40
eps_clip = 0.2
gamma = 0.99

lr_actor = 3e-4
lr_critic = 3e-4

env = gym.make("snake:snake-v0", max_iter=max_ep_len, mode="hardworking")

ppo_agent = PPO(env.action_space.n, lr_actor, lr_critic, gamma, k_epochs, eps_clip)

running_reward = 0
running_episodes = 0

time_step = 0

img_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((64, 64))
])

model_config = {
    "gamma": gamma,
    "max_training_timesteps": max_training_timesteps,
    "eps_clip": eps_clip
}
run = wandb.init(
    project="snake_RL",
    resume=False,
    config=model_config,
    name="PPO"
)

while time_step <= max_training_timesteps:
    state = env.reset()
    current_ep_reward = 0

    for t in range(1, max_ep_len + 1):
        # Select action with policy
        action = ppo_agent.select_action(img_transforms(state["frame"]).unsqueeze(0))
        state, reward, done, _ = env.step(action)

        # Saving the episode information
        ppo_agent.buffer.rewards.append(reward)
        ppo_agent.buffer.done.append(done)

        time_step += 1
        current_ep_reward += reward

        # update PPO agent
        if time_step % update_timestep == 0:
            ppo_agent.update()

        if time_step % print_freq == 0:
            avg_reward = running_reward / running_episodes

            print(f"Iteration:{running_episodes}, get average reward: {avg_reward:.2f}")

            running_reward = 0
            running_episodes = 0
            log = {
                "avg_reward": avg_reward,
            }
            wandb.log(log)

            if avg_reward > max_reward:
                max_reward = avg_reward
                torch.save(ppo_agent.policy.state_dict(), "best.pt")

        if done:
            break
            
    running_reward += current_ep_reward
    running_episodes += 1

In [None]:
def save_fig(ppo_agent, env, img_transforms, file_name):
    gif_frames = []
    state = env.reset()
    gif_frames.append(state["frame"])
    rewards = 0

    # Run the game
    while True:
        # Select action with policy
        action = ppo_agent.select_action(img_transforms(state["frame"]).unsqueeze(0))
        state, reward, done, _ = env.step(action)
        rewards += reward
        gif_frames.append(state["frame"])

        if done:
            break

    print(rewards)
    # Append the frames
    img, *imgs = [Image.fromarray(frame) for frame in gif_frames]
    img.save(
        fp=file_name, format="GIF", append_images=imgs,
        save_all=True, optimize=True, duration=150,
        loop=0
    )

In [None]:
save_fig(ppo_agent, env, img_transforms, "test.gif")