# Attempt at making custom TorchRL enviornment for 2048

## Helper functions

In [1]:
import torch
from tensordict import TensorDict
from torchrl.envs import GymEnv, EnvBase
from torchrl.envs.utils import check_env_specs, step_mdp
from torchrl.data import Bounded, Composite, Categorical, Binary
import numpy as np
import random

BOARD_SIZE = 4
ACTIONS = [0, 1, 2, 3]  # up,right, down, left

def add_tile(board, rng=None):
    empty = list(zip(*np.where(board == 0)))
    if not empty:  # no empty cells
        return board
    y, x = random.choice(empty)
    board[y][x] = 1 if random.random() < 0.9 else 2
    return board

def move_left(board):
    new_board = np.zeros_like(board)
    reward = 0
    for row in range(BOARD_SIZE):
        tiles = board[row][board[row] != 0] # collect non-zero tiles
        merged = []
        skip = False
        for i in range(len(tiles)):
            if skip:
                skip = False
                continue
            if i + 1 < len(tiles) and tiles[i] == tiles[i+1]:
                merged.append(tiles[i] + 1)
                reward += 2 ** (tiles[i] + 1)  # calculate reward
                skip = True
            else:
                merged.append(tiles[i])
        new_board[row][:len(merged)] = merged
    return new_board, reward

def move(board, direction): 
    if direction == 0:  # up
        board = np.rot90(board, 1)
        new_board, reward = move_left(board)   #reuse this func to death bc im lazy lmao
        new_board = np.rot90(new_board, -1)
    elif direction == 2:  # down
        board = np.rot90(board, -1)
        new_board, reward = move_left(board)
        new_board = np.rot90(new_board)
    elif direction == 3:  # left
        new_board, reward = move_left(board)
    elif direction == 1:  # right
        board = np.fliplr(board)
        new_board, reward = move_left(board)
        new_board = np.fliplr(new_board)
    else:
        raise ValueError("Invalid direction")
    return new_board, reward

def is_game_over(board):
    for a in ACTIONS:
        new_board, _ = move(board, a)
        if not np.array_equal(new_board, board):
            return False
    return True

## TorchRL Env 

In [2]:
class Game2048Env(EnvBase):
    def __init__(self, device="cpu", batch_size=None):
        #define self.observation_spec, self.action_spec, self.reward_spec here
        super().__init__(device=device, batch_size=batch_size)
        self.observation_spec = Composite(
            observation=Bounded(
                low=0.0,
                high=18.0,  # max tile is 2^18 in perfect conditions
                shape=(BOARD_SIZE, BOARD_SIZE),
                dtype=torch.float32,
                device=device
            ),
            shape=batch_size
        )
        self.action_spec = Categorical(
            n=len(ACTIONS),
            shape=(),
            dtype=torch.int64,
            device=device
        )
        self.reward_spec = Bounded(
            low=0.0,
            high=float('inf'),
            shape=(1,),
            dtype=torch.float32,
            device=device
        )
        self.done_spec = Composite(
            done=Binary(
                shape=(1,),
                dtype=torch.bool,
                device=device
            ),
            shape=batch_size
        )
        self.board = None
        self.rng = np.random.RandomState()
        self._set_seed(None)

    def _reset(self, tensordict=None):
        #returns observation of initial state, reset env to initial state to prepare for next episode
        self.board = np.zeros((BOARD_SIZE, BOARD_SIZE), dtype=int)
        self.board = add_tile(self.board, self.rng)
        self.board = add_tile(self.board, self.rng)
        obs = torch.from_numpy(self.board.copy().astype(np. float32)).to(self.device)
        done = torch.zeros(1, dtype=torch.bool, device=self.device)
        return TensorDict(
            {
                "observation": obs, 
                "done": done
            },
            batch_size=self.batch_size,
            device=self.device
        )
    
    def _step(self, tensordict):
        action = tensordict["action"].item()
        old_board = self.board.copy()
        self.board, reward = move(self.board, action)
        if not np.array_equal(old_board, self.board):
            self.board = add_tile(self.board)
        done = is_game_over(self.board)
        obs = torch.from_numpy(self.board.copy().astype(np.float32)).to(self.device)
        reward_tensor = torch.tensor([reward], dtype=torch.float32, device=self.device)
        done_tensor = torch.tensor(done, dtype=torch.bool, device=self.device)
        return TensorDict(
            {
                "observation": obs,
                "reward": reward_tensor,
                "done": done_tensor
            },
            batch_size=self.batch_size,
            device=self.device
        )
    
    def _set_seed(self, seed):
        if seed is None:
            seed = np.random.randint(0, 2**31 - 1)
        self.rng = np.random.RandomState(seed)
        torch.manual_seed(seed)
        return seed

In [6]:
env = Game2048Env()
td = env.reset()
print(f"\nInitial State:")
print(td["observation"].numpy())
print(f"Done: {td['done'].item()}")

for step in range(10):
    action = torch.tensor((step % 4), dtype=torch.int64)
    td["action"] = action
    print(f"\nStep {step + 1}, Action: {action}")
    td = env.step(td)
    print("td params:", td.keys())
    print("next params:", td["next"].keys())
    print(td["next"]["observation"].numpy())
    print(f"Reward: {td['next']['reward'].item()}")
    print(f"Done: {td['next']['done'].item()}")
    td = env.step_mdp(td)  # move to next step


Initial State:
[[1. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 2.]]
Done: False

Step 1, Action: 0
td params: _StringKeys(dict_keys(['observation', 'done', 'terminated', 'action', 'next']))
next params: _StringKeys(dict_keys(['observation', 'reward', 'done', 'terminated']))
[[1. 0. 0. 2.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 1.]]
Reward: 0.0
Done: False

Step 2, Action: 1
td params: _StringKeys(dict_keys(['observation', 'action', 'done', 'terminated', 'next']))
next params: _StringKeys(dict_keys(['observation', 'reward', 'done', 'terminated']))
[[0. 0. 1. 2.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [1. 0. 0. 1.]]
Reward: 0.0
Done: False

Step 3, Action: 2
td params: _StringKeys(dict_keys(['observation', 'action', 'done', 'terminated', 'next']))
next params: _StringKeys(dict_keys(['observation', 'reward', 'done', 'terminated']))
[[1. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 2.]
 [1. 0. 1. 1.]]
Reward: 0.0
Done: False

Step 4, Action: 3
td params: _StringKeys(dict_keys(['observation', 'ac