In [None]:
!git clone https://github.com/JoyPang123/RL-Explore-with-Own-made-Env.git
!mv RL-Explore-with-Own-made-Env/snake ./snake
!pip install -e snake

Cloning into 'RL-Explore-with-Own-made-Env'...
remote: Enumerating objects: 110, done.[K
remote: Counting objects: 100% (110/110), done.[K
remote: Compressing objects: 100% (84/84), done.[K
remote: Total 110 (delta 36), reused 93 (delta 21), pack-reused 0[K
Receiving objects: 100% (110/110), 21.03 KiB | 7.01 MiB/s, done.
Resolving deltas: 100% (36/36), done.
Obtaining file:///content/snake
Collecting pygame
  Downloading pygame-2.1.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
[K     |████████████████████████████████| 18.3 MB 109 kB/s 
Installing collected packages: pygame, snake
  Running setup.py develop for snake
Successfully installed pygame-2.1.0 snake-0.0.1


In [None]:
from collections import Counter

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.multiprocessing as mp
from torch.distributions import Categorical

import torchvision.transforms as transforms
from torchvision.transforms import InterpolationMode

import numpy as np
import matplotlib.pyplot as plt

import cv2

import gym

In [None]:
class ActorCritic(nn.Module):
    """Adapted from
    https://github.com/raillab/a2c/blob/master/a2c/model.py
    """
    def __init__(self, num_actions):
        super().__init__()

        # Create the layers for the model
        self.actor = nn.Sequential(
            nn.Conv2d(
                in_channels=3, out_channels=32,
                kernel_size=5, padding=2, stride=2
            ),  # (32, 32, 32)
            nn.ReLU(),
            nn.Conv2d(
                in_channels=32, out_channels=64,
                kernel_size=3, padding=1, stride=2
            ),  # (64, 16, 16)
            nn.ReLU(),
            nn.Conv2d(
                in_channels=64, out_channels=64,
                kernel_size=3, padding=1, stride=2
            ),  # (64, 8, 8)
            nn.ReLU(),
            nn.Conv2d(
                in_channels=64, out_channels=128,
                kernel_size=3, padding=1, stride=2
            ),  # (128, 4, 4)
            nn.ReLU(),
            nn.Flatten(start_dim=1),  # (2048)
            nn.Linear(128 * 4 * 4, 512),
            nn.ReLU(),
            nn.Linear(512, num_actions)
        )

        # Create the layers for the model
        self.critic = nn.Sequential(
            nn.Conv2d(
                in_channels=3, out_channels=32,
                kernel_size=5, padding=2, stride=2
            ),  # (32, 32, 32)
            nn.ReLU(),
            nn.Conv2d(
                in_channels=32, out_channels=64,
                kernel_size=3, padding=1, stride=2
            ),  # (64, 16, 16)
            nn.ReLU(),
            nn.Conv2d(
                in_channels=64, out_channels=64,
                kernel_size=3, padding=1, stride=2
            ),  # (64, 8, 8)
            nn.ReLU(),
            nn.Conv2d(
                in_channels=64, out_channels=128,
                kernel_size=3, padding=1, stride=2
            ),  # (128, 4, 4)
            nn.ReLU(),
            nn.Flatten(start_dim=1),  # (2048)
            nn.Linear(128 * 4 * 4, 512),
            nn.ReLU(),
            nn.Linear(512, 1)
        )

    def forward(self, x):
        actor = F.log_softmax(self.actor(x), dim=1)
        critic = self.critic(x)
        
        return actor, critic

In [None]:
def run_episode(worker_env, worker_model, N_steps=1000):
    # Transform the image
    img_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((64, 64))
    ])

    state = worker_env.reset()
    state = img_transforms(state["frame"])
    values, logprobs, rewards, actions = [], [], [], []

    count_length = 0

    while True:
        count_length += 1
        policy, value = worker_model(state.unsqueeze(0).float())

        values.append(value.view(-1))
        logits = policy.view(-1)
        action_dist = Categorical(logits=logits)
        action = action_dist.sample()
        actions.append(action.item())

        logprob_ = policy.view(-1)[action]
        logprobs.append(logprob_.view(-1))

        state_, reward, done, info = worker_env.step(action.item())
        state = img_transforms(state_["frame"])

        rewards.append(torch.tensor([reward]))
        if done:
            break

    return values, logprobs, rewards, actions

In [None]:
def update_params(worker_optim, values, log_probs, rewards, 
                  critic_coeff=1.0, gamma=0.9):
    logprobs = torch.cat(log_probs).float().flip(dims=(0,))
    values = torch.cat(values).float().flip(dims=(0,))
    rewards = torch.cat(rewards).float().flip(dims=(0,))
    # eps = np.finfo(np.float32).eps.item()

    returns = []
    ret_ = torch.tensor([0.])
    for reward in rewards:
        ret_ = reward + gamma * ret_
        returns.append(ret_)

    returns = torch.FloatTensor(returns)
    
    actor_loss = -1 * logprobs * (returns - values).detach()
    critic_loss = F.smooth_l1_loss(values, returns)
    loss = actor_loss.sum() + critic_loss.sum()

    worker_optim.zero_grad()
    loss.backward()
    worker_optim.step()

    return actor_loss, critic_loss, rewards.sum().item()

In [None]:
def worker(model, episodes):
    worker_env = gym.make("snake:snake-v0")
    worker_optim = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.5)

    for episode in range(episodes):
        total_reward = 0

        values, logprobs, rewards, actions = run_episode(worker_env, model)
        actor_loss, critic_loss, get_reward = update_params(
            worker_optim, values, logprobs, rewards
        )
        total_reward += get_reward

        if (episode % 100) == 0:
            print(f"==========Episode: {episode}============")
            print(f"snake's length: {worker_env.snake.length}, reward: {total_reward}")

            action_counter = dict(Counter(actions))
            for key, value in action_counter.items():
                print(f"{key}:{value}", end=" ")
            print()

In [None]:
def train(args, num_actions=4):
    actor_critic = ActorCritic(num_actions)
    worker(actor_critic, args["episodes"])

In [None]:
args = {
    "episodes": 4000,
    "num_workers": 1
}

In [None]:
train(args)

snake's length: 1, reward: -9.989999771118164
1:2 2:2 3:1 0:2 
snake's length: 1, reward: -10.055001258850098
0:20 1:9 2:5 3:2 
snake's length: 1, reward: -10.02500057220459
0:16 1:14 2:1 3:3 
snake's length: 1, reward: -10.050000190734863
1:12 3:1 0:2 
snake's length: 1, reward: -9.999999046325684
0:40 1:29 
snake's length: 1, reward: -9.989999771118164
1:4 0:15 
snake's length: 1, reward: -10.045000076293945
0:22 1:32 


KeyboardInterrupt: ignored