In [None]:
!pip install wandb
!wandb login

In [None]:
from collections import deque
import random

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision.transforms as transforms
import gym

import wandb

In [None]:
class Actor(nn.Module):
    def __init__(self, num_actions):
        super().__init__()

        # Create the layers for the model
        self.actor = nn.Sequential(
            nn.Conv2d(
                in_channels=3, out_channels=32,
                kernel_size=5, padding=2, stride=2
            ),  # (32, 32, 32)
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=32, out_channels=64,
                kernel_size=3, padding=1, stride=2
            ),  # (64, 16, 16)
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=64, out_channels=64,
                kernel_size=3, padding=1, stride=2
            ),  # (64, 8, 8)
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=64, out_channels=128,
                kernel_size=3, padding=1, stride=2
            ),  # (128, 4, 4)
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Flatten(start_dim=1),  # (2048)
            nn.Linear(128 * 4 * 4, 512),
            nn.ReLU(),
            nn.Linear(512, num_actions),
            nn.Softmax(dim=-1)
        )

    def forward(self, x):
        return self.actor(x)


class Critic(nn.Module):
    def __init__(self, act_dim):
        super().__init__()

        # Create the layers for the model
        self.critic = nn.Sequential(
            nn.Conv2d(
                in_channels=3, out_channels=32,
                kernel_size=5, padding=2, stride=2
            ),  # (32, 32, 32)
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=32, out_channels=64,
                kernel_size=3, padding=1, stride=2
            ),  # (64, 16, 16)
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=64, out_channels=64,
                kernel_size=3, padding=1, stride=2
            ),  # (64, 8, 8)
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=64, out_channels=128,
                kernel_size=3, padding=1, stride=2
            ),  # (128, 4, 4)
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Flatten(start_dim=1),  # (2048)
        )

        self.fc = nn.Sequential(
            nn.Linear(128 * 4 * 4 + act_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 1),
            nn.Tanh()
        )

    def forward(self, state, action):
        x = self.critic(state)
        x = torch.cat([x, action], dim=1)
        x = self.fc(x)

        return x

In [None]:
class ReplayMemory:
    def __init__(self, max_len):
        self.replay = deque(maxlen=max_len)

    def store_experience(self, state, reward,
                         action, next_state,
                         done):
        self.replay.append([state, reward, action, next_state, done])

    def size(self):
        return len(self.replay)

    def sample(self, batch_size):
        if len(self.replay) < batch_size:
            return None

        return random.sample(self.replay, k=batch_size)

