In [1]:
import torch
from torch import nn
import torch.optim as optim
import numpy as np
import random

In [2]:
class NN2048(nn.Module):
    def __init__(self, input_size=16, filter1=128, filter2=1024, drop_prob=0.):
        super(NN2048, self).__init__()
        self.conv_1 = nn.Conv2d(in_channels=input_size, out_channels=input_size, kernel_size=(1,1), padding=0)
        self.conv_2 = nn.Conv2d(in_channels=input_size, out_channels=input_size, kernel_size=(1,1), padding=0)
        self.conv_3 = nn.Conv2d(in_channels=input_size, out_channels=input_size, kernel_size=(1,1), padding=0)
        self.conv_4 = nn.Conv2d(in_channels=input_size, out_channels=input_size, kernel_size=(1,1), padding=0)
        self.conv_5 = nn.Conv2d(in_channels=input_size, out_channels=input_size, kernel_size=(1,1), padding=0)
        self.conv_6 = nn.Conv2d(in_channels=input_size, out_channels=input_size, kernel_size=(1,1), padding=0)
        
        self.conv_a = nn.Conv2d(in_channels=input_size, out_channels=filter1, kernel_size=(2,1), padding=0)
        self.conv_a3 = nn.Conv2d(in_channels=input_size, out_channels=filter1, kernel_size=(3,1), padding=0)
        self.conv_a4 = nn.Conv2d(in_channels=input_size, out_channels=filter1, kernel_size=(4,1), padding=0)
        self.conv_b = nn.Conv2d(in_channels=input_size, out_channels=filter1, kernel_size=(1,2), padding=0)
        self.conv_b3 = nn.Conv2d(in_channels=input_size, out_channels=filter1, kernel_size=(1,3), padding=0)
        self.conv_b4 = nn.Conv2d(in_channels=input_size, out_channels=filter1, kernel_size=(1,4), padding=0)
        self.conv_c = nn.Conv2d(in_channels=input_size, out_channels=filter1, kernel_size=(2,2), padding=0)
        
        self.conv_aa = nn.Conv2d(in_channels=filter1, out_channels=filter2, kernel_size=(2,1), padding=0)
        self.conv_ab = nn.Conv2d(in_channels=filter1, out_channels=filter2, kernel_size=(1,2), padding=0)
        self.conv_ba = nn.Conv2d(in_channels=filter1, out_channels=filter2, kernel_size=(2,1), padding=0)
        self.conv_bb = nn.Conv2d(in_channels=filter1, out_channels=filter2, kernel_size=(1,2), padding=0)
        
        self.conv_ab3 = nn.Conv2d(in_channels=filter1, out_channels=filter2, kernel_size=(1,3), padding=0)
        self.conv_ba3 = nn.Conv2d(in_channels=filter1, out_channels=filter2, kernel_size=(3,1), padding=0)
        self.conv_ab4 = nn.Conv2d(in_channels=filter1, out_channels=filter2, kernel_size=(1,4), padding=0)
        self.conv_ba4 = nn.Conv2d(in_channels=filter1, out_channels=filter2, kernel_size=(4,1), padding=0)
        self.conv_c2 = nn.Conv2d(in_channels=filter1, out_channels=filter2, kernel_size=(2,2), padding=0)
        self.pool = nn.MaxPool2d(2)
        
        self.relu = nn.ReLU()
        self.W_aa = nn.Linear(filter2 * 8, 1)
        self.W_ab = nn.Linear(filter2 * 9, 1)
        self.W_ba = nn.Linear(filter2 * 9, 1)
        self.W_bb = nn.Linear(filter2 * 8, 1)
        self.W_ab3 = nn.Linear(filter2 * 4, 1)
        self.W_ba3 = nn.Linear(filter2 * 4, 1)
        self.W_ab4 = nn.Linear(filter2 * 1, 1)
        self.W_ba4 = nn.Linear(filter2 * 1, 1)
        self.W_c = nn.Linear(filter2 * 1, 1)
        self.W_5 = nn.Linear(input_size * 4, 1)
        self.W_6 = nn.Linear(input_size * 16, 1)

    def flatten(self, x):
        N = x.size()[0]
        return x.view(N, -1)
        
    def forward(self, x):
        x = x.float()
        x1 = self.relu(self.conv_1(x))
        x2 = self.relu(self.conv_2(x))
        x3 = self.relu(self.conv_3(x))
        x4 = self.relu(self.conv_4(x))
        x5 = self.flatten(self.relu(self.conv_5(self.pool(x))))
        x6 = self.flatten(self.relu(self.conv_6(x)))
        
        a = self.relu(self.conv_a(x1))
        b = self.relu(self.conv_b(x1))
        c = self.relu(self.conv_c(x2))
        a3 = self.relu(self.conv_a3(x3))
        b3 = self.relu(self.conv_b3(x3))
        a4 = self.relu(self.conv_a4(x4))
        b4 = self.relu(self.conv_b4(x4))
        
        aa = self.flatten(self.relu(self.conv_aa(a)))
        ab = self.flatten(self.relu(self.conv_ab(a)))
        ba = self.flatten(self.relu(self.conv_ba(b)))
        bb = self.flatten(self.relu(self.conv_bb(b)))
        
        ab3 = self.flatten(self.relu(self.conv_ab3(a3)))
        ba3 = self.flatten(self.relu(self.conv_ba3(b3)))
        ab4 = self.flatten(self.relu(self.conv_ab4(a4)))
        ba4 = self.flatten(self.relu(self.conv_ba4(b4)))
        c2 = self.relu(self.conv_c2(c))
        c3 = self.flatten(self.pool(c2))
        
        out = self.W_aa(aa) + self.W_ab(ab) + self.W_ba(ba) + self.W_bb(bb) + \
              self.W_ab4(ab4) + self.W_ba4(ba4) + self.W_c(c3) + \
              self.W_ab3(ab3) + self.W_ba3(ba3) + self.W_5(x5) + self.W_6(x6)
        
        return out

