In [1]:
import torch
from torch import nn
import torch.optim as optim
import numpy as np
import random

In [2]:
class NN2048(nn.Module):
    def __init__(self, input_size=16, filter1=512, filter2=4096, drop_prob=0.):
        super(NN2048, self).__init__()
        self.conv_a = nn.Conv2d(in_channels=input_size, out_channels=filter1, kernel_size=(2,1), padding=0)
        self.conv_b = nn.Conv2d(in_channels=input_size, out_channels=filter1, kernel_size=(1,2), padding=0)
        self.conv_aa = nn.Conv2d(in_channels=filter1, out_channels=filter2, kernel_size=(2,1), padding=0)
        self.conv_ab = nn.Conv2d(in_channels=filter1, out_channels=filter2, kernel_size=(1,2), padding=0)
        self.conv_ba = nn.Conv2d(in_channels=filter1, out_channels=filter2, kernel_size=(2,1), padding=0)
        self.conv_bb = nn.Conv2d(in_channels=filter1, out_channels=filter2, kernel_size=(1,2), padding=0)
        self.relu = nn.ReLU()
        self.W_aa = nn.Linear(filter2 * 8, 1)
        self.W_ab = nn.Linear(filter2 * 9, 1)
        self.W_ba = nn.Linear(filter2 * 9, 1)
        self.W_bb = nn.Linear(filter2 * 8, 1)

    def flatten(self, x):
        N = x.size()[0]
        return x.view(N, -1)
        
    def forward(self, x):
        x = x.float()
        a = self.relu(self.conv_a(x))
        b = self.relu(self.conv_b(x))
        aa = self.flatten(self.relu(self.conv_aa(a)))
        ab = self.flatten(self.relu(self.conv_ab(a)))
        ba = self.flatten(self.relu(self.conv_ba(b)))
        bb = self.flatten(self.relu(self.conv_bb(b)))
        out = self.W_aa(aa) + self.W_ab(ab) + self.W_ba(ba) + self.W_bb(bb)
        return out

In [3]:
def make_input(grid):
    r = np.zeros(shape=(16, 4, 4))
    for i in range(4):
        for j in range(4):
            r[grid[i, j],i, j]=1
    return r

def add_two(mat):
    indexs=np.argwhere(mat==0)
    index=np.random.randint(0,len(indexs))
    mat[tuple(indexs[index])] = 1
    return mat

In [4]:
singleScore=[0,0,4,16,48,128,320,768,1792,4096,9216,20480,45056,98304,212992,458752,983040]
moveDict=np.load('move.npy')

def move(list):
    return moveDict[list[0],list[1],list[2],list[3],:]

def lookup(x):
    return singleScore[x]

lookup = np.vectorize(lookup)

def getScore(matrix):
    return np.sum(lookup(matrix))

def getMove(grid):
    board_list = []
    for i in range(4):
        newGrid=moveGrid(grid, i)
        if not isSame(grid,newGrid):
            board_list.append((newGrid, i, getScore(newGrid)))
    return board_list
        
def moveGrid(grid,i):
    # new=np.zeros((4,4),dtype=np.int)
    new = None
    if i==0:
        # move up
        grid=np.transpose(grid)
        new = np.stack([move(grid[row,:]) for row in range(4)], axis = 0).astype(int).T
    elif i==1:
        # move left
        new = np.stack([move(grid[row,:]) for row in range(4)], axis = 0).astype(int)
    elif i==2:
        # move down
        grid=np.transpose(grid)
        new = np.stack([np.flip(move(np.flip(grid[row,:]))) for row in range(4)], axis = 0).astype(int).T
    elif i==3:
        # move right
        new = np.stack([np.flip(move(np.flip(grid[row,:]))) for row in range(4)], axis = 0).astype(int)
    return new

def isSame(grid1,grid2):
    return np.all(grid1==grid2)

In [5]:
def Vchange(grid, v):
    g0 = grid
    g1 = g0[:,::-1,:]
    g2 = g0[:,:,::-1]
    g3 = g2[:,::-1,:]
    r0 = grid.swapaxes(1,2)
    r1 = r0[:,::-1,:]
    r2 = r0[:,:,::-1]
    r3 = r2[:,::-1,:]
    xtrain = np.array([g0,g1,g2,g3,r0,r1,r2,r3])
    ytrain = np.array([v]*8)
    return xtrain, ytrain

def gen_sample_and_learn(model, optimizer, loss_fn, is_train = False, explorationProb=0.1):
    model.eval()
    game_len = 0
    game_score = 0
    last_grid1 = np.zeros((4,4),dtype=np.int)
    last_grid1 = add_two(last_grid1)
    last_grid2 = make_input(last_grid1)
    last_loss = 0

    while True:
        grid_array = add_two(last_grid1)
        board_list = getMove(grid_array)
        if board_list:
            boards = np.array([make_input(g) for g,m,s in board_list])
            p = model(torch.from_numpy(boards).cuda()).flatten().detach()        
            game_len += 1
            best_v = None
            for i, (g,m,s) in enumerate(board_list):
                v = (s - game_score) + p[i].item()
                if best_v is None or v > best_v:
                    best_v = v
                    best_score = s
                    best_grid1 = board_list[i][0]
                    best_grid2 = boards[i]
                    
        else:
            best_v = 0
            best_grid1 = None
            best_grid2 = None
            
        if is_train:
            x, y = Vchange(last_grid2, best_v)
            x = torch.from_numpy(x).cuda()
            y = torch.from_numpy(y).unsqueeze(dim=1).cuda().float()
            model.train()
            optimizer.zero_grad()
            pred = model(x)
            loss = loss_fn(pred, y) / 2
            last_loss = loss.item()
            loss.backward()
            # nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            optimizer.step()
            model.eval()
