In [1]:
import numpy as np
import time
from IPython.display import clear_output

from envs import game2048_env

id_action_dict = {0 : "UP",
                  1 : "RIGHT",
                  2 : "DOWN",
                  3 : "LEFT"}

In [2]:
from torch import nn
import torch
from collections import deque
import itertools
import random
import torch.nn.functional as F

In [3]:
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"
    
device = torch.device(device)
    
print("PyTorch is using :", device)

PyTorch is using : mps


In [4]:
class DQN(nn.Module):
    ### Deep Q-Learning Network
    def __init__(self):
        super(DQN, self).__init__()
        
        self.conv_a = nn.Conv2d(16, 128, kernel_size=(1,2))       # Convolutional layers
        self.conv_b = nn.Conv2d(16, 128, kernel_size=(2,1))

        self.conv_aa = nn.Conv2d(128, 128, kernel_size=(1,2))
        self.conv_ab = nn.Conv2d(128, 128, kernel_size=(2,1))

        self.conv_ba = nn.Conv2d(128, 128, kernel_size=(1,2))
        self.conv_bb = nn.Conv2d(128, 128, kernel_size=(2,1))
        
        self.fc = nn.Sequential(         # Linear and Relu
            nn.Linear(7424, 256),
            nn.ReLU(),
            nn.Linear(256, 4)
        )
        
    def forward(self, x):
        # Forward function of the layer
        x_a = F.relu(self.conv_a(x))
        x_b = F.relu(self.conv_b(x))
        
        x_aa = F.relu(self.conv_aa(x_a))
        x_ab = F.relu(self.conv_ab(x_a))
        
        x_ba = F.relu(self.conv_ba(x_b))
        x_bb = F.relu(self.conv_bb(x_b))
        
        sh_a = x_a.shape
        sh_aa = x_aa.shape
        sh_ab = x_ab.shape
        sh_b = x_b.shape
        sh_ba = x_ba.shape
        sh_bb = x_bb.shape
        
        x_a = x_a.view(sh_a[0],sh_a[1]*sh_a[2]*sh_a[3])
        x_aa = x_aa.view(sh_aa[0],sh_aa[1]*sh_aa[2]*sh_aa[3])
        x_ab = x_ab.view(sh_ab[0],sh_ab[1]*sh_ab[2]*sh_ab[3])
        x_b = x_b.view(sh_b[0],sh_b[1]*sh_b[2]*sh_b[3])
        x_ba = x_ba.view(sh_ba[0],sh_ba[1]*sh_ba[2]*sh_ba[3])
        x_bb = x_bb.view(sh_bb[0],sh_bb[1]*sh_bb[2]*sh_bb[3])
        
        concat = torch.cat((x_a,x_b,x_aa,x_ab,x_ba,x_bb),dim=1)
        
        output = self.fc(concat)
        
        return output
    
    def act(self, obs):
        obs_t = torch.as_tensor(obs, dtype=torch.float32)
        q_values = self(obs_t.unsqueeze(0))
        
        max_q_index = torch.argmax(q_values, dim=1)[0]
        
        action = max_q_index.detach().item()
        
        return action

## Training

In [17]:
GAMMA = 0.99
BATCH_SIZE = 16
BUFFER_SIZE = 50000
MIN_REPLAY_SIZE = 1000
EPSILON_START = 0.1
EPSILON_END = 0.0001
EPSILON_DECAY = 10000
TARGET_UPDATE_FREQ = 500
LEARNING_RATE = 1e-4

In [18]:
env = game2048_env.Game2048Env()

In [19]:
replay_buffer = deque(maxlen=BUFFER_SIZE)
rew_buffer = deque(maxlen=100)
highest_buffer = deque(maxlen=100)

episode_reward = 0

online_net = DQN().to(device)
target_net = DQN().to(device)

target_net.load_state_dict(online_net.state_dict())

optimizer = torch.optim.Adam(online_net.parameters(), lr=LEARNING_RATE)

In [None]:
# Initialize Replay Buffer

obs = env.reset()
obs = np.rollaxis(obs, 2)
for _ in range(MIN_REPLAY_SIZE):
    action = env.action_space.sample()
    
    new_obs, rew, done, info = env.step(action)
    #rew = 16 - np.sum(new_obs)
    new_obs = np.rollaxis(new_obs, 2)
    transition = (obs, action, rew, done, new_obs)
    replay_buffer.append(transition)
    obs = new_obs
    
    if done:
        obs = env.reset()
        obs = np.rollaxis(obs, 2)
        

# Main training loop
obs = env.reset()
obs = np.rollaxis(obs, 2)