In [3]:
def make_input(grid):
    r = np.zeros(shape=(16, 4, 4))
    for i in range(4):
        for j in range(4):
            r[grid[i, j],i, j]=1
    return r

def add_two(mat):
    indexs=np.argwhere(mat==0)
    index=np.random.randint(0,len(indexs))
    mat[tuple(indexs[index])] = 1
    return mat

In [4]:
singleScore=[0,0,4,16,48,128,320,768,1792,4096,9216,20480,45056,98304,212992,458752,983040]
moveDict=np.load('move.npy')

def move(list):
    return moveDict[list[0],list[1],list[2],list[3],:]

def lookup(x):
    return singleScore[x]

lookup = np.vectorize(lookup)

def getScore(matrix):
    return np.sum(lookup(matrix))

def getMove(grid):
    board_list = []
    for i in range(4):
        newGrid=moveGrid(grid, i)
        if not isSame(grid,newGrid):
            board_list.append((newGrid, i, getScore(newGrid)))
    return board_list
        
def moveGrid(grid,i):
    # new=np.zeros((4,4),dtype=np.int)
    new = None
    if i==0:
        # move up
        grid=np.transpose(grid)
        new = np.stack([move(grid[row,:]) for row in range(4)], axis = 0).astype(int).T
    elif i==1:
        # move left
        new = np.stack([move(grid[row,:]) for row in range(4)], axis = 0).astype(int)
    elif i==2:
        # move down
        grid=np.transpose(grid)
        new = np.stack([np.flip(move(np.flip(grid[row,:]))) for row in range(4)], axis = 0).astype(int).T
    elif i==3:
        # move right
        new = np.stack([np.flip(move(np.flip(grid[row,:]))) for row in range(4)], axis = 0).astype(int)
    return new

def isSame(grid1,grid2):
    return np.all(grid1==grid2)

In [5]:
def Vchange(grid, v):
    g0 = grid
    g1 = g0[:,::-1,:]
    g2 = g0[:,:,::-1]
    g3 = g2[:,::-1,:]
    r0 = grid.swapaxes(1,2)
    r1 = r0[:,::-1,:]
    r2 = r0[:,:,::-1]
    r3 = r2[:,::-1,:]
    xtrain = np.array([g0,g1,g2,g3,r0,r1,r2,r3])
    ytrain = np.array([v]*8)
    return xtrain, ytrain