#             if game_len % 30 == 0:
#                 print (game_len, last_loss)
                
        if not board_list:
            break
            
        # gibbs sampling or espilon-greedy
        if is_train and random.random() < explorationProb:
            idx = random.randint(0, len(board_list) - 1)
            game_score = board_list[idx][2]
            last_grid1 = board_list[idx][0]
            last_grid2 = boards[idx]
        else:
            game_score = best_score
            last_grid1 = best_grid1
            last_grid2 = best_grid2
        
    return game_len, 2**grid_array.max(), game_score, last_loss

In [6]:
num_epochs = 200
lr = 1e-3
weight_decay = 0#1e-5

def train(model):
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay, betas=(0.5, 0.999))
    loss=nn.MSELoss()
    epoch = 0
    while epoch != num_epochs:
        epoch += 1
        game_len, max_score, game_score, last_loss = gen_sample_and_learn(model, optimizer, loss, True)
        print ('epoch', epoch, game_len, max_score, game_score, last_loss)
    
model = NN2048().cuda()
train(model)

epoch 1 161 128 1484 74637272.0
epoch 2 208 128 1996 11621043.0
epoch 3 158 128 1456 6250300.0
epoch 4 233 128 2368 4457950.5
epoch 5 418 512 5652 11269672.0
epoch 6 240 128 2432 3352961.0
epoch 7 229 256 2528 3913269.0
epoch 8 231 256 2564 2083026.5
epoch 9 337 256 3808 6051264.5
epoch 10 555 512 7484 7848240.5
epoch 11 421 256 5112 7212492.5
epoch 12 405 512 5472 8995341.0
epoch 13 469 512 6248 4431961.0
epoch 14 293 256 3304 2544085.25
epoch 15 261 256 2976 3569450.5
epoch 16 296 256 3344 1650875.5
epoch 17 543 512 7384 16167885.0
epoch 18 593 512 8024 7118882.0
epoch 19 557 512 7552 13999809.0
epoch 24 438 512 5828 18028396.0
epoch 25 554 512 7528 48800396.0
epoch 26 274 256 3156 29821330.0
epoch 27 433 512 5776 30421734.0
epoch 28 279 256 3180 14801079.0
epoch 29 480 512 6232 12242843.0
epoch 30 647 512 8956 36532196.0
epoch 31 711 1024 11076 94463824.0
epoch 32 314 256 3560 57427928.0
epoch 33 357 256 4112 20142252.0
epoch 34 550 512 7448 39173344.0
epoch 35 788 1024 12384 922016

KeyboardInterrupt: 

In [7]:
num_epochs = 50

def test(model):
    epoch = 0
    while epoch != num_epochs:
        epoch += 1
        game_len, max_score, game_score, last_loss = gen_sample_and_learn(model, None, None, False)
        print ('epoch', epoch, game_len, max_score, game_score, last_loss)

test(model)

epoch 1 556 512 7536 0
epoch 2 411 512 5544 0
epoch 3 561 512 7600 0
epoch 4 946 1024 15004 0
epoch 5 530 512 7244 0
epoch 6 548 512 7468 0
epoch 7 1502 2048 27080 0
epoch 8 1068 1024 16712 0
epoch 9 726 1024 11408 0
epoch 10 994 1024 15920 0
epoch 11 602 512 8108 0
epoch 12 612 512 8432 0
epoch 13 492 512 6764 0
epoch 14 575 512 7696 0
epoch 15 1068 1024 16880 0
epoch 16 289 256 3280 0
epoch 17 549 512 7584 0
epoch 18 407 512 5484 0
epoch 19 987 1024 15800 0
epoch 20 1041 1024 16480 0
epoch 21 845 1024 13120 0
epoch 22 507 512 6892 0
epoch 23 505 512 6880 0
epoch 24 527 512 7240 0
epoch 25 515 512 7068 0
epoch 26 1033 1024 16416 0
epoch 27 859 1024 13244 0
epoch 28 1023 1024 16268 0
epoch 29 1233 1024 19800 0
epoch 30 814 1024 12652 0
epoch 31 1324 1024 20896 0
epoch 32 595 512 8044 0
epoch 33 811 1024 12652 0
epoch 34 511 512 7044 0
epoch 35 652 512 9020 0
epoch 36 541 512 7360 0
epoch 37 1066 1024 16844 0
epoch 38 277 256 3168 0
epoch 39 549 512 7416 0
epoch 40 787 1024 12380 0
epoc