for step in itertools.count():
    epsilon = np.interp(step, [0, EPSILON_DECAY], [EPSILON_START, EPSILON_END])

    rnd_sample = np.random.random()
    
    if rnd_sample <= epsilon:
        action = env.action_space.sample()
    else:
        action = online_net.act(torch.as_tensor(obs, dtype=torch.float32).to(device))
    
    new_obs, rew, done, info = env.step(action)
    #rew = 16 - np.sum(new_obs)
    new_obs = np.rollaxis(new_obs, 2)
    transition = (obs, action, rew, done, new_obs)
    replay_buffer.append(transition)
    obs = new_obs
    
    episode_reward += rew
    
    if done:
        obs = env.reset()
        obs = np.rollaxis(obs, 2)
        
        rew_buffer.append(episode_reward)
        highest_buffer.append(info['highest'])
        
        max_target_q_values
        episode_reward = 0

    # Start Gradient Step
    #print(replay_buffer[0])
    transitions = random.sample(replay_buffer, BATCH_SIZE)
    
    obses = np.asarray([t[0] for t in transitions])
    actions = np.asarray([t[1] for t in transitions])
    rews = np.asarray([t[2] for t in transitions])
    dones = np.asarray([t[3] for t in transitions])
    new_obses = np.asarray([t[4] for t in transitions])
    
    obses_t = torch.as_tensor(obses, dtype=torch.float32).squeeze(1).to(device)
    actions_t = torch.as_tensor(actions, dtype=torch.int64).unsqueeze(1).to(device)
    rews_t = torch.as_tensor(rews, dtype=torch.float32).unsqueeze(1).to(device)
    dones_t = torch.as_tensor(dones, dtype=torch.float32).unsqueeze(1).to(device)
    new_obses_t = torch.as_tensor(new_obses, dtype=torch.float32).squeeze(1).to(device)
    
    # Compute targets
    target_q_values = target_net(new_obses_t)
    max_target_q_values = target_q_values.max(dim=1, keepdim=True)[0]
    
    targets = rews_t + GAMMA * (1 - dones_t) * max_target_q_values
    
    # Compute Loss
    q_values = online_net(obses_t)
    
    action_q_values = torch.gather(input=q_values, dim=1, index=actions_t)
    
    loss = nn.functional.smooth_l1_loss(action_q_values, targets)
    
    # Gradient Descent
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    # Update Target Network
    
    if step % TARGET_UPDATE_FREQ == 0:
        target_net.load_state_dict(online_net.state_dict())
    
    # Logging
    if step % 1000 == 0:
        print('Step :', step)
        print('Avg Rew :', np.round(np.mean(rew_buffer), 2))
        print('Highest :', np.round(np.mean(highest_buffer)))
        print('\n')

Step : 0
Avg Rew : nan
Highest : nan


Step : 1000
Avg Rew : -1648.0
Highest : 96.0


Step : 2000
Avg Rew : -1408.0
Highest : 102.0


Step : 3000
Avg Rew : -666.6
Highest : 93.0


Step : 4000
Avg Rew : -317.73
Highest : 105.0


Step : 5000
Avg Rew : -191.3
Highest : 101.0


Step : 6000
Avg Rew : 27.28
Highest : 104.0


Step : 7000
Avg Rew : 222.97
Highest : 115.0


Step : 8000
Avg Rew : 350.91
Highest : 120.0


Step : 9000
Avg Rew : 384.46
Highest : 121.0


Step : 10000
Avg Rew : 425.44
Highest : 121.0


Step : 11000
Avg Rew : 498.94
Highest : 125.0


Step : 12000
Avg Rew : 557.82
Highest : 126.0




## Playing

In [21]:
def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=0)

In [33]:
env = game2048_env.Game2048Env()
env.seed()

env.reset()

next_state = env.reset()
next_state = np.rollaxis(next_state, 2)

done = False

counter = 0
illegal_counter = 0

while not done:
    counter += 1
    clear_output(wait=True)
    
    action = online_net.act(torch.as_tensor(next_state, dtype=torch.float32).to(device))
    print("Action :", id_action_dict[action])
    print("Counter :", counter)
    
    next_state, reward, done, info = env.step(action)
    
    if info['illegal_move']:
        illegal_counter += 1
        
        obs_t = torch.as_tensor(obs, dtype=torch.float32).to(device)
        q_values = online_net(obs_t.unsqueeze(0)).cpu().detach().numpy().flatten()
        
        q_values_proba = softmax(q_values)
        
        action = np.random.choice([0, 1, 2, 3], p=q_values_proba)
        
        next_state, reward, done, info = env.step(action)
    
    next_state = np.rollaxis(next_state, 2)
    
    print("Illegal Counter :", illegal_counter)
    
    print(info)
    
    env.render()
    
    time.sleep(0.0)
    
    if done:
        print("You LOST !")
        env.reset()
        break

Action : DOWN
Counter : 1625


KeyboardInterrupt: 