In [None]:
class DDPG:
    def __init__(self, memory_size, num_actions,
                 actor_lr, critic_lr, gamma,
                 tau, device, img_transforms):
        # Set up model
        self.actor = Actor(num_actions).to(device)
        self.target_actor = Actor(num_actions).to(device)
        self.target_actor.eval()
        self.critic = Critic(num_actions).to(device)
        self.target_critic = Critic(num_actions).to(device)
        self.target_critic.eval()

        # Set up optimizer and criterion
        self.critic_criterion = nn.MSELoss()
        self.actor_optim = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic_optim = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)

        # Set up transforms and other hyper-parameters
        self.device = device
        self.img_transforms = img_transforms
        self.num_actions = num_actions
        self.memory = ReplayMemory(memory_size)
        self.gamma = gamma
        self.tau = tau

    def choose_action(self, cur_state, eps):
        # Open evaluation mode
        self.actor.eval()

        # Exploration
        if np.random.uniform() < eps:
            action = np.random.randint(0, self.num_actions)
        else:  # Exploitation
            cur_state = self.img_transforms(cur_state).to(self.device).unsqueeze(0)
            action_list = self.actor(cur_state)
            action = torch.argmax(action_list, dim=-1).item()

        # Open training mode
        self.actor.train()
        return action

    def actor_update(self, batch_data):
        # Separate the data into groups
        cur_state_batch = []

        for cur_state, *_ in batch_data:
            cur_state_batch.append(self.img_transforms(cur_state).unsqueeze(0))

        cur_state_batch = torch.cat(cur_state_batch, dim=0).to(self.device)
        actor_actions = F.gumbel_softmax(torch.log(F.softmax(self.actor(cur_state_batch), dim=1)), hard=True)

        loss = -self.critic(cur_state_batch, actor_actions).mean()
        self.actor_optim.zero_grad()
        loss.backward()
        self.actor_optim.step()

    def critic_update(self, batch_data):
        # Separate the data into groups
        cur_state_batch = []
        reward_batch = []
        action_batch = []
        next_state_batch = []
        done_batch = []

        for cur_state, reward, action, next_state, done in batch_data:
            cur_state_batch.append(self.img_transforms(cur_state).unsqueeze(0))
            reward_batch.append(reward)
            action_batch.append(action)
            next_state_batch.append(self.img_transforms(next_state).unsqueeze(0))
            done_batch.append(done)

        cur_state_batch = torch.cat(cur_state_batch, dim=0).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch).to(self.device)
        action_batch = torch.LongTensor(action_batch)
        action_batch = torch.zeros(len(batch_data), self.num_actions).scatter_(
            1, action_batch.unsqueeze(1), 1).to(self.device)
        next_state_batch = torch.cat(next_state_batch, dim=0).to(self.device)
        done_batch = torch.Tensor(done_batch).to(self.device)

        # Compute the TD error between eval and target
        Q_eval = self.critic(cur_state_batch, action_batch)
        next_action = F.softmax(self.target_actor(next_state_batch), dim=1)

        index = torch.argmax(next_action, dim=1).unsqueeze(1)
        next_action = torch.zeros_like(next_action).scatter_(1, index, 1).to(self.device)
        Q_target = reward_batch + self.gamma * (1 - done_batch) * self.target_critic(next_state_batch,
                                                                                     next_action).squeeze(1)

        loss = self.critic_criterion(Q_eval.squeeze(1), Q_target)

        self.critic_optim.zero_grad()
        loss.backward()
        self.critic_optim.step()

    def soft_update(self):
        # EMA for both actor and critic network
        for param, target_param in zip(self.actor.parameters(), self.target_actor.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

        for param, target_param in zip(self.critic.parameters(), self.target_critic.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

In [None]:
env = gym.make("snake:snake-v0", mode="hardworking")
device = "cpu"

# Set up environment hyperparameters
num_actions = env.action_space.n

# Set up training hyperparameters
tau = 0.05
max_time_steps = 100000
max_iter = 2000
gamma = 0.9
memory_size = 2000
batch_size = 32
actor_lr = 3e-4
critic_lr = 3e-4

In [None]:
def train(max_time_steps, max_iter, memory_size, 
          num_actions, actor_lr, critic_lr,
          gamma, tau, device, batch_size):
    
    # Set up model training
    img_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((64, 64))
    ])

    ddpg = DDPG(
        memory_size, num_actions, 
        actor_lr, critic_lr, gamma,
        tau, device, img_transforms
    )
    max_reward = -1e-9

    running_reward = 0
    running_episodes = 0

    time_step = 0
    print_freq = max_iter * 2

    while time_step < max_time_steps:
        state = env.reset()
        current_ep_reward = 0

        for _ in range(max_iter):
            # Get reward and state
            actions = ddpg.choose_action(state["frame"], 0.1)
            new_state, reward, done, _ = env.step(actions)

            current_ep_reward += reward
            ddpg.memory.store_experience(state["frame"], reward, actions, new_state["frame"], done)
            state = new_state

            if done:
                break
            
            # Wait for updating
            if ddpg.memory.size() < batch_size:
                continue

            batch_data = ddpg.memory.sample(batch_size)
            ddpg.critic_update(batch_data)
            ddpg.actor_update(batch_data)
            ddpg.soft_update()

            time_step += 1

            if time_step % print_freq == 0:
                avg_reward = running_reward / running_episodes

                print(f"Iteration:{running_episodes}, get average reward: {avg_reward:.2f}")

                running_reward = 0
                running_episodes = 0
                log = {
                    "avg_reward": avg_reward,
                }
                wandb.log(log)

                if avg_reward > max_reward:
                    max_reward = avg_reward
                    torch.save(ddpg.actor.state_dict(), "actor_best.pt")
                    torch.save(ddpg.critic.state_dict(), "critic_best.pt")
        
        running_reward += current_ep_reward
        running_episodes += 1

In [None]:
model_config = {
    "gamma": gamma,
    "max_time_steps": max_time_steps,
    "memory size": memory_size
}
run = wandb.init(
    project="snake_RL",
    resume=False,
    config=model_config,
    name="DDPG"
)

train(
    max_time_steps, max_iter, memory_size, 
    4, actor_lr, critic_lr,
    gamma, tau, "cpu", batch_size
)