def gen_sample_and_learn(model, optimizer, loss_fn, is_train = False, explorationProb=0.1):
    model.eval()
    game_len = 0
    game_score = 0
    last_grid1 = np.zeros((4,4),dtype=np.int)
    last_grid1 = add_two(last_grid1)
    last_grid2 = make_input(last_grid1)
    last_loss = 0

    while True:
        grid_array = add_two(last_grid1)
        board_list = getMove(grid_array)
        if board_list:
            boards = np.array([make_input(g) for g,m,s in board_list])
            p = model(torch.from_numpy(boards).cuda()).flatten().detach()        
            game_len += 1
            best_v = None
            for i, (g,m,s) in enumerate(board_list):
                v = (s - game_score) + p[i].item()
                if best_v is None or v > best_v:
                    best_v = v
                    best_score = s
                    best_grid1 = board_list[i][0]
                    best_grid2 = boards[i]
                    
        else:
            best_v = 0
            best_grid1 = None
            best_grid2 = None
            
        if is_train:
            x, y = Vchange(last_grid2, best_v)
            x = torch.from_numpy(x).cuda()
            y = torch.from_numpy(y).unsqueeze(dim=1).cuda().float()
            model.train()
            optimizer.zero_grad()
            pred = model(x)
            loss = loss_fn(pred, y) / 4
            last_loss = loss.item()
            loss.backward()
#             nn.utils.clip_grad_norm_(model.parameters(), 10.0) #
            optimizer.step()
            model.eval()
#             if game_len % 50 == 0:
#                 print (game_len, last_loss)
                
        if not board_list:
            break
            
        # gibbs sampling or espilon-greedy
        if is_train and random.random() < explorationProb:
            idx = random.randint(0, len(board_list) - 1)
            game_score = board_list[idx][2]
            last_grid1 = board_list[idx][0]
            last_grid2 = boards[idx]
        else:
            game_score = best_score
            last_grid1 = best_grid1
            last_grid2 = best_grid2
        
    return game_len, 2**grid_array.max(), game_score, last_loss

In [6]:
lr = 1e-3
weight_decay = 1e-6
beta1 = 0.8

model = NN2048().cuda()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay, betas=(beta1, 0.999))
loss=nn.MSELoss()

In [7]:
num_epochs = 500

def train(model, optimizer, loss):
    epoch = 0
    while epoch != num_epochs:
        epoch += 1
        game_len, max_score, game_score, last_loss = gen_sample_and_learn(model, optimizer, loss, True, 0)
        print ('epoch', epoch, game_len, max_score, game_score, last_loss)
    
train(model, optimizer, loss)

epoch 1 103 64 740 9871229.0
epoch 2 118 64 928 3312308.5
epoch 3 157 128 1452 2230751.75
epoch 4 172 128 1576 1470255.625
epoch 5 207 128 1932 1086375.375
epoch 6 163 128 1492 826240.375
epoch 7 209 128 2000 760496.125
epoch 8 425 512 5628 3388011.25
epoch 9 216 256 2408 552132.5
epoch 10 253 256 2776 652378.0625
epoch 11 328 256 3824 1228282.25
epoch 12 434 256 5276 2206250.5
epoch 13 319 256 3708 691608.5625
epoch 14 360 256 4180 1004677.75
epoch 15 240 256 2684 606612.5
epoch 16 491 512 6732 1481501.5
epoch 17 368 256 4184 905483.0
epoch 18 465 512 6464 1924369.0
epoch 19 505 512 6976 2630374.0
epoch 20 603 512 8112 2274636.5
epoch 21 390 256 4668 1052487.875
epoch 22 463 512 6036 2056889.75
epoch 23 166 128 1536 420576.28125
epoch 24 371 256 4284 658851.5625
epoch 25 308 256 3488 677235.75
epoch 26 211 128 2076 433458.03125
epoch 27 411 256 4784 599067.0
epoch 28 197 128 1840 385357.5625
epoch 29 192 128 1744 192557.171875
epoch 30 250 256 2756 330656.0625
epoch 31 232 128 2304 21

