<a href="https://www.kaggle.com/code/jvthunder/rl-2048?scriptVersionId=140301439" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [None]:
!pip install termtables
!pip install pickle5
!pip install gym
!pip install six
!pip install stable_baselines3

In [None]:
import numpy as np
import termtables as tt
import random
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras

from collections import deque
import time

import gym
from gym import spaces
from gym.utils import seeding

In [None]:
class Base2048Env(gym.Env):
        
    def __init__(self, SIZE=4, board=0 ,seed=None):
        if seed is not None:
            np.random.seed(seed)
        
        self.SIZE = SIZE
        self.width = SIZE
        self.height = SIZE
        self.layers = self.width * self.height
        
        self.ACTION_STRING = ['left', 'up', 'right', 'down']
        
        self.observation_space = spaces.Box(0, 1, (self.layers, self.width, self.height), dtype=int)
        self.action_space = spaces.Discrete(4)
        
        self.board = np.zeros((self.height, self.width)).astype(int)
        self.Matrix = np.zeros((self.layers, self.width, self.height), int)
        if type(board)==int:
            self.reset()
        else:
            self.board = board
            
        self.score = 0
        self.move_cnt = 0
        self.prv_move = '-'
        self.is_game_over = False
    
    def render(self):
        
        text = f"2048 Board State Turn {self.move_cnt}\n"
        text += f"Score: {self.score}\n"
        text += f"Previous move: {self.prv_move}\n"

        grid = tt.to_string(
            self.board,
            style=tt.styles.ascii_thin
        )
        text += grid
        
        if self.is_game_over:
            text += f"\nYou lose. Your score is {self.score}"

        print(text)
        
    def stack(self):
        self.Matrix = np.zeros((self.layers, self.width, self.height), int)
        for i in range(self.height):
            for j in range(self.width):
                if self.board[i][j]!=0:
                    dim = int(np.log2(self.board[i][j])-1)
                    self.Matrix[dim, i, j] = 1
        return self.Matrix

    def reset(self):
        self.board = np.zeros((self.height, self.width)).astype(int)
        self.add_random()
        self.add_random()
        self.score = 0
        self.move_cnt = 0
        self.prv_move = '-'
        self.is_game_over = False
        return self.stack()
    
    def add_random(self):
        number = 0
        if np.random.randint(0,10) == 0:
            number = 4
        else:
            number = 2

        available = []
        for i in range(self.height):
            for j in range(self.width):
                if self.board[i][j] == 0:
                    available.append((i,j))
        
        idx = np.random.randint(0, len(available))
        x, y = available[idx]
        self.board[x][y] = number
    
    def merge_left(self, board):
        valid = False
        inc_score = 0
        new_board = np.zeros((self.height, self.width)).astype(int)

        for i in range(self.height):
            latest_num = board[i][0]
            latest_pos = 0
            new_board[i][latest_pos] = latest_num

            for j in range(1,self.width):
                if board[i][j] == 0:
                    continue
                if latest_num == 0:
                    latest_num = board[i][j]
                    new_board[i][latest_pos] = latest_num
                    valid = True
                elif latest_num == board[i][j]:
                    latest_num += board[i][j]
                    new_board[i][latest_pos] = latest_num
                    inc_score += latest_num
                    valid = True
                else:
                    latest_num = board[i][j]
                    latest_pos += 1
                    new_board[i][latest_pos] = latest_num
                    if latest_pos != j:
                        valid = True 
        return new_board, valid, inc_score
    
    def check_game_over(self):
        for i in range(len(self.board)):
            if 2048 in self.board[i]:
                return True
        
        for idx in [0,1,2,3]:
            copy_board = np.rot90(self.board, k = idx)
            copy_board, valid, _ = self.merge_left(copy_board)
            if valid == True:
                return False
        return True

    def step(self, action):
        if self.is_game_over:
            self.reset()
            self.is_game_over = False
        
        if action not in self.action_space:
            self.prv_move = '-'
            return self.Matrix, 0, self.is_game_over, {}
        
        copy_board = np.rot90(self.board, k = action)
        copy_board, valid, inc_score = self.merge_left(copy_board)
        self.score += inc_score 
        
        self.prv_move = self.ACTION_STRING[action]
        if valid == False:
            return self.Matrix, 0, self.is_game_over, {}
        
        self.board = np.rot90(copy_board, k = -action)
        self.move_cnt += 1
        self.add_random()
        self.is_game_over = self.check_game_over()
        
        return self.stack(), inc_score, self.is_game_over, {}

In [None]:
env = Base2048Env()
obs = env.reset()
# env.render()
# print(obs.shape)
# print(env.observation_space)

for i in range(10):
    action = env.action_space.sample()
    observation, reward, done, info = env.step(action)
#     print(observation.shape)
    env.render()

In [None]:
from stable_baselines3 import PPO

def make_env():
    env = Base2048Env()
    return env

env = make_env()
env.reset()
model = PPO('MlpPolicy', env, verbose=1,
            n_steps=1024,
            batch_size=32, 
            gamma=0.9,
            learning_rate=0.0001,
            ent_coef=0.01,
            n_epochs=32,
            )

In [None]:
TIMESTEPS = 1024 * 1000
model.learn(total_timesteps=TIMESTEPS, reset_num_timesteps=False, tb_log_name=f"PPO")
model.save(f"models/PPO/{TIMESTEPS}")

In [None]:
import torch

env = Base2048Env()
observation = env.reset()
env.render()
done = False
for i in range(1000):
    observation = torch.FloatTensor(observation).unsqueeze(0)
    action, _ = model.predict(observation)
    action = action[0]
    observation, reward, done, info = env.step(action)
    env.render()
    if done: break