epoch 246 1033 1024 16416 11471766.0
epoch 247 531 512 7260 10119644.0
epoch 248 557 512 7544 4643081.0
epoch 249 292 256 3356 5100809.0
epoch 250 813 1024 12672 7455433.0
epoch 251 874 1024 13372 4952278.5
epoch 252 925 1024 14784 24434740.0
epoch 253 572 512 7680 6636346.0
epoch 254 1080 1024 17052 29712848.0
epoch 255 776 1024 12224 15956453.0
epoch 256 531 512 7260 13164506.0
epoch 257 601 512 8104 11093490.0
epoch 258 1050 1024 16564 12632287.0
epoch 259 340 256 3936 10183050.0
epoch 260 1020 1024 16240 19750808.0
epoch 261 413 256 5056 10624762.0
epoch 262 914 1024 14680 17902844.0
epoch 263 452 512 5928 5736547.0
epoch 264 539 512 7340 6697895.0
epoch 265 311 256 3516 4873318.5
epoch 266 1009 1024 16024 5677433.0
epoch 267 851 1024 13164 16597510.0
epoch 268 369 256 4248 6903490.0
epoch 269 448 512 5904 5102643.5
epoch 270 537 512 7328 4599804.0
epoch 271 443 512 5868 5102554.0
epoch 272 795 1024 12472 12834579.0
epoch 273 722 1024 11600 13047177.0
epoch 274 932 1024 14880 56893

epoch 476 795 1024 12460 117312464.0
epoch 477 965 1024 15372 133862768.0
epoch 478 1896 2048 34764 374880672.0
epoch 479 1550 2048 27708 215254464.0
epoch 480 1206 2048 22188 190834656.0
epoch 481 836 1024 12884 142538432.0
epoch 482 978 1024 15452 113361776.0
epoch 483 1027 1024 16348 79086016.0
epoch 484 950 1024 15036 65243432.0
epoch 485 829 1024 12808 62757164.0
epoch 486 817 1024 12696 57203120.0
epoch 487 1102 1024 17256 85984112.0
epoch 488 1763 2048 31984 85559088.0
epoch 489 1552 2048 27728 114693392.0
epoch 490 1097 1024 17200 88213160.0
epoch 491 985 1024 15528 60734248.0
epoch 492 686 512 9384 71895776.0
epoch 493 1110 1024 17308 7214030.0
epoch 494 1058 1024 16648 35746096.0
epoch 495 1555 2048 27756 76306048.0
epoch 496 1116 1024 17352 23094772.0
epoch 497 1114 1024 17388 16948196.0
epoch 498 1243 1024 19948 25731956.0
epoch 499 1574 2048 27740 34177696.0
epoch 500 1045 1024 16504 59592392.0


In [8]:
num_epochs = 500

train(model, optimizer, loss)

epoch 1 2014 2048 36360 188960064.0
epoch 2 546 512 7388 105737840.0
epoch 3 1155 1024 18156 54898752.0
epoch 4 979 1024 15708 84467648.0
epoch 5 689 512 9372 670892.5
epoch 6 1052 1024 16576 3077875.0
epoch 7 785 1024 12368 36728720.0
epoch 8 953 1024 15200 32232240.0
epoch 9 1253 1024 19964 9722018.0
epoch 10 820 1024 12712 6649119.0
epoch 11 1582 2048 28044 42463048.0
epoch 12 1053 1024 16584 43931804.0
epoch 13 1480 2048 26144 94424368.0
epoch 14 658 512 9556 92066496.0
epoch 15 1566 2048 27852 122995936.0
epoch 16 1233 1024 19080 20592682.0
epoch 17 1996 2048 36156 398927392.0
epoch 18 1364 2048 24424 144793472.0
epoch 19 1022 1024 16140 91061368.0
epoch 20 1498 2048 27020 213220416.0
epoch 21 945 1024 15372 403020992.0
epoch 22 2066 2048 37000 210853312.0
epoch 23 1816 2048 32940 750493440.0
epoch 24 1311 2048 23804 567252864.0
epoch 25 1505 2048 27096 357948160.0
epoch 26 1578 2048 28012 177642176.0
epoch 27 1187 1024 18508 137677504.0
epoch 28 582 512 7916 226482496.0
epoch 29 

epoch 223 1432 2048 25996 6065531392.0
epoch 224 1074 1024 16924 1686042880.0
epoch 225 1927 2048 34588 1951083904.0
epoch 226 784 1024 12364 3992686848.0
epoch 227 2045 2048 36744 1245356160.0
epoch 228 1022 1024 16264 1603339776.0
epoch 229 1817 2048 32936 1329981824.0
epoch 230 1836 2048 33144 622635264.0
epoch 231 1902 2048 34780 1960112384.0
epoch 232 1802 2048 32816 3066117120.0
epoch 233 1549 2048 27676 801712896.0
epoch 234 1749 2048 32104 1546373888.0
epoch 235 969 1024 15212 647094528.0
epoch 236 1953 2048 35292 590236544.0
epoch 237 1554 2048 27740 803318656.0
epoch 238 2012 2048 36292 1063100416.0
epoch 239 1813 2048 32912 1720778368.0
epoch 240 549 512 7416 1211834624.0
epoch 241 847 1024 13132 436735104.0
epoch 242 603 512 8172 296181568.0
epoch 243 1628 2048 28616 422208192.0
epoch 244 1593 2048 28140 124422272.0
epoch 245 1928 2048 35100 1704463360.0
epoch 246 1027 1024 16348 805739776.0
epoch 247 1135 1024 17932 489972032.0
epoch 248 1089 1024 17052 1251975.875
epoch 2

epoch 437 2062 2048 36940 3851333120.0
epoch 438 1885 2048 34048 3571223552.0
epoch 439 1042 1024 16476 3425934848.0
epoch 440 1421 2048 25916 4096365056.0
epoch 441 1060 1024 16632 948045696.0
epoch 442 1966 2048 35412 142237632.0
epoch 443 2061 2048 36924 1318673408.0
epoch 444 1422 2048 25416 2350256128.0
epoch 445 1939 2048 35208 1343001856.0
epoch 446 2003 2048 36220 1423047936.0
epoch 447 2055 2048 36892 1080984064.0
epoch 448 1305 2048 23688 2530927104.0
epoch 449 1935 2048 35148 792673920.0
epoch 450 2029 2048 36488 1119353344.0
epoch 451 1808 2048 32848 6366258176.0
epoch 452 1065 1024 16864 3243029504.0
epoch 453 1940 2048 35128 1228801536.0
epoch 454 925 1024 14844 4487952384.0
epoch 455 919 1024 14716 3229803520.0
epoch 456 857 512 12428 112800664.0
epoch 457 1811 2048 32868 528680064.0
epoch 458 1025 1024 16276 438739712.0
epoch 459 1815 2048 32924 242790304.0
epoch 460 2034 2048 36636 233969024.0
epoch 461 1129 1024 17888 455596896.0
epoch 462 1931 2048 35116 323034752.0


In [9]:
num_epochs = 100

def test(model):
    epoch = 0
    while epoch != num_epochs:
        epoch += 1
        game_len, max_score, game_score, last_loss = gen_sample_and_learn(model, None, None, False)
        print ('epoch', epoch, game_len, max_score, game_score, last_loss)

test(model)

epoch 1 2026 2048 36412 0
epoch 2 2040 2048 36680 0
epoch 3 1060 1024 16660 0
epoch 4 801 1024 12488 0
epoch 5 727 1024 11628 0
epoch 6 1042 1024 16476 0
epoch 7 796 1024 12476 0
epoch 8 925 1024 14280 0
epoch 9 1552 2048 27728 0
epoch 10 1801 2048 32800 0
epoch 11 1547 2048 27692 0
epoch 12 1054 1024 16588 0
epoch 13 1038 1024 16456 0
epoch 14 1031 1024 16408 0
epoch 15 1864 2048 33568 0
epoch 16 1809 2048 32856 0
epoch 17 1822 2048 32972 0
epoch 18 1553 2048 27736 0
epoch 19 1823 2048 32980 0
epoch 20 569 512 7628 0
epoch 21 1617 2048 29264 0
epoch 22 2048 2048 36760 0
epoch 23 1042 1024 16476 0
epoch 24 1044 1024 16488 0
epoch 25 1553 2048 27760 0
epoch 26 1004 1024 15984 0
epoch 27 1972 2048 35884 0
epoch 28 1041 1024 16472 0
epoch 29 2080 2048 37264 0
epoch 30 1559 2048 27804 0
epoch 31 1488 2048 26956 0
epoch 32 1813 2048 32888 0
epoch 33 1914 2048 34892 0
epoch 34 1544 2048 27648 0
epoch 35 1058 1024 16620 0
epoch 36 975 1024 15676 0
epoch 37 1054 1024 16588 0
epoch 38 980 1024 

In [10]:
import os
experiment_dir = "model"
filename = "model6.pth.tar"
num_epochs = 1000

def save_model(state, filename='model.pth.tar'):
    filename = os.path.join(experiment_dir, filename)
    torch.save(state, filename)

save_model({
    'epoch': num_epochs,
    'state_dict': model.cpu().state_dict(),
    'optimizer': optimizer.state_dict(),
}, filename)